├── main.py ├── zuo ├── predict_similarity.py ├── util.py ├── generate_training_data.py └── run_classifier_predict_online.py ├── bert ├── optimization_finetuning.py ├── tokenization.py └── modeling.py ├── README.md └── run_classifier.py /main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import json 3 | import numpy as np 4 | from zuo.predict_similarity import predict_single 5 | # from zuo.predict_similarity import predict as predict_combine 6 | #from predict_similarity_2 import predict as predict_2 # todo todo todo 7 | # from main_buer import predict_standardized_output as predict_2 8 | # from predict_similarity_s import predict as predict_baili 9 | input_path_labor = "/input/labor/input.json" 10 | output_path_labor = "/output/labor/output.json" 11 | input_path_divorce = "/input/divorce/input.json" 12 | output_path_divorce = "/output/divorce/output.json" 13 | input_path_loan = "/input/loan/input.json" 14 | output_path_loan = "/output/loan/output.json" 15 | 16 | # todo todo todo todo ==============================要在本地测试main.py,打开这几行代码 17 | input_path_labor = "zuo/data_all/labor/data_small_selected.json" 18 | output_path_labor = "zuo/data_all/labor/output.json" 19 | input_path_divorce = "zuo/data_all/divorce/data_small_selected.json" 20 | output_path_divorce = "zuo/data_all/divorce/output.json" 21 | input_path_loan = "zuo/data_all/loan/data_small_selected.json" 22 | output_path_loan = "zuo/data_all/loan/output.json" 23 | ## todo todo todo todo ==============================要在本地测试main.py,打开这几行代码 24 | # roeberta_zh_L-24_H-768_A-12.zip 25 | 26 | def predict(input_path, output_path): 27 | category_en = '' 28 | if 'labor' in input_path: 29 | category_en = 'labor' 30 | if 'divorce' in input_path: 31 | category_en = 'divorce' 32 | if 'loan' in input_path: 33 | category_en = 'loan' 34 | 35 | inf = open(input_path, "r", encoding='utf-8') 36 | ouf = open(output_path, "w", encoding='utf-8') 37 | count = 0 38 | for line in inf: 39 | pre_doc = json.loads(line) 40 | new_pre_doc = [] 41 | for sent in pre_doc: 42 | labels_prob_big=predict_single(str(sent['sentence']), category_en) # buer. labels_prob_buer: {'LB1':0.01,'LB2':0.02,'LB3':0.0,...} 43 | print("###labels_prob_big:",labels_prob_big) 44 | label_list=get_label_list_single(labels_prob_big) 45 | print("###label_list:",label_list) 46 | sent['labels'] = label_list 47 | new_pre_doc.append(sent) 48 | #count = count + 1 49 | json.dump(new_pre_doc, ouf, ensure_ascii=False) 50 | ouf.write('\n') 51 | 52 | inf.close() 53 | ouf.close() 54 | 55 | def check_whether_has_any_candidate(labels_with_prob_dict): 56 | """ 57 | 检测是否有有效候选项 58 | :param labels_with_prob_dict: 59 | :return: True if has candidate; False is not has candidate 60 | """ 61 | candidate_list=[] 62 | for lable_tag_en, p_temp in labels_with_prob_dict.items(): 63 | if float(p_temp)>0.01: 64 | candidate_list.append(lable_tag_en) 65 | 66 | return candidate_list 67 | #if len(candidate_list)>0: 68 | # return candidate_list 69 | #else: 70 | # return False 71 | 72 | def combine_prob(labels_prob_1,labels_prob_2,weight_1=0.50): 73 | """ 74 | 整合两个概率,概率取加权平均 75 | :param labels_prob_1: 76 | :param labels_prob_2: 77 | :return: 加权平均后的概率 78 | """ 79 | result_dict={} 80 | for tag, p_1 in labels_prob_1.items(): 81 | # print("tag:==",tag,"===;p_1:",p_1) # tag: LB1 ;p_1: 0.0 82 | p_2=labels_prob_2[tag] 83 | p_avg=float(p_1)*weight_1+float(p_2) *(1.0-weight_1) 84 | result_dict[tag]=p_avg 85 | return result_dict 86 | 87 | def get_label_list(labels_prob_1,labels_prob_2,threshold=0.5): 88 | """ 89 | 90 | :param labels_prob_1: {'LB1': 0.0016595844645053148, 'LB2': 0.11449998617172241, 'LB3': 0.003680239664390683, 91 | :param labels_prob_2: {'LB1': 1.810735193430446e-05, 'LB2': 0.0016248620037610333, 'LB3': 1.8363494746154174e-05 92 | :return: 93 | """ 94 | label_list=[] 95 | for tag_en, possibility_1 in labels_prob_1.items(): 96 | possibility_2=labels_prob_2[tag_en] 97 | possibility=np.average([possibility_1,possibility_2]) 98 | if possibility>threshold: 99 | label_list.append(tag_en) 100 | return label_list 101 | 102 | def get_label_list_single(labels_prob_1,threshold=0.5): 103 | """ 104 | 105 | :param labels_prob_1: {'LB1': 0.0016595844645053148, 'LB2': 0.11449998617172241, 'LB3': 0.003680239664390683, 106 | :param labels_prob_2: {'LB1': 1.810735193430446e-05, 'LB2': 0.0016248620037610333, 'LB3': 1.8363494746154174e-05 107 | :return: 108 | """ 109 | label_list=[] 110 | for tag_en, possibility_1 in labels_prob_1.items(): 111 | #possibility_2=labels_prob_2[tag_en] 112 | #possibility=np.average([possibility_1,possibility_2]) 113 | if possibility_1>threshold: 114 | label_list.append(tag_en) 115 | return label_list 116 | 117 | # labor领域预测 118 | predict(input_path_labor, output_path_labor) 119 | 120 | # loan领域预测 121 | predict(input_path_loan, output_path_loan) 122 | 123 | # divorce领域预测 124 | predict(input_path_divorce, output_path_divorce) 125 | -------------------------------------------------------------------------------- /zuo/predict_similarity.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # coding:utf8 3 | """ 4 | @author: Cong Yu 5 | @time: 2019-07-07 10:16 6 | """ 7 | 8 | import json 9 | import time 10 | import numpy as np 11 | import random 12 | from zuo.util import category2tags_dict, category2tag2id_dict, category2selectedtags, category2selectedtags_dict, \ 13 | category_cn_dict, load_factor_with_additional_info,sentence_match_single 14 | from zuo.run_classifier_predict_online import predict_online as predict_online_1 15 | 16 | data_path = 'zuo/data_all/' 17 | factorzh_additionalinfo_dict, _ = load_factor_with_additional_info(data_path + 'factor_desc_represent_add_re.csv')#'factor_desc_represent.csv') 18 | 19 | 20 | def predict_single(sentence, category_en, candidate_list=None, threshold=0.5): 21 | """ 22 | 通过预测得到标签的列表: 会综合使用多个标签或句子来做预测 23 | :param sentence: 句子 24 | :param category_en: 类别信息,如 'labor' 25 | : return: label_result_list: 预测除的标签. e.g. label_result_list=['LN1','LN13'] 26 | : return: tag_en_possibility_dict,这个模型融合时使用 e.g. tag_en_possibility_dict['LN1']=0.6 27 | """ 28 | tags_dict, tag2id_dict, selectedtags, selectedtags_dict = category2tags_dict[category_en], category2tag2id_dict[ 29 | category_en], category2selectedtags[category_en], category2selectedtags_dict[category_en] 30 | # print("###tags_dict:",tags_dict,";tag2id_dict:",tag2id_dict,";selectedtags:",selectedtags,";selectedtags_dict:",selectedtags_dict) 31 | selectedtag2index_dict = {v: k for k, v in selectedtags_dict.items()} 32 | 33 | label_result_list = [] 34 | tag_en_possibility_dict = {} # e.g. tag_en_possibility_dict['LN1']=0.6 35 | 36 | if candidate_list is not None and len(candidate_list) > 0: # 有传入的候选项不为空,那么更新selectedtags即更新模型预测的标签范围 37 | index_list = [tag2id_dict[tag_en_] for tag_en_ in candidate_list] # [0,1,3] 38 | selectedtags = [selectedtags_dict[indexx] for indexx in index_list] 39 | 40 | # print("#####bxul.selectedtags:",selectedtags) 41 | for candidate_tag_cn in selectedtags: 42 | if isinstance(sentence, float): return label_result_list 43 | # if len(sentence) > 90: sentence = sentence[0:45] + '。' + sentence[-45:] # OLD 44 | if len(sentence) > 250: sentence = sentence[0:125] + '。' + sentence[-125:] # OLD 45 | 46 | key = category_cn_dict[category_en] + "_" + candidate_tag_cn 47 | possibility_list = [] 48 | list_allow = get_allow_list_tag(factorzh_additionalinfo_dict[key]) 49 | for k, candi_tag_cn in factorzh_additionalinfo_dict[key].items(): # k:'label_zh',candi_tag_cn: '拒绝履行偿还' 50 | if k not in list_allow: continue # # if k not in ['label_zh','desc','sentence_repres1']:continue 51 | type_information = category_cn_dict[category_en] + '的' + candi_tag_cn 52 | ############################################################### 53 | # 添加一个判断,减少计算量即多数时候,只用一个模型计算就可以了,少数情况用两个模型 54 | _, possibility = predict_online_6(sentence, type_information) 55 | weight = 1 # 0.3333 #0.175 if k!='label_zh' else 0.3 56 | possibility_pos=possibility[1] 57 | possibility_list.append((possibility_pos, weight, k)) 58 | p_list = [e[0] * e[1] for e in possibility_list] 59 | possibility_pos_final = np.average(p_list) 60 | 61 | index = selectedtag2index_dict[candidate_tag_cn] 62 | tag_en = tags_dict[index] # e.g. tag_en='LN1' 63 | tag_en_possibility_dict[tag_en] = possibility_pos_final 64 | 65 | if possibility_pos_final > threshold: # 如果超过阀值,加入到预测除的标签列表 66 | label_result_list.append(tag_en) 67 | ##################这里使用正则####这里使用正则####这里使用正则###################################################################################### 68 | reugular_expression = factorzh_additionalinfo_dict[category_cn_dict[category_en] + "_" + candidate_tag_cn][ 69 | 'reugular_expression'] 70 | if reugular_expression != '' and str(reugular_expression) != 'nan': 71 | flag = sentence_match_single(sentence, reugular_expression) 72 | if flag == True: 73 | tag_en_possibility_dict[tag_en] = 1.0 74 | ##################这里使用正则###这里使用正则###这里使用正则######################################################################################## 75 | return tag_en_possibility_dict # label_result_list, 76 | 77 | 78 | 79 | def get_allow_list_tag(k_candi_tag_dict): # k,candi_tag_cn 80 | """ 81 | 获取允许的候选项 82 | :param k_candi_tag_dict: '借款纠纷_拒绝履行偿还': {'label_zh': '拒绝履行偿还', 'desc': '未按时偿还借款|拒不偿还借款', 'sentence_repres1': '如果XX未按指定的期间履行给付金钱义务,应当依照xx规定,加倍支付迟延履行期间的债务利息', 'sentence_repres2': '被告本人没有向信用社贷款,没有签过合同和借款凭证,也没有实际领取贷款本金,因此不同意还款。', 'sentence_repres3': '合同约定的借款期限届满之日,XXX未能履行还款义务'} 83 | :return: 84 | """ 85 | allow_tag_list = [] 86 | allow_tag_list.append('label_zh') 87 | # tag_list = random.sample(['desc', 'sentence_repres1', 'sentence_repres2', 'sentence_repres3'], 2) 88 | #allow_tag_list.extend(tag_list) 89 | # todo 90 | return allow_tag_list 91 | 92 | 93 | if __name__ == "__main__": 94 | text = "二、威海市文登区畜牧兽医技术服务中心向宋忠文支付2016年度带薪年休假工资1561.68元,于本判决生效后十日内付清;" 95 | domain = "labor" 96 | 97 | labels_prob = predict(text, domain) 98 | print(labels_prob) 99 | -------------------------------------------------------------------------------- /zuo/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | import codecs 4 | import pandas as pd 5 | import re 6 | category_cn_dict = {'divorce': '婚姻家庭', 'labor': '劳动争议', 'loan': '借款纠纷'} 7 | 8 | category_list = ['divorce', 'labor', 'loan'] 9 | data_path = 'zuo/data_all/' # 'zuo/data_all/' 10 | 11 | negative_word_list = ['不', '未', '没', '没有', '无', '非', '并未', '不再', '不能', '无法', '不足以', '不至于', '不存在', 12 | '不能证明', '不认可','尚未', '不行', '不到', '不满', '未满', '未到', '没到', '没满', '没法'] # 否定词列表 13 | negative_word = '(' + '|'.join(negative_word_list) + ')' # 否定词正则 14 | 15 | def load_factor_with_additional_info(source_file): 16 | """ 17 | 加载要素-描述-代表性描述的文件 18 | :param source_file: 19 | :return: 返回一个dict: dict[factor_zh]={desc,sentence_repres1,sentence_repres2,sentence_repres3} 20 | """ 21 | # 1.加载文件 22 | data = pd.read_csv(source_file) 23 | # 2.遍历并放入dict中 24 | factorzh_additionalinfo_dict = {} # 中文标签对应的额外信息:描述、好的句子1,2,3 25 | factorzh_neg_sentenceorlabel_dict = {} # 中文标签对应的候选负样本信息 26 | categorycn2labelcn2other = {} 27 | # 获得每个类别下所有的中文标签、描述、好的句子的集合 28 | category_cn_list = [v for k, v in category_cn_dict.items()] 29 | categorycn2totallist_dict = {} 30 | for indexx, row_ in data.iterrows(): 31 | category_ = row_['纠纷类型'] 32 | # label_en_ = row_['标签'] 33 | label_zh_ = row_['中文标签'] 34 | desc_ = row_['一句话标准描述'] 35 | sentence_repres1_ = row_['好的句子1'] 36 | sentence_repres2_ = row_['好的句子2'] 37 | sentence_repres3_ = row_['好的句子3'] 38 | sublist = categorycn2totallist_dict.get(category_, []) 39 | sublist.extend([label_zh_, desc_, sentence_repres1_, sentence_repres2_, sentence_repres3_]) 40 | categorycn2totallist_dict[category_] = sublist 41 | 42 | # print("按类别统计的所有的标签、描述、有用的句子的集合:",len(categorycn2totallist_dict),";total_list:",categorycn2totallist_dict) 43 | for index, row in data.iterrows(): 44 | category = row['纠纷类型'] 45 | # label_en=row['标签'] 46 | label_zh = row['中文标签'] 47 | desc = row['一句话标准描述'] 48 | sentence_repres1 = row['好的句子1'] 49 | sentence_repres2 = row['好的句子2'] 50 | sentence_repres3 = row['好的句子3'] 51 | # reugular_expression=row.get('re','') # this field is not use at all. 52 | key = category + "_" + label_zh 53 | factorzh_additionalinfo_dict[key] = {"label_zh": label_zh, "desc": desc, "sentence_repres1": sentence_repres1, 54 | "sentence_repres2": sentence_repres2, "sentence_repres3": sentence_repres3,'reugular_expression':''} 55 | other_desc_list = [x for x in categorycn2totallist_dict[category] if 56 | x not in [label_zh, desc, sentence_repres1, sentence_repres2, sentence_repres3]] 57 | factorzh_neg_sentenceorlabel_dict[key] = list(set(other_desc_list)) 58 | return factorzh_additionalinfo_dict, factorzh_neg_sentenceorlabel_dict 59 | 60 | 61 | def read_source_flies(data_path, category): 62 | """ 63 | 读取原始数据 64 | :param data_path: 65 | :param category: 66 | :return: 67 | """ 68 | data_path_group = data_path 69 | file_path_divorce = data_path + category 70 | # tags 71 | divorce_tags_file = file_path_divorce + '/tags.txt' 72 | divorce_tags_object = open(divorce_tags_file, 'r') 73 | divorce_tags = divorce_tags_object.readlines(); 74 | divorce_tags = [x.strip() for x in divorce_tags] 75 | divorce_tags_dict = {j: xx for j, xx in enumerate(divorce_tags)} 76 | divorce_tag2id_dict = {yy: xx for xx, yy in divorce_tags_dict.items()} 77 | # selectedtags 78 | divorce_selectedtags_file = file_path_divorce + '/selectedtags.txt' 79 | divorce_selectedtags_object = codecs.open(divorce_selectedtags_file, 'r', 80 | 'utf-8') # open(divorce_selectedtags_file, 'r') 81 | divorce_selectedtags = divorce_selectedtags_object.readlines() 82 | divorce_selectedtags = [xx.strip() for xx in divorce_selectedtags] 83 | divorce_selectedtags_dict = {jj: xxx for jj, xxx in enumerate(divorce_selectedtags)} # {1:'婚后有子女',2:'限制行为能力子女抚养',...} 84 | # raw data 85 | if 'big' not in data_path: 86 | divorce_data_file = file_path_divorce + '/data_small_selected.json' 87 | else: 88 | divorce_data_file = file_path_divorce + '/train_selected.json' 89 | divorce_data_object = codecs.open(divorce_data_file, 'r', 'utf-8') # open(divorce_data_file, 'r'); 90 | divorce_lines = divorce_data_object.readlines() 91 | return divorce_tags_dict, divorce_tag2id_dict, divorce_selectedtags, divorce_selectedtags_dict, divorce_lines 92 | 93 | 94 | category2tags_dict = {} 95 | category2tag2id_dict = {} 96 | category2selectedtags = {} 97 | category2selectedtags_dict = {} 98 | # factorzh_additionalinfo_dict={} 99 | 100 | factorzh_additionalinfo_dict, factorzh_neg_sentenceorlabel_dict = load_factor_with_additional_info( 101 | data_path + 'factor_desc_represent.csv') # 'factor_desc_represent.csv' 102 | #print("####factorzh_additionalinfo_dict:",factorzh_additionalinfo_dict) 103 | # '借款纠纷_拒绝履行偿还': {'label_zh': '拒绝履行偿还', 'desc': '未按时偿还借款|拒不偿还借款', 'sentence_repres1': '如果XX未按指定的期间履行给付金钱义务,应当依照xx规定,加倍支付迟延履行期间的债务利息', 'sentence_repres2': '被告本人没有向信用社贷款,没有签过合同和借款凭证,也没有实际领取贷款本金,因此不同意还款。', 'sentence_repres3': '合同约定的借款期限届满之日,XXX未能履行还款义务'} 104 | # print("-------------------------------------------------------------------------------------------------") 105 | count_k = 0 106 | # for k,v in factorzh_neg_sentenceorlabel_dict.items(): 107 | # print(count_k,"k:",k,";v:",v) 108 | # count_k=count_k+1 109 | 110 | for i, category in enumerate(category_list): 111 | tags_dict, tag2id_dict, selectedtags, selectedtags_dict, _ = read_source_flies(data_path, category) 112 | category2tags_dict[category] = tags_dict 113 | category2tag2id_dict[category] = tag2id_dict 114 | category2selectedtags[category] = selectedtags 115 | category2selectedtags_dict[category] = selectedtags_dict 116 | 117 | # reugular_expression = factorzh_additionalinfo_dict[domain_zn + "_" + tag_cn_candidate]['reugular_expression'] 118 | def sentence_match_single(sentence,keyword): 119 | flag_positive = len(re.findall(keyword, sentence)) > 0 120 | 121 | flag_neg_1 = len(re.findall(negative_word+keyword, sentence)) > 0 122 | flag_neg_2 = len(re.findall(keyword+negative_word, sentence)) > 0 123 | if flag_positive: 124 | if not flag_neg_1 and not flag_neg_2: 125 | return True 126 | return False 127 | 128 | #result=sentence_match_single('原、被告于1980年结婚,婚后生有女儿范某香,大儿子范某荣,二儿子范某华均已成家另居。','divorce') 129 | #print("result:",result) 130 | # print("category2tags_dict:",category2tags_dict) 131 | # print("category2tag2id_dict:",category2tag2id_dict) 132 | # print("category2selectedtags_dict:",category2selectedtags_dict) 133 | -------------------------------------------------------------------------------- /bert/optimization_finetuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, # 0.98 ONLY USED FOR PRETRAIN. MUST CHANGE AT FINE-TUNING 0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /zuo/generate_training_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import random 4 | from util import category_cn_dict, read_source_flies, factorzh_additionalinfo_dict, factorzh_neg_sentenceorlabel_dict 5 | 6 | 7 | def generate_train_data(data_path,examples_source_path): 8 | """ 9 | 输入3个数据文件和 10 | """ 11 | # 分别获取三个文件下的数据(divorce,labor,loan) 12 | category_list = ['divorce', 'labor', 'loan'] 13 | count = 0 14 | count_empty_list = 0 15 | total_pos = 0 # 正样本的总数量 16 | total_neg = 0 # 负样本的总数量 17 | total_list = [] 18 | avg_length = 0 # 平均长度 19 | ignore_count = 0 20 | 21 | category2tag_en_example_dict=load_pos_examples_files(examples_source_path) 22 | for i, category in enumerate(category_list): 23 | # 读取每个目录下的数据 24 | divorce_tags_dict, divorce_tag2id_dict, divorce_selectedtags, divorce_selectedtags_dict, divorce_lines = read_source_flies(data_path, category) 25 | divorce_id2tag_dict={v:k for k,v in divorce_tag2id_dict.items()} 26 | divorce_selectedtags2id_dict={v:k for k,v in divorce_selectedtags_dict.items()} 27 | for k, line in enumerate(divorce_lines): 28 | list_element = json.loads(line.strip()) 29 | for m, element in enumerate(list_element): 30 | # 单个样本级别(句子和标签列表) 31 | sentence = element['sentence'] 32 | sentence = sentence.replace("\t\n", "。").replace("\t", " ") 33 | if 'check-yuhan@gridsum.com' in sentence: ignore_count = ignore_count + 1;continue 34 | if len(str(sentence)) > 122: sentence = sentence[0:60] + "。" + sentence[-62:] # 处理超长的文本 # if len(str(sentence)) > 90: sentence = sentence[0:43] + "。" + sentence[-45:] # 处理超长的文本 35 | 36 | labels = element['labels'] 37 | 38 | if len(labels) < 1:count_empty_list = count_empty_list + 1 # 统计空标签的行数 39 | labels_word_list = [divorce_selectedtags_dict[divorce_tag2id_dict[x]] for x in labels] # 得到中文标签名的列表 40 | #print('###labels_word_list.',k, m, "sentence:", sentence, ";labels:", labels, ";labels_word_list:", labels_word_list) 41 | count = count + 1 42 | 43 | # 传入一个标签列表和文本,正样本(从标准问法来) 44 | total_list, avg_length, total_pos=get_pos_example_data(labels_word_list, category,divorce_selectedtags2id_dict, divorce_id2tag_dict, 45 | category2tag_en_example_dict,sentence, count, total_pos, avg_length, total_list) 46 | 47 | # 传入一个标签列表和文本,负样本(经典:其他的中文标签) 48 | total_list, avg_length, total_neg=get_neg_example_data(divorce_selectedtags_dict, labels_word_list, category, sentence, count,divorce_selectedtags2id_dict, 49 | divorce_id2tag_dict, category2tag_en_example_dict,total_list, total_neg, avg_length) 50 | if k%5000==0: 51 | print("Generate training data.length of total_list:",len(total_list),";total_pos:",total_pos,";total_neg:",total_neg) 52 | 53 | # 打印一些信息 54 | num_example = len(total_list) 55 | print("总行数:", count, ";空标签的行数:", count_empty_list, ";非空标签的行数:",count - count_empty_list) # 总行数:22500 ;空标签的行数:13800; 非空行数9000。 56 | print("正样本总数量:", total_pos, ";负样本总数量:", total_neg, ";总样本数量:", total_pos + total_neg) 57 | print(category + "_tags_dict:",divorce_tags_dict) # loan_tags_dict={0: 'LN1', 1: 'LN2', 2: 'LN3', 3: 'LN4', 4: 'LN5', 5: 'LN6', 6: 'LN7', 7: 'LN8', 8: 'LN9', 9: 'LN10', 10: 'LN11', 11: 'LN12', 12: 'LN13', 13: 'LN14', 14: 'LN15', 15: 'LN16', 16: 'LN17', 17: 'LN18', 18: 'LN19', 19: 'LN20'} 58 | print(category + "_selectedtags_dict:",divorce_selectedtags_dict) # _selectedtags_dict={0: '债权人转让债权', 1: '借款金额x万元', 2: '有借贷证明', 3: '贷款人系金融机构', 4: '返还借款', 5: '公司|单位|其他组织借款', 6: '连带保证', 7: '催告还款', 8: '支付利息', 9: '订立保证合同', 10: '有书面还款承诺', 11: '担保合同无效|撤销|解除', 12: '拒绝履行偿还', 13: '免除保证人保证责任', 14: '保证人不承担保证责任', 15: '质押人系公司', 16: '贷款人未按照约定的日期|数额提供借款', 17: '多人借款', 18: '债务人转让债务', 19: '约定利率不明'} 59 | avg_length = float(avg_length) / float(num_example) 60 | print("平均长度:", avg_length, ";ignore_count:", ignore_count) # 74. 120会够用 61 | 62 | # 写文件 train.tsv; dev.tsv 63 | write_data_to_file_system(total_list, data_path + 'train_data/') 64 | 65 | def get_pos_example_data(labels_word_list,category,divorce_selectedtags2id_dict,divorce_id2tag_dict,category2tag_en_example_dict,sentence,count,total_pos,avg_length,total_list): 66 | """ 67 | 得到正样本的训练数据 68 | :return: 69 | """ 70 | sub_result_list=[] 71 | for label_word in labels_word_list: # e.g. label_word: '婚后有子女'; 其他可用的:1.一句话标准描述;2.好的句子1;3.好的句子2;4.好的句子3 72 | labelzh_and_additional_info_list_big = [] 73 | # 中文标签对应的样本 74 | labelzh_and_additional_info_dict = factorzh_additionalinfo_dict[category_cn_dict[category] + "_" + label_word] 75 | labelzh_and_additional_info_list = [v for k, v in labelzh_and_additional_info_dict.items() if (not isinstance(v, float) and '*' not in v)] # 去掉有问题的元素 76 | labelzh_and_additional_info_sub_list= random.sample(labelzh_and_additional_info_list,2) 77 | labelzh_and_additional_info_list_big.extend(labelzh_and_additional_info_sub_list) 78 | ##################################################################################################### 79 | # 添加标签对应的例子的随机选的5个例子 80 | # 从中文标签到index 81 | temp_index = divorce_selectedtags2id_dict[label_word] 82 | # 从index中英文标签 83 | temp_tag_en = divorce_id2tag_dict[temp_index] 84 | labels_word_list_examples = category2tag_en_example_dict[category].get(temp_tag_en, []) # 中文标签对应的例子的列表 85 | if len(labels_word_list_examples) > 0: 86 | temp_num_examples = len(labels_word_list_examples) 87 | temp_top5_example_list = random.sample(labels_word_list_examples, min(temp_num_examples, 3)) # 最多选3个 88 | labelzh_and_additional_info_list_big.extend(temp_top5_example_list) 89 | ###################################################################################################### 90 | for label_candidate in labelzh_and_additional_info_list_big: 91 | if label_candidate == sentence: continue # 如果要做句子对任务的双方是相同的,那么直接忽略 92 | if len(str(label_candidate)) > 122: label_candidate = label_candidate[0:60] + "。" + label_candidate[-62:] # 处理超长的文本 # if len(str(sentence)) > 90: sentence = sentence[0:43] + "。" + sentence[-45:] # 处理超长的文本 93 | strings = '1' + '\t' + category_cn_dict[category] + "的" + label_candidate + "\t" + sentence + "\n" # label_word 94 | if isinstance(label_candidate, float): continue # 跳过太短的,或有问题的例子 95 | sub_result_list.append(strings) 96 | total_pos = total_pos + 1 97 | avg_length = avg_length + len(strings) 98 | 99 | total_list.extend(sub_result_list) 100 | return total_list,avg_length,total_pos 101 | ##打印############################################## 102 | #if len(labels_word_list)>0: 103 | # print(";labels_word_list:",labels_word_list," ;sentence:",sentence) 104 | # for lll,ee in enumerate(sub_result_list): 105 | # print("lll:",lll," ;ee:",ee) 106 | ################################################### 107 | 108 | def get_neg_example_data(divorce_selectedtags_dict,labels_word_list,category,sentence,count,divorce_selectedtags2id_dict,divorce_id2tag_dict,category2tag_en_example_dict,total_list,total_neg,avg_length): 109 | """ 110 | 获得负样本 111 | :param divorce_selectedtags_dict: 112 | :param labels_word_list: 113 | :param category: 114 | :param sentence: 115 | :param count: 116 | :param divorce_selectedtags2id_dict: 117 | :param divorce_id2tag_dict: 118 | :param category2tag_en_example_dict: 119 | :param total_list: 120 | :param total_neg: 121 | :param avg_length: 122 | :return: 123 | """ 124 | divorce_selectedtags_list = [v for k, v in divorce_selectedtags_dict.items()] 125 | #print("###divorce_selectedtags_list:", divorce_selectedtags_list) 126 | sub_neg_result_list=[] 127 | for label_word_neg in divorce_selectedtags_list: # 标准负样本 128 | random_number = random.random() 129 | if label_word_neg not in labels_word_list: # 不在标签中的,但在标签集合中的标签,皆为负样本 130 | if random_number > 0.8: # TODO TODO TODO 通过改变这个值,你可以决定生成负训练样本的数量。将这个数改大一点,如果你希望产生更少的数据 131 | # a.负样本1:其他的中文标签 132 | strings = '0' + '\t' + category_cn_dict[category] + "的" + label_word_neg + "\t" + sentence + "\n" 133 | if isinstance(label_word_neg, float): continue # 跳过太短的,或有问题的例子 134 | sub_neg_result_list.append(strings) 135 | total_neg = total_neg + 1 136 | avg_length = avg_length + len(strings) 137 | 138 | # b.负样本2:额外的标签的描述、好的句子1,2,3;外加标签对应的例子 139 | ############################################################################################ 140 | temp_neg_index = divorce_selectedtags2id_dict[label_word_neg] 141 | # 从index中英文标签 142 | temp_neg_tag_en = divorce_id2tag_dict[temp_neg_index] 143 | labels_word_list_neg_examples = category2tag_en_example_dict[category].get(temp_neg_tag_en, []) 144 | neg_examples_list_big = [] 145 | # 从标签对应的例子中采样出一部分 146 | if len(labels_word_list_neg_examples) > 0: 147 | temp_num_neg_examples = len(labels_word_list_neg_examples) 148 | temp_top5_neg_example_list = random.sample(labels_word_list_neg_examples,min(temp_num_neg_examples, 5)) 149 | neg_examples_list_big.extend(temp_top5_neg_example_list) 150 | ############################################################################################ 151 | neg_sublist = factorzh_neg_sentenceorlabel_dict[category_cn_dict[category] + "_" + label_word_neg] 152 | neg_sublist = [xx for xx in neg_sublist if not isinstance(xx, float)] # 去掉空值 153 | neg_examples_list_big.extend(neg_sublist) 154 | neg_sample_final_list = random.sample(neg_examples_list_big, 2) # 从描述和案例中随机选出2个 155 | # 添加到列表中 156 | for neg_sample in neg_sample_final_list: 157 | if isinstance(neg_sample, float): continue 158 | if len(str(neg_sample)) > 122: neg_sample = neg_sample[0:60] + "。" + neg_sample[-62:] 159 | strings_neg = '0' + '\t' + category_cn_dict[category] + "的" + neg_sample + "\t" + sentence + "\n" 160 | if neg_sample != label_word_neg: # sample出来的负样本不应该是负样本本身 161 | sub_neg_result_list.append(strings_neg) 162 | total_neg = total_neg + 1 163 | avg_length = avg_length + len(strings) 164 | 165 | total_list.extend(sub_neg_result_list) 166 | return total_list,avg_length,total_neg 167 | # 打印一些例子 168 | #for kkk, neg_elment in enumerate(sub_neg_result_list): 169 | # print("kkk:",kkk," ;neg_elment:",neg_elment) 170 | 171 | 172 | def write_file(data_list, target_file, file_type): 173 | """ 174 | 写单个文件 175 | :param data_list: 176 | :param target_file: 177 | :return: 178 | """ 179 | random.shuffle(data_list) 180 | target_object = open(target_file, 'w') 181 | count_pos = 0 182 | count_neg = 0 183 | for string in data_list: 184 | #print("##string:",string) 185 | label_string = string.split("\t")[0] 186 | # 统计正负比例 187 | if file_type == 'train': 188 | if label_string == '0': count_neg = count_neg + 1 189 | if label_string == '1': count_pos = count_pos + 1 190 | target_object.write(string) 191 | if file_type == 'dev': # 对于验证集,去掉一部分负样本 192 | if label_string == '1': 193 | target_object.write(string) 194 | count_pos = count_pos + 1 195 | else: 196 | if random.random() > 0.7:# TODO CHANGE AT08-27.0.6: 197 | target_object.write(string) 198 | count_neg = count_neg + 1 199 | print(file_type, "count_pos:", count_pos, ";count_neg:", count_neg, ";pert of pos:",(float(count_pos) / float(count_pos + count_neg))) 200 | target_object.close() 201 | 202 | 203 | def write_data_to_file_system(total_list, target_path): 204 | """ 205 | 写多个文件 206 | :param total_list: 207 | :param target_path: 208 | :return: 209 | """ 210 | random.shuffle(total_list) 211 | num_example = len(total_list) 212 | # 写训练集 213 | num_train = int(num_example * 0.95) 214 | train_list = total_list[0:num_train] 215 | target_train_file = target_path + 'train.tsv' 216 | write_file(train_list, target_train_file, 'train') 217 | # 写验证集 218 | dev_list = total_list[num_train:] 219 | target_dev_file = target_path + 'dev.tsv' 220 | write_file(dev_list, target_dev_file, 'dev') 221 | 222 | def load_pos_examples_files(examples_source_path): 223 | """ 224 | 从文件中读取标签对应的代表性的正样本 225 | :param source_path: 文件所在位置 226 | :return: 227 | """ 228 | category2tag_en_example_dict={} 229 | category_list = ['divorce', 'labor', 'loan'] 230 | for category_en in category_list: 231 | source_file=examples_source_path+category_en+'_pos_examples.json' 232 | source_object=open(source_file,'r') 233 | line=source_object.readline() 234 | temp_dict=json.loads(line) 235 | print("type of temp_dict:",type(temp_dict)) 236 | category2tag_en_example_dict[category_en]=temp_dict 237 | return category2tag_en_example_dict 238 | 239 | examples_source_path= 'zuo/data_all/pos_examples/' 240 | # category2tag_en_example_dict=load_pos_examples_files(examples_source_path) 241 | # divorce_tagen_examples=category2tag_en_example_dict['labor'] 242 | # examples=divorce_tagen_examples['LB18'] 243 | # print("##examples:",examples) 244 | 245 | data_path = 'zuo/data_all/' # './data/' 246 | generate_train_data(data_path,examples_source_path) 247 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-label Classification 2 | Transform multi-label classification as sentence pair task & 3 | 4 | Together with generating more training data, use more information and external knowledge 5 | 6 | *** UPDATE *** 7 | 8 | Going to releasing Chinese version of Pre-trained Model ALBERT at albert_zh, State Of The Art performance on benchmark with 30% fewer parameters than BERT_large and others. 9 | 10 | 一键运行10个数据集、9个基线模型、不同任务上模型效果的详细对比,见中文任务基准测评 CLUE benchmark 11 | 12 | 13 | Introduction 多标签分类介绍 14 | -------------------------------------------------------------------- 15 | Multi-label Classification is a classification problem where multiple labels may be assigned to each instance. 16 | 17 | Zero, one or multiple labels can be associated with an instance(or example). It is more general than multi-class 18 | 19 | classification where one and only one label assigned to an example. 20 | 21 | You can think the problem as a multiple binary classification that map input(X) to 22 | 23 | a value(y) either 0 or 1 for each label in label space. According to wiki: 24 | 25 | Multi-label classification is the problem of finding a model that maps inputs x to binary vectors y (assigning a value of 0 or 1 for each element (label) in y). 26 | 27 | Transform Multi-label Classification as Sentence Pair Task 28 | -------------------------------------------------------------------- 29 | ### 任务转化:将多标签分类转换为句子对任务 30 | 31 | It is normal to get a baseline by trying use a single classification model to predict the labels for each input string. 32 | 33 | And it will be time-effective during inference since only one time computation is happened. However, the performance is quite poor especially when 34 | 35 | you only have few instances or examples for each label. one of the reasons for poor performance is that it try to map input to target labels directly 36 | 37 | but fail to use more information and training example is not enough. 38 | 39 | By cast to sentence pair task, it is easy to use more information including label information, instance information, key words from each label. 40 | 41 | Generating more training data and Information 42 | -------------------------------------------------------------------- 43 | ### 产生更多训练数据、结合更多信息和额外的知识 44 | 45 | Let's talk about how to use additional information and generate more data. 46 | 47 | Sentence pair task is like this: . for sentence_1, it come from original input string. 48 | 49 | Then it be: . so where is sentence_2 come from? it is another input sentence. it can be come from: 50 | 51 | 1) chinese label string, or 52 | 53 | 2) an sentence which can represent a label. you can randomly pick a sentence from a specific label to represent label, 54 | 55 | you can also manual give an description to a label to represent this label; 56 | 57 | 3) you can learn the keywords from the label and put keywords together to repsenent the label. 58 | 59 | In a words, there are many way to generate sentence_2 to do sentence pair task. and as a result, we generate more than 60 | 61 | 1 million training instances for sentence pair task by using only 30k labeled instance. 62 | 63 | 直接的多标签分类去预测,由于设法直接从输入的文本去影射标签,没有使用额外的信息,训练数据也有限,效果比较差。 64 | 65 | 通过将其转化为句子对任务,我们可以比较容易的利用额外的信息,并且产生极大数量的训练样本。这些额外的信息包括但不限于: 66 | 67 | 特定标签对应的训练样本中的部分输入文本、中文的标签信息、标签对应的top关键词的组合。这些额外的信息可用来代表这个标签。 68 | 69 | 在#Task Description和#Generate Training Data,我将详细展开并举例。 70 | 71 | 72 | Procedure 流程 73 | -------------------------------------------------------------------- 74 | 1) Transform multi-label classification to sentence pair task with random instance from label ---> 75 | 76 | 2) Additional information: add label information, which is chinese; or keys words from label, or export knowledge ---> 77 | 78 | 3) Additional domain knowledge: large scale domain pre-training 79 | 80 | 81 | Performance 效果对比 82 | -------------------------------------------------------------------- 83 | 84 | | No. 序列| Model | 描述 Description | Online 线上 | 85 | | :-------| :------- | :------- | :---------: | 86 | |0 | Multi-label Classification(TextCNN)|多标签分类 | 61 | 87 | |1 | Multi-label Classification(Bert) |多标签分类| 64.9 | 88 | |2| Sentece-pair Task |句子对任务,标签对应的随机样本与输入文本| 68.9 | 89 | |3 |#2 + Instance Information |加上中文标签与输入文本的数据| 69.5 | 90 | |4 |#3 + bert_www_ext_law |bert_wwwm_ext基础上的领域预训练| 70.7 | 91 | |5 |#4 + RoBERTa-zh-domain | 结合RoBERTa和wwm的大规模领域预训练| 72.1 | 92 | |6 |#5 + RoBERTa-zh-Large-domain |24层的RoBERTa的预训练| 73.0| 93 | |7 |#6 + Ensemble| 使用3个模型概率求平均 & 候选项召回并预测加速 | 75.5| 94 | 95 | Additional information for pre-train RoBERTa chinese models, check: Roberta_zh 96 | 97 | Task Description 任务介绍 98 | -------------------------------------------------------------------- 99 | ### About Task 任务是什么? 100 | The purpose of this task is to extract important fact from description of legal case, 101 | 102 | and map description of case to case elements according to system designed by experts in the field. 103 | 104 | for each sentence in the paragraph from judicial document, the model need to identify the key element(s). 105 | 106 | multiple or zero elements may exist in a sentence. 107 | 108 | 本任务的主要目的是为了将案件描述中重要事实描述自动抽取出来,并根据领域专家设计的案情要素体系进行分类。 109 | 110 | 案情要素抽取的结果可以用于案情摘要、可解释性的类案推送以及相关知识推荐等司法领域的实际业务需求中。 111 | 112 | 具体地,给定司法文书中的相关段落,系统需针对文书中每个句子进行判断,识别其中的关键案情要素。 113 | 114 | 本任务共涉及三个领域,包括婚姻家庭、劳动争议、借款合同等领域。 115 | 116 | Check this for more details on this task:中国法研杯_CAIL2019(要素识别赛道) 117 | 118 | This is the 2rd/188 solution for this task. 119 | 120 | ### Examples of Data 数据介绍 121 | 122 | 本任务所使用的数据集主要来自于“中国裁判文书网”公开的法律文书,每条训练数据由一份法律文书的案情描述片段构成,其中每个句子都被标记了对应的类别标签 123 | 124 | (需要特别注意的是,每个句子对应的类别标签个数不定),例如: 125 | 126 | {"labels": ["DV1", "DV4", "DV2"],"sentence": "In our opinion, according to the agreement between the two parties at the time of divorce, the plaintiff has paid 22210.00 yuan for the child's upbringing on the basis of ten-year upbringing. We can confirm that the plaintiff pays 200.00 yuan for the child's upbringing on a monthly basis."} 127 | 128 | {"labels": ["DV1", "DV4", "DV2"],"sentence": "本院认为,依据双方离婚时的协议约定,原告已按尚抚养十年一性支付孩子抚养费22210.00元,可确认原告每月支付孩子抚养费为200.00元。"}, 129 | {"labels": [],"sentence": "父母离婚后子女的抚养问题,应从有利于子女身心健康以及保障子女合法权益出发,结合父母双方的抚养能力和抚养条件等具体情况妥善解决。"}, 130 | {"labels": ["DV1", "DV8", "DV4", "DV2"],"sentence": "二、关于抚养费的承担问题,本院根据子女的实际需要、父母双方的负担能力和当地的实际生活水平确定,由被告赵某甲每月给付孩子抚养费600.00元;"}, 131 | {"labels": ["DV14", "DV9", "DV12"], "sentence": "原告诉称,原被告原系夫妻关系,双方于2015年3月18日经河南省焦作市山阳区人民法院一审判决离婚,离婚后原告才发现被告在婚姻关系存续期间,与他人同居怀孕并生下一男孩,给原告造成极大伤害。"}, 132 | {"labels": [], "sentence": "特诉至贵院,请求判决被告赔偿原告精神损害抚慰金3万元。"}, 133 | {"labels": [], "sentence": "被告辩称,1、原告在焦作市山阳区人民法院离婚诉讼中,提交答辩状期间就同意被告的离婚请求,被告并未出现与他人同居怀孕情况,原告不具备损害赔偿权利主体资格。"}, {"labels": [], "sentence": "2、被告没有与他人持续稳定的同居生活,不具备损害赔偿责任主体资格。"}, {"labels": [], "sentence": "3、原被告婚姻关系存续期间,原告经常对被告实施家庭暴力。"}, {"labels": [], "sentence": "原告存在较大过错,无权提起本案损失赔偿请求。"}, {"labels": ["DV1"], "sentence": "经审理查明:原被告于××××年××月××日登记结婚,××××年××月××日生育女孩都某乙。"}, {"labels": ["DV9"], "sentence": "被告于2014年9月23日向焦作市山阳区人民法院起诉与原告离婚,该院于2015年3月18日判决准予原被告离婚后,原告不服上诉至焦作市中级人民法院,该院于2015年6月15日作出终审判决,驳回原告上诉,维持原判。"} 134 | 135 | for each english label like DV1, it associated with a chinese label, such as 婚后有子女, which means 'Having children after marriage'. 136 | 137 | 对于每个英文标签,都有一个对应的中文标签名称。 如 DV1、DV2、DV4对应的中文标签分别为:婚后有子女、限制行为能力子女抚养、支付抚养费。DV8对应:按月给付抚养费 138 | 139 | Generate Training Data 生成训练数据 140 | -------------------------------------------------------------------- 141 | ### 标签下的代表性样本的产生 142 | 我们首先选取每个标签下一定数量样本,如随机的选取5个样本,来代表这个标签;并且由于我们知道这个标签对应的中文名称,所以,对于任何一个标签,我们都构造了6个句子来代表这个标签, 143 | 144 | 记为集合{representation_set}. 需要注意的是,为了使得样本更有代表性,在随机选取样本过程中,我们优先选择哪些只有一个标签的样本作为我们的代表性的样本。 145 | 146 | 如对于DV1,我们6个样本包括: 147 | 148 | 婚后有子女(来自于中文标签) 149 | 150 | ××××年××月××日生育女儿赵某乙。(来自于DV1标签下的样本) 151 | 152 | 综合全案证据及庭审调查,本院确认以下法律事实:××××年××月××日,原告林某和被告赵某甲办理结婚登记手续,婚后于××××年××月××日生育婚生女儿赵某乙。(来自于DV1标签下的样本) 153 | 154 | 四、抚养权变更后,被告赵某甲有探望女儿赵某乙的权利,原告林某应为被告赵某甲探望孩子提供必要便利。(来自于DV1标签下的样本) 155 | 156 | 原始样本: {"labels": ["DV1", "DV4", "DV2"],"sentence": "本院认为,依据双方离婚时的协议约定,原告已按尚抚养十年一性支付孩子抚养费22210.00元,可确认原告每月支付孩子抚养费为200.00元。"} 157 | 158 | ### 正样本的构造 159 | 160 | 我们需要构造句子对任务: 161 | 162 | 句子对任务中的第一部分的输入为,原始的文本,即本院认为,依据双方离婚时的协议约定...),第二部分的输入是数据产生过程的关键所在。 163 | 164 | 对于DV1这个标签,我们得到六个句子即representation_set[DV1],并且对于原始输入与这里面的每个样本的组合,我们给label赋值为1。如: 165 | 166 | <"本院认为,依据双方离婚时的协议约定,原告已按尚抚养十年一性支付孩子抚养费22210.00元,可确认原告每月支付孩子抚养费为200.00元。", "婚后有子女", 1> 167 | 168 | <"本院认为,依据双方离婚时的协议约定,原告已按尚抚养十年一性支付孩子抚养费22210.00元,可确认原告每月支付孩子抚养费为200.00元。", “××××年××月××日生育女儿赵某乙。", 1> 169 | 170 | <"本院认为,依据双方离婚时的协议约定,原告已按尚抚养十年一性支付孩子抚养费22210.00元,可确认原告每月支付孩子抚养费为200.00元。", “综合全案证据及庭审调查,本院确认以下法律事实:××××年××月××日,原告林某和被告赵某甲办理结婚登记手续,婚后于××××年××月××日生育婚生女儿赵某乙。", 1> 171 | 172 | .... 173 | 174 | 通过这个形式,我们将正样本扩大了6倍;我们可以通过标签下采样更多的例子,以及使用标签下统计得到的关键词的组合,来进一步扩大正样本的集合。 175 | 176 | ### 负样本的构造 177 | 178 | 对于一个标签集合,如婚姻家庭下有20个标签,那么对于一个句子,只要没有被打上标签的,就是负样本。我们也构造了一个标签下的可能的所有样本的集合外加中文标签,并通过随机样本的方式来得到需要的样本。 179 | 180 | 如对于我们的原始样本的输入文本:"本院认为,依据双方离婚时的协议约定,原告已按尚抚养十年一性支付孩子抚养费22210.00元,可确认原告每月支付孩子抚养费为200.00元。", 181 | 182 | 它被打上了三个标签["DV1", "DV4", "DV2"],那么其他标签["DV3","DV5",...,"DV20"],都可以用来构造负样本。对应DV3,我们选择4+1即5个样本。3即从DV3下随机的找出三个样本,1即中文标签做为样本。 183 | 184 | 185 | ### 正负样本分布 Training data and Its distribution 186 | 187 | 由于正样本量扩大了至少6倍,负样本扩大了20倍左右,最终我们从原始的1万个样本中产生了100多万的数据。当然为了样本分布受控,我们也对负样本有部分下采样。 188 | 189 | 另外为了得到与任务目标接近的验证集,我们在验证集上对负样本精选了更大程度的下采样。 190 | 191 | 分布大致如下: 192 | 193 | train: 194 | count_pos: 261001 ;count_neg: 1011322 ;pert of pos: 0.20513737470752316 195 | 196 | dev: 197 | count_pos: 13863 ;count_neg: 16121 ;pert of pos: 0.4623465848452508 198 | 199 | 200 | 生成训练数据的命令 Run command to generate training data : 201 | 202 | python3 -u zuo/generate_training_data 203 | 204 | 205 | Relationship with Few-Shot Learning 与Few-Shot Learning有什么关系 206 | -------------------------------------------------------------------- 207 | According to Quora: with the term “few-shot learning”, the “few” usually lies between zero and five, 208 | 209 | meaning that training a model with zero examples is known as zero-shot learning, one example is one-shot learning, and so on. 210 | 211 | All of these variants are trying to solve the same problem with differing levels of training material. 212 | 213 | We are not doing few shot learning. however we try to get maximum performance for tasks with not so many examples. 214 | 215 | And also try to use information and knowledge,whether come from task or outside the task, as much as possible, to boost the performance. 216 | 217 | Download Data 下载数据 218 | -------------------------------------------------------------------- 219 | 点击这里下载数据,并解压缩到zuo目录下。这样你就有了一个新的包含所有需要的数据的目录./zuo/data_all 220 | 221 | Training 训练模型 222 | -------------------------------------------------------------------- 223 | Run Command to Train the model: 224 | 225 | export BERT_BASE_DIR=./RoBERTa_zh_Large 226 | export TEXT_DIR=./zuo/data_all/train_data 227 | 228 | nohup python3 run_classifier.py --task_name=sentence_pair --do_train=true --do_eval=true --data_dir=$TEXT_DIR \ 229 | --vocab_file=$BERT_BASE_DIR/vocab.txt --bert_config_file=$BERT_BASE_DIR/bert_config_big.json \ 230 | --init_checkpoint=$BERT_BASE_DIR/roberta_zh_large_model.ckpt --max_seq_length=256 --train_batch_size=128 \ 231 | --learning_rate=1e-5 --num_train_epochs=3 --output_dir=zuo/model_files/roberta-zh-large_law & 232 | 233 | 如果你从现有的模型基础上训练,指定一下BERT_BASE_DIR的路径,并确保bert_config_file和init_checkpoint两个参数的值能对应到相应的文件上。 234 | 这里假设你下载了roberta的模型并放在本项目的这个RoBERTa_zh_Large目录下。 235 | 236 | 237 | Inference and its acceleration 预测加速 238 | -------------------------------------------------------------------- 239 | 训练完成后,运行命令来进行预测 Run Command to Train Model: 240 | 241 | python3 -u main.py 242 | 243 | 需要注意的是,你需要确保相应的目录有训练好的模型,见zuo/run_classifier_predict_online.py,特别注意这两个参数要能对应上: 244 | 245 | init_checkpoint和bert_config_file 246 | 247 | ### 句子对任务的构建 Construct Sentence Pair 248 | 249 | 虽然训练阶段使用了很多信息和知识来训练,但是预测阶段我们只采用<原始输入的句子,候选标签对应的中文标签>来构造句子对任务。我们认为样本下的标签虽然能 250 | 251 | 代表标签,但中文标签具有最好的代表性,预测效果也好一些。 252 | 253 | ### 预测阶段加速 Accelerate Inference Time 254 | 255 | 由于采用了sentence pair任务即句子对形式,对于一个输入,有20个标签,每个标签都需要预测,那么总共需要预测20次,这会导致预测时间过长。 256 | 257 | 所以,我们采用的是快速召回+精细预测的形式。 258 | 259 | 快速召回,采用多标签分类的形式,只需一次预测就可以将可能的候选项找到,如概率大于0.05的标签都是候选项。实践中,多数时候候选的标签为只有0个或1个。 260 | 261 | 对于每个候选的标签,都会使用句子对任务(原始句子,标签对应的中文描述)的模型去预测这个标签的概率;当某个标签的概率超过0.5的时候,即认为是目标标签。 262 | 263 | 对于响应速度要求不是特别严格的时候,我们也可以通过训练层数比较少的句子对模型来作为快速召回模块。 264 | 265 | Multi-label classification directly with Bert 使用Bert直接做多标签分类 266 | -------------------------------------------------------------------- 267 | Check this 3.BERT, You can find code here: 268 | 269 | train_bert_multi-label.py 270 | 271 | Unfinished Work 未完成的工作 272 | -------------------------------------------------------------------- 273 | 利用标签间的关系,构造更多数据和更难的任务。 274 | 275 | 由于标签之间具有共现关系或排斥关系等,通过标签关系来生成更多数据,也是未来可以研究的一个方向。 276 | 277 | 项目贡献者,还包括: 278 | -------------------------------------------------------------------- 279 | skyhawk1990 280 | 281 | YC-wind 282 | 283 | 284 | Reference 285 | -------------------------------------------------------------------- 286 | 1、BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 287 | 288 | 2、RoBERTa中文预训练模型:Roberta_zh 289 | 290 | 3、Pre-Training with Whole Word Masking for Chinese BERT 291 | -------------------------------------------------------------------------------- /bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | #print("items:",items) #['[CLS]', '日', '##期', ',', '但', '被', '##告', '金', '##东', '##福', '载', '##明', '[MASK]', 'U', '##N', '##K', ']', '保', '##证', '本', '##月', '1', '##4', '[MASK]', '到', '##位', ',', '2', '##0', '##1', '##5', '年', '6', '[MASK]', '1', '##1', '日', '[', 'U', '##N', '##K', ']', ',', '原', '##告', '[MASK]', '认', '##可', '于', '2', '##0', '##1', '##5', '[MASK]', '6', '月', '[MASK]', '[MASK]', '日', '##向', '被', '##告', '主', '##张', '权', '##利', '。', '而', '[MASK]', '[MASK]', '自', '[MASK]', '[MASK]', '[MASK]', '[MASK]', '年', '6', '月', '1', '##1', '日', '[SEP]', '原', '##告', '于', '2', '##0', '##1', '##6', '[MASK]', '6', '[MASK]', '2', '##4', '日', '起', '##诉', ',', '主', '##张', '保', '##证', '责', '##任', ',', '已', '超', '##过', '保', '##证', '期', '##限', '[MASK]', '保', '##证', '人', '依', '##法', '不', '##再', '承', '##担', '保', '##证', '[MASK]', '[MASK]', '[MASK]', '[SEP]'] 140 | for i,item in enumerate(items): 141 | #print(i,"item:",item) # ##期 142 | output.append(vocab[item]) 143 | return output 144 | 145 | 146 | def convert_tokens_to_ids(vocab, tokens): 147 | return convert_by_vocab(vocab, tokens) 148 | 149 | 150 | def convert_ids_to_tokens(inv_vocab, ids): 151 | return convert_by_vocab(inv_vocab, ids) 152 | 153 | 154 | def whitespace_tokenize(text): 155 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 156 | text = text.strip() 157 | if not text: 158 | return [] 159 | tokens = text.split() 160 | return tokens 161 | 162 | 163 | class FullTokenizer(object): 164 | """Runs end-to-end tokenziation.""" 165 | 166 | def __init__(self, vocab_file, do_lower_case=True): 167 | self.vocab = load_vocab(vocab_file) 168 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 169 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 170 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 171 | 172 | def tokenize(self, text): 173 | split_tokens = [] 174 | for token in self.basic_tokenizer.tokenize(text): 175 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 176 | split_tokens.append(sub_token) 177 | 178 | return split_tokens 179 | 180 | def convert_tokens_to_ids(self, tokens): 181 | return convert_by_vocab(self.vocab, tokens) 182 | 183 | def convert_ids_to_tokens(self, ids): 184 | return convert_by_vocab(self.inv_vocab, ids) 185 | 186 | 187 | class BasicTokenizer(object): 188 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 189 | 190 | def __init__(self, do_lower_case=True): 191 | """Constructs a BasicTokenizer. 192 | 193 | Args: 194 | do_lower_case: Whether to lower case the input. 195 | """ 196 | self.do_lower_case = do_lower_case 197 | 198 | def tokenize(self, text): 199 | """Tokenizes a piece of text.""" 200 | text = convert_to_unicode(text) 201 | text = self._clean_text(text) 202 | 203 | # This was added on November 1st, 2018 for the multilingual and Chinese 204 | # models. This is also applied to the English models now, but it doesn't 205 | # matter since the English models were not trained on any Chinese data 206 | # and generally don't have any Chinese data in them (there are Chinese 207 | # characters in the vocabulary because Wikipedia does have some Chinese 208 | # words in the English Wikipedia.). 209 | text = self._tokenize_chinese_chars(text) 210 | 211 | orig_tokens = whitespace_tokenize(text) 212 | split_tokens = [] 213 | for token in orig_tokens: 214 | if self.do_lower_case: 215 | token = token.lower() 216 | token = self._run_strip_accents(token) 217 | split_tokens.extend(self._run_split_on_punc(token)) 218 | 219 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 220 | return output_tokens 221 | 222 | def _run_strip_accents(self, text): 223 | """Strips accents from a piece of text.""" 224 | text = unicodedata.normalize("NFD", text) 225 | output = [] 226 | for char in text: 227 | cat = unicodedata.category(char) 228 | if cat == "Mn": 229 | continue 230 | output.append(char) 231 | return "".join(output) 232 | 233 | def _run_split_on_punc(self, text): 234 | """Splits punctuation on a piece of text.""" 235 | chars = list(text) 236 | i = 0 237 | start_new_word = True 238 | output = [] 239 | while i < len(chars): 240 | char = chars[i] 241 | if _is_punctuation(char): 242 | output.append([char]) 243 | start_new_word = True 244 | else: 245 | if start_new_word: 246 | output.append([]) 247 | start_new_word = False 248 | output[-1].append(char) 249 | i += 1 250 | 251 | return ["".join(x) for x in output] 252 | 253 | def _tokenize_chinese_chars(self, text): 254 | """Adds whitespace around any CJK character.""" 255 | output = [] 256 | for char in text: 257 | cp = ord(char) 258 | if self._is_chinese_char(cp): 259 | output.append(" ") 260 | output.append(char) 261 | output.append(" ") 262 | else: 263 | output.append(char) 264 | return "".join(output) 265 | 266 | def _is_chinese_char(self, cp): 267 | """Checks whether CP is the codepoint of a CJK character.""" 268 | # This defines a "chinese character" as anything in the CJK Unicode block: 269 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 270 | # 271 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 272 | # despite its name. The modern Korean Hangul alphabet is a different block, 273 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 274 | # space-separated words, so they are not treated specially and handled 275 | # like the all of the other languages. 276 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 277 | (cp >= 0x3400 and cp <= 0x4DBF) or # 278 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 279 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 280 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 281 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 282 | (cp >= 0xF900 and cp <= 0xFAFF) or # 283 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 284 | return True 285 | 286 | return False 287 | 288 | def _clean_text(self, text): 289 | """Performs invalid character removal and whitespace cleanup on text.""" 290 | output = [] 291 | for char in text: 292 | cp = ord(char) 293 | if cp == 0 or cp == 0xfffd or _is_control(char): 294 | continue 295 | if _is_whitespace(char): 296 | output.append(" ") 297 | else: 298 | output.append(char) 299 | return "".join(output) 300 | 301 | 302 | class WordpieceTokenizer(object): 303 | """Runs WordPiece tokenziation.""" 304 | 305 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 306 | self.vocab = vocab 307 | self.unk_token = unk_token 308 | self.max_input_chars_per_word = max_input_chars_per_word 309 | 310 | def tokenize(self, text): 311 | """Tokenizes a piece of text into its word pieces. 312 | 313 | This uses a greedy longest-match-first algorithm to perform tokenization 314 | using the given vocabulary. 315 | 316 | For example: 317 | input = "unaffable" 318 | output = ["un", "##aff", "##able"] 319 | 320 | Args: 321 | text: A single token or whitespace separated tokens. This should have 322 | already been passed through `BasicTokenizer. 323 | 324 | Returns: 325 | A list of wordpiece tokens. 326 | """ 327 | 328 | text = convert_to_unicode(text) 329 | 330 | output_tokens = [] 331 | for token in whitespace_tokenize(text): 332 | chars = list(token) 333 | if len(chars) > self.max_input_chars_per_word: 334 | output_tokens.append(self.unk_token) 335 | continue 336 | 337 | is_bad = False 338 | start = 0 339 | sub_tokens = [] 340 | while start < len(chars): 341 | end = len(chars) 342 | cur_substr = None 343 | while start < end: 344 | substr = "".join(chars[start:end]) 345 | if start > 0: 346 | substr = "##" + substr 347 | if substr in self.vocab: 348 | cur_substr = substr 349 | break 350 | end -= 1 351 | if cur_substr is None: 352 | is_bad = True 353 | break 354 | sub_tokens.append(cur_substr) 355 | start = end 356 | 357 | if is_bad: 358 | output_tokens.append(self.unk_token) 359 | else: 360 | output_tokens.extend(sub_tokens) 361 | return output_tokens 362 | 363 | 364 | def _is_whitespace(char): 365 | """Checks whether `chars` is a whitespace character.""" 366 | # \t, \n, and \r are technically contorl characters but we treat them 367 | # as whitespace since they are generally considered as such. 368 | if char == " " or char == "\t" or char == "\n" or char == "\r": 369 | return True 370 | cat = unicodedata.category(char) 371 | if cat == "Zs": 372 | return True 373 | return False 374 | 375 | 376 | def _is_control(char): 377 | """Checks whether `chars` is a control character.""" 378 | # These are technically control characters but we count them as whitespace 379 | # characters. 380 | if char == "\t" or char == "\n" or char == "\r": 381 | return False 382 | cat = unicodedata.category(char) 383 | if cat in ("Cc", "Cf"): 384 | return True 385 | return False 386 | 387 | 388 | def _is_punctuation(char): 389 | """Checks whether `chars` is a punctuation character.""" 390 | cp = ord(char) 391 | # We treat all non-letter/number ASCII as punctuation. 392 | # Characters such as "^", "$", and "`" are not in the Unicode 393 | # Punctuation class but we treat them as punctuation anyways, for 394 | # consistency. 395 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 396 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 397 | return True 398 | cat = unicodedata.category(char) 399 | if cat.startswith("P"): 400 | return True 401 | return False 402 | -------------------------------------------------------------------------------- /zuo/run_classifier_predict_online.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner of classification for online prediction. input is a list. output is a label.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import csv 22 | import os 23 | import zuo.bert.modeling as modeling 24 | import zuo.bert.tokenization as tokenization 25 | import tensorflow as tf 26 | import numpy as np 27 | import logging 28 | import time 29 | 30 | flags = tf.flags 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | ## Required parameters 35 | BERT_BASE_DIR = "zuo/model_files/roberta-zh-large_law/" # "../model_files/inference_with_reason/checkpoint_bert/" 36 | # BERT_BASE_DIR_NEW = "zuo/model_files/roberta-layer12-01-new/" # "../model_files/inference_with_reason/checkpoint_bert/" 37 | 38 | flags.DEFINE_string("bert_config_file", BERT_BASE_DIR + "bert_config_large.json", 39 | "The config json file corresponding to the pre-trained BERT model. " 40 | "This specifies the model architecture.") 41 | 42 | flags.DEFINE_string("task_name", "sentence_pair", "The name of the task to train.") 43 | 44 | flags.DEFINE_string("vocab_file", BERT_BASE_DIR + "vocab.txt", 45 | "The vocabulary file that the BERT model was trained on.") 46 | 47 | flags.DEFINE_string("init_checkpoint", BERT_BASE_DIR, # model.ckpt-66870--> /model.ckpt-66870 48 | "Initial checkpoint (usually from a pre-trained BERT model).") 49 | 50 | flags.DEFINE_integer("max_seq_length", 256, # 128 51 | "The maximum total input sequence length after WordPiece tokenization. " 52 | "Sequences longer than this will be truncated, and sequences shorter " 53 | "than this will be padded.") 54 | 55 | flags.DEFINE_bool( 56 | "do_lower_case", True, 57 | "Whether to lower case the input text. Should be True for uncased " 58 | "models and False for cased models.") 59 | 60 | flags.DEFINE_string("c", "gunicorn.conf", 61 | "gunicorn.conf") # data/sgns.target.word-word.dynwin5.thr10.neg5.dim300.iter5--->data/news_12g_baidubaike_20g_novel_90g_embedding_64.bin--->sgns.merge.char 62 | 63 | 64 | class InputExample(object): 65 | """A single training/test example for simple sequence classification.""" 66 | 67 | def __init__(self, guid, text_a, text_b=None, label=None): 68 | """Constructs a InputExample. 69 | 70 | Args: 71 | guid: Unique id for the example. 72 | text_a: string. The untokenized text of the first sequence. For single 73 | sequence tasks, only this sequence must be specified. 74 | text_b: (Optional) string. The untokenized text of the second sequence. 75 | Only must be specified for sequence pair tasks. 76 | label: (Optional) string. The label of the example. This should be 77 | specified for train and dev examples, but not for test examples. 78 | """ 79 | self.guid = guid 80 | self.text_a = text_a 81 | self.text_b = text_b 82 | self.label = label 83 | 84 | 85 | class InputFeatures(object): 86 | """A single set of features of data.""" 87 | 88 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 89 | self.input_ids = input_ids 90 | self.input_mask = input_mask 91 | self.segment_ids = segment_ids 92 | self.label_id = label_id 93 | 94 | 95 | class DataProcessor(object): 96 | """Base class for data converters for sequence classification data sets.""" 97 | 98 | def get_train_examples(self, data_dir): 99 | """Gets a collection of `InputExample`s for the train set.""" 100 | raise NotImplementedError() 101 | 102 | def get_dev_examples(self, data_dir): 103 | """Gets a collection of `InputExample`s for the dev set.""" 104 | raise NotImplementedError() 105 | 106 | def get_test_examples(self, data_dir): 107 | """Gets a collection of `InputExample`s for prediction.""" 108 | raise NotImplementedError() 109 | 110 | def get_labels(self): 111 | """Gets the list of labels for this data set.""" 112 | raise NotImplementedError() 113 | 114 | @classmethod 115 | def _read_tsv(cls, input_file, quotechar=None): 116 | """Reads a tab separated value file.""" 117 | with tf.gfile.Open(input_file, "r") as f: 118 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 119 | lines = [] 120 | for line in reader: 121 | lines.append(line) 122 | return lines 123 | 124 | 125 | class SentencePairClassificationProcessor(DataProcessor): 126 | """Processor for the internal data set. sentence pair classification""" 127 | 128 | def __init__(self): 129 | self.language = "zh" 130 | 131 | def get_train_examples(self, data_dir): 132 | """See base class.""" 133 | return self._create_examples( 134 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 135 | 136 | def get_dev_examples(self, data_dir): 137 | """See base class.""" 138 | return self._create_examples( 139 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 140 | 141 | def get_test_examples(self, data_dir): 142 | """See base class.""" 143 | return self._create_examples( 144 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 145 | 146 | def get_labels(self): 147 | """See base class.""" 148 | return ["0", "1"] 149 | 150 | def _create_examples(self, lines, set_type): 151 | """Creates examples for the training and dev sets.""" 152 | examples = [] 153 | for (i, line) in enumerate(lines): 154 | if i == 0: 155 | continue 156 | guid = "%s-%s" % (set_type, i) 157 | label = tokenization.convert_to_unicode(line[0]) 158 | text_a = tokenization.convert_to_unicode(line[1]) 159 | text_b = tokenization.convert_to_unicode(line[2]) 160 | examples.append( 161 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 162 | return examples 163 | 164 | 165 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer): 166 | """Converts a single `InputExample` into a single `InputFeatures`.""" 167 | label_map = {} 168 | for (i, label) in enumerate(label_list): 169 | label_map[label] = i 170 | 171 | tokens_a = tokenizer.tokenize(example.text_a) 172 | tokens_b = None 173 | if example.text_b: 174 | tokens_b = tokenizer.tokenize(example.text_b) 175 | 176 | if tokens_b: 177 | # Modifies `tokens_a` and `tokens_b` in place so that the total 178 | # length is less than the specified length. 179 | # Account for [CLS], [SEP], [SEP] with "- 3" 180 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 181 | else: 182 | # Account for [CLS] and [SEP] with "- 2" 183 | if len(tokens_a) > max_seq_length - 2: 184 | tokens_a = tokens_a[0:(max_seq_length - 2)] 185 | 186 | # The convention in BERT is: 187 | # (a) For sequence pairs: 188 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 189 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 190 | # (b) For single sequences: 191 | # tokens: [CLS] the dog is hairy . [SEP] 192 | # type_ids: 0 0 0 0 0 0 0 193 | # 194 | # Where "type_ids" are used to indicate whether this is the first 195 | # sequence or the second sequence. The embedding vectors for `type=0` and 196 | # `type=1` were learned during pre-training and are added to the wordpiece 197 | # embedding vector (and position vector). This is not *strictly* necessary 198 | # since the [SEP] token unambiguously separates the sequences, but it makes 199 | # it easier for the model to learn the concept of sequences. 200 | # 201 | # For classification tasks, the first vector (corresponding to [CLS]) is 202 | # used as as the "sentence vector". Note that this only makes sense because 203 | # the entire model is fine-tuned. 204 | tokens = [] 205 | segment_ids = [] 206 | tokens.append("[CLS]") 207 | segment_ids.append(0) 208 | for token in tokens_a: 209 | tokens.append(token) 210 | segment_ids.append(0) 211 | tokens.append("[SEP]") 212 | segment_ids.append(0) 213 | 214 | if tokens_b: 215 | for token in tokens_b: 216 | tokens.append(token) 217 | segment_ids.append(1) 218 | tokens.append("[SEP]") 219 | segment_ids.append(1) 220 | 221 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 222 | 223 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 224 | # tokens are attended to. 225 | input_mask = [1] * len(input_ids) 226 | 227 | # Zero-pad up to the sequence length. 228 | while len(input_ids) < max_seq_length: 229 | input_ids.append(0) 230 | input_mask.append(0) 231 | segment_ids.append(0) 232 | 233 | assert len(input_ids) == max_seq_length 234 | assert len(input_mask) == max_seq_length 235 | assert len(segment_ids) == max_seq_length 236 | 237 | label_id = label_map[example.label] 238 | if ex_index < 5: 239 | tf.logging.info("*** Example ***") 240 | tf.logging.info("guid: %s" % (example.guid)) 241 | tf.logging.info("tokens: %s" % " ".join( 242 | [tokenization.printable_text(x) for x in tokens])) 243 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 244 | # tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 245 | # tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 246 | # tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 247 | 248 | feature = InputFeatures( 249 | input_ids=input_ids, 250 | input_mask=input_mask, 251 | segment_ids=segment_ids, 252 | label_id=label_id) 253 | return feature 254 | 255 | 256 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 257 | """Truncates a sequence pair in place to the maximum length.""" 258 | 259 | # This is a simple heuristic which will always truncate the longer sequence 260 | # one token at a time. This makes more sense than truncating an equal percent 261 | # of tokens from each, since if one sequence is very short then each token 262 | # that's truncated likely contains more information than a longer sequence. 263 | while True: 264 | total_length = len(tokens_a) + len(tokens_b) 265 | if total_length <= max_length: 266 | break 267 | if len(tokens_a) > len(tokens_b): 268 | tokens_a.pop() 269 | else: 270 | tokens_b.pop() 271 | 272 | 273 | def create_int_feature(values): 274 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 275 | return f 276 | 277 | 278 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, labels, num_labels, 279 | use_one_hot_embeddings): 280 | """Creates a classification model.""" 281 | print("create_model.is_training:", is_training) 282 | model = modeling.BertModel( 283 | config=bert_config, 284 | is_training=is_training, 285 | input_ids=input_ids, 286 | input_mask=input_mask, 287 | token_type_ids=segment_ids, 288 | use_one_hot_embeddings=use_one_hot_embeddings) 289 | 290 | # In the demo, we are doing a simple classification task on the entire 291 | # segment. 292 | # 293 | # If you want to use the token-level output, use model.get_sequence_output() 294 | # instead. 295 | output_layer = model.get_pooled_output() 296 | 297 | hidden_size = output_layer.shape[-1].value 298 | output_weights = tf.get_variable( 299 | "output_weights", [num_labels, hidden_size], 300 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 301 | 302 | output_bias = tf.get_variable( 303 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 304 | 305 | with tf.variable_scope("loss"): 306 | if is_training: 307 | # I.e., 0.1 dropout 308 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 309 | 310 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 311 | logits = tf.nn.bias_add(logits, output_bias) 312 | probabilities = tf.nn.softmax(logits, axis=-1) 313 | log_probs = tf.nn.log_softmax(logits, axis=-1) 314 | 315 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 316 | 317 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 318 | loss = tf.reduce_mean(per_example_loss) 319 | 320 | return (loss, per_example_loss, logits, probabilities, model) 321 | 322 | 323 | tf.logging.set_verbosity(tf.logging.ERROR) # INFO 324 | processors = { 325 | "sentence_pair": SentencePairClassificationProcessor, 326 | } 327 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 328 | task_name = FLAGS.task_name.lower() 329 | print("task_name:", task_name) 330 | processor = processors[task_name]() 331 | label_list = processor.get_labels() 332 | # lines_dev=processor.get_dev_examples("./TEXT_DIR") 333 | index2label = {i: label_list[i] for i in range(len(label_list))} 334 | tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 335 | 336 | 337 | def main(_): 338 | pass 339 | 340 | 341 | # init mode and session 342 | # move something codes outside of function, so that this code will run only once during online prediction when predict_online is invoked. 343 | is_training = False 344 | use_one_hot_embeddings = False 345 | batch_size = 1 346 | num_labels = len(label_list) 347 | gpu_config = tf.ConfigProto() 348 | gpu_config.gpu_options.allow_growth = True 349 | #sess = tf.Session(config=gpu_config) 350 | model = None 351 | input_ids_p, input_mask_p, label_ids_p, segment_ids_p = None, None, None, None 352 | if not os.path.exists(FLAGS.init_checkpoint + "checkpoint"): 353 | raise Exception("failed to get checkpoint. going to return. init_checkpoint:", FLAGS.init_checkpoint) 354 | 355 | global graph 356 | graph = tf.Graph() # tf.get_default_graph() 357 | global sess 358 | sess = tf.Session(config=gpu_config, graph=graph) 359 | # with sess2: 360 | with graph.as_default(): 361 | print("BERT.going to restore checkpoint:"+FLAGS.init_checkpoint) 362 | # sess.run(tf.global_variables_initializer()) 363 | input_ids_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="input_ids") 364 | input_mask_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="input_mask") 365 | label_ids_p = tf.placeholder(tf.int32, [batch_size], name="label_ids") 366 | segment_ids_p = tf.placeholder(tf.int32, [FLAGS.max_seq_length], name="segment_ids") 367 | total_loss, per_example_loss, logits, probabilities, model = create_model(bert_config, is_training, input_ids_p, 368 | input_mask_p, segment_ids_p, label_ids_p, 369 | num_labels, use_one_hot_embeddings) 370 | saver = tf.train.Saver() 371 | saver.restore(sess, tf.train.latest_checkpoint(FLAGS.init_checkpoint)) 372 | ######################################################################### 373 | # trainable_variable_list=tf.trainable_variables() 374 | # trainable_variable_list=[x for x in trainable_variable_list if 'adam' not in x.name] 375 | # saver_new=tf.train.Saver(trainable_variable_list) 376 | # saver_new.save(sess,BERT_BASE_DIR_NEW) 377 | # print("save new checkpiont completed...") 378 | ######################################################################### 379 | 380 | print("checkpoint1:",FLAGS.init_checkpoint) 381 | 382 | 383 | def predict_online(content, type_information): 384 | """ 385 | do online prediction. each time make prediction for one instance. 386 | you can change to a batch if you want. 387 | 388 | :param line: a list. element is: [dummy_label,text_a,text_b] 389 | :return: 390 | """ 391 | # print("bert.predict_online.content:"+str(content)+";type_information:"+str(type_information)) 392 | label = '1' # tokenization.convert_to_unicode(line[0]) # this should compatible with format you defined in processor. 393 | text_a = tokenization.convert_to_unicode(type_information) 394 | text_b = tokenization.convert_to_unicode(content) 395 | example = InputExample(guid=0, text_a=text_a, text_b=text_b, label=label) 396 | feature = convert_single_example(0, example, label_list, FLAGS.max_seq_length, tokenizer) 397 | input_ids = np.reshape([feature.input_ids], (1, FLAGS.max_seq_length)) 398 | input_mask = np.reshape([feature.input_mask], (1, FLAGS.max_seq_length)) 399 | segment_ids = np.reshape([feature.segment_ids], (FLAGS.max_seq_length)) 400 | label_ids = [feature.label_id] 401 | 402 | global graph 403 | with graph.as_default(): 404 | feed_dict = {input_ids_p: input_ids, input_mask_p: input_mask, segment_ids_p: segment_ids, 405 | label_ids_p: label_ids} 406 | possibility = sess.run([probabilities], feed_dict) 407 | possibility = possibility[0][0] # get first label 408 | label_index = np.argmax(possibility) 409 | label_predict = index2label[label_index] 410 | return label_predict, possibility 411 | 412 | 413 | if __name__ == "__main__": 414 | # tf.app.run() 415 | # 0 劳动争议的经济性裁员 2010年10月21日原告向咸阳市劳动争议仲裁委员会申请劳动争议仲裁,该会于2010年向原告送达了咸劳仲不字第(2010)第38号不予受理通知书。 416 | time1 = time.time() 417 | content = '2010年10月21日原告向咸阳市劳动争议仲裁委员会申请劳动争议仲裁,该会于2010年向原告送达了咸劳仲不字第(2010)第38号不予受理通知书。' 418 | type_information = '劳动争议的经济性裁员' 419 | result = predict_online(content, type_information) 420 | time2 = time.time() 421 | print("result:", result, (time2 - time1)) 422 | 423 | time1 = time.time() 424 | content = '一、确认原告虞爱玲与被告浦江荣建置业有限公司的劳动关系已解除;' 425 | type_information = '劳动争议的解除劳动关系' 426 | result = predict_online(content, type_information) 427 | time2 = time.time() 428 | print("result:", result, (time2 - time1)) 429 | 430 | time1 = time.time() 431 | content = ' 婚后原、被告夫妻感情一般。' 432 | type_information = '婚姻家庭的限制行为能力子女抚养' 433 | result = predict_online(content, type_information) 434 | time2 = time.time() 435 | print("result:", result, (time2 - time1)) 436 | -------------------------------------------------------------------------------- /bert/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import numpy as np 27 | import six 28 | import tensorflow as tf 29 | 30 | 31 | class BertConfig(object): 32 | """Configuration for `BertModel`.""" 33 | 34 | def __init__(self, 35 | vocab_size, 36 | hidden_size=768, 37 | num_hidden_layers=12, 38 | num_attention_heads=12, 39 | intermediate_size=3072, 40 | hidden_act="gelu", 41 | hidden_dropout_prob=0.1, 42 | attention_probs_dropout_prob=0.1, 43 | max_position_embeddings=512, 44 | type_vocab_size=16, 45 | initializer_range=0.02): 46 | """Constructs BertConfig. 47 | 48 | Args: 49 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 50 | hidden_size: Size of the encoder layers and the pooler layer. 51 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 52 | num_attention_heads: Number of attention heads for each attention layer in 53 | the Transformer encoder. 54 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 55 | layer in the Transformer encoder. 56 | hidden_act: The non-linear activation function (function or string) in the 57 | encoder and pooler. 58 | hidden_dropout_prob: The dropout probability for all fully connected 59 | layers in the embeddings, encoder, and pooler. 60 | attention_probs_dropout_prob: The dropout ratio for the attention 61 | probabilities. 62 | max_position_embeddings: The maximum sequence length that this model might 63 | ever be used with. Typically set this to something large just in case 64 | (e.g., 512 or 1024 or 2048). 65 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 66 | `BertModel`. 67 | initializer_range: The stdev of the truncated_normal_initializer for 68 | initializing all weight matrices. 69 | """ 70 | self.vocab_size = vocab_size 71 | self.hidden_size = hidden_size 72 | self.num_hidden_layers = num_hidden_layers 73 | self.num_attention_heads = num_attention_heads 74 | self.hidden_act = hidden_act 75 | self.intermediate_size = intermediate_size 76 | self.hidden_dropout_prob = hidden_dropout_prob 77 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 78 | self.max_position_embeddings = max_position_embeddings 79 | self.type_vocab_size = type_vocab_size 80 | self.initializer_range = initializer_range 81 | 82 | @classmethod 83 | def from_dict(cls, json_object): 84 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 85 | config = BertConfig(vocab_size=None) 86 | for (key, value) in six.iteritems(json_object): 87 | config.__dict__[key] = value 88 | return config 89 | 90 | @classmethod 91 | def from_json_file(cls, json_file): 92 | """Constructs a `BertConfig` from a json file of parameters.""" 93 | with tf.gfile.GFile(json_file, "r") as reader: 94 | text = reader.read() 95 | return cls.from_dict(json.loads(text)) 96 | 97 | def to_dict(self): 98 | """Serializes this instance to a Python dictionary.""" 99 | output = copy.deepcopy(self.__dict__) 100 | return output 101 | 102 | def to_json_string(self): 103 | """Serializes this instance to a JSON string.""" 104 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 105 | 106 | 107 | class BertModel(object): 108 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 109 | 110 | Example usage: 111 | 112 | ```python 113 | # Already been converted into WordPiece token ids 114 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 115 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 116 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 117 | 118 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 119 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 120 | 121 | model = modeling.BertModel(config=config, is_training=True, 122 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 123 | 124 | label_embeddings = tf.get_variable(...) 125 | pooled_output = model.get_pooled_output() 126 | logits = tf.matmul(pooled_output, label_embeddings) 127 | ... 128 | ``` 129 | """ 130 | 131 | def __init__(self, 132 | config, 133 | is_training, 134 | input_ids, 135 | input_mask=None, 136 | token_type_ids=None, 137 | use_one_hot_embeddings=False, 138 | scope=None): 139 | """Constructor for BertModel. 140 | 141 | Args: 142 | config: `BertConfig` instance. 143 | is_training: bool. true for training model, false for eval model. Controls 144 | whether dropout will be applied. 145 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 146 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 147 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 148 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 149 | embeddings or tf.embedding_lookup() for the word embeddings. 150 | scope: (optional) variable scope. Defaults to "bert". 151 | 152 | Raises: 153 | ValueError: The config is invalid or one of the input tensor shapes 154 | is invalid. 155 | """ 156 | config = copy.deepcopy(config) 157 | if not is_training: 158 | config.hidden_dropout_prob = 0.0 159 | config.attention_probs_dropout_prob = 0.0 160 | 161 | input_shape = get_shape_list(input_ids, expected_rank=2) 162 | batch_size = input_shape[0] 163 | seq_length = input_shape[1] 164 | 165 | if input_mask is None: 166 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 167 | 168 | if token_type_ids is None: 169 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 170 | 171 | with tf.variable_scope(scope, default_name="bert"): 172 | with tf.variable_scope("embeddings"): 173 | # Perform embedding lookup on the word ids. 174 | (self.embedding_output, self.embedding_table) = embedding_lookup( 175 | input_ids=input_ids, 176 | vocab_size=config.vocab_size, 177 | embedding_size=config.hidden_size, 178 | initializer_range=config.initializer_range, 179 | word_embedding_name="word_embeddings", 180 | use_one_hot_embeddings=use_one_hot_embeddings) 181 | 182 | # Add positional embeddings and token type embeddings, then layer 183 | # normalize and perform dropout. 184 | self.embedding_output = embedding_postprocessor( 185 | input_tensor=self.embedding_output, 186 | use_token_type=True, 187 | token_type_ids=token_type_ids, 188 | token_type_vocab_size=config.type_vocab_size, 189 | token_type_embedding_name="token_type_embeddings", 190 | use_position_embeddings=True, 191 | position_embedding_name="position_embeddings", 192 | initializer_range=config.initializer_range, 193 | max_position_embeddings=config.max_position_embeddings, 194 | dropout_prob=config.hidden_dropout_prob) 195 | 196 | with tf.variable_scope("encoder"): 197 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 198 | # mask of shape [batch_size, seq_length, seq_length] which is used 199 | # for the attention scores. 200 | attention_mask = create_attention_mask_from_input_mask( 201 | input_ids, input_mask) 202 | 203 | # Run the stacked transformer. 204 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 205 | self.all_encoder_layers = transformer_model( 206 | input_tensor=self.embedding_output, 207 | attention_mask=attention_mask, 208 | hidden_size=config.hidden_size, 209 | num_hidden_layers=config.num_hidden_layers, 210 | num_attention_heads=config.num_attention_heads, 211 | intermediate_size=config.intermediate_size, 212 | intermediate_act_fn=get_activation(config.hidden_act), 213 | hidden_dropout_prob=config.hidden_dropout_prob, 214 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 215 | initializer_range=config.initializer_range, 216 | do_return_all_layers=True) 217 | 218 | self.sequence_output = self.all_encoder_layers[-1] # [batch_size, seq_length, hidden_size] 219 | # The "pooler" converts the encoded sequence tensor of shape 220 | # [batch_size, seq_length, hidden_size] to a tensor of shape 221 | # [batch_size, hidden_size]. This is necessary for segment-level 222 | # (or segment-pair-level) classification tasks where we need a fixed 223 | # dimensional representation of the segment. 224 | with tf.variable_scope("pooler"): 225 | # We "pool" the model by simply taking the hidden state corresponding 226 | # to the first token. We assume that this has been pre-trained 227 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 228 | self.pooled_output = tf.layers.dense( 229 | first_token_tensor, 230 | config.hidden_size, 231 | activation=tf.tanh, 232 | kernel_initializer=create_initializer(config.initializer_range)) 233 | 234 | def get_pooled_output(self): 235 | return self.pooled_output 236 | 237 | def get_sequence_output(self): 238 | """Gets final hidden layer of encoder. 239 | 240 | Returns: 241 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 242 | to the final hidden of the transformer encoder. 243 | """ 244 | return self.sequence_output 245 | 246 | def get_all_encoder_layers(self): 247 | return self.all_encoder_layers 248 | 249 | def get_embedding_output(self): 250 | """Gets output of the embedding lookup (i.e., input to the transformer). 251 | 252 | Returns: 253 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 254 | to the output of the embedding layer, after summing the word 255 | embeddings with the positional embeddings and the token type embeddings, 256 | then performing layer normalization. This is the input to the transformer. 257 | """ 258 | return self.embedding_output 259 | 260 | def get_embedding_table(self): 261 | return self.embedding_table 262 | 263 | 264 | def gelu(x): 265 | """Gaussian Error Linear Unit. 266 | 267 | This is a smoother version of the RELU. 268 | Original paper: https://arxiv.org/abs/1606.08415 269 | Args: 270 | x: float Tensor to perform activation. 271 | 272 | Returns: 273 | `x` with the GELU activation applied. 274 | """ 275 | cdf = 0.5 * (1.0 + tf.tanh( 276 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) 277 | return x * cdf 278 | 279 | 280 | def get_activation(activation_string): 281 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 282 | 283 | Args: 284 | activation_string: String name of the activation function. 285 | 286 | Returns: 287 | A Python function corresponding to the activation function. If 288 | `activation_string` is None, empty, or "linear", this will return None. 289 | If `activation_string` is not a string, it will return `activation_string`. 290 | 291 | Raises: 292 | ValueError: The `activation_string` does not correspond to a known 293 | activation. 294 | """ 295 | 296 | # We assume that anything that"s not a string is already an activation 297 | # function, so we just return it. 298 | if not isinstance(activation_string, six.string_types): 299 | return activation_string 300 | 301 | if not activation_string: 302 | return None 303 | 304 | act = activation_string.lower() 305 | if act == "linear": 306 | return None 307 | elif act == "relu": 308 | return tf.nn.relu 309 | elif act == "gelu": 310 | return gelu 311 | elif act == "tanh": 312 | return tf.tanh 313 | else: 314 | raise ValueError("Unsupported activation: %s" % act) 315 | 316 | 317 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 318 | """Compute the union of the current variables and checkpoint variables.""" 319 | assignment_map = {} 320 | initialized_variable_names = {} 321 | 322 | name_to_variable = collections.OrderedDict() 323 | for var in tvars: 324 | name = var.name 325 | m = re.match("^(.*):\\d+$", name) 326 | if m is not None: 327 | name = m.group(1) 328 | name_to_variable[name] = var 329 | 330 | init_vars = tf.train.list_variables(init_checkpoint) 331 | 332 | assignment_map = collections.OrderedDict() 333 | for x in init_vars: 334 | (name, var) = (x[0], x[1]) 335 | if name not in name_to_variable: 336 | continue 337 | assignment_map[name] = name 338 | initialized_variable_names[name] = 1 339 | initialized_variable_names[name + ":0"] = 1 340 | 341 | return (assignment_map, initialized_variable_names) 342 | 343 | 344 | def dropout(input_tensor, dropout_prob): 345 | """Perform dropout. 346 | 347 | Args: 348 | input_tensor: float Tensor. 349 | dropout_prob: Python float. The probability of dropping out a value (NOT of 350 | *keeping* a dimension as in `tf.nn.dropout`). 351 | 352 | Returns: 353 | A version of `input_tensor` with dropout applied. 354 | """ 355 | if dropout_prob is None or dropout_prob == 0.0: 356 | return input_tensor 357 | 358 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 359 | return output 360 | 361 | 362 | def layer_norm(input_tensor, name=None): 363 | """Run layer normalization on the last dimension of the tensor.""" 364 | return tf.contrib.layers.layer_norm( 365 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 366 | 367 | 368 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 369 | """Runs layer normalization followed by dropout.""" 370 | output_tensor = layer_norm(input_tensor, name) 371 | output_tensor = dropout(output_tensor, dropout_prob) 372 | return output_tensor 373 | 374 | 375 | def create_initializer(initializer_range=0.02): 376 | """Creates a `truncated_normal_initializer` with the given range.""" 377 | return tf.truncated_normal_initializer(stddev=initializer_range) 378 | 379 | 380 | def embedding_lookup(input_ids, 381 | vocab_size, 382 | embedding_size=128, 383 | initializer_range=0.02, 384 | word_embedding_name="word_embeddings", 385 | use_one_hot_embeddings=False): 386 | """Looks up words embeddings for id tensor. 387 | 388 | Args: 389 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 390 | ids. 391 | vocab_size: int. Size of the embedding vocabulary. 392 | embedding_size: int. Width of the word embeddings. 393 | initializer_range: float. Embedding initialization range. 394 | word_embedding_name: string. Name of the embedding table. 395 | use_one_hot_embeddings: bool. If True, use one-hot method for word 396 | embeddings. If False, use `tf.gather()`. 397 | 398 | Returns: 399 | float Tensor of shape [batch_size, seq_length, embedding_size]. 400 | """ 401 | # This function assumes that the input is of shape [batch_size, seq_length, 402 | # num_inputs]. 403 | # 404 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 405 | # reshape to [batch_size, seq_length, 1]. 406 | if input_ids.shape.ndims == 2: 407 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 408 | 409 | embedding_table = tf.get_variable( 410 | name=word_embedding_name, 411 | shape=[vocab_size, embedding_size], 412 | initializer=create_initializer(initializer_range)) 413 | 414 | flat_input_ids = tf.reshape(input_ids, [-1]) 415 | if use_one_hot_embeddings: 416 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 417 | output = tf.matmul(one_hot_input_ids, embedding_table) 418 | else: 419 | output = tf.gather(embedding_table, flat_input_ids) 420 | 421 | input_shape = get_shape_list(input_ids) 422 | 423 | output = tf.reshape(output, 424 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 425 | return (output, embedding_table) 426 | 427 | 428 | def embedding_postprocessor(input_tensor, 429 | use_token_type=False, 430 | token_type_ids=None, 431 | token_type_vocab_size=16, 432 | token_type_embedding_name="token_type_embeddings", 433 | use_position_embeddings=True, 434 | position_embedding_name="position_embeddings", 435 | initializer_range=0.02, 436 | max_position_embeddings=512, 437 | dropout_prob=0.1): 438 | """Performs various post-processing on a word embedding tensor. 439 | 440 | Args: 441 | input_tensor: float Tensor of shape [batch_size, seq_length, 442 | embedding_size]. 443 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 444 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 445 | Must be specified if `use_token_type` is True. 446 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 447 | token_type_embedding_name: string. The name of the embedding table variable 448 | for token type ids. 449 | use_position_embeddings: bool. Whether to add position embeddings for the 450 | position of each token in the sequence. 451 | position_embedding_name: string. The name of the embedding table variable 452 | for positional embeddings. 453 | initializer_range: float. Range of the weight initialization. 454 | max_position_embeddings: int. Maximum sequence length that might ever be 455 | used with this model. This can be longer than the sequence length of 456 | input_tensor, but cannot be shorter. 457 | dropout_prob: float. Dropout probability applied to the final output tensor. 458 | 459 | Returns: 460 | float tensor with same shape as `input_tensor`. 461 | 462 | Raises: 463 | ValueError: One of the tensor shapes or input values is invalid. 464 | """ 465 | input_shape = get_shape_list(input_tensor, expected_rank=3) 466 | batch_size = input_shape[0] 467 | seq_length = input_shape[1] 468 | width = input_shape[2] 469 | 470 | output = input_tensor 471 | 472 | if use_token_type: 473 | if token_type_ids is None: 474 | raise ValueError("`token_type_ids` must be specified if" 475 | "`use_token_type` is True.") 476 | token_type_table = tf.get_variable( 477 | name=token_type_embedding_name, 478 | shape=[token_type_vocab_size, width], 479 | initializer=create_initializer(initializer_range)) 480 | # This vocab will be small so we always do one-hot here, since it is always 481 | # faster for a small vocabulary. 482 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 483 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 484 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 485 | token_type_embeddings = tf.reshape(token_type_embeddings, 486 | [batch_size, seq_length, width]) 487 | output += token_type_embeddings 488 | 489 | if use_position_embeddings: 490 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 491 | with tf.control_dependencies([assert_op]): 492 | full_position_embeddings = tf.get_variable( 493 | name=position_embedding_name, 494 | shape=[max_position_embeddings, width], 495 | initializer=create_initializer(initializer_range)) 496 | # Since the position embedding table is a learned variable, we create it 497 | # using a (long) sequence length `max_position_embeddings`. The actual 498 | # sequence length might be shorter than this, for faster training of 499 | # tasks that do not have long sequences. 500 | # 501 | # So `full_position_embeddings` is effectively an embedding table 502 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 503 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 504 | # perform a slice. 505 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 506 | [seq_length, -1]) 507 | num_dims = len(output.shape.as_list()) 508 | 509 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 510 | # we broadcast among the first dimensions, which is typically just 511 | # the batch size. 512 | position_broadcast_shape = [] 513 | for _ in range(num_dims - 2): 514 | position_broadcast_shape.append(1) 515 | position_broadcast_shape.extend([seq_length, width]) 516 | position_embeddings = tf.reshape(position_embeddings, 517 | position_broadcast_shape) 518 | output += position_embeddings 519 | 520 | output = layer_norm_and_dropout(output, dropout_prob) 521 | return output 522 | 523 | 524 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 525 | """Create 3D attention mask from a 2D tensor mask. 526 | 527 | Args: 528 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 529 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 530 | 531 | Returns: 532 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 533 | """ 534 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 535 | batch_size = from_shape[0] 536 | from_seq_length = from_shape[1] 537 | 538 | to_shape = get_shape_list(to_mask, expected_rank=2) 539 | to_seq_length = to_shape[1] 540 | 541 | to_mask = tf.cast( 542 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 543 | 544 | # We don't assume that `from_tensor` is a mask (although it could be). We 545 | # don't actually care if we attend *from* padding tokens (only *to* padding) 546 | # tokens so we create a tensor of all ones. 547 | # 548 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 549 | broadcast_ones = tf.ones( 550 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 551 | 552 | # Here we broadcast along two dimensions to create the mask. 553 | mask = broadcast_ones * to_mask 554 | 555 | return mask 556 | 557 | 558 | def attention_layer(from_tensor, 559 | to_tensor, 560 | attention_mask=None, 561 | num_attention_heads=1, 562 | size_per_head=512, 563 | query_act=None, 564 | key_act=None, 565 | value_act=None, 566 | attention_probs_dropout_prob=0.0, 567 | initializer_range=0.02, 568 | do_return_2d_tensor=False, 569 | batch_size=None, 570 | from_seq_length=None, 571 | to_seq_length=None): 572 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 573 | 574 | This is an implementation of multi-headed attention based on "Attention 575 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 576 | this is self-attention. Each timestep in `from_tensor` attends to the 577 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 578 | 579 | This function first projects `from_tensor` into a "query" tensor and 580 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 581 | of tensors of length `num_attention_heads`, where each tensor is of shape 582 | [batch_size, seq_length, size_per_head]. 583 | 584 | Then, the query and key tensors are dot-producted and scaled. These are 585 | softmaxed to obtain attention probabilities. The value tensors are then 586 | interpolated by these probabilities, then concatenated back to a single 587 | tensor and returned. 588 | 589 | In practice, the multi-headed attention are done with transposes and 590 | reshapes rather than actual separate tensors. 591 | 592 | Args: 593 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 594 | from_width]. 595 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 596 | attention_mask: (optional) int32 Tensor of shape [batch_size, 597 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 598 | attention scores will effectively be set to -infinity for any positions in 599 | the mask that are 0, and will be unchanged for positions that are 1. 600 | num_attention_heads: int. Number of attention heads. 601 | size_per_head: int. Size of each attention head. 602 | query_act: (optional) Activation function for the query transform. 603 | key_act: (optional) Activation function for the key transform. 604 | value_act: (optional) Activation function for the value transform. 605 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 606 | attention probabilities. 607 | initializer_range: float. Range of the weight initializer. 608 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 609 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 610 | output will be of shape [batch_size, from_seq_length, num_attention_heads 611 | * size_per_head]. 612 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 613 | of the 3D version of the `from_tensor` and `to_tensor`. 614 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 615 | of the 3D version of the `from_tensor`. 616 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 617 | of the 3D version of the `to_tensor`. 618 | 619 | Returns: 620 | float Tensor of shape [batch_size, from_seq_length, 621 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 622 | true, this will be of shape [batch_size * from_seq_length, 623 | num_attention_heads * size_per_head]). 624 | 625 | Raises: 626 | ValueError: Any of the arguments or tensor shapes are invalid. 627 | """ 628 | 629 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 630 | seq_length, width): 631 | output_tensor = tf.reshape( 632 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 633 | 634 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 635 | return output_tensor 636 | 637 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 638 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 639 | 640 | if len(from_shape) != len(to_shape): 641 | raise ValueError( 642 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 643 | 644 | if len(from_shape) == 3: 645 | batch_size = from_shape[0] 646 | from_seq_length = from_shape[1] 647 | to_seq_length = to_shape[1] 648 | elif len(from_shape) == 2: 649 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 650 | raise ValueError( 651 | "When passing in rank 2 tensors to attention_layer, the values " 652 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 653 | "must all be specified.") 654 | 655 | # Scalar dimensions referenced here: 656 | # B = batch size (number of sequences) 657 | # F = `from_tensor` sequence length 658 | # T = `to_tensor` sequence length 659 | # N = `num_attention_heads` 660 | # H = `size_per_head` 661 | 662 | from_tensor_2d = reshape_to_matrix(from_tensor) 663 | to_tensor_2d = reshape_to_matrix(to_tensor) 664 | 665 | # `query_layer` = [B*F, N*H] 666 | query_layer = tf.layers.dense( 667 | from_tensor_2d, 668 | num_attention_heads * size_per_head, 669 | activation=query_act, 670 | name="query", 671 | kernel_initializer=create_initializer(initializer_range)) 672 | 673 | # `key_layer` = [B*T, N*H] 674 | key_layer = tf.layers.dense( 675 | to_tensor_2d, 676 | num_attention_heads * size_per_head, 677 | activation=key_act, 678 | name="key", 679 | kernel_initializer=create_initializer(initializer_range)) 680 | 681 | # `value_layer` = [B*T, N*H] 682 | value_layer = tf.layers.dense( 683 | to_tensor_2d, 684 | num_attention_heads * size_per_head, 685 | activation=value_act, 686 | name="value", 687 | kernel_initializer=create_initializer(initializer_range)) 688 | 689 | # `query_layer` = [B, N, F, H] 690 | query_layer = transpose_for_scores(query_layer, batch_size, 691 | num_attention_heads, from_seq_length, 692 | size_per_head) 693 | 694 | # `key_layer` = [B, N, T, H] 695 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 696 | to_seq_length, size_per_head) 697 | 698 | # Take the dot product between "query" and "key" to get the raw 699 | # attention scores. 700 | # `attention_scores` = [B, N, F, T] 701 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 702 | attention_scores = tf.multiply(attention_scores, 703 | 1.0 / math.sqrt(float(size_per_head))) 704 | 705 | if attention_mask is not None: 706 | # `attention_mask` = [B, 1, F, T] 707 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 708 | 709 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 710 | # masked positions, this operation will create a tensor which is 0.0 for 711 | # positions we want to attend and -10000.0 for masked positions. 712 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 713 | 714 | # Since we are adding it to the raw scores before the softmax, this is 715 | # effectively the same as removing these entirely. 716 | attention_scores += adder 717 | 718 | # Normalize the attention scores to probabilities. 719 | # `attention_probs` = [B, N, F, T] 720 | attention_probs = tf.nn.softmax(attention_scores) 721 | 722 | # This is actually dropping out entire tokens to attend to, which might 723 | # seem a bit unusual, but is taken from the original Transformer paper. 724 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 725 | 726 | # `value_layer` = [B, T, N, H] 727 | value_layer = tf.reshape( 728 | value_layer, 729 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 730 | 731 | # `value_layer` = [B, N, T, H] 732 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 733 | 734 | # `context_layer` = [B, N, F, H] 735 | context_layer = tf.matmul(attention_probs, value_layer) 736 | 737 | # `context_layer` = [B, F, N, H] 738 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 739 | 740 | if do_return_2d_tensor: 741 | # `context_layer` = [B*F, N*H] 742 | context_layer = tf.reshape( 743 | context_layer, 744 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 745 | else: 746 | # `context_layer` = [B, F, N*H] 747 | context_layer = tf.reshape( 748 | context_layer, 749 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 750 | 751 | return context_layer 752 | 753 | 754 | def transformer_model(input_tensor, 755 | attention_mask=None, 756 | hidden_size=768, 757 | num_hidden_layers=12, 758 | num_attention_heads=12, 759 | intermediate_size=3072, 760 | intermediate_act_fn=gelu, 761 | hidden_dropout_prob=0.1, 762 | attention_probs_dropout_prob=0.1, 763 | initializer_range=0.02, 764 | do_return_all_layers=False): 765 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 766 | 767 | This is almost an exact implementation of the original Transformer encoder. 768 | 769 | See the original paper: 770 | https://arxiv.org/abs/1706.03762 771 | 772 | Also see: 773 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 774 | 775 | Args: 776 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 777 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 778 | seq_length], with 1 for positions that can be attended to and 0 in 779 | positions that should not be. 780 | hidden_size: int. Hidden size of the Transformer. 781 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 782 | num_attention_heads: int. Number of attention heads in the Transformer. 783 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 784 | forward) layer. 785 | intermediate_act_fn: function. The non-linear activation function to apply 786 | to the output of the intermediate/feed-forward layer. 787 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 788 | attention_probs_dropout_prob: float. Dropout probability of the attention 789 | probabilities. 790 | initializer_range: float. Range of the initializer (stddev of truncated 791 | normal). 792 | do_return_all_layers: Whether to also return all layers or just the final 793 | layer. 794 | 795 | Returns: 796 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 797 | hidden layer of the Transformer. 798 | 799 | Raises: 800 | ValueError: A Tensor shape or parameter is invalid. 801 | """ 802 | if hidden_size % num_attention_heads != 0: 803 | raise ValueError( 804 | "The hidden size (%d) is not a multiple of the number of attention " 805 | "heads (%d)" % (hidden_size, num_attention_heads)) 806 | 807 | attention_head_size = int(hidden_size / num_attention_heads) 808 | input_shape = get_shape_list(input_tensor, expected_rank=3) 809 | batch_size = input_shape[0] 810 | seq_length = input_shape[1] 811 | input_width = input_shape[2] 812 | 813 | # The Transformer performs sum residuals on all layers so the input needs 814 | # to be the same as the hidden size. 815 | if input_width != hidden_size: 816 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 817 | (input_width, hidden_size)) 818 | 819 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 820 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 821 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 822 | # help the optimizer. 823 | prev_output = reshape_to_matrix(input_tensor) 824 | 825 | all_layer_outputs = [] 826 | for layer_idx in range(num_hidden_layers): 827 | with tf.variable_scope("layer_%d" % layer_idx): 828 | layer_input = prev_output 829 | 830 | with tf.variable_scope("attention"): 831 | attention_heads = [] 832 | with tf.variable_scope("self"): 833 | attention_head = attention_layer( 834 | from_tensor=layer_input, 835 | to_tensor=layer_input, 836 | attention_mask=attention_mask, 837 | num_attention_heads=num_attention_heads, 838 | size_per_head=attention_head_size, 839 | attention_probs_dropout_prob=attention_probs_dropout_prob, 840 | initializer_range=initializer_range, 841 | do_return_2d_tensor=True, 842 | batch_size=batch_size, 843 | from_seq_length=seq_length, 844 | to_seq_length=seq_length) 845 | attention_heads.append(attention_head) 846 | 847 | attention_output = None 848 | if len(attention_heads) == 1: 849 | attention_output = attention_heads[0] 850 | else: 851 | # In the case where we have other sequences, we just concatenate 852 | # them to the self-attention head before the projection. 853 | attention_output = tf.concat(attention_heads, axis=-1) 854 | 855 | # Run a linear projection of `hidden_size` then add a residual 856 | # with `layer_input`. 857 | with tf.variable_scope("output"): 858 | attention_output = tf.layers.dense( 859 | attention_output, 860 | hidden_size, 861 | kernel_initializer=create_initializer(initializer_range)) 862 | attention_output = dropout(attention_output, hidden_dropout_prob) 863 | attention_output = layer_norm(attention_output + layer_input) 864 | 865 | # The activation is only applied to the "intermediate" hidden layer. 866 | with tf.variable_scope("intermediate"): 867 | intermediate_output = tf.layers.dense( 868 | attention_output, 869 | intermediate_size, 870 | activation=intermediate_act_fn, 871 | kernel_initializer=create_initializer(initializer_range)) 872 | 873 | # Down-project back to `hidden_size` then add the residual. 874 | with tf.variable_scope("output"): 875 | layer_output = tf.layers.dense( 876 | intermediate_output, 877 | hidden_size, 878 | kernel_initializer=create_initializer(initializer_range)) 879 | layer_output = dropout(layer_output, hidden_dropout_prob) 880 | layer_output = layer_norm(layer_output + attention_output) 881 | prev_output = layer_output 882 | all_layer_outputs.append(layer_output) 883 | 884 | if do_return_all_layers: 885 | final_outputs = [] 886 | for layer_output in all_layer_outputs: 887 | final_output = reshape_from_matrix(layer_output, input_shape) 888 | final_outputs.append(final_output) 889 | return final_outputs 890 | else: 891 | final_output = reshape_from_matrix(prev_output, input_shape) 892 | return final_output 893 | 894 | 895 | def get_shape_list(tensor, expected_rank=None, name=None): 896 | """Returns a list of the shape of tensor, preferring static dimensions. 897 | 898 | Args: 899 | tensor: A tf.Tensor object to find the shape of. 900 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 901 | specified and the `tensor` has a different rank, and exception will be 902 | thrown. 903 | name: Optional name of the tensor for the error message. 904 | 905 | Returns: 906 | A list of dimensions of the shape of tensor. All static dimensions will 907 | be returned as python integers, and dynamic dimensions will be returned 908 | as tf.Tensor scalars. 909 | """ 910 | if name is None: 911 | name = tensor.name 912 | 913 | if expected_rank is not None: 914 | assert_rank(tensor, expected_rank, name) 915 | 916 | shape = tensor.shape.as_list() 917 | 918 | non_static_indexes = [] 919 | for (index, dim) in enumerate(shape): 920 | if dim is None: 921 | non_static_indexes.append(index) 922 | 923 | if not non_static_indexes: 924 | return shape 925 | 926 | dyn_shape = tf.shape(tensor) 927 | for index in non_static_indexes: 928 | shape[index] = dyn_shape[index] 929 | return shape 930 | 931 | 932 | def reshape_to_matrix(input_tensor): 933 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 934 | ndims = input_tensor.shape.ndims 935 | if ndims < 2: 936 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 937 | (input_tensor.shape)) 938 | if ndims == 2: 939 | return input_tensor 940 | 941 | width = input_tensor.shape[-1] 942 | output_tensor = tf.reshape(input_tensor, [-1, width]) 943 | return output_tensor 944 | 945 | 946 | def reshape_from_matrix(output_tensor, orig_shape_list): 947 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 948 | if len(orig_shape_list) == 2: 949 | return output_tensor 950 | 951 | output_shape = get_shape_list(output_tensor) 952 | 953 | orig_dims = orig_shape_list[0:-1] 954 | width = output_shape[-1] 955 | 956 | return tf.reshape(output_tensor, orig_dims + [width]) 957 | 958 | 959 | def assert_rank(tensor, expected_rank, name=None): 960 | """Raises an exception if the tensor rank is not of the expected rank. 961 | 962 | Args: 963 | tensor: A tf.Tensor to check the rank of. 964 | expected_rank: Python integer or list of integers, expected rank. 965 | name: Optional name of the tensor for the error message. 966 | 967 | Raises: 968 | ValueError: If the expected shape doesn't match the actual shape. 969 | """ 970 | if name is None: 971 | name = tensor.name 972 | 973 | expected_rank_dict = {} 974 | if isinstance(expected_rank, six.integer_types): 975 | expected_rank_dict[expected_rank] = True 976 | else: 977 | for x in expected_rank: 978 | expected_rank_dict[x] = True 979 | 980 | actual_rank = tensor.shape.ndims 981 | if actual_rank not in expected_rank_dict: 982 | scope_name = tf.get_variable_scope().name 983 | raise ValueError( 984 | "For the tensor `%s` in scope `%s`, the actual rank " 985 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 986 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 987 | -------------------------------------------------------------------------------- /run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import csv 23 | import os 24 | import bert.modeling as modeling 25 | import bert.optimization_finetuning as optimization 26 | import bert.tokenization as tokenization 27 | import tensorflow as tf 28 | # from loss import bi_tempered_logistic_loss 29 | 30 | flags = tf.flags 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | ## Required parameters 35 | flags.DEFINE_string( 36 | "data_dir", None, 37 | "The input data dir. Should contain the .tsv files (or other data files) " 38 | "for the task.") 39 | 40 | flags.DEFINE_string( 41 | "bert_config_file", None, 42 | "The config json file corresponding to the pre-trained BERT model. " 43 | "This specifies the model architecture.") 44 | 45 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 46 | 47 | flags.DEFINE_string("vocab_file", None, 48 | "The vocabulary file that the BERT model was trained on.") 49 | 50 | flags.DEFINE_string( 51 | "output_dir", None, 52 | "The output directory where the model checkpoints will be written.") 53 | 54 | ## Other parameters 55 | 56 | flags.DEFINE_string( 57 | "init_checkpoint", None, 58 | "Initial checkpoint (usually from a pre-trained BERT model).") 59 | 60 | flags.DEFINE_bool( 61 | "do_lower_case", True, 62 | "Whether to lower case the input text. Should be True for uncased " 63 | "models and False for cased models.") 64 | 65 | flags.DEFINE_integer( 66 | "max_seq_length", 128, 67 | "The maximum total input sequence length after WordPiece tokenization. " 68 | "Sequences longer than this will be truncated, and sequences shorter " 69 | "than this will be padded.") 70 | 71 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 72 | 73 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 74 | 75 | flags.DEFINE_bool( 76 | "do_predict", False, 77 | "Whether to run the model in inference mode on the test set.") 78 | 79 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 80 | 81 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 82 | 83 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 84 | 85 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 86 | 87 | flags.DEFINE_float("num_train_epochs", 3.0, 88 | "Total number of training epochs to perform.") 89 | 90 | flags.DEFINE_float( 91 | "warmup_proportion", 0.1, 92 | "Proportion of training to perform linear learning rate warmup for. " 93 | "E.g., 0.1 = 10% of training.") 94 | 95 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 96 | "How often to save the model checkpoint.") 97 | 98 | flags.DEFINE_integer("iterations_per_loop", 1000, 99 | "How many steps to make in each estimator call.") 100 | 101 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 102 | 103 | tf.flags.DEFINE_string( 104 | "tpu_name", None, 105 | "The Cloud TPU to use for training. This should be either the name " 106 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 107 | "url.") 108 | 109 | tf.flags.DEFINE_string( 110 | "tpu_zone", None, 111 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 112 | "specified, we will attempt to automatically detect the GCE project from " 113 | "metadata.") 114 | 115 | tf.flags.DEFINE_string( 116 | "gcp_project", None, 117 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 118 | "specified, we will attempt to automatically detect the GCE project from " 119 | "metadata.") 120 | 121 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 122 | 123 | flags.DEFINE_integer( 124 | "num_tpu_cores", 8, 125 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 126 | 127 | 128 | class InputExample(object): 129 | """A single training/test example for simple sequence classification.""" 130 | 131 | def __init__(self, guid, text_a, text_b=None, label=None): 132 | """Constructs a InputExample. 133 | Args: 134 | guid: Unique id for the example. 135 | text_a: string. The untokenized text of the first sequence. For single 136 | sequence tasks, only this sequence must be specified. 137 | text_b: (Optional) string. The untokenized text of the second sequence. 138 | Only must be specified for sequence pair tasks. 139 | label: (Optional) string. The label of the example. This should be 140 | specified for train and dev examples, but not for test examples. 141 | """ 142 | self.guid = guid 143 | self.text_a = text_a 144 | self.text_b = text_b 145 | self.label = label 146 | 147 | 148 | class PaddingInputExample(object): 149 | """Fake example so the num input examples is a multiple of the batch size. 150 | When running eval/predict on the TPU, we need to pad the number of examples 151 | to be a multiple of the batch size, because the TPU requires a fixed batch 152 | size. The alternative is to drop the last batch, which is bad because it means 153 | the entire output data won't be generated. 154 | We use this class instead of `None` because treating `None` as padding 155 | battches could cause silent errors. 156 | """ 157 | 158 | 159 | class InputFeatures(object): 160 | """A single set of features of data.""" 161 | 162 | def __init__(self, 163 | input_ids, 164 | input_mask, 165 | segment_ids, 166 | label_id, 167 | is_real_example=True): 168 | self.input_ids = input_ids 169 | self.input_mask = input_mask 170 | self.segment_ids = segment_ids 171 | self.label_id = label_id 172 | self.is_real_example = is_real_example 173 | 174 | 175 | class DataProcessor(object): 176 | """Base class for data converters for sequence classification data sets.""" 177 | 178 | def get_train_examples(self, data_dir): 179 | """Gets a collection of `InputExample`s for the train set.""" 180 | raise NotImplementedError() 181 | 182 | def get_dev_examples(self, data_dir): 183 | """Gets a collection of `InputExample`s for the dev set.""" 184 | raise NotImplementedError() 185 | 186 | def get_test_examples(self, data_dir): 187 | """Gets a collection of `InputExample`s for prediction.""" 188 | raise NotImplementedError() 189 | 190 | def get_labels(self): 191 | """Gets the list of labels for this data set.""" 192 | raise NotImplementedError() 193 | 194 | @classmethod 195 | def _read_tsv(cls, input_file, quotechar=None): 196 | """Reads a tab separated value file.""" 197 | with tf.gfile.Open(input_file, "r") as f: 198 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 199 | lines = [] 200 | for line in reader: 201 | lines.append(line) 202 | return lines 203 | 204 | 205 | class XnliProcessor(DataProcessor): 206 | """Processor for the XNLI data set.""" 207 | 208 | def __init__(self): 209 | self.language = "zh" 210 | 211 | def get_train_examples(self, data_dir): 212 | """See base class.""" 213 | lines = self._read_tsv( 214 | os.path.join(data_dir, "multinli", 215 | "multinli.train.%s.tsv" % self.language)) 216 | examples = [] 217 | for (i, line) in enumerate(lines): 218 | if i == 0: 219 | continue 220 | guid = "train-%d" % (i) 221 | text_a = tokenization.convert_to_unicode(line[0]) 222 | text_b = tokenization.convert_to_unicode(line[1]) 223 | label = tokenization.convert_to_unicode(line[2]) 224 | if label == tokenization.convert_to_unicode("contradictory"): 225 | label = tokenization.convert_to_unicode("contradiction") 226 | examples.append( 227 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 228 | return examples 229 | 230 | def get_dev_examples(self, data_dir): 231 | """See base class.""" 232 | lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) 233 | examples = [] 234 | for (i, line) in enumerate(lines): 235 | if i == 0: 236 | continue 237 | guid = "dev-%d" % (i) 238 | language = tokenization.convert_to_unicode(line[0]) 239 | if language != tokenization.convert_to_unicode(self.language): 240 | continue 241 | text_a = tokenization.convert_to_unicode(line[6]) 242 | text_b = tokenization.convert_to_unicode(line[7]) 243 | label = tokenization.convert_to_unicode(line[1]) 244 | examples.append( 245 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 246 | return examples 247 | 248 | def get_labels(self): 249 | """See base class.""" 250 | return ["contradiction", "entailment", "neutral"] 251 | 252 | 253 | class MnliProcessor(DataProcessor): 254 | """Processor for the MultiNLI data set (GLUE version).""" 255 | 256 | def get_train_examples(self, data_dir): 257 | """See base class.""" 258 | return self._create_examples( 259 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 260 | 261 | def get_dev_examples(self, data_dir): 262 | """See base class.""" 263 | return self._create_examples( 264 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 265 | "dev_matched") 266 | 267 | def get_test_examples(self, data_dir): 268 | """See base class.""" 269 | return self._create_examples( 270 | self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test") 271 | 272 | def get_labels(self): 273 | """See base class.""" 274 | return ["contradiction", "entailment", "neutral"] 275 | 276 | def _create_examples(self, lines, set_type): 277 | """Creates examples for the training and dev sets.""" 278 | examples = [] 279 | for (i, line) in enumerate(lines): 280 | if i == 0: 281 | continue 282 | guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) 283 | text_a = tokenization.convert_to_unicode(line[8]) 284 | text_b = tokenization.convert_to_unicode(line[9]) 285 | if set_type == "test": 286 | label = "contradiction" 287 | else: 288 | label = tokenization.convert_to_unicode(line[-1]) 289 | examples.append( 290 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 291 | return examples 292 | 293 | 294 | class MrpcProcessor(DataProcessor): 295 | """Processor for the MRPC data set (GLUE version).""" 296 | 297 | def get_train_examples(self, data_dir): 298 | """See base class.""" 299 | return self._create_examples( 300 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 301 | 302 | def get_dev_examples(self, data_dir): 303 | """See base class.""" 304 | return self._create_examples( 305 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 306 | 307 | def get_test_examples(self, data_dir): 308 | """See base class.""" 309 | return self._create_examples( 310 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 311 | 312 | def get_labels(self): 313 | """See base class.""" 314 | return ["0", "1"] 315 | 316 | def _create_examples(self, lines, set_type): 317 | """Creates examples for the training and dev sets.""" 318 | examples = [] 319 | for (i, line) in enumerate(lines): 320 | if i == 0: 321 | continue 322 | guid = "%s-%s" % (set_type, i) 323 | text_a = tokenization.convert_to_unicode(line[3]) 324 | text_b = tokenization.convert_to_unicode(line[4]) 325 | if set_type == "test": 326 | label = "0" 327 | else: 328 | label = tokenization.convert_to_unicode(line[0]) 329 | examples.append( 330 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 331 | return examples 332 | 333 | 334 | class ColaProcessor(DataProcessor): 335 | """Processor for the CoLA data set (GLUE version).""" 336 | 337 | def get_train_examples(self, data_dir): 338 | """See base class.""" 339 | return self._create_examples( 340 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 341 | 342 | def get_dev_examples(self, data_dir): 343 | """See base class.""" 344 | return self._create_examples( 345 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 346 | 347 | def get_test_examples(self, data_dir): 348 | """See base class.""" 349 | return self._create_examples( 350 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 351 | 352 | def get_labels(self): 353 | """See base class.""" 354 | return ["0", "1"] 355 | 356 | def _create_examples(self, lines, set_type): 357 | """Creates examples for the training and dev sets.""" 358 | examples = [] 359 | for (i, line) in enumerate(lines): 360 | # Only the test set has a header 361 | if set_type == "test" and i == 0: 362 | continue 363 | guid = "%s-%s" % (set_type, i) 364 | if set_type == "test": 365 | text_a = tokenization.convert_to_unicode(line[1]) 366 | label = "0" 367 | else: 368 | text_a = tokenization.convert_to_unicode(line[3]) 369 | label = tokenization.convert_to_unicode(line[1]) 370 | examples.append( 371 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 372 | return examples 373 | 374 | 375 | def convert_single_example(ex_index, example, label_list, max_seq_length, 376 | tokenizer): 377 | """Converts a single `InputExample` into a single `InputFeatures`.""" 378 | 379 | if isinstance(example, PaddingInputExample): 380 | return InputFeatures( 381 | input_ids=[0] * max_seq_length, 382 | input_mask=[0] * max_seq_length, 383 | segment_ids=[0] * max_seq_length, 384 | label_id=0, 385 | is_real_example=False) 386 | 387 | label_map = {} 388 | for (i, label) in enumerate(label_list): 389 | label_map[label] = i 390 | 391 | tokens_a = tokenizer.tokenize(example.text_a) 392 | tokens_b = None 393 | if example.text_b: 394 | tokens_b = tokenizer.tokenize(example.text_b) 395 | 396 | if tokens_b: 397 | # Modifies `tokens_a` and `tokens_b` in place so that the total 398 | # length is less than the specified length. 399 | # Account for [CLS], [SEP], [SEP] with "- 3" 400 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 401 | else: 402 | # Account for [CLS] and [SEP] with "- 2" 403 | if len(tokens_a) > max_seq_length - 2: 404 | tokens_a = tokens_a[0:(max_seq_length - 2)] 405 | 406 | # The convention in BERT is: 407 | # (a) For sequence pairs: 408 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 409 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 410 | # (b) For single sequences: 411 | # tokens: [CLS] the dog is hairy . [SEP] 412 | # type_ids: 0 0 0 0 0 0 0 413 | # 414 | # Where "type_ids" are used to indicate whether this is the first 415 | # sequence or the second sequence. The embedding vectors for `type=0` and 416 | # `type=1` were learned during pre-training and are added to the wordpiece 417 | # embedding vector (and position vector). This is not *strictly* necessary 418 | # since the [SEP] token unambiguously separates the sequences, but it makes 419 | # it easier for the model to learn the concept of sequences. 420 | # 421 | # For classification tasks, the first vector (corresponding to [CLS]) is 422 | # used as the "sentence vector". Note that this only makes sense because 423 | # the entire model is fine-tuned. 424 | tokens = [] 425 | segment_ids = [] 426 | tokens.append("[CLS]") 427 | segment_ids.append(0) 428 | for token in tokens_a: 429 | tokens.append(token) 430 | segment_ids.append(0) 431 | tokens.append("[SEP]") 432 | segment_ids.append(0) 433 | 434 | if tokens_b: 435 | for token in tokens_b: 436 | tokens.append(token) 437 | segment_ids.append(1) 438 | tokens.append("[SEP]") 439 | segment_ids.append(1) 440 | 441 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 442 | 443 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 444 | # tokens are attended to. 445 | input_mask = [1] * len(input_ids) 446 | 447 | # Zero-pad up to the sequence length. 448 | while len(input_ids) < max_seq_length: 449 | input_ids.append(0) 450 | input_mask.append(0) 451 | segment_ids.append(0) 452 | 453 | assert len(input_ids) == max_seq_length 454 | assert len(input_mask) == max_seq_length 455 | assert len(segment_ids) == max_seq_length 456 | 457 | label_id = label_map[example.label] 458 | if ex_index < 5: 459 | tf.logging.info("*** Example ***") 460 | tf.logging.info("guid: %s" % (example.guid)) 461 | tf.logging.info("tokens: %s" % " ".join( 462 | [tokenization.printable_text(x) for x in tokens])) 463 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 464 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 465 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 466 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 467 | 468 | feature = InputFeatures( 469 | input_ids=input_ids, 470 | input_mask=input_mask, 471 | segment_ids=segment_ids, 472 | label_id=label_id, 473 | is_real_example=True) 474 | return feature 475 | 476 | 477 | def file_based_convert_examples_to_features( 478 | examples, label_list, max_seq_length, tokenizer, output_file): 479 | """Convert a set of `InputExample`s to a TFRecord file.""" 480 | 481 | writer = tf.python_io.TFRecordWriter(output_file) 482 | 483 | for (ex_index, example) in enumerate(examples): 484 | if ex_index % 10000 == 0: 485 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 486 | 487 | feature = convert_single_example(ex_index, example, label_list, 488 | max_seq_length, tokenizer) 489 | 490 | def create_int_feature(values): 491 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 492 | return f 493 | 494 | features = collections.OrderedDict() 495 | features["input_ids"] = create_int_feature(feature.input_ids) 496 | features["input_mask"] = create_int_feature(feature.input_mask) 497 | features["segment_ids"] = create_int_feature(feature.segment_ids) 498 | features["label_ids"] = create_int_feature([feature.label_id]) 499 | features["is_real_example"] = create_int_feature( 500 | [int(feature.is_real_example)]) 501 | 502 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 503 | writer.write(tf_example.SerializeToString()) 504 | writer.close() 505 | 506 | 507 | def file_based_input_fn_builder(input_file, seq_length, is_training, 508 | drop_remainder): 509 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 510 | 511 | name_to_features = { 512 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 513 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 514 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 515 | "label_ids": tf.FixedLenFeature([], tf.int64), 516 | "is_real_example": tf.FixedLenFeature([], tf.int64), 517 | } 518 | 519 | def _decode_record(record, name_to_features): 520 | """Decodes a record to a TensorFlow example.""" 521 | example = tf.parse_single_example(record, name_to_features) 522 | 523 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 524 | # So cast all int64 to int32. 525 | for name in list(example.keys()): 526 | t = example[name] 527 | if t.dtype == tf.int64: 528 | t = tf.to_int32(t) 529 | example[name] = t 530 | 531 | return example 532 | 533 | def input_fn(params): 534 | """The actual input function.""" 535 | batch_size = params["batch_size"] 536 | 537 | # For training, we want a lot of parallel reading and shuffling. 538 | # For eval, we want no shuffling and parallel reading doesn't matter. 539 | d = tf.data.TFRecordDataset(input_file) 540 | if is_training: 541 | d = d.repeat() 542 | d = d.shuffle(buffer_size=100) 543 | 544 | d = d.apply( 545 | tf.contrib.data.map_and_batch( 546 | lambda record: _decode_record(record, name_to_features), 547 | batch_size=batch_size, 548 | drop_remainder=drop_remainder)) 549 | 550 | return d 551 | 552 | return input_fn 553 | 554 | 555 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 556 | """Truncates a sequence pair in place to the maximum length.""" 557 | 558 | # This is a simple heuristic which will always truncate the longer sequence 559 | # one token at a time. This makes more sense than truncating an equal percent 560 | # of tokens from each, since if one sequence is very short then each token 561 | # that's truncated likely contains more information than a longer sequence. 562 | while True: 563 | total_length = len(tokens_a) + len(tokens_b) 564 | if total_length <= max_length: 565 | break 566 | if len(tokens_a) > len(tokens_b): 567 | tokens_a.pop() 568 | else: 569 | tokens_b.pop() 570 | 571 | 572 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 573 | labels, num_labels, use_one_hot_embeddings): 574 | """Creates a classification model.""" 575 | model = modeling.BertModel( 576 | config=bert_config, 577 | is_training=is_training, 578 | input_ids=input_ids, 579 | input_mask=input_mask, 580 | token_type_ids=segment_ids, 581 | use_one_hot_embeddings=use_one_hot_embeddings) 582 | 583 | # In the demo, we are doing a simple classification task on the entire 584 | # segment. 585 | # 586 | # If you want to use the token-level output, use model.get_sequence_output() 587 | # instead. 588 | output_layer = model.get_pooled_output() 589 | 590 | hidden_size = output_layer.shape[-1].value 591 | 592 | output_weights = tf.get_variable( 593 | "output_weights", [num_labels, hidden_size], 594 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 595 | 596 | output_bias = tf.get_variable( 597 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 598 | 599 | with tf.variable_scope("loss"): 600 | if is_training: 601 | # I.e., 0.1 dropout 602 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 603 | 604 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 605 | logits = tf.nn.bias_add(logits, output_bias) 606 | probabilities = tf.nn.softmax(logits, axis=-1) 607 | log_probs = tf.nn.log_softmax(logits, axis=-1) 608 | 609 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 610 | 611 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) # todo 08-29 try temp-loss 612 | ###############bi_tempered_logistic_loss############################################################################ 613 | # print("##cross entropy loss is used...."); tf.logging.info("##cross entropy loss is used....") 614 | # t1=0.9 #t1=0.90 615 | # t2=1.05 #t2=1.05 616 | # per_example_loss=bi_tempered_logistic_loss(log_probs,one_hot_labels,t1,t2,label_smoothing=0.1,num_iters=5) # TODO label_smoothing=0.0 617 | #tf.logging.info("per_example_loss:"+str(per_example_loss.shape)) 618 | ##############bi_tempered_logistic_loss############################################################################# 619 | 620 | loss = tf.reduce_mean(per_example_loss) 621 | 622 | return (loss, per_example_loss, logits, probabilities) 623 | 624 | 625 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 626 | num_train_steps, num_warmup_steps, use_tpu, 627 | use_one_hot_embeddings): 628 | """Returns `model_fn` closure for TPUEstimator.""" 629 | 630 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 631 | """The `model_fn` for TPUEstimator.""" 632 | 633 | tf.logging.info("*** Features ***") 634 | for name in sorted(features.keys()): 635 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 636 | 637 | input_ids = features["input_ids"] 638 | input_mask = features["input_mask"] 639 | segment_ids = features["segment_ids"] 640 | label_ids = features["label_ids"] 641 | is_real_example = None 642 | if "is_real_example" in features: 643 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 644 | else: 645 | is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) 646 | 647 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 648 | 649 | (total_loss, per_example_loss, logits, probabilities) = create_model( 650 | bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, 651 | num_labels, use_one_hot_embeddings) 652 | 653 | tvars = tf.trainable_variables() 654 | initialized_variable_names = {} 655 | scaffold_fn = None 656 | if init_checkpoint: 657 | (assignment_map, initialized_variable_names 658 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 659 | if use_tpu: 660 | 661 | def tpu_scaffold(): 662 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 663 | return tf.train.Scaffold() 664 | 665 | scaffold_fn = tpu_scaffold 666 | else: 667 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 668 | 669 | tf.logging.info("**** Trainable Variables ****") 670 | for var in tvars: 671 | init_string = "" 672 | if var.name in initialized_variable_names: 673 | init_string = ", *INIT_FROM_CKPT*" 674 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 675 | init_string) 676 | 677 | output_spec = None 678 | if mode == tf.estimator.ModeKeys.TRAIN: 679 | 680 | train_op = optimization.create_optimizer( 681 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 682 | 683 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 684 | mode=mode, 685 | loss=total_loss, 686 | train_op=train_op, 687 | scaffold_fn=scaffold_fn) 688 | elif mode == tf.estimator.ModeKeys.EVAL: 689 | 690 | def metric_fn(per_example_loss, label_ids, logits, is_real_example): 691 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 692 | accuracy = tf.metrics.accuracy( 693 | labels=label_ids, predictions=predictions, weights=is_real_example) 694 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) 695 | return { 696 | "eval_accuracy": accuracy, 697 | "eval_loss": loss, 698 | } 699 | 700 | eval_metrics = (metric_fn, 701 | [per_example_loss, label_ids, logits, is_real_example]) 702 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 703 | mode=mode, 704 | loss=total_loss, 705 | eval_metrics=eval_metrics, 706 | scaffold_fn=scaffold_fn) 707 | else: 708 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 709 | mode=mode, 710 | predictions={"probabilities": probabilities}, 711 | scaffold_fn=scaffold_fn) 712 | return output_spec 713 | 714 | return model_fn 715 | 716 | 717 | # This function is not used by this file but is still used by the Colab and 718 | # people who depend on it. 719 | def input_fn_builder(features, seq_length, is_training, drop_remainder): 720 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 721 | 722 | all_input_ids = [] 723 | all_input_mask = [] 724 | all_segment_ids = [] 725 | all_label_ids = [] 726 | 727 | for feature in features: 728 | all_input_ids.append(feature.input_ids) 729 | all_input_mask.append(feature.input_mask) 730 | all_segment_ids.append(feature.segment_ids) 731 | all_label_ids.append(feature.label_id) 732 | 733 | def input_fn(params): 734 | """The actual input function.""" 735 | batch_size = params["batch_size"] 736 | 737 | num_examples = len(features) 738 | 739 | # This is for demo purposes and does NOT scale to large data sets. We do 740 | # not use Dataset.from_generator() because that uses tf.py_func which is 741 | # not TPU compatible. The right way to load data is with TFRecordReader. 742 | d = tf.data.Dataset.from_tensor_slices({ 743 | "input_ids": 744 | tf.constant( 745 | all_input_ids, shape=[num_examples, seq_length], 746 | dtype=tf.int32), 747 | "input_mask": 748 | tf.constant( 749 | all_input_mask, 750 | shape=[num_examples, seq_length], 751 | dtype=tf.int32), 752 | "segment_ids": 753 | tf.constant( 754 | all_segment_ids, 755 | shape=[num_examples, seq_length], 756 | dtype=tf.int32), 757 | "label_ids": 758 | tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32), 759 | }) 760 | 761 | if is_training: 762 | d = d.repeat() 763 | d = d.shuffle(buffer_size=100) 764 | 765 | d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder) 766 | return d 767 | 768 | return input_fn 769 | 770 | class LCQMCPairClassificationProcessor(DataProcessor): # TODO NEED CHANGE2 771 | """Processor for the internal data set. sentence pair classification""" 772 | def __init__(self): 773 | self.language = "zh" 774 | 775 | def get_train_examples(self, data_dir): 776 | """See base class.""" 777 | return self._create_examples( 778 | self._read_tsv(os.path.join(data_dir, "train.txt")), "train") 779 | # dev_0827.tsv 780 | 781 | def get_dev_examples(self, data_dir): 782 | """See base class.""" 783 | return self._create_examples( 784 | self._read_tsv(os.path.join(data_dir, "test.txt")), "dev") # todo change temp for test purpose 785 | 786 | def get_test_examples(self, data_dir): 787 | """See base class.""" 788 | return self._create_examples( 789 | self._read_tsv(os.path.join(data_dir, "test.txt")), "test") 790 | 791 | def get_labels(self): 792 | """See base class.""" 793 | return ["0", "1"] 794 | #return ["-1","0", "1"] 795 | 796 | def _create_examples(self, lines, set_type): 797 | """Creates examples for the training and dev sets.""" 798 | examples = [] 799 | print("length of lines:",len(lines)) 800 | for (i, line) in enumerate(lines): 801 | #print('#i:',i,line) 802 | if i == 0: 803 | continue 804 | guid = "%s-%s" % (set_type, i) 805 | try: 806 | label = tokenization.convert_to_unicode(line[2]) 807 | text_a = tokenization.convert_to_unicode(line[0]) 808 | text_b = tokenization.convert_to_unicode(line[1]) 809 | examples.append( 810 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 811 | except Exception: 812 | print('###error.i:', i, line) 813 | return examples 814 | 815 | class SentencePairClassificationProcessor(DataProcessor): 816 | """Processor for the internal data set. sentence pair classification""" 817 | def __init__(self): 818 | self.language = "zh" 819 | 820 | def get_train_examples(self, data_dir): 821 | """See base class.""" 822 | return self._create_examples( 823 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 824 | # dev_0827.tsv 825 | 826 | def get_dev_examples(self, data_dir): 827 | """See base class.""" 828 | return self._create_examples( 829 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 830 | 831 | def get_test_examples(self, data_dir): 832 | """See base class.""" 833 | return self._create_examples( 834 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 835 | 836 | def get_labels(self): 837 | """See base class.""" 838 | return ["0", "1"] 839 | 840 | def _create_examples(self, lines, set_type): 841 | """Creates examples for the training and dev sets.""" 842 | examples = [] 843 | print("length of lines:",len(lines)) 844 | for (i, line) in enumerate(lines): 845 | #print('#i:',i,line) 846 | if i == 0: 847 | continue 848 | guid = "%s-%s" % (set_type, i) 849 | try: 850 | label = tokenization.convert_to_unicode(line[0]) 851 | text_a = tokenization.convert_to_unicode(line[1]) 852 | text_b = tokenization.convert_to_unicode(line[2]) 853 | examples.append( 854 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 855 | except Exception: 856 | print('###error.i:', i, line) 857 | return examples 858 | 859 | # This function is not used by this file but is still used by the Colab and 860 | # people who depend on it. 861 | def convert_examples_to_features(examples, label_list, max_seq_length, 862 | tokenizer): 863 | """Convert a set of `InputExample`s to a list of `InputFeatures`.""" 864 | 865 | features = [] 866 | for (ex_index, example) in enumerate(examples): 867 | if ex_index % 10000 == 0: 868 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 869 | 870 | feature = convert_single_example(ex_index, example, label_list, 871 | max_seq_length, tokenizer) 872 | 873 | features.append(feature) 874 | return features 875 | 876 | 877 | def main(_): 878 | tf.logging.set_verbosity(tf.logging.INFO) 879 | 880 | processors = { 881 | "cola": ColaProcessor, 882 | "mnli": MnliProcessor, 883 | "mrpc": MrpcProcessor, 884 | "xnli": XnliProcessor, 885 | "sentence_pair": SentencePairClassificationProcessor, 886 | "lcqmc_pair":LCQMCPairClassificationProcessor 887 | 888 | 889 | } 890 | 891 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, 892 | FLAGS.init_checkpoint) 893 | 894 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 895 | raise ValueError( 896 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 897 | 898 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 899 | 900 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 901 | raise ValueError( 902 | "Cannot use sequence length %d because the BERT model " 903 | "was only trained up to sequence length %d" % 904 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 905 | 906 | tf.gfile.MakeDirs(FLAGS.output_dir) 907 | 908 | task_name = FLAGS.task_name.lower() 909 | 910 | if task_name not in processors: 911 | raise ValueError("Task not found: %s" % (task_name)) 912 | 913 | processor = processors[task_name]() 914 | 915 | label_list = processor.get_labels() 916 | 917 | tokenizer = tokenization.FullTokenizer( 918 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 919 | 920 | tpu_cluster_resolver = None 921 | if FLAGS.use_tpu and FLAGS.tpu_name: 922 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 923 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 924 | 925 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 926 | # Cloud TPU: Invalid TPU configuration, ensure ClusterResolver is passed to tpu. 927 | print("###tpu_cluster_resolver:",tpu_cluster_resolver) 928 | run_config = tf.contrib.tpu.RunConfig( 929 | cluster=tpu_cluster_resolver, 930 | master=FLAGS.master, 931 | model_dir=FLAGS.output_dir, 932 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 933 | tpu_config=tf.contrib.tpu.TPUConfig( 934 | iterations_per_loop=FLAGS.iterations_per_loop, 935 | num_shards=FLAGS.num_tpu_cores, 936 | per_host_input_for_training=is_per_host)) 937 | 938 | train_examples = None 939 | num_train_steps = None 940 | num_warmup_steps = None 941 | if FLAGS.do_train: 942 | train_examples =processor.get_train_examples(FLAGS.data_dir) # TODO 943 | print("###length of total train_examples:",len(train_examples)) 944 | num_train_steps = int(len(train_examples)/ FLAGS.train_batch_size * FLAGS.num_train_epochs) 945 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 946 | 947 | model_fn = model_fn_builder( 948 | bert_config=bert_config, 949 | num_labels=len(label_list), 950 | init_checkpoint=FLAGS.init_checkpoint, 951 | learning_rate=FLAGS.learning_rate, 952 | num_train_steps=num_train_steps, 953 | num_warmup_steps=num_warmup_steps, 954 | use_tpu=FLAGS.use_tpu, 955 | use_one_hot_embeddings=FLAGS.use_tpu) 956 | 957 | # If TPU is not available, this will fall back to normal Estimator on CPU 958 | # or GPU. 959 | estimator = tf.contrib.tpu.TPUEstimator( 960 | use_tpu=FLAGS.use_tpu, 961 | model_fn=model_fn, 962 | config=run_config, 963 | train_batch_size=FLAGS.train_batch_size, 964 | eval_batch_size=FLAGS.eval_batch_size, 965 | predict_batch_size=FLAGS.predict_batch_size) 966 | 967 | if FLAGS.do_train: 968 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 969 | train_file_exists=os.path.exists(train_file) 970 | print("###train_file_exists:", train_file_exists," ;train_file:",train_file) 971 | if not train_file_exists: # if tf_record file not exist, convert from raw text file. # TODO 972 | file_based_convert_examples_to_features(train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) 973 | tf.logging.info("***** Running training *****") 974 | tf.logging.info(" Num examples = %d", len(train_examples)) 975 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 976 | tf.logging.info(" Num steps = %d", num_train_steps) 977 | train_input_fn = file_based_input_fn_builder( 978 | input_file=train_file, 979 | seq_length=FLAGS.max_seq_length, 980 | is_training=True, 981 | drop_remainder=True) 982 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 983 | 984 | if FLAGS.do_eval: 985 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 986 | num_actual_eval_examples = len(eval_examples) 987 | if FLAGS.use_tpu: 988 | # TPU requires a fixed batch size for all batches, therefore the number 989 | # of examples must be a multiple of the batch size, or else examples 990 | # will get dropped. So we pad with fake examples which are ignored 991 | # later on. These do NOT count towards the metric (all tf.metrics 992 | # support a per-instance weight, and these get a weight of 0.0). 993 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 994 | eval_examples.append(PaddingInputExample()) 995 | 996 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 997 | file_based_convert_examples_to_features( 998 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) 999 | 1000 | tf.logging.info("***** Running evaluation *****") 1001 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 1002 | len(eval_examples), num_actual_eval_examples, 1003 | len(eval_examples) - num_actual_eval_examples) 1004 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 1005 | 1006 | # This tells the estimator to run through the entire set. 1007 | eval_steps = None 1008 | # However, if running eval on the TPU, you will need to specify the 1009 | # number of steps. 1010 | if FLAGS.use_tpu: 1011 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 1012 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 1013 | 1014 | eval_drop_remainder = True if FLAGS.use_tpu else False 1015 | eval_input_fn = file_based_input_fn_builder( 1016 | input_file=eval_file, 1017 | seq_length=FLAGS.max_seq_length, 1018 | is_training=False, 1019 | drop_remainder=eval_drop_remainder) 1020 | 1021 | ####################################################################################################################### 1022 | # evaluate 所有的checkpoint 1023 | steps_and_files = [] 1024 | filenames = tf.gfile.ListDirectory(FLAGS.output_dir) 1025 | for filename in filenames: 1026 | if filename.endswith(".index"): 1027 | ckpt_name = filename[:-6] 1028 | cur_filename = os.path.join(FLAGS.output_dir, ckpt_name) 1029 | global_step = int(cur_filename.split("-")[-1]) 1030 | tf.logging.info("Add {} to eval list.".format(cur_filename)) 1031 | steps_and_files.append([global_step, cur_filename]) 1032 | steps_and_files = sorted(steps_and_files, key=lambda x: x[0]) 1033 | 1034 | output_eval_file = os.path.join(FLAGS.data_dir, "eval_results16-layer24-4million-2.txt") # finetuning-layer24-4million 1035 | print("output_eval_file:",output_eval_file) 1036 | tf.logging.info("output_eval_file:"+output_eval_file) 1037 | with tf.gfile.GFile(output_eval_file, "w") as writer: 1038 | for global_step, filename in sorted(steps_and_files, key=lambda x: x[0]): 1039 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps, checkpoint_path=filename) 1040 | 1041 | tf.logging.info("***** Eval results %s *****" % (filename)) 1042 | writer.write("***** Eval results %s *****\n" % (filename)) 1043 | for key in sorted(result.keys()): 1044 | tf.logging.info(" %s = %s", key, str(result[key])) 1045 | writer.write("%s = %s\n" % (key, str(result[key]))) 1046 | ####################################################################################################################### 1047 | 1048 | #result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 1049 | # 1050 | #output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 1051 | #with tf.gfile.GFile(output_eval_file, "w") as writer: 1052 | # tf.logging.info("***** Eval results *****") 1053 | # for key in sorted(result.keys()): 1054 | # tf.logging.info(" %s = %s", key, str(result[key])) 1055 | # writer.write("%s = %s\n" % (key, str(result[key]))) 1056 | 1057 | if FLAGS.do_predict: 1058 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 1059 | num_actual_predict_examples = len(predict_examples) 1060 | if FLAGS.use_tpu: 1061 | # TPU requires a fixed batch size for all batches, therefore the number 1062 | # of examples must be a multiple of the batch size, or else examples 1063 | # will get dropped. So we pad with fake examples which are ignored 1064 | # later on. 1065 | while len(predict_examples) % FLAGS.predict_batch_size != 0: 1066 | predict_examples.append(PaddingInputExample()) 1067 | 1068 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 1069 | file_based_convert_examples_to_features(predict_examples, label_list, 1070 | FLAGS.max_seq_length, tokenizer, 1071 | predict_file) 1072 | 1073 | tf.logging.info("***** Running prediction*****") 1074 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 1075 | len(predict_examples), num_actual_predict_examples, 1076 | len(predict_examples) - num_actual_predict_examples) 1077 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 1078 | 1079 | predict_drop_remainder = True if FLAGS.use_tpu else False 1080 | predict_input_fn = file_based_input_fn_builder( 1081 | input_file=predict_file, 1082 | seq_length=FLAGS.max_seq_length, 1083 | is_training=False, 1084 | drop_remainder=predict_drop_remainder) 1085 | 1086 | result = estimator.predict(input_fn=predict_input_fn) 1087 | 1088 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") 1089 | with tf.gfile.GFile(output_predict_file, "w") as writer: 1090 | num_written_lines = 0 1091 | tf.logging.info("***** Predict results *****") 1092 | for (i, prediction) in enumerate(result): 1093 | probabilities = prediction["probabilities"] 1094 | if i >= num_actual_predict_examples: 1095 | break 1096 | output_line = "\t".join( 1097 | str(class_probability) 1098 | for class_probability in probabilities) + "\n" 1099 | writer.write(output_line) 1100 | num_written_lines += 1 1101 | assert num_written_lines == num_actual_predict_examples 1102 | 1103 | 1104 | if __name__ == "__main__": 1105 | flags.mark_flag_as_required("data_dir") 1106 | flags.mark_flag_as_required("task_name") 1107 | flags.mark_flag_as_required("vocab_file") 1108 | flags.mark_flag_as_required("bert_config_file") 1109 | flags.mark_flag_as_required("output_dir") 1110 | tf.app.run() --------------------------------------------------------------------------------