├── imgs ├── api_example.jpg ├── dataset_desc.jpg ├── deploy-success.jpg ├── serving_offline.jpg └── serving_deploy_arcticture.jpg ├── requirements.txt ├── model_convert.sh ├── bert_classify_server.sh ├── cmrc ├── run_cmrc.sh ├── cmrc_eval.py └── cmrc_tool │ ├── tokenization.py │ ├── run_squad_inf.py │ └── run_squad_inf_cmrc.py ├── train.sh ├── export.sh ├── train_roberta_tiny_clue_command_gpu.sh ├── api ├── api_service.py └── api_service_flask.py ├── .gitignore ├── optimization.py ├── test_serving.py ├── freeze_graph.py ├── README.md ├── run_savedModel_infer.py ├── test_tf_serving.py ├── run_pb_inference.py ├── tokenization.py └── modeling.py /imgs/api_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongxiaohuang/TextClassifier_Transformer/HEAD/imgs/api_example.jpg -------------------------------------------------------------------------------- /imgs/dataset_desc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongxiaohuang/TextClassifier_Transformer/HEAD/imgs/dataset_desc.jpg -------------------------------------------------------------------------------- /imgs/deploy-success.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongxiaohuang/TextClassifier_Transformer/HEAD/imgs/deploy-success.jpg -------------------------------------------------------------------------------- /imgs/serving_offline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongxiaohuang/TextClassifier_Transformer/HEAD/imgs/serving_offline.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow. 2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. 3 | -------------------------------------------------------------------------------- /imgs/serving_deploy_arcticture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongxiaohuang/TextClassifier_Transformer/HEAD/imgs/serving_deploy_arcticture.jpg -------------------------------------------------------------------------------- /model_convert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #description: Convert model from ckpt format to pb format 3 | #如果在模型文件中($TRAINED_CLASSIFIER)存在label2id.pkl文件,此处可以不用指定num_labels参数 4 | 5 | export BERT_BASE_DIR=./chinese_roberta_zh_l12 6 | export TRAINED_CLASSIFIER=./output 7 | export EXP_NAME=mobile_0_roberta_base_epoch1 8 | 9 | python freeze_graph.py \ 10 | -bert_model_dir $BERT_BASE_DIR \ 11 | -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \ 12 | -max_seq_len 128 -------------------------------------------------------------------------------- /bert_classify_server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #chkconfig: 2345 80 90 3 | #description: ����BERT����ģ�� 4 | 5 | echo 'start BERT classify server...' 6 | rm -rf tmp* 7 | 8 | export BERT_BASE_DIR=./chinese_roberta_zh_l12 9 | export TRAINED_CLASSIFIER=./output 10 | export EXP_NAME=mobile_0_roberta_base_epoch1 11 | 12 | bert-base-serving-start \ 13 | -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \ 14 | -bert_model_dir $BERT_BASE_DIR \ 15 | -model_pb_dir $TRAINED_CLASSIFIER/$EXP_NAME \ 16 | -mode CLASS \ 17 | -max_seq_len 128 \ 18 | -port 5575 \ 19 | -port_out 5576 \ 20 | -device_map 1 21 | -------------------------------------------------------------------------------- /cmrc/run_cmrc.sh: -------------------------------------------------------------------------------- 1 | export SQUAD_DIR=./CMRC_DIR 2 | export BERT_BASE_DIR=/export/huangdongxiao/huangdongxiao/AI_QA/models/RoBERTa-tiny-clue 3 | python run_cmrc.py \ 4 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 5 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 6 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 7 | --do_train=False \ 8 | --do_export=True \ 9 | --do_predict=True \ 10 | --train_file=$SQUAD_DIR/train.json \ 11 | --predict_file=$SQUAD_DIR/dev.json \ 12 | --train_batch_size=16 \ 13 | --learning_rate=3e-5 \ 14 | --num_train_epochs=8.0 \ 15 | --max_seq_length=384 \ 16 | --doc_stride=128 \ 17 | --output_dir=./output/cmrc \ 18 | --version_2_with_negative=False -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #description: BERT fine-tuning 3 | 4 | export BERT_BASE_DIR=./chinese_roberta_zh_l12 5 | export DATA_DIR=./dat 6 | export TRAINED_CLASSIFIER=./output 7 | export MODEL_NAME=mobile_0_roberta_base_epoch1 8 | 9 | python run_classifier_serving.py \ 10 | --task_name=setiment \ 11 | --do_train=true \ 12 | --do_eval=true \ 13 | --do_predict=False \ 14 | --data_dir=$DATA_DIR \ 15 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 16 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 17 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 18 | --max_seq_length=128 \ 19 | --train_batch_size=32 \ 20 | --learning_rate=2e-5 \ 21 | --num_train_epochs=1.0 \ 22 | --output_dir=$TRAINED_CLASSIFIER/$MODEL_NAME -------------------------------------------------------------------------------- /export.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #description: BERT fine-tuning 3 | 4 | export BERT_BASE_DIR=./chinese_roberta_zh_l12 5 | export DATA_DIR=./dat 6 | export TRAINED_CLASSIFIER=./output 7 | export MODEL_NAME=mobile_0_roberta_base 8 | 9 | python run_classifier_serving.py \ 10 | --task_name=setiment \ 11 | --do_train=False \ 12 | --do_eval=False \ 13 | --do_predict=True \ 14 | --data_dir=$DATA_DIR \ 15 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 16 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 17 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 18 | --max_seq_length=128 \ 19 | --train_batch_size=32 \ 20 | --learning_rate=2e-5 \ 21 | --num_train_epochs=1.0 \ 22 | --output_dir=$TRAINED_CLASSIFIER/$MODEL_NAME \ 23 | --do_export=True \ 24 | --export_dir=exported -------------------------------------------------------------------------------- /train_roberta_tiny_clue_command_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #description: BERT fine-tuning 3 | 4 | export BERT_BASE_DIR=/export/huangdongxiao/huangdongxiao/AI_QA/models/RoBERTa-tiny-clue 5 | export DATA_DIR=/export/huangdongxiao/huangdongxiao/AI_QA/ALBERT/1-data/command 6 | export TRAINED_CLASSIFIER=./output 7 | export MODEL_NAME=roberta_tiny_clue_command_gpu 8 | 9 | export CUDA_VISIBLE_DEVICES=2 10 | python run_classifier_serving_gpu.py \ 11 | --task_name=command \ 12 | --do_train=false \ 13 | --do_eval=false \ 14 | --do_predict=true \ 15 | --do_export=false \ 16 | --do_frozen=true \ 17 | --data_dir=$DATA_DIR \ 18 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 19 | --test_file=test \ 20 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 21 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 22 | --max_seq_length=128 \ 23 | --train_batch_size=32 \ 24 | --learning_rate=1e-4 \ 25 | --num_train_epochs=6.0 \ 26 | --output_dir=$TRAINED_CLASSIFIER/$MODEL_NAME 27 | -------------------------------------------------------------------------------- /api/api_service.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: api_service_new 4 | Description : api客户端请求服务器,返回标签 5 | Author : 逸轩 6 | date: 2019/10/12 7 | 8 | """ 9 | 10 | import json 11 | import re 12 | import time 13 | from bert_base.client import BertClient 14 | 15 | bc = BertClient(ip='192.168.9.23', port=5575, port_out=5576, show_server_config=False, check_version=False, check_length=False, mode='CLASS') 16 | print('BertClient连接成功') 17 | 18 | # 切分句子 19 | def cut_sent(txt): 20 | # 先预处理去空格等 21 | txt = re.sub('([  \t]+)', r" ", txt) # blank word 22 | txt = txt.rstrip() # 段尾如果有多余的\n就去掉它 23 | nlist = txt.split(";") 24 | nnlist = [x for x in nlist if x.strip() != ''] # 过滤掉空行 25 | return nnlist 26 | 27 | 28 | # 对句子列表进行预测识别 29 | def class_pred(list_text): 30 | # 文本拆分成句子 31 | # list_text = cut_sent(text) 32 | print("total setance: %d" % (len(list_text))) 33 | # with BertClient(ip='192.168.9.23', port=5575, port_out=5576, show_server_config=False, check_version=False, 34 | # check_length=False, mode='CLASS') as bc: 35 | start_t = time.perf_counter() 36 | rst = bc.encode(list_text) 37 | # print('result:', rst) 38 | print('time used:{}s'.format(time.perf_counter() - start_t)) 39 | # 返回结构为: 40 | # rst: [{'pred_label': ['0', '1', '0'], 'score': [0.9983683228492737, 0.9988993406295776, 0.9997349381446838]}] 41 | # 抽取出标注结果 42 | pred_label = rst[0]["pred_label"] 43 | result_txt = [[pred_label[i], list_text[i]] for i in range(len(pred_label))] 44 | return result_txt 45 | 46 | if __name__ == '__main__': 47 | while True: 48 | text = input(r'请输入句子(多个句子请用;分隔):') 49 | list_text = cut_sent(text) 50 | # print(list_text) 51 | result = class_pred(list_text) 52 | print(result) 53 | -------------------------------------------------------------------------------- /api/api_service_flask.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: api_service_flask 4 | Description : api请求服务端,并基于flask提供接口待查询 5 | Author : 逸轩 6 | date: 2019/10/12 7 | 8 | """ 9 | 10 | import json 11 | import re 12 | import time 13 | from bert_base.client import BertClient 14 | from flask import Flask, request 15 | from flask_cors import CORS 16 | 17 | flaskAPP = Flask(import_name=__name__) 18 | CORS(flaskAPP, supports_credentials=True) 19 | 20 | bc = BertClient(ip='192.168.9.23', port=5575, port_out=5576, show_server_config=False, check_version=False, check_length=False, mode='CLASS') 21 | print('BertClient连接成功') 22 | 23 | # 切分句子 24 | def cut_sent(txt): 25 | # 先预处理去空格等 26 | txt = re.sub('([  \t]+)', r" ", txt) # blank word 27 | txt = txt.rstrip() # 段尾如果有多余的\n就去掉它 28 | nlist = txt.split("\n") 29 | nnlist = [x for x in nlist if x.strip() != ''] # 过滤掉空行 30 | return nnlist 31 | 32 | 33 | # 对句子列表进行预测识别 34 | def class_pred(list_text): 35 | # 文本拆分成句子 36 | # list_text = cut_sent(text) 37 | print("total setance: %d" % (len(list_text))) 38 | # with BertClient(ip='192.168.9.23', port=5575, port_out=5576, show_server_config=False, check_version=False, 39 | # check_length=False, mode='CLASS') as bc: 40 | start_t = time.perf_counter() 41 | rst = bc.encode(list_text) 42 | print('result:', rst) 43 | print('time used:{}'.format(time.perf_counter() - start_t)) 44 | # 返回结构为: 45 | # rst: [{'pred_label': ['0', '1', '0'], 'score': [0.9983683228492737, 0.9988993406295776, 0.9997349381446838]}] 46 | # 抽取出标注结果 47 | pred_label = rst[0]["pred_label"] 48 | result_txt = [[pred_label[i], list_text[i]] for i in range(len(pred_label))] 49 | return result_txt 50 | 51 | @flaskAPP.route('/predict_online', methods=['GET', 'POST']) 52 | def predict_online(): 53 | text = request.args.get('text') 54 | print('服务器接收到字段:') 55 | print('text:', text) 56 | print('==============================') 57 | lstseg = cut_sent(text) 58 | print('-' * 30) 59 | print('结果,共%d个句子:' % (len(lstseg))) 60 | # for x in lstseg: 61 | # print("第%d句:【 %s】" % (lstseg.index(x), x)) 62 | print('-' * 30) 63 | # if request.method == 'POST' or 1: 64 | # res['result'] = class_pred(lstseg) 65 | result = class_pred(lstseg) 66 | new_res_list = [] 67 | for term in result: 68 | if term[0] == '1': 69 | label = '好评' 70 | if term[0] == '0': 71 | label = '中评' 72 | if term[0] == '-1': 73 | label = '差评' 74 | new_res_list.append([label, term[1]]) 75 | new_res = {'result': new_res_list} 76 | # print('result:%s' % str(res)) 77 | print('result:%s' % str(new_res)) 78 | # return jsonify(res) 79 | return json.dumps(new_res, ensure_ascii=False) 80 | 81 | flaskAPP.run(host='192.168.9.23', port=8910, debug=True) 82 | 83 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __MACOSX/ 3 | *.tar.gz 4 | settings.json 5 | .vscode/ 6 | *jd* 7 | *JD* 8 | *.data 9 | checkpoint 10 | *.index 11 | *.pickle 12 | *.pickle 13 | *.pkl 14 | *backup 15 | *.out 16 | *.csv 17 | *.meta 18 | *.data-* 19 | .idea 20 | *.pyc 21 | # corpus 22 | # resources 23 | .DS_Store 24 | *.zip 25 | *.txt 26 | __pycache__/ 27 | *.py[cod] 28 | *$py.class 29 | 30 | # C extensions 31 | *.so 32 | 33 | # Distribution / packaging 34 | .Python 35 | build/ 36 | develop-eggs/ 37 | dist/ 38 | downloads/ 39 | eggs/ 40 | .eggs/ 41 | lib/ 42 | lib64/ 43 | parts/ 44 | sdist/ 45 | var/ 46 | wheels/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | MANIFEST 51 | 52 | # PyInstaller 53 | # Usually these files are written by a python script from a template 54 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 55 | *.manifest 56 | *.spec 57 | 58 | # Installer logs 59 | pip-log.txt 60 | pip-delete-this-directory.txt 61 | 62 | # Unit test / coverage reports 63 | htmlcov/ 64 | .tox/ 65 | .coverage 66 | .coverage.* 67 | .cache 68 | nosetests.xml 69 | coverage.xml 70 | *.cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # pyenv 100 | .python-version 101 | 102 | # celery beat schedule file 103 | celerybeat-schedule 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | 130 | # project related 131 | # Byte-compiled / optimized / DLL files 132 | __MACOSX/ 133 | *.tar.gz 134 | settings.json 135 | .vscode/ 136 | *jd* 137 | *JD* 138 | *.data 139 | checkpoint 140 | *.index 141 | *.pickle 142 | *.pickle 143 | *.pkl 144 | *backup 145 | *.out 146 | *.csv 147 | *.meta 148 | *.data-* 149 | .idea 150 | *.pyc 151 | # corpus 152 | # resources 153 | .DS_Store 154 | *.zip 155 | *.txt 156 | __pycache__/ 157 | *.py[cod] 158 | *$py.class 159 | 160 | # C extensions 161 | *.so 162 | 163 | # Distribution / packaging 164 | .Python 165 | build/ 166 | develop-eggs/ 167 | dist/ 168 | downloads/ 169 | eggs/ 170 | .eggs/ 171 | lib/ 172 | lib64/ 173 | parts/ 174 | sdist/ 175 | var/ 176 | wheels/ 177 | *.egg-info/ 178 | .installed.cfg 179 | *.egg 180 | MANIFEST 181 | 182 | # PyInstaller 183 | # Usually these files are written by a python script from a template 184 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 185 | *.manifest 186 | *.spec 187 | 188 | # Installer logs 189 | pip-log.txt 190 | pip-delete-this-directory.txt 191 | 192 | # Unit test / coverage reports 193 | htmlcov/ 194 | .tox/ 195 | .coverage 196 | .coverage.* 197 | .cache 198 | nosetests.xml 199 | coverage.xml 200 | *.cover 201 | .hypothesis/ 202 | .pytest_cache/ 203 | 204 | # Translations 205 | *.mo 206 | *.pot 207 | 208 | # Django stuff: 209 | *.log 210 | local_settings.py 211 | db.sqlite3 212 | 213 | # Flask stuff: 214 | instance/ 215 | .webassets-cache 216 | 217 | # Scrapy stuff: 218 | .scrapy 219 | 220 | # Sphinx documentation 221 | docs/_build/ 222 | 223 | # PyBuilder 224 | target/ 225 | 226 | # Jupyter Notebook 227 | .ipynb_checkpoints 228 | 229 | # pyenv 230 | .python-version 231 | 232 | # celery beat schedule file 233 | celerybeat-schedule 234 | 235 | # SageMath parsed files 236 | *.sage.py 237 | 238 | # Environments 239 | .env 240 | .venv 241 | env/ 242 | venv/ 243 | ENV/ 244 | env.bak/ 245 | venv.bak/ 246 | 247 | # Spyder project settings 248 | .spyderproject 249 | .spyproject 250 | 251 | # Rope project settings 252 | .ropeproject 253 | 254 | # mkdocs documentation 255 | /site 256 | 257 | # mypy 258 | .mypy_cache/ 259 | 260 | # project related 261 | cmrc_1588284341 262 | 1*/ 263 | -------------------------------------------------------------------------------- /cmrc/cmrc_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Evaluation script for CMRC 2018 4 | version: v5 - special 5 | Note: 6 | v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets 7 | v5: formatted output, add usage description 8 | v4: fixed segmentation issues 9 | ''' 10 | from __future__ import print_function 11 | from collections import Counter, OrderedDict 12 | import string 13 | import re 14 | import argparse 15 | import json 16 | import sys 17 | if sys.version[0] == '2': 18 | reload(sys) 19 | sys.setdefaultencoding("utf-8") 20 | import nltk 21 | nltk.download('punkt') 22 | import pdb 23 | 24 | # split Chinese with English 25 | def mixed_segmentation(in_str, rm_punc=False): 26 | in_str = str(in_str).lower().strip() 27 | segs_out = [] 28 | temp_str = "" 29 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 30 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 31 | '「','」','(',')','-','~','『','』'] 32 | for char in in_str: 33 | if rm_punc and char in sp_char: 34 | continue 35 | if re.search(u'[\u4e00-\u9fa5]', char) or char in sp_char: # 匹配所有中文 36 | if temp_str != "": 37 | ss = nltk.word_tokenize(temp_str) 38 | segs_out.extend(ss) 39 | temp_str = "" 40 | segs_out.append(char) 41 | else: 42 | temp_str += char 43 | 44 | #handling last part 45 | if temp_str != "": 46 | ss = nltk.word_tokenize(temp_str) 47 | segs_out.extend(ss) 48 | 49 | return segs_out 50 | 51 | 52 | # remove punctuation 53 | def remove_punctuation(in_str): 54 | in_str = str(in_str).lower().strip() 55 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=', 56 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、', 57 | '「','」','(',')','-','~','『','』'] 58 | out_segs = [] 59 | for char in in_str: 60 | if char in sp_char: 61 | continue 62 | else: 63 | out_segs.append(char) 64 | return ''.join(out_segs) 65 | 66 | 67 | # find longest common string 68 | def find_lcs(s1, s2): 69 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)] 70 | mmax = 0 71 | p = 0 72 | for i in range(len(s1)): 73 | for j in range(len(s2)): 74 | if s1[i] == s2[j]: 75 | m[i+1][j+1] = m[i][j]+1 76 | if m[i+1][j+1] > mmax: 77 | mmax=m[i+1][j+1] 78 | p=i+1 79 | return s1[p-mmax:p], mmax 80 | 81 | # 82 | def evaluate(ground_truth_file, prediction_file): 83 | f1 = 0 84 | em = 0 85 | total_count = 0 86 | skip_count = 0 87 | for instance in ground_truth_file["data"]: 88 | #context_id = instance['context_id'].strip() 89 | #context_text = instance['context_text'].strip() 90 | for para in instance["paragraphs"]: 91 | for qas in para['qas']: 92 | total_count += 1 93 | query_id = qas['id'].strip() 94 | query_text = qas['question'].strip() 95 | answers = [x["text"] for x in qas['answers']] 96 | 97 | if query_id not in prediction_file: 98 | sys.stderr.write('Unanswered question: {}\n'.format(query_id)) 99 | skip_count += 1 100 | continue 101 | 102 | prediction = str(prediction_file[query_id]) 103 | f1 += calc_f1_score(answers, prediction) 104 | em += calc_em_score(answers, prediction) 105 | 106 | f1_score = 100.0 * f1 / total_count 107 | em_score = 100.0 * em / total_count 108 | return f1_score, em_score, total_count, skip_count 109 | 110 | 111 | def calc_f1_score(answers, prediction): 112 | f1_scores = [] 113 | for ans in answers: 114 | ans_segs = mixed_segmentation(ans, rm_punc=True) 115 | prediction_segs = mixed_segmentation(prediction, rm_punc=True) 116 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs) 117 | if lcs_len == 0: 118 | f1_scores.append(0) 119 | continue 120 | precision = 1.0*lcs_len/len(prediction_segs) 121 | recall = 1.0*lcs_len/len(ans_segs) 122 | f1 = (2*precision*recall)/(precision+recall) 123 | f1_scores.append(f1) 124 | return max(f1_scores) 125 | 126 | 127 | def calc_em_score(answers, prediction): 128 | em = 0 129 | for ans in answers: 130 | ans_ = remove_punctuation(ans) 131 | prediction_ = remove_punctuation(prediction) 132 | if ans_ == prediction_: 133 | em = 1 134 | break 135 | return em 136 | 137 | if __name__ == '__main__': 138 | parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018') 139 | parser.add_argument('--dataset_file', type=str, default="/Users/huangdongxiao2/CodeRepos/SesameSt/bert/output/cmrc/dev.json" , help='Official dataset file') 140 | parser.add_argument('--prediction_file', type=str, default="/Users/huangdongxiao2/CodeRepos/SesameSt/bert/output/cmrc/predictions.json" ,help='Your prediction File') 141 | args = parser.parse_args() 142 | ground_truth_file = json.load(open(args.dataset_file, 'rb')) 143 | prediction_file = json.load(open(args.prediction_file, 'rb')) 144 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file) 145 | AVG = (EM+F1)*0.5 146 | output_result = OrderedDict() 147 | output_result['AVERAGE'] = '%.3f' % AVG 148 | output_result['F1'] = '%.3f' % F1 149 | output_result['EM'] = '%.3f' % EM 150 | output_result['TOTAL'] = TOTAL 151 | output_result['SKIP'] = SKIP 152 | output_result['FILE'] = args.prediction_file 153 | print(json.dumps(output_result)) -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /test_serving.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: test_serving 4 | Description : 5 | Author : 逸轩 6 | date: 2019/10/12 7 | 8 | """ 9 | 10 | import tensorflow as tf 11 | import tokenization 12 | 13 | class InputExample(object): 14 | """A single training/test example for simple sequence classification.""" 15 | 16 | def __init__(self, guid, text_a, text_b=None, label=None): 17 | """Constructs a InputExample. 18 | 19 | Args: 20 | guid: Unique id for the example. 21 | text_a: string. The untokenized text of the first sequence. For single 22 | sequence tasks, only this sequence must be specified. 23 | text_b: (Optional) string. The untokenized text of the second sequence. 24 | Only must be specified for sequence pair tasks. 25 | label: (Optional) string. The label of the example. This should be 26 | specified for train and dev examples, but not for test examples. 27 | """ 28 | self.guid = guid 29 | self.text_a = text_a 30 | self.text_b = text_b 31 | self.label = label 32 | 33 | class PaddingInputExample(object): 34 | """Fake example so the num input examples is a multiple of the batch size. 35 | 36 | When running eval/predict on the TPU, we need to pad the number of examples 37 | to be a multiple of the batch size, because the TPU requires a fixed batch 38 | size. The alternative is to drop the last batch, which is bad because it means 39 | the entire output data won't be generated. 40 | 41 | We use this class instead of `None` because treating `None` as padding 42 | battches could cause silent errors. 43 | """ 44 | 45 | 46 | class InputFeatures(object): 47 | """A single set of features of data.""" 48 | 49 | def __init__(self, 50 | input_ids, 51 | input_mask, 52 | segment_ids, 53 | label_id, 54 | is_real_example=True): 55 | self.input_ids = input_ids 56 | self.input_mask = input_mask 57 | self.segment_ids = segment_ids 58 | self.label_id = label_id 59 | self.is_real_example = is_real_example 60 | 61 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 62 | """Truncates a sequence pair in place to the maximum length.""" 63 | 64 | # This is a simple heuristic which will always truncate the longer sequence 65 | # one token at a time. This makes more sense than truncating an equal percent 66 | # of tokens from each, since if one sequence is very short then each token 67 | # that's truncated likely contains more information than a longer sequence. 68 | while True: 69 | total_length = len(tokens_a) + len(tokens_b) 70 | if total_length <= max_length: 71 | break 72 | if len(tokens_a) > len(tokens_b): 73 | tokens_a.pop() 74 | else: 75 | tokens_b.pop() 76 | 77 | 78 | def convert_single_example(ex_index, example, label_list, max_seq_length, 79 | tokenizer): 80 | """Converts a single `InputExample` into a single `InputFeatures`.""" 81 | 82 | if isinstance(example, PaddingInputExample): 83 | return InputFeatures( 84 | input_ids=[0] * max_seq_length, 85 | input_mask=[0] * max_seq_length, 86 | segment_ids=[0] * max_seq_length, 87 | label_id=0, 88 | is_real_example=False) 89 | 90 | label_map = {} 91 | for (i, label) in enumerate(label_list): 92 | label_map[label] = i 93 | 94 | tokens_a = tokenizer.tokenize(example.text_a) 95 | tokens_b = None 96 | if example.text_b: 97 | tokens_b = tokenizer.tokenize(example.text_b) 98 | 99 | if tokens_b: 100 | # Modifies `tokens_a` and `tokens_b` in place so that the total 101 | # length is less than the specified length. 102 | # Account for [CLS], [SEP], [SEP] with "- 3" 103 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 104 | else: 105 | # Account for [CLS] and [SEP] with "- 2" 106 | if len(tokens_a) > max_seq_length - 2: 107 | tokens_a = tokens_a[0:(max_seq_length - 2)] 108 | 109 | # The convention in BERT is: 110 | # (a) For sequence pairs: 111 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 112 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 113 | # (b) For single sequences: 114 | # tokens: [CLS] the dog is hairy . [SEP] 115 | # type_ids: 0 0 0 0 0 0 0 116 | # 117 | # Where "type_ids" are used to indicate whether this is the first 118 | # sequence or the second sequence. The embedding vectors for `type=0` and 119 | # `type=1` were learned during pre-training and are added to the wordpiece 120 | # embedding vector (and position vector). This is not *strictly* necessary 121 | # since the [SEP] token unambiguously separates the sequences, but it makes 122 | # it easier for the model to learn the concept of sequences. 123 | # 124 | # For classification tasks, the first vector (corresponding to [CLS]) is 125 | # used as the "sentence vector". Note that this only makes sense because 126 | # the entire model is fine-tuned. 127 | tokens = [] 128 | segment_ids = [] 129 | tokens.append("[CLS]") 130 | segment_ids.append(0) 131 | for token in tokens_a: 132 | tokens.append(token) 133 | segment_ids.append(0) 134 | tokens.append("[SEP]") 135 | segment_ids.append(0) 136 | 137 | if tokens_b: 138 | for token in tokens_b: 139 | tokens.append(token) 140 | segment_ids.append(1) 141 | tokens.append("[SEP]") 142 | segment_ids.append(1) 143 | 144 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 145 | 146 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 147 | # tokens are attended to. 148 | input_mask = [1] * len(input_ids) 149 | 150 | # Zero-pad up to the sequence length. 151 | while len(input_ids) < max_seq_length: 152 | input_ids.append(0) 153 | input_mask.append(0) 154 | segment_ids.append(0) 155 | 156 | assert len(input_ids) == max_seq_length 157 | assert len(input_mask) == max_seq_length 158 | assert len(segment_ids) == max_seq_length 159 | 160 | # debug xmxoxo 2019/3/13 161 | # print(ex_index,example.text_a) 162 | 163 | label_id = label_map[example.label] 164 | if ex_index < 5: 165 | tf.logging.info("*** Example ***") 166 | tf.logging.info("guid: %s" % (example.guid)) 167 | tf.logging.info("tokens: %s" % " ".join( 168 | [tokenization.printable_text(x) for x in tokens])) 169 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 170 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 171 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 172 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 173 | 174 | feature = InputFeatures( 175 | input_ids=input_ids, 176 | input_mask=input_mask, 177 | segment_ids=segment_ids, 178 | label_id=label_id, 179 | is_real_example=True) 180 | return feature 181 | 182 | if __name__ == '__main__': 183 | predict_fn = tf.contrib.predictor.from_saved_model('exported/1571054350') 184 | label_list = ["-1", "0", "1"] 185 | max_seq_length = 128 186 | tokenizer = tokenization.FullTokenizer(vocab_file='chinese_roberta_zh_l12/vocab.txt', do_lower_case=True) 187 | print('模型加载完毕!正在监听》》》') 188 | while True: 189 | question = input("> ") 190 | predict_example = InputExample("id", question, None, '0') 191 | feature = convert_single_example(100, predict_example, label_list, 192 | max_seq_length, tokenizer) 193 | 194 | prediction = predict_fn({ 195 | "input_ids": [feature.input_ids], 196 | "input_mask": [feature.input_mask], 197 | "segment_ids": [feature.segment_ids], 198 | "label_ids": [feature.label_id], 199 | }) 200 | probabilities = prediction["probabilities"] 201 | label = label_list[probabilities.argmax()] 202 | print(label) -------------------------------------------------------------------------------- /freeze_graph.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: test_serving 4 | Description : BERT模型文件 ckpt转pb 工具 5 | Author : stephen 6 | date: 2019/10/12 7 | 8 | """ 9 | 10 | import os 11 | from termcolor import colored 12 | import modeling 13 | import logging 14 | import tensorflow as tf 15 | import argparse 16 | import pickle 17 | 18 | 19 | def set_logger(context, verbose=False): 20 | if os.name == 'nt': # for Windows 21 | return NTLogger(context, verbose) 22 | 23 | logger = logging.getLogger(context) 24 | logger.setLevel(logging.DEBUG if verbose else logging.INFO) 25 | formatter = logging.Formatter( 26 | '%(levelname)-.1s:' + context + ':[%(filename).3s:%(funcName).3s:%(lineno)3d]:%(message)s', datefmt= 27 | '%m-%d %H:%M:%S') 28 | console_handler = logging.StreamHandler() 29 | console_handler.setLevel(logging.DEBUG if verbose else logging.INFO) 30 | console_handler.setFormatter(formatter) 31 | logger.handlers = [] 32 | logger.addHandler(console_handler) 33 | return logger 34 | 35 | 36 | class NTLogger: 37 | def __init__(self, context, verbose): 38 | self.context = context 39 | self.verbose = verbose 40 | 41 | def info(self, msg, **kwargs): 42 | print('I:%s:%s' % (self.context, msg), flush=True) 43 | 44 | def debug(self, msg, **kwargs): 45 | if self.verbose: 46 | print('D:%s:%s' % (self.context, msg), flush=True) 47 | 48 | def error(self, msg, **kwargs): 49 | print('E:%s:%s' % (self.context, msg), flush=True) 50 | 51 | def warning(self, msg, **kwargs): 52 | print('W:%s:%s' % (self.context, msg), flush=True) 53 | 54 | def create_classification_model(bert_config, is_training, input_ids, input_mask, segment_ids, labels, num_labels): 55 | """ 56 | 57 | :param bert_config: 58 | :param is_training: 59 | :param input_ids: 60 | :param input_mask: 61 | :param segment_ids: 62 | :param labels: 63 | :param num_labels: 64 | :param use_one_hot_embedding: 65 | :return: 66 | """ 67 | 68 | #import tensorflow as tf 69 | #import modeling 70 | 71 | # 通过传入的训练数据,进行representation 72 | model = modeling.BertModel( 73 | config=bert_config, 74 | is_training=is_training, 75 | input_ids=input_ids, 76 | input_mask=input_mask, 77 | token_type_ids=segment_ids, 78 | ) 79 | 80 | embedding_layer = model.get_sequence_output() 81 | output_layer = model.get_pooled_output() 82 | hidden_size = output_layer.shape[-1].value 83 | 84 | # predict = CNN_Classification(embedding_chars=embedding_layer, 85 | # labels=labels, 86 | # num_tags=num_labels, 87 | # sequence_length=FLAGS.max_seq_length, 88 | # embedding_dims=embedding_layer.shape[-1].value, 89 | # vocab_size=0, 90 | # filter_sizes=[3, 4, 5], 91 | # num_filters=3, 92 | # dropout_keep_prob=FLAGS.dropout_keep_prob, 93 | # l2_reg_lambda=0.001) 94 | # loss, predictions, probabilities = predict.add_cnn_layer() 95 | 96 | output_weights = tf.get_variable( 97 | "output_weights", [num_labels, hidden_size], 98 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 99 | 100 | output_bias = tf.get_variable( 101 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 102 | 103 | with tf.variable_scope("loss"): 104 | if is_training: 105 | # I.e., 0.1 dropout 106 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 107 | 108 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 109 | logits = tf.nn.bias_add(logits, output_bias) 110 | probabilities = tf.nn.softmax(logits, axis=-1) 111 | log_probs = tf.nn.log_softmax(logits, axis=-1) 112 | 113 | if labels is not None: 114 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 115 | 116 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 117 | loss = tf.reduce_mean(per_example_loss) 118 | else: 119 | loss, per_example_loss = None, None 120 | return (loss, per_example_loss, logits, probabilities) 121 | 122 | 123 | def init_predict_var(path): 124 | num_labels = 2 125 | label2id = None 126 | id2label = None 127 | label2id_file = os.path.join(path, 'label2id.pkl') 128 | if os.path.exists(label2id_file): 129 | with open(label2id_file, 'rb') as rf: 130 | label2id = pickle.load(rf) 131 | id2label = {value: key for key, value in label2id.items()} 132 | num_labels = len(label2id.items()) 133 | print('num_labels:%d' % num_labels) 134 | else: 135 | print('Can\'t found %s' % label2id_file) 136 | return num_labels, label2id, id2label 137 | 138 | 139 | 140 | def optimize_class_model(args, logger=None): 141 | """ 142 | 加载中文分类模型 143 | :param args: 144 | :param num_labels: 145 | :param logger: 146 | :return: 147 | """ 148 | 149 | if not logger: 150 | logger = set_logger(colored('CLASSIFICATION_MODEL, Lodding...', 'cyan'), args.verbose) 151 | pass 152 | try: 153 | # 如果PB文件已经存在则,返回PB文件的路径,否则将模型转化为PB文件,并且返回存储PB文件的路径 154 | if args.model_pb_dir is None: 155 | tmp_dir = args.model_dir 156 | else: 157 | tmp_dir = args.model_pb_dir 158 | 159 | pb_file = os.path.join(tmp_dir, 'classification_model.pb') 160 | if os.path.exists(pb_file): 161 | print('pb_file exits', pb_file) 162 | return pb_file 163 | 164 | #增加 从label2id.pkl中读取num_labels, 这样也可以不用指定num_labels参数; 2019/4/17 165 | if not args.num_labels: 166 | num_labels, label2id, id2label = init_predict_var(tmp_dir) 167 | #--- 168 | 169 | graph = tf.Graph() 170 | with graph.as_default(): 171 | with tf.Session() as sess: 172 | input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids') 173 | input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask') 174 | 175 | bert_config = modeling.BertConfig.from_json_file(os.path.join(args.bert_model_dir, 'bert_config.json')) 176 | 177 | loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False, 178 | input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels) 179 | 180 | # pred_ids = tf.argmax(probabilities, axis=-1, output_type=tf.int32, name='pred_ids') 181 | # pred_ids = tf.identity(pred_ids, 'pred_ids') 182 | 183 | probabilities = tf.identity(probabilities, 'pred_prob') 184 | saver = tf.train.Saver() 185 | 186 | with tf.Session() as sess: 187 | sess.run(tf.global_variables_initializer()) 188 | latest_checkpoint = tf.train.latest_checkpoint(args.model_dir) 189 | logger.info('loading... %s ' % latest_checkpoint ) 190 | saver.restore(sess,latest_checkpoint ) 191 | logger.info('freeze...') 192 | from tensorflow.python.framework import graph_util 193 | tmp_g = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['pred_prob']) 194 | logger.info('predict cut finished !!!') 195 | 196 | # 存储二进制模型到文件中 197 | logger.info('write graph to a tmp file: %s' % pb_file) 198 | with tf.gfile.GFile(pb_file, 'wb') as f: 199 | f.write(tmp_g.SerializeToString()) 200 | return pb_file 201 | except Exception as e: 202 | logger.error('fail to optimize the graph! %s' % e, exc_info=True) 203 | 204 | 205 | 206 | 207 | if __name__ == '__main__': 208 | # pass 209 | 210 | """ 211 | bert_model_dir="/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12" 212 | model_dir="/mnt/sda1/transdat/bert-demo/bert/output/demo7" 213 | model_pb_dir=model_dir 214 | max_seq_len=128 215 | num_labels=2 216 | """ 217 | 218 | 219 | parser = argparse.ArgumentParser(description='Trans ckpt file to .pb file') 220 | parser.add_argument('-bert_model_dir', type=str, required=True, 221 | help='chinese google bert model path') 222 | parser.add_argument('-model_dir', type=str, required=True, 223 | help='directory of a pretrained BERT model') 224 | parser.add_argument('-model_pb_dir', type=str, default=None, 225 | help='directory of a pretrained BERT model,default = model_dir') 226 | parser.add_argument('-max_seq_len', type=int, default=128, 227 | help='maximum length of a sequence,default:128') 228 | parser.add_argument('-num_labels', type=int, default=None, 229 | help='length of all labels,default=2') 230 | parser.add_argument('-verbose', action='store_true', default=False, 231 | help='turn on tensorflow logging for debug') 232 | 233 | args = parser.parse_args() 234 | 235 | optimize_class_model(args, logger=None) 236 | 237 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TextClassifier_Transformer 2 | 个人基于谷歌开源的BERT编写的文本分类器(基于微调方式),可自由加载NLP领域知名的预训练语言模型BERT、 3 | Roberta、ALBert及其wwm版本,同时适配ERNIE1.0.
4 | 该项目支持两种预测方式:
5 | (1)线下实时预测
6 | (2)服务端实时预测
7 | 8 | ## 新增改动 9 | 2020-03-25:
10 | (1)项目名由'TextClassifier_BERT'更改为'TextClassifier_Transformer';
11 | (2)新增ELECTRA、AlBert两个预训练模型。
12 | **注意:在使用AlBert时,请将该项目下的modeling.py文件更新为ALBert项目中下的modeling.py,而后在运行**
13 | 2020-03-04:
14 | 模型部署增加tf-serving机制,具体实施方式见[This Blog](https://Vincent131499.github.io/2020/02/28/以BERT分类为例阐述模型部署关键技术) 15 | 16 | ## 运行环境 17 | * Python3.6+
18 | * Tensorflow1.10+/Tensorflow-gpu1.10+
19 |
20 | 提供知名的预训练语言模型下载地址(其中百度开源的Ernie模型已转换成tf格式):
21 | Bert-base:链接:https://pan.baidu.com/s/18h9zgvnlU5ztwaQNnzBXTg 提取码:9r1z
22 | Roberta:链接:https://pan.baidu.com/s/1FBpok7U9ekYJRu1a8NSM-Q 提取码:i50r
23 | Bert-wwm:链接:链接:https://pan.baidu.com/s/1lhoJCT_LkboC1_1YXk1ItQ 提取码:ejt7
24 | ERNIE1.0:链接:链接:https://pan.baidu.com/s/1S6MI8rQyQ4U7dLszyb73Yw 提取码:gc6f
25 | ELECTRA-Tiny:链接:https://pan.baidu.com/s/11QaL7A4YSCYq4YlGyU1_vA 提取码:27jb
26 | AlBert-base:链接:https://pan.baidu.com/s/1U7Zx73ngci2Oqp3SLaVOaw 提取码:uijw
27 |
28 | 29 | ## 项目说明 30 | 主要分为两种运行模式:
31 | 模式1:线下实时预测
32 | step1:数据准备
33 | step2:模型训练
34 | step3:模型导出
35 | step4:线下实时预测
36 | 模式2:服务端实时预测 37 | step1:数据准备
38 | step2:模型训练
39 | step3:模型转换
40 | step4:服务部署
41 | step5:应用端
42 | ### 注意事项 43 | 1.如果你只是想体验从模型训练到本地线下预测这一套流程,只需要按照模式1依次执行即可
44 | 2.若你想想体验从模型训练到模型部署整个流程,则需要按照模式2依次执行
45 |
下面将针对以上两个模式的运行方式进行详细说明。
46 | ## 模式1:线下实时预测 47 | ### Step1:数据准备 48 | 为了快速实验项目效果,这里使用了样本规模较小的手机评论数据,数据比较简单,有三个分类:-1(差评)、0(中评)、1(好评),数据样例如下所示:
49 | ![数据描述](https://github.com/Vincent131499/TextClassifier_BERT/raw/master/imgs/dataset_desc.jpg) 50 | ps:本项目中已将其拆分成了train.tsv、dev.txv、test.tsv三个文件
51 | ### Step2:模型训练 52 | 训练命令:
53 | ```Bash 54 | bash train.sh 55 | ``` 56 | train.sh参数说明: 57 | ```Bash 58 | export BERT_BASE_DIR=./chinese_roberta_zh_l12 #指定预训练的语言模型所在路径 59 | export DATA_DIR=./dat #指定数据集所在路径 60 | export TRAINED_CLASSIFIER=./output #训练的模型输出路径 61 | export MODEL_NAME=mobile_0_roberta_base #训练的模型命名 62 | ``` 63 | 详细说明:训练模型直接使用bert微调的方式进行训练,对应的程序文件为run_classifier_serving.py。关于微调bert进行训练的代码网上介绍的 64 | 很多,这里就不一一介绍。主要是创建针对该任务的Processor即:SentimentProcessor,在这个processor的_create_examples()和get_labels()函数自定义,如下所示: 65 | ```Python 66 | class SetimentProcessor(DataProcessor): 67 | def get_train_examples(self, data_dir): 68 | """See base class.""" 69 | return self._create_examples( 70 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 71 | 72 | def get_dev_examples(self, data_dir): 73 | """See base class.""" 74 | return self._create_examples( 75 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 76 | 77 | def get_test_examples(self, data_dir): 78 | """See base class.""" 79 | return self._create_examples( 80 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 81 | 82 | def get_labels(self): 83 | """See base class.""" 84 | 85 | """ 86 | if not os.path.exists(os.path.join(FLAGS.output_dir, 'label_list.pkl')): 87 | with codecs.open(os.path.join(FLAGS.output_dir, 'label_list.pkl'), 'wb') as fd: 88 | pickle.dump(self.labels, fd) 89 | """ 90 | return ["-1", "0", "1"] 91 | 92 | def _create_examples(self, lines, set_type): 93 | """Creates examples for the training and dev sets.""" 94 | examples = [] 95 | for (i, line) in enumerate(lines): 96 | if i == 0: 97 | continue 98 | guid = "%s-%s" % (set_type, i) 99 | 100 | #debug (by xmxoxo) 101 | #print("read line: No.%d" % i) 102 | 103 | text_a = tokenization.convert_to_unicode(line[1]) 104 | if set_type == "test": 105 | label = "0" 106 | else: 107 | label = tokenization.convert_to_unicode(line[0]) 108 | examples.append( 109 | InputExample(guid=guid, text_a=text_a, label=label)) 110 | return examples 111 | ``` 112 |
注意,此处作出的一个特别变动之处是在conver_single_example()函数中增加了一段保存label的代码,在训练过程中在保存的模型路径下生成label2id.pkl文件,代码如下所示:
113 | ```Python 114 | #--- save label2id.pkl --- 115 | #在这里输出label2id.pkl , add by stephen 2019-10-12 116 | output_label2id_file = os.path.join(FLAGS.output_dir, "label2id.pkl") 117 | if not os.path.exists(output_label2id_file): 118 | with open(output_label2id_file,'wb') as w: 119 | pickle.dump(label_map,w) 120 | #--- Add end --- 121 | ``` 122 | ### Step3:模型导出 123 | 运行如下命令: 124 | ```Bash 125 | bash export.sh 126 | ``` 127 | export.sh参数说明: 128 | ```Bash 129 | #以下四个参数应与train.sh中设置的值保持一致 130 | export BERT_BASE_DIR=./chinese_roberta_zh_l12 131 | export DATA_DIR=./dat 132 | export TRAINED_CLASSIFIER=./output 133 | export MODEL_NAME=mobile_0_roberta_base 134 | ``` 135 | 会在指定的exported目录下生成以一个时间戳命名的模型目录。
136 | 详细说明:run_classifier.py 主要设计为单次运行的目的,如果把 do_predict 参数设置成 True,倒也确实可以预测,但输入样本是基于文件的,并且不支持将模型持久化在内存里进行 serving,因此需要自己改一些代码,达到两个目的:
137 | (1)允许将模型加载到内存里,即:允许一次加载,多次调用。
138 | (2)允许读取非文件中的样本进行预测。譬如从标准输入流读取样本输入。
139 | * 将模型加载到内存里
140 | run_classifier.py 的 859 行加载了模型为 estimator 变量,但是遗憾的是 estimator 原生并不支持一次加载,多次预测。参见:https://guillaumegenthial.github.io/serving-tensorflow-estimator.html。 141 | 因此需要使用 estimator.export_saved_model() 方法把 estimator 重新导出成 tf.saved_model。 142 | 代码参考了 https://github.com/bigboNed3/bert_serving),在run_classifier_serving中定义serving_input_fn()函数,如下:
143 | ```Python 144 | def serving_input_fn(): 145 | label_ids = tf.placeholder(tf.int32, [None], name='label_ids') 146 | input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids') 147 | input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask') 148 | segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids') 149 | input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ 150 | 'label_ids': label_ids, 151 | 'input_ids': input_ids, 152 | 'input_mask': input_mask, 153 | 'segment_ids': segment_ids, 154 | }) 155 | return input_fn 156 | ``` 157 | 继而在run_classifier_serving中定义do_export选项: 158 | ```Python 159 | if do_export: 160 | estimator._export_to_tpu = False 161 | estimator.export_savedmodel(Flags.export_dir, serving_input_fn) 162 | ``` 163 | ### Step4:线下实时预测 164 | 运行test_serving.py文件,即可进行线下实时预测。
165 | 运行效果如下所示:
166 | ![运行效果图](https://github.com/Vincent131499/TextClassifier_BERT/raw/master/imgs/serving_offline.jpg)
167 | 详细说明:导出模型后,就不需要第 859 行那个 estimator 对象了,可以自行从刚刚的导出模型目录加载模型,代码如下:
168 | ```Python 169 | predict_fn = tf.contrib.predictor.from_saved_model('/exported/1571054350') 170 | ``` 171 | 基于上面的 predict_fn 变量,就可以直接进行预测了。下面是一个从标准输入流读取问题样本,并预测分类的样例代码:
172 | ```Python 173 | while True: 174 | question = input("> ") 175 | predict_example = InputExample("id", question, None, '某固定伪标记') 176 | feature = convert_single_example(100, predict_example, label_list, 177 | FLAGS.max_seq_length, tokenizer) 178 | 179 | prediction = predict_fn({ 180 | "input_ids":[feature.input_ids], 181 | "input_mask":[feature.input_mask], 182 | "segment_ids":[feature.segment_ids], 183 | "label_ids":[feature.label_id], 184 | }) 185 | probabilities = prediction["probabilities"] 186 | label = label_list[probabilities.argmax()] 187 | print(label) 188 | ``` 189 | ## 模式2:服务端实时预测 190 | 首先针对该模式的基本架构进行说明:
191 | ![服务端部署架构](https://github.com/Vincent131499/TextClassifier_BERT/raw/master/imgs/serving_deploy_arcticture.jpg) 192 |
架构说明:
193 | BERT模型服务端:加载模型,进行实时预测的服务; 使用的是 BERT-BiLSTM-CRF-NER提供的bert-base;
194 | API服务端:调用实时预测服务,为应用提供API接口的服务,用flask编写;
195 | 应用端:最终的应用端; 我这里为了简便,并没有编写网页,直接调用了api接口。
196 | ### Step1:数据准备 197 | 同模式1中的Step1介绍。 198 | ### Step2:模型训练 199 | 同模式1中的Step2介绍。 200 | ### Step3:模型转换 201 | 运行如下命令: 202 | ```Bash 203 | bash model_convert.sh 204 | ``` 205 | 会在$TRAINED_CLASSIFIER/$EXP_NAME生成pb格式的模型文件
206 | model_convert.sh参数说明: 207 | ```Bash 208 | export BERT_BASE_DIR=./chinese_roberta_zh_l12 #训练模型时使用的预训练语言模型所在路径 209 | export TRAINED_CLASSIFIER=./output #训练好的模型输出的路径 210 | export EXP_NAME=mobile_0_roberta_base #训练后保存的模型命名 211 | 212 | python freeze_graph.py \ 213 | -bert_model_dir $BERT_BASE_DIR \ 214 | -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \ 215 | -max_seq_len 128 #注意,这里的max_seq_len应与训练的脚本train.sh设置的max_seq_length参数值保持一致 216 | ``` 217 | ### Step4:模型部署 218 | 运行如下命令: 219 | ```Bash 220 | bash bert_classify_server.sh 221 | ``` 222 | 提示:在运行这个命令前需要保证安装了bert-base这个库,使用如下命令进行安装: 223 | ```Bash 224 | pip install bert-base 225 | ``` 226 | **注意**:
227 | port 和 port_out 这两个参数是API调用的端口号,默认是5555和5556,如果你准备部署多个模型服务实例,那一定要指定自己的端口号,避免冲突。 228 | 我这里是改为: 5575 和 5576
229 | 如果报错没运行起来,可能是有些模块没装上,都是 bert_base/server/http.py里引用的,装上就好了: 230 | ``` 231 | sudo pip install flask 232 | sudo pip install flask_compress 233 | sudo pip install flask_cors 234 | sudo pip install flask_json 235 | ``` 236 | 我这里的配置是2个GTX 1080 Ti,这个时候双卡的优势终于发挥出来了,GPU 1用于预测,GPU 0还可以继续训练模型。
237 | 部署成功示例图如下:
238 | ![部署成功示例图](https://github.com/Vincent131499/TextClassifier_BERT/raw/master/imgs/deploy-success.jpg) 239 | ### Step5:应用端 240 | 运行如下命令: 241 | ```Bash 242 | python api/api_service_flask.py 243 | ``` 244 | 即可通过指定api接口(本项目中是http://192.168.9.23:8910/predict_online?text=我好开心)访问部署的服务器。
245 | 通过浏览器进行请求:
246 | ![浏览器请求](https://github.com/Vincent131499/TextClassifier_BERT/raw/master/imgs/api_example.jpg) 247 | 248 | ## TODO: 249 | - [ ] fp16 support 250 | - [ ] add LAMB optimizer 251 | - [x] train data shuffle 252 | - [x] do_froze -------------------------------------------------------------------------------- /run_savedModel_infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import tokenization 5 | import tensorflow as tf 6 | import numpy as np 7 | import time 8 | 9 | parser = argparse.ArgumentParser( 10 | description='BERT model saved model case/batch test program, exit with q') 11 | 12 | parser.add_argument('--model', type=str, 13 | default='/Users/huangdongxiao2/CodeRepos/SesameSt/albert_zh/inference/robert_tiny_clue', help='the path for the model') 14 | parser.add_argument('--vocab_file', type=str, 15 | default='/Users/huangdongxiao2/CodeRepos/SesameSt/albert_zh/inference/robert_tiny_clue/vocab.txt') 16 | # default='../ALBERT/albert_base/vocab_chinese.txt') 17 | parser.add_argument('--labels', type=list, default=[ 18 | 'happy', 'anger', 'lost', 'fear', 'sad', 'other', 'anxiety'], help='label list') 19 | parser.add_argument('--max_seq_length', type=int, default=128, 20 | help='the length of sequence for text padding') 21 | parser.add_argument('--tensor_input_ids', type=str, default='input_ids', 22 | help='the input_ids feature name for saved model') 23 | parser.add_argument('--tensor_input_mask', type=str, default='input_mask', 24 | help='the input_mask feature name for saved model') 25 | parser.add_argument('--tensor_segment_ids', type=str, default='segment_ids', 26 | help='the segment_ids feature name for saved model') 27 | parser.add_argument('--MODE', type=str, default='SINGLE', 28 | help='SINGLE prediction or BATCH prediction') 29 | args_in_use = parser.parse_args() 30 | 31 | 32 | class InputExample(object): 33 | """A single training/test example for simple sequence classification.""" 34 | 35 | def __init__(self, guid, text_a, text_b=None, label=None): 36 | """Constructs a InputExample. 37 | Args: 38 | guid: Unique id for the example. 39 | text_a: string. The untokenized text of the first sequence. For single 40 | sequence tasks, only this sequence must be specified. 41 | text_b: (Optional) string. The untokenized text of the second sequence. 42 | Only must be specified for sequence pair tasks. 43 | label: (Optional) string. The label of the example. This should be 44 | specified for train and dev examples, but not for test examples. 45 | """ 46 | self.guid = guid 47 | self.text_a = text_a 48 | self.text_b = text_b 49 | 50 | 51 | class InputFeatures(object): 52 | """A single set of features of data.""" 53 | 54 | def __init__(self, 55 | input_ids, 56 | input_mask, 57 | segment_ids, 58 | is_real_example=True): 59 | self.input_ids = input_ids 60 | self.input_mask = input_mask 61 | self.segment_ids = segment_ids 62 | self.is_real_example = is_real_example 63 | 64 | 65 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 66 | """Truncates a sequence pair in place to the maximum length.""" 67 | 68 | # This is a simple heuristic which will always truncate the longer sequence 69 | # one token at a time. This makes more sense than truncating an equal percent 70 | # of tokens from each, since if one sequence is very short then each token 71 | # that's truncated likely contains more information than a longer sequence. 72 | while True: 73 | total_length = len(tokens_a) + len(tokens_b) 74 | if total_length <= max_length: 75 | break 76 | if len(tokens_a) > len(tokens_b): 77 | tokens_a.pop() 78 | else: 79 | tokens_b.pop() 80 | 81 | 82 | def convert_single_example(example, label_list, max_seq_length, 83 | tokenizer): 84 | """Converts a single `InputExample` into a single `InputFeatures`.""" 85 | 86 | label_map = {} 87 | for (i, label) in enumerate(label_list): 88 | label_map[label] = i 89 | 90 | tokens_a = tokenizer.tokenize(example.text_a) 91 | tokens_b = None 92 | if example.text_b: 93 | tokens_b = tokenizer.tokenize(example.text_b) 94 | 95 | if tokens_b: 96 | # Modifies `tokens_a` and `tokens_b` in place so that the total 97 | # length is less than the specified length. 98 | # Account for [CLS], [SEP], [SEP] with "- 3" 99 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 100 | else: 101 | # Account for [CLS] and [SEP] with "- 2" 102 | if len(tokens_a) > max_seq_length - 2: 103 | tokens_a = tokens_a[0:(max_seq_length - 2)] 104 | 105 | # The convention in BERT is: 106 | # (a) For sequence pairs: 107 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 108 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 109 | # (b) For single sequences: 110 | # tokens: [CLS] the dog is hairy . [SEP] 111 | # type_ids: 0 0 0 0 0 0 0 112 | # 113 | # Where "type_ids" are used to indicate whether this is the first 114 | # sequence or the second sequence. The embedding vectors for `type=0` and 115 | # `type=1` were learned during pre-training and are added to the wordpiece 116 | # embedding vector (and position vector). This is not *strictly* necessary 117 | # since the [SEP] token unambiguously separates the sequences, but it makes 118 | # it easier for the model to learn the concept of sequences. 119 | # 120 | # For classification tasks, the first vector (corresponding to [CLS]) is 121 | # used as the "sentence vector". Note that this only makes sense because 122 | # the entire model is fine-tuned. 123 | tokens = [] 124 | segment_ids = [] 125 | tokens.append("[CLS]") 126 | segment_ids.append(0) 127 | for token in tokens_a: 128 | tokens.append(token) 129 | segment_ids.append(0) 130 | tokens.append("[SEP]") 131 | segment_ids.append(0) 132 | 133 | if tokens_b: 134 | for token in tokens_b: 135 | tokens.append(token) 136 | segment_ids.append(1) 137 | tokens.append("[SEP]") 138 | segment_ids.append(1) 139 | 140 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 141 | 142 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 143 | # tokens are attended to. 144 | input_mask = [1] * len(input_ids) 145 | 146 | # Zero-pad up to the sequence length. 147 | while len(input_ids) < max_seq_length: 148 | input_ids.append(0) 149 | input_mask.append(0) 150 | segment_ids.append(0) 151 | 152 | assert len(input_ids) == max_seq_length 153 | assert len(input_mask) == max_seq_length 154 | assert len(segment_ids) == max_seq_length 155 | 156 | feature = InputFeatures( 157 | input_ids=input_ids, 158 | input_mask=input_mask, 159 | segment_ids=segment_ids, 160 | is_real_example=True) 161 | return feature 162 | 163 | 164 | if __name__ == '__main__': 165 | predict_fn = tf.contrib.predictor.from_saved_model(args_in_use.model) 166 | label_list = args_in_use.labels 167 | label_map = {i: label for i, label in enumerate(label_list)} 168 | max_seq_length = args_in_use.max_seq_length 169 | tokenizer = tokenization.FullTokenizer(vocab_file=args_in_use.vocab_file, 170 | do_lower_case=True) 171 | if args_in_use.MODE == 'SINGLE': 172 | while True: 173 | question = input('(PRESS q to quit)\n> ') 174 | if question == 'q': 175 | break 176 | predict_example = InputExample('id', question, None) 177 | feature = convert_single_example( 178 | predict_example, label_list, max_seq_length, tokenizer) 179 | start_time = time.time() 180 | prediction = predict_fn({ 181 | args_in_use.tensor_input_ids: [feature.input_ids], 182 | args_in_use.tensor_input_mask: [feature.input_mask], 183 | args_in_use.tensor_segment_ids: [feature.segment_ids], 184 | }) 185 | print(f'elapsed time: {time.time()-start_time}s') 186 | print(prediction) 187 | probabilities = prediction["probabilities"] 188 | max_index = np.argmax(probabilities[0]) 189 | print(probabilities, '\n') 190 | print(f'label: {label_map[max_index]}') 191 | 192 | elif args_in_use.MODE == 'BATCH': 193 | questions = [ 194 | '我要投诉的', 195 | '我很不开心', 196 | '我好喜欢你' 197 | ] 198 | features = [] 199 | for i, question in enumerate(questions): 200 | predict_example = InputExample(f'id{i}', question, None) 201 | feature = convert_single_example( 202 | predict_example, label_list, max_seq_length, tokenizer) 203 | features.append(feature) 204 | feed_dict = { 205 | args_in_use.tensor_input_ids: [feature.input_ids for feature in features], 206 | args_in_use.tensor_input_mask: [feature.input_mask for feature in features], 207 | args_in_use.tensor_segment_ids: [feature.segment_ids for feature in features], 208 | } 209 | predictions = predict_fn(feed_dict) 210 | probabilities = predictions["q"] 211 | print(probabilities, '\n') 212 | max_idxs = np.argmax(probabilities, 1) 213 | print( 214 | f'labels: {[label_map[max_index] for max_index in max_idxs]}') 215 | else: 216 | raise ValueError('unsupported mode') 217 | -------------------------------------------------------------------------------- /test_tf_serving.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Description : 使用tf-serving部署,使用requests调用接口 4 | Author : MeteorMan 5 | date: 2020/3/4 6 | """ 7 | 8 | import json 9 | import numpy as np 10 | import requests 11 | import pickle 12 | import os 13 | import tensorflow as tf 14 | import tokenization 15 | 16 | 17 | class InputExample(object): 18 | """A single training/test example for simple sequence classification.""" 19 | 20 | def __init__(self, guid, text_a, text_b=None, label=None): 21 | """Constructs a InputExample. 22 | 23 | Args: 24 | guid: Unique id for the example. 25 | text_a: string. The untokenized text of the first sequence. For single 26 | sequence tasks, only this sequence must be specified. 27 | text_b: (Optional) string. The untokenized text of the second sequence. 28 | Only must be specified for sequence pair tasks. 29 | label: (Optional) string. The label of the example. This should be 30 | specified for train and dev examples, but not for test examples. 31 | """ 32 | self.guid = guid 33 | self.text_a = text_a 34 | self.text_b = text_b 35 | self.label = label 36 | 37 | class PaddingInputExample(object): 38 | """Fake example so the num input examples is a multiple of the batch size. 39 | 40 | When running eval/predict on the TPU, we need to pad the number of examples 41 | to be a multiple of the batch size, because the TPU requires a fixed batch 42 | size. The alternative is to drop the last batch, which is bad because it means 43 | the entire output data won't be generated. 44 | 45 | We use this class instead of `None` because treating `None` as padding 46 | battches could cause silent errors. 47 | """ 48 | 49 | 50 | class InputFeatures(object): 51 | """A single set of features of data.""" 52 | 53 | def __init__(self, 54 | input_ids, 55 | input_mask, 56 | segment_ids, 57 | label_id, 58 | is_real_example=True): 59 | self.input_ids = input_ids 60 | self.input_mask = input_mask 61 | self.segment_ids = segment_ids 62 | self.label_id = label_id 63 | self.is_real_example = is_real_example 64 | 65 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 66 | """Truncates a sequence pair in place to the maximum length.""" 67 | 68 | # This is a simple heuristic which will always truncate the longer sequence 69 | # one token at a time. This makes more sense than truncating an equal percent 70 | # of tokens from each, since if one sequence is very short then each token 71 | # that's truncated likely contains more information than a longer sequence. 72 | while True: 73 | total_length = len(tokens_a) + len(tokens_b) 74 | if total_length <= max_length: 75 | break 76 | if len(tokens_a) > len(tokens_b): 77 | tokens_a.pop() 78 | else: 79 | tokens_b.pop() 80 | 81 | 82 | def convert_single_example(ex_index, example, label_list, max_seq_length, 83 | tokenizer): 84 | """Converts a single `InputExample` into a single `InputFeatures`.""" 85 | 86 | if isinstance(example, PaddingInputExample): 87 | return InputFeatures( 88 | input_ids=[0] * max_seq_length, 89 | input_mask=[0] * max_seq_length, 90 | segment_ids=[0] * max_seq_length, 91 | label_id=0, 92 | is_real_example=False) 93 | 94 | label_map = {} 95 | for (i, label) in enumerate(label_list): 96 | label_map[label] = i 97 | 98 | tokens_a = tokenizer.tokenize(example.text_a) 99 | tokens_b = None 100 | if example.text_b: 101 | tokens_b = tokenizer.tokenize(example.text_b) 102 | 103 | if tokens_b: 104 | # Modifies `tokens_a` and `tokens_b` in place so that the total 105 | # length is less than the specified length. 106 | # Account for [CLS], [SEP], [SEP] with "- 3" 107 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 108 | else: 109 | # Account for [CLS] and [SEP] with "- 2" 110 | if len(tokens_a) > max_seq_length - 2: 111 | tokens_a = tokens_a[0:(max_seq_length - 2)] 112 | 113 | # The convention in BERT is: 114 | # (a) For sequence pairs: 115 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 116 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 117 | # (b) For single sequences: 118 | # tokens: [CLS] the dog is hairy . [SEP] 119 | # type_ids: 0 0 0 0 0 0 0 120 | # 121 | # Where "type_ids" are used to indicate whether this is the first 122 | # sequence or the second sequence. The embedding vectors for `type=0` and 123 | # `type=1` were learned during pre-training and are added to the wordpiece 124 | # embedding vector (and position vector). This is not *strictly* necessary 125 | # since the [SEP] token unambiguously separates the sequences, but it makes 126 | # it easier for the model to learn the concept of sequences. 127 | # 128 | # For classification tasks, the first vector (corresponding to [CLS]) is 129 | # used as the "sentence vector". Note that this only makes sense because 130 | # the entire model is fine-tuned. 131 | tokens = [] 132 | segment_ids = [] 133 | tokens.append("[CLS]") 134 | segment_ids.append(0) 135 | for token in tokens_a: 136 | tokens.append(token) 137 | segment_ids.append(0) 138 | tokens.append("[SEP]") 139 | segment_ids.append(0) 140 | 141 | if tokens_b: 142 | for token in tokens_b: 143 | tokens.append(token) 144 | segment_ids.append(1) 145 | tokens.append("[SEP]") 146 | segment_ids.append(1) 147 | 148 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 149 | 150 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 151 | # tokens are attended to. 152 | input_mask = [1] * len(input_ids) 153 | 154 | # Zero-pad up to the sequence length. 155 | while len(input_ids) < max_seq_length: 156 | input_ids.append(0) 157 | input_mask.append(0) 158 | segment_ids.append(0) 159 | 160 | assert len(input_ids) == max_seq_length 161 | assert len(input_mask) == max_seq_length 162 | assert len(segment_ids) == max_seq_length 163 | 164 | label_id = label_map[example.label] 165 | if ex_index < 5: 166 | tf.logging.info("*** Example ***") 167 | tf.logging.info("guid: %s" % (example.guid)) 168 | tf.logging.info("tokens: %s" % " ".join( 169 | [tokenization.printable_text(x) for x in tokens])) 170 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 171 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 172 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 173 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 174 | 175 | feature = InputFeatures( 176 | input_ids=input_ids, 177 | input_mask=input_mask, 178 | segment_ids=segment_ids, 179 | label_id=label_id, 180 | is_real_example=True) 181 | return feature 182 | 183 | label_dict = pickle.load(open('./label2id.pkl', 'rb')) 184 | label_list = list(label_dict.keys()) 185 | # id2label = get_suit_dict() 186 | max_seq_length = 64 187 | tokenizer = tokenization.FullTokenizer(vocab_file='./vocab.txt', 188 | do_lower_case=True) 189 | 190 | def predict_offline_single(): 191 | while True: 192 | import time 193 | question = input("> ") 194 | start_time = time.time() 195 | predict_example = InputExample("100", question, None, '婚姻家庭') 196 | feature = convert_single_example(100, predict_example, label_list, max_seq_length, tokenizer) 197 | data = json.dumps({ 198 | "instances": [ 199 | { 200 | "input_ids": feature.input_ids, 201 | "input_mask": feature.input_mask, 202 | "segment_ids": feature.segment_ids, 203 | "label_ids": [feature.label_id] 204 | } 205 | ] 206 | }) 207 | headers = {"content-type": "application/json"} 208 | # json_response = requests.post( 209 | # 'http://127.0.0.1:9001/v1/models/bert_domain_classify:predict', 210 | # data=data, headers=headers) 211 | json_response = requests.post( 212 | 'http://192.168.0.105:9001/v1/models/bert_classify:predict', 213 | data=data, headers=headers) 214 | end_time = time.time() 215 | print('Spend time {}sec'.format(end_time-start_time)) 216 | predictions = np.array(json.loads(json_response.text)['predictions']) 217 | # print(np.argmax(predictions, axis=-1)) 218 | label = label_list[predictions.argmax()] 219 | print(label) 220 | 221 | 222 | def predict_offline_batch(): 223 | while True: 224 | import time 225 | question = input("> ") 226 | start_time = time.time() 227 | sents_list = question.split('|&|') 228 | instances_dict = {} 229 | instance_list = [] 230 | for sent in sents_list: 231 | predict_example = InputExample("100", sent, None, '婚姻家庭') 232 | feature = convert_single_example(100, predict_example, label_list, max_seq_length, tokenizer) 233 | instance_list.append( 234 | { 235 | "input_ids": feature.input_ids, 236 | "input_mask": feature.input_mask, 237 | "segment_ids": feature.segment_ids, 238 | "label_ids": [feature.label_id] 239 | } 240 | ) 241 | instances_dict['instances'] = instance_list 242 | 243 | data = json.dumps({ 244 | "instances": instances_dict['instances'] 245 | }) 246 | headers = {"content-type": "application/json"} 247 | # json_response = requests.post( 248 | # 'http://127.0.0.1:9001/v1/models/bert_domain_classify:predict', 249 | # data=data, headers=headers) 250 | json_response = requests.post( 251 | 'http://192.168.0.105:9001/v1/models/bert_classify:predict', 252 | data=data, headers=headers) 253 | end_time = time.time() 254 | print('Spend time {}sec'.format(end_time-start_time)) 255 | predictions = np.array(json.loads(json_response.text)['predictions']) 256 | print(predictions) 257 | # print(np.argmax(predictions, axis=-1)) 258 | for pred in predictions: 259 | 260 | label = label_list[pred.argmax()] 261 | print(label) 262 | 263 | predict_offline_single() 264 | # predict_offline_batch() 265 | -------------------------------------------------------------------------------- /run_pb_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import tokenization 5 | import tensorflow as tf 6 | import numpy as np 7 | import time 8 | parser = argparse.ArgumentParser( 9 | description='BERT model pb model case/batch test program, exit with q') 10 | 11 | parser.add_argument('--model', type=str, 12 | default='./inference/robert_tiny_clue/frozen_model.pb', help='the path for the model') 13 | #TODO: CHECK vocab file 14 | parser.add_argument('--vocab_file', type=str, 15 | # default='../ALBERT/albert_base/vocab_chinese.txt') 16 | default='/Users/huangdongxiao2/CodeRepos/SesameSt/albert_zh/inference/robert_tiny_clue/vocab.txt') 17 | parser.add_argument('--labels', type=list, default=[ 18 | 'happy', 'anger', 'lost', 'fear', 'sad', 'other', 'anxiety'], help='label list') 19 | parser.add_argument('--max_seq_length', type=int, default=128, 20 | help='the length of sequence for text padding') 21 | parser.add_argument('--tensor_input_ids', type=str, default='input_ids:0', 22 | help='the input_ids op_name for graph, format: :') 23 | parser.add_argument('--tensor_input_mask', type=str, default='input_mask:0', 24 | help='the input_mask op_name for graph, format: :') 25 | parser.add_argument('--tensor_segment_ids', type=str, default='segment_ids:0', 26 | help='the segment_ids op_name for graph, format: :') 27 | parser.add_argument('--tensor_output', type=str, default='loss/pred_prob:0', 28 | help='the output op_name for graph, format: :') 29 | parser.add_argument('--MODE', type=str, default='SINGLE', 30 | help='SINGLE prediction or BATCH prediction') 31 | args_in_use = parser.parse_args() 32 | """ 33 | gpu settting 34 | """ 35 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # default: 0 36 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 37 | """ 38 | load pb model and predict 39 | """ 40 | 41 | 42 | class InputExample(object): 43 | """A single training/test example for simple sequence classification.""" 44 | 45 | def __init__(self, guid, text_a, text_b=None): 46 | """Constructs a InputExample. 47 | Args: 48 | guid: Unique id for the example. 49 | text_a: string. The untokenized text of the first sequence. For single 50 | sequence tasks, only this sequence must be specified. 51 | text_b: (Optional) string. The untokenized text of the second sequence. 52 | Only must be specified for sequence pair tasks. 53 | label: (Optional) string. The label of the example. This should be 54 | specified for train and dev examples, but not for test examples. 55 | """ 56 | self.guid = guid 57 | self.text_a = text_a 58 | self.text_b = text_b 59 | 60 | 61 | class InputFeatures(object): 62 | """A single set of features of data.""" 63 | 64 | def __init__(self, 65 | input_ids, 66 | input_mask, 67 | segment_ids, 68 | is_real_example=True): 69 | self.input_ids = input_ids 70 | self.input_mask = input_mask 71 | self.segment_ids = segment_ids 72 | self.is_real_example = is_real_example 73 | 74 | 75 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 76 | """Truncates a sequence pair in place to the maximum length.""" 77 | 78 | # This is a simple heuristic which will always truncate the longer sequence 79 | # one token at a time. This makes more sense than truncating an equal percent 80 | # of tokens from each, since if one sequence is very short then each token 81 | # that's truncated likely contains more information than a longer sequence. 82 | while True: 83 | total_length = len(tokens_a) + len(tokens_b) 84 | if total_length <= max_length: 85 | break 86 | if len(tokens_a) > len(tokens_b): 87 | tokens_a.pop() 88 | else: 89 | tokens_b.pop() 90 | 91 | 92 | def convert_single_example(example, label_list, max_seq_length, 93 | tokenizer): 94 | """Converts a single `InputExample` into a single `InputFeatures`.""" 95 | 96 | label_map = {} 97 | for (i, label) in enumerate(label_list): 98 | label_map[label] = i 99 | 100 | tokens_a = tokenizer.tokenize(example.text_a) 101 | tokens_b = None 102 | if example.text_b: 103 | tokens_b = tokenizer.tokenize(example.text_b) 104 | 105 | if tokens_b: 106 | # Modifies `tokens_a` and `tokens_b` in place so that the total 107 | # length is less than the specified length. 108 | # Account for [CLS], [SEP], [SEP] with "- 3" 109 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 110 | else: 111 | # Account for [CLS] and [SEP] with "- 2" 112 | if len(tokens_a) > max_seq_length - 2: 113 | tokens_a = tokens_a[0:(max_seq_length - 2)] 114 | 115 | # The convention in BERT is: 116 | # (a) For sequence pairs: 117 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 118 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 119 | # (b) For single sequences: 120 | # tokens: [CLS] the dog is hairy . [SEP] 121 | # type_ids: 0 0 0 0 0 0 0 122 | # 123 | # Where "type_ids" are used to indicate whether this is the first 124 | # sequence or the second sequence. The embedding vectors for `type=0` and 125 | # `type=1` were learned during pre-training and are added to the wordpiece 126 | # embedding vector (and position vector). This is not *strictly* necessary 127 | # since the [SEP] token unambiguously separates the sequences, but it makes 128 | # it easier for the model to learn the concept of sequences. 129 | # 130 | # For classification tasks, the first vector (corresponding to [CLS]) is 131 | # used as the "sentence vector". Note that this only makes sense because 132 | # the entire model is fine-tuned. 133 | tokens = [] 134 | segment_ids = [] 135 | tokens.append("[CLS]") 136 | segment_ids.append(0) 137 | for token in tokens_a: 138 | tokens.append(token) 139 | segment_ids.append(0) 140 | tokens.append("[SEP]") 141 | segment_ids.append(0) 142 | 143 | if tokens_b: 144 | for token in tokens_b: 145 | tokens.append(token) 146 | segment_ids.append(1) 147 | tokens.append("[SEP]") 148 | segment_ids.append(1) 149 | 150 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 151 | 152 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 153 | # tokens are attended to. 154 | input_mask = [1] * len(input_ids) 155 | 156 | # Zero-pad up to the sequence length. 157 | while len(input_ids) < max_seq_length: 158 | input_ids.append(0) 159 | input_mask.append(0) 160 | segment_ids.append(0) 161 | 162 | assert len(input_ids) == max_seq_length 163 | assert len(input_mask) == max_seq_length 164 | assert len(segment_ids) == max_seq_length 165 | 166 | feature = InputFeatures( 167 | input_ids=input_ids, 168 | input_mask=input_mask, 169 | segment_ids=segment_ids, 170 | is_real_example=True) 171 | return feature 172 | 173 | 174 | with tf.Graph().as_default(): 175 | output_graph_def = tf.GraphDef() 176 | label_list = args_in_use.labels 177 | label_map = {i: label for i, label in enumerate(label_list)} 178 | 179 | max_seq_length = args_in_use.max_seq_length 180 | """ 181 | load pb model 182 | """ 183 | with open(args_in_use.model, 'rb') as f: 184 | output_graph_def.ParseFromString(f.read()) 185 | tf.import_graph_def(output_graph_def, name='') 186 | """ 187 | enter a text and predict 188 | """ 189 | with tf.Session() as sess: 190 | tf.global_variables_initializer().run() 191 | input_ids = sess.graph.get_tensor_by_name( 192 | args_in_use.tensor_input_ids) 193 | input_mask = sess.graph.get_tensor_by_name( 194 | args_in_use.tensor_input_mask) 195 | segment_ids = sess.graph.get_tensor_by_name( 196 | args_in_use.tensor_segment_ids) 197 | tokenizer = tokenization.FullTokenizer( 198 | vocab_file=args_in_use.vocab_file, do_lower_case=True) 199 | output = args_in_use.tensor_output 200 | if args_in_use.MODE == 'SINGLE': 201 | while 1: 202 | question = input("enter a sentence:") 203 | if question == 'q' or question == 'quit()': 204 | break 205 | predict_example = InputExample('id', question, None) 206 | feature = convert_single_example( 207 | predict_example, label_list, max_seq_length, tokenizer) 208 | # print(feature.input_ids) 209 | # print(feature.input_mask) 210 | # print(feature.segment_ids) 211 | feed_dict = { 212 | input_ids: [feature.input_ids], 213 | input_mask: [feature.input_mask], 214 | segment_ids: [feature.segment_ids], 215 | } 216 | # works fine 217 | # feed_dict = { 218 | # args_in_use.tensor_input_ids: [feature.input_ids], 219 | # args_in_use.tensor_input_mask: [feature.input_mask], 220 | # args_in_use.tensor_segment_ids: [feature.segment_ids], 221 | # } 222 | start_time = time.time() 223 | y_pred_cls = sess.run(output, feed_dict=feed_dict) 224 | print(f'elapsed time: {time.time()-start_time}s') 225 | max_index = np.argmax(y_pred_cls[0]) 226 | print(" current results ", y_pred_cls) 227 | print(f'label: {label_map[max_index]}') 228 | elif args_in_use.MODE == 'BATCH': 229 | questions = [ 230 | '我要投诉的', 231 | '我很不开心', 232 | '我好喜欢你' 233 | ] 234 | features = [] 235 | for i, question in enumerate(questions): 236 | predict_example = InputExample(f'id{i}', question, None) 237 | feature = convert_single_example( 238 | predict_example, label_list, max_seq_length, tokenizer) 239 | features.append(feature) 240 | feed_dict = { 241 | input_ids: [feature.input_ids for feature in features], 242 | input_mask: [feature.input_mask for feature in features], 243 | segment_ids: [feature.segment_ids for feature in features], 244 | } 245 | y_pred_cls = sess.run(output, feed_dict=feed_dict) 246 | max_idxs = np.argmax(y_pred_cls, 1) 247 | print(y_pred_cls) 248 | print( 249 | f'labels: {[label_map[max_index] for max_index in max_idxs]}') 250 | else: 251 | raise ValueError('unsupported mode') 252 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat.startswith("C"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /cmrc/cmrc_tool/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | class FullTokenizer(object): # what bert uses 161 | """Runs end-to-end tokenziation.""" 162 | 163 | def __init__(self, vocab_file, do_lower_case=True): 164 | self.vocab = load_vocab(vocab_file) 165 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 166 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 167 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 168 | 169 | def tokenize(self, text): 170 | split_tokens = [] 171 | for token in self.basic_tokenizer.tokenize(text): 172 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 173 | split_tokens.append(sub_token) 174 | 175 | return split_tokens 176 | 177 | def convert_tokens_to_ids(self, tokens): 178 | return convert_by_vocab(self.vocab, tokens) 179 | 180 | def convert_ids_to_tokens(self, ids): 181 | return convert_by_vocab(self.inv_vocab, ids) 182 | 183 | class BasicTokenizer(object):# tokenize by space 184 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 185 | 186 | def __init__(self, do_lower_case=True): 187 | """Constructs a BasicTokenizer. 188 | 189 | Args: 190 | do_lower_case: Whether to lower case the input. 191 | """ 192 | self.do_lower_case = do_lower_case 193 | 194 | def tokenize(self, text): 195 | """Tokenizes a piece of text.""" 196 | text = convert_to_unicode(text) 197 | text = self._clean_text(text) 198 | 199 | # This was added on November 1st, 2018 for the multilingual and Chinese 200 | # models. This is also applied to the English models now, but it doesn't 201 | # matter since the English models were not trained on any Chinese data 202 | # and generally don't have any Chinese data in them (there are Chinese 203 | # characters in the vocabulary because Wikipedia does have some Chinese 204 | # words in the English Wikipedia.). 205 | text = self._tokenize_chinese_chars(text) # 也就是在中文字符的前后加上空格,这样后续的分词流程会把每一个字符当成一个词。 206 | 207 | orig_tokens = whitespace_tokenize(text) 208 | split_tokens = [] 209 | for token in orig_tokens: 210 | if self.do_lower_case: 211 | token = token.lower() 212 | token = self._run_strip_accents(token) 213 | split_tokens.extend(self._run_split_on_punc(token)) 214 | 215 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 216 | return output_tokens 217 | 218 | def _run_strip_accents(self, text): 219 | """ 220 | Strips accents from a piece of text. 221 | 222 | >>> s1 = 'café' 223 | >>> s2 = 'cafe\u0301' 224 | >>> s1, s2 225 | ('café', 'café') 226 | >>> len(s1), len(s2) 227 | (4, 5) 228 | >>> s1 == s2 229 | False 230 | 我们”看到”的é其实可以有两种表示方法,一是用一个codepoint直接表示”é”,另外一种是用”e”再加上特殊的codepoint U+0301两个字符来表示。U+0301是COMBINING ACUTE ACCENT,它跟在e之后就变成了”é”。类似的”a\u0301”显示出来就是”á”。注意:这只是打印出来一模一样而已,但是在计算机内部的表示它们完全不同的,前者é是一个codepoint,值为0xe9,而后者是两个codepoint,分别是0x65和0x301。unicodedata.normalize(“NFD”, text)就会把0xe9变成0x65和0x301,比如下面的测试代码。 231 | 232 | 233 | """ 234 | text = unicodedata.normalize("NFD", text) 235 | output = [] 236 | for char in text: 237 | cat = unicodedata.category(char) 238 | if cat == "Mn": # ACCENT 239 | continue 240 | output.append(char) 241 | return "".join(output) 242 | 243 | def _run_split_on_punc(self, text): 244 | """Splits punctuation on a piece of text.""" 245 | chars = list(text) 246 | i = 0 247 | start_new_word = True 248 | output = [] 249 | """这个函数会对输入字符串用标点进行切分,返回一个list,list的每一个元素都是一个char。比如输入he’s,则输出是[[h,e], [’],[s]]。""" 250 | while i < len(chars): 251 | char = chars[i] 252 | if _is_punctuation(char): 253 | output.append([char]) 254 | start_new_word = True 255 | else: 256 | if start_new_word: 257 | output.append([]) 258 | start_new_word = False 259 | output[-1].append(char) 260 | i += 1 261 | 262 | return ["".join(x) for x in output] 263 | 264 | def _tokenize_chinese_chars(self, text): 265 | """Adds whitespace around any CJK character.""" 266 | output = [] 267 | for char in text: 268 | cp = ord(char) 269 | if self._is_chinese_char(cp): 270 | output.append(" ") 271 | output.append(char) 272 | output.append(" ") 273 | else: 274 | output.append(char) 275 | return "".join(output) 276 | 277 | def _is_chinese_char(self, cp): 278 | """Checks whether CP is the codepoint of a CJK character.""" 279 | # This defines a "chinese character" as anything in the CJK Unicode block: 280 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 281 | # 282 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 283 | # despite its name. The modern Korean Hangul alphabet is a different block, 284 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 285 | # space-separated words, so they are not treated specially and handled 286 | # like the all of the other languages. 287 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 288 | (cp >= 0x3400 and cp <= 0x4DBF) or # 289 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 290 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 291 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 292 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 293 | (cp >= 0xF900 and cp <= 0xFAFF) or # 294 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 295 | return True 296 | 297 | return False 298 | 299 | def _clean_text(self, text): 300 | """Performs invalid character removal and whitespace cleanup on text.""" 301 | output = [] 302 | for char in text: 303 | cp = ord(char) 304 | # codepoint == 0 means: nonsense character; codepoint == 0xfffd is shown as � which is used for unrecognized character 305 | if cp == 0 or cp == 0xfffd or _is_control(char): 306 | continue 307 | if _is_whitespace(char): 308 | output.append(" ") 309 | else: 310 | output.append(char) 311 | return "".join(output) 312 | 313 | 314 | class WordpieceTokenizer(object): # tokenize by space and then do it tokenize in a more detailede way to WordPiece level 315 | """Runs WordPiece tokenziation.""" 316 | 317 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 318 | self.vocab = vocab 319 | self.unk_token = unk_token 320 | self.max_input_chars_per_word = max_input_chars_per_word 321 | 322 | def tokenize(self, text): 323 | """Tokenizes a piece of text into its word pieces. 324 | 325 | This uses a greedy longest-match-first algorithm to perform tokenization 326 | using the given vocabulary. 327 | 328 | For example: 329 | input = "unaffable" 330 | 我们用一个例子来看代码的执行过程。比如假设输入是”unaffable”。我们跳到while循环部分,这是start=0,end=len(chars)=9,也就是先看看unaffable在不在词典里,如果在,那么直接作为一个WordPiece,如果不再,那么end-=1,也就是看unaffabl在不在词典里,最终发现”un”在词典里,把un加到结果里。 331 | 332 | 接着start=2,看affable在不在,不在再看affabl,…,最后发现 ##aff 在词典里。注意:##表示这个词是接着前面的,这样使得WordPiece切分是可逆的——我们可以恢复出“真正”的词。 333 | output = ["un", "##aff", "##able"] 334 | 335 | Args: 336 | text: A single token or whitespace separated tokens. This should have 337 | already been passed through `BasicTokenizer. 338 | 339 | Returns: 340 | A list of wordpiece tokens. 341 | """ 342 | 343 | text = convert_to_unicode(text) 344 | 345 | output_tokens = [] 346 | for token in whitespace_tokenize(text): 347 | chars = list(token) 348 | if len(chars) > self.max_input_chars_per_word: 349 | output_tokens.append(self.unk_token) 350 | continue 351 | 352 | is_bad = False 353 | start = 0 354 | sub_tokens = [] 355 | while start < len(chars): 356 | end = len(chars) 357 | cur_substr = None 358 | while start < end: 359 | substr = "".join(chars[start:end]) 360 | if start > 0: 361 | substr = "##" + substr 362 | if substr in self.vocab: 363 | cur_substr = substr 364 | break 365 | end -= 1 366 | if cur_substr is None: 367 | is_bad = True 368 | break 369 | sub_tokens.append(cur_substr) 370 | start = end 371 | 372 | if is_bad: 373 | output_tokens.append(self.unk_token) 374 | else: 375 | output_tokens.extend(sub_tokens) 376 | return output_tokens 377 | 378 | 379 | def _is_whitespace(char): 380 | """Checks whether `chars` is a whitespace character.""" 381 | # \t, \n, and \r are technically contorl characters but we treat them 382 | # as whitespace since they are generally considered as such. 383 | if char == " " or char == "\t" or char == "\n" or char == "\r": 384 | return True 385 | cat = unicodedata.category(char) 386 | if cat == "Zs": 387 | return True 388 | return False 389 | 390 | 391 | def _is_control(char): 392 | """Checks whether `chars` is a control character. non printable character""" 393 | # These are technically control characters but we count them as whitespace 394 | # characters. 395 | if char == "\t" or char == "\n" or char == "\r": 396 | return False 397 | cat = unicodedata.category(char) 398 | if cat.startswith("C"): 399 | return True 400 | return False 401 | 402 | 403 | def _is_punctuation(char): 404 | """Checks whether `chars` is a punctuation character.""" 405 | cp = ord(char) 406 | # We treat all non-letter/number ASCII as punctuation. 407 | # Characters such as "^", "$", and "`" are not in the Unicode 408 | # Punctuation class but we treat them as punctuation anyways, for 409 | # consistency. 410 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 411 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 412 | return True 413 | cat = unicodedata.category(char) 414 | if cat.startswith("P"): 415 | return True 416 | return False 417 | -------------------------------------------------------------------------------- /cmrc/cmrc_tool/run_squad_inf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import math 5 | import tokenization 6 | import collections 7 | import tensorflow as tf 8 | import numpy as np 9 | import time 10 | import six 11 | 12 | parser = argparse.ArgumentParser( 13 | description='BERT model saved model case/batch test program, exit with q') 14 | 15 | parser.add_argument('--model', type=str, 16 | default='./1586959006', help='the path for the model') 17 | parser.add_argument('--vocab_file', type=str, 18 | default='./1586959006/vocab.txt') 19 | parser.add_argument('--max_seq_length', type=int, default=384, 20 | help='the length of sequence for text padding') 21 | parser.add_argument('--do_lower_case', type=bool, default=True, 22 | help='Whether to lower case the input text. Should be True for uncased models and False for cased models.') 23 | parser.add_argument('--max_answer_length', type=int, default=128, 24 | help='The maximum length of an answer that can be generated. This is needed because the start and end predictions are not conditioned on one another.') 25 | parser.add_argument('--n_best_size', type=int, default=20, 26 | help='"The total number of n-best predictions to generate in the nbest_predictions.json output file.') 27 | parser.add_argument('--doc_stride', type=int, default=128, 28 | help='the length of document stride') 29 | parser.add_argument('--max_query_length', type=int, default=64, 30 | help='the max length of query length') 31 | parser.add_argument('--tensor_start_positions', type=str, default='start_positions', 32 | help='the start_positions feature name for saved model') 33 | parser.add_argument('--tensor_end_positions', type=str, default='end_positions', 34 | help='the end_positions feature name for saved model') 35 | parser.add_argument('--tensor_unique_ids', type=str, default='unique_ids', 36 | help='the unique_ids feature name for saved model') 37 | parser.add_argument('--tensor_input_ids', type=str, default='input_ids', 38 | help='the input_ids feature name for saved model') 39 | parser.add_argument('--tensor_input_mask', type=str, default='input_mask', 40 | help='the input_mask feature name for saved model') 41 | parser.add_argument('--tensor_segment_ids', type=str, default='segment_ids', 42 | help='the segment_ids feature name for saved model') 43 | parser.add_argument('--MODE', type=str, default='SINGLE', 44 | help='SINGLE prediction or BATCH prediction') 45 | args_in_use = parser.parse_args() 46 | 47 | 48 | class SquadExample(object): 49 | """A single training/test example for simple sequence classification. 50 | 51 | For examples without an answer, the start and end position are -1. 52 | """ 53 | 54 | def __init__(self, 55 | qas_id, 56 | question_text, 57 | doc_tokens, 58 | orig_answer_text=None, 59 | start_position=None, 60 | end_position=None, 61 | is_impossible=False): 62 | self.qas_id = qas_id 63 | self.question_text = question_text 64 | self.doc_tokens = doc_tokens 65 | self.orig_answer_text = orig_answer_text 66 | self.start_position = start_position 67 | self.end_position = end_position 68 | self.is_impossible = is_impossible 69 | 70 | def __str__(self): 71 | return self.__repr__() 72 | 73 | def __repr__(self): 74 | s = "" 75 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) 76 | s += ", question_text: %s" % ( 77 | tokenization.printable_text(self.question_text)) 78 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 79 | if self.start_position: 80 | s += ", start_position: %d" % (self.start_position) 81 | if self.start_position: 82 | s += ", end_position: %d" % (self.end_position) 83 | if self.start_position: 84 | s += ", is_impossible: %r" % (self.is_impossible) 85 | return s 86 | 87 | 88 | class InputFeatures(object): 89 | """A single set of features of data.""" 90 | 91 | def __init__(self, 92 | unique_id, 93 | example_index, 94 | doc_span_index, 95 | tokens, 96 | token_to_orig_map, 97 | token_is_max_context, 98 | input_ids, 99 | input_mask, 100 | segment_ids, 101 | start_position=None, 102 | end_position=None, 103 | is_impossible=None): 104 | self.unique_id = unique_id 105 | self.example_index = example_index 106 | self.doc_span_index = doc_span_index 107 | self.tokens = tokens 108 | self.token_to_orig_map = token_to_orig_map 109 | self.token_is_max_context = token_is_max_context 110 | self.input_ids = input_ids 111 | self.input_mask = input_mask 112 | self.segment_ids = segment_ids 113 | self.start_position = start_position 114 | self.end_position = end_position 115 | self.is_impossible = is_impossible 116 | 117 | 118 | def _check_is_max_context(doc_spans, cur_span_index, position): 119 | """Check if this is the 'max context' doc span for the token.""" 120 | 121 | # Because of the sliding window approach taken to scoring documents, a single 122 | # token can appear in multiple documents. E.g. 123 | # Doc: the man went to the store and bought a gallon of milk 124 | # Span A: the man went to the 125 | # Span B: to the store and bought 126 | # Span C: and bought a gallon of 127 | # ... 128 | # 129 | # Now the word 'bought' will have two scores from spans B and C. We only 130 | # want to consider the score with "maximum context", which we define as 131 | # the *minimum* of its left and right context (the *sum* of left and 132 | # right context will always be the same, of course). 133 | # 134 | # In the example the maximum context for 'bought' would be span C since 135 | # it has 1 left context and 3 right context, while span B has 4 left context 136 | # and 0 right context. 137 | best_score = None 138 | best_span_index = None 139 | for (span_index, doc_span) in enumerate(doc_spans): 140 | end = doc_span.start + doc_span.length - 1 141 | if position < doc_span.start: 142 | continue 143 | if position > end: 144 | continue 145 | num_left_context = position - doc_span.start 146 | num_right_context = end - position 147 | score = min(num_left_context, num_right_context) + \ 148 | 0.01 * doc_span.length 149 | if best_score is None or score > best_score: 150 | best_score = score 151 | best_span_index = span_index 152 | 153 | return cur_span_index == best_span_index 154 | 155 | 156 | def convert_examples_to_features(examples, tokenizer, max_seq_length, 157 | doc_stride, max_query_length): 158 | """Loads a data file into a list of `InputBatch`s.""" 159 | 160 | unique_id = 1000000000 161 | 162 | features = [] 163 | for (example_index, example) in enumerate(examples): 164 | query_tokens = tokenizer.tokenize(example.question_text) 165 | 166 | if len(query_tokens) > max_query_length: 167 | query_tokens = query_tokens[0:max_query_length] 168 | 169 | # 将原文变成token之后,token和原始文本index的对应关系 170 | tok_to_orig_index = [] 171 | # 将原文变成token之后,原始文本和token的index的对应关系 172 | orig_to_tok_index = [] 173 | all_doc_tokens = [] 174 | for (i, token) in enumerate(example.doc_tokens): 175 | orig_to_tok_index.append(len(all_doc_tokens)) 176 | sub_tokens = tokenizer.tokenize(token) 177 | for sub_token in sub_tokens: 178 | tok_to_orig_index.append(i) 179 | all_doc_tokens.append(sub_token) 180 | 181 | tok_start_position = None 182 | tok_end_position = None 183 | 184 | # The -3 accounts for [CLS], [SEP] and [SEP] 185 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 186 | 187 | # We can have documents that are longer than the maximum sequence length. 188 | # To deal with this we do a sliding window approach, where we take chunks 189 | # of the up to our max length with a stride of `doc_stride`. 190 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 191 | "DocSpan", ["start", "length"]) 192 | doc_spans = [] 193 | start_offset = 0 194 | while start_offset < len(all_doc_tokens): 195 | length = len(all_doc_tokens) - start_offset 196 | if length > max_tokens_for_doc: 197 | length = max_tokens_for_doc 198 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 199 | # 如果length是小于max_tokens_for_doc的话,那么就会直接break 200 | if start_offset + length == len(all_doc_tokens): 201 | break 202 | start_offset += min(length, doc_stride) 203 | 204 | for (doc_span_index, doc_span) in enumerate(doc_spans): 205 | tokens = [] 206 | token_to_orig_map = {} 207 | token_is_max_context = {} 208 | # segment_ids中0代表[CLS]、第一个[SEP]和query_tokens,1代表doc和第二个[SEP] 209 | segment_ids = [] 210 | tokens.append("[CLS]") 211 | segment_ids.append(0) 212 | for token in query_tokens: 213 | tokens.append(token) 214 | segment_ids.append(0) 215 | tokens.append("[SEP]") 216 | segment_ids.append(0) 217 | 218 | for i in range(doc_span.length): 219 | split_token_index = doc_span.start + i 220 | # len(tokens)为[CLS]+query_tokens+[SEP]的大小,应该是doc_tokens第i个token 221 | token_to_orig_map[len( 222 | tokens)] = tok_to_orig_index[split_token_index] 223 | 224 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, 225 | split_token_index) 226 | token_is_max_context[len(tokens)] = is_max_context 227 | tokens.append(all_doc_tokens[split_token_index]) 228 | segment_ids.append(1) 229 | tokens.append("[SEP]") 230 | segment_ids.append(1) 231 | 232 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 233 | 234 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 235 | # tokens are attended to. 236 | input_mask = [1] * len(input_ids) 237 | 238 | # Zero-pad up to the sequence length. 239 | while len(input_ids) < max_seq_length: 240 | input_ids.append(0) 241 | input_mask.append(0) 242 | segment_ids.append(0) 243 | 244 | assert len(input_ids) == max_seq_length 245 | assert len(input_mask) == max_seq_length 246 | assert len(segment_ids) == max_seq_length 247 | 248 | start_position = None 249 | end_position = None 250 | 251 | features.append( 252 | InputFeatures( 253 | unique_id=unique_id, 254 | example_index=example_index, 255 | doc_span_index=doc_span_index, 256 | tokens=tokens, # 可能会引入UNK 257 | token_to_orig_map=token_to_orig_map, 258 | token_is_max_context=token_is_max_context, 259 | input_ids=input_ids, 260 | input_mask=input_mask, 261 | segment_ids=segment_ids, 262 | start_position=start_position, 263 | end_position=end_position)) 264 | unique_id += 1 265 | 266 | return features 267 | 268 | 269 | def read_squad_examples(input_data, tokenizer): 270 | """ 271 | https://github.com/eva-n27/BERT-for-Chinese-Question-Answering/blob/master/run_squad.py 272 | Read a SQuAD json file into a list of SquadExample. 273 | 这个函数将input_data[i]["paragraphs"]["context"]变成一个list,词的list 274 | 然后遍历"qas",对于每一个qa,提取 275 | { 276 | qas_id: qa['id'], 277 | question_text: qa["question"], 278 | orig_answer_text: answer["text"], 279 | start_position: start_position, 280 | end_position: end_position 281 | } 282 | """ 283 | import unicodedata 284 | 285 | def is_whitespace(c): 286 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 287 | return True 288 | return False 289 | 290 | def is_control(char): 291 | """Checks whether `chars` is a control character.""" 292 | # These are technically control characters but we count them as whitespace 293 | # characters. 294 | if char == "\t" or char == "\n" or char == "\r": 295 | return False 296 | cat = unicodedata.category(char) 297 | if cat.startswith("C"): 298 | return True 299 | return False 300 | 301 | def clean_text(text): 302 | """Performs invalid character removal and whitespace cleanup on text.""" 303 | output = [] 304 | for char in text: 305 | cp = ord(char) 306 | if cp == 0 or cp == 0xfffd or is_control(char): 307 | continue 308 | if is_whitespace(char): 309 | output.append(" ") 310 | else: 311 | output.append(char) 312 | return "".join(output) 313 | 314 | examples = [] 315 | tf.logging.info("*** reading squad examples ***") 316 | for entry in input_data: 317 | for paragraph in entry["paragraphs"]: 318 | paragraph_text = " ".join(tokenization.whitespace_tokenize( 319 | clean_text(paragraph["context"]))) 320 | 321 | for qa in paragraph["qas"]: 322 | doc_tokens = tokenizer.basic_tokenizer.tokenize(paragraph_text) 323 | qas_id = qa["id"] 324 | question_text = qa["question"] 325 | start_position = None 326 | end_position = None 327 | orig_answer_text = None 328 | is_impossible = False 329 | 330 | example = SquadExample( 331 | qas_id=qas_id, 332 | question_text=question_text, 333 | doc_tokens=doc_tokens, 334 | orig_answer_text=orig_answer_text, 335 | start_position=start_position, 336 | end_position=end_position) 337 | examples.append(example) 338 | return examples 339 | 340 | 341 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): 342 | """Project the tokenized prediction back to the original text.""" 343 | 344 | # When we created the data, we kept track of the alignment between original 345 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 346 | # now `orig_text` contains the span of our original text corresponding to the 347 | # span that we predicted. 348 | # 349 | # However, `orig_text` may contain extra characters that we don't want in 350 | # our prediction. 351 | # 352 | # For example, let's say: 353 | # pred_text = steve smith 354 | # orig_text = Steve Smith's 355 | # 356 | # We don't want to return `orig_text` because it contains the extra "'s". 357 | # 358 | # We don't want to return `pred_text` because it's already been normalized 359 | # (the SQuAD eval script also does punctuation stripping/lower casing but 360 | # our tokenizer does additional normalization like stripping accent 361 | # characters). 362 | # 363 | # What we really want to return is "Steve Smith". 364 | # 365 | # Therefore, we have to apply a semi-complicated alignment heruistic between 366 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This 367 | # can fail in certain cases in which case we just return `orig_text`. 368 | 369 | def _strip_spaces(text): 370 | ns_chars = [] 371 | ns_to_s_map = collections.OrderedDict() 372 | for (i, c) in enumerate(text): 373 | if c == " ": 374 | continue 375 | ns_to_s_map[len(ns_chars)] = i 376 | ns_chars.append(c) 377 | ns_text = "".join(ns_chars) 378 | return (ns_text, ns_to_s_map) 379 | 380 | # We first tokenize `orig_text`, strip whitespace from the result 381 | # and `pred_text`, and check if they are the same length. If they are 382 | # NOT the same length, the heuristic has failed. If they are the same 383 | # length, we assume the characters are one-to-one aligned. 384 | 385 | ################################################# 386 | """ 387 | 原本 388 | """ 389 | # tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) 390 | # tok_text = "".join(tokenizer.tokenize(orig_text)) 391 | """ 392 | 更新的版本 20200416 393 | """ 394 | tokenizer = tokenization.FullTokenizer(vocab_file=args_in_use.vocab_file, 395 | do_lower_case=True) 396 | tok_text = "".join(tokenizer.tokenize(orig_text)) 397 | # De-tokenize WordPieces that have been split off. 398 | tok_text = tok_text.replace(" ##", "") 399 | tok_text = tok_text.replace("##", "") 400 | #################################################### 401 | 402 | start_position = tok_text.find(pred_text) 403 | if start_position == -1: 404 | print(f'{pred_text} not in {tok_text}') #UNK 的情况 405 | return orig_text 406 | end_position = start_position + len(pred_text) - 1 407 | 408 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 409 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 410 | 411 | if len(orig_ns_text) != len(tok_ns_text): 412 | return orig_text 413 | 414 | # We then project the characters in `pred_text` back to `orig_text` using 415 | # the character-to-character alignment. 416 | tok_s_to_ns_map = {} 417 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map): 418 | tok_s_to_ns_map[tok_index] = i 419 | 420 | orig_start_position = None 421 | if start_position in tok_s_to_ns_map: 422 | ns_start_position = tok_s_to_ns_map[start_position] 423 | if ns_start_position in orig_ns_to_s_map: 424 | orig_start_position = orig_ns_to_s_map[ns_start_position] 425 | 426 | if orig_start_position is None: 427 | return orig_text 428 | 429 | orig_end_position = None 430 | if end_position in tok_s_to_ns_map: 431 | ns_end_position = tok_s_to_ns_map[end_position] 432 | if ns_end_position in orig_ns_to_s_map: 433 | orig_end_position = orig_ns_to_s_map[ns_end_position] 434 | 435 | if orig_end_position is None: 436 | return orig_text 437 | 438 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 439 | return output_text 440 | 441 | 442 | def _compute_softmax(scores): 443 | """Compute softmax probability over raw logits.""" 444 | if not scores: 445 | return [] 446 | 447 | max_score = None 448 | for score in scores: 449 | if max_score is None or score > max_score: 450 | max_score = score 451 | 452 | exp_scores = [] 453 | total_sum = 0.0 454 | for score in scores: 455 | x = math.exp(score - max_score) 456 | exp_scores.append(x) 457 | total_sum += x 458 | 459 | probs = [] 460 | for score in exp_scores: 461 | probs.append(score / total_sum) 462 | return probs 463 | 464 | 465 | def _get_best_indexes(logits, n_best_size): 466 | """Get the n-best logits from a list.""" 467 | index_and_score = sorted( 468 | enumerate(logits), key=lambda x: x[1], reverse=True) 469 | 470 | best_indexes = [] 471 | for i in range(len(index_and_score)): 472 | if i >= n_best_size: 473 | break 474 | best_indexes.append(index_and_score[i][0]) 475 | return best_indexes 476 | 477 | 478 | def get_predictions(all_examples, all_features, all_results, n_best_size, 479 | max_answer_length, do_lower_case): 480 | 481 | example_index_to_features = collections.defaultdict(list) 482 | for feature in all_features: 483 | example_index_to_features[feature.example_index].append(feature) 484 | 485 | unique_id_to_result = {} 486 | for result in all_results: 487 | unique_id_to_result[result.unique_id] = result 488 | 489 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 490 | "PrelimPrediction", 491 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) 492 | 493 | all_predictions = collections.OrderedDict() 494 | all_nbest_json = collections.OrderedDict() 495 | for (example_index, example) in enumerate(all_examples): 496 | features = example_index_to_features[example_index] 497 | 498 | prelim_predictions = [] 499 | for (feature_index, feature) in enumerate(features): 500 | result = unique_id_to_result[feature.unique_id] 501 | 502 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 503 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 504 | for start_index in start_indexes: 505 | for end_index in end_indexes: 506 | # We could hypothetically create invalid predictions, e.g., predict 507 | # that the start of the span is in the question. We throw out all 508 | # invalid predictions. 509 | if start_index >= len(feature.tokens): 510 | continue 511 | if end_index >= len(feature.tokens): 512 | continue 513 | if start_index not in feature.token_to_orig_map: 514 | continue 515 | if end_index not in feature.token_to_orig_map: 516 | continue 517 | if not feature.token_is_max_context.get(start_index, False): 518 | continue 519 | if end_index < start_index: 520 | continue 521 | length = end_index - start_index + 1 522 | if length > max_answer_length: 523 | continue 524 | prelim_predictions.append( 525 | _PrelimPrediction( 526 | feature_index=feature_index, 527 | start_index=start_index, 528 | end_index=end_index, 529 | start_logit=result.start_logits[start_index], 530 | end_logit=result.end_logits[end_index])) 531 | 532 | prelim_predictions = sorted( 533 | prelim_predictions, 534 | key=lambda x: (x.start_logit + x.end_logit), 535 | reverse=True) 536 | 537 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 538 | "NbestPrediction", ["text", "start_logit", "end_logit"]) 539 | 540 | seen_predictions = {} 541 | nbest = [] 542 | for pred in prelim_predictions: 543 | if len(nbest) >= n_best_size: 544 | break 545 | feature = features[pred.feature_index] 546 | 547 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] 548 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 549 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 550 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] 551 | tok_text = "".join(tok_tokens) 552 | 553 | # De-tokenize WordPieces that have been split off. 554 | tok_text = tok_text.replace(" ##", "") 555 | tok_text = tok_text.replace("##", "") 556 | 557 | # Clean whitespace 558 | tok_text = tok_text.strip() 559 | tok_text = " ".join(tok_text.split()) 560 | orig_text = "".join(orig_tokens) 561 | 562 | final_text = get_final_text(tok_text, orig_text, do_lower_case) 563 | if final_text in seen_predictions: 564 | continue 565 | 566 | seen_predictions[final_text] = True 567 | nbest.append( 568 | _NbestPrediction( 569 | text=final_text, 570 | start_logit=pred.start_logit, 571 | end_logit=pred.end_logit)) 572 | 573 | # In very rare edge cases we could have no valid predictions. So we 574 | # just create a nonce prediction in this case to avoid failure. 575 | if not nbest: 576 | nbest.append( 577 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 578 | 579 | assert len(nbest) >= 1 580 | 581 | total_scores = [] 582 | for entry in nbest: 583 | total_scores.append(entry.start_logit + entry.end_logit) 584 | 585 | probs = _compute_softmax(total_scores) 586 | 587 | nbest_json = [] 588 | for (i, entry) in enumerate(nbest): 589 | output = collections.OrderedDict() 590 | output["text"] = entry.text 591 | output["probability"] = probs[i] 592 | output["start_logit"] = entry.start_logit 593 | output["end_logit"] = entry.end_logit 594 | nbest_json.append(output) 595 | 596 | assert len(nbest_json) >= 1 597 | 598 | all_predictions[example.qas_id] = nbest_json[0]["text"] 599 | all_nbest_json[example.qas_id] = nbest_json 600 | return all_predictions, all_nbest_json 601 | 602 | 603 | RawResult = collections.namedtuple("RawResult", 604 | ["unique_id", "start_logits", "end_logits"]) 605 | 606 | 607 | if __name__ == '__main__': 608 | predict_fn = tf.contrib.predictor.from_saved_model(args_in_use.model) 609 | max_seq_length = args_in_use.max_seq_length 610 | tokenizer = tokenization.FullTokenizer(vocab_file=args_in_use.vocab_file, 611 | do_lower_case=True) 612 | if args_in_use.MODE == 'SINGLE': 613 | # while True: 614 | # paragraph = input('(PRESS q to quit)请输入段落\n> ') 615 | # question = input('(PRESS q to quit)请输入问题\n> ') 616 | # if question == 'q': 617 | # break 618 | 619 | """ 620 | DEMO EXAMPLES 621 | """ 622 | # paragraph = "《战国无双3》()是由光荣和ω-force开发的战国无双系列 的正统第三续作。本作以三大故事为主轴,分别是以武田信玄等人为主的《关东三国 志》,织田信长等人为主的《战国三杰》,石田三成等人为主的《关原的年轻武者》 ,丰富游戏内的剧情。此部份专门介绍角色,欲知武器情报、奥义字或擅长攻击类型 等,请至战国无双系列1.由于乡里大辅先生因故去世,不得不寻找其他声优接手。从 猛将传 and Z开始。2.战国无双 编年史的原创男女主角亦有专属声优。此模式是任 > 天堂游戏谜之村雨城改编的新增模式。本作中共有20张战场地图(不含村雨城),后 来发行的猛将传再新增3张战场地图。但游戏内战役数量繁多,部分地图会有兼用的> 状况,战役虚实则是以光荣发行的2本「战国无双3 人物真书」内容为主,以下是相> 关介绍。(注:前方加☆者为猛将传新增关卡及地图。)合并本篇和猛将传的内容" 623 | # question = "《战国无双3》是由哪两个公司合作开发的?" 624 | # paragraph = "弗赖永广场(Freyung)是奥地利首都维也纳的一个三角形广场,位于内城(第一区)。这个广场最初位于古罗马堡垒Vindabona城墙以外,在12世纪,奥地利公爵亨利二世邀请爱尔兰僧侣来此修建了“苏格兰修道院”(Schottenkloster),因为当时爱尔兰被称为“新苏格兰”。修道院周围的广场也称为苏格兰广场。“弗赖永”这个名称源于古代德语词汇“frey”,意为自由。这是因为修道院拥有不受公爵管理的特权,还能保护逃亡者。1773年,其侧又新建小修道院,因其形状被称为鞋盒房子(Schubladkastenhaus)。弗赖永广场成为重要的市场,各种各样的街头艺术家以表演为生。其中一种表演是“维也纳小丑”(Wiener Hanswurst)。由于霍夫堡皇宫距此不远,在17世纪和18世纪,许多奥地利贵族在广场上和附近的Herrengasse兴建他们的城市住所。1856年,拆除了弗赖永广场和毗邻的Am Hof广场之间的房屋,以加宽街道。19世纪后期,银行和其他金融机构也迁来此处设立总部。" 625 | # paragraph = "“弗赖永”这个名称源于古代德语词汇“frey”,意为自由。这是因为修道院拥有不受公爵管理的特权,还能保护逃亡者。1773年,其侧又新建小修道院,因其形状被称为鞋盒房子(Schubladkastenhaus)。弗赖永广场成为重要的市场,各种各样的街头艺术家以表演为生。其中一种表演是“维也纳小丑”(Wiener Hanswurst)。由于霍夫堡皇宫距此不远,在17世纪和18世纪,许多奥地利贵族在广场上和附近的Herrengasse兴建他们的城市住所。1856年,拆除了弗赖永广场和毗邻的Am Hof广场之间的房屋,以加宽街道。19世纪后期,银行和其他金融机构也迁来此处设立总部。" 626 | # question = "为什么修道院的名称取自意为自由的古代德语词汇“frey”?" 627 | paragraph = "历时五年四次审议的《电子商务法》已于今年1月1日起实施。《电子商务法》的颁布实施为我国规范当前电子商务市场秩序、维护公平竞争环境、保障参与主体权益、促进电子商务健康快速发展奠定了法律基础。从总体上,应该看到《电子商务法》是一部以促进电子商务发展为立法目标之一的法律,是一部权益法,也是一部促进法。《电子商务法》专门设立了“电子商务促进”章节,明确了国家发展电子商务的重点方向。其中,农村电商和电商扶贫成为促进的重点." 628 | question = "电子商务法的目的" 629 | paragraph2 = """ 630 | 第八条 本法所称对外贸易经营者,是指依法办理工商登记或者其他执业手续,依照本法和其他有关法律、行政法规的规定从事对外贸易经营活动的法人、其他组织或者个人。 631 | """ 632 | question2 = "对外贸易经营者的定义" 633 | # question = "电子商务法的生效日期" # bad case 634 | # question = "电子商务法生效时间" # 一般的 case 635 | # question = "电子商务法实施时间" # good case 636 | input_data = [{ 637 | "paragraphs": [ 638 | { 639 | "context": paragraph, 640 | "qas": [ 641 | { 642 | "question": question, 643 | "id": "RANDOM_QUESTION_ID" 644 | } 645 | ] 646 | }, 647 | { 648 | "context": paragraph2, 649 | "qas": [ 650 | { 651 | "question": question2, 652 | "id": "RANDOM_QUESTION_ID2" 653 | } 654 | ] 655 | }, 656 | ] 657 | }] 658 | predict_examples = read_squad_examples(input_data, tokenizer) 659 | 660 | features = convert_examples_to_features( 661 | examples=predict_examples, 662 | tokenizer=tokenizer, 663 | max_seq_length=args_in_use.max_seq_length, 664 | doc_stride=args_in_use.doc_stride, 665 | max_query_length=args_in_use.max_query_length) 666 | 667 | start_time = time.time() 668 | results = predict_fn({ 669 | args_in_use.tensor_unique_ids: [feature.unique_id for feature in features], 670 | args_in_use.tensor_input_ids: [feature.input_ids for feature in features], 671 | args_in_use.tensor_input_mask: [feature.input_mask for feature in features], 672 | args_in_use.tensor_segment_ids: [feature.segment_ids for feature in features], 673 | }) 674 | print(f'elapsed time: {time.time()-start_time}s') 675 | print(np.shape(results['end_logits'])) 676 | unique_ids = results['unique_ids'] 677 | start_logits_list = results['start_logits'] 678 | end_logits_list = results['end_logits'] 679 | all_results = [] 680 | for unique_id, start_logits, end_logits in zip(unique_ids, start_logits_list, end_logits_list): 681 | # unique_id = int(result["unique_ids"]) 682 | # start_logits = [float(x) for x in result["start_logits"].flat] 683 | # end_logits = [float(x) for x in result["end_logits"].flat] 684 | _raw_result = RawResult( 685 | unique_id=unique_id, 686 | start_logits=start_logits, 687 | end_logits=end_logits) 688 | all_results.append(_raw_result) 689 | all_predictions, all_nbest_json = get_predictions(predict_examples, features, all_results, 690 | args_in_use.n_best_size, args_in_use.max_answer_length, 691 | args_in_use.do_lower_case) 692 | print(all_predictions) 693 | # print(all_nbest_json) 694 | # 分数对不上 695 | 696 | elif args_in_use.MODE == 'BATCH': 697 | pass 698 | # TO BE IMPLEMENTED 699 | else: 700 | raise ValueError('unsupported mode') 701 | -------------------------------------------------------------------------------- /cmrc/cmrc_tool/run_squad_inf_cmrc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import os 4 | import sys 5 | import math 6 | import tokenization 7 | import collections 8 | import tensorflow as tf 9 | import numpy as np 10 | import time 11 | import six 12 | import pdb 13 | 14 | parser = argparse.ArgumentParser( 15 | description='BERT model saved model case/batch test program, exit with q') 16 | 17 | parser.add_argument('--model', type=str, 18 | default='./cmrc_1588284341', help='the path for the model') 19 | parser.add_argument('--vocab_file', type=str, 20 | default='./1586959006/vocab.txt') 21 | parser.add_argument('--max_seq_length', type=int, default=512, 22 | help='the length of sequence for text padding') 23 | parser.add_argument('--do_lower_case', type=bool, default=True, 24 | help='Whether to lower case the input text. Should be True for uncased models and False for cased models.') 25 | parser.add_argument('--max_answer_length', type=int, default=312, 26 | help='The maximum length of an answer that can be generated. This is needed because the start and end predictions are not conditioned on one another.') 27 | parser.add_argument('--n_best_size', type=int, default=20, 28 | help='"The total number of n-best predictions to generate in the nbest_predictions.json output file.') 29 | parser.add_argument('--doc_stride', type=int, default=128, 30 | help='the length of document stride') 31 | parser.add_argument('--max_query_length', type=int, default=64, 32 | help='the max length of query length') 33 | parser.add_argument('--tensor_start_positions', type=str, default='start_positions', 34 | help='the start_positions feature name for saved model') 35 | parser.add_argument('--tensor_end_positions', type=str, default='end_positions', 36 | help='the end_positions feature name for saved model') 37 | parser.add_argument('--tensor_unique_ids', type=str, default='unique_ids', 38 | help='the unique_ids feature name for saved model') 39 | parser.add_argument('--tensor_input_ids', type=str, default='input_ids', 40 | help='the input_ids feature name for saved model') 41 | parser.add_argument('--tensor_input_mask', type=str, default='input_mask', 42 | help='the input_mask feature name for saved model') 43 | parser.add_argument('--tensor_segment_ids', type=str, default='segment_ids', 44 | help='the segment_ids feature name for saved model') 45 | parser.add_argument('--tensor_input_span_mask', type=str, default='input_span_mask', 46 | help='the input_span_mask feature name for saved model') 47 | parser.add_argument('--MODE', type=str, default='SINGLE', 48 | help='SINGLE prediction or BATCH prediction') 49 | args_in_use = parser.parse_args() 50 | 51 | 52 | class SquadExample(object): 53 | """A single training/test example for simple sequence classification. 54 | 55 | For examples without an answer, the start and end position are -1. 56 | """ 57 | 58 | def __init__(self, 59 | qas_id, 60 | question_text, 61 | doc_tokens, 62 | orig_answer_text=None, 63 | start_position=None, 64 | end_position=None, 65 | is_impossible=False): 66 | self.qas_id = qas_id 67 | self.question_text = question_text 68 | self.doc_tokens = doc_tokens 69 | self.orig_answer_text = orig_answer_text 70 | self.start_position = start_position 71 | self.end_position = end_position 72 | self.is_impossible = is_impossible 73 | 74 | def __str__(self): 75 | return self.__repr__() 76 | 77 | def __repr__(self): 78 | s = "" 79 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) 80 | s += ", question_text: %s" % ( 81 | tokenization.printable_text(self.question_text)) 82 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 83 | if self.start_position: 84 | s += ", start_position: %d" % (self.start_position) 85 | if self.start_position: 86 | s += ", end_position: %d" % (self.end_position) 87 | if self.start_position: 88 | s += ", is_impossible: %r" % (self.is_impossible) 89 | return s 90 | 91 | 92 | class InputFeatures(object): 93 | """A single set of features of data.""" 94 | 95 | def __init__(self, 96 | unique_id, 97 | example_index, 98 | doc_span_index, 99 | tokens, 100 | token_to_orig_map, 101 | token_is_max_context, 102 | input_ids, 103 | input_mask, 104 | segment_ids, 105 | input_span_mask, 106 | start_position=None, 107 | end_position=None): 108 | self.unique_id = unique_id 109 | self.example_index = example_index 110 | self.doc_span_index = doc_span_index 111 | self.tokens = tokens 112 | self.token_to_orig_map = token_to_orig_map 113 | self.token_is_max_context = token_is_max_context 114 | self.input_ids = input_ids 115 | self.input_mask = input_mask 116 | self.segment_ids = segment_ids 117 | self.input_span_mask = input_span_mask 118 | self.start_position = start_position 119 | self.end_position = end_position 120 | 121 | 122 | def _check_is_max_context(doc_spans, cur_span_index, position): 123 | """Check if this is the 'max context' doc span for the token.""" 124 | 125 | # Because of the sliding window approach taken to scoring documents, a single 126 | # token can appear in multiple documents. E.g. 127 | # Doc: the man went to the store and bought a gallon of milk 128 | # Span A: the man went to the 129 | # Span B: to the store and bought 130 | # Span C: and bought a gallon of 131 | # ... 132 | # 133 | # Now the word 'bought' will have two scores from spans B and C. We only 134 | # want to consider the score with "maximum context", which we define as 135 | # the *minimum* of its left and right context (the *sum* of left and 136 | # right context will always be the same, of course). 137 | # 138 | # In the example the maximum context for 'bought' would be span C since 139 | # it has 1 left context and 3 right context, while span B has 4 left context 140 | # and 0 right context. 141 | best_score = None 142 | best_span_index = None 143 | for (span_index, doc_span) in enumerate(doc_spans): 144 | end = doc_span.start + doc_span.length - 1 145 | if position < doc_span.start: 146 | continue 147 | if position > end: 148 | continue 149 | num_left_context = position - doc_span.start 150 | num_right_context = end - position 151 | score = min(num_left_context, num_right_context) + \ 152 | 0.01 * doc_span.length 153 | if best_score is None or score > best_score: 154 | best_score = score 155 | best_span_index = span_index 156 | 157 | return cur_span_index == best_span_index 158 | 159 | 160 | class ChineseFullTokenizer(object): 161 | """Runs end-to-end tokenziation.""" 162 | 163 | def __init__(self, vocab_file, do_lower_case=False): 164 | self.vocab = tokenization.load_vocab(vocab_file) 165 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 166 | self.wordpiece_tokenizer = tokenization.WordpieceTokenizer( 167 | vocab=self.vocab) 168 | self.do_lower_case = do_lower_case 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in customize_tokenizer(text, do_lower_case=self.do_lower_case): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return tokenization.convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return tokenization.convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | def convert_examples_to_features(examples, tokenizer, max_seq_length, 186 | doc_stride, max_query_length): 187 | """Loads a data file into a list of `InputBatch`s.""" 188 | 189 | unique_id = 1000000000 190 | tokenizer = ChineseFullTokenizer( 191 | vocab_file=args_in_use.vocab_file, do_lower_case=args_in_use.do_lower_case) 192 | 193 | features = [] 194 | for (example_index, example) in enumerate(examples): 195 | query_tokens = tokenizer.tokenize(example.question_text) 196 | 197 | if len(query_tokens) > max_query_length: 198 | query_tokens = query_tokens[0:max_query_length] 199 | 200 | # 将原文变成token之后,token和原始文本index的对应关系 201 | tok_to_orig_index = [] 202 | # 将原文变成token之后,原始文本和token的index的对应关系 203 | orig_to_tok_index = [] 204 | all_doc_tokens = [] 205 | for (i, token) in enumerate(example.doc_tokens): 206 | orig_to_tok_index.append(len(all_doc_tokens)) 207 | sub_tokens = tokenizer.tokenize(token) 208 | for sub_token in sub_tokens: 209 | tok_to_orig_index.append(i) 210 | all_doc_tokens.append(sub_token) 211 | 212 | tok_start_position = None 213 | tok_end_position = None 214 | 215 | # The -3 accounts for [CLS], [SEP] and [SEP] 216 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 217 | 218 | # We can have documents that are longer than the maximum sequence length. 219 | # To deal with this we do a sliding window approach, where we take chunks 220 | # of the up to our max length with a stride of `doc_stride`. 221 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 222 | "DocSpan", ["start", "length"]) 223 | doc_spans = [] 224 | start_offset = 0 225 | while start_offset < len(all_doc_tokens): 226 | length = len(all_doc_tokens) - start_offset 227 | if length > max_tokens_for_doc: 228 | length = max_tokens_for_doc 229 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 230 | # 如果length是小于max_tokens_for_doc的话,那么就会直接break 231 | if start_offset + length == len(all_doc_tokens): 232 | break 233 | start_offset += min(length, doc_stride) 234 | 235 | for (doc_span_index, doc_span) in enumerate(doc_spans): 236 | tokens = [] 237 | token_to_orig_map = {} 238 | token_is_max_context = {} 239 | # segment_ids中0代表[CLS]、第一个[SEP]和query_tokens,1代表doc和第二个[SEP] 240 | segment_ids = [] 241 | input_span_mask = [] 242 | tokens.append("[CLS]") 243 | segment_ids.append(0) 244 | input_span_mask.append(1) 245 | for token in query_tokens: 246 | tokens.append(token) 247 | segment_ids.append(0) 248 | input_span_mask.append(0) 249 | tokens.append("[SEP]") 250 | segment_ids.append(0) 251 | input_span_mask.append(0) # TODO:check why 252 | 253 | for i in range(doc_span.length): 254 | split_token_index = doc_span.start + i 255 | # len(tokens)为[CLS]+query_tokens+[SEP]的大小,应该是doc_tokens第i个token 256 | token_to_orig_map[len( 257 | tokens)] = tok_to_orig_index[split_token_index] 258 | 259 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, 260 | split_token_index) 261 | token_is_max_context[len(tokens)] = is_max_context 262 | tokens.append(all_doc_tokens[split_token_index]) 263 | segment_ids.append(1) 264 | input_span_mask.append(1) 265 | tokens.append("[SEP]") 266 | segment_ids.append(1) 267 | input_span_mask.append(0) 268 | 269 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 270 | 271 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 272 | # tokens are attended to. 273 | input_mask = [1] * len(input_ids) 274 | 275 | # Zero-pad up to the sequence length. 276 | while len(input_ids) < max_seq_length: 277 | input_ids.append(0) 278 | input_mask.append(0) 279 | segment_ids.append(0) 280 | input_span_mask.append(0) 281 | 282 | assert len(input_ids) == max_seq_length 283 | assert len(input_mask) == max_seq_length 284 | assert len(segment_ids) == max_seq_length 285 | assert len(input_span_mask) == max_seq_length 286 | 287 | start_position = None 288 | end_position = None 289 | 290 | features.append( 291 | InputFeatures( 292 | unique_id=unique_id, 293 | example_index=example_index, 294 | doc_span_index=doc_span_index, 295 | tokens=tokens, # 可能会引入UNK 296 | token_to_orig_map=token_to_orig_map, 297 | token_is_max_context=token_is_max_context, 298 | input_ids=input_ids, 299 | input_mask=input_mask, 300 | segment_ids=segment_ids, 301 | input_span_mask=input_span_mask, 302 | start_position=start_position, 303 | end_position=end_position)) 304 | unique_id += 1 305 | 306 | return features 307 | 308 | 309 | def customize_tokenizer(text, do_lower_case=False): 310 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) 311 | temp_x = "" 312 | text = tokenization.convert_to_unicode(text) 313 | for c in text: 314 | if tokenizer._is_chinese_char(ord(c)) or tokenization._is_punctuation(c) or tokenization._is_whitespace(c) or tokenization._is_control(c): 315 | temp_x += " " + c + " " 316 | else: 317 | temp_x += c 318 | if do_lower_case: 319 | temp_x = temp_x.lower() 320 | return temp_x.split() 321 | 322 | 323 | def read_squad_examples(input_file): 324 | """Read a SQuAD json file into a list of SquadExample.""" 325 | # with tf.gfile.Open(input_file, "r") as reader: 326 | # input_data = json.load(reader)["data"] 327 | 328 | # 329 | examples = [] 330 | for entry in input_data: 331 | for paragraph in entry["paragraphs"]: 332 | paragraph_text = paragraph["context"] 333 | raw_doc_tokens = customize_tokenizer( 334 | paragraph_text, do_lower_case=args_in_use.do_lower_case) 335 | doc_tokens = [] 336 | char_to_word_offset = [] 337 | prev_is_whitespace = True 338 | 339 | k = 0 340 | temp_word = "" 341 | for c in paragraph_text: 342 | # c is whitespace 343 | if tokenization._is_whitespace(c) or not c.split(): 344 | char_to_word_offset.append(k-1) 345 | continue 346 | else: 347 | temp_word += c 348 | char_to_word_offset.append(k) 349 | if args_in_use.do_lower_case: 350 | temp_word = temp_word.lower() 351 | if temp_word == raw_doc_tokens[k]: 352 | doc_tokens.append(temp_word) 353 | temp_word = "" 354 | k += 1 355 | if k != len(raw_doc_tokens): 356 | print(paragraph) 357 | print(doc_tokens) 358 | print(raw_doc_tokens) 359 | assert k == len(raw_doc_tokens) 360 | 361 | for qa in paragraph["qas"]: 362 | qas_id = qa["id"] 363 | question_text = qa["question"] 364 | start_position = None 365 | end_position = None 366 | orig_answer_text = None 367 | 368 | example = SquadExample( 369 | qas_id=qas_id, 370 | question_text=question_text, 371 | doc_tokens=doc_tokens, 372 | orig_answer_text=orig_answer_text, 373 | start_position=start_position, 374 | end_position=end_position) 375 | examples.append(example) 376 | tf.logging.info("**********read_squad_examples complete!**********") 377 | 378 | return examples 379 | 380 | 381 | def get_final_text(pred_text, orig_text, do_lower_case): 382 | """Project the tokenized prediction back to the original text.""" 383 | 384 | # When we created the data, we kept track of the alignment between original 385 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 386 | # now `orig_text` contains the span of our original text corresponding to the 387 | # span that we predicted. 388 | # 389 | # However, `orig_text` may contain extra characters that we don't want in 390 | # our prediction. 391 | # 392 | # For example, let's say: 393 | # pred_text = steve smith 394 | # orig_text = Steve Smith's 395 | # 396 | # We don't want to return `orig_text` because it contains the extra "'s". 397 | # 398 | # We don't want to return `pred_text` because it's already been normalized 399 | # (the SQuAD eval script also does punctuation stripping/lower casing but 400 | # our tokenizer does additional normalization like stripping accent 401 | # characters). 402 | # 403 | # What we really want to return is "Steve Smith". 404 | # 405 | # Therefore, we have to apply a semi-complicated alignment heruistic between 406 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This 407 | # can fail in certain cases in which case we just return `orig_text`. 408 | 409 | def _strip_spaces(text): 410 | ns_chars = [] 411 | ns_to_s_map = collections.OrderedDict() 412 | for (i, c) in enumerate(text): 413 | if c == " ": 414 | continue 415 | ns_to_s_map[len(ns_chars)] = i 416 | ns_chars.append(c) 417 | ns_text = "".join(ns_chars) 418 | return (ns_text, ns_to_s_map) 419 | 420 | # We first tokenize `orig_text`, strip whitespace from the result 421 | # and `pred_text`, and check if they are the same length. If they are 422 | # NOT the same length, the heuristic has failed. If they are the same 423 | # length, we assume the characters are one-to-one aligned. 424 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) 425 | 426 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 427 | 428 | start_position = tok_text.find(pred_text) 429 | if start_position == -1: 430 | tf.logging.info( 431 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 432 | return orig_text 433 | end_position = start_position + len(pred_text) - 1 434 | 435 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 436 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 437 | 438 | if len(orig_ns_text) != len(tok_ns_text): 439 | if args_in_use.verbose_logging: 440 | tf.logging.info("Length not equal after stripping spaces: '%s' vs '%s'", 441 | orig_ns_text, tok_ns_text) 442 | return orig_text 443 | 444 | # We then project the characters in `pred_text` back to `orig_text` using 445 | # the character-to-character alignment. 446 | tok_s_to_ns_map = {} 447 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map): 448 | tok_s_to_ns_map[tok_index] = i 449 | 450 | orig_start_position = None 451 | if start_position in tok_s_to_ns_map: 452 | ns_start_position = tok_s_to_ns_map[start_position] 453 | if ns_start_position in orig_ns_to_s_map: 454 | orig_start_position = orig_ns_to_s_map[ns_start_position] 455 | 456 | if orig_start_position is None: 457 | if args_in_use.verbose_logging: 458 | tf.logging.info("Couldn't map start position") 459 | return orig_text 460 | 461 | orig_end_position = None 462 | if end_position in tok_s_to_ns_map: 463 | ns_end_position = tok_s_to_ns_map[end_position] 464 | if ns_end_position in orig_ns_to_s_map: 465 | orig_end_position = orig_ns_to_s_map[ns_end_position] 466 | 467 | if orig_end_position is None: 468 | if args_in_use.verbose_logging: 469 | tf.logging.info("Couldn't map end position") 470 | return orig_text 471 | 472 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 473 | return output_text 474 | 475 | 476 | def _compute_softmax(scores): 477 | """Compute softmax probability over raw logits.""" 478 | if not scores: 479 | return [] 480 | 481 | max_score = None 482 | for score in scores: 483 | if max_score is None or score > max_score: 484 | max_score = score 485 | 486 | exp_scores = [] 487 | total_sum = 0.0 488 | for score in scores: 489 | x = math.exp(score - max_score) 490 | exp_scores.append(x) 491 | total_sum += x 492 | 493 | probs = [] 494 | for score in exp_scores: 495 | probs.append(score / total_sum) 496 | return probs 497 | 498 | 499 | def _get_best_indexes(logits, n_best_size): 500 | """Get the n-best logits from a list.""" 501 | index_and_score = sorted( 502 | enumerate(logits), key=lambda x: x[1], reverse=True) 503 | 504 | best_indexes = [] 505 | for i in range(len(index_and_score)): 506 | if i >= n_best_size: 507 | break 508 | best_indexes.append(index_and_score[i][0]) 509 | return best_indexes 510 | 511 | 512 | def get_predictions(all_examples, all_features, all_results, n_best_size, 513 | max_answer_length, do_lower_case): 514 | 515 | example_index_to_features = collections.defaultdict(list) 516 | for feature in all_features: 517 | example_index_to_features[feature.example_index].append(feature) 518 | 519 | unique_id_to_result = {} 520 | for result in all_results: 521 | unique_id_to_result[result.unique_id] = result 522 | 523 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 524 | "PrelimPrediction", 525 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) 526 | 527 | all_predictions = collections.OrderedDict() 528 | all_nbest_json = collections.OrderedDict() 529 | for (example_index, example) in enumerate(all_examples): 530 | features = example_index_to_features[example_index] 531 | 532 | prelim_predictions = [] 533 | for (feature_index, feature) in enumerate(features): 534 | result = unique_id_to_result[feature.unique_id] 535 | 536 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 537 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 538 | for start_index in start_indexes: 539 | for end_index in end_indexes: 540 | # We could hypothetically create invalid predictions, e.g., predict 541 | # that the start of the span is in the question. We throw out all 542 | # invalid predictions. 543 | if start_index >= len(feature.tokens): 544 | continue 545 | if end_index >= len(feature.tokens): 546 | continue 547 | if start_index not in feature.token_to_orig_map: 548 | continue 549 | if end_index not in feature.token_to_orig_map: 550 | continue 551 | if not feature.token_is_max_context.get(start_index, False): 552 | continue 553 | if end_index < start_index: 554 | continue 555 | length = end_index - start_index + 1 556 | if length > max_answer_length: 557 | continue 558 | prelim_predictions.append( 559 | _PrelimPrediction( 560 | feature_index=feature_index, 561 | start_index=start_index, 562 | end_index=end_index, 563 | start_logit=result.start_logits[start_index], 564 | end_logit=result.end_logits[end_index])) 565 | 566 | prelim_predictions = sorted( 567 | prelim_predictions, 568 | key=lambda x: (x.start_logit + x.end_logit), 569 | reverse=True) 570 | 571 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 572 | "NbestPrediction", ["text", "start_logit", "end_logit", "start_index", "end_index"]) 573 | 574 | seen_predictions = {} 575 | nbest = [] 576 | for pred in prelim_predictions: 577 | if len(nbest) >= n_best_size: 578 | break 579 | feature = features[pred.feature_index] 580 | 581 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] 582 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 583 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 584 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] 585 | tok_text = " ".join(tok_tokens) 586 | 587 | # De-tokenize WordPieces that have been split off. 588 | tok_text = tok_text.replace(" ##", "") 589 | tok_text = tok_text.replace("##", "") 590 | 591 | # Clean whitespace 592 | tok_text = tok_text.strip() 593 | tok_text = " ".join(tok_text.split()) 594 | orig_text = " ".join(orig_tokens) 595 | 596 | final_text = get_final_text(tok_text, orig_text, do_lower_case) 597 | final_text = final_text.replace(' ', '') 598 | if final_text in seen_predictions: 599 | continue 600 | 601 | seen_predictions[final_text] = True 602 | nbest.append( 603 | _NbestPrediction( 604 | text=final_text, 605 | start_logit=pred.start_logit, 606 | end_logit=pred.end_logit, 607 | start_index=pred.start_index, 608 | end_index=pred.end_index)) 609 | 610 | # In very rare edge cases we could have no valid predictions. So we 611 | # just create a nonce prediction in this case to avoid failure. 612 | if not nbest: 613 | nbest.append( 614 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 615 | 616 | assert len(nbest) >= 1 617 | 618 | total_scores = [] 619 | best_non_null_entry = None 620 | for entry in nbest: 621 | total_scores.append(entry.start_logit + entry.end_logit) 622 | if not best_non_null_entry: 623 | if entry.text: 624 | best_non_null_entry = entry 625 | 626 | probs = _compute_softmax(total_scores) 627 | 628 | nbest_json = [] 629 | for (i, entry) in enumerate(nbest): 630 | output = collections.OrderedDict() 631 | output["text"] = entry.text 632 | output["probability"] = probs[i] 633 | output["start_logit"] = entry.start_logit 634 | output["end_logit"] = entry.end_logit 635 | output["start_index"] = entry.start_index 636 | output["end_index"] = entry.end_index 637 | nbest_json.append(output) 638 | 639 | assert len(nbest_json) >= 1 640 | 641 | all_predictions[example.qas_id] = nbest_json[0]["text"] 642 | all_nbest_json[example.qas_id] = nbest_json 643 | return all_predictions, all_nbest_json 644 | 645 | 646 | RawResult = collections.namedtuple("RawResult", 647 | ["unique_id", "start_logits", "end_logits"]) 648 | 649 | 650 | if __name__ == '__main__': 651 | predict_fn = tf.contrib.predictor.from_saved_model(args_in_use.model) 652 | max_seq_length = args_in_use.max_seq_length 653 | tokenizer = tokenization.FullTokenizer(vocab_file=args_in_use.vocab_file, 654 | do_lower_case=True) 655 | if args_in_use.MODE == 'SINGLE': 656 | # while True: 657 | # paragraph = input('(PRESS q to quit)请输入段落\n> ') 658 | # question = input('(PRESS q to quit)请输入问题\n> ') 659 | # if question == 'q': 660 | # break 661 | 662 | """ 663 | DEMO EXAMPLES 664 | """ 665 | # paragraph = "《战国无双3》()是由光荣和ω-force开发的战国无双系列 的正统第三续作。本作以三大故事为主轴,分别是以武田信玄等人为主的《关东三国 志》,织田信长等人为主的《战国三杰》,石田三成等人为主的《关原的年轻武者》 ,丰富游戏内的剧情。此部份专门介绍角色,欲知武器情报、奥义字或擅长攻击类型 等,请至战国无双系列1.由于乡里大辅先生因故去世,不得不寻找其他声优接手。从 猛将传 and Z开始。2.战国无双 编年史的原创男女主角亦有专属声优。此模式是任 > 天堂游戏谜之村雨城改编的新增模式。本作中共有20张战场地图(不含村雨城),后 来发行的猛将传再新增3张战场地图。但游戏内战役数量繁多,部分地图会有兼用的> 状况,战役虚实则是以光荣发行的2本「战国无双3 人物真书」内容为主,以下是相> 关介绍。(注:前方加☆者为猛将传新增关卡及地图。)合并本篇和猛将传的内容" 666 | # question = "《战国无双3》是由哪两个公司合作开发的?" 667 | # paragraph = "弗赖永广场(Freyung)是奥地利首都维也纳的一个三角形广场,位于内城(第一区)。这个广场最初位于古罗马堡垒Vindabona城墙以外,在12世纪,奥地利公爵亨利二世邀请爱尔兰僧侣来此修建了“苏格兰修道院”(Schottenkloster),因为当时爱尔兰被称为“新苏格兰”。修道院周围的广场也称为苏格兰广场。“弗赖永”这个名称源于古代德语词汇“frey”,意为自由。这是因为修道院拥有不受公爵管理的特权,还能保护逃亡者。1773年,其侧又新建小修道院,因其形状被称为鞋盒房子(Schubladkastenhaus)。弗赖永广场成为重要的市场,各种各样的街头艺术家以表演为生。其中一种表演是“维也纳小丑”(Wiener Hanswurst)。由于霍夫堡皇宫距此不远,在17世纪和18世纪,许多奥地利贵族在广场上和附近的Herrengasse兴建他们的城市住所。1856年,拆除了弗赖永广场和毗邻的Am Hof广场之间的房屋,以加宽街道。19世纪后期,银行和其他金融机构也迁来此处设立总部。" 668 | # paragraph = "“弗赖永”这个名称源于古代德语词汇“frey”,意为自由。这是因为修道院拥有不受公爵管理的特权,还能保护逃亡者。1773年,其侧又新建小修道院,因其形状被称为鞋盒房子(Schubladkastenhaus)。弗赖永广场成为重要的市场,各种各样的街头艺术家以表演为生。其中一种表演是“维也纳小丑”(Wiener Hanswurst)。由于霍夫堡皇宫距此不远,在17世纪和18世纪,许多奥地利贵族在广场上和附近的Herrengasse兴建他们的城市住所。1856年,拆除了弗赖永广场和毗邻的Am Hof广场之间的房屋,以加宽街道。19世纪后期,银行和其他金融机构也迁来此处设立总部。" 669 | # question = "为什么修道院的名称取自意为自由的古代德语词汇“frey”?" 670 | paragraph = "历时五年四次审议的《电子商务法》已于今年1月1日起实施。《电子商务法》的颁布实施为我国规范当前电子商务市场秩序、维护公平竞争环境、保障参与主体权益、促进电子商务健康快速发展奠定了法律基础。从总体上,应该看到《电子商务法》是一部以促进电子商务发展为立法目标之一的法律,是一部权益法,也是一部促进法。《电子商务法》专门设立了“电子商务促进”章节,明确了国家发展电子商务的重点方向。其中,农村电商和电商扶贫成为促进的重点." 671 | question = "电子商务法的目的" 672 | paragraph2 ="阳光板大部分使用的是聚碳酸酯(PC)原料生产,利用空挤压工艺在耐候性脆弱的PC板材上空挤压UV树脂,质量好一点的板面均匀分布有高浓度的UV层,阻挡紫外线的穿过,防止板材黄变,延长板材寿命使产品使用寿命达到10年以上。并且产品具有长期持续透明性的特点。(有单面和双面UV防护)。用途:住宅/商厦采光天幕,工厂厂房 仓库采光顶,体育场馆采光顶,广告牌,通道/停车棚,游泳池/温室覆盖,室内隔断。另本司有隔热保温的PC板材做温棚 遮阳棚 都不错2832217048@qq.com" 673 | question2 = "阳光板雨棚能用几年" 674 | paragraph3 = "藏蓝色,兼于蓝色和黑色之间,既有蓝色的沉静安宁,也有黑色的神秘成熟,既有黑色的收敛效果,又不乏蓝色的洁净长久,虽然不会大热流行,却是可以长久的信任,当藏蓝色与其他颜色相遇,你便会懂得它内在的涵养。藏蓝色+橙色单纯的藏蓝色会给人很严肃的气氛,橙色的点缀让藏蓝色也充满时尚活力。藏蓝色+白色白色是藏蓝色的最佳搭档,两者搭档最容易显得很干净,藏蓝色和白色营造的洗练感,让通勤装永远都不会过时,展现出都市女性的利落感。藏蓝色+粉色藏蓝色和粉色组合散发出成熟优雅的女人味,让粉色显出别样娇嫩。藏蓝色+米色藏蓝色和米色的搭配散发出浓郁的知性气质,稚气的设计细节更显年轻。藏蓝色+红色藏蓝色和红色的搭配更加的沉稳,也更具存在感,如果是面积差不多的服装来搭配,可以用红色的小物点缀来巧妙的平衡。藏蓝色+松石绿藏蓝色搭配柔和的松石绿色给人上品好品质的感觉,用凉鞋和项链来点缀更加具有层次感。藏蓝色+黄色明亮的黄色热情洋溢的融化了藏蓝色的冰冷静谧,细节感的设计更加具有轻松休闲的气氛。藏蓝色+金色推荐单品:藏蓝色" 675 | question3 = "藏蓝色配什么颜色好看" 676 | # question = "电子商务法的生效日期" # bad case 677 | # question = "电子商务法生效时间" # 一般的 case 678 | # question = "电子商务法实施时间" # good case 679 | input_data = [{ 680 | "paragraphs": [ 681 | { 682 | "context": paragraph, 683 | "qas": [ 684 | { 685 | "question": question, 686 | "id": "RANDOM_QUESTION_ID" 687 | } 688 | ] 689 | }, 690 | { 691 | "context": paragraph2, 692 | "qas": [ 693 | { 694 | "question": question2, 695 | "id": "RANDOM_QUESTION_ID2" 696 | } 697 | ] 698 | }, 699 | { 700 | "context": paragraph3, 701 | "qas": [ 702 | { 703 | "question": question3, 704 | "id": "RANDOM_QUESTION_ID3" 705 | } 706 | ] 707 | }, 708 | ] 709 | }] 710 | predict_examples = read_squad_examples(input_data) 711 | 712 | features = convert_examples_to_features( 713 | examples=predict_examples, 714 | tokenizer=tokenizer, 715 | max_seq_length=args_in_use.max_seq_length, 716 | doc_stride=args_in_use.doc_stride, 717 | max_query_length=args_in_use.max_query_length) 718 | 719 | start_time = time.time() 720 | results = predict_fn({ 721 | args_in_use.tensor_unique_ids: [feature.unique_id for feature in features], 722 | args_in_use.tensor_input_ids: [feature.input_ids for feature in features], 723 | args_in_use.tensor_input_mask: [feature.input_mask for feature in features], 724 | args_in_use.tensor_segment_ids: [feature.segment_ids for feature in features], 725 | args_in_use.tensor_input_span_mask: [feature.input_span_mask for feature in features], 726 | }) 727 | print(f'elapsed time: {time.time()-start_time}s') 728 | print(np.shape(results['end_logits'])) 729 | unique_ids = results['unique_ids'] 730 | start_logits_list = results['start_logits'] 731 | end_logits_list = results['end_logits'] 732 | all_results = [] 733 | for unique_id, start_logits, end_logits in zip(unique_ids, start_logits_list, end_logits_list): 734 | # unique_id = int(result["unique_ids"]) 735 | # start_logits = [float(x) for x in result["start_logits"].flat] 736 | # end_logits = [float(x) for x in result["end_logits"].flat] 737 | _raw_result = RawResult( 738 | unique_id=unique_id, 739 | start_logits=start_logits, 740 | end_logits=end_logits) 741 | all_results.append(_raw_result) 742 | all_predictions, all_nbest_json = get_predictions(predict_examples, features, all_results, 743 | args_in_use.n_best_size, args_in_use.max_answer_length, 744 | args_in_use.do_lower_case) 745 | print(all_predictions) 746 | # print(all_nbest_json) 747 | # 分数对不上 748 | 749 | elif args_in_use.MODE == 'BATCH': 750 | pass 751 | # TO BE IMPLEMENTED 752 | else: 753 | raise ValueError('unsupported mode') 754 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import six 27 | import tensorflow as tf 28 | 29 | 30 | class BertConfig(object): 31 | """Configuration for `BertModel`.""" 32 | 33 | def __init__(self, 34 | vocab_size, 35 | hidden_size=768, 36 | num_hidden_layers=12, 37 | num_attention_heads=12, 38 | intermediate_size=3072, 39 | hidden_act="gelu", 40 | hidden_dropout_prob=0.1, 41 | attention_probs_dropout_prob=0.1, 42 | max_position_embeddings=512, 43 | type_vocab_size=16, 44 | initializer_range=0.02): 45 | """Constructs BertConfig. 46 | 47 | Args: 48 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 49 | hidden_size: Size of the encoder layers and the pooler layer. 50 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 51 | num_attention_heads: Number of attention heads for each attention layer in 52 | the Transformer encoder. 53 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 54 | layer in the Transformer encoder. 55 | hidden_act: The non-linear activation function (function or string) in the 56 | encoder and pooler. 57 | hidden_dropout_prob: The dropout probability for all fully connected 58 | layers in the embeddings, encoder, and pooler. 59 | attention_probs_dropout_prob: The dropout ratio for the attention 60 | probabilities. 61 | max_position_embeddings: The maximum sequence length that this model might 62 | ever be used with. Typically set this to something large just in case 63 | (e.g., 512 or 1024 or 2048). 64 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 65 | `BertModel`. 66 | initializer_range: The stdev of the truncated_normal_initializer for 67 | initializing all weight matrices. 68 | """ 69 | self.vocab_size = vocab_size 70 | self.hidden_size = hidden_size 71 | self.num_hidden_layers = num_hidden_layers 72 | self.num_attention_heads = num_attention_heads 73 | self.hidden_act = hidden_act 74 | self.intermediate_size = intermediate_size 75 | self.hidden_dropout_prob = hidden_dropout_prob 76 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 77 | self.max_position_embeddings = max_position_embeddings 78 | self.type_vocab_size = type_vocab_size 79 | self.initializer_range = initializer_range 80 | 81 | @classmethod 82 | def from_dict(cls, json_object): 83 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 84 | config = BertConfig(vocab_size=None) 85 | for (key, value) in six.iteritems(json_object): 86 | config.__dict__[key] = value 87 | return config 88 | 89 | @classmethod 90 | def from_json_file(cls, json_file): 91 | """Constructs a `BertConfig` from a json file of parameters.""" 92 | with tf.gfile.GFile(json_file, "r") as reader: 93 | text = reader.read() 94 | return cls.from_dict(json.loads(text)) 95 | 96 | def to_dict(self): 97 | """Serializes this instance to a Python dictionary.""" 98 | output = copy.deepcopy(self.__dict__) 99 | return output 100 | 101 | def to_json_string(self): 102 | """Serializes this instance to a JSON string.""" 103 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 104 | 105 | 106 | class BertModel(object): 107 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 108 | 109 | Example usage: 110 | 111 | ```python 112 | # Already been converted into WordPiece token ids 113 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 114 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 115 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 116 | 117 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 118 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 119 | 120 | model = modeling.BertModel(config=config, is_training=True, 121 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 122 | 123 | label_embeddings = tf.get_variable(...) 124 | pooled_output = model.get_pooled_output() 125 | logits = tf.matmul(pooled_output, label_embeddings) 126 | ... 127 | ``` 128 | """ 129 | 130 | def __init__(self, 131 | config, 132 | is_training, 133 | input_ids, 134 | input_mask=None, 135 | token_type_ids=None, 136 | use_one_hot_embeddings=True, 137 | scope=None): 138 | """Constructor for BertModel. 139 | 140 | Args: 141 | config: `BertConfig` instance. 142 | is_training: bool. true for training model, false for eval model. Controls 143 | whether dropout will be applied. 144 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 145 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 146 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 147 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 148 | embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, 149 | it is much faster if this is True, on the CPU or GPU, it is faster if 150 | this is False. 151 | scope: (optional) variable scope. Defaults to "bert". 152 | 153 | Raises: 154 | ValueError: The config is invalid or one of the input tensor shapes 155 | is invalid. 156 | """ 157 | config = copy.deepcopy(config) 158 | if not is_training: 159 | config.hidden_dropout_prob = 0.0 160 | config.attention_probs_dropout_prob = 0.0 161 | 162 | input_shape = get_shape_list(input_ids, expected_rank=2) 163 | batch_size = input_shape[0] 164 | seq_length = input_shape[1] 165 | 166 | if input_mask is None: 167 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 168 | 169 | if token_type_ids is None: 170 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 171 | 172 | with tf.variable_scope(scope, default_name="bert"): 173 | with tf.variable_scope("embeddings"): 174 | # Perform embedding lookup on the word ids. 175 | (self.embedding_output, self.embedding_table) = embedding_lookup( 176 | input_ids=input_ids, 177 | vocab_size=config.vocab_size, 178 | embedding_size=config.hidden_size, 179 | initializer_range=config.initializer_range, 180 | word_embedding_name="word_embeddings", 181 | use_one_hot_embeddings=use_one_hot_embeddings) 182 | 183 | # Add positional embeddings and token type embeddings, then layer 184 | # normalize and perform dropout. 185 | self.embedding_output = embedding_postprocessor( 186 | input_tensor=self.embedding_output, 187 | use_token_type=True, 188 | token_type_ids=token_type_ids, 189 | token_type_vocab_size=config.type_vocab_size, 190 | token_type_embedding_name="token_type_embeddings", 191 | use_position_embeddings=True, 192 | position_embedding_name="position_embeddings", 193 | initializer_range=config.initializer_range, 194 | max_position_embeddings=config.max_position_embeddings, 195 | dropout_prob=config.hidden_dropout_prob) 196 | 197 | with tf.variable_scope("encoder"): 198 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 199 | # mask of shape [batch_size, seq_length, seq_length] which is used 200 | # for the attention scores. 201 | attention_mask = create_attention_mask_from_input_mask( 202 | input_ids, input_mask) 203 | 204 | # Run the stacked transformer. 205 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 206 | self.all_encoder_layers = transformer_model( 207 | input_tensor=self.embedding_output, 208 | attention_mask=attention_mask, 209 | hidden_size=config.hidden_size, 210 | num_hidden_layers=config.num_hidden_layers, 211 | num_attention_heads=config.num_attention_heads, 212 | intermediate_size=config.intermediate_size, 213 | intermediate_act_fn=get_activation(config.hidden_act), 214 | hidden_dropout_prob=config.hidden_dropout_prob, 215 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 216 | initializer_range=config.initializer_range, 217 | do_return_all_layers=True) 218 | 219 | self.sequence_output = self.all_encoder_layers[-1] 220 | # The "pooler" converts the encoded sequence tensor of shape 221 | # [batch_size, seq_length, hidden_size] to a tensor of shape 222 | # [batch_size, hidden_size]. This is necessary for segment-level 223 | # (or segment-pair-level) classification tasks where we need a fixed 224 | # dimensional representation of the segment. 225 | with tf.variable_scope("pooler"): 226 | # We "pool" the model by simply taking the hidden state corresponding 227 | # to the first token. We assume that this has been pre-trained 228 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 229 | self.pooled_output = tf.layers.dense( 230 | first_token_tensor, 231 | config.hidden_size, 232 | activation=tf.tanh, 233 | kernel_initializer=create_initializer(config.initializer_range)) 234 | 235 | def get_pooled_output(self): 236 | return self.pooled_output 237 | 238 | def get_sequence_output(self): 239 | """Gets final hidden layer of encoder. 240 | 241 | Returns: 242 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 243 | to the final hidden of the transformer encoder. 244 | """ 245 | return self.sequence_output 246 | 247 | def get_all_encoder_layers(self): 248 | return self.all_encoder_layers 249 | 250 | def get_embedding_output(self): 251 | """Gets output of the embedding lookup (i.e., input to the transformer). 252 | 253 | Returns: 254 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 255 | to the output of the embedding layer, after summing the word 256 | embeddings with the positional embeddings and the token type embeddings, 257 | then performing layer normalization. This is the input to the transformer. 258 | """ 259 | return self.embedding_output 260 | 261 | def get_embedding_table(self): 262 | return self.embedding_table 263 | 264 | 265 | def gelu(input_tensor): 266 | """Gaussian Error Linear Unit. 267 | 268 | This is a smoother version of the RELU. 269 | Original paper: https://arxiv.org/abs/1606.08415 270 | 271 | Args: 272 | input_tensor: float Tensor to perform activation. 273 | 274 | Returns: 275 | `input_tensor` with the GELU activation applied. 276 | """ 277 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) 278 | return input_tensor * cdf 279 | 280 | 281 | def get_activation(activation_string): 282 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 283 | 284 | Args: 285 | activation_string: String name of the activation function. 286 | 287 | Returns: 288 | A Python function corresponding to the activation function. If 289 | `activation_string` is None, empty, or "linear", this will return None. 290 | If `activation_string` is not a string, it will return `activation_string`. 291 | 292 | Raises: 293 | ValueError: The `activation_string` does not correspond to a known 294 | activation. 295 | """ 296 | 297 | # We assume that anything that"s not a string is already an activation 298 | # function, so we just return it. 299 | if not isinstance(activation_string, six.string_types): 300 | return activation_string 301 | 302 | if not activation_string: 303 | return None 304 | 305 | act = activation_string.lower() 306 | if act == "linear": 307 | return None 308 | elif act == "relu": 309 | return tf.nn.relu 310 | elif act == "gelu": 311 | return gelu 312 | elif act == "tanh": 313 | return tf.tanh 314 | else: 315 | raise ValueError("Unsupported activation: %s" % act) 316 | 317 | 318 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 319 | """Compute the union of the current variables and checkpoint variables.""" 320 | assignment_map = {} 321 | initialized_variable_names = {} 322 | 323 | name_to_variable = collections.OrderedDict() 324 | for var in tvars: 325 | name = var.name 326 | m = re.match("^(.*):\\d+$", name) 327 | if m is not None: 328 | name = m.group(1) 329 | name_to_variable[name] = var 330 | 331 | init_vars = tf.train.list_variables(init_checkpoint) 332 | 333 | assignment_map = collections.OrderedDict() 334 | for x in init_vars: 335 | (name, var) = (x[0], x[1]) 336 | if name not in name_to_variable: 337 | continue 338 | assignment_map[name] = name 339 | initialized_variable_names[name] = 1 340 | initialized_variable_names[name + ":0"] = 1 341 | 342 | return (assignment_map, initialized_variable_names) 343 | 344 | 345 | def dropout(input_tensor, dropout_prob): 346 | """Perform dropout. 347 | 348 | Args: 349 | input_tensor: float Tensor. 350 | dropout_prob: Python float. The probability of dropping out a value (NOT of 351 | *keeping* a dimension as in `tf.nn.dropout`). 352 | 353 | Returns: 354 | A version of `input_tensor` with dropout applied. 355 | """ 356 | if dropout_prob is None or dropout_prob == 0.0: 357 | return input_tensor 358 | 359 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 360 | return output 361 | 362 | 363 | def layer_norm(input_tensor, name=None): 364 | """Run layer normalization on the last dimension of the tensor.""" 365 | return tf.contrib.layers.layer_norm( 366 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 367 | 368 | 369 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 370 | """Runs layer normalization followed by dropout.""" 371 | output_tensor = layer_norm(input_tensor, name) 372 | output_tensor = dropout(output_tensor, dropout_prob) 373 | return output_tensor 374 | 375 | 376 | def create_initializer(initializer_range=0.02): 377 | """Creates a `truncated_normal_initializer` with the given range.""" 378 | return tf.truncated_normal_initializer(stddev=initializer_range) 379 | 380 | 381 | def embedding_lookup(input_ids, 382 | vocab_size, 383 | embedding_size=128, 384 | initializer_range=0.02, 385 | word_embedding_name="word_embeddings", 386 | use_one_hot_embeddings=False): 387 | """Looks up words embeddings for id tensor. 388 | 389 | Args: 390 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 391 | ids. 392 | vocab_size: int. Size of the embedding vocabulary. 393 | embedding_size: int. Width of the word embeddings. 394 | initializer_range: float. Embedding initialization range. 395 | word_embedding_name: string. Name of the embedding table. 396 | use_one_hot_embeddings: bool. If True, use one-hot method for word 397 | embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better 398 | for TPUs. 399 | 400 | Returns: 401 | float Tensor of shape [batch_size, seq_length, embedding_size]. 402 | """ 403 | # This function assumes that the input is of shape [batch_size, seq_length, 404 | # num_inputs]. 405 | # 406 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 407 | # reshape to [batch_size, seq_length, 1]. 408 | if input_ids.shape.ndims == 2: 409 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 410 | 411 | embedding_table = tf.get_variable( 412 | name=word_embedding_name, 413 | shape=[vocab_size, embedding_size], 414 | initializer=create_initializer(initializer_range)) 415 | 416 | if use_one_hot_embeddings: 417 | flat_input_ids = tf.reshape(input_ids, [-1]) 418 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 419 | output = tf.matmul(one_hot_input_ids, embedding_table) 420 | else: 421 | output = tf.nn.embedding_lookup(embedding_table, input_ids) 422 | 423 | input_shape = get_shape_list(input_ids) 424 | 425 | output = tf.reshape(output, 426 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 427 | return (output, embedding_table) 428 | 429 | 430 | def embedding_postprocessor(input_tensor, 431 | use_token_type=False, 432 | token_type_ids=None, 433 | token_type_vocab_size=16, 434 | token_type_embedding_name="token_type_embeddings", 435 | use_position_embeddings=True, 436 | position_embedding_name="position_embeddings", 437 | initializer_range=0.02, 438 | max_position_embeddings=512, 439 | dropout_prob=0.1): 440 | """Performs various post-processing on a word embedding tensor. 441 | 442 | Args: 443 | input_tensor: float Tensor of shape [batch_size, seq_length, 444 | embedding_size]. 445 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 446 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 447 | Must be specified if `use_token_type` is True. 448 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 449 | token_type_embedding_name: string. The name of the embedding table variable 450 | for token type ids. 451 | use_position_embeddings: bool. Whether to add position embeddings for the 452 | position of each token in the sequence. 453 | position_embedding_name: string. The name of the embedding table variable 454 | for positional embeddings. 455 | initializer_range: float. Range of the weight initialization. 456 | max_position_embeddings: int. Maximum sequence length that might ever be 457 | used with this model. This can be longer than the sequence length of 458 | input_tensor, but cannot be shorter. 459 | dropout_prob: float. Dropout probability applied to the final output tensor. 460 | 461 | Returns: 462 | float tensor with same shape as `input_tensor`. 463 | 464 | Raises: 465 | ValueError: One of the tensor shapes or input values is invalid. 466 | """ 467 | input_shape = get_shape_list(input_tensor, expected_rank=3) 468 | batch_size = input_shape[0] 469 | seq_length = input_shape[1] 470 | width = input_shape[2] 471 | 472 | output = input_tensor 473 | 474 | if use_token_type: 475 | if token_type_ids is None: 476 | raise ValueError("`token_type_ids` must be specified if" 477 | "`use_token_type` is True.") 478 | token_type_table = tf.get_variable( 479 | name=token_type_embedding_name, 480 | shape=[token_type_vocab_size, width], 481 | initializer=create_initializer(initializer_range)) 482 | # This vocab will be small so we always do one-hot here, since it is always 483 | # faster for a small vocabulary. 484 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 485 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 486 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 487 | token_type_embeddings = tf.reshape(token_type_embeddings, 488 | [batch_size, seq_length, width]) 489 | output += token_type_embeddings 490 | 491 | if use_position_embeddings: 492 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 493 | with tf.control_dependencies([assert_op]): 494 | full_position_embeddings = tf.get_variable( 495 | name=position_embedding_name, 496 | shape=[max_position_embeddings, width], 497 | initializer=create_initializer(initializer_range)) 498 | # Since the position embedding table is a learned variable, we create it 499 | # using a (long) sequence length `max_position_embeddings`. The actual 500 | # sequence length might be shorter than this, for faster training of 501 | # tasks that do not have long sequences. 502 | # 503 | # So `full_position_embeddings` is effectively an embedding table 504 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 505 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 506 | # perform a slice. 507 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 508 | [seq_length, -1]) 509 | num_dims = len(output.shape.as_list()) 510 | 511 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 512 | # we broadcast among the first dimensions, which is typically just 513 | # the batch size. 514 | position_broadcast_shape = [] 515 | for _ in range(num_dims - 2): 516 | position_broadcast_shape.append(1) 517 | position_broadcast_shape.extend([seq_length, width]) 518 | position_embeddings = tf.reshape(position_embeddings, 519 | position_broadcast_shape) 520 | output += position_embeddings 521 | 522 | output = layer_norm_and_dropout(output, dropout_prob) 523 | return output 524 | 525 | 526 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 527 | """Create 3D attention mask from a 2D tensor mask. 528 | 529 | Args: 530 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 531 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 532 | 533 | Returns: 534 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 535 | """ 536 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 537 | batch_size = from_shape[0] 538 | from_seq_length = from_shape[1] 539 | 540 | to_shape = get_shape_list(to_mask, expected_rank=2) 541 | to_seq_length = to_shape[1] 542 | 543 | to_mask = tf.cast( 544 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 545 | 546 | # We don't assume that `from_tensor` is a mask (although it could be). We 547 | # don't actually care if we attend *from* padding tokens (only *to* padding) 548 | # tokens so we create a tensor of all ones. 549 | # 550 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 551 | broadcast_ones = tf.ones( 552 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 553 | 554 | # Here we broadcast along two dimensions to create the mask. 555 | mask = broadcast_ones * to_mask 556 | 557 | return mask 558 | 559 | 560 | def attention_layer(from_tensor, 561 | to_tensor, 562 | attention_mask=None, 563 | num_attention_heads=1, 564 | size_per_head=512, 565 | query_act=None, 566 | key_act=None, 567 | value_act=None, 568 | attention_probs_dropout_prob=0.0, 569 | initializer_range=0.02, 570 | do_return_2d_tensor=False, 571 | batch_size=None, 572 | from_seq_length=None, 573 | to_seq_length=None): 574 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 575 | 576 | This is an implementation of multi-headed attention based on "Attention 577 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 578 | this is self-attention. Each timestep in `from_tensor` attends to the 579 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 580 | 581 | This function first projects `from_tensor` into a "query" tensor and 582 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 583 | of tensors of length `num_attention_heads`, where each tensor is of shape 584 | [batch_size, seq_length, size_per_head]. 585 | 586 | Then, the query and key tensors are dot-producted and scaled. These are 587 | softmaxed to obtain attention probabilities. The value tensors are then 588 | interpolated by these probabilities, then concatenated back to a single 589 | tensor and returned. 590 | 591 | In practice, the multi-headed attention are done with transposes and 592 | reshapes rather than actual separate tensors. 593 | 594 | Args: 595 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 596 | from_width]. 597 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 598 | attention_mask: (optional) int32 Tensor of shape [batch_size, 599 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 600 | attention scores will effectively be set to -infinity for any positions in 601 | the mask that are 0, and will be unchanged for positions that are 1. 602 | num_attention_heads: int. Number of attention heads. 603 | size_per_head: int. Size of each attention head. 604 | query_act: (optional) Activation function for the query transform. 605 | key_act: (optional) Activation function for the key transform. 606 | value_act: (optional) Activation function for the value transform. 607 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 608 | attention probabilities. 609 | initializer_range: float. Range of the weight initializer. 610 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 611 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 612 | output will be of shape [batch_size, from_seq_length, num_attention_heads 613 | * size_per_head]. 614 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 615 | of the 3D version of the `from_tensor` and `to_tensor`. 616 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 617 | of the 3D version of the `from_tensor`. 618 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 619 | of the 3D version of the `to_tensor`. 620 | 621 | Returns: 622 | float Tensor of shape [batch_size, from_seq_length, 623 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 624 | true, this will be of shape [batch_size * from_seq_length, 625 | num_attention_heads * size_per_head]). 626 | 627 | Raises: 628 | ValueError: Any of the arguments or tensor shapes are invalid. 629 | """ 630 | 631 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 632 | seq_length, width): 633 | output_tensor = tf.reshape( 634 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 635 | 636 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 637 | return output_tensor 638 | 639 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 640 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 641 | 642 | if len(from_shape) != len(to_shape): 643 | raise ValueError( 644 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 645 | 646 | if len(from_shape) == 3: 647 | batch_size = from_shape[0] 648 | from_seq_length = from_shape[1] 649 | to_seq_length = to_shape[1] 650 | elif len(from_shape) == 2: 651 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 652 | raise ValueError( 653 | "When passing in rank 2 tensors to attention_layer, the values " 654 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 655 | "must all be specified.") 656 | 657 | # Scalar dimensions referenced here: 658 | # B = batch size (number of sequences) 659 | # F = `from_tensor` sequence length 660 | # T = `to_tensor` sequence length 661 | # N = `num_attention_heads` 662 | # H = `size_per_head` 663 | 664 | from_tensor_2d = reshape_to_matrix(from_tensor) 665 | to_tensor_2d = reshape_to_matrix(to_tensor) 666 | 667 | # `query_layer` = [B*F, N*H] 668 | query_layer = tf.layers.dense( 669 | from_tensor_2d, 670 | num_attention_heads * size_per_head, 671 | activation=query_act, 672 | name="query", 673 | kernel_initializer=create_initializer(initializer_range)) 674 | 675 | # `key_layer` = [B*T, N*H] 676 | key_layer = tf.layers.dense( 677 | to_tensor_2d, 678 | num_attention_heads * size_per_head, 679 | activation=key_act, 680 | name="key", 681 | kernel_initializer=create_initializer(initializer_range)) 682 | 683 | # `value_layer` = [B*T, N*H] 684 | value_layer = tf.layers.dense( 685 | to_tensor_2d, 686 | num_attention_heads * size_per_head, 687 | activation=value_act, 688 | name="value", 689 | kernel_initializer=create_initializer(initializer_range)) 690 | 691 | # `query_layer` = [B, N, F, H] 692 | query_layer = transpose_for_scores(query_layer, batch_size, 693 | num_attention_heads, from_seq_length, 694 | size_per_head) 695 | 696 | # `key_layer` = [B, N, T, H] 697 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 698 | to_seq_length, size_per_head) 699 | 700 | # Take the dot product between "query" and "key" to get the raw 701 | # attention scores. 702 | # `attention_scores` = [B, N, F, T] 703 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 704 | attention_scores = tf.multiply(attention_scores, 705 | 1.0 / math.sqrt(float(size_per_head))) 706 | 707 | if attention_mask is not None: 708 | # `attention_mask` = [B, 1, F, T] 709 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 710 | 711 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 712 | # masked positions, this operation will create a tensor which is 0.0 for 713 | # positions we want to attend and -10000.0 for masked positions. 714 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 715 | 716 | # Since we are adding it to the raw scores before the softmax, this is 717 | # effectively the same as removing these entirely. 718 | attention_scores += adder 719 | 720 | # Normalize the attention scores to probabilities. 721 | # `attention_probs` = [B, N, F, T] 722 | attention_probs = tf.nn.softmax(attention_scores) 723 | 724 | # This is actually dropping out entire tokens to attend to, which might 725 | # seem a bit unusual, but is taken from the original Transformer paper. 726 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 727 | 728 | # `value_layer` = [B, T, N, H] 729 | value_layer = tf.reshape( 730 | value_layer, 731 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 732 | 733 | # `value_layer` = [B, N, T, H] 734 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 735 | 736 | # `context_layer` = [B, N, F, H] 737 | context_layer = tf.matmul(attention_probs, value_layer) 738 | 739 | # `context_layer` = [B, F, N, H] 740 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 741 | 742 | if do_return_2d_tensor: 743 | # `context_layer` = [B*F, N*H] 744 | context_layer = tf.reshape( 745 | context_layer, 746 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 747 | else: 748 | # `context_layer` = [B, F, N*H] 749 | context_layer = tf.reshape( 750 | context_layer, 751 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 752 | 753 | return context_layer 754 | 755 | 756 | def transformer_model(input_tensor, 757 | attention_mask=None, 758 | hidden_size=768, 759 | num_hidden_layers=12, 760 | num_attention_heads=12, 761 | intermediate_size=3072, 762 | intermediate_act_fn=gelu, 763 | hidden_dropout_prob=0.1, 764 | attention_probs_dropout_prob=0.1, 765 | initializer_range=0.02, 766 | do_return_all_layers=False): 767 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 768 | 769 | This is almost an exact implementation of the original Transformer encoder. 770 | 771 | See the original paper: 772 | https://arxiv.org/abs/1706.03762 773 | 774 | Also see: 775 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 776 | 777 | Args: 778 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 779 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 780 | seq_length], with 1 for positions that can be attended to and 0 in 781 | positions that should not be. 782 | hidden_size: int. Hidden size of the Transformer. 783 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 784 | num_attention_heads: int. Number of attention heads in the Transformer. 785 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 786 | forward) layer. 787 | intermediate_act_fn: function. The non-linear activation function to apply 788 | to the output of the intermediate/feed-forward layer. 789 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 790 | attention_probs_dropout_prob: float. Dropout probability of the attention 791 | probabilities. 792 | initializer_range: float. Range of the initializer (stddev of truncated 793 | normal). 794 | do_return_all_layers: Whether to also return all layers or just the final 795 | layer. 796 | 797 | Returns: 798 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 799 | hidden layer of the Transformer. 800 | 801 | Raises: 802 | ValueError: A Tensor shape or parameter is invalid. 803 | """ 804 | if hidden_size % num_attention_heads != 0: 805 | raise ValueError( 806 | "The hidden size (%d) is not a multiple of the number of attention " 807 | "heads (%d)" % (hidden_size, num_attention_heads)) 808 | 809 | attention_head_size = int(hidden_size / num_attention_heads) 810 | input_shape = get_shape_list(input_tensor, expected_rank=3) 811 | batch_size = input_shape[0] 812 | seq_length = input_shape[1] 813 | input_width = input_shape[2] 814 | 815 | # The Transformer performs sum residuals on all layers so the input needs 816 | # to be the same as the hidden size. 817 | if input_width != hidden_size: 818 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 819 | (input_width, hidden_size)) 820 | 821 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 822 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 823 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 824 | # help the optimizer. 825 | prev_output = reshape_to_matrix(input_tensor) 826 | 827 | all_layer_outputs = [] 828 | for layer_idx in range(num_hidden_layers): 829 | with tf.variable_scope("layer_%d" % layer_idx): 830 | layer_input = prev_output 831 | 832 | with tf.variable_scope("attention"): 833 | attention_heads = [] 834 | with tf.variable_scope("self"): 835 | attention_head = attention_layer( 836 | from_tensor=layer_input, 837 | to_tensor=layer_input, 838 | attention_mask=attention_mask, 839 | num_attention_heads=num_attention_heads, 840 | size_per_head=attention_head_size, 841 | attention_probs_dropout_prob=attention_probs_dropout_prob, 842 | initializer_range=initializer_range, 843 | do_return_2d_tensor=True, 844 | batch_size=batch_size, 845 | from_seq_length=seq_length, 846 | to_seq_length=seq_length) 847 | attention_heads.append(attention_head) 848 | 849 | attention_output = None 850 | if len(attention_heads) == 1: 851 | attention_output = attention_heads[0] 852 | else: 853 | # In the case where we have other sequences, we just concatenate 854 | # them to the self-attention head before the projection. 855 | attention_output = tf.concat(attention_heads, axis=-1) 856 | 857 | # Run a linear projection of `hidden_size` then add a residual 858 | # with `layer_input`. 859 | with tf.variable_scope("output"): 860 | attention_output = tf.layers.dense( 861 | attention_output, 862 | hidden_size, 863 | kernel_initializer=create_initializer(initializer_range)) 864 | attention_output = dropout(attention_output, hidden_dropout_prob) 865 | attention_output = layer_norm(attention_output + layer_input) 866 | 867 | # The activation is only applied to the "intermediate" hidden layer. 868 | with tf.variable_scope("intermediate"): 869 | intermediate_output = tf.layers.dense( 870 | attention_output, 871 | intermediate_size, 872 | activation=intermediate_act_fn, 873 | kernel_initializer=create_initializer(initializer_range)) 874 | 875 | # Down-project back to `hidden_size` then add the residual. 876 | with tf.variable_scope("output"): 877 | layer_output = tf.layers.dense( 878 | intermediate_output, 879 | hidden_size, 880 | kernel_initializer=create_initializer(initializer_range)) 881 | layer_output = dropout(layer_output, hidden_dropout_prob) 882 | layer_output = layer_norm(layer_output + attention_output) 883 | prev_output = layer_output 884 | all_layer_outputs.append(layer_output) 885 | 886 | if do_return_all_layers: 887 | final_outputs = [] 888 | for layer_output in all_layer_outputs: 889 | final_output = reshape_from_matrix(layer_output, input_shape) 890 | final_outputs.append(final_output) 891 | return final_outputs 892 | else: 893 | final_output = reshape_from_matrix(prev_output, input_shape) 894 | return final_output 895 | 896 | 897 | def get_shape_list(tensor, expected_rank=None, name=None): 898 | """Returns a list of the shape of tensor, preferring static dimensions. 899 | 900 | Args: 901 | tensor: A tf.Tensor object to find the shape of. 902 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 903 | specified and the `tensor` has a different rank, and exception will be 904 | thrown. 905 | name: Optional name of the tensor for the error message. 906 | 907 | Returns: 908 | A list of dimensions of the shape of tensor. All static dimensions will 909 | be returned as python integers, and dynamic dimensions will be returned 910 | as tf.Tensor scalars. 911 | """ 912 | if name is None: 913 | name = tensor.name 914 | 915 | if expected_rank is not None: 916 | assert_rank(tensor, expected_rank, name) 917 | 918 | shape = tensor.shape.as_list() 919 | 920 | non_static_indexes = [] 921 | for (index, dim) in enumerate(shape): 922 | if dim is None: 923 | non_static_indexes.append(index) 924 | 925 | if not non_static_indexes: 926 | return shape 927 | 928 | dyn_shape = tf.shape(tensor) 929 | for index in non_static_indexes: 930 | shape[index] = dyn_shape[index] 931 | return shape 932 | 933 | 934 | def reshape_to_matrix(input_tensor): 935 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 936 | ndims = input_tensor.shape.ndims 937 | if ndims < 2: 938 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 939 | (input_tensor.shape)) 940 | if ndims == 2: 941 | return input_tensor 942 | 943 | width = input_tensor.shape[-1] 944 | output_tensor = tf.reshape(input_tensor, [-1, width]) 945 | return output_tensor 946 | 947 | 948 | def reshape_from_matrix(output_tensor, orig_shape_list): 949 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 950 | if len(orig_shape_list) == 2: 951 | return output_tensor 952 | 953 | output_shape = get_shape_list(output_tensor) 954 | 955 | orig_dims = orig_shape_list[0:-1] 956 | width = output_shape[-1] 957 | 958 | return tf.reshape(output_tensor, orig_dims + [width]) 959 | 960 | 961 | def assert_rank(tensor, expected_rank, name=None): 962 | """Raises an exception if the tensor rank is not of the expected rank. 963 | 964 | Args: 965 | tensor: A tf.Tensor to check the rank of. 966 | expected_rank: Python integer or list of integers, expected rank. 967 | name: Optional name of the tensor for the error message. 968 | 969 | Raises: 970 | ValueError: If the expected shape doesn't match the actual shape. 971 | """ 972 | if name is None: 973 | name = tensor.name 974 | 975 | expected_rank_dict = {} 976 | if isinstance(expected_rank, six.integer_types): 977 | expected_rank_dict[expected_rank] = True 978 | else: 979 | for x in expected_rank: 980 | expected_rank_dict[x] = True 981 | 982 | actual_rank = tensor.shape.ndims 983 | if actual_rank not in expected_rank_dict: 984 | scope_name = tf.get_variable_scope().name 985 | raise ValueError( 986 | "For the tensor `%s` in scope `%s`, the actual rank " 987 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 988 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 989 | --------------------------------------------------------------------------------