├── .gitignore ├── data ├── FN_data_gen.py ├── FNexp │ └── emtext2triplelist.py └── utils.py ├── lstm ├── BERT_tf2 │ ├── bert4keras │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── layers.py │ │ ├── models.py │ │ ├── optimizers.py │ │ ├── snippets.py │ │ └── tokenizers.py │ └── bert_tools.py ├── extraction.py └── utils.py ├── readme.md ├── requirements.txt └── rere ├── BERT_tf2 ├── bert4keras │ ├── __init__.py │ ├── backend.py │ ├── layers.py │ ├── models.py │ ├── optimizers.py │ ├── snippets.py │ └── tokenizers.py └── bert_tools.py ├── batch_run.py ├── bert-to-h5.py ├── bert4keras ├── __init__.py ├── backend.py ├── layers.py ├── models.py ├── optimizers.py ├── snippets.py └── tokenizers.py ├── config.py ├── extraction.py ├── model └── readme.md ├── readme.md ├── tfhub ├── chinese_roberta_wwm_ext_L-12_H-768_A-12 │ └── readme.md └── uncased_L-12_H-768_A-12 │ └── readme.md └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.json 2 | *.zip 3 | *.h5 4 | *.txt 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | tfhub/* 14 | **/.vs/* 15 | 16 | 17 | # other files 18 | plot/* 19 | get_partial/* 20 | # rere/BERT_TF2/* 21 | data/* 22 | !data/*/ 23 | !data/*.py -------------------------------------------------------------------------------- /data/FN_data_gen.py: -------------------------------------------------------------------------------- 1 | import utils,os,json,random 2 | # folders = ['Wiki-KBP','WebNLG','NYT','NYT11-HRL','NYT10-HRL','ske2019'][3:] 3 | 4 | 5 | folders = ['ske2019'] 6 | ds = ['train','valid','dev','test'] 7 | fnratio = [0.1,0.2,0.3,0.4,0.5] 8 | 9 | rel_set = set() 10 | for folder in folders: 11 | for r in fnratio: 12 | for d in ds: 13 | dname = f'{folder}' 14 | odname = f'FNexp/{folder}@{r}' 15 | print(os.path.isdir(odname)) 16 | if not os.path.isdir(odname): os.makedirs(odname) 17 | fn = f'{dname}/new_{d}.json' 18 | # print(fn) 19 | # print(os.path.exists(fn)) 20 | if not os.path.exists(fn): continue 21 | li = utils.LoadJsons(fn) 22 | if d=='train': 23 | newli = [] 24 | for dic in li: 25 | newspos = [] 26 | if not dic['relationMentions']:continue 27 | for spo in dic['relationMentions']: 28 | if random.random()>float(r): 29 | newspos.append(spo) 30 | dic['relationMentions'] = newspos 31 | if not dic['relationMentions']:continue 32 | newli.append(dic) 33 | else: 34 | newli = li 35 | utils.SaveList(map(lambda x:json.dumps(x, ensure_ascii=False), newli), f'{odname}/new_{d}.json') 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /data/FNexp/emtext2triplelist.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | folders = ['NYT10-HRL','NYT11-HRL','ske2019'] 4 | folders=['ske2019'] 5 | ds = ['train','dev','test'] 6 | fnratio = [0] 7 | for r in fnratio: 8 | for folder in folders: 9 | for dd in ds: 10 | rel_set = set() 11 | data = [] 12 | name= f'FNexp/{folder}@{r}/new_{dd}.json' 13 | print(name) 14 | load_dic=[] 15 | with open('./'+name,'r',encoding='utf-8') as load_f: 16 | for l in load_f.readlines(): 17 | a = json.loads(l) 18 | print(l) 19 | if not a['relationMentions']: 20 | continue 21 | line = { 22 | 'text': a['sentText'].lstrip('\"').strip('\r\n').rstrip('\"'), 23 | 'triple_list': [(i['em1Text'], i['label'], i['em2Text']) for i in a['relationMentions'] if i['label'] != 'None'] 24 | } 25 | if not line['triple_list']: 26 | continue 27 | data.append(line) 28 | for rm in a['relationMentions']: 29 | if rm['label'] != 'None': 30 | rel_set.add(rm['label']) 31 | with open(f'FNexp/{folder}@{r}/{dd}_triples.json', 'w', encoding='utf-8') as f: 32 | json.dump(data, f, indent=4, ensure_ascii=False) 33 | 34 | id2predicate = {i:j for i,j in enumerate(sorted(rel_set))} 35 | predicate2id = {j:i for i,j in id2predicate.items()} 36 | 37 | 38 | with open(f'FNexp/{folder}@{r}/rel2id.json', 'w', encoding='utf-8') as f: 39 | json.dump([id2predicate, predicate2id], f, indent=4, ensure_ascii=False) 40 | 41 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | # coding = utf-8 2 | 3 | import os, re, sys, random, urllib.parse, json 4 | from collections import defaultdict 5 | 6 | def WriteLine(fout, lst): 7 | fout.write('\t'.join([str(x) for x in lst]) + '\n') 8 | 9 | def RM(patt, sr): 10 | mat = re.search(patt, sr, re.DOTALL | re.MULTILINE) 11 | return mat.group(1) if mat else '' 12 | 13 | try: import requests 14 | except: pass 15 | def GetPage(url, cookie='', proxy='', timeout=5): 16 | try: 17 | headers = {'User-Agent':'Mozilla/5.0 (Windows NT 6.3; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/42.0.2311.90 Safari/537.36'} 18 | if cookie != '': headers['cookie'] = cookie 19 | if proxy != '': 20 | proxies = {'http': proxy, 'https': proxy} 21 | resp = requests.get(url, headers=headers, proxies=proxies, timeout=timeout) 22 | else: resp = requests.get(url, headers=headers, timeout=timeout) 23 | content = resp.content 24 | try: 25 | import chardet 26 | charset = chardet.detect(content).get('encoding','utf-8') 27 | if charset.lower().startswith('gb'): charset = 'gbk' 28 | content = content.decode(charset, errors='replace') 29 | except: 30 | headc = content[:min([3000,len(content)])].decode(errors='ignore') 31 | charset = RM('charset="?([-a-zA-Z0-9]+)', headc) 32 | if charset == '': charset = 'utf-8' 33 | content = content.decode(charset, errors='replace') 34 | except Exception as e: 35 | print(e) 36 | content = '' 37 | return content 38 | 39 | def GetJson(url, cookie='', proxy='', timeout=5.0): 40 | try: 41 | headers = {'User-Agent':'Mozilla/5.0 (Windows NT 6.3; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/42.0.2311.90 Safari/537.36'} 42 | if cookie != '': headers['cookie'] = cookie 43 | if proxy != '': 44 | proxies = {'http': proxy, 'https': proxy} 45 | resp = requests.get(url, headers=headers, proxies=proxies, timeout=timeout) 46 | else: resp = requests.get(url, headers=headers, timeout=timeout) 47 | return resp.json() 48 | except Exception as e: 49 | print(e) 50 | content = {} 51 | return content 52 | 53 | def FindAllHrefs(url, content=None, regex=''): 54 | ret = set() 55 | if content == None: content = GetPage(url) 56 | patt = re.compile('href="?([a-zA-Z0-9-_:/.%]+)') 57 | for xx in re.findall(patt, content): 58 | ret.add( urllib.parse.urljoin(url, xx) ) 59 | if regex != '': ret = (x for x in ret if re.match(regex, x)) 60 | return list(ret) 61 | 62 | def Translate(txt): 63 | postdata = {'from': 'en', 'to': 'zh', 'transtype': 'realtime', 'query': txt} 64 | url = "http://fanyi.baidu.com/v2transapi" 65 | try: 66 | resp = requests.post(url, data=postdata, 67 | headers={'Referer': 'http://fanyi.baidu.com/'}) 68 | ret = resp.json() 69 | ret = ret['trans_result']['data'][0]['dst'] 70 | except Exception as e: 71 | print(e) 72 | ret = '' 73 | return ret 74 | 75 | def IsChsStr(z): 76 | return re.search('^[\u4e00-\u9fa5]+$', z) is not None 77 | 78 | def FreqDict2List(dt): 79 | return sorted(dt.items(), key=lambda d:d[-1], reverse=True) 80 | 81 | def SelectRowsbyCol(fn, ofn, st, num = 0): 82 | with open(fn, encoding = "utf-8") as fin: 83 | with open(ofn, "w", encoding = "utf-8") as fout: 84 | for line in (ll for ll in fin.read().split('\n') if ll != ""): 85 | if line.split('\t')[num] in st: 86 | fout.write(line + '\n') 87 | 88 | def MergeFiles(dir, objfile, regstr = ".*"): 89 | with open(objfile, "w", encoding = "utf-8") as fout: 90 | for file in os.listdir(dir): 91 | if re.match(regstr, file): 92 | with open(os.path.join(dir, file), encoding = "utf-8") as filein: 93 | fout.write(filein.read()) 94 | 95 | def JoinFiles(fnx, fny, ofn): 96 | with open(fnx, encoding = "utf-8") as fin: 97 | lx = [vv for vv in fin.read().split('\n') if vv != ""] 98 | with open(fny, encoding = "utf-8") as fin: 99 | ly = [vv for vv in fin.read().split('\n') if vv != ""] 100 | with open(ofn, "w", encoding = "utf-8") as fout: 101 | for i in range(min(len(lx), len(ly))): 102 | fout.write(lx[i] + "\t" + ly[i] + "\n") 103 | 104 | 105 | def RemoveDupRows(file, fobj='*'): 106 | st = set() 107 | if fobj == '*': fobj = file 108 | with open(file, encoding = "utf-8") as fin: 109 | for line in fin.read().split('\n'): 110 | if line == "": continue 111 | st.add(line) 112 | with open(fobj, "w", encoding = "utf-8") as fout: 113 | for line in st: 114 | fout.write(line + '\n') 115 | 116 | def LoadCSV(fn): 117 | ret = [] 118 | with open(fn, encoding='utf-8') as fin: 119 | for line in fin: 120 | lln = line.rstrip('\r\n').split('\t') 121 | ret.append(lln) 122 | return ret 123 | 124 | def LoadCSVg(fn): 125 | with open(fn, encoding='utf-8') as fin: 126 | for line in fin: 127 | lln = line.rstrip('\r\n').split('\t') 128 | yield lln 129 | 130 | def SaveCSV(csv, fn): 131 | with open(fn, 'w', encoding='utf-8') as fout: 132 | for x in csv: 133 | WriteLine(fout, x) 134 | 135 | def SplitTables(fn, limit=3): 136 | rst = set() 137 | with open(fn, encoding='utf-8') as fin: 138 | for line in fin: 139 | lln = line.rstrip('\r\n').split('\t') 140 | rst.add(len(lln)) 141 | if len(rst) > limit: 142 | print('%d tables, exceed limit %d' % (len(rst), limit)) 143 | return 144 | for ii in rst: 145 | print('%d columns' % ii) 146 | with open(fn.replace('.txt', '') + '.split.%d.txt' % ii, 'w', encoding='utf-8') as fout: 147 | with open(fn, encoding='utf-8') as fin: 148 | for line in fin: 149 | lln = line.rstrip('\r\n').split('\t') 150 | if len(lln) == ii: 151 | fout.write(line) 152 | 153 | def LoadSet(fn): 154 | with open(fn, encoding="utf-8") as fin: 155 | st = set(ll for ll in fin.read().split('\n') if ll != "") 156 | return st 157 | 158 | def LoadList(fn): 159 | with open(fn, encoding="utf-8") as fin: 160 | st = list(ll for ll in fin.read().split('\n') if ll != "") 161 | return st 162 | 163 | def LoadJsonsg(fn): return map(json.loads, LoadListg(fn)) 164 | def LoadJsons(fn): return list(LoadJsonsg(fn)) 165 | 166 | def LoadListg(fn): 167 | with open(fn, encoding="utf-8") as fin: 168 | for ll in fin: 169 | ll = ll.strip() 170 | if ll != '': yield ll 171 | 172 | def LoadDict(fn, func=str): 173 | dict = {} 174 | with open(fn, encoding = "utf-8") as fin: 175 | for lv in (ll.split('\t', 1) for ll in fin.read().split('\n') if ll != ""): 176 | dict[lv[0]] = func(lv[1]) 177 | return dict 178 | 179 | def SaveDict(dict, ofn, output0 = True): 180 | with open(ofn, "w", encoding = "utf-8") as fout: 181 | for k in dict.keys(): 182 | if output0 or dict[k] != 0: 183 | fout.write(str(k) + "\t" + str(dict[k]) + "\n") 184 | 185 | def SaveList(st, ofn): 186 | with open(ofn, "w", encoding = "utf-8") as fout: 187 | for k in st: 188 | fout.write(str(k) + "\n") 189 | 190 | def ListDirFiles(dir, filter=None): 191 | if filter is None: 192 | return [os.path.join(dir, x) for x in os.listdir(dir)] 193 | return [os.path.join(dir, x) for x in os.listdir(dir) if filter(x)] 194 | 195 | def ProcessDir(dir, func, param): 196 | for file in os.listdir(dir): 197 | print(file) 198 | func(os.path.join(dir, file), param) 199 | 200 | def GetLines(fn): 201 | with open(fn, encoding = "utf-8", errors = 'ignore') as fin: 202 | lines = list(map(str.strip, fin.readlines())) 203 | return lines 204 | 205 | 206 | def SortRows(file, fobj, cid, type=int, rev = True): 207 | lines = LoadCSV(file) 208 | dat = [] 209 | for dv in lines: 210 | if len(dv) <= cid: continue 211 | dat.append((type(dv[cid]), dv)) 212 | with open(fobj, "w", encoding = "utf-8") as fout: 213 | for dd in sorted(dat, reverse = rev): 214 | fout.write('\t'.join(dd[1]) + '\n') 215 | 216 | def SampleRows(file, fobj, num): 217 | zz = list(open(file, encoding='utf-8')) 218 | num = min([num, len(zz)]) 219 | zz = random.sample(zz, num) 220 | with open(fobj, 'w', encoding='utf-8') as fout: 221 | for xx in zz: fout.write(xx) 222 | 223 | def SetProduct(file1, file2, fobj): 224 | l1, l2 = GetLines(file1), GetLines(file2) 225 | with open(fobj, 'w', encoding='utf-8') as fout: 226 | for z1 in l1: 227 | for z2 in l2: 228 | fout.write(z1 + z2 + '\n') 229 | 230 | class TokenList: 231 | def __init__(self, file, low_freq=2, source=None, func=None, save_low_freq=2, special_marks=[]): 232 | if not os.path.exists(file): 233 | tdict = defaultdict(int) 234 | for i, xx in enumerate(special_marks): tdict[xx] = 100000000 - i 235 | for xx in source: 236 | for token in func(xx): tdict[token] += 1 237 | tokens = FreqDict2List(tdict) 238 | tokens = [x for x in tokens if x[1] >= save_low_freq] 239 | SaveCSV(tokens, file) 240 | self.id2t = ['', ''] + \ 241 | [x for x,y in LoadCSV(file) if float(y) >= low_freq] 242 | self.t2id = {v:k for k,v in enumerate(self.id2t)} 243 | def get_id(self, token): return self.t2id.get(token, 1) 244 | def get_token(self, ii): return self.id2t[ii] 245 | def get_num(self): return len(self.id2t) 246 | 247 | def CalcF1(correct, output, golden): 248 | prec = correct / max(output, 1); reca = correct / max(golden, 1); 249 | f1 = 2 * prec * reca / max(1e-9, prec + reca) 250 | pstr = 'Prec: %.4f %d/%d, Reca: %.4f %d/%d, F1: %.4f' % (prec, correct, output, reca, correct, golden, f1) 251 | return pstr 252 | 253 | def Upgradeljqpy(url=None): 254 | if url is None: url = 'http://gdm.fudan.edu.cn/files1/ljq/ljqpy.py' 255 | dirs = [dir for dir in reversed(sys.path) if os.path.isdir(dir) and 'ljqpy.py' in os.listdir(dir)] 256 | if len(dirs) == 0: raise Exception("package directory no found") 257 | dir = dirs[0] 258 | print('downloading ljqpy.py from %s to %s' % (url, dir)) 259 | resp = requests.get(url) 260 | if b'Upgradeljqpy' not in resp.content: raise Exception('bad file') 261 | with open(os.path.join(dir, 'ljqpy.py'), 'wb') as fout: 262 | fout.write(resp.content) 263 | print('success') 264 | 265 | def sql(cmd=''): 266 | if cmd == '': cmd = input("> ") 267 | cts = [x for x in cmd.strip().lower()] 268 | instr = False 269 | for i in range(len(cts)): 270 | if cts[i] == '"' and cts[i-1] != '\\': instr = not instr 271 | if cts[i] == ' ' and instr: cts[i] = " " 272 | cmds = "".join(cts).split(' ') 273 | keyw = { 'select', 'from', 'to', 'where' } 274 | ct, kn = {}, '' 275 | for xx in cmds: 276 | if xx in keyw: kn = xx 277 | else: ct[kn] = ct.get(kn, "") + " " + xx 278 | 279 | for xx in ct.keys(): 280 | ct[xx] = ct[xx].replace(" ", " ").strip() 281 | 282 | if ct.get('where', "") == "": ct['where'] = 'True' 283 | 284 | if os.path.isdir(ct['from']): fl = [os.path.join(ct['from'], x) for x in os.listdir(ct['from'])] 285 | else: fl = ct['from'].split('+') 286 | 287 | if ct.get('to', "") == "": ct['to'] = 'temp.txt' 288 | 289 | for xx in ct.keys(): 290 | print(xx + " : " + ct[xx]) 291 | 292 | total = 0 293 | with open(ct['to'], 'w', encoding = 'utf-8') as fout: 294 | for fn in fl: 295 | print('selecting ' + fn) 296 | for xx in open(fn, encoding = 'utf-8'): 297 | x = xx.rstrip('\r\n').split('\t') 298 | if eval(ct['where']): 299 | if ct['select'] == '*': res = "\t".join(x) + '\n' 300 | else: res = "\t".join(eval('[' + ct['select'] + ']')) + '\n' 301 | fout.write(res) 302 | total += 1 303 | 304 | print('completed, ' + str(total) + " records") 305 | 306 | def cmd(): 307 | while True: 308 | cmd = input("> ") 309 | sql(cmd) 310 | -------------------------------------------------------------------------------- /lstm/BERT_tf2/bert4keras/__init__.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | 3 | __version__ = '0.8.8' 4 | -------------------------------------------------------------------------------- /lstm/BERT_tf2/bert4keras/backend.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 分离后端函数,主要是为了同时兼容原生keras和tf.keras 3 | # 通过设置环境变量TF_KERAS=1来切换tf.keras 4 | 5 | import os, sys 6 | from distutils.util import strtobool 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorflow.python.util import nest, tf_inspect 10 | from tensorflow.python.eager import tape 11 | from tensorflow.python.ops.custom_gradient import _graph_mode_decorator 12 | 13 | import tensorflow.keras as keras 14 | import tensorflow.keras.backend as K 15 | sys.modules['keras'] = keras 16 | 17 | is_tf_keras = True 18 | 19 | # 判断是否启用重计算(通过时间换空间) 20 | do_recompute = strtobool(os.environ.get('RECOMPUTE', '0')) 21 | 22 | def gelu_erf(x): 23 | """基于Erf直接计算的gelu函数 24 | """ 25 | return 0.5 * x * (1.0 + tf.math.erf(x / np.sqrt(2.0))) 26 | 27 | 28 | def gelu_tanh(x): 29 | """基于Tanh近似计算的gelu函数 30 | """ 31 | cdf = 0.5 * ( 32 | 1.0 + K.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * K.pow(x, 3)))) 33 | ) 34 | return x * cdf 35 | 36 | 37 | def set_gelu(version): 38 | """设置gelu版本 39 | """ 40 | version = version.lower() 41 | assert version in ['erf', 'tanh'], 'gelu version must be erf or tanh' 42 | if version == 'erf': 43 | keras.utils.get_custom_objects()['gelu'] = gelu_erf 44 | else: 45 | keras.utils.get_custom_objects()['gelu'] = gelu_tanh 46 | 47 | 48 | def piecewise_linear(t, schedule): 49 | """分段线性函数 50 | 其中schedule是形如{1000: 1, 2000: 0.1}的字典, 51 | 表示 t ∈ [0, 1000]时,输出从0均匀增加至1,而 52 | t ∈ [1000, 2000]时,输出从1均匀降低到0.1,最后 53 | t > 2000时,保持0.1不变。 54 | """ 55 | schedule = sorted(schedule.items()) 56 | if schedule[0][0] != 0: 57 | schedule = [(0, 0.0)] + schedule 58 | 59 | x = K.constant(schedule[0][1], dtype=K.floatx()) 60 | t = K.cast(t, K.floatx()) 61 | for i in range(len(schedule)): 62 | t_begin = schedule[i][0] 63 | x_begin = x 64 | if i != len(schedule) - 1: 65 | dx = schedule[i + 1][1] - schedule[i][1] 66 | dt = schedule[i + 1][0] - schedule[i][0] 67 | slope = 1.0 * dx / dt 68 | x = schedule[i][1] + slope * (t - t_begin) 69 | else: 70 | x = K.constant(schedule[i][1], dtype=K.floatx()) 71 | x = K.switch(t >= t_begin, x, x_begin) 72 | 73 | return x 74 | 75 | 76 | def search_layer(inputs, name, exclude_from=None): 77 | """根据inputs和name来搜索层 78 | 说明:inputs为某个层或某个层的输出;name为目标层的名字。 79 | 实现:根据inputs一直往上递归搜索,直到发现名字为name的层为止; 80 | 如果找不到,那就返回None。 81 | """ 82 | if exclude_from is None: 83 | exclude_from = set() 84 | 85 | if isinstance(inputs, keras.layers.Layer): 86 | layer = inputs 87 | else: 88 | layer = inputs._keras_history[0] 89 | 90 | if layer.name == name: 91 | return layer 92 | elif layer in exclude_from: 93 | return None 94 | else: 95 | exclude_from.add(layer) 96 | if isinstance(layer, keras.models.Model): 97 | model = layer 98 | for layer in model.layers: 99 | if layer.name == name: 100 | return layer 101 | inbound_layers = layer._inbound_nodes[0].inbound_layers 102 | if not isinstance(inbound_layers, list): 103 | inbound_layers = [inbound_layers] 104 | if len(inbound_layers) > 0: 105 | for layer in inbound_layers: 106 | layer = search_layer(layer, name, exclude_from) 107 | if layer is not None: 108 | return layer 109 | 110 | 111 | def sequence_masking(x, mask, mode=0, axis=None): 112 | """为序列条件mask的函数 113 | mask: 形如(batch_size, seq_len)的0-1矩阵; 114 | mode: 如果是0,则直接乘以mask; 115 | 如果是1,则在padding部分减去一个大正数。 116 | axis: 序列所在轴,默认为1; 117 | """ 118 | if mask is None or mode not in [0, 1]: 119 | return x 120 | else: 121 | if axis is None: 122 | axis = 1 123 | if axis == -1: 124 | axis = K.ndim(x) - 1 125 | assert axis > 0, 'axis must be greater than 0' 126 | for _ in range(axis - 1): 127 | mask = K.expand_dims(mask, 1) 128 | for _ in range(K.ndim(x) - K.ndim(mask) - axis + 1): 129 | mask = K.expand_dims(mask, K.ndim(mask)) 130 | if mode == 0: 131 | return x * mask 132 | else: 133 | return x - (1 - mask) * 1e12 134 | 135 | 136 | def batch_gather(params, indices): 137 | """同tf旧版本的batch_gather 138 | """ 139 | if K.dtype(indices)[:3] != 'int': 140 | indices = K.cast(indices, 'int32') 141 | 142 | try: 143 | return tf.gather(params, indices, batch_dims=K.ndim(indices) - 1) 144 | except Exception as e1: 145 | try: 146 | return tf.batch_gather(params, indices) 147 | except Exception as e2: 148 | raise ValueError('%s\n%s\n' % (e1.message, e2.message)) 149 | 150 | 151 | def pool1d( 152 | x, 153 | pool_size, 154 | strides=1, 155 | padding='valid', 156 | data_format=None, 157 | pool_mode='max' 158 | ): 159 | """向量序列的pool函数 160 | """ 161 | x = K.expand_dims(x, 1) 162 | x = K.pool2d( 163 | x, 164 | pool_size=(1, pool_size), 165 | strides=(1, strides), 166 | padding=padding, 167 | data_format=data_format, 168 | pool_mode=pool_mode 169 | ) 170 | return x[:, 0] 171 | 172 | 173 | def divisible_temporal_padding(x, n): 174 | """将一维向量序列右padding到长度能被n整除 175 | """ 176 | r_len = K.shape(x)[1] % n 177 | p_len = K.switch(r_len > 0, n - r_len, 0) 178 | return K.temporal_padding(x, (0, p_len)) 179 | 180 | 181 | def swish(x): 182 | """swish函数(这样封装过后才有 __name__ 属性) 183 | """ 184 | return tf.nn.swish(x) 185 | 186 | 187 | def leaky_relu(x, alpha=0.2): 188 | """leaky relu函数(这样封装过后才有 __name__ 属性) 189 | """ 190 | return tf.nn.leaky_relu(x, alpha=alpha) 191 | 192 | 193 | def symbolic(f): 194 | """恒等装饰器(兼容旧版本keras用) 195 | """ 196 | return f 197 | 198 | 199 | def graph_mode_decorator(f, *args, **kwargs): 200 | """tf 2.1与之前版本的传参方式不一样,这里做个同步 201 | """ 202 | if tf.__version__ < '2.1': 203 | return _graph_mode_decorator(f, *args, **kwargs) 204 | else: 205 | return _graph_mode_decorator(f, args, kwargs) 206 | 207 | 208 | def recompute_grad(call): 209 | """重计算装饰器(用来装饰Keras层的call函数) 210 | 关于重计算,请参考:https://arxiv.org/abs/1604.06174 211 | """ 212 | if not do_recompute: 213 | return call 214 | 215 | def inner(self, inputs, **kwargs): 216 | """定义需要求梯度的函数以及重新定义求梯度过程 217 | (参考自官方自带的tf.recompute_grad函数) 218 | """ 219 | flat_inputs = nest.flatten(inputs) 220 | call_args = tf_inspect.getfullargspec(call).args 221 | for key in ['mask', 'training']: 222 | if key not in call_args and key in kwargs: 223 | del kwargs[key] 224 | 225 | def kernel_call(): 226 | """定义前向计算 227 | """ 228 | return call(self, inputs, **kwargs) 229 | 230 | def call_and_grad(*inputs): 231 | """定义前向计算和反向计算 232 | """ 233 | with tape.stop_recording(): 234 | outputs = kernel_call() 235 | outputs = tf.identity(outputs) 236 | 237 | def grad_fn(doutputs, variables=None): 238 | watches = list(inputs) 239 | if variables is not None: 240 | watches += list(variables) 241 | with tf.GradientTape() as t: 242 | t.watch(watches) 243 | with tf.control_dependencies([doutputs]): 244 | outputs = kernel_call() 245 | grads = t.gradient( 246 | outputs, watches, output_gradients=[doutputs] 247 | ) 248 | del t 249 | return grads[:len(inputs)], grads[len(inputs):] 250 | 251 | return outputs, grad_fn 252 | 253 | if True: # 仅在tf >= 2.0下可用 254 | outputs, grad_fn = call_and_grad(*flat_inputs) 255 | flat_outputs = nest.flatten(outputs) 256 | 257 | def actual_grad_fn(*doutputs): 258 | grads = grad_fn(*doutputs, variables=self.trainable_weights) 259 | return grads[0] + grads[1] 260 | 261 | watches = flat_inputs + self.trainable_weights 262 | watches = [tf.convert_to_tensor(x) for x in watches] 263 | tape.record_operation( 264 | call.__name__, flat_outputs, watches, actual_grad_fn 265 | ) 266 | return outputs 267 | 268 | return inner 269 | 270 | 271 | # 给旧版本keras新增symbolic方法(装饰器), 272 | # 以便兼容optimizers.py中的代码 273 | K.symbolic = getattr(K, 'symbolic', None) or symbolic 274 | 275 | custom_objects = { 276 | 'gelu_erf': gelu_erf, 277 | 'gelu_tanh': gelu_tanh, 278 | 'gelu': gelu_erf, 279 | 'swish': swish, 280 | 'leaky_relu': leaky_relu, 281 | } 282 | 283 | keras.utils.get_custom_objects().update(custom_objects) 284 | -------------------------------------------------------------------------------- /lstm/BERT_tf2/bert4keras/snippets.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 代码合集 3 | 4 | import six 5 | import logging 6 | import numpy as np 7 | import re 8 | import sys 9 | from collections import defaultdict 10 | import json 11 | 12 | _open_ = open 13 | is_py2 = six.PY2 14 | 15 | if not is_py2: 16 | basestring = str 17 | 18 | 19 | def to_array(*args): 20 | """批量转numpy的array 21 | """ 22 | results = [np.array(a) for a in args] 23 | if len(args) == 1: 24 | return results[0] 25 | else: 26 | return results 27 | 28 | 29 | def is_string(s): 30 | """判断是否是字符串 31 | """ 32 | return isinstance(s, basestring) 33 | 34 | 35 | def strQ2B(ustring): 36 | """全角符号转对应的半角符号 37 | """ 38 | rstring = '' 39 | for uchar in ustring: 40 | inside_code = ord(uchar) 41 | # 全角空格直接转换 42 | if inside_code == 12288: 43 | inside_code = 32 44 | # 全角字符(除空格)根据关系转化 45 | elif (inside_code >= 65281 and inside_code <= 65374): 46 | inside_code -= 65248 47 | rstring += unichr(inside_code) 48 | return rstring 49 | 50 | 51 | def string_matching(s, keywords): 52 | """判断s是否至少包含keywords中的至少一个字符串 53 | """ 54 | for k in keywords: 55 | if re.search(k, s): 56 | return True 57 | return False 58 | 59 | 60 | def convert_to_unicode(text, encoding='utf-8', errors='ignore'): 61 | """字符串转换为unicode格式(假设输入为utf-8格式) 62 | """ 63 | if is_py2: 64 | if isinstance(text, str): 65 | text = text.decode(encoding, errors=errors) 66 | else: 67 | if isinstance(text, bytes): 68 | text = text.decode(encoding, errors=errors) 69 | return text 70 | 71 | 72 | def convert_to_str(text, encoding='utf-8', errors='ignore'): 73 | """字符串转换为str格式(假设输入为utf-8格式) 74 | """ 75 | if is_py2: 76 | if isinstance(text, unicode): 77 | text = text.encode(encoding, errors=errors) 78 | else: 79 | if isinstance(text, bytes): 80 | text = text.decode(encoding, errors=errors) 81 | return text 82 | 83 | 84 | class open: 85 | """模仿python自带的open函数,主要是为了同时兼容py2和py3 86 | """ 87 | def __init__(self, name, mode='r', encoding=None, errors='ignore'): 88 | if is_py2: 89 | self.file = _open_(name, mode) 90 | else: 91 | self.file = _open_(name, mode, encoding=encoding, errors=errors) 92 | self.encoding = encoding 93 | self.errors = errors 94 | 95 | def __iter__(self): 96 | for l in self.file: 97 | if self.encoding: 98 | l = convert_to_unicode(l, self.encoding, self.errors) 99 | yield l 100 | 101 | def read(self): 102 | text = self.file.read() 103 | if self.encoding: 104 | text = convert_to_unicode(text, self.encoding, self.errors) 105 | return text 106 | 107 | def write(self, text): 108 | if self.encoding: 109 | text = convert_to_str(text, self.encoding, self.errors) 110 | self.file.write(text) 111 | 112 | def flush(self): 113 | self.file.flush() 114 | 115 | def close(self): 116 | self.file.close() 117 | 118 | def __enter__(self): 119 | return self 120 | 121 | def __exit__(self, type, value, tb): 122 | self.close() 123 | 124 | 125 | def parallel_apply( 126 | func, iterable, workers, max_queue_size, callback=None, dummy=False 127 | ): 128 | """多进程或多线程地将func应用到iterable的每个元素中。 129 | 注意这个apply是异步且无序的,也就是说依次输入a,b,c,但是 130 | 输出可能是func(c), func(a), func(b)。 131 | 参数: 132 | dummy: False是多进程/线性,True则是多线程/线性; 133 | callback: 处理单个输出的回调函数。 134 | """ 135 | if dummy: 136 | from multiprocessing.dummy import Pool, Queue 137 | else: 138 | from multiprocessing import Pool, Queue 139 | 140 | in_queue, out_queue = Queue(max_queue_size), Queue() 141 | 142 | def worker_step(in_queue, out_queue): 143 | # 单步函数包装成循环执行 144 | while True: 145 | i, d = in_queue.get() 146 | r = func(d) 147 | out_queue.put((i, r)) 148 | 149 | # 启动多进程/线程 150 | pool = Pool(workers, worker_step, (in_queue, out_queue)) 151 | 152 | if callback is None: 153 | results = [] 154 | 155 | # 后处理函数 156 | def process_out_queue(): 157 | out_count = 0 158 | for _ in range(out_queue.qsize()): 159 | i, d = out_queue.get() 160 | out_count += 1 161 | if callback is None: 162 | results.append((i, d)) 163 | else: 164 | callback(d) 165 | return out_count 166 | 167 | # 存入数据,取出结果 168 | in_count, out_count = 0, 0 169 | for i, d in enumerate(iterable): 170 | in_count += 1 171 | while True: 172 | try: 173 | in_queue.put((i, d), block=False) 174 | break 175 | except six.moves.queue.Full: 176 | out_count += process_out_queue() 177 | if in_count % max_queue_size == 0: 178 | out_count += process_out_queue() 179 | 180 | while out_count != in_count: 181 | out_count += process_out_queue() 182 | 183 | pool.terminate() 184 | 185 | if callback is None: 186 | results = sorted(results, key=lambda r: r[0]) 187 | return [r[1] for r in results] 188 | 189 | 190 | def sequence_padding(inputs, length=None, padding=0): 191 | """Numpy函数,将序列padding到同一长度 192 | """ 193 | if length is None: 194 | length = max([len(x) for x in inputs]) 195 | 196 | pad_width = [(0, 0) for _ in np.shape(inputs[0])] 197 | outputs = [] 198 | for x in inputs: 199 | x = x[:length] 200 | pad_width[0] = (0, length - len(x)) 201 | x = np.pad(x, pad_width, 'constant', constant_values=padding) 202 | outputs.append(x) 203 | 204 | return np.array(outputs) 205 | 206 | 207 | def text_segmentate(text, maxlen, seps='\n', strips=None): 208 | """将文本按照标点符号划分为若干个短句 209 | """ 210 | text = text.strip().strip(strips) 211 | if seps and len(text) > maxlen: 212 | pieces = text.split(seps[0]) 213 | text, texts = '', [] 214 | for i, p in enumerate(pieces): 215 | if text and p and len(text) + len(p) > maxlen - 1: 216 | texts.extend(text_segmentate(text, maxlen, seps[1:], strips)) 217 | text = '' 218 | if i + 1 == len(pieces): 219 | text = text + p 220 | else: 221 | text = text + p + seps[0] 222 | if text: 223 | texts.extend(text_segmentate(text, maxlen, seps[1:], strips)) 224 | return texts 225 | else: 226 | return [text] 227 | 228 | 229 | def is_one_of(x, ys): 230 | """判断x是否在ys之中 231 | 等价于x in ys,但有些情况下x in ys会报错 232 | """ 233 | for y in ys: 234 | if x is y: 235 | return True 236 | return False 237 | 238 | 239 | class DataGenerator(object): 240 | """数据生成器模版 241 | """ 242 | def __init__(self, data, batch_size=32, buffer_size=None): 243 | self.data = data 244 | self.batch_size = batch_size 245 | if hasattr(self.data, '__len__'): 246 | self.steps = len(self.data) // self.batch_size 247 | if len(self.data) % self.batch_size != 0: 248 | self.steps += 1 249 | else: 250 | self.steps = None 251 | self.buffer_size = buffer_size or batch_size * 1000 252 | 253 | def __len__(self): 254 | return self.steps 255 | 256 | def sample(self, random=False): 257 | """采样函数,每个样本同时返回一个is_end标记 258 | """ 259 | if random: 260 | if self.steps is None: 261 | 262 | def generator(): 263 | caches, isfull = [], False 264 | for d in self.data: 265 | caches.append(d) 266 | if isfull: 267 | i = np.random.randint(len(caches)) 268 | yield caches.pop(i) 269 | elif len(caches) == self.buffer_size: 270 | isfull = True 271 | while caches: 272 | i = np.random.randint(len(caches)) 273 | yield caches.pop(i) 274 | 275 | else: 276 | 277 | def generator(): 278 | indices = list(range(len(self.data))) 279 | np.random.shuffle(indices) 280 | for i in indices: 281 | yield self.data[i] 282 | 283 | data = generator() 284 | else: 285 | data = iter(self.data) 286 | 287 | d_current = next(data) 288 | for d_next in data: 289 | yield False, d_current 290 | d_current = d_next 291 | 292 | yield True, d_current 293 | 294 | def __iter__(self, random=False): 295 | raise NotImplementedError 296 | 297 | def forfit(self): 298 | while True: 299 | for d in self.__iter__(True): 300 | yield d 301 | 302 | 303 | class ViterbiDecoder(object): 304 | """Viterbi解码算法基类 305 | """ 306 | def __init__(self, trans, starts=None, ends=None): 307 | self.trans = trans 308 | self.num_labels = len(trans) 309 | self.non_starts = [] 310 | self.non_ends = [] 311 | if starts is not None: 312 | for i in range(self.num_labels): 313 | if i not in starts: 314 | self.non_starts.append(i) 315 | if ends is not None: 316 | for i in range(self.num_labels): 317 | if i not in ends: 318 | self.non_ends.append(i) 319 | 320 | def decode(self, nodes): 321 | """nodes.shape=[seq_len, num_labels] 322 | """ 323 | # 预处理 324 | nodes[0, self.non_starts] -= np.inf 325 | nodes[-1, self.non_ends] -= np.inf 326 | 327 | # 动态规划 328 | labels = np.arange(self.num_labels).reshape((1, -1)) 329 | scores = nodes[0].reshape((-1, 1)) 330 | paths = labels 331 | for l in range(1, len(nodes)): 332 | M = scores + self.trans + nodes[l].reshape((1, -1)) 333 | idxs = M.argmax(0) 334 | scores = M.max(0).reshape((-1, 1)) 335 | paths = np.concatenate([paths[:, idxs], labels], 0) 336 | 337 | # 最优路径 338 | return paths[:, scores[:, 0].argmax()] 339 | 340 | 341 | def softmax(x, axis=-1): 342 | """numpy版softmax 343 | """ 344 | x = x - x.max(axis=axis, keepdims=True) 345 | x = np.exp(x) 346 | return x / x.sum(axis=axis, keepdims=True) 347 | 348 | 349 | class AutoRegressiveDecoder(object): 350 | """通用自回归生成模型解码基类 351 | 包含beam search和random sample两种策略 352 | """ 353 | def __init__(self, start_id, end_id, maxlen, minlen=None): 354 | self.start_id = start_id 355 | self.end_id = end_id 356 | self.maxlen = maxlen 357 | self.minlen = minlen or 1 358 | if start_id is None: 359 | self.first_output_ids = np.empty((1, 0), dtype=int) 360 | else: 361 | self.first_output_ids = np.array([[self.start_id]]) 362 | 363 | @staticmethod 364 | def wraps(default_rtype='probas', use_states=False): 365 | """用来进一步完善predict函数 366 | 目前包含:1. 设置rtype参数,并做相应处理; 367 | 2. 确定states的使用,并做相应处理。 368 | """ 369 | def actual_decorator(predict): 370 | def new_predict( 371 | self, inputs, output_ids, states, rtype=default_rtype 372 | ): 373 | assert rtype in ['probas', 'logits'] 374 | prediction = predict(self, inputs, output_ids, states) 375 | 376 | if not use_states: 377 | prediction = (prediction, None) 378 | 379 | if default_rtype == 'logits': 380 | prediction = (softmax(prediction[0]), prediction[1]) 381 | 382 | if rtype == 'probas': 383 | return prediction 384 | else: 385 | return np.log(prediction[0] + 1e-12), prediction[1] 386 | 387 | return new_predict 388 | 389 | return actual_decorator 390 | 391 | def predict(self, inputs, output_ids, states=None, rtype='logits'): 392 | """用户需自定义递归预测函数 393 | 说明:rtype为字符串logits或probas,用户定义的时候,应当根据rtype来 394 | 返回不同的结果,rtype=probas时返回归一化的概率,rtype=logits时 395 | 则返回softmax前的结果或者概率对数。 396 | 返回:二元组 (得分或概率, states) 397 | """ 398 | raise NotImplementedError 399 | 400 | def beam_search(self, inputs, topk, states=None, min_ends=1): 401 | """beam search解码 402 | 说明:这里的topk即beam size; 403 | 返回:最优解码序列。 404 | """ 405 | inputs = [np.array([i]) for i in inputs] 406 | output_ids, output_scores = self.first_output_ids, np.zeros(1) 407 | for step in range(self.maxlen): 408 | scores, states = self.predict( 409 | inputs, output_ids, states, 'logits' 410 | ) # 计算当前得分 411 | if step == 0: # 第1步预测后将输入重复topk次 412 | inputs = [np.repeat(i, topk, axis=0) for i in inputs] 413 | scores = output_scores.reshape((-1, 1)) + scores # 综合累积得分 414 | indices = scores.argpartition(-topk, axis=None)[-topk:] # 仅保留topk 415 | indices_1 = indices // scores.shape[1] # 行索引 416 | indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引 417 | output_ids = np.concatenate([output_ids[indices_1], indices_2], 418 | 1) # 更新输出 419 | output_scores = np.take_along_axis( 420 | scores, indices, axis=None 421 | ) # 更新得分 422 | end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记 423 | if output_ids.shape[1] >= self.minlen: # 最短长度判断 424 | best_one = output_scores.argmax() # 得分最大的那个 425 | if end_counts[best_one] == min_ends: # 如果已经终止 426 | return output_ids[best_one] # 直接输出 427 | else: # 否则,只保留未完成部分 428 | flag = (end_counts < min_ends) # 标记未完成序列 429 | if not flag.all(): # 如果有已完成的 430 | inputs = [i[flag] for i in inputs] # 扔掉已完成序列 431 | output_ids = output_ids[flag] # 扔掉已完成序列 432 | output_scores = output_scores[flag] # 扔掉已完成序列 433 | end_counts = end_counts[flag] # 扔掉已完成end计数 434 | topk = flag.sum() # topk相应变化 435 | # 达到长度直接输出 436 | return output_ids[output_scores.argmax()] 437 | 438 | def random_sample(self, inputs, n, topk=None, topp=None, states=None, min_ends=1): 439 | """随机采样n个结果 440 | 说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp 441 | 表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。 442 | 返回:n个解码序列组成的list。 443 | """ 444 | inputs = [np.array([i]) for i in inputs] 445 | output_ids = self.first_output_ids 446 | results = [] 447 | for step in range(self.maxlen): 448 | probas, states = self.predict( 449 | inputs, output_ids, states, 'probas' 450 | ) # 计算当前概率 451 | probas /= probas.sum(axis=1, keepdims=True) # 确保归一化 452 | if step == 0: # 第1步预测后将结果重复n次 453 | probas = np.repeat(probas, n, axis=0) 454 | inputs = [np.repeat(i, n, axis=0) for i in inputs] 455 | output_ids = np.repeat(output_ids, n, axis=0) 456 | if topk is not None: 457 | k_indices = probas.argpartition(-topk, 458 | axis=1)[:, -topk:] # 仅保留topk 459 | probas = np.take_along_axis(probas, k_indices, axis=1) # topk概率 460 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化 461 | if topp is not None: 462 | p_indices = probas.argsort(axis=1)[:, ::-1] # 从高到低排序 463 | probas = np.take_along_axis(probas, p_indices, axis=1) # 排序概率 464 | cumsum_probas = np.cumsum(probas, axis=1) # 累积概率 465 | flag = np.roll(cumsum_probas >= topp, 1, axis=1) # 标记超过topp的部分 466 | flag[:, 0] = False # 结合上面的np.roll,实现平移一位的效果 467 | probas[flag] = 0 # 后面的全部置零 468 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化 469 | sample_func = lambda p: np.random.choice(len(p), p=p) # 按概率采样函数 470 | sample_ids = np.apply_along_axis(sample_func, 1, probas) # 执行采样 471 | sample_ids = sample_ids.reshape((-1, 1)) # 对齐形状 472 | if topp is not None: 473 | sample_ids = np.take_along_axis(p_indices, sample_ids, axis=1) # 对齐原id 474 | if topk is not None: 475 | sample_ids = np.take_along_axis(k_indices, sample_ids, axis=1) # 对齐原id 476 | output_ids = np.concatenate([output_ids, sample_ids], 1) # 更新输出 477 | end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记 478 | if output_ids.shape[1] >= self.minlen: # 最短长度判断 479 | flag = (end_counts == min_ends) # 标记已完成序列 480 | if flag.any(): # 如果有已完成的 481 | for ids in output_ids[flag]: # 存好已完成序列 482 | results.append(ids) 483 | flag = (flag == False) # 标记未完成序列 484 | inputs = [i[flag] for i in inputs] # 只保留未完成部分输入 485 | output_ids = output_ids[flag] # 只保留未完成部分候选集 486 | end_counts = end_counts[flag] # 只保留未完成部分end计数 487 | if len(output_ids) == 0: 488 | break 489 | # 如果还有未完成序列,直接放入结果 490 | for ids in output_ids: 491 | results.append(ids) 492 | # 返回结果 493 | return results 494 | 495 | 496 | def insert_arguments(**arguments): 497 | """装饰器,为类方法增加参数 498 | (主要用于类的__init__方法) 499 | """ 500 | def actual_decorator(func): 501 | def new_func(self, *args, **kwargs): 502 | for k, v in arguments.items(): 503 | if k in kwargs: 504 | v = kwargs.pop(k) 505 | setattr(self, k, v) 506 | return func(self, *args, **kwargs) 507 | 508 | return new_func 509 | 510 | return actual_decorator 511 | 512 | 513 | def delete_arguments(*arguments): 514 | """装饰器,为类方法删除参数 515 | (主要用于类的__init__方法) 516 | """ 517 | def actual_decorator(func): 518 | def new_func(self, *args, **kwargs): 519 | for k in arguments: 520 | if k in kwargs: 521 | raise TypeError( 522 | '%s got an unexpected keyword argument \'%s\'' % 523 | (self.__class__.__name__, k) 524 | ) 525 | return func(self, *args, **kwargs) 526 | 527 | return new_func 528 | 529 | return actual_decorator 530 | 531 | 532 | def longest_common_substring(source, target): 533 | """最长公共子串(source和target的最长公共切片区间) 534 | 返回:子串长度, 所在区间(四元组) 535 | 注意:最长公共子串可能不止一个,所返回的区间只代表其中一个。 536 | """ 537 | c, l, span = defaultdict(int), 0, (0, 0, 0, 0) 538 | for i, si in enumerate(source, 1): 539 | for j, tj in enumerate(target, 1): 540 | if si == tj: 541 | c[i, j] = c[i - 1, j - 1] + 1 542 | if c[i, j] > l: 543 | l = c[i, j] 544 | span = (i - l, i, j - l, j) 545 | return l, span 546 | 547 | 548 | def longest_common_subsequence(source, target): 549 | """最长公共子序列(source和target的最长非连续子序列) 550 | 返回:子序列长度, 映射关系(映射对组成的list) 551 | 注意:最长公共子序列可能不止一个,所返回的映射只代表其中一个。 552 | """ 553 | c = defaultdict(int) 554 | for i, si in enumerate(source, 1): 555 | for j, tj in enumerate(target, 1): 556 | if si == tj: 557 | c[i, j] = c[i - 1, j - 1] + 1 558 | elif c[i, j - 1] > c[i - 1, j]: 559 | c[i, j] = c[i, j - 1] 560 | else: 561 | c[i, j] = c[i - 1, j] 562 | l, mapping = c[len(source), len(target)], [] 563 | i, j = len(source) - 1, len(target) - 1 564 | while len(mapping) < l: 565 | if source[i] == target[j]: 566 | mapping.append((i, j)) 567 | i, j = i - 1, j - 1 568 | elif c[i + 1, j] > c[i, j + 1]: 569 | j = j - 1 570 | else: 571 | i = i - 1 572 | return l, mapping[::-1] 573 | 574 | 575 | class WebServing(object): 576 | """简单的Web接口 577 | 用法: 578 | arguments = {'text': (None, True), 'n': (int, False)} 579 | web = WebServing(port=8864) 580 | web.route('/gen_synonyms', gen_synonyms, arguments) 581 | web.start() 582 | # 然后访问 http://127.0.0.1:8864/gen_synonyms?text=你好 583 | 说明: 584 | 基于bottlepy简单封装,仅作为临时测试使用,不保证性能。 585 | 目前仅保证支持 Tensorflow 1.x + Keras <= 2.3.1。 586 | 欢迎有经验的开发者帮忙改进。 587 | 依赖: 588 | pip install bottle 589 | pip install paste 590 | (如果不用 server='paste' 的话,可以不装paste库) 591 | """ 592 | def __init__(self, host='0.0.0.0', port=8000, server='paste'): 593 | 594 | import tensorflow as tf 595 | from bert4keras.backend import K 596 | import bottle 597 | 598 | self.host = host 599 | self.port = port 600 | self.server = server 601 | self.graph = tf.get_default_graph() 602 | self.sess = K.get_session() 603 | self.set_session = K.set_session 604 | self.bottle = bottle 605 | 606 | def wraps(self, func, arguments, method='GET'): 607 | """封装为接口函数 608 | 参数: 609 | func:要转换为接口的函数,需要保证输出可以json化,即需要 610 | 保证 json.dumps(func(inputs)) 能被执行成功; 611 | arguments:声明func所需参数,其中key为参数名,value[0]为 612 | 对应的转换函数(接口获取到的参数值都是字符串 613 | 型),value[1]为该参数是否必须; 614 | method:GET或者POST。 615 | """ 616 | def new_func(): 617 | outputs = {'code': 0, 'desc': u'succeeded', 'data': {}} 618 | kwargs = {} 619 | for key, value in arguments.items(): 620 | if method == 'GET': 621 | result = self.bottle.request.GET.get(key) 622 | else: 623 | result = self.bottle.request.POST.get(key) 624 | if result is None: 625 | if value[1]: 626 | outputs['code'] = 1 627 | outputs['desc'] = 'lack of "%s" argument' % key 628 | return json.dumps(outputs, ensure_ascii=False) 629 | else: 630 | if value[0] is None: 631 | result = convert_to_unicode(result) 632 | else: 633 | result = value[0](result) 634 | kwargs[key] = result 635 | try: 636 | with self.graph.as_default(): 637 | self.set_session(self.sess) 638 | outputs['data'] = func(**kwargs) 639 | except Exception as e: 640 | outputs['code'] = 2 641 | outputs['desc'] = str(e) 642 | return json.dumps(outputs, ensure_ascii=False) 643 | 644 | return new_func 645 | 646 | def route(self, path, func, arguments, method='GET'): 647 | """添加接口 648 | """ 649 | func = self.wraps(func, arguments, method) 650 | self.bottle.route(path, method=method)(func) 651 | 652 | def start(self): 653 | """启动服务 654 | """ 655 | self.bottle.run(host=self.host, port=self.port, server=self.server) 656 | 657 | 658 | class Hook: 659 | """注入uniout模块,实现import时才触发 660 | """ 661 | def __init__(self, module): 662 | self.module = module 663 | 664 | def __getattr__(self, attr): 665 | """使得 from bert4keras.backend import uniout 666 | 等效于 import uniout (自动识别Python版本,Python3 667 | 下则无操作。) 668 | """ 669 | if attr == 'uniout': 670 | if is_py2: 671 | import uniout 672 | else: 673 | return getattr(self.module, attr) 674 | 675 | 676 | Hook.__name__ = __name__ 677 | sys.modules[__name__] = Hook(sys.modules[__name__]) 678 | del Hook 679 | -------------------------------------------------------------------------------- /lstm/BERT_tf2/bert4keras/tokenizers.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 工具函数 3 | 4 | import unicodedata, re, os 5 | from bert4keras.snippets import is_string, is_py2 6 | from bert4keras.snippets import open 7 | 8 | 9 | def load_vocab(dict_path, encoding='utf-8', simplified=False, startswith=None): 10 | """从bert的词典文件中读取词典 11 | """ 12 | token_dict = {} 13 | if os.path.isdir(dict_path): 14 | dict_path = os.path.join(dict_path, 'vocab.txt') 15 | with open(dict_path, encoding=encoding) as reader: 16 | for line in reader: 17 | token = line.split() 18 | token = token[0] if token else line.strip() 19 | token_dict[token] = len(token_dict) 20 | 21 | if simplified: # 过滤冗余部分token 22 | new_token_dict, keep_tokens = {}, [] 23 | startswith = startswith or [] 24 | for t in startswith: 25 | new_token_dict[t] = len(new_token_dict) 26 | keep_tokens.append(token_dict[t]) 27 | 28 | for t, _ in sorted(token_dict.items(), key=lambda s: s[1]): 29 | if t not in new_token_dict: 30 | keep = True 31 | if len(t) > 1: 32 | for c in Tokenizer.stem(t): 33 | if ( 34 | Tokenizer._is_cjk_character(c) or 35 | Tokenizer._is_punctuation(c) 36 | ): 37 | keep = False 38 | break 39 | if keep: 40 | new_token_dict[t] = len(new_token_dict) 41 | keep_tokens.append(token_dict[t]) 42 | 43 | return new_token_dict, keep_tokens 44 | else: 45 | return token_dict 46 | 47 | 48 | def save_vocab(dict_path, token_dict, encoding='utf-8'): 49 | """将词典(比如精简过的)保存为文件 50 | """ 51 | with open(dict_path, 'w', encoding=encoding) as writer: 52 | for k, v in sorted(token_dict.items(), key=lambda s: s[1]): 53 | writer.write(k + '\n') 54 | 55 | 56 | class BasicTokenizer(object): 57 | """分词器基类 58 | """ 59 | def __init__(self, token_start='[CLS]', token_end='[SEP]'): 60 | """初始化 61 | """ 62 | self._token_pad = '[PAD]' 63 | self._token_unk = '[UNK]' 64 | self._token_mask = '[MASK]' 65 | self._token_start = token_start 66 | self._token_end = token_end 67 | 68 | def tokenize(self, text, maxlen=None): 69 | """分词函数 70 | """ 71 | tokens = self._tokenize(text) 72 | if self._token_start is not None: 73 | tokens.insert(0, self._token_start) 74 | if self._token_end is not None: 75 | tokens.append(self._token_end) 76 | 77 | if maxlen is not None: 78 | index = int(self._token_end is not None) + 1 79 | self.truncate_sequence(maxlen, tokens, None, -index) 80 | 81 | return tokens 82 | 83 | def token_to_id(self, token): 84 | """token转换为对应的id 85 | """ 86 | raise NotImplementedError 87 | 88 | def tokens_to_ids(self, tokens): 89 | """token序列转换为对应的id序列 90 | """ 91 | return [self.token_to_id(token) for token in tokens] 92 | 93 | def truncate_sequence( 94 | self, maxlen, first_sequence, second_sequence=None, pop_index=-1 95 | ): 96 | """截断总长度 97 | """ 98 | if second_sequence is None: 99 | second_sequence = [] 100 | 101 | while True: 102 | total_length = len(first_sequence) + len(second_sequence) 103 | if total_length <= maxlen: 104 | break 105 | elif len(first_sequence) > len(second_sequence): 106 | first_sequence.pop(pop_index) 107 | else: 108 | second_sequence.pop(pop_index) 109 | 110 | def encode( 111 | self, first_text, second_text=None, maxlen=None, pattern='S*E*E' 112 | ): 113 | """输出文本对应token id和segment id 114 | """ 115 | if is_string(first_text): 116 | first_tokens = self.tokenize(first_text) 117 | else: 118 | first_tokens = first_text 119 | 120 | if second_text is None: 121 | second_tokens = None 122 | elif is_string(second_text): 123 | if pattern == 'S*E*E': 124 | idx = int(bool(self._token_start)) 125 | second_tokens = self.tokenize(second_text)[idx:] 126 | elif pattern == 'S*ES*E': 127 | second_tokens = self.tokenize(second_text) 128 | else: 129 | second_tokens = second_text 130 | 131 | if maxlen is not None: 132 | self.truncate_sequence(maxlen, first_tokens, second_tokens, -2) 133 | 134 | first_token_ids = self.tokens_to_ids(first_tokens) 135 | first_segment_ids = [0] * len(first_token_ids) 136 | 137 | if second_text is not None: 138 | second_token_ids = self.tokens_to_ids(second_tokens) 139 | second_segment_ids = [1] * len(second_token_ids) 140 | first_token_ids.extend(second_token_ids) 141 | first_segment_ids.extend(second_segment_ids) 142 | 143 | return first_token_ids, first_segment_ids 144 | 145 | def id_to_token(self, i): 146 | """id序列为对应的token 147 | """ 148 | raise NotImplementedError 149 | 150 | def ids_to_tokens(self, ids): 151 | """id序列转换为对应的token序列 152 | """ 153 | return [self.id_to_token(i) for i in ids] 154 | 155 | def decode(self, ids): 156 | """转为可读文本 157 | """ 158 | raise NotImplementedError 159 | 160 | def _tokenize(self, text): 161 | """基本分词函数 162 | """ 163 | raise NotImplementedError 164 | 165 | 166 | class Tokenizer(BasicTokenizer): 167 | """Bert原生分词器 168 | 纯Python实现,代码修改自keras_bert的tokenizer实现 169 | """ 170 | def __init__( 171 | self, token_dict, do_lower_case=False, pre_tokenize=None, **kwargs 172 | ): 173 | """这里的pre_tokenize是外部传入的分词函数,用作对文本进行预分词。如果传入 174 | pre_tokenize,则先执行pre_tokenize(text),然后在它的基础上执行原本的 175 | tokenize函数。 176 | """ 177 | super(Tokenizer, self).__init__(**kwargs) 178 | if is_string(token_dict): 179 | token_dict = load_vocab(token_dict) 180 | 181 | self._do_lower_case = do_lower_case 182 | self._pre_tokenize = pre_tokenize 183 | self._token_dict = token_dict 184 | self._token_dict_inv = {v: k for k, v in token_dict.items()} 185 | self._vocab_size = len(token_dict) 186 | 187 | for token in ['pad', 'unk', 'mask', 'start', 'end']: 188 | try: 189 | _token_id = token_dict[getattr(self, '_token_%s' % token)] 190 | setattr(self, '_token_%s_id' % token, _token_id) 191 | except: 192 | pass 193 | 194 | def token_to_id(self, token): 195 | """token转换为对应的id 196 | """ 197 | return self._token_dict.get(token, self._token_unk_id) 198 | 199 | def id_to_token(self, i): 200 | """id转换为对应的token 201 | """ 202 | return self._token_dict_inv[i] 203 | 204 | def decode(self, ids, tokens=None): 205 | """转为可读文本 206 | """ 207 | tokens = tokens or self.ids_to_tokens(ids) 208 | tokens = [token for token in tokens if not self._is_special(token)] 209 | 210 | text, flag = '', False 211 | for i, token in enumerate(tokens): 212 | if token[:2] == '##': 213 | text += token[2:] 214 | elif len(token) == 1 and self._is_cjk_character(token): 215 | text += token 216 | elif len(token) == 1 and self._is_punctuation(token): 217 | text += token 218 | text += ' ' 219 | elif i > 0 and self._is_cjk_character(text[-1]): 220 | text += token 221 | else: 222 | text += ' ' 223 | text += token 224 | 225 | text = re.sub(' +', ' ', text) 226 | text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text) 227 | punctuation = self._cjk_punctuation() + '+-/={(<[' 228 | punctuation_regex = '|'.join([re.escape(p) for p in punctuation]) 229 | punctuation_regex = '(%s) ' % punctuation_regex 230 | text = re.sub(punctuation_regex, '\\1', text) 231 | text = re.sub('(\d\.) (\d)', '\\1\\2', text) 232 | 233 | return text.strip() 234 | 235 | def _tokenize(self, text, pre_tokenize=True): 236 | """基本分词函数 237 | """ 238 | if self._do_lower_case: 239 | if is_py2: 240 | text = unicode(text) 241 | text = text.lower() 242 | text = unicodedata.normalize('NFD', text) 243 | text = ''.join([ 244 | ch for ch in text if unicodedata.category(ch) != 'Mn' 245 | ]) 246 | 247 | if pre_tokenize and self._pre_tokenize is not None: 248 | tokens = [] 249 | for token in self._pre_tokenize(text): 250 | if token in self._token_dict: 251 | tokens.append(token) 252 | else: 253 | tokens.extend(self._tokenize(token, False)) 254 | return tokens 255 | 256 | spaced = '' 257 | for ch in text: 258 | if self._is_punctuation(ch) or self._is_cjk_character(ch): 259 | spaced += ' ' + ch + ' ' 260 | elif self._is_space(ch): 261 | spaced += ' ' 262 | elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch): 263 | continue 264 | else: 265 | spaced += ch 266 | 267 | tokens = [] 268 | for word in spaced.strip().split(): 269 | tokens.extend(self._word_piece_tokenize(word)) 270 | 271 | return tokens 272 | 273 | def _word_piece_tokenize(self, word): 274 | """word内分成subword 275 | """ 276 | if word in self._token_dict: 277 | return [word] 278 | 279 | tokens = [] 280 | start, stop = 0, 0 281 | while start < len(word): 282 | stop = len(word) 283 | while stop > start: 284 | sub = word[start:stop] 285 | if start > 0: 286 | sub = '##' + sub 287 | if sub in self._token_dict: 288 | break 289 | stop -= 1 290 | if start == stop: 291 | stop += 1 292 | tokens.append(sub) 293 | start = stop 294 | 295 | return tokens 296 | 297 | @staticmethod 298 | def stem(token): 299 | """获取token的“词干”(如果是##开头,则自动去掉##) 300 | """ 301 | if token[:2] == '##': 302 | return token[2:] 303 | else: 304 | return token 305 | 306 | @staticmethod 307 | def _is_space(ch): 308 | """空格类字符判断 309 | """ 310 | return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \ 311 | unicodedata.category(ch) == 'Zs' 312 | 313 | @staticmethod 314 | def _is_punctuation(ch): 315 | """标点符号类字符判断(全/半角均在此内) 316 | 提醒:unicodedata.category这个函数在py2和py3下的 317 | 表现可能不一样,比如u'§'字符,在py2下的结果为'So', 318 | 在py3下的结果是'Po'。 319 | """ 320 | code = ord(ch) 321 | return 33 <= code <= 47 or \ 322 | 58 <= code <= 64 or \ 323 | 91 <= code <= 96 or \ 324 | 123 <= code <= 126 or \ 325 | unicodedata.category(ch).startswith('P') 326 | 327 | @staticmethod 328 | def _cjk_punctuation(): 329 | return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002' 330 | 331 | @staticmethod 332 | def _is_cjk_character(ch): 333 | """CJK类字符判断(包括中文字符也在此列) 334 | 参考:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 335 | """ 336 | code = ord(ch) 337 | return 0x4E00 <= code <= 0x9FFF or \ 338 | 0x3400 <= code <= 0x4DBF or \ 339 | 0x20000 <= code <= 0x2A6DF or \ 340 | 0x2A700 <= code <= 0x2B73F or \ 341 | 0x2B740 <= code <= 0x2B81F or \ 342 | 0x2B820 <= code <= 0x2CEAF or \ 343 | 0xF900 <= code <= 0xFAFF or \ 344 | 0x2F800 <= code <= 0x2FA1F 345 | 346 | @staticmethod 347 | def _is_control(ch): 348 | """控制类字符判断 349 | """ 350 | return unicodedata.category(ch) in ('Cc', 'Cf') 351 | 352 | @staticmethod 353 | def _is_special(ch): 354 | """判断是不是有特殊含义的符号 355 | """ 356 | return bool(ch) and (ch[0] == '[') and (ch[-1] == ']') 357 | 358 | def rematch(self, text, tokens): 359 | """给出原始的text和tokenize后的tokens的映射关系 360 | """ 361 | if is_py2: 362 | text = unicode(text) 363 | 364 | if self._do_lower_case: 365 | text = text.lower() 366 | 367 | normalized_text, char_mapping = '', [] 368 | for i, ch in enumerate(text): 369 | if self._do_lower_case: 370 | ch = unicodedata.normalize('NFD', ch) 371 | ch = ''.join([c for c in ch if unicodedata.category(c) != 'Mn']) 372 | ch = ''.join([ 373 | c for c in ch 374 | if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c)) 375 | ]) 376 | normalized_text += ch 377 | char_mapping.extend([i] * len(ch)) 378 | 379 | text, token_mapping, offset = normalized_text, [], 0 380 | for token in tokens: 381 | if self._is_special(token): 382 | token_mapping.append([]) 383 | else: 384 | token = self.stem(token) 385 | start = text[offset:].index(token) + offset 386 | end = start + len(token) 387 | token_mapping.append(char_mapping[start:end]) 388 | offset = end 389 | 390 | return token_mapping 391 | 392 | 393 | class SpTokenizer(BasicTokenizer): 394 | """基于SentencePiece模型的封装,使用上跟Tokenizer基本一致。 395 | """ 396 | def __init__(self, sp_model_path, **kwargs): 397 | super(SpTokenizer, self).__init__(**kwargs) 398 | import sentencepiece as spm 399 | self.sp_model = spm.SentencePieceProcessor() 400 | self.sp_model.Load(sp_model_path) 401 | self._token_pad = self.sp_model.id_to_piece(self.sp_model.pad_id()) 402 | self._token_unk = self.sp_model.id_to_piece(self.sp_model.unk_id()) 403 | self._vocab_size = self.sp_model.get_piece_size() 404 | 405 | for token in ['pad', 'unk', 'mask', 'start', 'end']: 406 | try: 407 | _token = getattr(self, '_token_%s' % token) 408 | _token_id = self.sp_model.piece_to_id(_token) 409 | setattr(self, '_token_%s_id' % token, _token_id) 410 | except: 411 | pass 412 | 413 | def token_to_id(self, token): 414 | """token转换为对应的id 415 | """ 416 | return self.sp_model.piece_to_id(token) 417 | 418 | def id_to_token(self, i): 419 | """id转换为对应的token 420 | """ 421 | if i < self._vocab_size: 422 | return self.sp_model.id_to_piece(i) 423 | else: 424 | return '' 425 | 426 | def decode(self, ids): 427 | """转为可读文本 428 | """ 429 | ids = [i for i in ids if self._is_decodable(i)] 430 | text = self.sp_model.decode_ids(ids) 431 | return text.decode('utf-8') if is_py2 else text 432 | 433 | def _tokenize(self, text): 434 | """基本分词函数 435 | """ 436 | tokens = self.sp_model.encode_as_pieces(text) 437 | return tokens 438 | 439 | def _is_special(self, i): 440 | """判断是不是有特殊含义的符号 441 | """ 442 | return self.sp_model.is_control(i) or \ 443 | self.sp_model.is_unknown(i) or \ 444 | self.sp_model.is_unused(i) 445 | 446 | def _is_decodable(self, i): 447 | """判断是否应该被解码输出 448 | """ 449 | return (i < self._vocab_size) and not self._is_special(i) 450 | -------------------------------------------------------------------------------- /lstm/BERT_tf2/bert_tools.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, math, re 2 | import tensorflow as tf 3 | from tqdm import tqdm 4 | import numpy as np 5 | import tensorflow.keras.backend as K 6 | from tensorflow.keras.callbacks import Callback 7 | from tensorflow.keras.models import * 8 | 9 | from bert4keras.tokenizers import Tokenizer 10 | from bert4keras.snippets import to_array 11 | from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr, extend_with_weight_decay 12 | from bert4keras.models import build_transformer_model 13 | from bert4keras.layers import * 14 | from tensorflow.keras.initializers import TruncatedNormal 15 | 16 | en_dict_path = r'../tfhub/uncased_L-12_H-768_A-12/vocab.txt' 17 | cn_dict_path = r'../tfhub/chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt' 18 | 19 | tokenizer = Tokenizer(cn_dict_path, do_lower_case=True) 20 | language = 'cn' 21 | 22 | def switch_to_en(): 23 | global tokenizer, language 24 | tokenizer = Tokenizer(en_dict_path, do_lower_case=True) 25 | language = 'en' 26 | 27 | def convert_sentences(sents, maxlen=256): 28 | shape = (2, len(sents), maxlen) 29 | X = np.zeros(shape, dtype='int32') 30 | for ii, sent in tqdm(enumerate(sents), desc="Converting sentences"): 31 | tids, segs = tokenizer.encode(sent, maxlen=maxlen) 32 | X[0,ii,:len(tids)] = tids 33 | X[1,ii,:len(segs)] = segs 34 | return [X[0], X[1]] 35 | 36 | def convert_tokens(sents, maxlen=256): 37 | shape = (2, len(sents), maxlen) 38 | X = np.zeros(shape, dtype='int32') 39 | for ii, sent in tqdm(enumerate(sents), desc="Converting tokens"): 40 | tids = tokenizer.tokens_to_ids(sent) 41 | X[0,ii,:len(tids)] = tids 42 | return [X[0], X[1]] 43 | 44 | def lock_transformer_layers(transformer, layers=-1): 45 | def _filter(layers, prefix): 46 | return [x for x in transformer.layers if x.name.startswith(prefix)] 47 | if hasattr(transformer, 'model'): transformer = transformer.model 48 | if layers >= 0: 49 | print('locking', 'Embedding-*') 50 | for layer in _filter(transformer, 'Embedding-'): 51 | layer.trainable = False 52 | print('locking', 'Transformer-[%d-%d]-*' % (0, layers-1)) 53 | for index in range(layers): 54 | for layer in _filter(transformer, 'Transformer-%d-' % index): 55 | layer.trainable = False 56 | 57 | def unlock_transformer_layers(transformer): 58 | if hasattr(transformer, 'model'): transformer = transformer.model 59 | for layer in transformer.layers: 60 | layer.trainable = True 61 | 62 | def get_suggested_optimizer(init_lr=5e-5, total_steps=None): 63 | lr_schedule = {1000:1, 10000:0.01} 64 | if total_steps is not None: 65 | lr_schedule = {total_steps//10:1, total_steps:0.1} 66 | optimizer = extend_with_weight_decay(Adam) 67 | optimizer = extend_with_piecewise_linear_lr(optimizer) 68 | optimizer_params = { 69 | 'learning_rate': init_lr, 70 | 'lr_schedule': lr_schedule, 71 | 'weight_decay_rate': 0.01, 72 | 'exclude_from_weight_decay': ['Norm', 'bias'], 73 | 'bias_correction': False, 74 | } 75 | optimizer = optimizer(**optimizer_params) 76 | return optimizer 77 | 78 | def convert_single_setences(sens, maxlen, tokenizer, details=False): 79 | X = np.zeros((len(sens), maxlen), dtype='int32') 80 | datas = [] 81 | for i, s in enumerate(sens): 82 | tokens = tokenizer.tokenize(s)[:maxlen-2] 83 | if details: 84 | otokens = restore_token_list(s, tokens) 85 | datas.append({'id':i, 's':s, 'otokens':otokens}) 86 | tt = ['[CLS]'] + tokens + ['[SEP]'] 87 | tids = tokenizer.convert_tokens_to_ids(tt) 88 | X[i,:len(tids)] = tids 89 | if details: return datas, X 90 | return X 91 | 92 | def build_classifier(classes, bert_h5=None): 93 | if bert_h5 is None: 94 | bert_h5 = '../tfhub/chinese_roberta_wwm_ext.h5' if language == 'cn' else '../tfhub/bert_uncased.h5' 95 | bert = load_model(bert_h5) 96 | output = Lambda(lambda x: x[:,0], name='CLS-token')(bert.output) 97 | if classes == 2: 98 | output = Dense(1, activation='sigmoid', kernel_initializer=TruncatedNormal(stddev=0.02))(output) 99 | else: 100 | output = Dense(classes, activation='softmax', kernel_initializer=TruncatedNormal(stddev=0.02))(output) 101 | model = Model(bert.input, output) 102 | model.bert_encoder = bert 103 | return model 104 | 105 | 106 | ## THESE FUNCTIONS ARE TESTED FOR CHS LANGUAGE ONLY 107 | def gen_token_list_inv_pointer(sent, token_list): 108 | zz = tokenizer.rematch(sent, token_list) 109 | return [x[0] for x in zz if len(x) > 0] 110 | sent = sent.lower() 111 | otiis = []; iis = 0 112 | for it, token in enumerate(token_list): 113 | otoken = token.lstrip('#') 114 | if token[0] == '[' and token[-1] == ']': otoken = '' 115 | niis = iis 116 | while niis <= len(sent): 117 | if sent[niis:].startswith(otoken): break 118 | if otoken in '-"' and sent[niis][0] in '—“”': break 119 | niis += 1 120 | if niis >= len(sent): niis = iis 121 | otiis.append(niis) 122 | iis = niis + max(1, len(otoken)) 123 | for tt, ii in zip(token_list, otiis): print(tt, sent[ii:ii+len(tt.lstrip('#'))]) 124 | for i, iis in enumerate(otiis): 125 | assert iis < len(sent) 126 | otoken = token_list[i].strip('#') 127 | assert otoken == '[UNK]' or sent[iis:iis+len(otoken)] == otoken 128 | return otiis 129 | 130 | # restore [UNK] tokens to the original tokens 131 | def restore_token_list(sent, token_list): 132 | if token_list[0] == '[CLS]': token_list = token_list[1:-1] 133 | invp = gen_token_list_inv_pointer(sent, token_list) 134 | invp.append(len(sent)) 135 | otokens = [sent[u:v] for u,v in zip(invp, invp[1:])] 136 | processed = -1 137 | for ii, tk in enumerate(token_list): 138 | if tk != '[UNK]': continue 139 | if ii < processed: continue 140 | for jj in range(ii+1, len(token_list)): 141 | if token_list[jj] != '[UNK]': break 142 | else: jj = len(token_list) 143 | allseg = sent[invp[ii]:invp[jj]] 144 | 145 | if ii + 1 == jj: continue 146 | seppts = [0] + [i for i, x in enumerate(allseg) if i > 0 and i+1 < len(allseg) and x == ' ' and allseg[i-1] != ' '] 147 | if allseg[seppts[-1]:].replace(' ', '') == '': seppts = seppts[:-1] 148 | seppts.append(len(allseg)) 149 | if len(seppts) == jj - ii + 1: 150 | for k, (u,v) in enumerate(zip(seppts, seppts[1:])): 151 | otokens[ii+k] = allseg[u:v] 152 | processed = jj + 1 153 | if invp[0] > 0: otokens[0] = sent[:invp[0]] + otokens[0] 154 | if ''.join(otokens) != sent: 155 | raise Exception('restore tokens failed, text and restored:\n%s\n%s' % (sent, ''.join(otokens))) 156 | return otokens 157 | 158 | def gen_word_level_labels(sent, token_list, word_list, pos_list=None): 159 | otiis = gen_token_list_inv_pointer(sent, token_list) 160 | wdiis = []; iis = 0 161 | for ip, pword in enumerate(word_list): 162 | niis = iis 163 | while niis < len(sent): 164 | if pword == '' or sent[niis:].startswith(pword[0]): break 165 | niis += 1 166 | wdiis.append(niis) 167 | iis = niis + len(pword) 168 | #for tt, ii in zip(word_list, wdiis): print(tt, sent[ii:ii+len(tt)]) 169 | 170 | rlist = []; ip = 0 171 | for it, iis in enumerate(otiis): 172 | while ip + 1 < len(wdiis) and wdiis[ip+1] <= iis: ip += 1 173 | if iis == wdiis[ip]: rr = 'B' 174 | elif iis > wdiis[ip]: rr = 'I' 175 | rr += '-' + pos_list[ip] 176 | rlist.append(rr) 177 | #for rr, tt in zip(rlist, token_list): print(rr, tt) 178 | return rlist 179 | 180 | def normalize_sentence(text): 181 | text = re.sub('[“”]', '"', text) 182 | text = re.sub('[—]', '-', text) 183 | text = re.sub('[^\u0000-\u007f\u4e00-\u9fa5\u3001-\u303f\uff00-\uffef·—]', ' \u2800 ', text) 184 | return text 185 | 186 | if __name__ == '__main__': 187 | switch_to_en() 188 | sent = 'French is the national language of France where the leaders are François Hollande and Manuel Valls . Barny cakes , made with sponge cake , can be found in France .' 189 | tokens = tokenizer.tokenize(sent) 190 | otokens = restore_token_list(sent, tokens) 191 | print(tokens) 192 | print(otokens) 193 | print('done') -------------------------------------------------------------------------------- /lstm/extraction.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, re, utils, json 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | sys.path.append('./BERT_tf2') 4 | import bert_tools as bt 5 | import tensorflow as tf 6 | from tensorflow.keras.models import load_model 7 | from tensorflow.keras.layers import * 8 | from tensorflow.keras.callbacks import * 9 | from tensorflow.keras.optimizers import * 10 | from bert4keras.backend import keras, K 11 | from collections import defaultdict 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | datan=sys.argv[1] 16 | dname = f'{datan}_lstm' 17 | datadir = f'../data/{datan}' 18 | trains = utils.LoadJsons(os.path.join(datadir, 'new_train.json')) 19 | valids = utils.LoadJsons(os.path.join(datadir, 'new_valid.json')) 20 | tests = utils.LoadJsons(os.path.join(datadir, 'new_test.json')) 21 | 22 | if not os.path.isdir('model/'+dname): os.makedirs('model/'+dname) 23 | 24 | def wdir(x): return 'model/'+dname+'/'+x 25 | 26 | rels = utils.TokenList(wdir('rels.txt'), 1, trains, lambda z:[x['label'] for x in z['relationMentions']]) 27 | print('rels:', rels.get_num()) 28 | 29 | bt.switch_to_en() 30 | 31 | maxlen = 128 32 | 33 | def dgcnn_block(x, dim, dila=1): 34 | y1 = Conv1D(dim, 3, padding='same', dilation_rate=dila)(x) 35 | y2 = Conv1D(dim, 3, padding='same', dilation_rate=dila, activation='sigmoid')(x) 36 | yy = multiply([y1, y2]) 37 | if yy.shape[-1] == x.shape[-1]: yy = add([yy, x]) 38 | return yy 39 | 40 | def neg_log_mean_loss(y_true, y_pred): 41 | eps = 1e-6 42 | pos = - K.sum(y_true * K.log(y_pred+eps), 1) / K.maximum(eps, K.sum(y_true, 1)) 43 | neg = K.sum((1-y_true) * y_pred, 1) / K.maximum(eps, K.sum(1-y_true, 1)) 44 | neg = - K.log(1 - neg + eps) 45 | return K.mean(pos + neg * 15) 46 | 47 | def FindValuePos(sent, value): 48 | ret = []; 49 | value = value.replace(' ', '').lower() 50 | if value == '': return ret 51 | ss = [x.replace(' ', '').lower() for x in sent] 52 | for k, v in enumerate(ss): 53 | if not value.startswith(v): continue 54 | vi = 0 55 | for j in range(k, len(ss)): 56 | if value[vi:].startswith(ss[j]): 57 | vi += len(ss[j]) 58 | if vi == len(value): 59 | ret.append( (k, j+1) ) 60 | else: break 61 | return ret 62 | 63 | def GetTopSpans(tokens, rr, K=40): 64 | cands = defaultdict(float) 65 | start_indexes = sorted(enumerate(rr[:,0]), key=lambda x:-x[1])[:K] 66 | end_indexes = sorted(enumerate(rr[:,1]), key=lambda x:-x[1])[:K] 67 | for start_index, start_score in start_indexes: 68 | if start_score < 0.1: continue 69 | if start_index >= len(tokens): continue 70 | for end_index, end_score in end_indexes: 71 | if end_score < 0.1: continue 72 | if end_index >= len(tokens): continue 73 | if end_index < start_index: continue 74 | length = end_index - start_index + 1 75 | if length > 40: continue 76 | ans = ''.join(tokens[start_index:end_index+1]).strip() 77 | if '》' in ans: continue 78 | if '、' in ans and len(ans.split('、')) > 2 and ',' not in ans and ',' not in ans: 79 | aas = ans.split('、') 80 | for aa in aas: cands[aa.strip()] += start_score * end_score / len(aas) 81 | continue 82 | cands[ans] += start_score * end_score 83 | 84 | cand_list = sorted(cands.items(), key=lambda x:len(x[0])) 85 | removes = set() 86 | contains = {} 87 | for i, (x, y) in enumerate(cand_list): 88 | for j, (xx, yy) in enumerate(cand_list[:i]): 89 | if xx in x and len(xx) < len(x): 90 | contains.setdefault(x, []).append(xx) 91 | 92 | for i, (x, y) in enumerate(cand_list): 93 | sump = sum(cands[z] for z in contains.get(x, []) if z not in removes) 94 | suml = sum(len(z) for z in contains.get(x, []) if z not in removes) 95 | if suml > 0: sump = sump * min(1, len(x) / suml) 96 | if sump > y: removes.add(x) 97 | else: 98 | for z in contains.get(x, []): removes.add(z) 99 | 100 | ret = [x for x in cand_list if x[0] not in removes] 101 | ret.sort(key=lambda x:-x[1]) 102 | return ret[:K] 103 | 104 | def GenTriple(p, x, y): 105 | return {'label':p, 'em1Text':x, 'em2Text':y} 106 | 107 | 108 | class RCModel: 109 | def __init__(self): 110 | inp_words = Input((None,), dtype='int32') 111 | inp_seg = Input((None,), dtype='int32') 112 | xx = Embedding(bt.tokenizer._vocab_size, 256, mask_zero=True)(inp_words) 113 | xx = Bidirectional(LSTM(256, return_sequences=True))(xx) 114 | xx = Bidirectional(LSTM(256, return_sequences=True))(xx) 115 | xx = GlobalAveragePooling1D()(xx) 116 | #xx = Lambda(lambda x:x[:,0])(xx) 117 | pos = Dense(rels.get_num(), activation='sigmoid')(xx) 118 | self.model = tf.keras.models.Model(inputs=[inp_words, inp_seg], outputs=pos) 119 | #bt.lock_transformer_layers(self.bert, 8) 120 | self.model_ready = False 121 | 122 | def gen_golden_y(self, datas): 123 | for dd in datas: 124 | dd['rc_obj'] = list(set(x['label'] for x in dd.get('relationMentions', []))) 125 | 126 | def make_model_data(self, datas): 127 | self.gen_golden_y(datas) 128 | for dd in tqdm(datas, desc='tokenize'): 129 | s = dd['sentText'] 130 | tokens = bt.tokenizer.tokenize(s, maxlen=maxlen) 131 | dd['tokens'] = tokens 132 | N = len(datas) 133 | X = [np.zeros((N, maxlen), dtype='int32'), np.zeros((N, maxlen), dtype='int32')] 134 | Y = np.zeros((N, rels.get_num())) 135 | for i, dd in enumerate(tqdm(datas, desc='gen XY', total=N)): 136 | tokens = dd['tokens'] 137 | X[0][i][:len(tokens)] = bt.tokenizer.tokens_to_ids(tokens) 138 | for x in dd['rc_obj']: Y[i][rels.get_id(x)] = 1 139 | return X, Y 140 | 141 | def load_model(self): 142 | self.model.load_weights(wdir('rc.h5')) 143 | self.model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) 144 | self.model_ready = True 145 | 146 | def train(self, datas, batch_size=32, epochs=10): 147 | self.X, self.Y = self.make_model_data(datas) 148 | #self.optimizer = bt.get_suggested_optimizer(5e-5, len(datas) * epochs // batch_size) 149 | self.optimizer = RMSprop(5e-5) 150 | self.model.compile(self.optimizer, 'binary_crossentropy', metrics=['accuracy']) 151 | self.cb_mcb = ModelCheckpoint(wdir('rc.h5'), save_weights_only=True, verbose=1) 152 | self.model.fit(self.X, self.Y, batch_size, epochs=epochs, shuffle=True, 153 | validation_split=0.01, callbacks=[self.cb_mcb]) 154 | self.model_ready = True 155 | 156 | def get_output(self, datas, pred, threshold=0.5): 157 | for dd, pp in zip(datas, pred): 158 | dd['rc_pred'] = list(rels.get_token(i) for i, sc in enumerate(pp) if sc > threshold) 159 | 160 | def evaluate(self, datas): 161 | ccnt, gcnt, ecnt = 0, 0, 0 162 | for dd in datas: 163 | plabels = set(dd['rc_pred']) 164 | ecnt += len(plabels) 165 | gcnt += len(set(dd['rc_obj'])) 166 | ccnt += len(plabels & set(dd['rc_obj'])) 167 | return utils.CalcF1(ccnt, ecnt, gcnt) 168 | 169 | def predict(self, datas, threshold=0.5, ofile=None): 170 | if not self.model_ready: self.load_model() 171 | self.vX, self.vY = self.make_model_data(datas) 172 | pred = self.model.predict(self.vX, batch_size=64, verbose=1) 173 | self.get_output(datas, pred, threshold) 174 | f1str = self.evaluate(datas) 175 | if ofile is not None: 176 | utils.SaveList(map(lambda x:json.dumps(x, ensure_ascii=False), datas), wdir(ofile)) 177 | print(f1str) 178 | return f1str 179 | 180 | class EEModel: 181 | def __init__(self): 182 | inp_words = Input((None,), dtype='int32') 183 | inp_seg = Input((None,), dtype='int32') 184 | xx = Embedding(bt.tokenizer._vocab_size, 256, mask_zero=True)(inp_words) 185 | xx = Bidirectional(LSTM(256, return_sequences=True))(xx) 186 | xx = Bidirectional(LSTM(256, return_sequences=True))(xx) 187 | pos = Dense(4, activation='sigmoid')(xx) 188 | self.model = tf.keras.models.Model(inputs=[inp_words, inp_seg], outputs=pos) 189 | self.model_ready = False 190 | 191 | def make_model_data(self, datas): 192 | if 'tokens' not in datas[0]: 193 | for dd in tqdm(datas, desc='tokenize'): 194 | s = dd['sentText'] 195 | tokens = bt.tokenizer.tokenize(s, maxlen=maxlen) 196 | dd['tokens'] = tokens 197 | N = 0 198 | for dd in tqdm(datas, desc='matching'): 199 | otokens = bt.restore_token_list(dd['sentText'], dd['tokens']) 200 | dd['otokens'] = otokens 201 | if '' in otokens: 202 | print(dd['sentText']) 203 | print(dd['tokens']) 204 | print(otokens) 205 | # assert '' not in otokens 206 | ys = {} 207 | if 'rc_pred' in dd: 208 | plist = dd['rc_pred'] 209 | else: 210 | for x in dd.get('relationMentions', []): 211 | ys.setdefault(x['label'], []).append( (x['em1Text'], x['em2Text']) ) 212 | plist = sorted(ys.keys()) 213 | yys = [] 214 | for pp in plist: 215 | spos, opos = [], [] 216 | for s, o in ys.get(pp, []): 217 | ss, oo = FindValuePos(otokens, s), FindValuePos(otokens, o) 218 | if len(ss) == 0 and len(oo) == 0: continue 219 | spos.extend(ss) 220 | opos.extend(oo) 221 | yys.append( {'pp':pp, 'spos':spos, 'opos':opos} ) 222 | dd['ee_obj'] = yys 223 | N += len(yys) 224 | X = [np.zeros((N, maxlen), dtype='int32'), np.zeros((N, maxlen), dtype='int8')] 225 | Y = np.zeros((N, maxlen, 4), dtype='int8') 226 | ii = 0 227 | for dd in tqdm(datas, desc='gen EE XY'): 228 | tokens = dd['tokens'] 229 | for item in dd['ee_obj']: 230 | pp, spos, opos = item['pp'], item['spos'], item['opos'] 231 | first = bt.tokenizer.tokenize(pp) 232 | offset = len(first) 233 | item['offset'] = offset 234 | tts = (first + tokens[1:])[:maxlen] 235 | X[0][ii][:len(tts)] = bt.tokenizer.tokens_to_ids(tts) 236 | X[1][ii][offset:offset+len(tokens)-1] = 1 237 | for u, v in spos: 238 | try: 239 | Y[ii][offset+u,0] = 1 240 | Y[ii][offset+v-1,1] = 1 241 | except: pass 242 | for u, v in opos: 243 | try: 244 | Y[ii][offset+u,2] = 1 245 | Y[ii][offset+v-1,3] = 1 246 | except: pass 247 | ii += 1 248 | return X, Y 249 | 250 | def train(self, datas, batch_size=32, epochs=10): 251 | self.X, self.Y = self.make_model_data(datas) 252 | #self.optimizer = bt.get_suggested_optimizer(5e-5, len(datas) * epochs // batch_size) 253 | self.optimizer = RMSprop(5e-5) 254 | self.model.compile(self.optimizer, 'binary_crossentropy', metrics=['accuracy']) 255 | self.cb_mcb = ModelCheckpoint(wdir('ee.h5'), save_weights_only=True, verbose=1) 256 | self.model.fit(self.X, self.Y, batch_size, epochs=epochs, shuffle=True, 257 | validation_split=0.01, callbacks=[self.cb_mcb]) 258 | self.model_ready = True 259 | 260 | def load_model(self): 261 | self.model.load_weights(wdir('ee.h5')) 262 | self.model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) 263 | self.model_ready = True 264 | 265 | def get_output(self, datas, pred, threshold=0.5): 266 | ii = 0 267 | for dd in datas: 268 | rtriples = [] 269 | for item in dd['ee_obj']: 270 | predicate, offset = item['pp'], item['offset'] 271 | rr = pred[ii]; ii += 1 272 | subs = GetTopSpans(dd['otokens'], rr[offset:,:2]) 273 | objs = GetTopSpans(dd['otokens'], rr[offset:,2:]) 274 | 275 | vv1 = [x for x,y in subs if y >= 0.1] 276 | vv2 = [x for x,y in objs if y >= 0.1] 277 | 278 | subv = {x:y for x,y in subs} 279 | objv = {x:y for x,y in objs} 280 | 281 | #mats = None 282 | #if len(vv1) * len(vv2) >= 4: 283 | # sent = ''.join(data[2]) 284 | # mats = set(Match(sent, vv1, vv2)) 285 | 286 | for sv1, sv2 in [(sv1, sv2) for sv1 in vv1 for sv2 in vv2] : 287 | if sv1 == sv2: continue 288 | score = min(subv[sv1], objv[sv2]) 289 | #if mats is not None and (sv1, sv2) not in mats: score -= 0.5 290 | if score < threshold: continue 291 | rtriples.append( GenTriple(predicate, sv1, sv2) ) 292 | 293 | dd['ee_pred'] = rtriples 294 | # assert '' not in dd['otokens'] 295 | 296 | def evaluate(self, datas): 297 | ccnt, gcnt, ecnt = 0, 0, 0 298 | for dd in datas: 299 | golden = set(); predict = set() 300 | for x in dd['relationMentions']: 301 | ss = '|'.join([x[nn] for nn in ['label', 'em1Text', 'em2Text']]) 302 | golden.add(ss) 303 | for x in dd['ee_pred']: 304 | ss = '|'.join([x[nn] for nn in ['label', 'em1Text', 'em2Text']]) 305 | predict.add(ss) 306 | ecnt += len(predict) 307 | gcnt += len(golden) 308 | ccnt += len(predict & golden) 309 | return utils.CalcF1(ccnt, ecnt, gcnt) 310 | 311 | def predict(self, datas, threshold=0.5, ofile=None): 312 | if not self.model_ready: self.load_model() 313 | self.vX, self.vY = self.make_model_data(datas) 314 | pred = self.model.predict(self.vX, batch_size=64, verbose=1) 315 | self.get_output(datas, pred, threshold=threshold) 316 | if ofile is not None: 317 | utils.SaveList(map(lambda x:json.dumps(x, ensure_ascii=False), datas), wdir(ofile)) 318 | f1str = self.evaluate(datas) 319 | print(f1str) 320 | return f1str 321 | 322 | 323 | if __name__ == '__main__': 324 | rc = RCModel() 325 | if 'trainrc' in sys.argv: 326 | rc.train(trains, batch_size=64, epochs=10) 327 | if not 'eeonly' in sys.argv: 328 | rc.predict(tests, threshold=0.4, ofile='valid_rc.json') 329 | tests = utils.LoadJsons(wdir('valid_rc.json')) 330 | ee = EEModel() 331 | if 'trainee' in sys.argv: 332 | ee.train(trains, batch_size=32, epochs=10) 333 | ee.predict(tests, threshold=0.2, ofile='valid_ee.json') 334 | print('done') -------------------------------------------------------------------------------- /lstm/utils.py: -------------------------------------------------------------------------------- 1 | # coding = utf-8 2 | 3 | import os, re, sys, random, urllib.parse, json 4 | from collections import defaultdict 5 | 6 | def WriteLine(fout, lst): 7 | fout.write('\t'.join([str(x) for x in lst]) + '\n') 8 | 9 | def RM(patt, sr): 10 | mat = re.search(patt, sr, re.DOTALL | re.MULTILINE) 11 | return mat.group(1) if mat else '' 12 | 13 | try: import requests 14 | except: pass 15 | def GetPage(url, cookie='', proxy='', timeout=5): 16 | try: 17 | headers = {'User-Agent':'Mozilla/5.0 (Windows NT 6.3; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/42.0.2311.90 Safari/537.36'} 18 | if cookie != '': headers['cookie'] = cookie 19 | if proxy != '': 20 | proxies = {'http': proxy, 'https': proxy} 21 | resp = requests.get(url, headers=headers, proxies=proxies, timeout=timeout) 22 | else: resp = requests.get(url, headers=headers, timeout=timeout) 23 | content = resp.content 24 | try: 25 | import chardet 26 | charset = chardet.detect(content).get('encoding','utf-8') 27 | if charset.lower().startswith('gb'): charset = 'gbk' 28 | content = content.decode(charset, errors='replace') 29 | except: 30 | headc = content[:min([3000,len(content)])].decode(errors='ignore') 31 | charset = RM('charset="?([-a-zA-Z0-9]+)', headc) 32 | if charset == '': charset = 'utf-8' 33 | content = content.decode(charset, errors='replace') 34 | except Exception as e: 35 | print(e) 36 | content = '' 37 | return content 38 | 39 | def GetJson(url, cookie='', proxy='', timeout=5.0): 40 | try: 41 | headers = {'User-Agent':'Mozilla/5.0 (Windows NT 6.3; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/42.0.2311.90 Safari/537.36'} 42 | if cookie != '': headers['cookie'] = cookie 43 | if proxy != '': 44 | proxies = {'http': proxy, 'https': proxy} 45 | resp = requests.get(url, headers=headers, proxies=proxies, timeout=timeout) 46 | else: resp = requests.get(url, headers=headers, timeout=timeout) 47 | return resp.json() 48 | except Exception as e: 49 | print(e) 50 | content = {} 51 | return content 52 | 53 | def FindAllHrefs(url, content=None, regex=''): 54 | ret = set() 55 | if content == None: content = GetPage(url) 56 | patt = re.compile('href="?([a-zA-Z0-9-_:/.%]+)') 57 | for xx in re.findall(patt, content): 58 | ret.add( urllib.parse.urljoin(url, xx) ) 59 | if regex != '': ret = (x for x in ret if re.match(regex, x)) 60 | return list(ret) 61 | 62 | def Translate(txt): 63 | postdata = {'from': 'en', 'to': 'zh', 'transtype': 'realtime', 'query': txt} 64 | url = "http://fanyi.baidu.com/v2transapi" 65 | try: 66 | resp = requests.post(url, data=postdata, 67 | headers={'Referer': 'http://fanyi.baidu.com/'}) 68 | ret = resp.json() 69 | ret = ret['trans_result']['data'][0]['dst'] 70 | except Exception as e: 71 | print(e) 72 | ret = '' 73 | return ret 74 | 75 | def IsChsStr(z): 76 | return re.search('^[\u4e00-\u9fa5]+$', z) is not None 77 | 78 | def FreqDict2List(dt): 79 | return sorted(dt.items(), key=lambda d:d[-1], reverse=True) 80 | 81 | def SelectRowsbyCol(fn, ofn, st, num = 0): 82 | with open(fn, encoding = "utf-8") as fin: 83 | with open(ofn, "w", encoding = "utf-8") as fout: 84 | for line in (ll for ll in fin.read().split('\n') if ll != ""): 85 | if line.split('\t')[num] in st: 86 | fout.write(line + '\n') 87 | 88 | def MergeFiles(dir, objfile, regstr = ".*"): 89 | with open(objfile, "w", encoding = "utf-8") as fout: 90 | for file in os.listdir(dir): 91 | if re.match(regstr, file): 92 | with open(os.path.join(dir, file), encoding = "utf-8") as filein: 93 | fout.write(filein.read()) 94 | 95 | def JoinFiles(fnx, fny, ofn): 96 | with open(fnx, encoding = "utf-8") as fin: 97 | lx = [vv for vv in fin.read().split('\n') if vv != ""] 98 | with open(fny, encoding = "utf-8") as fin: 99 | ly = [vv for vv in fin.read().split('\n') if vv != ""] 100 | with open(ofn, "w", encoding = "utf-8") as fout: 101 | for i in range(min(len(lx), len(ly))): 102 | fout.write(lx[i] + "\t" + ly[i] + "\n") 103 | 104 | 105 | def RemoveDupRows(file, fobj='*'): 106 | st = set() 107 | if fobj == '*': fobj = file 108 | with open(file, encoding = "utf-8") as fin: 109 | for line in fin.read().split('\n'): 110 | if line == "": continue 111 | st.add(line) 112 | with open(fobj, "w", encoding = "utf-8") as fout: 113 | for line in st: 114 | fout.write(line + '\n') 115 | 116 | def LoadCSV(fn): 117 | ret = [] 118 | with open(fn, encoding='utf-8') as fin: 119 | for line in fin: 120 | lln = line.rstrip('\r\n').split('\t') 121 | ret.append(lln) 122 | return ret 123 | 124 | def LoadCSVg(fn): 125 | with open(fn, encoding='utf-8') as fin: 126 | for line in fin: 127 | lln = line.rstrip('\r\n').split('\t') 128 | yield lln 129 | 130 | def SaveCSV(csv, fn): 131 | with open(fn, 'w', encoding='utf-8') as fout: 132 | for x in csv: 133 | WriteLine(fout, x) 134 | 135 | def SplitTables(fn, limit=3): 136 | rst = set() 137 | with open(fn, encoding='utf-8') as fin: 138 | for line in fin: 139 | lln = line.rstrip('\r\n').split('\t') 140 | rst.add(len(lln)) 141 | if len(rst) > limit: 142 | print('%d tables, exceed limit %d' % (len(rst), limit)) 143 | return 144 | for ii in rst: 145 | print('%d columns' % ii) 146 | with open(fn.replace('.txt', '') + '.split.%d.txt' % ii, 'w', encoding='utf-8') as fout: 147 | with open(fn, encoding='utf-8') as fin: 148 | for line in fin: 149 | lln = line.rstrip('\r\n').split('\t') 150 | if len(lln) == ii: 151 | fout.write(line) 152 | 153 | def LoadSet(fn): 154 | with open(fn, encoding="utf-8") as fin: 155 | st = set(ll for ll in fin.read().split('\n') if ll != "") 156 | return st 157 | 158 | def LoadList(fn): 159 | with open(fn, encoding="utf-8") as fin: 160 | st = list(ll for ll in fin.read().split('\n') if ll != "") 161 | return st 162 | 163 | def LoadJsonsg(fn): return map(json.loads, LoadListg(fn)) 164 | def LoadJsons(fn): return list(LoadJsonsg(fn)) 165 | 166 | def LoadListg(fn): 167 | with open(fn, encoding="utf-8") as fin: 168 | for ll in fin: 169 | ll = ll.strip() 170 | if ll != '': yield ll 171 | 172 | def LoadDict(fn, func=str): 173 | dict = {} 174 | with open(fn, encoding = "utf-8") as fin: 175 | for lv in (ll.split('\t', 1) for ll in fin.read().split('\n') if ll != ""): 176 | dict[lv[0]] = func(lv[1]) 177 | return dict 178 | 179 | def SaveDict(dict, ofn, output0 = True): 180 | with open(ofn, "w", encoding = "utf-8") as fout: 181 | for k in dict.keys(): 182 | if output0 or dict[k] != 0: 183 | fout.write(str(k) + "\t" + str(dict[k]) + "\n") 184 | 185 | def SaveList(st, ofn): 186 | with open(ofn, "w", encoding = "utf-8") as fout: 187 | for k in st: 188 | fout.write(str(k) + "\n") 189 | 190 | def ListDirFiles(dir, filter=None): 191 | if filter is None: 192 | return [os.path.join(dir, x) for x in os.listdir(dir)] 193 | return [os.path.join(dir, x) for x in os.listdir(dir) if filter(x)] 194 | 195 | def ProcessDir(dir, func, param): 196 | for file in os.listdir(dir): 197 | print(file) 198 | func(os.path.join(dir, file), param) 199 | 200 | def GetLines(fn): 201 | with open(fn, encoding = "utf-8", errors = 'ignore') as fin: 202 | lines = list(map(str.strip, fin.readlines())) 203 | return lines 204 | 205 | 206 | def SortRows(file, fobj, cid, type=int, rev = True): 207 | lines = LoadCSV(file) 208 | dat = [] 209 | for dv in lines: 210 | if len(dv) <= cid: continue 211 | dat.append((type(dv[cid]), dv)) 212 | with open(fobj, "w", encoding = "utf-8") as fout: 213 | for dd in sorted(dat, reverse = rev): 214 | fout.write('\t'.join(dd[1]) + '\n') 215 | 216 | def SampleRows(file, fobj, num): 217 | zz = list(open(file, encoding='utf-8')) 218 | num = min([num, len(zz)]) 219 | zz = random.sample(zz, num) 220 | with open(fobj, 'w', encoding='utf-8') as fout: 221 | for xx in zz: fout.write(xx) 222 | 223 | def SetProduct(file1, file2, fobj): 224 | l1, l2 = GetLines(file1), GetLines(file2) 225 | with open(fobj, 'w', encoding='utf-8') as fout: 226 | for z1 in l1: 227 | for z2 in l2: 228 | fout.write(z1 + z2 + '\n') 229 | 230 | class TokenList: 231 | def __init__(self, file, low_freq=2, source=None, func=None, save_low_freq=2, special_marks=[]): 232 | if not os.path.exists(file): 233 | tdict = defaultdict(int) 234 | for i, xx in enumerate(special_marks): tdict[xx] = 100000000 - i 235 | for xx in source: 236 | for token in func(xx): tdict[token] += 1 237 | tokens = FreqDict2List(tdict) 238 | tokens = [x for x in tokens if x[1] >= save_low_freq] 239 | SaveCSV(tokens, file) 240 | self.id2t = ['', ''] + \ 241 | [x for x,y in LoadCSV(file) if float(y) >= low_freq] 242 | self.t2id = {v:k for k,v in enumerate(self.id2t)} 243 | def get_id(self, token): return self.t2id.get(token, 1) 244 | def get_token(self, ii): return self.id2t[ii] 245 | def get_num(self): return len(self.id2t) 246 | 247 | def CalcF1(correct, output, golden): 248 | prec = correct / max(output, 1); reca = correct / max(golden, 1); 249 | f1 = 2 * prec * reca / max(1e-9, prec + reca) 250 | pstr = 'Prec: %.4f %d/%d, Reca: %.4f %d/%d, F1: %.4f' % (prec, correct, output, reca, correct, golden, f1) 251 | return pstr 252 | 253 | def Upgradeljqpy(url=None): 254 | if url is None: url = 'http://gdm.fudan.edu.cn/files1/ljq/ljqpy.py' 255 | dirs = [dir for dir in reversed(sys.path) if os.path.isdir(dir) and 'ljqpy.py' in os.listdir(dir)] 256 | if len(dirs) == 0: raise Exception("package directory no found") 257 | dir = dirs[0] 258 | print('downloading ljqpy.py from %s to %s' % (url, dir)) 259 | resp = requests.get(url) 260 | if b'Upgradeljqpy' not in resp.content: raise Exception('bad file') 261 | with open(os.path.join(dir, 'ljqpy.py'), 'wb') as fout: 262 | fout.write(resp.content) 263 | print('success') 264 | 265 | def sql(cmd=''): 266 | if cmd == '': cmd = input("> ") 267 | cts = [x for x in cmd.strip().lower()] 268 | instr = False 269 | for i in range(len(cts)): 270 | if cts[i] == '"' and cts[i-1] != '\\': instr = not instr 271 | if cts[i] == ' ' and instr: cts[i] = " " 272 | cmds = "".join(cts).split(' ') 273 | keyw = { 'select', 'from', 'to', 'where' } 274 | ct, kn = {}, '' 275 | for xx in cmds: 276 | if xx in keyw: kn = xx 277 | else: ct[kn] = ct.get(kn, "") + " " + xx 278 | 279 | for xx in ct.keys(): 280 | ct[xx] = ct[xx].replace(" ", " ").strip() 281 | 282 | if ct.get('where', "") == "": ct['where'] = 'True' 283 | 284 | if os.path.isdir(ct['from']): fl = [os.path.join(ct['from'], x) for x in os.listdir(ct['from'])] 285 | else: fl = ct['from'].split('+') 286 | 287 | if ct.get('to', "") == "": ct['to'] = 'temp.txt' 288 | 289 | for xx in ct.keys(): 290 | print(xx + " : " + ct[xx]) 291 | 292 | total = 0 293 | with open(ct['to'], 'w', encoding = 'utf-8') as fout: 294 | for fn in fl: 295 | print('selecting ' + fn) 296 | for xx in open(fn, encoding = 'utf-8'): 297 | x = xx.rstrip('\r\n').split('\t') 298 | if eval(ct['where']): 299 | if ct['select'] == '*': res = "\t".join(x) + '\n' 300 | else: res = "\t".join(eval('[' + ct['select'] + ']')) + '\n' 301 | fout.write(res) 302 | total += 1 303 | 304 | print('completed, ' + str(total) + " records") 305 | 306 | def cmd(): 307 | while True: 308 | cmd = input("> ") 309 | sql(cmd) 310 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Revisiting the Negative Data of Distantly Supervised Relation Extraction 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-the-negative-data-of-distantly/relation-extraction-on-nyt10-hrl)](https://paperswithcode.com/sota/relation-extraction-on-nyt10-hrl?p=revisiting-the-negative-data-of-distantly) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-the-negative-data-of-distantly/relation-extraction-on-nyt11-hrl)](https://paperswithcode.com/sota/relation-extraction-on-nyt11-hrl?p=revisiting-the-negative-data-of-distantly) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revisiting-the-negative-data-of-distantly/relation-extraction-on-nyt21)](https://paperswithcode.com/sota/relation-extraction-on-nyt21?p=revisiting-the-negative-data-of-distantly) 6 | 7 | This repository contains the source code and dataset for the paper: **Revisiting the Negative Data of Distantly Supervised Relation Extraction**. Chenhao Xie, Jiaqing Liang, Jingping Liu, Chengsong Huang, Wenhao Huang, Yanghua Xiao. ACL 2021. [paper](https://arxiv.org/pdf/2105.10158.pdf) 8 | 9 | 10 | ## How to reproduce 11 | 12 | Install all the dependencies in `requirements.txt`. 13 | 14 | Download the BERT-related files and follow the instructions in `tfhub/*/readme.md` 15 | Run`rere/bert-to-h5.py` to produce`bert_uncased.h5` and `chinese_roberta_wwm_ext.h5`. 16 | 17 | The models: ReRe and ReRe_LSTM, in table 3 are provided for reproducing in the dictionary `rere` and `rere_lstm`. 18 | `extraction.py` is the main file. 19 | If you want to train the model, you may use cmd `python extraction.py {data_set_name} train`. 20 | You can also load the model and predict by the cmd `python extraction.py {data_set_name}`,for example `python extraction.py NYT11-HRL`. We can provide the pre-trained model for reproducing exactly the same result as in the paper. 21 | 22 | The data set in Figure 3 are provided in `data/FNexp`.The data sets are generated by `data/FN_data_gen.py`. 23 | You can use the cmd`python extraction.py FNexp/{data_set_name}@{radio} train`, for example `python extraction.py FNexp/ske2019@0.1`,to train the corresponding model. 24 | 25 | ## Datasets 26 | Datasets are provided separately in [this repo](https://github.com/redreamality/-RERE-data.git). 27 | Including two new datasets NYT21 and SKE21 (the labeled testset of SKE2019). 28 | 29 | ## Usage and troubleshooting 30 | 31 | The package `bert4keras` that we provided in `./rere/BERT_TF2` can alternatively be installed via pip, but we **don't** guarantee that its latest version works with our code, if trouble happens, please run `pip uninstall bert4keras`. 32 | If pretrained the models are needed for reproduce, please contact the authors. We are willing to provide them. 33 | 34 | ## Environments detail 35 | NVIDIA-SMI 455.23.04 36 | 37 | Driver Version: 455.23.04 38 | 39 | CUDA Version: 11.1 40 | 41 | GeForce RTX 3090 42 | 43 | Python 3.7.9 44 | 45 | `requirements.txt` are provided for installing the virtual environment in conda. 46 | 47 | ## Citation 48 | 49 | @inproceedings{xie2021revisiting, 50 | title={Revisiting the Negative Data of Distantly Supervised Relation Extraction}, 51 | author={Xie, Chenhao and Liang, Jiaqing and Liu, Jingping and Huang, Chengsong and Huang, Wenhao and Xiao, Yanghua}, 52 | booktitle={Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics}, 53 | year={2021} 54 | } 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/redreamality/RERE-relation-extraction/8715c45807d26445a6449be82fb7f2e91078165e/requirements.txt -------------------------------------------------------------------------------- /rere/BERT_tf2/bert4keras/__init__.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | 3 | __version__ = '0.8.8' 4 | -------------------------------------------------------------------------------- /rere/BERT_tf2/bert4keras/backend.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 分离后端函数,主要是为了同时兼容原生keras和tf.keras 3 | # 通过设置环境变量TF_KERAS=1来切换tf.keras 4 | 5 | import os, sys 6 | from distutils.util import strtobool 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorflow.python.util import nest, tf_inspect 10 | from tensorflow.python.eager import tape 11 | from tensorflow.python.ops.custom_gradient import _graph_mode_decorator 12 | 13 | import tensorflow.keras as keras 14 | import tensorflow.keras.backend as K 15 | sys.modules['keras'] = keras 16 | 17 | is_tf_keras = True 18 | 19 | # 判断是否启用重计算(通过时间换空间) 20 | do_recompute = strtobool(os.environ.get('RECOMPUTE', '0')) 21 | 22 | def gelu_erf(x): 23 | """基于Erf直接计算的gelu函数 24 | """ 25 | return 0.5 * x * (1.0 + tf.math.erf(x / np.sqrt(2.0))) 26 | 27 | 28 | def gelu_tanh(x): 29 | """基于Tanh近似计算的gelu函数 30 | """ 31 | cdf = 0.5 * ( 32 | 1.0 + K.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * K.pow(x, 3)))) 33 | ) 34 | return x * cdf 35 | 36 | 37 | def set_gelu(version): 38 | """设置gelu版本 39 | """ 40 | version = version.lower() 41 | assert version in ['erf', 'tanh'], 'gelu version must be erf or tanh' 42 | if version == 'erf': 43 | keras.utils.get_custom_objects()['gelu'] = gelu_erf 44 | else: 45 | keras.utils.get_custom_objects()['gelu'] = gelu_tanh 46 | 47 | 48 | def piecewise_linear(t, schedule): 49 | """分段线性函数 50 | 其中schedule是形如{1000: 1, 2000: 0.1}的字典, 51 | 表示 t ∈ [0, 1000]时,输出从0均匀增加至1,而 52 | t ∈ [1000, 2000]时,输出从1均匀降低到0.1,最后 53 | t > 2000时,保持0.1不变。 54 | """ 55 | schedule = sorted(schedule.items()) 56 | if schedule[0][0] != 0: 57 | schedule = [(0, 0.0)] + schedule 58 | 59 | x = K.constant(schedule[0][1], dtype=K.floatx()) 60 | t = K.cast(t, K.floatx()) 61 | for i in range(len(schedule)): 62 | t_begin = schedule[i][0] 63 | x_begin = x 64 | if i != len(schedule) - 1: 65 | dx = schedule[i + 1][1] - schedule[i][1] 66 | dt = schedule[i + 1][0] - schedule[i][0] 67 | slope = 1.0 * dx / dt 68 | x = schedule[i][1] + slope * (t - t_begin) 69 | else: 70 | x = K.constant(schedule[i][1], dtype=K.floatx()) 71 | x = K.switch(t >= t_begin, x, x_begin) 72 | 73 | return x 74 | 75 | 76 | def search_layer(inputs, name, exclude_from=None): 77 | """根据inputs和name来搜索层 78 | 说明:inputs为某个层或某个层的输出;name为目标层的名字。 79 | 实现:根据inputs一直往上递归搜索,直到发现名字为name的层为止; 80 | 如果找不到,那就返回None。 81 | """ 82 | if exclude_from is None: 83 | exclude_from = set() 84 | 85 | if isinstance(inputs, keras.layers.Layer): 86 | layer = inputs 87 | else: 88 | layer = inputs._keras_history[0] 89 | 90 | if layer.name == name: 91 | return layer 92 | elif layer in exclude_from: 93 | return None 94 | else: 95 | exclude_from.add(layer) 96 | if isinstance(layer, keras.models.Model): 97 | model = layer 98 | for layer in model.layers: 99 | if layer.name == name: 100 | return layer 101 | inbound_layers = layer._inbound_nodes[0].inbound_layers 102 | if not isinstance(inbound_layers, list): 103 | inbound_layers = [inbound_layers] 104 | if len(inbound_layers) > 0: 105 | for layer in inbound_layers: 106 | layer = search_layer(layer, name, exclude_from) 107 | if layer is not None: 108 | return layer 109 | 110 | 111 | def sequence_masking(x, mask, mode=0, axis=None): 112 | """为序列条件mask的函数 113 | mask: 形如(batch_size, seq_len)的0-1矩阵; 114 | mode: 如果是0,则直接乘以mask; 115 | 如果是1,则在padding部分减去一个大正数。 116 | axis: 序列所在轴,默认为1; 117 | """ 118 | if mask is None or mode not in [0, 1]: 119 | return x 120 | else: 121 | if axis is None: 122 | axis = 1 123 | if axis == -1: 124 | axis = K.ndim(x) - 1 125 | assert axis > 0, 'axis must be greater than 0' 126 | for _ in range(axis - 1): 127 | mask = K.expand_dims(mask, 1) 128 | for _ in range(K.ndim(x) - K.ndim(mask) - axis + 1): 129 | mask = K.expand_dims(mask, K.ndim(mask)) 130 | if mode == 0: 131 | return x * mask 132 | else: 133 | return x - (1 - mask) * 1e12 134 | 135 | 136 | def batch_gather(params, indices): 137 | """同tf旧版本的batch_gather 138 | """ 139 | if K.dtype(indices)[:3] != 'int': 140 | indices = K.cast(indices, 'int32') 141 | 142 | try: 143 | return tf.gather(params, indices, batch_dims=K.ndim(indices) - 1) 144 | except Exception as e1: 145 | try: 146 | return tf.batch_gather(params, indices) 147 | except Exception as e2: 148 | raise ValueError('%s\n%s\n' % (e1.message, e2.message)) 149 | 150 | 151 | def pool1d( 152 | x, 153 | pool_size, 154 | strides=1, 155 | padding='valid', 156 | data_format=None, 157 | pool_mode='max' 158 | ): 159 | """向量序列的pool函数 160 | """ 161 | x = K.expand_dims(x, 1) 162 | x = K.pool2d( 163 | x, 164 | pool_size=(1, pool_size), 165 | strides=(1, strides), 166 | padding=padding, 167 | data_format=data_format, 168 | pool_mode=pool_mode 169 | ) 170 | return x[:, 0] 171 | 172 | 173 | def divisible_temporal_padding(x, n): 174 | """将一维向量序列右padding到长度能被n整除 175 | """ 176 | r_len = K.shape(x)[1] % n 177 | p_len = K.switch(r_len > 0, n - r_len, 0) 178 | return K.temporal_padding(x, (0, p_len)) 179 | 180 | 181 | def swish(x): 182 | """swish函数(这样封装过后才有 __name__ 属性) 183 | """ 184 | return tf.nn.swish(x) 185 | 186 | 187 | def leaky_relu(x, alpha=0.2): 188 | """leaky relu函数(这样封装过后才有 __name__ 属性) 189 | """ 190 | return tf.nn.leaky_relu(x, alpha=alpha) 191 | 192 | 193 | def symbolic(f): 194 | """恒等装饰器(兼容旧版本keras用) 195 | """ 196 | return f 197 | 198 | 199 | def graph_mode_decorator(f, *args, **kwargs): 200 | """tf 2.1与之前版本的传参方式不一样,这里做个同步 201 | """ 202 | if tf.__version__ < '2.1': 203 | return _graph_mode_decorator(f, *args, **kwargs) 204 | else: 205 | return _graph_mode_decorator(f, args, kwargs) 206 | 207 | 208 | def recompute_grad(call): 209 | """重计算装饰器(用来装饰Keras层的call函数) 210 | 关于重计算,请参考:https://arxiv.org/abs/1604.06174 211 | """ 212 | if not do_recompute: 213 | return call 214 | 215 | def inner(self, inputs, **kwargs): 216 | """定义需要求梯度的函数以及重新定义求梯度过程 217 | (参考自官方自带的tf.recompute_grad函数) 218 | """ 219 | flat_inputs = nest.flatten(inputs) 220 | call_args = tf_inspect.getfullargspec(call).args 221 | for key in ['mask', 'training']: 222 | if key not in call_args and key in kwargs: 223 | del kwargs[key] 224 | 225 | def kernel_call(): 226 | """定义前向计算 227 | """ 228 | return call(self, inputs, **kwargs) 229 | 230 | def call_and_grad(*inputs): 231 | """定义前向计算和反向计算 232 | """ 233 | with tape.stop_recording(): 234 | outputs = kernel_call() 235 | outputs = tf.identity(outputs) 236 | 237 | def grad_fn(doutputs, variables=None): 238 | watches = list(inputs) 239 | if variables is not None: 240 | watches += list(variables) 241 | with tf.GradientTape() as t: 242 | t.watch(watches) 243 | with tf.control_dependencies([doutputs]): 244 | outputs = kernel_call() 245 | grads = t.gradient( 246 | outputs, watches, output_gradients=[doutputs] 247 | ) 248 | del t 249 | return grads[:len(inputs)], grads[len(inputs):] 250 | 251 | return outputs, grad_fn 252 | 253 | if True: # 仅在tf >= 2.0下可用 254 | outputs, grad_fn = call_and_grad(*flat_inputs) 255 | flat_outputs = nest.flatten(outputs) 256 | 257 | def actual_grad_fn(*doutputs): 258 | grads = grad_fn(*doutputs, variables=self.trainable_weights) 259 | return grads[0] + grads[1] 260 | 261 | watches = flat_inputs + self.trainable_weights 262 | watches = [tf.convert_to_tensor(x) for x in watches] 263 | tape.record_operation( 264 | call.__name__, flat_outputs, watches, actual_grad_fn 265 | ) 266 | return outputs 267 | 268 | return inner 269 | 270 | 271 | # 给旧版本keras新增symbolic方法(装饰器), 272 | # 以便兼容optimizers.py中的代码 273 | K.symbolic = getattr(K, 'symbolic', None) or symbolic 274 | 275 | custom_objects = { 276 | 'gelu_erf': gelu_erf, 277 | 'gelu_tanh': gelu_tanh, 278 | 'gelu': gelu_erf, 279 | 'swish': swish, 280 | 'leaky_relu': leaky_relu, 281 | } 282 | 283 | keras.utils.get_custom_objects().update(custom_objects) 284 | -------------------------------------------------------------------------------- /rere/BERT_tf2/bert4keras/snippets.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 代码合集 3 | 4 | import six 5 | import logging 6 | import numpy as np 7 | import re 8 | import sys 9 | from collections import defaultdict 10 | import json 11 | 12 | _open_ = open 13 | is_py2 = six.PY2 14 | 15 | if not is_py2: 16 | basestring = str 17 | 18 | 19 | def to_array(*args): 20 | """批量转numpy的array 21 | """ 22 | results = [np.array(a) for a in args] 23 | if len(args) == 1: 24 | return results[0] 25 | else: 26 | return results 27 | 28 | 29 | def is_string(s): 30 | """判断是否是字符串 31 | """ 32 | return isinstance(s, basestring) 33 | 34 | 35 | def strQ2B(ustring): 36 | """全角符号转对应的半角符号 37 | """ 38 | rstring = '' 39 | for uchar in ustring: 40 | inside_code = ord(uchar) 41 | # 全角空格直接转换 42 | if inside_code == 12288: 43 | inside_code = 32 44 | # 全角字符(除空格)根据关系转化 45 | elif (inside_code >= 65281 and inside_code <= 65374): 46 | inside_code -= 65248 47 | rstring += unichr(inside_code) 48 | return rstring 49 | 50 | 51 | def string_matching(s, keywords): 52 | """判断s是否至少包含keywords中的至少一个字符串 53 | """ 54 | for k in keywords: 55 | if re.search(k, s): 56 | return True 57 | return False 58 | 59 | 60 | def convert_to_unicode(text, encoding='utf-8', errors='ignore'): 61 | """字符串转换为unicode格式(假设输入为utf-8格式) 62 | """ 63 | if is_py2: 64 | if isinstance(text, str): 65 | text = text.decode(encoding, errors=errors) 66 | else: 67 | if isinstance(text, bytes): 68 | text = text.decode(encoding, errors=errors) 69 | return text 70 | 71 | 72 | def convert_to_str(text, encoding='utf-8', errors='ignore'): 73 | """字符串转换为str格式(假设输入为utf-8格式) 74 | """ 75 | if is_py2: 76 | if isinstance(text, unicode): 77 | text = text.encode(encoding, errors=errors) 78 | else: 79 | if isinstance(text, bytes): 80 | text = text.decode(encoding, errors=errors) 81 | return text 82 | 83 | 84 | class open: 85 | """模仿python自带的open函数,主要是为了同时兼容py2和py3 86 | """ 87 | def __init__(self, name, mode='r', encoding=None, errors='ignore'): 88 | if is_py2: 89 | self.file = _open_(name, mode) 90 | else: 91 | self.file = _open_(name, mode, encoding=encoding, errors=errors) 92 | self.encoding = encoding 93 | self.errors = errors 94 | 95 | def __iter__(self): 96 | for l in self.file: 97 | if self.encoding: 98 | l = convert_to_unicode(l, self.encoding, self.errors) 99 | yield l 100 | 101 | def read(self): 102 | text = self.file.read() 103 | if self.encoding: 104 | text = convert_to_unicode(text, self.encoding, self.errors) 105 | return text 106 | 107 | def write(self, text): 108 | if self.encoding: 109 | text = convert_to_str(text, self.encoding, self.errors) 110 | self.file.write(text) 111 | 112 | def flush(self): 113 | self.file.flush() 114 | 115 | def close(self): 116 | self.file.close() 117 | 118 | def __enter__(self): 119 | return self 120 | 121 | def __exit__(self, type, value, tb): 122 | self.close() 123 | 124 | 125 | def parallel_apply( 126 | func, iterable, workers, max_queue_size, callback=None, dummy=False 127 | ): 128 | """多进程或多线程地将func应用到iterable的每个元素中。 129 | 注意这个apply是异步且无序的,也就是说依次输入a,b,c,但是 130 | 输出可能是func(c), func(a), func(b)。 131 | 参数: 132 | dummy: False是多进程/线性,True则是多线程/线性; 133 | callback: 处理单个输出的回调函数。 134 | """ 135 | if dummy: 136 | from multiprocessing.dummy import Pool, Queue 137 | else: 138 | from multiprocessing import Pool, Queue 139 | 140 | in_queue, out_queue = Queue(max_queue_size), Queue() 141 | 142 | def worker_step(in_queue, out_queue): 143 | # 单步函数包装成循环执行 144 | while True: 145 | i, d = in_queue.get() 146 | r = func(d) 147 | out_queue.put((i, r)) 148 | 149 | # 启动多进程/线程 150 | pool = Pool(workers, worker_step, (in_queue, out_queue)) 151 | 152 | if callback is None: 153 | results = [] 154 | 155 | # 后处理函数 156 | def process_out_queue(): 157 | out_count = 0 158 | for _ in range(out_queue.qsize()): 159 | i, d = out_queue.get() 160 | out_count += 1 161 | if callback is None: 162 | results.append((i, d)) 163 | else: 164 | callback(d) 165 | return out_count 166 | 167 | # 存入数据,取出结果 168 | in_count, out_count = 0, 0 169 | for i, d in enumerate(iterable): 170 | in_count += 1 171 | while True: 172 | try: 173 | in_queue.put((i, d), block=False) 174 | break 175 | except six.moves.queue.Full: 176 | out_count += process_out_queue() 177 | if in_count % max_queue_size == 0: 178 | out_count += process_out_queue() 179 | 180 | while out_count != in_count: 181 | out_count += process_out_queue() 182 | 183 | pool.terminate() 184 | 185 | if callback is None: 186 | results = sorted(results, key=lambda r: r[0]) 187 | return [r[1] for r in results] 188 | 189 | 190 | def sequence_padding(inputs, length=None, padding=0): 191 | """Numpy函数,将序列padding到同一长度 192 | """ 193 | if length is None: 194 | length = max([len(x) for x in inputs]) 195 | 196 | pad_width = [(0, 0) for _ in np.shape(inputs[0])] 197 | outputs = [] 198 | for x in inputs: 199 | x = x[:length] 200 | pad_width[0] = (0, length - len(x)) 201 | x = np.pad(x, pad_width, 'constant', constant_values=padding) 202 | outputs.append(x) 203 | 204 | return np.array(outputs) 205 | 206 | 207 | def text_segmentate(text, maxlen, seps='\n', strips=None): 208 | """将文本按照标点符号划分为若干个短句 209 | """ 210 | text = text.strip().strip(strips) 211 | if seps and len(text) > maxlen: 212 | pieces = text.split(seps[0]) 213 | text, texts = '', [] 214 | for i, p in enumerate(pieces): 215 | if text and p and len(text) + len(p) > maxlen - 1: 216 | texts.extend(text_segmentate(text, maxlen, seps[1:], strips)) 217 | text = '' 218 | if i + 1 == len(pieces): 219 | text = text + p 220 | else: 221 | text = text + p + seps[0] 222 | if text: 223 | texts.extend(text_segmentate(text, maxlen, seps[1:], strips)) 224 | return texts 225 | else: 226 | return [text] 227 | 228 | 229 | def is_one_of(x, ys): 230 | """判断x是否在ys之中 231 | 等价于x in ys,但有些情况下x in ys会报错 232 | """ 233 | for y in ys: 234 | if x is y: 235 | return True 236 | return False 237 | 238 | 239 | class DataGenerator(object): 240 | """数据生成器模版 241 | """ 242 | def __init__(self, data, batch_size=32, buffer_size=None): 243 | self.data = data 244 | self.batch_size = batch_size 245 | if hasattr(self.data, '__len__'): 246 | self.steps = len(self.data) // self.batch_size 247 | if len(self.data) % self.batch_size != 0: 248 | self.steps += 1 249 | else: 250 | self.steps = None 251 | self.buffer_size = buffer_size or batch_size * 1000 252 | 253 | def __len__(self): 254 | return self.steps 255 | 256 | def sample(self, random=False): 257 | """采样函数,每个样本同时返回一个is_end标记 258 | """ 259 | if random: 260 | if self.steps is None: 261 | 262 | def generator(): 263 | caches, isfull = [], False 264 | for d in self.data: 265 | caches.append(d) 266 | if isfull: 267 | i = np.random.randint(len(caches)) 268 | yield caches.pop(i) 269 | elif len(caches) == self.buffer_size: 270 | isfull = True 271 | while caches: 272 | i = np.random.randint(len(caches)) 273 | yield caches.pop(i) 274 | 275 | else: 276 | 277 | def generator(): 278 | indices = list(range(len(self.data))) 279 | np.random.shuffle(indices) 280 | for i in indices: 281 | yield self.data[i] 282 | 283 | data = generator() 284 | else: 285 | data = iter(self.data) 286 | 287 | d_current = next(data) 288 | for d_next in data: 289 | yield False, d_current 290 | d_current = d_next 291 | 292 | yield True, d_current 293 | 294 | def __iter__(self, random=False): 295 | raise NotImplementedError 296 | 297 | def forfit(self): 298 | while True: 299 | for d in self.__iter__(True): 300 | yield d 301 | 302 | 303 | class ViterbiDecoder(object): 304 | """Viterbi解码算法基类 305 | """ 306 | def __init__(self, trans, starts=None, ends=None): 307 | self.trans = trans 308 | self.num_labels = len(trans) 309 | self.non_starts = [] 310 | self.non_ends = [] 311 | if starts is not None: 312 | for i in range(self.num_labels): 313 | if i not in starts: 314 | self.non_starts.append(i) 315 | if ends is not None: 316 | for i in range(self.num_labels): 317 | if i not in ends: 318 | self.non_ends.append(i) 319 | 320 | def decode(self, nodes): 321 | """nodes.shape=[seq_len, num_labels] 322 | """ 323 | # 预处理 324 | nodes[0, self.non_starts] -= np.inf 325 | nodes[-1, self.non_ends] -= np.inf 326 | 327 | # 动态规划 328 | labels = np.arange(self.num_labels).reshape((1, -1)) 329 | scores = nodes[0].reshape((-1, 1)) 330 | paths = labels 331 | for l in range(1, len(nodes)): 332 | M = scores + self.trans + nodes[l].reshape((1, -1)) 333 | idxs = M.argmax(0) 334 | scores = M.max(0).reshape((-1, 1)) 335 | paths = np.concatenate([paths[:, idxs], labels], 0) 336 | 337 | # 最优路径 338 | return paths[:, scores[:, 0].argmax()] 339 | 340 | 341 | def softmax(x, axis=-1): 342 | """numpy版softmax 343 | """ 344 | x = x - x.max(axis=axis, keepdims=True) 345 | x = np.exp(x) 346 | return x / x.sum(axis=axis, keepdims=True) 347 | 348 | 349 | class AutoRegressiveDecoder(object): 350 | """通用自回归生成模型解码基类 351 | 包含beam search和random sample两种策略 352 | """ 353 | def __init__(self, start_id, end_id, maxlen, minlen=None): 354 | self.start_id = start_id 355 | self.end_id = end_id 356 | self.maxlen = maxlen 357 | self.minlen = minlen or 1 358 | if start_id is None: 359 | self.first_output_ids = np.empty((1, 0), dtype=int) 360 | else: 361 | self.first_output_ids = np.array([[self.start_id]]) 362 | 363 | @staticmethod 364 | def wraps(default_rtype='probas', use_states=False): 365 | """用来进一步完善predict函数 366 | 目前包含:1. 设置rtype参数,并做相应处理; 367 | 2. 确定states的使用,并做相应处理。 368 | """ 369 | def actual_decorator(predict): 370 | def new_predict( 371 | self, inputs, output_ids, states, rtype=default_rtype 372 | ): 373 | assert rtype in ['probas', 'logits'] 374 | prediction = predict(self, inputs, output_ids, states) 375 | 376 | if not use_states: 377 | prediction = (prediction, None) 378 | 379 | if default_rtype == 'logits': 380 | prediction = (softmax(prediction[0]), prediction[1]) 381 | 382 | if rtype == 'probas': 383 | return prediction 384 | else: 385 | return np.log(prediction[0] + 1e-12), prediction[1] 386 | 387 | return new_predict 388 | 389 | return actual_decorator 390 | 391 | def predict(self, inputs, output_ids, states=None, rtype='logits'): 392 | """用户需自定义递归预测函数 393 | 说明:rtype为字符串logits或probas,用户定义的时候,应当根据rtype来 394 | 返回不同的结果,rtype=probas时返回归一化的概率,rtype=logits时 395 | 则返回softmax前的结果或者概率对数。 396 | 返回:二元组 (得分或概率, states) 397 | """ 398 | raise NotImplementedError 399 | 400 | def beam_search(self, inputs, topk, states=None, min_ends=1): 401 | """beam search解码 402 | 说明:这里的topk即beam size; 403 | 返回:最优解码序列。 404 | """ 405 | inputs = [np.array([i]) for i in inputs] 406 | output_ids, output_scores = self.first_output_ids, np.zeros(1) 407 | for step in range(self.maxlen): 408 | scores, states = self.predict( 409 | inputs, output_ids, states, 'logits' 410 | ) # 计算当前得分 411 | if step == 0: # 第1步预测后将输入重复topk次 412 | inputs = [np.repeat(i, topk, axis=0) for i in inputs] 413 | scores = output_scores.reshape((-1, 1)) + scores # 综合累积得分 414 | indices = scores.argpartition(-topk, axis=None)[-topk:] # 仅保留topk 415 | indices_1 = indices // scores.shape[1] # 行索引 416 | indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引 417 | output_ids = np.concatenate([output_ids[indices_1], indices_2], 418 | 1) # 更新输出 419 | output_scores = np.take_along_axis( 420 | scores, indices, axis=None 421 | ) # 更新得分 422 | end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记 423 | if output_ids.shape[1] >= self.minlen: # 最短长度判断 424 | best_one = output_scores.argmax() # 得分最大的那个 425 | if end_counts[best_one] == min_ends: # 如果已经终止 426 | return output_ids[best_one] # 直接输出 427 | else: # 否则,只保留未完成部分 428 | flag = (end_counts < min_ends) # 标记未完成序列 429 | if not flag.all(): # 如果有已完成的 430 | inputs = [i[flag] for i in inputs] # 扔掉已完成序列 431 | output_ids = output_ids[flag] # 扔掉已完成序列 432 | output_scores = output_scores[flag] # 扔掉已完成序列 433 | end_counts = end_counts[flag] # 扔掉已完成end计数 434 | topk = flag.sum() # topk相应变化 435 | # 达到长度直接输出 436 | return output_ids[output_scores.argmax()] 437 | 438 | def random_sample(self, inputs, n, topk=None, topp=None, states=None, min_ends=1): 439 | """随机采样n个结果 440 | 说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp 441 | 表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。 442 | 返回:n个解码序列组成的list。 443 | """ 444 | inputs = [np.array([i]) for i in inputs] 445 | output_ids = self.first_output_ids 446 | results = [] 447 | for step in range(self.maxlen): 448 | probas, states = self.predict( 449 | inputs, output_ids, states, 'probas' 450 | ) # 计算当前概率 451 | probas /= probas.sum(axis=1, keepdims=True) # 确保归一化 452 | if step == 0: # 第1步预测后将结果重复n次 453 | probas = np.repeat(probas, n, axis=0) 454 | inputs = [np.repeat(i, n, axis=0) for i in inputs] 455 | output_ids = np.repeat(output_ids, n, axis=0) 456 | if topk is not None: 457 | k_indices = probas.argpartition(-topk, 458 | axis=1)[:, -topk:] # 仅保留topk 459 | probas = np.take_along_axis(probas, k_indices, axis=1) # topk概率 460 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化 461 | if topp is not None: 462 | p_indices = probas.argsort(axis=1)[:, ::-1] # 从高到低排序 463 | probas = np.take_along_axis(probas, p_indices, axis=1) # 排序概率 464 | cumsum_probas = np.cumsum(probas, axis=1) # 累积概率 465 | flag = np.roll(cumsum_probas >= topp, 1, axis=1) # 标记超过topp的部分 466 | flag[:, 0] = False # 结合上面的np.roll,实现平移一位的效果 467 | probas[flag] = 0 # 后面的全部置零 468 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化 469 | sample_func = lambda p: np.random.choice(len(p), p=p) # 按概率采样函数 470 | sample_ids = np.apply_along_axis(sample_func, 1, probas) # 执行采样 471 | sample_ids = sample_ids.reshape((-1, 1)) # 对齐形状 472 | if topp is not None: 473 | sample_ids = np.take_along_axis(p_indices, sample_ids, axis=1) # 对齐原id 474 | if topk is not None: 475 | sample_ids = np.take_along_axis(k_indices, sample_ids, axis=1) # 对齐原id 476 | output_ids = np.concatenate([output_ids, sample_ids], 1) # 更新输出 477 | end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记 478 | if output_ids.shape[1] >= self.minlen: # 最短长度判断 479 | flag = (end_counts == min_ends) # 标记已完成序列 480 | if flag.any(): # 如果有已完成的 481 | for ids in output_ids[flag]: # 存好已完成序列 482 | results.append(ids) 483 | flag = (flag == False) # 标记未完成序列 484 | inputs = [i[flag] for i in inputs] # 只保留未完成部分输入 485 | output_ids = output_ids[flag] # 只保留未完成部分候选集 486 | end_counts = end_counts[flag] # 只保留未完成部分end计数 487 | if len(output_ids) == 0: 488 | break 489 | # 如果还有未完成序列,直接放入结果 490 | for ids in output_ids: 491 | results.append(ids) 492 | # 返回结果 493 | return results 494 | 495 | 496 | def insert_arguments(**arguments): 497 | """装饰器,为类方法增加参数 498 | (主要用于类的__init__方法) 499 | """ 500 | def actual_decorator(func): 501 | def new_func(self, *args, **kwargs): 502 | for k, v in arguments.items(): 503 | if k in kwargs: 504 | v = kwargs.pop(k) 505 | setattr(self, k, v) 506 | return func(self, *args, **kwargs) 507 | 508 | return new_func 509 | 510 | return actual_decorator 511 | 512 | 513 | def delete_arguments(*arguments): 514 | """装饰器,为类方法删除参数 515 | (主要用于类的__init__方法) 516 | """ 517 | def actual_decorator(func): 518 | def new_func(self, *args, **kwargs): 519 | for k in arguments: 520 | if k in kwargs: 521 | raise TypeError( 522 | '%s got an unexpected keyword argument \'%s\'' % 523 | (self.__class__.__name__, k) 524 | ) 525 | return func(self, *args, **kwargs) 526 | 527 | return new_func 528 | 529 | return actual_decorator 530 | 531 | 532 | def longest_common_substring(source, target): 533 | """最长公共子串(source和target的最长公共切片区间) 534 | 返回:子串长度, 所在区间(四元组) 535 | 注意:最长公共子串可能不止一个,所返回的区间只代表其中一个。 536 | """ 537 | c, l, span = defaultdict(int), 0, (0, 0, 0, 0) 538 | for i, si in enumerate(source, 1): 539 | for j, tj in enumerate(target, 1): 540 | if si == tj: 541 | c[i, j] = c[i - 1, j - 1] + 1 542 | if c[i, j] > l: 543 | l = c[i, j] 544 | span = (i - l, i, j - l, j) 545 | return l, span 546 | 547 | 548 | def longest_common_subsequence(source, target): 549 | """最长公共子序列(source和target的最长非连续子序列) 550 | 返回:子序列长度, 映射关系(映射对组成的list) 551 | 注意:最长公共子序列可能不止一个,所返回的映射只代表其中一个。 552 | """ 553 | c = defaultdict(int) 554 | for i, si in enumerate(source, 1): 555 | for j, tj in enumerate(target, 1): 556 | if si == tj: 557 | c[i, j] = c[i - 1, j - 1] + 1 558 | elif c[i, j - 1] > c[i - 1, j]: 559 | c[i, j] = c[i, j - 1] 560 | else: 561 | c[i, j] = c[i - 1, j] 562 | l, mapping = c[len(source), len(target)], [] 563 | i, j = len(source) - 1, len(target) - 1 564 | while len(mapping) < l: 565 | if source[i] == target[j]: 566 | mapping.append((i, j)) 567 | i, j = i - 1, j - 1 568 | elif c[i + 1, j] > c[i, j + 1]: 569 | j = j - 1 570 | else: 571 | i = i - 1 572 | return l, mapping[::-1] 573 | 574 | 575 | class WebServing(object): 576 | """简单的Web接口 577 | 用法: 578 | arguments = {'text': (None, True), 'n': (int, False)} 579 | web = WebServing(port=8864) 580 | web.route('/gen_synonyms', gen_synonyms, arguments) 581 | web.start() 582 | # 然后访问 http://127.0.0.1:8864/gen_synonyms?text=你好 583 | 说明: 584 | 基于bottlepy简单封装,仅作为临时测试使用,不保证性能。 585 | 目前仅保证支持 Tensorflow 1.x + Keras <= 2.3.1。 586 | 欢迎有经验的开发者帮忙改进。 587 | 依赖: 588 | pip install bottle 589 | pip install paste 590 | (如果不用 server='paste' 的话,可以不装paste库) 591 | """ 592 | def __init__(self, host='0.0.0.0', port=8000, server='paste'): 593 | 594 | import tensorflow as tf 595 | from bert4keras.backend import K 596 | import bottle 597 | 598 | self.host = host 599 | self.port = port 600 | self.server = server 601 | self.graph = tf.get_default_graph() 602 | self.sess = K.get_session() 603 | self.set_session = K.set_session 604 | self.bottle = bottle 605 | 606 | def wraps(self, func, arguments, method='GET'): 607 | """封装为接口函数 608 | 参数: 609 | func:要转换为接口的函数,需要保证输出可以json化,即需要 610 | 保证 json.dumps(func(inputs)) 能被执行成功; 611 | arguments:声明func所需参数,其中key为参数名,value[0]为 612 | 对应的转换函数(接口获取到的参数值都是字符串 613 | 型),value[1]为该参数是否必须; 614 | method:GET或者POST。 615 | """ 616 | def new_func(): 617 | outputs = {'code': 0, 'desc': u'succeeded', 'data': {}} 618 | kwargs = {} 619 | for key, value in arguments.items(): 620 | if method == 'GET': 621 | result = self.bottle.request.GET.get(key) 622 | else: 623 | result = self.bottle.request.POST.get(key) 624 | if result is None: 625 | if value[1]: 626 | outputs['code'] = 1 627 | outputs['desc'] = 'lack of "%s" argument' % key 628 | return json.dumps(outputs, ensure_ascii=False) 629 | else: 630 | if value[0] is None: 631 | result = convert_to_unicode(result) 632 | else: 633 | result = value[0](result) 634 | kwargs[key] = result 635 | try: 636 | with self.graph.as_default(): 637 | self.set_session(self.sess) 638 | outputs['data'] = func(**kwargs) 639 | except Exception as e: 640 | outputs['code'] = 2 641 | outputs['desc'] = str(e) 642 | return json.dumps(outputs, ensure_ascii=False) 643 | 644 | return new_func 645 | 646 | def route(self, path, func, arguments, method='GET'): 647 | """添加接口 648 | """ 649 | func = self.wraps(func, arguments, method) 650 | self.bottle.route(path, method=method)(func) 651 | 652 | def start(self): 653 | """启动服务 654 | """ 655 | self.bottle.run(host=self.host, port=self.port, server=self.server) 656 | 657 | 658 | class Hook: 659 | """注入uniout模块,实现import时才触发 660 | """ 661 | def __init__(self, module): 662 | self.module = module 663 | 664 | def __getattr__(self, attr): 665 | """使得 from bert4keras.backend import uniout 666 | 等效于 import uniout (自动识别Python版本,Python3 667 | 下则无操作。) 668 | """ 669 | if attr == 'uniout': 670 | if is_py2: 671 | import uniout 672 | else: 673 | return getattr(self.module, attr) 674 | 675 | 676 | Hook.__name__ = __name__ 677 | sys.modules[__name__] = Hook(sys.modules[__name__]) 678 | del Hook 679 | -------------------------------------------------------------------------------- /rere/BERT_tf2/bert4keras/tokenizers.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 工具函数 3 | 4 | import unicodedata, re, os 5 | from bert4keras.snippets import is_string, is_py2 6 | from bert4keras.snippets import open 7 | 8 | 9 | def load_vocab(dict_path, encoding='utf-8', simplified=False, startswith=None): 10 | """从bert的词典文件中读取词典 11 | """ 12 | token_dict = {} 13 | if os.path.isdir(dict_path): 14 | dict_path = os.path.join(dict_path, 'vocab.txt') 15 | with open(dict_path, encoding=encoding) as reader: 16 | for line in reader: 17 | token = line.split() 18 | token = token[0] if token else line.strip() 19 | token_dict[token] = len(token_dict) 20 | 21 | if simplified: # 过滤冗余部分token 22 | new_token_dict, keep_tokens = {}, [] 23 | startswith = startswith or [] 24 | for t in startswith: 25 | new_token_dict[t] = len(new_token_dict) 26 | keep_tokens.append(token_dict[t]) 27 | 28 | for t, _ in sorted(token_dict.items(), key=lambda s: s[1]): 29 | if t not in new_token_dict: 30 | keep = True 31 | if len(t) > 1: 32 | for c in Tokenizer.stem(t): 33 | if ( 34 | Tokenizer._is_cjk_character(c) or 35 | Tokenizer._is_punctuation(c) 36 | ): 37 | keep = False 38 | break 39 | if keep: 40 | new_token_dict[t] = len(new_token_dict) 41 | keep_tokens.append(token_dict[t]) 42 | 43 | return new_token_dict, keep_tokens 44 | else: 45 | return token_dict 46 | 47 | 48 | def save_vocab(dict_path, token_dict, encoding='utf-8'): 49 | """将词典(比如精简过的)保存为文件 50 | """ 51 | with open(dict_path, 'w', encoding=encoding) as writer: 52 | for k, v in sorted(token_dict.items(), key=lambda s: s[1]): 53 | writer.write(k + '\n') 54 | 55 | 56 | class BasicTokenizer(object): 57 | """分词器基类 58 | """ 59 | def __init__(self, token_start='[CLS]', token_end='[SEP]'): 60 | """初始化 61 | """ 62 | self._token_pad = '[PAD]' 63 | self._token_unk = '[UNK]' 64 | self._token_mask = '[MASK]' 65 | self._token_start = token_start 66 | self._token_end = token_end 67 | 68 | def tokenize(self, text, maxlen=None): 69 | """分词函数 70 | """ 71 | tokens = self._tokenize(text) 72 | if self._token_start is not None: 73 | tokens.insert(0, self._token_start) 74 | if self._token_end is not None: 75 | tokens.append(self._token_end) 76 | 77 | if maxlen is not None: 78 | index = int(self._token_end is not None) + 1 79 | self.truncate_sequence(maxlen, tokens, None, -index) 80 | 81 | return tokens 82 | 83 | def token_to_id(self, token): 84 | """token转换为对应的id 85 | """ 86 | raise NotImplementedError 87 | 88 | def tokens_to_ids(self, tokens): 89 | """token序列转换为对应的id序列 90 | """ 91 | return [self.token_to_id(token) for token in tokens] 92 | 93 | def truncate_sequence( 94 | self, maxlen, first_sequence, second_sequence=None, pop_index=-1 95 | ): 96 | """截断总长度 97 | """ 98 | if second_sequence is None: 99 | second_sequence = [] 100 | 101 | while True: 102 | total_length = len(first_sequence) + len(second_sequence) 103 | if total_length <= maxlen: 104 | break 105 | elif len(first_sequence) > len(second_sequence): 106 | first_sequence.pop(pop_index) 107 | else: 108 | second_sequence.pop(pop_index) 109 | 110 | def encode( 111 | self, first_text, second_text=None, maxlen=None, pattern='S*E*E' 112 | ): 113 | """输出文本对应token id和segment id 114 | """ 115 | if is_string(first_text): 116 | first_tokens = self.tokenize(first_text) 117 | else: 118 | first_tokens = first_text 119 | 120 | if second_text is None: 121 | second_tokens = None 122 | elif is_string(second_text): 123 | if pattern == 'S*E*E': 124 | idx = int(bool(self._token_start)) 125 | second_tokens = self.tokenize(second_text)[idx:] 126 | elif pattern == 'S*ES*E': 127 | second_tokens = self.tokenize(second_text) 128 | else: 129 | second_tokens = second_text 130 | 131 | if maxlen is not None: 132 | self.truncate_sequence(maxlen, first_tokens, second_tokens, -2) 133 | 134 | first_token_ids = self.tokens_to_ids(first_tokens) 135 | first_segment_ids = [0] * len(first_token_ids) 136 | 137 | if second_text is not None: 138 | second_token_ids = self.tokens_to_ids(second_tokens) 139 | second_segment_ids = [1] * len(second_token_ids) 140 | first_token_ids.extend(second_token_ids) 141 | first_segment_ids.extend(second_segment_ids) 142 | 143 | return first_token_ids, first_segment_ids 144 | 145 | def id_to_token(self, i): 146 | """id序列为对应的token 147 | """ 148 | raise NotImplementedError 149 | 150 | def ids_to_tokens(self, ids): 151 | """id序列转换为对应的token序列 152 | """ 153 | return [self.id_to_token(i) for i in ids] 154 | 155 | def decode(self, ids): 156 | """转为可读文本 157 | """ 158 | raise NotImplementedError 159 | 160 | def _tokenize(self, text): 161 | """基本分词函数 162 | """ 163 | raise NotImplementedError 164 | 165 | 166 | class Tokenizer(BasicTokenizer): 167 | """Bert原生分词器 168 | 纯Python实现,代码修改自keras_bert的tokenizer实现 169 | """ 170 | def __init__( 171 | self, token_dict, do_lower_case=False, pre_tokenize=None, **kwargs 172 | ): 173 | """这里的pre_tokenize是外部传入的分词函数,用作对文本进行预分词。如果传入 174 | pre_tokenize,则先执行pre_tokenize(text),然后在它的基础上执行原本的 175 | tokenize函数。 176 | """ 177 | super(Tokenizer, self).__init__(**kwargs) 178 | if is_string(token_dict): 179 | token_dict = load_vocab(token_dict) 180 | 181 | self._do_lower_case = do_lower_case 182 | self._pre_tokenize = pre_tokenize 183 | self._token_dict = token_dict 184 | self._token_dict_inv = {v: k for k, v in token_dict.items()} 185 | self._vocab_size = len(token_dict) 186 | 187 | for token in ['pad', 'unk', 'mask', 'start', 'end']: 188 | try: 189 | _token_id = token_dict[getattr(self, '_token_%s' % token)] 190 | setattr(self, '_token_%s_id' % token, _token_id) 191 | except: 192 | pass 193 | 194 | def token_to_id(self, token): 195 | """token转换为对应的id 196 | """ 197 | return self._token_dict.get(token, self._token_unk_id) 198 | 199 | def id_to_token(self, i): 200 | """id转换为对应的token 201 | """ 202 | return self._token_dict_inv[i] 203 | 204 | def decode(self, ids, tokens=None): 205 | """转为可读文本 206 | """ 207 | tokens = tokens or self.ids_to_tokens(ids) 208 | tokens = [token for token in tokens if not self._is_special(token)] 209 | 210 | text, flag = '', False 211 | for i, token in enumerate(tokens): 212 | if token[:2] == '##': 213 | text += token[2:] 214 | elif len(token) == 1 and self._is_cjk_character(token): 215 | text += token 216 | elif len(token) == 1 and self._is_punctuation(token): 217 | text += token 218 | text += ' ' 219 | elif i > 0 and self._is_cjk_character(text[-1]): 220 | text += token 221 | else: 222 | text += ' ' 223 | text += token 224 | 225 | text = re.sub(' +', ' ', text) 226 | text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text) 227 | punctuation = self._cjk_punctuation() + '+-/={(<[' 228 | punctuation_regex = '|'.join([re.escape(p) for p in punctuation]) 229 | punctuation_regex = '(%s) ' % punctuation_regex 230 | text = re.sub(punctuation_regex, '\\1', text) 231 | text = re.sub('(\d\.) (\d)', '\\1\\2', text) 232 | 233 | return text.strip() 234 | 235 | def _tokenize(self, text, pre_tokenize=True): 236 | """基本分词函数 237 | """ 238 | if self._do_lower_case: 239 | if is_py2: 240 | text = unicode(text) 241 | text = text.lower() 242 | text = unicodedata.normalize('NFD', text) 243 | text = ''.join([ 244 | ch for ch in text if unicodedata.category(ch) != 'Mn' 245 | ]) 246 | 247 | if pre_tokenize and self._pre_tokenize is not None: 248 | tokens = [] 249 | for token in self._pre_tokenize(text): 250 | if token in self._token_dict: 251 | tokens.append(token) 252 | else: 253 | tokens.extend(self._tokenize(token, False)) 254 | return tokens 255 | 256 | spaced = '' 257 | for ch in text: 258 | if self._is_punctuation(ch) or self._is_cjk_character(ch): 259 | spaced += ' ' + ch + ' ' 260 | elif self._is_space(ch): 261 | spaced += ' ' 262 | elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch): 263 | continue 264 | else: 265 | spaced += ch 266 | 267 | tokens = [] 268 | for word in spaced.strip().split(): 269 | tokens.extend(self._word_piece_tokenize(word)) 270 | 271 | return tokens 272 | 273 | def _word_piece_tokenize(self, word): 274 | """word内分成subword 275 | """ 276 | if word in self._token_dict: 277 | return [word] 278 | 279 | tokens = [] 280 | start, stop = 0, 0 281 | while start < len(word): 282 | stop = len(word) 283 | while stop > start: 284 | sub = word[start:stop] 285 | if start > 0: 286 | sub = '##' + sub 287 | if sub in self._token_dict: 288 | break 289 | stop -= 1 290 | if start == stop: 291 | stop += 1 292 | tokens.append(sub) 293 | start = stop 294 | 295 | return tokens 296 | 297 | @staticmethod 298 | def stem(token): 299 | """获取token的“词干”(如果是##开头,则自动去掉##) 300 | """ 301 | if token[:2] == '##': 302 | return token[2:] 303 | else: 304 | return token 305 | 306 | @staticmethod 307 | def _is_space(ch): 308 | """空格类字符判断 309 | """ 310 | return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \ 311 | unicodedata.category(ch) == 'Zs' 312 | 313 | @staticmethod 314 | def _is_punctuation(ch): 315 | """标点符号类字符判断(全/半角均在此内) 316 | 提醒:unicodedata.category这个函数在py2和py3下的 317 | 表现可能不一样,比如u'§'字符,在py2下的结果为'So', 318 | 在py3下的结果是'Po'。 319 | """ 320 | code = ord(ch) 321 | return 33 <= code <= 47 or \ 322 | 58 <= code <= 64 or \ 323 | 91 <= code <= 96 or \ 324 | 123 <= code <= 126 or \ 325 | unicodedata.category(ch).startswith('P') 326 | 327 | @staticmethod 328 | def _cjk_punctuation(): 329 | return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002' 330 | 331 | @staticmethod 332 | def _is_cjk_character(ch): 333 | """CJK类字符判断(包括中文字符也在此列) 334 | 参考:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 335 | """ 336 | code = ord(ch) 337 | return 0x4E00 <= code <= 0x9FFF or \ 338 | 0x3400 <= code <= 0x4DBF or \ 339 | 0x20000 <= code <= 0x2A6DF or \ 340 | 0x2A700 <= code <= 0x2B73F or \ 341 | 0x2B740 <= code <= 0x2B81F or \ 342 | 0x2B820 <= code <= 0x2CEAF or \ 343 | 0xF900 <= code <= 0xFAFF or \ 344 | 0x2F800 <= code <= 0x2FA1F 345 | 346 | @staticmethod 347 | def _is_control(ch): 348 | """控制类字符判断 349 | """ 350 | return unicodedata.category(ch) in ('Cc', 'Cf') 351 | 352 | @staticmethod 353 | def _is_special(ch): 354 | """判断是不是有特殊含义的符号 355 | """ 356 | return bool(ch) and (ch[0] == '[') and (ch[-1] == ']') 357 | 358 | def rematch(self, text, tokens): 359 | """给出原始的text和tokenize后的tokens的映射关系 360 | """ 361 | if is_py2: 362 | text = unicode(text) 363 | 364 | if self._do_lower_case: 365 | text = text.lower() 366 | 367 | normalized_text, char_mapping = '', [] 368 | for i, ch in enumerate(text): 369 | if self._do_lower_case: 370 | ch = unicodedata.normalize('NFD', ch) 371 | ch = ''.join([c for c in ch if unicodedata.category(c) != 'Mn']) 372 | ch = ''.join([ 373 | c for c in ch 374 | if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c)) 375 | ]) 376 | normalized_text += ch 377 | char_mapping.extend([i] * len(ch)) 378 | 379 | text, token_mapping, offset = normalized_text, [], 0 380 | for token in tokens: 381 | if self._is_special(token): 382 | token_mapping.append([]) 383 | else: 384 | token = self.stem(token) 385 | start = text[offset:].index(token) + offset 386 | end = start + len(token) 387 | token_mapping.append(char_mapping[start:end]) 388 | offset = end 389 | 390 | return token_mapping 391 | 392 | 393 | class SpTokenizer(BasicTokenizer): 394 | """基于SentencePiece模型的封装,使用上跟Tokenizer基本一致。 395 | """ 396 | def __init__(self, sp_model_path, **kwargs): 397 | super(SpTokenizer, self).__init__(**kwargs) 398 | import sentencepiece as spm 399 | self.sp_model = spm.SentencePieceProcessor() 400 | self.sp_model.Load(sp_model_path) 401 | self._token_pad = self.sp_model.id_to_piece(self.sp_model.pad_id()) 402 | self._token_unk = self.sp_model.id_to_piece(self.sp_model.unk_id()) 403 | self._vocab_size = self.sp_model.get_piece_size() 404 | 405 | for token in ['pad', 'unk', 'mask', 'start', 'end']: 406 | try: 407 | _token = getattr(self, '_token_%s' % token) 408 | _token_id = self.sp_model.piece_to_id(_token) 409 | setattr(self, '_token_%s_id' % token, _token_id) 410 | except: 411 | pass 412 | 413 | def token_to_id(self, token): 414 | """token转换为对应的id 415 | """ 416 | return self.sp_model.piece_to_id(token) 417 | 418 | def id_to_token(self, i): 419 | """id转换为对应的token 420 | """ 421 | if i < self._vocab_size: 422 | return self.sp_model.id_to_piece(i) 423 | else: 424 | return '' 425 | 426 | def decode(self, ids): 427 | """转为可读文本 428 | """ 429 | ids = [i for i in ids if self._is_decodable(i)] 430 | text = self.sp_model.decode_ids(ids) 431 | return text.decode('utf-8') if is_py2 else text 432 | 433 | def _tokenize(self, text): 434 | """基本分词函数 435 | """ 436 | tokens = self.sp_model.encode_as_pieces(text) 437 | return tokens 438 | 439 | def _is_special(self, i): 440 | """判断是不是有特殊含义的符号 441 | """ 442 | return self.sp_model.is_control(i) or \ 443 | self.sp_model.is_unknown(i) or \ 444 | self.sp_model.is_unused(i) 445 | 446 | def _is_decodable(self, i): 447 | """判断是否应该被解码输出 448 | """ 449 | return (i < self._vocab_size) and not self._is_special(i) 450 | -------------------------------------------------------------------------------- /rere/BERT_tf2/bert_tools.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, math, re 2 | import tensorflow as tf 3 | from tqdm import tqdm 4 | import numpy as np 5 | import tensorflow.keras.backend as K 6 | from tensorflow.keras.callbacks import Callback 7 | from tensorflow.keras.models import * 8 | 9 | from bert4keras.tokenizers import Tokenizer 10 | from bert4keras.snippets import to_array 11 | from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr, extend_with_weight_decay 12 | from bert4keras.models import build_transformer_model 13 | from bert4keras.layers import * 14 | from tensorflow.keras.initializers import TruncatedNormal 15 | 16 | en_dict_path = r'../tfhub/uncased_L-12_H-768_A-12/vocab.txt' 17 | cn_dict_path = r'../tfhub/chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt' 18 | 19 | # script_path = os.path.split(os.path.realpath(__file__))[0] 20 | # en_dict_path = os.path.join(script_path, '../tfhub/uncased_L-12_H-768_A-12') 21 | # cn_dict_path = os.path.join(script_path, '../tfhub/chinese_roberta_wwm_ext_L-12_H-768_A-12') 22 | tokenizer = Tokenizer(cn_dict_path, do_lower_case=True) 23 | language = 'cn' 24 | 25 | def switch_to_en(): 26 | global tokenizer, language 27 | tokenizer = Tokenizer(en_dict_path, do_lower_case=True) 28 | language = 'en' 29 | 30 | def convert_sentences(sents, maxlen=256): 31 | shape = (2, len(sents), maxlen) 32 | X = np.zeros(shape, dtype='int32') 33 | for ii, sent in tqdm(enumerate(sents), desc="Converting sentences"): 34 | tids, segs = tokenizer.encode(sent, maxlen=maxlen) 35 | X[0,ii,:len(tids)] = tids 36 | X[1,ii,:len(segs)] = segs 37 | return [X[0], X[1]] 38 | 39 | def convert_tokens(sents, maxlen=256): 40 | shape = (2, len(sents), maxlen) 41 | X = np.zeros(shape, dtype='int32') 42 | for ii, sent in tqdm(enumerate(sents), desc="Converting tokens"): 43 | tids = tokenizer.tokens_to_ids(sent) 44 | X[0,ii,:len(tids)] = tids 45 | return [X[0], X[1]] 46 | 47 | def lock_transformer_layers(transformer, layers=-1): 48 | def _filter(layers, prefix): 49 | return [x for x in transformer.layers if x.name.startswith(prefix)] 50 | if hasattr(transformer, 'model'): transformer = transformer.model 51 | if layers >= 0: 52 | print('locking', 'Embedding-*') 53 | for layer in _filter(transformer, 'Embedding-'): 54 | layer.trainable = False 55 | print('locking', 'Transformer-[%d-%d]-*' % (0, layers-1)) 56 | for index in range(layers): 57 | for layer in _filter(transformer, 'Transformer-%d-' % index): 58 | layer.trainable = False 59 | 60 | def unlock_transformer_layers(transformer): 61 | if hasattr(transformer, 'model'): transformer = transformer.model 62 | for layer in transformer.layers: 63 | layer.trainable = True 64 | 65 | def get_suggested_optimizer(init_lr=5e-5, total_steps=None): 66 | lr_schedule = {1000:1, 10000:0.01} 67 | if total_steps is not None: 68 | lr_schedule = {total_steps//10:1, total_steps:0.1} 69 | optimizer = extend_with_weight_decay(Adam) 70 | optimizer = extend_with_piecewise_linear_lr(optimizer) 71 | optimizer_params = { 72 | 'learning_rate': init_lr, 73 | 'lr_schedule': lr_schedule, 74 | 'weight_decay_rate': 0.01, 75 | 'exclude_from_weight_decay': ['Norm', 'bias'], 76 | 'bias_correction': False, 77 | } 78 | optimizer = optimizer(**optimizer_params) 79 | return optimizer 80 | 81 | def convert_single_setences(sens, maxlen, tokenizer, details=False): 82 | X = np.zeros((len(sens), maxlen), dtype='int32') 83 | datas = [] 84 | for i, s in enumerate(sens): 85 | tokens = tokenizer.tokenize(s)[:maxlen-2] 86 | if details: 87 | otokens = restore_token_list(s, tokens) 88 | datas.append({'id':i, 's':s, 'otokens':otokens}) 89 | tt = ['[CLS]'] + tokens + ['[SEP]'] 90 | tids = tokenizer.convert_tokens_to_ids(tt) 91 | X[i,:len(tids)] = tids 92 | if details: return datas, X 93 | return X 94 | 95 | def build_classifier(classes, bert_h5=None): 96 | if bert_h5 is None: 97 | bert_h5 = '../tfhub/chinese_roberta_wwm_ext.h5' if language == 'cn' else '../tfhub/bert_uncased.h5' 98 | bert = load_model(bert_h5) 99 | output = Lambda(lambda x: x[:,0], name='CLS-token')(bert.output) 100 | if classes == 2: 101 | output = Dense(1, activation='sigmoid', kernel_initializer=TruncatedNormal(stddev=0.02))(output) 102 | else: 103 | output = Dense(classes, activation='softmax', kernel_initializer=TruncatedNormal(stddev=0.02))(output) 104 | model = Model(bert.input, output) 105 | model.bert_encoder = bert 106 | return model 107 | 108 | 109 | ## THESE FUNCTIONS ARE TESTED FOR CHS LANGUAGE ONLY 110 | def gen_token_list_inv_pointer(sent, token_list): 111 | zz = tokenizer.rematch(sent, token_list) 112 | return [x[0] for x in zz if len(x) > 0] 113 | sent = sent.lower() 114 | otiis = []; iis = 0 115 | for it, token in enumerate(token_list): 116 | otoken = token.lstrip('#') 117 | if token[0] == '[' and token[-1] == ']': otoken = '' 118 | niis = iis 119 | while niis <= len(sent): 120 | if sent[niis:].startswith(otoken): break 121 | if otoken in '-"' and sent[niis][0] in '—“”': break 122 | niis += 1 123 | if niis >= len(sent): niis = iis 124 | otiis.append(niis) 125 | iis = niis + max(1, len(otoken)) 126 | for tt, ii in zip(token_list, otiis): print(tt, sent[ii:ii+len(tt.lstrip('#'))]) 127 | for i, iis in enumerate(otiis): 128 | assert iis < len(sent) 129 | otoken = token_list[i].strip('#') 130 | assert otoken == '[UNK]' or sent[iis:iis+len(otoken)] == otoken 131 | return otiis 132 | 133 | # restore [UNK] tokens to the original tokens 134 | def restore_token_list(sent, token_list): 135 | if token_list[0] == '[CLS]': token_list = token_list[1:-1] 136 | invp = gen_token_list_inv_pointer(sent, token_list) 137 | invp.append(len(sent)) 138 | otokens = [sent[u:v] for u,v in zip(invp, invp[1:])] 139 | processed = -1 140 | for ii, tk in enumerate(token_list): 141 | if tk != '[UNK]': continue 142 | if ii < processed: continue 143 | for jj in range(ii+1, len(token_list)): 144 | if token_list[jj] != '[UNK]': break 145 | else: jj = len(token_list) 146 | allseg = sent[invp[ii]:invp[jj]] 147 | 148 | if ii + 1 == jj: continue 149 | seppts = [0] + [i for i, x in enumerate(allseg) if i > 0 and i+1 < len(allseg) and x == ' ' and allseg[i-1] != ' '] 150 | if allseg[seppts[-1]:].replace(' ', '') == '': seppts = seppts[:-1] 151 | seppts.append(len(allseg)) 152 | if len(seppts) == jj - ii + 1: 153 | for k, (u,v) in enumerate(zip(seppts, seppts[1:])): 154 | otokens[ii+k] = allseg[u:v] 155 | processed = jj + 1 156 | if invp[0] > 0: otokens[0] = sent[:invp[0]] + otokens[0] 157 | if ''.join(otokens) != sent: 158 | raise Exception('restore tokens failed, text and restored:\n%s\n%s' % (sent, ''.join(otokens))) 159 | return otokens 160 | 161 | def gen_word_level_labels(sent, token_list, word_list, pos_list=None): 162 | otiis = gen_token_list_inv_pointer(sent, token_list) 163 | wdiis = []; iis = 0 164 | for ip, pword in enumerate(word_list): 165 | niis = iis 166 | while niis < len(sent): 167 | if pword == '' or sent[niis:].startswith(pword[0]): break 168 | niis += 1 169 | wdiis.append(niis) 170 | iis = niis + len(pword) 171 | #for tt, ii in zip(word_list, wdiis): print(tt, sent[ii:ii+len(tt)]) 172 | 173 | rlist = []; ip = 0 174 | for it, iis in enumerate(otiis): 175 | while ip + 1 < len(wdiis) and wdiis[ip+1] <= iis: ip += 1 176 | if iis == wdiis[ip]: rr = 'B' 177 | elif iis > wdiis[ip]: rr = 'I' 178 | rr += '-' + pos_list[ip] 179 | rlist.append(rr) 180 | #for rr, tt in zip(rlist, token_list): print(rr, tt) 181 | return rlist 182 | 183 | def normalize_sentence(text): 184 | text = re.sub('[“”]', '"', text) 185 | text = re.sub('[—]', '-', text) 186 | text = re.sub('[^\u0000-\u007f\u4e00-\u9fa5\u3001-\u303f\uff00-\uffef·—]', ' \u2800 ', text) 187 | return text 188 | 189 | if __name__ == '__main__': 190 | #from transformers import BertTokenizer 191 | #tokenizer = BertTokenizer.from_pretrained('hfl/chinese-bert-wwm') 192 | switch_to_en() 193 | sent = 'French is the national language of France where the leaders are François Hollande and Manuel Valls . Barny cakes , made with sponge cake , can be found in France .' 194 | tokens = tokenizer.tokenize(sent) 195 | otokens = restore_token_list(sent, tokens) 196 | print(tokens) 197 | print(otokens) 198 | print('done') -------------------------------------------------------------------------------- /rere/batch_run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | folders = ['NYT10-HRL','NYT11-HRL'] 5 | # fnratio = [ str(round(a*0.1, 1)) for a in range(1,6)] 6 | for d in folders: 7 | # for radio in fnratio: 8 | cmd = f'python extraction.py {d}' 9 | print(cmd) 10 | with open('result.csv','a',encoding='utf8') as f: f.write('\n'+cmd+'\n') 11 | os.system(cmd) 12 | for d in folders: 13 | # for radio in fnratio: 14 | cmd = f'python extraction.py {d}' 15 | print(cmd) 16 | with open('result.csv','a',encoding='utf8') as f: f.write('\n'+cmd+'\n') 17 | os.system(cmd) -------------------------------------------------------------------------------- /rere/bert-to-h5.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | sys.path.append('./BERT_tf2') 3 | from bert4keras.models import build_transformer_model 4 | 5 | bert_path = './tfhub/uncased_L-12_H-768_A-12' 6 | bert = build_transformer_model(bert_path, return_keras_model=False) 7 | bert.model.save('./tfhub/bert_uncased.h5') 8 | 9 | -------------------------------------------------------------------------------- /rere/bert4keras/__init__.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | 3 | __version__ = '0.9.5' 4 | -------------------------------------------------------------------------------- /rere/bert4keras/backend.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 分离后端函数,主要是为了同时兼容原生keras和tf.keras 3 | # 通过设置环境变量TF_KERAS=1来切换tf.keras 4 | 5 | import os, sys 6 | from distutils.util import strtobool 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorflow.python.util import nest, tf_inspect 10 | from tensorflow.python.eager import tape 11 | from tensorflow.python.ops.custom_gradient import _graph_mode_decorator 12 | 13 | # 判断是tf.keras还是纯keras的标记 14 | is_tf_keras = strtobool(os.environ.get('TF_KERAS', '1')) # 15 | 16 | if is_tf_keras: 17 | import tensorflow.keras as keras 18 | import tensorflow.keras.backend as K 19 | sys.modules['keras'] = keras 20 | else: 21 | import keras 22 | import keras.backend as K 23 | 24 | # 判断是否启用重计算(通过时间换空间) 25 | do_recompute = strtobool(os.environ.get('RECOMPUTE', '0')) 26 | 27 | 28 | def gelu_erf(x): 29 | """基于Erf直接计算的gelu函数 30 | """ 31 | return 0.5 * x * (1.0 + tf.math.erf(x / np.sqrt(2.0))) 32 | 33 | 34 | def gelu_tanh(x): 35 | """基于Tanh近似计算的gelu函数 36 | """ 37 | cdf = 0.5 * ( 38 | 1.0 + K.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * K.pow(x, 3)))) 39 | ) 40 | return x * cdf 41 | 42 | 43 | def set_gelu(version): 44 | """设置gelu版本 45 | """ 46 | version = version.lower() 47 | assert version in ['erf', 'tanh'], 'gelu version must be erf or tanh' 48 | if version == 'erf': 49 | keras.utils.get_custom_objects()['gelu'] = gelu_erf 50 | else: 51 | keras.utils.get_custom_objects()['gelu'] = gelu_tanh 52 | 53 | 54 | def piecewise_linear(t, schedule): 55 | """分段线性函数 56 | 其中schedule是形如{1000: 1, 2000: 0.1}的字典, 57 | 表示 t ∈ [0, 1000]时,输出从0均匀增加至1,而 58 | t ∈ [1000, 2000]时,输出从1均匀降低到0.1,最后 59 | t > 2000时,保持0.1不变。 60 | """ 61 | schedule = sorted(schedule.items()) 62 | if schedule[0][0] != 0: 63 | schedule = [(0, 0.0)] + schedule 64 | 65 | x = K.constant(schedule[0][1], dtype=K.floatx()) 66 | t = K.cast(t, K.floatx()) 67 | for i in range(len(schedule)): 68 | t_begin = schedule[i][0] 69 | x_begin = x 70 | if i != len(schedule) - 1: 71 | dx = schedule[i + 1][1] - schedule[i][1] 72 | dt = schedule[i + 1][0] - schedule[i][0] 73 | slope = 1.0 * dx / dt 74 | x = schedule[i][1] + slope * (t - t_begin) 75 | else: 76 | x = K.constant(schedule[i][1], dtype=K.floatx()) 77 | x = K.switch(t >= t_begin, x, x_begin) 78 | 79 | return x 80 | 81 | 82 | def search_layer(inputs, name, exclude_from=None): 83 | """根据inputs和name来搜索层 84 | 说明:inputs为某个层或某个层的输出;name为目标层的名字。 85 | 实现:根据inputs一直往上递归搜索,直到发现名字为name的层为止; 86 | 如果找不到,那就返回None。 87 | """ 88 | if exclude_from is None: 89 | exclude_from = set() 90 | 91 | if isinstance(inputs, keras.layers.Layer): 92 | layer = inputs 93 | else: 94 | layer = inputs._keras_history[0] 95 | 96 | if layer.name == name: 97 | return layer 98 | elif layer in exclude_from: 99 | return None 100 | else: 101 | exclude_from.add(layer) 102 | if isinstance(layer, keras.models.Model): 103 | model = layer 104 | for layer in model.layers: 105 | if layer.name == name: 106 | return layer 107 | inbound_layers = layer._inbound_nodes[0].inbound_layers 108 | if not isinstance(inbound_layers, list): 109 | inbound_layers = [inbound_layers] 110 | if len(inbound_layers) > 0: 111 | for layer in inbound_layers: 112 | layer = search_layer(layer, name, exclude_from) 113 | if layer is not None: 114 | return layer 115 | 116 | 117 | def sequence_masking(x, mask, mode=0, axis=None): 118 | """为序列条件mask的函数 119 | mask: 形如(batch_size, seq_len)的0-1矩阵; 120 | mode: 如果是0,则直接乘以mask; 121 | 如果是1,则在padding部分减去一个大正数。 122 | axis: 序列所在轴,默认为1; 123 | """ 124 | if mask is None or mode not in [0, 1]: 125 | return x 126 | else: 127 | if axis is None: 128 | axis = 1 129 | if axis == -1: 130 | axis = K.ndim(x) - 1 131 | assert axis > 0, 'axis must be greater than 0' 132 | for _ in range(axis - 1): 133 | mask = K.expand_dims(mask, 1) 134 | for _ in range(K.ndim(x) - K.ndim(mask) - axis + 1): 135 | mask = K.expand_dims(mask, K.ndim(mask)) 136 | if mode == 0: 137 | return x * mask 138 | else: 139 | return x - (1 - mask) * 1e12 140 | 141 | 142 | def batch_gather(params, indices): 143 | """同tf旧版本的batch_gather 144 | """ 145 | if K.dtype(indices)[:3] != 'int': 146 | indices = K.cast(indices, 'int32') 147 | 148 | try: 149 | return tf.gather(params, indices, batch_dims=K.ndim(indices) - 1) 150 | except Exception as e1: 151 | try: 152 | return tf.batch_gather(params, indices) 153 | except Exception as e2: 154 | raise ValueError('%s\n%s\n' % (e1.message, e2.message)) 155 | 156 | 157 | def pool1d( 158 | x, 159 | pool_size, 160 | strides=1, 161 | padding='valid', 162 | data_format=None, 163 | pool_mode='max' 164 | ): 165 | """向量序列的pool函数 166 | """ 167 | x = K.expand_dims(x, 1) 168 | x = K.pool2d( 169 | x, 170 | pool_size=(1, pool_size), 171 | strides=(1, strides), 172 | padding=padding, 173 | data_format=data_format, 174 | pool_mode=pool_mode 175 | ) 176 | return x[:, 0] 177 | 178 | 179 | def divisible_temporal_padding(x, n): 180 | """将一维向量序列右padding到长度能被n整除 181 | """ 182 | r_len = K.shape(x)[1] % n 183 | p_len = K.switch(r_len > 0, n - r_len, 0) 184 | return K.temporal_padding(x, (0, p_len)) 185 | 186 | 187 | def swish(x): 188 | """swish函数(这样封装过后才有 __name__ 属性) 189 | """ 190 | return tf.nn.swish(x) 191 | 192 | 193 | def leaky_relu(x, alpha=0.2): 194 | """leaky relu函数(这样封装过后才有 __name__ 属性) 195 | """ 196 | return tf.nn.leaky_relu(x, alpha=alpha) 197 | 198 | 199 | class Sinusoidal(keras.initializers.Initializer): 200 | """Sin-Cos位置向量初始化器 201 | 来自:https://arxiv.org/abs/1706.03762 202 | """ 203 | def __call__(self, shape, dtype=None): 204 | """Sin-Cos形式的位置向量 205 | """ 206 | vocab_size, depth = shape 207 | embeddings = np.zeros(shape) 208 | for pos in range(vocab_size): 209 | for i in range(depth // 2): 210 | theta = pos / np.power(10000, 2. * i / depth) 211 | embeddings[pos, 2 * i] = np.sin(theta) 212 | embeddings[pos, 2 * i + 1] = np.cos(theta) 213 | return embeddings 214 | 215 | 216 | def symbolic(f): 217 | """恒等装饰器(兼容旧版本keras用) 218 | """ 219 | return f 220 | 221 | 222 | def graph_mode_decorator(f, *args, **kwargs): 223 | """tf 2.1与之前版本的传参方式不一样,这里做个同步 224 | """ 225 | if tf.__version__ < '2.1': 226 | return _graph_mode_decorator(f, *args, **kwargs) 227 | else: 228 | return _graph_mode_decorator(f, args, kwargs) 229 | 230 | 231 | def recompute_grad(call): 232 | """重计算装饰器(用来装饰Keras层的call函数) 233 | 关于重计算,请参考:https://arxiv.org/abs/1604.06174 234 | """ 235 | if not do_recompute: 236 | return call 237 | 238 | def inner(self, inputs, **kwargs): 239 | """定义需要求梯度的函数以及重新定义求梯度过程 240 | (参考自官方自带的tf.recompute_grad函数) 241 | """ 242 | flat_inputs = nest.flatten(inputs) 243 | call_args = tf_inspect.getfullargspec(call).args 244 | for key in ['mask', 'training']: 245 | if key not in call_args and key in kwargs: 246 | del kwargs[key] 247 | 248 | def kernel_call(): 249 | """定义前向计算 250 | """ 251 | return call(self, inputs, **kwargs) 252 | 253 | def call_and_grad(*inputs): 254 | """定义前向计算和反向计算 255 | """ 256 | if is_tf_keras: 257 | with tape.stop_recording(): 258 | outputs = kernel_call() 259 | outputs = tf.identity(outputs) 260 | else: 261 | outputs = kernel_call() 262 | 263 | def grad_fn(doutputs, variables=None): 264 | watches = list(inputs) 265 | if variables is not None: 266 | watches += list(variables) 267 | with tf.GradientTape() as t: 268 | t.watch(watches) 269 | with tf.control_dependencies([doutputs]): 270 | outputs = kernel_call() 271 | grads = t.gradient( 272 | outputs, watches, output_gradients=[doutputs] 273 | ) 274 | del t 275 | return grads[:len(inputs)], grads[len(inputs):] 276 | 277 | return outputs, grad_fn 278 | 279 | if is_tf_keras: # 仅在tf >= 2.0下可用 280 | outputs, grad_fn = call_and_grad(*flat_inputs) 281 | flat_outputs = nest.flatten(outputs) 282 | 283 | def actual_grad_fn(*doutputs): 284 | grads = grad_fn(*doutputs, variables=self.trainable_weights) 285 | return grads[0] + grads[1] 286 | 287 | watches = flat_inputs + self.trainable_weights 288 | watches = [tf.convert_to_tensor(x) for x in watches] 289 | tape.record_operation( 290 | call.__name__, flat_outputs, watches, actual_grad_fn 291 | ) 292 | return outputs 293 | else: # keras + tf >= 1.14 均可用 294 | return graph_mode_decorator(call_and_grad, *flat_inputs) 295 | 296 | return inner 297 | 298 | 299 | # 给旧版本keras新增symbolic方法(装饰器), 300 | # 以便兼容optimizers.py中的代码 301 | K.symbolic = getattr(K, 'symbolic', None) or symbolic 302 | 303 | custom_objects = { 304 | 'gelu_erf': gelu_erf, 305 | 'gelu_tanh': gelu_tanh, 306 | 'gelu': gelu_erf, 307 | 'swish': swish, 308 | 'leaky_relu': leaky_relu, 309 | 'Sinusoidal': Sinusoidal, 310 | } 311 | 312 | keras.utils.get_custom_objects().update(custom_objects) 313 | -------------------------------------------------------------------------------- /rere/bert4keras/tokenizers.py: -------------------------------------------------------------------------------- 1 | #! -*- coding: utf-8 -*- 2 | # 分词函数 3 | 4 | import unicodedata, re 5 | from bert4keras.snippets import is_string, is_py2 6 | # from bert4keras.snippets import open 7 | from bert4keras.snippets import convert_to_unicode 8 | 9 | 10 | def load_vocab(dict_path, encoding='utf-8', simplified=False, startswith=None): 11 | """从bert的词典文件中读取词典 12 | """ 13 | token_dict = {} 14 | with open(dict_path, encoding=encoding) as reader: 15 | for line in reader: 16 | token = line.split() 17 | token = token[0] if token else line.strip() 18 | token_dict[token] = len(token_dict) 19 | 20 | if simplified: # 过滤冗余部分token 21 | new_token_dict, keep_tokens = {}, [] 22 | startswith = startswith or [] 23 | for t in startswith: 24 | new_token_dict[t] = len(new_token_dict) 25 | keep_tokens.append(token_dict[t]) 26 | 27 | for t, _ in sorted(token_dict.items(), key=lambda s: s[1]): 28 | if t not in new_token_dict: 29 | keep = True 30 | if len(t) > 1: 31 | for c in Tokenizer.stem(t): 32 | if ( 33 | Tokenizer._is_cjk_character(c) or 34 | Tokenizer._is_punctuation(c) 35 | ): 36 | keep = False 37 | break 38 | if keep: 39 | new_token_dict[t] = len(new_token_dict) 40 | keep_tokens.append(token_dict[t]) 41 | 42 | return new_token_dict, keep_tokens 43 | else: 44 | return token_dict 45 | 46 | 47 | def save_vocab(dict_path, token_dict, encoding='utf-8'): 48 | """将词典(比如精简过的)保存为文件 49 | """ 50 | with open(dict_path, 'w', encoding=encoding) as writer: 51 | for k, v in sorted(token_dict.items(), key=lambda s: s[1]): 52 | writer.write(k + '\n') 53 | 54 | 55 | class TokenizerBase(object): 56 | """分词器基类 57 | """ 58 | def __init__( 59 | self, 60 | token_start='[CLS]', 61 | token_end='[SEP]', 62 | pre_tokenize=None, 63 | token_translate=None 64 | ): 65 | """参数说明: 66 | pre_tokenize:外部传入的分词函数,用作对文本进行预分词。如果传入 67 | pre_tokenize,则先执行pre_tokenize(text),然后在它 68 | 的基础上执行原本的tokenize函数; 69 | token_translate:映射字典,主要用在tokenize之后,将某些特殊的token 70 | 替换为对应的token。 71 | """ 72 | self._token_pad = '[PAD]' 73 | self._token_unk = '[UNK]' 74 | self._token_mask = '[MASK]' 75 | self._token_start = token_start 76 | self._token_end = token_end 77 | self._pre_tokenize = pre_tokenize 78 | self._token_translate = token_translate or {} 79 | self._token_translate_inv = { 80 | v: k 81 | for k, v in self._token_translate.items() 82 | } 83 | 84 | def tokenize(self, text, maxlen=None): 85 | """分词函数 86 | """ 87 | tokens = [ 88 | self._token_translate.get(token) or token 89 | for token in self._tokenize(text) 90 | ] 91 | if self._token_start is not None: 92 | tokens.insert(0, self._token_start) 93 | if self._token_end is not None: 94 | tokens.append(self._token_end) 95 | 96 | if maxlen is not None: 97 | index = int(self._token_end is not None) + 1 98 | self.truncate_sequence(maxlen, tokens, None, -index) 99 | 100 | return tokens 101 | 102 | def token_to_id(self, token): 103 | """token转换为对应的id 104 | """ 105 | raise NotImplementedError 106 | 107 | def tokens_to_ids(self, tokens): 108 | """token序列转换为对应的id序列 109 | """ 110 | return [self.token_to_id(token) for token in tokens] 111 | 112 | def truncate_sequence( 113 | self, maxlen, first_sequence, second_sequence=None, pop_index=-1 114 | ): 115 | """截断总长度 116 | """ 117 | if second_sequence is None: 118 | second_sequence = [] 119 | 120 | while True: 121 | total_length = len(first_sequence) + len(second_sequence) 122 | if total_length <= maxlen: 123 | break 124 | elif len(first_sequence) > len(second_sequence): 125 | first_sequence.pop(pop_index) 126 | else: 127 | second_sequence.pop(pop_index) 128 | 129 | def encode( 130 | self, first_text, second_text=None, maxlen=None, pattern='S*E*E' 131 | ): 132 | """输出文本对应token id和segment id 133 | """ 134 | if is_string(first_text): 135 | first_tokens = self.tokenize(first_text) 136 | else: 137 | first_tokens = first_text 138 | 139 | if second_text is None: 140 | second_tokens = None 141 | elif is_string(second_text): 142 | if pattern == 'S*E*E': 143 | idx = int(bool(self._token_start)) 144 | second_tokens = self.tokenize(second_text)[idx:] 145 | elif pattern == 'S*ES*E': 146 | second_tokens = self.tokenize(second_text) 147 | else: 148 | second_tokens = second_text 149 | 150 | if maxlen is not None: 151 | self.truncate_sequence(maxlen, first_tokens, second_tokens, -2) 152 | 153 | first_token_ids = self.tokens_to_ids(first_tokens) 154 | first_segment_ids = [0] * len(first_token_ids) 155 | 156 | if second_text is not None: 157 | second_token_ids = self.tokens_to_ids(second_tokens) 158 | second_segment_ids = [1] * len(second_token_ids) 159 | first_token_ids.extend(second_token_ids) 160 | first_segment_ids.extend(second_segment_ids) 161 | 162 | return first_token_ids, first_segment_ids 163 | 164 | def id_to_token(self, i): 165 | """id序列为对应的token 166 | """ 167 | raise NotImplementedError 168 | 169 | def ids_to_tokens(self, ids): 170 | """id序列转换为对应的token序列 171 | """ 172 | return [self.id_to_token(i) for i in ids] 173 | 174 | def decode(self, ids): 175 | """转为可读文本 176 | """ 177 | raise NotImplementedError 178 | 179 | def _tokenize(self, text): 180 | """基本分词函数 181 | """ 182 | raise NotImplementedError 183 | 184 | 185 | class Tokenizer(TokenizerBase): 186 | """Bert原生分词器 187 | 纯Python实现,代码修改自keras_bert的tokenizer实现 188 | """ 189 | def __init__(self, token_dict, do_lower_case=False, **kwargs): 190 | super(Tokenizer, self).__init__(**kwargs) 191 | if is_string(token_dict): 192 | token_dict = load_vocab(token_dict) 193 | 194 | self._do_lower_case = do_lower_case 195 | self._token_dict = token_dict 196 | self._token_dict_inv = {v: k for k, v in token_dict.items()} 197 | self._vocab_size = len(token_dict) 198 | 199 | for token in ['pad', 'unk', 'mask', 'start', 'end']: 200 | try: 201 | _token_id = token_dict[getattr(self, '_token_%s' % token)] 202 | setattr(self, '_token_%s_id' % token, _token_id) 203 | except: 204 | pass 205 | 206 | def token_to_id(self, token): 207 | """token转换为对应的id 208 | """ 209 | return self._token_dict.get(token, self._token_unk_id) 210 | 211 | def id_to_token(self, i): 212 | """id转换为对应的token 213 | """ 214 | return self._token_dict_inv[i] 215 | 216 | def decode(self, ids, tokens=None): 217 | """转为可读文本 218 | """ 219 | tokens = tokens or self.ids_to_tokens(ids) 220 | tokens = [token for token in tokens if not self._is_special(token)] 221 | 222 | text, flag = '', False 223 | for i, token in enumerate(tokens): 224 | if token[:2] == '##': 225 | text += token[2:] 226 | elif len(token) == 1 and self._is_cjk_character(token): 227 | text += token 228 | elif len(token) == 1 and self._is_punctuation(token): 229 | text += token 230 | text += ' ' 231 | elif i > 0 and self._is_cjk_character(text[-1]): 232 | text += token 233 | else: 234 | text += ' ' 235 | text += token 236 | 237 | text = re.sub(' +', ' ', text) 238 | text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text) 239 | punctuation = self._cjk_punctuation() + '+-/={(<[' 240 | punctuation_regex = '|'.join([re.escape(p) for p in punctuation]) 241 | punctuation_regex = '(%s) ' % punctuation_regex 242 | text = re.sub(punctuation_regex, '\\1', text) 243 | text = re.sub('(\d\.) (\d)', '\\1\\2', text) 244 | 245 | return text.strip() 246 | 247 | def _tokenize(self, text, pre_tokenize=True): 248 | """基本分词函数 249 | """ 250 | if self._do_lower_case: 251 | if is_py2: 252 | text = unicode(text) 253 | text = text.lower() 254 | text = unicodedata.normalize('NFD', text) 255 | text = ''.join([ 256 | ch for ch in text if unicodedata.category(ch) != 'Mn' 257 | ]) 258 | 259 | if pre_tokenize and self._pre_tokenize is not None: 260 | tokens = [] 261 | for token in self._pre_tokenize(text): 262 | if token in self._token_dict: 263 | tokens.append(token) 264 | else: 265 | tokens.extend(self._tokenize(token, False)) 266 | return tokens 267 | 268 | spaced = '' 269 | for ch in text: 270 | if self._is_punctuation(ch) or self._is_cjk_character(ch): 271 | spaced += ' ' + ch + ' ' 272 | elif self._is_space(ch): 273 | spaced += ' ' 274 | elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch): 275 | continue 276 | else: 277 | spaced += ch 278 | 279 | tokens = [] 280 | for word in spaced.strip().split(): 281 | tokens.extend(self._word_piece_tokenize(word)) 282 | 283 | return tokens 284 | 285 | def _word_piece_tokenize(self, word): 286 | """word内分成subword 287 | """ 288 | if word in self._token_dict: 289 | return [word] 290 | 291 | tokens = [] 292 | start, stop = 0, 0 293 | while start < len(word): 294 | stop = len(word) 295 | while stop > start: 296 | sub = word[start:stop] 297 | if start > 0: 298 | sub = '##' + sub 299 | if sub in self._token_dict: 300 | break 301 | stop -= 1 302 | if start == stop: 303 | stop += 1 304 | tokens.append(sub) 305 | start = stop 306 | 307 | return tokens 308 | 309 | @staticmethod 310 | def stem(token): 311 | """获取token的“词干”(如果是##开头,则自动去掉##) 312 | """ 313 | if token[:2] == '##': 314 | return token[2:] 315 | else: 316 | return token 317 | 318 | @staticmethod 319 | def _is_space(ch): 320 | """空格类字符判断 321 | """ 322 | return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \ 323 | unicodedata.category(ch) == 'Zs' 324 | 325 | @staticmethod 326 | def _is_punctuation(ch): 327 | """标点符号类字符判断(全/半角均在此内) 328 | 提醒:unicodedata.category这个函数在py2和py3下的 329 | 表现可能不一样,比如u'§'字符,在py2下的结果为'So', 330 | 在py3下的结果是'Po'。 331 | """ 332 | code = ord(ch) 333 | return 33 <= code <= 47 or \ 334 | 58 <= code <= 64 or \ 335 | 91 <= code <= 96 or \ 336 | 123 <= code <= 126 or \ 337 | unicodedata.category(ch).startswith('P') 338 | 339 | @staticmethod 340 | def _cjk_punctuation(): 341 | return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002' 342 | 343 | @staticmethod 344 | def _is_cjk_character(ch): 345 | """CJK类字符判断(包括中文字符也在此列) 346 | 参考:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 347 | """ 348 | code = ord(ch) 349 | return 0x4E00 <= code <= 0x9FFF or \ 350 | 0x3400 <= code <= 0x4DBF or \ 351 | 0x20000 <= code <= 0x2A6DF or \ 352 | 0x2A700 <= code <= 0x2B73F or \ 353 | 0x2B740 <= code <= 0x2B81F or \ 354 | 0x2B820 <= code <= 0x2CEAF or \ 355 | 0xF900 <= code <= 0xFAFF or \ 356 | 0x2F800 <= code <= 0x2FA1F 357 | 358 | @staticmethod 359 | def _is_control(ch): 360 | """控制类字符判断 361 | """ 362 | return unicodedata.category(ch) in ('Cc', 'Cf') 363 | 364 | @staticmethod 365 | def _is_special(ch): 366 | """判断是不是有特殊含义的符号 367 | """ 368 | return bool(ch) and (ch[0] == '[') and (ch[-1] == ']') 369 | 370 | def rematch(self, text, tokens): 371 | """给出原始的text和tokenize后的tokens的映射关系 372 | """ 373 | if is_py2: 374 | text = unicode(text) 375 | 376 | if self._do_lower_case: 377 | text = text.lower() 378 | 379 | normalized_text, char_mapping = '', [] 380 | for i, ch in enumerate(text): 381 | if self._do_lower_case: 382 | ch = unicodedata.normalize('NFD', ch) 383 | ch = ''.join([c for c in ch if unicodedata.category(c) != 'Mn']) 384 | ch = ''.join([ 385 | c for c in ch 386 | if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c)) 387 | ]) 388 | normalized_text += ch 389 | char_mapping.extend([i] * len(ch)) 390 | 391 | text, token_mapping, offset = normalized_text, [], 0 392 | for token in tokens: 393 | if self._is_special(token): 394 | token_mapping.append([]) 395 | else: 396 | token = self.stem(token) 397 | start = text[offset:].index(token) + offset 398 | end = start + len(token) 399 | token_mapping.append(char_mapping[start:end]) 400 | offset = end 401 | 402 | return token_mapping 403 | 404 | 405 | class SpTokenizer(TokenizerBase): 406 | """基于SentencePiece模型的封装,使用上跟Tokenizer基本一致。 407 | """ 408 | def __init__(self, sp_model_path, **kwargs): 409 | super(SpTokenizer, self).__init__(**kwargs) 410 | import sentencepiece as spm 411 | self.sp_model = spm.SentencePieceProcessor() 412 | self.sp_model.Load(sp_model_path) 413 | self._token_pad = self.sp_model.id_to_piece(self.sp_model.pad_id()) 414 | self._token_unk = self.sp_model.id_to_piece(self.sp_model.unk_id()) 415 | self._vocab_size = self.sp_model.get_piece_size() 416 | 417 | for token in ['pad', 'unk', 'mask', 'start', 'end']: 418 | try: 419 | _token = getattr(self, '_token_%s' % token) 420 | _token_id = self.sp_model.piece_to_id(_token) 421 | setattr(self, '_token_%s_id' % token, _token_id) 422 | except: 423 | pass 424 | 425 | def token_to_id(self, token): 426 | """token转换为对应的id 427 | """ 428 | return self.sp_model.piece_to_id(token) 429 | 430 | def id_to_token(self, i): 431 | """id转换为对应的token 432 | """ 433 | if i < self._vocab_size: 434 | return self.sp_model.id_to_piece(i) 435 | else: 436 | return '' 437 | 438 | def decode(self, ids): 439 | """转为可读文本 440 | """ 441 | tokens = [ 442 | self._token_translate_inv.get(token) or token 443 | for token in self.ids_to_tokens(ids) 444 | ] 445 | text = self.sp_model.decode_pieces(tokens) 446 | return convert_to_unicode(text) 447 | 448 | def _tokenize(self, text): 449 | """基本分词函数 450 | """ 451 | if self._pre_tokenize is not None: 452 | text = ' '.join(self._pre_tokenize(text)) 453 | 454 | tokens = self.sp_model.encode_as_pieces(text) 455 | return tokens 456 | 457 | def _is_special(self, i): 458 | """判断是不是有特殊含义的符号 459 | """ 460 | return self.sp_model.is_control(i) or \ 461 | self.sp_model.is_unknown(i) or \ 462 | self.sp_model.is_unused(i) 463 | 464 | def _is_decodable(self, i): 465 | """判断是否应该被解码输出 466 | """ 467 | return (i < self._vocab_size) and not self._is_special(i) 468 | -------------------------------------------------------------------------------- /rere/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | dic=dict() 3 | dic['NYT10-HRL']={'thre_rc':0.4,'thre_ee':0.2} 4 | dic['NYT11-HRL']={'thre_rc':0.56,'thre_ee':0.55} 5 | dic['NYT21-HRL']={'thre_rc':0.56,'thre_ee':0.55} 6 | dic['ske2019']={'thre_rc':0.5,'thre_ee':0.3} 7 | # with open('config.json','w') as f: 8 | print(dic) 9 | with open('config.json','w') as f: 10 | json.dump(dic,f) -------------------------------------------------------------------------------- /rere/extraction.py: -------------------------------------------------------------------------------- 1 | import os, sys, time, re, utils, json 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | sys.path.append('./BERT_tf2') 4 | sys.path.append('./') 5 | # import torch 6 | import bert_tools as bt 7 | import tensorflow as tf 8 | from tensorflow.keras.models import load_model 9 | from tensorflow.keras.layers import * 10 | from tensorflow.keras.callbacks import * 11 | from bert4keras.backend import keras, K 12 | from collections import defaultdict 13 | import numpy as np 14 | from tqdm import tqdm 15 | import json 16 | 17 | dname = sys.argv[1] 18 | datadir = '../data/' 19 | trains = utils.LoadJsons(datadir+dname+'/new_train.json') 20 | valids = utils.LoadJsons(datadir+dname+'/new_valid.json') 21 | tests = utils.LoadJsons(datadir+dname+'/new_test.json') 22 | thre_rc=0.2 23 | thre_ee=0.4 24 | try: 25 | with open('config.json','r') as f: 26 | thres=json.load(f) 27 | thre_rc=thres[f'{dname}']['thre_rc'] 28 | thre_ee=thres[f'{dname}']['thre_ee'] 29 | except:pass 30 | if not os.path.isdir('model/'+dname): os.makedirs('model/'+dname) 31 | 32 | def wdir(x): return 'model/'+dname+'/BCE_'+x 33 | 34 | rels = utils.TokenList(wdir('rels.txt'), 1, trains, lambda z:[x['label'] for x in z['relationMentions']]) 35 | print('rels:', rels.get_num()) 36 | 37 | if not dname == 'ske2019': 38 | bt.switch_to_en() 39 | 40 | 41 | LOSS='BCE' 42 | maxlen = 128 43 | 44 | def cpu_mid_loss(y_true,y_pred,mid=0,pi=0.1,**kwargs): 45 | eps = 1e-6 46 | y_true=K.cast(y_true, 'float32') 47 | pos = K.sum(y_true * y_pred, 1) / K.maximum(eps, K.sum(y_true, 1)) 48 | pos = - K.log(pos + eps) 49 | neg = K.sum((1-y_true) * y_pred, 1) / K.maximum(eps, K.sum(1-y_true, 1)) 50 | neg = K.abs(neg-mid) 51 | neg = - K.log(1 - neg + eps) 52 | return K.mean(pi*pos + neg) 53 | 54 | def dgcnn_block(x, dim, dila=1): 55 | y1 = Conv1D(dim, 3, padding='same', dilation_rate=dila)(x) 56 | y2 = Conv1D(dim, 3, padding='same', dilation_rate=dila, activation='sigmoid')(x) 57 | yy = multiply([y1, y2]) 58 | if yy.shape[-1] == x.shape[-1]: yy = add([yy, x]) 59 | return yy 60 | 61 | def neg_log_mean_loss(y_true, y_pred): 62 | eps = 1e-6 63 | pos = - K.sum(y_true * K.log(y_pred+eps), 1) / K.maximum(eps, K.sum(y_true, 1)) 64 | neg = K.sum((1-y_true) * y_pred, 1) / K.maximum(eps, K.sum(1-y_true, 1)) 65 | neg = - K.log(1 - neg + eps) 66 | return K.mean(pos + neg * 15) 67 | 68 | def FindValuePos(sent, value): 69 | ret = []; 70 | value = value.replace(' ', '').lower() 71 | if value == '': return ret 72 | ss = [x.replace(' ', '').lower() for x in sent] 73 | for k, v in enumerate(ss): 74 | if not value.startswith(v): continue 75 | vi = 0 76 | for j in range(k, len(ss)): 77 | if value[vi:].startswith(ss[j]): 78 | vi += len(ss[j]) 79 | if vi == len(value): 80 | ret.append( (k, j+1) ) 81 | else: break 82 | return ret 83 | 84 | def GetTopSpans(tokens, rr, K=40): 85 | cands = defaultdict(float) 86 | start_indexes = sorted(enumerate(rr[:,0]), key=lambda x:-x[1])[:K] 87 | end_indexes = sorted(enumerate(rr[:,1]), key=lambda x:-x[1])[:K] 88 | for start_index, start_score in start_indexes: 89 | if start_score < 0.1: continue 90 | if start_index >= len(tokens): continue 91 | for end_index, end_score in end_indexes: 92 | if end_score < 0.1: continue 93 | if end_index >= len(tokens): continue 94 | if end_index < start_index: continue 95 | length = end_index - start_index + 1 96 | if length > 40: continue 97 | ans = ''.join(tokens[start_index:end_index+1]).strip() 98 | if '》' in ans: continue 99 | if '、' in ans and len(ans.split('、')) > 2 and ',' not in ans and ',' not in ans: 100 | aas = ans.split('、') 101 | for aa in aas: cands[aa.strip()] += start_score * end_score / len(aas) 102 | continue 103 | cands[ans] += start_score * end_score 104 | 105 | cand_list = sorted(cands.items(), key=lambda x:len(x[0])) 106 | removes = set() 107 | contains = {} 108 | for i, (x, y) in enumerate(cand_list): 109 | for j, (xx, yy) in enumerate(cand_list[:i]): 110 | if xx in x and len(xx) < len(x): 111 | contains.setdefault(x, []).append(xx) 112 | 113 | for i, (x, y) in enumerate(cand_list): 114 | sump = sum(cands[z] for z in contains.get(x, []) if z not in removes) 115 | suml = sum(len(z) for z in contains.get(x, []) if z not in removes) 116 | if suml > 0: sump = sump * min(1, len(x) / suml) 117 | if sump > y: removes.add(x) 118 | else: 119 | for z in contains.get(x, []): removes.add(z) 120 | 121 | ret = [x for x in cand_list if x[0] not in removes] 122 | ret.sort(key=lambda x:-x[1]) 123 | return ret[:K] 124 | 125 | def GenTriple(p, x, y): 126 | return {'label':p, 'em1Text':x, 'em2Text':y} 127 | 128 | 129 | class RCModel: 130 | def __init__(self): 131 | if not dname == 'ske2019': 132 | self.bert = load_model('./bert_uncased.h5') 133 | else: 134 | self.bert = load_model('./chinese_roberta_wwm_ext.h5') 135 | xx = Lambda(lambda x:x[:,0])(self.bert.output) 136 | pos = Dense(rels.get_num(), activation='sigmoid')(xx) 137 | self.model = tf.keras.models.Model(inputs=self.bert.input, outputs=pos) 138 | bt.lock_transformer_layers(self.bert, 8) 139 | self.model_ready = False 140 | 141 | def gen_golden_y(self, datas): 142 | for dd in datas: 143 | dd['rc_obj'] = list(set(x['label'] for x in dd.get('relationMentions', []))) 144 | 145 | def make_model_data(self, datas): 146 | self.gen_golden_y(datas) 147 | for dd in tqdm(datas, desc='tokenize'): 148 | s = dd['sentText'] 149 | tokens = bt.tokenizer.tokenize(s, maxlen=maxlen) 150 | dd['tokens'] = tokens 151 | N = len(datas) 152 | X = [np.zeros((N, maxlen), dtype='int32'), np.zeros((N, maxlen), dtype='int32')] 153 | Y = np.zeros((N, rels.get_num())) 154 | for i, dd in enumerate(tqdm(datas, desc='gen XY', total=N)): 155 | tokens = dd['tokens'] 156 | X[0][i][:len(tokens)] = bt.tokenizer.tokens_to_ids(tokens) 157 | for x in dd['rc_obj']: Y[i][rels.get_id(x)] = 1 158 | return X, Y 159 | 160 | def load_model(self): 161 | self.model.load_weights(wdir('rc.h5')) 162 | self.model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) 163 | self.model_ready = True 164 | 165 | def train(self, datas, batch_size=32, epochs=10): 166 | self.X, self.Y = self.make_model_data(datas) 167 | self.optimizer = bt.get_suggested_optimizer(5e-5, len(datas) * epochs // batch_size) 168 | if LOSS=='BCE': 169 | self.model.compile(self.optimizer, 'binary_crossentropy', metrics=['accuracy']) 170 | elif LOSS=='MID': 171 | from functools import partial 172 | PI_RC = np.sum(self.Y)/np.prod(self.Y.shape) 173 | FN_RATIO = 0.05 174 | mid_loss = partial(cpu_mid_loss, mid=PI_RC*(1+FN_RATIO),pi=0.04) 175 | self.model.compile(self.optimizer, mid_loss, metrics=['accuracy']) 176 | self.cb_mcb = keras.callbacks.ModelCheckpoint(wdir('rc.h5'), save_weights_only=True, verbose=1) 177 | self.model.fit(self.X, self.Y, batch_size, epochs=epochs, shuffle=True, 178 | validation_split=0.01, callbacks=[self.cb_mcb]) 179 | self.model_ready = True 180 | 181 | def get_output(self, datas, pred, threshold=0.5): 182 | for dd, pp in zip(datas, pred): 183 | dd['rc_pred'] = list(rels.get_token(i) for i, sc in enumerate(pp) if sc > threshold) 184 | 185 | def evaluate(self, datas): 186 | ccnt, gcnt, ecnt = 0, 0, 0 187 | for dd in datas: 188 | plabels = set(dd['rc_pred']) 189 | ecnt += len(plabels) 190 | gcnt += len(set(dd['rc_obj'])) 191 | ccnt += len(plabels & set(dd['rc_obj'])) 192 | return utils.CalcF1(ccnt, ecnt, gcnt) 193 | 194 | def predict(self, datas, threshold=0.5, ofile=None): 195 | if not self.model_ready: self.load_model() 196 | self.vX, self.vY = self.make_model_data(datas) 197 | pred = self.model.predict(self.vX, batch_size=64, verbose=1) 198 | self.get_output(datas, pred, threshold) 199 | f1str = self.evaluate(datas) 200 | if ofile is not None: 201 | utils.SaveList(map(lambda x:json.dumps(x, ensure_ascii=False), datas), wdir(ofile)) 202 | print(f1str) 203 | return f1str 204 | 205 | class EEModel: 206 | def __init__(self): 207 | if not dname == 'ske2019': 208 | self.bert = load_model('./bert_uncased.h5') 209 | else: 210 | self.bert = load_model('./chinese_roberta_wwm_ext.h5') 211 | pos = Dense(4, activation='sigmoid')(self.bert.output) 212 | self.model = tf.keras.models.Model(inputs=self.bert.input, outputs=pos) 213 | bt.lock_transformer_layers(self.bert, 3) 214 | self.model_ready = False 215 | 216 | def make_model_data(self, datas): 217 | if 'tokens' not in datas[0]: 218 | for dd in tqdm(datas, desc='tokenize'): 219 | s = dd['sentText'] 220 | tokens = bt.tokenizer.tokenize(s, maxlen=maxlen) 221 | dd['tokens'] = tokens 222 | N = 0 223 | for dd in tqdm(datas, desc='matching'): 224 | otokens = bt.restore_token_list(dd['sentText'], dd['tokens']) 225 | dd['otokens'] = otokens 226 | ys = {} 227 | if 'rc_pred' in dd: 228 | plist = dd['rc_pred'] 229 | else: 230 | for x in dd.get('relationMentions', []): 231 | ys.setdefault(x['label'], []).append( (x['em1Text'], x['em2Text']) ) 232 | plist = sorted(ys.keys()) 233 | yys = [] 234 | for pp in plist: 235 | spos, opos = [], [] 236 | for s, o in ys.get(pp, []): 237 | ss, oo = FindValuePos(otokens, s), FindValuePos(otokens, o) 238 | if len(ss) == 0 and len(oo) == 0: continue 239 | spos.extend(ss) 240 | opos.extend(oo) 241 | yys.append( {'pp':pp, 'spos':spos, 'opos':opos} ) 242 | dd['ee_obj'] = yys 243 | N += len(yys) 244 | X = [np.zeros((N, maxlen), dtype='int32'), np.zeros((N, maxlen), dtype='int8')] 245 | Y = np.zeros((N, maxlen, 4), dtype='int8') 246 | ii = 0 247 | for dd in tqdm(datas, desc='gen EE XY'): 248 | tokens = dd['tokens'] 249 | for item in dd['ee_obj']: 250 | pp, spos, opos = item['pp'], item['spos'], item['opos'] 251 | first = bt.tokenizer.tokenize(pp) 252 | offset = len(first) 253 | item['offset'] = offset 254 | tts = (first + tokens[1:])[:maxlen] 255 | X[0][ii][:len(tts)] = bt.tokenizer.tokens_to_ids(tts) 256 | X[1][ii][offset:offset+len(tokens)-1] = 1 257 | for u, v in spos: 258 | try: 259 | Y[ii][offset+u,0] = 1 260 | Y[ii][offset+v-1,1] = 1 261 | except: pass 262 | for u, v in opos: 263 | try: 264 | Y[ii][offset+u,2] = 1 265 | Y[ii][offset+v-1,3] = 1 266 | except: pass 267 | ii += 1 268 | return X, Y 269 | 270 | def train(self, datas, batch_size=32, epochs=10): 271 | self.X, self.Y = self.make_model_data(datas) 272 | self.optimizer = bt.get_suggested_optimizer(5e-5, len(datas) * epochs // batch_size) 273 | self.model.compile(self.optimizer, 'binary_crossentropy', metrics=['accuracy']) 274 | self.cb_mcb = keras.callbacks.ModelCheckpoint(wdir('ee.h5'), save_weights_only=True, verbose=1) 275 | self.model.fit(self.X, self.Y, batch_size, epochs=epochs, shuffle=True, 276 | validation_split=0.01, callbacks=[self.cb_mcb]) 277 | self.model_ready = True 278 | 279 | def load_model(self): 280 | self.model.load_weights(wdir('ee.h5')) 281 | self.model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) 282 | self.model_ready = True 283 | 284 | def get_output(self, datas, pred, threshold=0.5): 285 | ii = 0 286 | for dd in datas: 287 | rtriples = [] 288 | for item in dd['ee_obj']: 289 | predicate, offset = item['pp'], item['offset'] 290 | rr = pred[ii]; ii += 1 291 | subs = GetTopSpans(dd['otokens'], rr[offset:,:2]) 292 | objs = GetTopSpans(dd['otokens'], rr[offset:,2:]) 293 | 294 | vv1 = [x for x,y in subs if y >= 0.1] 295 | vv2 = [x for x,y in objs if y >= 0.1] 296 | 297 | subv = {x:y for x,y in subs} 298 | objv = {x:y for x,y in objs} 299 | 300 | #mats = None 301 | #if len(vv1) * len(vv2) >= 4: 302 | # sent = ''.join(data[2]) 303 | # mats = set(Match(sent, vv1, vv2)) 304 | 305 | for sv1, sv2 in [(sv1, sv2) for sv1 in vv1 for sv2 in vv2] : 306 | if sv1 == sv2: continue 307 | score = min(subv[sv1], objv[sv2]) 308 | #if mats is not None and (sv1, sv2) not in mats: score -= 0.5 309 | if score < threshold: continue 310 | rtriples.append( GenTriple(predicate, sv1, sv2) ) 311 | 312 | dd['ee_pred'] = rtriples 313 | # assert '' not in dd['otokens'] 314 | 315 | def evaluate(self, datas): 316 | ccnt, gcnt, ecnt = 0, 0, 0 317 | for dd in datas: 318 | golden = set(); predict = set() 319 | for x in dd['relationMentions']: 320 | ss = '|'.join([x[nn] for nn in ['label', 'em1Text', 'em2Text']]) 321 | golden.add(ss) 322 | for x in dd['ee_pred']: 323 | ss = '|'.join([x[nn] for nn in ['label', 'em1Text', 'em2Text']]) 324 | predict.add(ss) 325 | ecnt += len(predict) 326 | gcnt += len(golden) 327 | ccnt += len(predict & golden) 328 | return utils.CalcF1(ccnt, ecnt, gcnt) 329 | 330 | def predict(self, datas, threshold=0.5, ofile=None): 331 | ffout=open("ans.txt",'a') 332 | if not self.model_ready: self.load_model() 333 | self.vX, self.vY = self.make_model_data(datas) 334 | pred = self.model.predict(self.vX, batch_size=64, verbose=1) 335 | self.get_output(datas, pred, threshold=threshold) 336 | if ofile is not None: 337 | utils.SaveList(map(lambda x:json.dumps(x, ensure_ascii=False), datas), wdir(ofile)) 338 | f1str = self.evaluate(datas) 339 | ffout.write(f1str) 340 | print(f1str) 341 | ffout.close() 342 | return f1str 343 | 344 | if __name__ == '__main__': 345 | rc = RCModel() 346 | if 'train' in sys.argv: 347 | rc.train(trains, batch_size=32, epochs=20) 348 | rc.predict(tests, threshold=thre_rc, ofile='valid_rc.json') 349 | tests = utils.LoadJsons(wdir('valid_rc.json')) 350 | ee = EEModel() 351 | if 'train' in sys.argv: 352 | ee.train(trains, batch_size=32, epochs=20) 353 | ee.predict(tests, threshold=thre_ee, ofile='valid_ee.json') 354 | print('done') -------------------------------------------------------------------------------- /rere/model/readme.md: -------------------------------------------------------------------------------- 1 | put pretrained models here. -------------------------------------------------------------------------------- /rere/readme.md: -------------------------------------------------------------------------------- 1 | file bert_uncased.h5 chinese_roberta_wwm_ext.h5 is necessary in \rere_reproduce\rere 2 | it can transform by bert-to-h5.py 3 | -------------------------------------------------------------------------------- /rere/tfhub/chinese_roberta_wwm_ext_L-12_H-768_A-12/readme.md: -------------------------------------------------------------------------------- 1 | the pretrained model can be found in -------------------------------------------------------------------------------- /rere/tfhub/uncased_L-12_H-768_A-12/readme.md: -------------------------------------------------------------------------------- 1 | a model can be downloaded in https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-12_H-768_A-12.zip -------------------------------------------------------------------------------- /rere/utils.py: -------------------------------------------------------------------------------- 1 | # coding = utf-8 2 | 3 | import os, re, sys, random, urllib.parse, json 4 | from collections import defaultdict 5 | 6 | def WriteLine(fout, lst): 7 | fout.write('\t'.join([str(x) for x in lst]) + '\n') 8 | 9 | def RM(patt, sr): 10 | mat = re.search(patt, sr, re.DOTALL | re.MULTILINE) 11 | return mat.group(1) if mat else '' 12 | 13 | try: import requests 14 | except: pass 15 | def GetPage(url, cookie='', proxy='', timeout=5): 16 | try: 17 | headers = {'User-Agent':'Mozilla/5.0 (Windows NT 6.3; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/42.0.2311.90 Safari/537.36'} 18 | if cookie != '': headers['cookie'] = cookie 19 | if proxy != '': 20 | proxies = {'http': proxy, 'https': proxy} 21 | resp = requests.get(url, headers=headers, proxies=proxies, timeout=timeout) 22 | else: resp = requests.get(url, headers=headers, timeout=timeout) 23 | content = resp.content 24 | try: 25 | import chardet 26 | charset = chardet.detect(content).get('encoding','utf-8') 27 | if charset.lower().startswith('gb'): charset = 'gbk' 28 | content = content.decode(charset, errors='replace') 29 | except: 30 | headc = content[:min([3000,len(content)])].decode(errors='ignore') 31 | charset = RM('charset="?([-a-zA-Z0-9]+)', headc) 32 | if charset == '': charset = 'utf-8' 33 | content = content.decode(charset, errors='replace') 34 | except Exception as e: 35 | print(e) 36 | content = '' 37 | return content 38 | 39 | def GetJson(url, cookie='', proxy='', timeout=5.0): 40 | try: 41 | headers = {'User-Agent':'Mozilla/5.0 (Windows NT 6.3; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/42.0.2311.90 Safari/537.36'} 42 | if cookie != '': headers['cookie'] = cookie 43 | if proxy != '': 44 | proxies = {'http': proxy, 'https': proxy} 45 | resp = requests.get(url, headers=headers, proxies=proxies, timeout=timeout) 46 | else: resp = requests.get(url, headers=headers, timeout=timeout) 47 | return resp.json() 48 | except Exception as e: 49 | print(e) 50 | content = {} 51 | return content 52 | 53 | def FindAllHrefs(url, content=None, regex=''): 54 | ret = set() 55 | if content == None: content = GetPage(url) 56 | patt = re.compile('href="?([a-zA-Z0-9-_:/.%]+)') 57 | for xx in re.findall(patt, content): 58 | ret.add( urllib.parse.urljoin(url, xx) ) 59 | if regex != '': ret = (x for x in ret if re.match(regex, x)) 60 | return list(ret) 61 | 62 | def Translate(txt): 63 | postdata = {'from': 'en', 'to': 'zh', 'transtype': 'realtime', 'query': txt} 64 | url = "http://fanyi.baidu.com/v2transapi" 65 | try: 66 | resp = requests.post(url, data=postdata, 67 | headers={'Referer': 'http://fanyi.baidu.com/'}) 68 | ret = resp.json() 69 | ret = ret['trans_result']['data'][0]['dst'] 70 | except Exception as e: 71 | print(e) 72 | ret = '' 73 | return ret 74 | 75 | def IsChsStr(z): 76 | return re.search('^[\u4e00-\u9fa5]+$', z) is not None 77 | 78 | def FreqDict2List(dt): 79 | return sorted(dt.items(), key=lambda d:d[-1], reverse=True) 80 | 81 | def SelectRowsbyCol(fn, ofn, st, num = 0): 82 | with open(fn, encoding = "utf-8") as fin: 83 | with open(ofn, "w", encoding = "utf-8") as fout: 84 | for line in (ll for ll in fin.read().split('\n') if ll != ""): 85 | if line.split('\t')[num] in st: 86 | fout.write(line + '\n') 87 | 88 | def MergeFiles(dir, objfile, regstr = ".*"): 89 | with open(objfile, "w", encoding = "utf-8") as fout: 90 | for file in os.listdir(dir): 91 | if re.match(regstr, file): 92 | with open(os.path.join(dir, file), encoding = "utf-8") as filein: 93 | fout.write(filein.read()) 94 | 95 | def JoinFiles(fnx, fny, ofn): 96 | with open(fnx, encoding = "utf-8") as fin: 97 | lx = [vv for vv in fin.read().split('\n') if vv != ""] 98 | with open(fny, encoding = "utf-8") as fin: 99 | ly = [vv for vv in fin.read().split('\n') if vv != ""] 100 | with open(ofn, "w", encoding = "utf-8") as fout: 101 | for i in range(min(len(lx), len(ly))): 102 | fout.write(lx[i] + "\t" + ly[i] + "\n") 103 | 104 | 105 | def RemoveDupRows(file, fobj='*'): 106 | st = set() 107 | if fobj == '*': fobj = file 108 | with open(file, encoding = "utf-8") as fin: 109 | for line in fin.read().split('\n'): 110 | if line == "": continue 111 | st.add(line) 112 | with open(fobj, "w", encoding = "utf-8") as fout: 113 | for line in st: 114 | fout.write(line + '\n') 115 | 116 | def LoadCSV(fn): 117 | ret = [] 118 | with open(fn, encoding='utf-8') as fin: 119 | for line in fin: 120 | lln = line.rstrip('\r\n').split('\t') 121 | ret.append(lln) 122 | return ret 123 | 124 | def LoadCSVg(fn): 125 | with open(fn, encoding='utf-8') as fin: 126 | for line in fin: 127 | lln = line.rstrip('\r\n').split('\t') 128 | yield lln 129 | 130 | def SaveCSV(csv, fn): 131 | with open(fn, 'w', encoding='utf-8') as fout: 132 | for x in csv: 133 | WriteLine(fout, x) 134 | 135 | def SplitTables(fn, limit=3): 136 | rst = set() 137 | with open(fn, encoding='utf-8') as fin: 138 | for line in fin: 139 | lln = line.rstrip('\r\n').split('\t') 140 | rst.add(len(lln)) 141 | if len(rst) > limit: 142 | print('%d tables, exceed limit %d' % (len(rst), limit)) 143 | return 144 | for ii in rst: 145 | print('%d columns' % ii) 146 | with open(fn.replace('.txt', '') + '.split.%d.txt' % ii, 'w', encoding='utf-8') as fout: 147 | with open(fn, encoding='utf-8') as fin: 148 | for line in fin: 149 | lln = line.rstrip('\r\n').split('\t') 150 | if len(lln) == ii: 151 | fout.write(line) 152 | 153 | def LoadSet(fn): 154 | with open(fn, encoding="utf-8") as fin: 155 | st = set(ll for ll in fin.read().split('\n') if ll != "") 156 | return st 157 | 158 | def LoadList(fn): 159 | with open(fn, encoding="utf-8") as fin: 160 | st = list(ll for ll in fin.read().split('\n') if ll != "") 161 | return st 162 | 163 | def LoadJsonsg(fn): return map(json.loads, LoadListg(fn)) 164 | def LoadJsons(fn): return list(LoadJsonsg(fn)) 165 | 166 | def LoadListg(fn): 167 | with open(fn, encoding="utf-8") as fin: 168 | for ll in fin: 169 | ll = ll.strip() 170 | if ll != '': yield ll 171 | 172 | def LoadDict(fn, func=str): 173 | dict = {} 174 | with open(fn, encoding = "utf-8") as fin: 175 | for lv in (ll.split('\t', 1) for ll in fin.read().split('\n') if ll != ""): 176 | dict[lv[0]] = func(lv[1]) 177 | return dict 178 | 179 | def SaveDict(dict, ofn, output0 = True): 180 | with open(ofn, "w", encoding = "utf-8") as fout: 181 | for k in dict.keys(): 182 | if output0 or dict[k] != 0: 183 | fout.write(str(k) + "\t" + str(dict[k]) + "\n") 184 | 185 | def SaveList(st, ofn): 186 | with open(ofn, "w", encoding = "utf-8") as fout: 187 | for k in st: 188 | fout.write(str(k) + "\n") 189 | 190 | def ListDirFiles(dir, filter=None): 191 | if filter is None: 192 | return [os.path.join(dir, x) for x in os.listdir(dir)] 193 | return [os.path.join(dir, x) for x in os.listdir(dir) if filter(x)] 194 | 195 | def ProcessDir(dir, func, param): 196 | for file in os.listdir(dir): 197 | print(file) 198 | func(os.path.join(dir, file), param) 199 | 200 | def GetLines(fn): 201 | with open(fn, encoding = "utf-8", errors = 'ignore') as fin: 202 | lines = list(map(str.strip, fin.readlines())) 203 | return lines 204 | 205 | 206 | def SortRows(file, fobj, cid, type=int, rev = True): 207 | lines = LoadCSV(file) 208 | dat = [] 209 | for dv in lines: 210 | if len(dv) <= cid: continue 211 | dat.append((type(dv[cid]), dv)) 212 | with open(fobj, "w", encoding = "utf-8") as fout: 213 | for dd in sorted(dat, reverse = rev): 214 | fout.write('\t'.join(dd[1]) + '\n') 215 | 216 | def SampleRows(file, fobj, num): 217 | zz = list(open(file, encoding='utf-8')) 218 | num = min([num, len(zz)]) 219 | zz = random.sample(zz, num) 220 | with open(fobj, 'w', encoding='utf-8') as fout: 221 | for xx in zz: fout.write(xx) 222 | 223 | def SetProduct(file1, file2, fobj): 224 | l1, l2 = GetLines(file1), GetLines(file2) 225 | with open(fobj, 'w', encoding='utf-8') as fout: 226 | for z1 in l1: 227 | for z2 in l2: 228 | fout.write(z1 + z2 + '\n') 229 | 230 | class TokenList: 231 | def __init__(self, file, low_freq=2, source=None, func=None, save_low_freq=2, special_marks=[]): 232 | if not os.path.exists(file): 233 | tdict = defaultdict(int) 234 | for i, xx in enumerate(special_marks): tdict[xx] = 100000000 - i 235 | for xx in source: 236 | for token in func(xx): tdict[token] += 1 237 | tokens = FreqDict2List(tdict) 238 | tokens = [x for x in tokens if x[1] >= save_low_freq] 239 | SaveCSV(tokens, file) 240 | self.id2t = ['', ''] + \ 241 | [x for x,y in LoadCSV(file) if float(y) >= low_freq] 242 | self.t2id = {v:k for k,v in enumerate(self.id2t)} 243 | def get_id(self, token): return self.t2id.get(token, 1) 244 | def get_token(self, ii): return self.id2t[ii] 245 | def get_num(self): return len(self.id2t) 246 | 247 | def CalcF1(correct, output, golden): 248 | prec = correct / max(output, 1); reca = correct / max(golden, 1); 249 | f1 = 2 * prec * reca / max(1e-9, prec + reca) 250 | pstr = 'Prec: %.4f %d/%d, Reca: %.4f %d/%d, F1: %.4f' % (prec, correct, output, reca, correct, golden, f1) 251 | return pstr 252 | 253 | def Upgradeljqpy(url=None): 254 | if url is None: url = 'http://gdm.fudan.edu.cn/files1/ljq/ljqpy.py' 255 | dirs = [dir for dir in reversed(sys.path) if os.path.isdir(dir) and 'ljqpy.py' in os.listdir(dir)] 256 | if len(dirs) == 0: raise Exception("package directory no found") 257 | dir = dirs[0] 258 | print('downloading ljqpy.py from %s to %s' % (url, dir)) 259 | resp = requests.get(url) 260 | if b'Upgradeljqpy' not in resp.content: raise Exception('bad file') 261 | with open(os.path.join(dir, 'ljqpy.py'), 'wb') as fout: 262 | fout.write(resp.content) 263 | print('success') 264 | 265 | def sql(cmd=''): 266 | if cmd == '': cmd = input("> ") 267 | cts = [x for x in cmd.strip().lower()] 268 | instr = False 269 | for i in range(len(cts)): 270 | if cts[i] == '"' and cts[i-1] != '\\': instr = not instr 271 | if cts[i] == ' ' and instr: cts[i] = " " 272 | cmds = "".join(cts).split(' ') 273 | keyw = { 'select', 'from', 'to', 'where' } 274 | ct, kn = {}, '' 275 | for xx in cmds: 276 | if xx in keyw: kn = xx 277 | else: ct[kn] = ct.get(kn, "") + " " + xx 278 | 279 | for xx in ct.keys(): 280 | ct[xx] = ct[xx].replace(" ", " ").strip() 281 | 282 | if ct.get('where', "") == "": ct['where'] = 'True' 283 | 284 | if os.path.isdir(ct['from']): fl = [os.path.join(ct['from'], x) for x in os.listdir(ct['from'])] 285 | else: fl = ct['from'].split('+') 286 | 287 | if ct.get('to', "") == "": ct['to'] = 'temp.txt' 288 | 289 | for xx in ct.keys(): 290 | print(xx + " : " + ct[xx]) 291 | 292 | total = 0 293 | with open(ct['to'], 'w', encoding = 'utf-8') as fout: 294 | for fn in fl: 295 | print('selecting ' + fn) 296 | for xx in open(fn, encoding = 'utf-8'): 297 | x = xx.rstrip('\r\n').split('\t') 298 | if eval(ct['where']): 299 | if ct['select'] == '*': res = "\t".join(x) + '\n' 300 | else: res = "\t".join(eval('[' + ct['select'] + ']')) + '\n' 301 | fout.write(res) 302 | total += 1 303 | 304 | print('completed, ' + str(total) + " records") 305 | 306 | def cmd(): 307 | while True: 308 | cmd = input("> ") 309 | sql(cmd) 310 | --------------------------------------------------------------------------------