├── imgs
├── api_example.jpg
├── dataset_desc.jpg
├── deploy-success.jpg
├── serving_offline.jpg
└── serving_deploy_arcticture.jpg
├── requirements.txt
├── model_convert.sh
├── bert_classify_server.sh
├── cmrc
├── run_cmrc.sh
├── cmrc_eval.py
└── cmrc_tool
│ ├── tokenization.py
│ ├── run_squad_inf.py
│ └── run_squad_inf_cmrc.py
├── train.sh
├── export.sh
├── train_roberta_tiny_clue_command_gpu.sh
├── api
├── api_service.py
└── api_service_flask.py
├── .gitignore
├── optimization.py
├── test_serving.py
├── freeze_graph.py
├── README.md
├── run_savedModel_infer.py
├── test_tf_serving.py
├── run_pb_inference.py
├── tokenization.py
└── modeling.py
/imgs/api_example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dongxiaohuang/TextClassifier_Transformer/HEAD/imgs/api_example.jpg
--------------------------------------------------------------------------------
/imgs/dataset_desc.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dongxiaohuang/TextClassifier_Transformer/HEAD/imgs/dataset_desc.jpg
--------------------------------------------------------------------------------
/imgs/deploy-success.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dongxiaohuang/TextClassifier_Transformer/HEAD/imgs/deploy-success.jpg
--------------------------------------------------------------------------------
/imgs/serving_offline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dongxiaohuang/TextClassifier_Transformer/HEAD/imgs/serving_offline.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow.
2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow.
3 |
--------------------------------------------------------------------------------
/imgs/serving_deploy_arcticture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dongxiaohuang/TextClassifier_Transformer/HEAD/imgs/serving_deploy_arcticture.jpg
--------------------------------------------------------------------------------
/model_convert.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #description: Convert model from ckpt format to pb format
3 | #如果在模型文件中($TRAINED_CLASSIFIER)存在label2id.pkl文件,此处可以不用指定num_labels参数
4 |
5 | export BERT_BASE_DIR=./chinese_roberta_zh_l12
6 | export TRAINED_CLASSIFIER=./output
7 | export EXP_NAME=mobile_0_roberta_base_epoch1
8 |
9 | python freeze_graph.py \
10 | -bert_model_dir $BERT_BASE_DIR \
11 | -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
12 | -max_seq_len 128
--------------------------------------------------------------------------------
/bert_classify_server.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #chkconfig: 2345 80 90
3 | #description: ����BERT�����
4 |
5 | echo 'start BERT classify server...'
6 | rm -rf tmp*
7 |
8 | export BERT_BASE_DIR=./chinese_roberta_zh_l12
9 | export TRAINED_CLASSIFIER=./output
10 | export EXP_NAME=mobile_0_roberta_base_epoch1
11 |
12 | bert-base-serving-start \
13 | -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
14 | -bert_model_dir $BERT_BASE_DIR \
15 | -model_pb_dir $TRAINED_CLASSIFIER/$EXP_NAME \
16 | -mode CLASS \
17 | -max_seq_len 128 \
18 | -port 5575 \
19 | -port_out 5576 \
20 | -device_map 1
21 |
--------------------------------------------------------------------------------
/cmrc/run_cmrc.sh:
--------------------------------------------------------------------------------
1 | export SQUAD_DIR=./CMRC_DIR
2 | export BERT_BASE_DIR=/export/huangdongxiao/huangdongxiao/AI_QA/models/RoBERTa-tiny-clue
3 | python run_cmrc.py \
4 | --vocab_file=$BERT_BASE_DIR/vocab.txt \
5 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \
6 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
7 | --do_train=False \
8 | --do_export=True \
9 | --do_predict=True \
10 | --train_file=$SQUAD_DIR/train.json \
11 | --predict_file=$SQUAD_DIR/dev.json \
12 | --train_batch_size=16 \
13 | --learning_rate=3e-5 \
14 | --num_train_epochs=8.0 \
15 | --max_seq_length=384 \
16 | --doc_stride=128 \
17 | --output_dir=./output/cmrc \
18 | --version_2_with_negative=False
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #description: BERT fine-tuning
3 |
4 | export BERT_BASE_DIR=./chinese_roberta_zh_l12
5 | export DATA_DIR=./dat
6 | export TRAINED_CLASSIFIER=./output
7 | export MODEL_NAME=mobile_0_roberta_base_epoch1
8 |
9 | python run_classifier_serving.py \
10 | --task_name=setiment \
11 | --do_train=true \
12 | --do_eval=true \
13 | --do_predict=False \
14 | --data_dir=$DATA_DIR \
15 | --vocab_file=$BERT_BASE_DIR/vocab.txt \
16 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \
17 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
18 | --max_seq_length=128 \
19 | --train_batch_size=32 \
20 | --learning_rate=2e-5 \
21 | --num_train_epochs=1.0 \
22 | --output_dir=$TRAINED_CLASSIFIER/$MODEL_NAME
--------------------------------------------------------------------------------
/export.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #description: BERT fine-tuning
3 |
4 | export BERT_BASE_DIR=./chinese_roberta_zh_l12
5 | export DATA_DIR=./dat
6 | export TRAINED_CLASSIFIER=./output
7 | export MODEL_NAME=mobile_0_roberta_base
8 |
9 | python run_classifier_serving.py \
10 | --task_name=setiment \
11 | --do_train=False \
12 | --do_eval=False \
13 | --do_predict=True \
14 | --data_dir=$DATA_DIR \
15 | --vocab_file=$BERT_BASE_DIR/vocab.txt \
16 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \
17 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
18 | --max_seq_length=128 \
19 | --train_batch_size=32 \
20 | --learning_rate=2e-5 \
21 | --num_train_epochs=1.0 \
22 | --output_dir=$TRAINED_CLASSIFIER/$MODEL_NAME \
23 | --do_export=True \
24 | --export_dir=exported
--------------------------------------------------------------------------------
/train_roberta_tiny_clue_command_gpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #description: BERT fine-tuning
3 |
4 | export BERT_BASE_DIR=/export/huangdongxiao/huangdongxiao/AI_QA/models/RoBERTa-tiny-clue
5 | export DATA_DIR=/export/huangdongxiao/huangdongxiao/AI_QA/ALBERT/1-data/command
6 | export TRAINED_CLASSIFIER=./output
7 | export MODEL_NAME=roberta_tiny_clue_command_gpu
8 |
9 | export CUDA_VISIBLE_DEVICES=2
10 | python run_classifier_serving_gpu.py \
11 | --task_name=command \
12 | --do_train=false \
13 | --do_eval=false \
14 | --do_predict=true \
15 | --do_export=false \
16 | --do_frozen=true \
17 | --data_dir=$DATA_DIR \
18 | --vocab_file=$BERT_BASE_DIR/vocab.txt \
19 | --test_file=test \
20 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \
21 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
22 | --max_seq_length=128 \
23 | --train_batch_size=32 \
24 | --learning_rate=1e-4 \
25 | --num_train_epochs=6.0 \
26 | --output_dir=$TRAINED_CLASSIFIER/$MODEL_NAME
27 |
--------------------------------------------------------------------------------
/api/api_service.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: api_service_new
4 | Description : api客户端请求服务器,返回标签
5 | Author : 逸轩
6 | date: 2019/10/12
7 |
8 | """
9 |
10 | import json
11 | import re
12 | import time
13 | from bert_base.client import BertClient
14 |
15 | bc = BertClient(ip='192.168.9.23', port=5575, port_out=5576, show_server_config=False, check_version=False, check_length=False, mode='CLASS')
16 | print('BertClient连接成功')
17 |
18 | # 切分句子
19 | def cut_sent(txt):
20 | # 先预处理去空格等
21 | txt = re.sub('([ \t]+)', r" ", txt) # blank word
22 | txt = txt.rstrip() # 段尾如果有多余的\n就去掉它
23 | nlist = txt.split(";")
24 | nnlist = [x for x in nlist if x.strip() != ''] # 过滤掉空行
25 | return nnlist
26 |
27 |
28 | # 对句子列表进行预测识别
29 | def class_pred(list_text):
30 | # 文本拆分成句子
31 | # list_text = cut_sent(text)
32 | print("total setance: %d" % (len(list_text)))
33 | # with BertClient(ip='192.168.9.23', port=5575, port_out=5576, show_server_config=False, check_version=False,
34 | # check_length=False, mode='CLASS') as bc:
35 | start_t = time.perf_counter()
36 | rst = bc.encode(list_text)
37 | # print('result:', rst)
38 | print('time used:{}s'.format(time.perf_counter() - start_t))
39 | # 返回结构为:
40 | # rst: [{'pred_label': ['0', '1', '0'], 'score': [0.9983683228492737, 0.9988993406295776, 0.9997349381446838]}]
41 | # 抽取出标注结果
42 | pred_label = rst[0]["pred_label"]
43 | result_txt = [[pred_label[i], list_text[i]] for i in range(len(pred_label))]
44 | return result_txt
45 |
46 | if __name__ == '__main__':
47 | while True:
48 | text = input(r'请输入句子(多个句子请用;分隔):')
49 | list_text = cut_sent(text)
50 | # print(list_text)
51 | result = class_pred(list_text)
52 | print(result)
53 |
--------------------------------------------------------------------------------
/api/api_service_flask.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: api_service_flask
4 | Description : api请求服务端,并基于flask提供接口待查询
5 | Author : 逸轩
6 | date: 2019/10/12
7 |
8 | """
9 |
10 | import json
11 | import re
12 | import time
13 | from bert_base.client import BertClient
14 | from flask import Flask, request
15 | from flask_cors import CORS
16 |
17 | flaskAPP = Flask(import_name=__name__)
18 | CORS(flaskAPP, supports_credentials=True)
19 |
20 | bc = BertClient(ip='192.168.9.23', port=5575, port_out=5576, show_server_config=False, check_version=False, check_length=False, mode='CLASS')
21 | print('BertClient连接成功')
22 |
23 | # 切分句子
24 | def cut_sent(txt):
25 | # 先预处理去空格等
26 | txt = re.sub('([ \t]+)', r" ", txt) # blank word
27 | txt = txt.rstrip() # 段尾如果有多余的\n就去掉它
28 | nlist = txt.split("\n")
29 | nnlist = [x for x in nlist if x.strip() != ''] # 过滤掉空行
30 | return nnlist
31 |
32 |
33 | # 对句子列表进行预测识别
34 | def class_pred(list_text):
35 | # 文本拆分成句子
36 | # list_text = cut_sent(text)
37 | print("total setance: %d" % (len(list_text)))
38 | # with BertClient(ip='192.168.9.23', port=5575, port_out=5576, show_server_config=False, check_version=False,
39 | # check_length=False, mode='CLASS') as bc:
40 | start_t = time.perf_counter()
41 | rst = bc.encode(list_text)
42 | print('result:', rst)
43 | print('time used:{}'.format(time.perf_counter() - start_t))
44 | # 返回结构为:
45 | # rst: [{'pred_label': ['0', '1', '0'], 'score': [0.9983683228492737, 0.9988993406295776, 0.9997349381446838]}]
46 | # 抽取出标注结果
47 | pred_label = rst[0]["pred_label"]
48 | result_txt = [[pred_label[i], list_text[i]] for i in range(len(pred_label))]
49 | return result_txt
50 |
51 | @flaskAPP.route('/predict_online', methods=['GET', 'POST'])
52 | def predict_online():
53 | text = request.args.get('text')
54 | print('服务器接收到字段:')
55 | print('text:', text)
56 | print('==============================')
57 | lstseg = cut_sent(text)
58 | print('-' * 30)
59 | print('结果,共%d个句子:' % (len(lstseg)))
60 | # for x in lstseg:
61 | # print("第%d句:【 %s】" % (lstseg.index(x), x))
62 | print('-' * 30)
63 | # if request.method == 'POST' or 1:
64 | # res['result'] = class_pred(lstseg)
65 | result = class_pred(lstseg)
66 | new_res_list = []
67 | for term in result:
68 | if term[0] == '1':
69 | label = '好评'
70 | if term[0] == '0':
71 | label = '中评'
72 | if term[0] == '-1':
73 | label = '差评'
74 | new_res_list.append([label, term[1]])
75 | new_res = {'result': new_res_list}
76 | # print('result:%s' % str(res))
77 | print('result:%s' % str(new_res))
78 | # return jsonify(res)
79 | return json.dumps(new_res, ensure_ascii=False)
80 |
81 | flaskAPP.run(host='192.168.9.23', port=8910, debug=True)
82 |
83 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __MACOSX/
3 | *.tar.gz
4 | settings.json
5 | .vscode/
6 | *jd*
7 | *JD*
8 | *.data
9 | checkpoint
10 | *.index
11 | *.pickle
12 | *.pickle
13 | *.pkl
14 | *backup
15 | *.out
16 | *.csv
17 | *.meta
18 | *.data-*
19 | .idea
20 | *.pyc
21 | # corpus
22 | # resources
23 | .DS_Store
24 | *.zip
25 | *.txt
26 | __pycache__/
27 | *.py[cod]
28 | *$py.class
29 |
30 | # C extensions
31 | *.so
32 |
33 | # Distribution / packaging
34 | .Python
35 | build/
36 | develop-eggs/
37 | dist/
38 | downloads/
39 | eggs/
40 | .eggs/
41 | lib/
42 | lib64/
43 | parts/
44 | sdist/
45 | var/
46 | wheels/
47 | *.egg-info/
48 | .installed.cfg
49 | *.egg
50 | MANIFEST
51 |
52 | # PyInstaller
53 | # Usually these files are written by a python script from a template
54 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
55 | *.manifest
56 | *.spec
57 |
58 | # Installer logs
59 | pip-log.txt
60 | pip-delete-this-directory.txt
61 |
62 | # Unit test / coverage reports
63 | htmlcov/
64 | .tox/
65 | .coverage
66 | .coverage.*
67 | .cache
68 | nosetests.xml
69 | coverage.xml
70 | *.cover
71 | .hypothesis/
72 | .pytest_cache/
73 |
74 | # Translations
75 | *.mo
76 | *.pot
77 |
78 | # Django stuff:
79 | *.log
80 | local_settings.py
81 | db.sqlite3
82 |
83 | # Flask stuff:
84 | instance/
85 | .webassets-cache
86 |
87 | # Scrapy stuff:
88 | .scrapy
89 |
90 | # Sphinx documentation
91 | docs/_build/
92 |
93 | # PyBuilder
94 | target/
95 |
96 | # Jupyter Notebook
97 | .ipynb_checkpoints
98 |
99 | # pyenv
100 | .python-version
101 |
102 | # celery beat schedule file
103 | celerybeat-schedule
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 |
130 | # project related
131 | # Byte-compiled / optimized / DLL files
132 | __MACOSX/
133 | *.tar.gz
134 | settings.json
135 | .vscode/
136 | *jd*
137 | *JD*
138 | *.data
139 | checkpoint
140 | *.index
141 | *.pickle
142 | *.pickle
143 | *.pkl
144 | *backup
145 | *.out
146 | *.csv
147 | *.meta
148 | *.data-*
149 | .idea
150 | *.pyc
151 | # corpus
152 | # resources
153 | .DS_Store
154 | *.zip
155 | *.txt
156 | __pycache__/
157 | *.py[cod]
158 | *$py.class
159 |
160 | # C extensions
161 | *.so
162 |
163 | # Distribution / packaging
164 | .Python
165 | build/
166 | develop-eggs/
167 | dist/
168 | downloads/
169 | eggs/
170 | .eggs/
171 | lib/
172 | lib64/
173 | parts/
174 | sdist/
175 | var/
176 | wheels/
177 | *.egg-info/
178 | .installed.cfg
179 | *.egg
180 | MANIFEST
181 |
182 | # PyInstaller
183 | # Usually these files are written by a python script from a template
184 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
185 | *.manifest
186 | *.spec
187 |
188 | # Installer logs
189 | pip-log.txt
190 | pip-delete-this-directory.txt
191 |
192 | # Unit test / coverage reports
193 | htmlcov/
194 | .tox/
195 | .coverage
196 | .coverage.*
197 | .cache
198 | nosetests.xml
199 | coverage.xml
200 | *.cover
201 | .hypothesis/
202 | .pytest_cache/
203 |
204 | # Translations
205 | *.mo
206 | *.pot
207 |
208 | # Django stuff:
209 | *.log
210 | local_settings.py
211 | db.sqlite3
212 |
213 | # Flask stuff:
214 | instance/
215 | .webassets-cache
216 |
217 | # Scrapy stuff:
218 | .scrapy
219 |
220 | # Sphinx documentation
221 | docs/_build/
222 |
223 | # PyBuilder
224 | target/
225 |
226 | # Jupyter Notebook
227 | .ipynb_checkpoints
228 |
229 | # pyenv
230 | .python-version
231 |
232 | # celery beat schedule file
233 | celerybeat-schedule
234 |
235 | # SageMath parsed files
236 | *.sage.py
237 |
238 | # Environments
239 | .env
240 | .venv
241 | env/
242 | venv/
243 | ENV/
244 | env.bak/
245 | venv.bak/
246 |
247 | # Spyder project settings
248 | .spyderproject
249 | .spyproject
250 |
251 | # Rope project settings
252 | .ropeproject
253 |
254 | # mkdocs documentation
255 | /site
256 |
257 | # mypy
258 | .mypy_cache/
259 |
260 | # project related
261 | cmrc_1588284341
262 | 1*/
263 |
--------------------------------------------------------------------------------
/cmrc/cmrc_eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''
3 | Evaluation script for CMRC 2018
4 | version: v5 - special
5 | Note:
6 | v5 - special: Evaluate on SQuAD-style CMRC 2018 Datasets
7 | v5: formatted output, add usage description
8 | v4: fixed segmentation issues
9 | '''
10 | from __future__ import print_function
11 | from collections import Counter, OrderedDict
12 | import string
13 | import re
14 | import argparse
15 | import json
16 | import sys
17 | if sys.version[0] == '2':
18 | reload(sys)
19 | sys.setdefaultencoding("utf-8")
20 | import nltk
21 | nltk.download('punkt')
22 | import pdb
23 |
24 | # split Chinese with English
25 | def mixed_segmentation(in_str, rm_punc=False):
26 | in_str = str(in_str).lower().strip()
27 | segs_out = []
28 | temp_str = ""
29 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
30 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
31 | '「','」','(',')','-','~','『','』']
32 | for char in in_str:
33 | if rm_punc and char in sp_char:
34 | continue
35 | if re.search(u'[\u4e00-\u9fa5]', char) or char in sp_char: # 匹配所有中文
36 | if temp_str != "":
37 | ss = nltk.word_tokenize(temp_str)
38 | segs_out.extend(ss)
39 | temp_str = ""
40 | segs_out.append(char)
41 | else:
42 | temp_str += char
43 |
44 | #handling last part
45 | if temp_str != "":
46 | ss = nltk.word_tokenize(temp_str)
47 | segs_out.extend(ss)
48 |
49 | return segs_out
50 |
51 |
52 | # remove punctuation
53 | def remove_punctuation(in_str):
54 | in_str = str(in_str).lower().strip()
55 | sp_char = ['-',':','_','*','^','/','\\','~','`','+','=',
56 | ',','。',':','?','!','“','”',';','’','《','》','……','·','、',
57 | '「','」','(',')','-','~','『','』']
58 | out_segs = []
59 | for char in in_str:
60 | if char in sp_char:
61 | continue
62 | else:
63 | out_segs.append(char)
64 | return ''.join(out_segs)
65 |
66 |
67 | # find longest common string
68 | def find_lcs(s1, s2):
69 | m = [[0 for i in range(len(s2)+1)] for j in range(len(s1)+1)]
70 | mmax = 0
71 | p = 0
72 | for i in range(len(s1)):
73 | for j in range(len(s2)):
74 | if s1[i] == s2[j]:
75 | m[i+1][j+1] = m[i][j]+1
76 | if m[i+1][j+1] > mmax:
77 | mmax=m[i+1][j+1]
78 | p=i+1
79 | return s1[p-mmax:p], mmax
80 |
81 | #
82 | def evaluate(ground_truth_file, prediction_file):
83 | f1 = 0
84 | em = 0
85 | total_count = 0
86 | skip_count = 0
87 | for instance in ground_truth_file["data"]:
88 | #context_id = instance['context_id'].strip()
89 | #context_text = instance['context_text'].strip()
90 | for para in instance["paragraphs"]:
91 | for qas in para['qas']:
92 | total_count += 1
93 | query_id = qas['id'].strip()
94 | query_text = qas['question'].strip()
95 | answers = [x["text"] for x in qas['answers']]
96 |
97 | if query_id not in prediction_file:
98 | sys.stderr.write('Unanswered question: {}\n'.format(query_id))
99 | skip_count += 1
100 | continue
101 |
102 | prediction = str(prediction_file[query_id])
103 | f1 += calc_f1_score(answers, prediction)
104 | em += calc_em_score(answers, prediction)
105 |
106 | f1_score = 100.0 * f1 / total_count
107 | em_score = 100.0 * em / total_count
108 | return f1_score, em_score, total_count, skip_count
109 |
110 |
111 | def calc_f1_score(answers, prediction):
112 | f1_scores = []
113 | for ans in answers:
114 | ans_segs = mixed_segmentation(ans, rm_punc=True)
115 | prediction_segs = mixed_segmentation(prediction, rm_punc=True)
116 | lcs, lcs_len = find_lcs(ans_segs, prediction_segs)
117 | if lcs_len == 0:
118 | f1_scores.append(0)
119 | continue
120 | precision = 1.0*lcs_len/len(prediction_segs)
121 | recall = 1.0*lcs_len/len(ans_segs)
122 | f1 = (2*precision*recall)/(precision+recall)
123 | f1_scores.append(f1)
124 | return max(f1_scores)
125 |
126 |
127 | def calc_em_score(answers, prediction):
128 | em = 0
129 | for ans in answers:
130 | ans_ = remove_punctuation(ans)
131 | prediction_ = remove_punctuation(prediction)
132 | if ans_ == prediction_:
133 | em = 1
134 | break
135 | return em
136 |
137 | if __name__ == '__main__':
138 | parser = argparse.ArgumentParser(description='Evaluation Script for CMRC 2018')
139 | parser.add_argument('--dataset_file', type=str, default="/Users/huangdongxiao2/CodeRepos/SesameSt/bert/output/cmrc/dev.json" , help='Official dataset file')
140 | parser.add_argument('--prediction_file', type=str, default="/Users/huangdongxiao2/CodeRepos/SesameSt/bert/output/cmrc/predictions.json" ,help='Your prediction File')
141 | args = parser.parse_args()
142 | ground_truth_file = json.load(open(args.dataset_file, 'rb'))
143 | prediction_file = json.load(open(args.prediction_file, 'rb'))
144 | F1, EM, TOTAL, SKIP = evaluate(ground_truth_file, prediction_file)
145 | AVG = (EM+F1)*0.5
146 | output_result = OrderedDict()
147 | output_result['AVERAGE'] = '%.3f' % AVG
148 | output_result['F1'] = '%.3f' % F1
149 | output_result['EM'] = '%.3f' % EM
150 | output_result['TOTAL'] = TOTAL
151 | output_result['SKIP'] = SKIP
152 | output_result['FILE'] = args.prediction_file
153 | print(json.dumps(output_result))
--------------------------------------------------------------------------------
/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
26 | """Creates an optimizer training op."""
27 | global_step = tf.train.get_or_create_global_step()
28 |
29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
30 |
31 | # Implements linear decay of the learning rate.
32 | learning_rate = tf.train.polynomial_decay(
33 | learning_rate,
34 | global_step,
35 | num_train_steps,
36 | end_learning_rate=0.0,
37 | power=1.0,
38 | cycle=False)
39 |
40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
41 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
42 | if num_warmup_steps:
43 | global_steps_int = tf.cast(global_step, tf.int32)
44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
45 |
46 | global_steps_float = tf.cast(global_steps_int, tf.float32)
47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
48 |
49 | warmup_percent_done = global_steps_float / warmup_steps_float
50 | warmup_learning_rate = init_lr * warmup_percent_done
51 |
52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
53 | learning_rate = (
54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
55 |
56 | # It is recommended that you use this optimizer for fine tuning, since this
57 | # is how the model was trained (note that the Adam m/v variables are NOT
58 | # loaded from init_checkpoint.)
59 | optimizer = AdamWeightDecayOptimizer(
60 | learning_rate=learning_rate,
61 | weight_decay_rate=0.01,
62 | beta_1=0.9,
63 | beta_2=0.999,
64 | epsilon=1e-6,
65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
66 |
67 | if use_tpu:
68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
69 |
70 | tvars = tf.trainable_variables()
71 | grads = tf.gradients(loss, tvars)
72 |
73 | # This is how the model was pre-trained.
74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
75 |
76 | train_op = optimizer.apply_gradients(
77 | zip(grads, tvars), global_step=global_step)
78 |
79 | # Normally the global step update is done inside of `apply_gradients`.
80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
81 | # a different optimizer, you should probably take this line out.
82 | new_global_step = global_step + 1
83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
84 | return train_op
85 |
86 |
87 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
88 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
89 |
90 | def __init__(self,
91 | learning_rate,
92 | weight_decay_rate=0.0,
93 | beta_1=0.9,
94 | beta_2=0.999,
95 | epsilon=1e-6,
96 | exclude_from_weight_decay=None,
97 | name="AdamWeightDecayOptimizer"):
98 | """Constructs a AdamWeightDecayOptimizer."""
99 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
100 |
101 | self.learning_rate = learning_rate
102 | self.weight_decay_rate = weight_decay_rate
103 | self.beta_1 = beta_1
104 | self.beta_2 = beta_2
105 | self.epsilon = epsilon
106 | self.exclude_from_weight_decay = exclude_from_weight_decay
107 |
108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
109 | """See base class."""
110 | assignments = []
111 | for (grad, param) in grads_and_vars:
112 | if grad is None or param is None:
113 | continue
114 |
115 | param_name = self._get_variable_name(param.name)
116 |
117 | m = tf.get_variable(
118 | name=param_name + "/adam_m",
119 | shape=param.shape.as_list(),
120 | dtype=tf.float32,
121 | trainable=False,
122 | initializer=tf.zeros_initializer())
123 | v = tf.get_variable(
124 | name=param_name + "/adam_v",
125 | shape=param.shape.as_list(),
126 | dtype=tf.float32,
127 | trainable=False,
128 | initializer=tf.zeros_initializer())
129 |
130 | # Standard Adam update.
131 | next_m = (
132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
133 | next_v = (
134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
135 | tf.square(grad)))
136 |
137 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
138 |
139 | # Just adding the square of the weights to the loss function is *not*
140 | # the correct way of using L2 regularization/weight decay with Adam,
141 | # since that will interact with the m and v parameters in strange ways.
142 | #
143 | # Instead we want ot decay the weights in a manner that doesn't interact
144 | # with the m/v parameters. This is equivalent to adding the square
145 | # of the weights to the loss with plain (non-momentum) SGD.
146 | if self._do_use_weight_decay(param_name):
147 | update += self.weight_decay_rate * param
148 |
149 | update_with_lr = self.learning_rate * update
150 |
151 | next_param = param - update_with_lr
152 |
153 | assignments.extend(
154 | [param.assign(next_param),
155 | m.assign(next_m),
156 | v.assign(next_v)])
157 | return tf.group(*assignments, name=name)
158 |
159 | def _do_use_weight_decay(self, param_name):
160 | """Whether to use L2 weight decay for `param_name`."""
161 | if not self.weight_decay_rate:
162 | return False
163 | if self.exclude_from_weight_decay:
164 | for r in self.exclude_from_weight_decay:
165 | if re.search(r, param_name) is not None:
166 | return False
167 | return True
168 |
169 | def _get_variable_name(self, param_name):
170 | """Get the variable name from the tensor name."""
171 | m = re.match("^(.*):\\d+$", param_name)
172 | if m is not None:
173 | param_name = m.group(1)
174 | return param_name
175 |
--------------------------------------------------------------------------------
/test_serving.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: test_serving
4 | Description :
5 | Author : 逸轩
6 | date: 2019/10/12
7 |
8 | """
9 |
10 | import tensorflow as tf
11 | import tokenization
12 |
13 | class InputExample(object):
14 | """A single training/test example for simple sequence classification."""
15 |
16 | def __init__(self, guid, text_a, text_b=None, label=None):
17 | """Constructs a InputExample.
18 |
19 | Args:
20 | guid: Unique id for the example.
21 | text_a: string. The untokenized text of the first sequence. For single
22 | sequence tasks, only this sequence must be specified.
23 | text_b: (Optional) string. The untokenized text of the second sequence.
24 | Only must be specified for sequence pair tasks.
25 | label: (Optional) string. The label of the example. This should be
26 | specified for train and dev examples, but not for test examples.
27 | """
28 | self.guid = guid
29 | self.text_a = text_a
30 | self.text_b = text_b
31 | self.label = label
32 |
33 | class PaddingInputExample(object):
34 | """Fake example so the num input examples is a multiple of the batch size.
35 |
36 | When running eval/predict on the TPU, we need to pad the number of examples
37 | to be a multiple of the batch size, because the TPU requires a fixed batch
38 | size. The alternative is to drop the last batch, which is bad because it means
39 | the entire output data won't be generated.
40 |
41 | We use this class instead of `None` because treating `None` as padding
42 | battches could cause silent errors.
43 | """
44 |
45 |
46 | class InputFeatures(object):
47 | """A single set of features of data."""
48 |
49 | def __init__(self,
50 | input_ids,
51 | input_mask,
52 | segment_ids,
53 | label_id,
54 | is_real_example=True):
55 | self.input_ids = input_ids
56 | self.input_mask = input_mask
57 | self.segment_ids = segment_ids
58 | self.label_id = label_id
59 | self.is_real_example = is_real_example
60 |
61 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
62 | """Truncates a sequence pair in place to the maximum length."""
63 |
64 | # This is a simple heuristic which will always truncate the longer sequence
65 | # one token at a time. This makes more sense than truncating an equal percent
66 | # of tokens from each, since if one sequence is very short then each token
67 | # that's truncated likely contains more information than a longer sequence.
68 | while True:
69 | total_length = len(tokens_a) + len(tokens_b)
70 | if total_length <= max_length:
71 | break
72 | if len(tokens_a) > len(tokens_b):
73 | tokens_a.pop()
74 | else:
75 | tokens_b.pop()
76 |
77 |
78 | def convert_single_example(ex_index, example, label_list, max_seq_length,
79 | tokenizer):
80 | """Converts a single `InputExample` into a single `InputFeatures`."""
81 |
82 | if isinstance(example, PaddingInputExample):
83 | return InputFeatures(
84 | input_ids=[0] * max_seq_length,
85 | input_mask=[0] * max_seq_length,
86 | segment_ids=[0] * max_seq_length,
87 | label_id=0,
88 | is_real_example=False)
89 |
90 | label_map = {}
91 | for (i, label) in enumerate(label_list):
92 | label_map[label] = i
93 |
94 | tokens_a = tokenizer.tokenize(example.text_a)
95 | tokens_b = None
96 | if example.text_b:
97 | tokens_b = tokenizer.tokenize(example.text_b)
98 |
99 | if tokens_b:
100 | # Modifies `tokens_a` and `tokens_b` in place so that the total
101 | # length is less than the specified length.
102 | # Account for [CLS], [SEP], [SEP] with "- 3"
103 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
104 | else:
105 | # Account for [CLS] and [SEP] with "- 2"
106 | if len(tokens_a) > max_seq_length - 2:
107 | tokens_a = tokens_a[0:(max_seq_length - 2)]
108 |
109 | # The convention in BERT is:
110 | # (a) For sequence pairs:
111 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
112 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
113 | # (b) For single sequences:
114 | # tokens: [CLS] the dog is hairy . [SEP]
115 | # type_ids: 0 0 0 0 0 0 0
116 | #
117 | # Where "type_ids" are used to indicate whether this is the first
118 | # sequence or the second sequence. The embedding vectors for `type=0` and
119 | # `type=1` were learned during pre-training and are added to the wordpiece
120 | # embedding vector (and position vector). This is not *strictly* necessary
121 | # since the [SEP] token unambiguously separates the sequences, but it makes
122 | # it easier for the model to learn the concept of sequences.
123 | #
124 | # For classification tasks, the first vector (corresponding to [CLS]) is
125 | # used as the "sentence vector". Note that this only makes sense because
126 | # the entire model is fine-tuned.
127 | tokens = []
128 | segment_ids = []
129 | tokens.append("[CLS]")
130 | segment_ids.append(0)
131 | for token in tokens_a:
132 | tokens.append(token)
133 | segment_ids.append(0)
134 | tokens.append("[SEP]")
135 | segment_ids.append(0)
136 |
137 | if tokens_b:
138 | for token in tokens_b:
139 | tokens.append(token)
140 | segment_ids.append(1)
141 | tokens.append("[SEP]")
142 | segment_ids.append(1)
143 |
144 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
145 |
146 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
147 | # tokens are attended to.
148 | input_mask = [1] * len(input_ids)
149 |
150 | # Zero-pad up to the sequence length.
151 | while len(input_ids) < max_seq_length:
152 | input_ids.append(0)
153 | input_mask.append(0)
154 | segment_ids.append(0)
155 |
156 | assert len(input_ids) == max_seq_length
157 | assert len(input_mask) == max_seq_length
158 | assert len(segment_ids) == max_seq_length
159 |
160 | # debug xmxoxo 2019/3/13
161 | # print(ex_index,example.text_a)
162 |
163 | label_id = label_map[example.label]
164 | if ex_index < 5:
165 | tf.logging.info("*** Example ***")
166 | tf.logging.info("guid: %s" % (example.guid))
167 | tf.logging.info("tokens: %s" % " ".join(
168 | [tokenization.printable_text(x) for x in tokens]))
169 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
170 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
171 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
172 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id))
173 |
174 | feature = InputFeatures(
175 | input_ids=input_ids,
176 | input_mask=input_mask,
177 | segment_ids=segment_ids,
178 | label_id=label_id,
179 | is_real_example=True)
180 | return feature
181 |
182 | if __name__ == '__main__':
183 | predict_fn = tf.contrib.predictor.from_saved_model('exported/1571054350')
184 | label_list = ["-1", "0", "1"]
185 | max_seq_length = 128
186 | tokenizer = tokenization.FullTokenizer(vocab_file='chinese_roberta_zh_l12/vocab.txt', do_lower_case=True)
187 | print('模型加载完毕!正在监听》》》')
188 | while True:
189 | question = input("> ")
190 | predict_example = InputExample("id", question, None, '0')
191 | feature = convert_single_example(100, predict_example, label_list,
192 | max_seq_length, tokenizer)
193 |
194 | prediction = predict_fn({
195 | "input_ids": [feature.input_ids],
196 | "input_mask": [feature.input_mask],
197 | "segment_ids": [feature.segment_ids],
198 | "label_ids": [feature.label_id],
199 | })
200 | probabilities = prediction["probabilities"]
201 | label = label_list[probabilities.argmax()]
202 | print(label)
--------------------------------------------------------------------------------
/freeze_graph.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: test_serving
4 | Description : BERT模型文件 ckpt转pb 工具
5 | Author : stephen
6 | date: 2019/10/12
7 |
8 | """
9 |
10 | import os
11 | from termcolor import colored
12 | import modeling
13 | import logging
14 | import tensorflow as tf
15 | import argparse
16 | import pickle
17 |
18 |
19 | def set_logger(context, verbose=False):
20 | if os.name == 'nt': # for Windows
21 | return NTLogger(context, verbose)
22 |
23 | logger = logging.getLogger(context)
24 | logger.setLevel(logging.DEBUG if verbose else logging.INFO)
25 | formatter = logging.Formatter(
26 | '%(levelname)-.1s:' + context + ':[%(filename).3s:%(funcName).3s:%(lineno)3d]:%(message)s', datefmt=
27 | '%m-%d %H:%M:%S')
28 | console_handler = logging.StreamHandler()
29 | console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
30 | console_handler.setFormatter(formatter)
31 | logger.handlers = []
32 | logger.addHandler(console_handler)
33 | return logger
34 |
35 |
36 | class NTLogger:
37 | def __init__(self, context, verbose):
38 | self.context = context
39 | self.verbose = verbose
40 |
41 | def info(self, msg, **kwargs):
42 | print('I:%s:%s' % (self.context, msg), flush=True)
43 |
44 | def debug(self, msg, **kwargs):
45 | if self.verbose:
46 | print('D:%s:%s' % (self.context, msg), flush=True)
47 |
48 | def error(self, msg, **kwargs):
49 | print('E:%s:%s' % (self.context, msg), flush=True)
50 |
51 | def warning(self, msg, **kwargs):
52 | print('W:%s:%s' % (self.context, msg), flush=True)
53 |
54 | def create_classification_model(bert_config, is_training, input_ids, input_mask, segment_ids, labels, num_labels):
55 | """
56 |
57 | :param bert_config:
58 | :param is_training:
59 | :param input_ids:
60 | :param input_mask:
61 | :param segment_ids:
62 | :param labels:
63 | :param num_labels:
64 | :param use_one_hot_embedding:
65 | :return:
66 | """
67 |
68 | #import tensorflow as tf
69 | #import modeling
70 |
71 | # 通过传入的训练数据,进行representation
72 | model = modeling.BertModel(
73 | config=bert_config,
74 | is_training=is_training,
75 | input_ids=input_ids,
76 | input_mask=input_mask,
77 | token_type_ids=segment_ids,
78 | )
79 |
80 | embedding_layer = model.get_sequence_output()
81 | output_layer = model.get_pooled_output()
82 | hidden_size = output_layer.shape[-1].value
83 |
84 | # predict = CNN_Classification(embedding_chars=embedding_layer,
85 | # labels=labels,
86 | # num_tags=num_labels,
87 | # sequence_length=FLAGS.max_seq_length,
88 | # embedding_dims=embedding_layer.shape[-1].value,
89 | # vocab_size=0,
90 | # filter_sizes=[3, 4, 5],
91 | # num_filters=3,
92 | # dropout_keep_prob=FLAGS.dropout_keep_prob,
93 | # l2_reg_lambda=0.001)
94 | # loss, predictions, probabilities = predict.add_cnn_layer()
95 |
96 | output_weights = tf.get_variable(
97 | "output_weights", [num_labels, hidden_size],
98 | initializer=tf.truncated_normal_initializer(stddev=0.02))
99 |
100 | output_bias = tf.get_variable(
101 | "output_bias", [num_labels], initializer=tf.zeros_initializer())
102 |
103 | with tf.variable_scope("loss"):
104 | if is_training:
105 | # I.e., 0.1 dropout
106 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
107 |
108 | logits = tf.matmul(output_layer, output_weights, transpose_b=True)
109 | logits = tf.nn.bias_add(logits, output_bias)
110 | probabilities = tf.nn.softmax(logits, axis=-1)
111 | log_probs = tf.nn.log_softmax(logits, axis=-1)
112 |
113 | if labels is not None:
114 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
115 |
116 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
117 | loss = tf.reduce_mean(per_example_loss)
118 | else:
119 | loss, per_example_loss = None, None
120 | return (loss, per_example_loss, logits, probabilities)
121 |
122 |
123 | def init_predict_var(path):
124 | num_labels = 2
125 | label2id = None
126 | id2label = None
127 | label2id_file = os.path.join(path, 'label2id.pkl')
128 | if os.path.exists(label2id_file):
129 | with open(label2id_file, 'rb') as rf:
130 | label2id = pickle.load(rf)
131 | id2label = {value: key for key, value in label2id.items()}
132 | num_labels = len(label2id.items())
133 | print('num_labels:%d' % num_labels)
134 | else:
135 | print('Can\'t found %s' % label2id_file)
136 | return num_labels, label2id, id2label
137 |
138 |
139 |
140 | def optimize_class_model(args, logger=None):
141 | """
142 | 加载中文分类模型
143 | :param args:
144 | :param num_labels:
145 | :param logger:
146 | :return:
147 | """
148 |
149 | if not logger:
150 | logger = set_logger(colored('CLASSIFICATION_MODEL, Lodding...', 'cyan'), args.verbose)
151 | pass
152 | try:
153 | # 如果PB文件已经存在则,返回PB文件的路径,否则将模型转化为PB文件,并且返回存储PB文件的路径
154 | if args.model_pb_dir is None:
155 | tmp_dir = args.model_dir
156 | else:
157 | tmp_dir = args.model_pb_dir
158 |
159 | pb_file = os.path.join(tmp_dir, 'classification_model.pb')
160 | if os.path.exists(pb_file):
161 | print('pb_file exits', pb_file)
162 | return pb_file
163 |
164 | #增加 从label2id.pkl中读取num_labels, 这样也可以不用指定num_labels参数; 2019/4/17
165 | if not args.num_labels:
166 | num_labels, label2id, id2label = init_predict_var(tmp_dir)
167 | #---
168 |
169 | graph = tf.Graph()
170 | with graph.as_default():
171 | with tf.Session() as sess:
172 | input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids')
173 | input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask')
174 |
175 | bert_config = modeling.BertConfig.from_json_file(os.path.join(args.bert_model_dir, 'bert_config.json'))
176 |
177 | loss, per_example_loss, logits, probabilities = create_classification_model(bert_config=bert_config, is_training=False,
178 | input_ids=input_ids, input_mask=input_mask, segment_ids=None, labels=None, num_labels=num_labels)
179 |
180 | # pred_ids = tf.argmax(probabilities, axis=-1, output_type=tf.int32, name='pred_ids')
181 | # pred_ids = tf.identity(pred_ids, 'pred_ids')
182 |
183 | probabilities = tf.identity(probabilities, 'pred_prob')
184 | saver = tf.train.Saver()
185 |
186 | with tf.Session() as sess:
187 | sess.run(tf.global_variables_initializer())
188 | latest_checkpoint = tf.train.latest_checkpoint(args.model_dir)
189 | logger.info('loading... %s ' % latest_checkpoint )
190 | saver.restore(sess,latest_checkpoint )
191 | logger.info('freeze...')
192 | from tensorflow.python.framework import graph_util
193 | tmp_g = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['pred_prob'])
194 | logger.info('predict cut finished !!!')
195 |
196 | # 存储二进制模型到文件中
197 | logger.info('write graph to a tmp file: %s' % pb_file)
198 | with tf.gfile.GFile(pb_file, 'wb') as f:
199 | f.write(tmp_g.SerializeToString())
200 | return pb_file
201 | except Exception as e:
202 | logger.error('fail to optimize the graph! %s' % e, exc_info=True)
203 |
204 |
205 |
206 |
207 | if __name__ == '__main__':
208 | # pass
209 |
210 | """
211 | bert_model_dir="/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12"
212 | model_dir="/mnt/sda1/transdat/bert-demo/bert/output/demo7"
213 | model_pb_dir=model_dir
214 | max_seq_len=128
215 | num_labels=2
216 | """
217 |
218 |
219 | parser = argparse.ArgumentParser(description='Trans ckpt file to .pb file')
220 | parser.add_argument('-bert_model_dir', type=str, required=True,
221 | help='chinese google bert model path')
222 | parser.add_argument('-model_dir', type=str, required=True,
223 | help='directory of a pretrained BERT model')
224 | parser.add_argument('-model_pb_dir', type=str, default=None,
225 | help='directory of a pretrained BERT model,default = model_dir')
226 | parser.add_argument('-max_seq_len', type=int, default=128,
227 | help='maximum length of a sequence,default:128')
228 | parser.add_argument('-num_labels', type=int, default=None,
229 | help='length of all labels,default=2')
230 | parser.add_argument('-verbose', action='store_true', default=False,
231 | help='turn on tensorflow logging for debug')
232 |
233 | args = parser.parse_args()
234 |
235 | optimize_class_model(args, logger=None)
236 |
237 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TextClassifier_Transformer
2 | 个人基于谷歌开源的BERT编写的文本分类器(基于微调方式),可自由加载NLP领域知名的预训练语言模型BERT、
3 | Roberta、ALBert及其wwm版本,同时适配ERNIE1.0.
4 | 该项目支持两种预测方式:
5 | (1)线下实时预测
6 | (2)服务端实时预测
7 |
8 | ## 新增改动
9 | 2020-03-25:
10 | (1)项目名由'TextClassifier_BERT'更改为'TextClassifier_Transformer';
11 | (2)新增ELECTRA、AlBert两个预训练模型。
12 | **注意:在使用AlBert时,请将该项目下的modeling.py文件更新为ALBert项目中下的modeling.py,而后在运行**
13 | 2020-03-04:
14 | 模型部署增加tf-serving机制,具体实施方式见[This Blog](https://Vincent131499.github.io/2020/02/28/以BERT分类为例阐述模型部署关键技术)
15 |
16 | ## 运行环境
17 | * Python3.6+
18 | * Tensorflow1.10+/Tensorflow-gpu1.10+
19 |
20 | 提供知名的预训练语言模型下载地址(其中百度开源的Ernie模型已转换成tf格式):
21 | Bert-base:链接:https://pan.baidu.com/s/18h9zgvnlU5ztwaQNnzBXTg 提取码:9r1z
22 | Roberta:链接:https://pan.baidu.com/s/1FBpok7U9ekYJRu1a8NSM-Q 提取码:i50r
23 | Bert-wwm:链接:链接:https://pan.baidu.com/s/1lhoJCT_LkboC1_1YXk1ItQ 提取码:ejt7
24 | ERNIE1.0:链接:链接:https://pan.baidu.com/s/1S6MI8rQyQ4U7dLszyb73Yw 提取码:gc6f
25 | ELECTRA-Tiny:链接:https://pan.baidu.com/s/11QaL7A4YSCYq4YlGyU1_vA 提取码:27jb
26 | AlBert-base:链接:https://pan.baidu.com/s/1U7Zx73ngci2Oqp3SLaVOaw 提取码:uijw
27 |
28 |
29 | ## 项目说明
30 | 主要分为两种运行模式:
31 | 模式1:线下实时预测
32 | step1:数据准备
33 | step2:模型训练
34 | step3:模型导出
35 | step4:线下实时预测
36 | 模式2:服务端实时预测
37 | step1:数据准备
38 | step2:模型训练
39 | step3:模型转换
40 | step4:服务部署
41 | step5:应用端
42 | ### 注意事项
43 | 1.如果你只是想体验从模型训练到本地线下预测这一套流程,只需要按照模式1依次执行即可
44 | 2.若你想想体验从模型训练到模型部署整个流程,则需要按照模式2依次执行
45 |
下面将针对以上两个模式的运行方式进行详细说明。
46 | ## 模式1:线下实时预测
47 | ### Step1:数据准备
48 | 为了快速实验项目效果,这里使用了样本规模较小的手机评论数据,数据比较简单,有三个分类:-1(差评)、0(中评)、1(好评),数据样例如下所示:
49 | 
50 | ps:本项目中已将其拆分成了train.tsv、dev.txv、test.tsv三个文件
51 | ### Step2:模型训练
52 | 训练命令:
53 | ```Bash
54 | bash train.sh
55 | ```
56 | train.sh参数说明:
57 | ```Bash
58 | export BERT_BASE_DIR=./chinese_roberta_zh_l12 #指定预训练的语言模型所在路径
59 | export DATA_DIR=./dat #指定数据集所在路径
60 | export TRAINED_CLASSIFIER=./output #训练的模型输出路径
61 | export MODEL_NAME=mobile_0_roberta_base #训练的模型命名
62 | ```
63 | 详细说明:训练模型直接使用bert微调的方式进行训练,对应的程序文件为run_classifier_serving.py。关于微调bert进行训练的代码网上介绍的
64 | 很多,这里就不一一介绍。主要是创建针对该任务的Processor即:SentimentProcessor,在这个processor的_create_examples()和get_labels()函数自定义,如下所示:
65 | ```Python
66 | class SetimentProcessor(DataProcessor):
67 | def get_train_examples(self, data_dir):
68 | """See base class."""
69 | return self._create_examples(
70 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
71 |
72 | def get_dev_examples(self, data_dir):
73 | """See base class."""
74 | return self._create_examples(
75 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
76 |
77 | def get_test_examples(self, data_dir):
78 | """See base class."""
79 | return self._create_examples(
80 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
81 |
82 | def get_labels(self):
83 | """See base class."""
84 |
85 | """
86 | if not os.path.exists(os.path.join(FLAGS.output_dir, 'label_list.pkl')):
87 | with codecs.open(os.path.join(FLAGS.output_dir, 'label_list.pkl'), 'wb') as fd:
88 | pickle.dump(self.labels, fd)
89 | """
90 | return ["-1", "0", "1"]
91 |
92 | def _create_examples(self, lines, set_type):
93 | """Creates examples for the training and dev sets."""
94 | examples = []
95 | for (i, line) in enumerate(lines):
96 | if i == 0:
97 | continue
98 | guid = "%s-%s" % (set_type, i)
99 |
100 | #debug (by xmxoxo)
101 | #print("read line: No.%d" % i)
102 |
103 | text_a = tokenization.convert_to_unicode(line[1])
104 | if set_type == "test":
105 | label = "0"
106 | else:
107 | label = tokenization.convert_to_unicode(line[0])
108 | examples.append(
109 | InputExample(guid=guid, text_a=text_a, label=label))
110 | return examples
111 | ```
112 |
注意,此处作出的一个特别变动之处是在conver_single_example()函数中增加了一段保存label的代码,在训练过程中在保存的模型路径下生成label2id.pkl文件,代码如下所示:
113 | ```Python
114 | #--- save label2id.pkl ---
115 | #在这里输出label2id.pkl , add by stephen 2019-10-12
116 | output_label2id_file = os.path.join(FLAGS.output_dir, "label2id.pkl")
117 | if not os.path.exists(output_label2id_file):
118 | with open(output_label2id_file,'wb') as w:
119 | pickle.dump(label_map,w)
120 | #--- Add end ---
121 | ```
122 | ### Step3:模型导出
123 | 运行如下命令:
124 | ```Bash
125 | bash export.sh
126 | ```
127 | export.sh参数说明:
128 | ```Bash
129 | #以下四个参数应与train.sh中设置的值保持一致
130 | export BERT_BASE_DIR=./chinese_roberta_zh_l12
131 | export DATA_DIR=./dat
132 | export TRAINED_CLASSIFIER=./output
133 | export MODEL_NAME=mobile_0_roberta_base
134 | ```
135 | 会在指定的exported目录下生成以一个时间戳命名的模型目录。
136 | 详细说明:run_classifier.py 主要设计为单次运行的目的,如果把 do_predict 参数设置成 True,倒也确实可以预测,但输入样本是基于文件的,并且不支持将模型持久化在内存里进行 serving,因此需要自己改一些代码,达到两个目的:
137 | (1)允许将模型加载到内存里,即:允许一次加载,多次调用。
138 | (2)允许读取非文件中的样本进行预测。譬如从标准输入流读取样本输入。
139 | * 将模型加载到内存里
140 | run_classifier.py 的 859 行加载了模型为 estimator 变量,但是遗憾的是 estimator 原生并不支持一次加载,多次预测。参见:https://guillaumegenthial.github.io/serving-tensorflow-estimator.html。
141 | 因此需要使用 estimator.export_saved_model() 方法把 estimator 重新导出成 tf.saved_model。
142 | 代码参考了 https://github.com/bigboNed3/bert_serving),在run_classifier_serving中定义serving_input_fn()函数,如下:
143 | ```Python
144 | def serving_input_fn():
145 | label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
146 | input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
147 | input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
148 | segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')
149 | input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
150 | 'label_ids': label_ids,
151 | 'input_ids': input_ids,
152 | 'input_mask': input_mask,
153 | 'segment_ids': segment_ids,
154 | })
155 | return input_fn
156 | ```
157 | 继而在run_classifier_serving中定义do_export选项:
158 | ```Python
159 | if do_export:
160 | estimator._export_to_tpu = False
161 | estimator.export_savedmodel(Flags.export_dir, serving_input_fn)
162 | ```
163 | ### Step4:线下实时预测
164 | 运行test_serving.py文件,即可进行线下实时预测。
165 | 运行效果如下所示:
166 | 
167 | 详细说明:导出模型后,就不需要第 859 行那个 estimator 对象了,可以自行从刚刚的导出模型目录加载模型,代码如下:
168 | ```Python
169 | predict_fn = tf.contrib.predictor.from_saved_model('/exported/1571054350')
170 | ```
171 | 基于上面的 predict_fn 变量,就可以直接进行预测了。下面是一个从标准输入流读取问题样本,并预测分类的样例代码:
172 | ```Python
173 | while True:
174 | question = input("> ")
175 | predict_example = InputExample("id", question, None, '某固定伪标记')
176 | feature = convert_single_example(100, predict_example, label_list,
177 | FLAGS.max_seq_length, tokenizer)
178 |
179 | prediction = predict_fn({
180 | "input_ids":[feature.input_ids],
181 | "input_mask":[feature.input_mask],
182 | "segment_ids":[feature.segment_ids],
183 | "label_ids":[feature.label_id],
184 | })
185 | probabilities = prediction["probabilities"]
186 | label = label_list[probabilities.argmax()]
187 | print(label)
188 | ```
189 | ## 模式2:服务端实时预测
190 | 首先针对该模式的基本架构进行说明:
191 | 
192 |
架构说明:
193 | BERT模型服务端:加载模型,进行实时预测的服务; 使用的是 BERT-BiLSTM-CRF-NER提供的bert-base;
194 | API服务端:调用实时预测服务,为应用提供API接口的服务,用flask编写;
195 | 应用端:最终的应用端; 我这里为了简便,并没有编写网页,直接调用了api接口。
196 | ### Step1:数据准备
197 | 同模式1中的Step1介绍。
198 | ### Step2:模型训练
199 | 同模式1中的Step2介绍。
200 | ### Step3:模型转换
201 | 运行如下命令:
202 | ```Bash
203 | bash model_convert.sh
204 | ```
205 | 会在$TRAINED_CLASSIFIER/$EXP_NAME生成pb格式的模型文件
206 | model_convert.sh参数说明:
207 | ```Bash
208 | export BERT_BASE_DIR=./chinese_roberta_zh_l12 #训练模型时使用的预训练语言模型所在路径
209 | export TRAINED_CLASSIFIER=./output #训练好的模型输出的路径
210 | export EXP_NAME=mobile_0_roberta_base #训练后保存的模型命名
211 |
212 | python freeze_graph.py \
213 | -bert_model_dir $BERT_BASE_DIR \
214 | -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
215 | -max_seq_len 128 #注意,这里的max_seq_len应与训练的脚本train.sh设置的max_seq_length参数值保持一致
216 | ```
217 | ### Step4:模型部署
218 | 运行如下命令:
219 | ```Bash
220 | bash bert_classify_server.sh
221 | ```
222 | 提示:在运行这个命令前需要保证安装了bert-base这个库,使用如下命令进行安装:
223 | ```Bash
224 | pip install bert-base
225 | ```
226 | **注意**:
227 | port 和 port_out 这两个参数是API调用的端口号,默认是5555和5556,如果你准备部署多个模型服务实例,那一定要指定自己的端口号,避免冲突。
228 | 我这里是改为: 5575 和 5576
229 | 如果报错没运行起来,可能是有些模块没装上,都是 bert_base/server/http.py里引用的,装上就好了:
230 | ```
231 | sudo pip install flask
232 | sudo pip install flask_compress
233 | sudo pip install flask_cors
234 | sudo pip install flask_json
235 | ```
236 | 我这里的配置是2个GTX 1080 Ti,这个时候双卡的优势终于发挥出来了,GPU 1用于预测,GPU 0还可以继续训练模型。
237 | 部署成功示例图如下:
238 | 
239 | ### Step5:应用端
240 | 运行如下命令:
241 | ```Bash
242 | python api/api_service_flask.py
243 | ```
244 | 即可通过指定api接口(本项目中是http://192.168.9.23:8910/predict_online?text=我好开心)访问部署的服务器。
245 | 通过浏览器进行请求:
246 | 
247 |
248 | ## TODO:
249 | - [ ] fp16 support
250 | - [ ] add LAMB optimizer
251 | - [x] train data shuffle
252 | - [x] do_froze
--------------------------------------------------------------------------------
/run_savedModel_infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import tokenization
5 | import tensorflow as tf
6 | import numpy as np
7 | import time
8 |
9 | parser = argparse.ArgumentParser(
10 | description='BERT model saved model case/batch test program, exit with q')
11 |
12 | parser.add_argument('--model', type=str,
13 | default='/Users/huangdongxiao2/CodeRepos/SesameSt/albert_zh/inference/robert_tiny_clue', help='the path for the model')
14 | parser.add_argument('--vocab_file', type=str,
15 | default='/Users/huangdongxiao2/CodeRepos/SesameSt/albert_zh/inference/robert_tiny_clue/vocab.txt')
16 | # default='../ALBERT/albert_base/vocab_chinese.txt')
17 | parser.add_argument('--labels', type=list, default=[
18 | 'happy', 'anger', 'lost', 'fear', 'sad', 'other', 'anxiety'], help='label list')
19 | parser.add_argument('--max_seq_length', type=int, default=128,
20 | help='the length of sequence for text padding')
21 | parser.add_argument('--tensor_input_ids', type=str, default='input_ids',
22 | help='the input_ids feature name for saved model')
23 | parser.add_argument('--tensor_input_mask', type=str, default='input_mask',
24 | help='the input_mask feature name for saved model')
25 | parser.add_argument('--tensor_segment_ids', type=str, default='segment_ids',
26 | help='the segment_ids feature name for saved model')
27 | parser.add_argument('--MODE', type=str, default='SINGLE',
28 | help='SINGLE prediction or BATCH prediction')
29 | args_in_use = parser.parse_args()
30 |
31 |
32 | class InputExample(object):
33 | """A single training/test example for simple sequence classification."""
34 |
35 | def __init__(self, guid, text_a, text_b=None, label=None):
36 | """Constructs a InputExample.
37 | Args:
38 | guid: Unique id for the example.
39 | text_a: string. The untokenized text of the first sequence. For single
40 | sequence tasks, only this sequence must be specified.
41 | text_b: (Optional) string. The untokenized text of the second sequence.
42 | Only must be specified for sequence pair tasks.
43 | label: (Optional) string. The label of the example. This should be
44 | specified for train and dev examples, but not for test examples.
45 | """
46 | self.guid = guid
47 | self.text_a = text_a
48 | self.text_b = text_b
49 |
50 |
51 | class InputFeatures(object):
52 | """A single set of features of data."""
53 |
54 | def __init__(self,
55 | input_ids,
56 | input_mask,
57 | segment_ids,
58 | is_real_example=True):
59 | self.input_ids = input_ids
60 | self.input_mask = input_mask
61 | self.segment_ids = segment_ids
62 | self.is_real_example = is_real_example
63 |
64 |
65 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
66 | """Truncates a sequence pair in place to the maximum length."""
67 |
68 | # This is a simple heuristic which will always truncate the longer sequence
69 | # one token at a time. This makes more sense than truncating an equal percent
70 | # of tokens from each, since if one sequence is very short then each token
71 | # that's truncated likely contains more information than a longer sequence.
72 | while True:
73 | total_length = len(tokens_a) + len(tokens_b)
74 | if total_length <= max_length:
75 | break
76 | if len(tokens_a) > len(tokens_b):
77 | tokens_a.pop()
78 | else:
79 | tokens_b.pop()
80 |
81 |
82 | def convert_single_example(example, label_list, max_seq_length,
83 | tokenizer):
84 | """Converts a single `InputExample` into a single `InputFeatures`."""
85 |
86 | label_map = {}
87 | for (i, label) in enumerate(label_list):
88 | label_map[label] = i
89 |
90 | tokens_a = tokenizer.tokenize(example.text_a)
91 | tokens_b = None
92 | if example.text_b:
93 | tokens_b = tokenizer.tokenize(example.text_b)
94 |
95 | if tokens_b:
96 | # Modifies `tokens_a` and `tokens_b` in place so that the total
97 | # length is less than the specified length.
98 | # Account for [CLS], [SEP], [SEP] with "- 3"
99 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
100 | else:
101 | # Account for [CLS] and [SEP] with "- 2"
102 | if len(tokens_a) > max_seq_length - 2:
103 | tokens_a = tokens_a[0:(max_seq_length - 2)]
104 |
105 | # The convention in BERT is:
106 | # (a) For sequence pairs:
107 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
108 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
109 | # (b) For single sequences:
110 | # tokens: [CLS] the dog is hairy . [SEP]
111 | # type_ids: 0 0 0 0 0 0 0
112 | #
113 | # Where "type_ids" are used to indicate whether this is the first
114 | # sequence or the second sequence. The embedding vectors for `type=0` and
115 | # `type=1` were learned during pre-training and are added to the wordpiece
116 | # embedding vector (and position vector). This is not *strictly* necessary
117 | # since the [SEP] token unambiguously separates the sequences, but it makes
118 | # it easier for the model to learn the concept of sequences.
119 | #
120 | # For classification tasks, the first vector (corresponding to [CLS]) is
121 | # used as the "sentence vector". Note that this only makes sense because
122 | # the entire model is fine-tuned.
123 | tokens = []
124 | segment_ids = []
125 | tokens.append("[CLS]")
126 | segment_ids.append(0)
127 | for token in tokens_a:
128 | tokens.append(token)
129 | segment_ids.append(0)
130 | tokens.append("[SEP]")
131 | segment_ids.append(0)
132 |
133 | if tokens_b:
134 | for token in tokens_b:
135 | tokens.append(token)
136 | segment_ids.append(1)
137 | tokens.append("[SEP]")
138 | segment_ids.append(1)
139 |
140 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
141 |
142 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
143 | # tokens are attended to.
144 | input_mask = [1] * len(input_ids)
145 |
146 | # Zero-pad up to the sequence length.
147 | while len(input_ids) < max_seq_length:
148 | input_ids.append(0)
149 | input_mask.append(0)
150 | segment_ids.append(0)
151 |
152 | assert len(input_ids) == max_seq_length
153 | assert len(input_mask) == max_seq_length
154 | assert len(segment_ids) == max_seq_length
155 |
156 | feature = InputFeatures(
157 | input_ids=input_ids,
158 | input_mask=input_mask,
159 | segment_ids=segment_ids,
160 | is_real_example=True)
161 | return feature
162 |
163 |
164 | if __name__ == '__main__':
165 | predict_fn = tf.contrib.predictor.from_saved_model(args_in_use.model)
166 | label_list = args_in_use.labels
167 | label_map = {i: label for i, label in enumerate(label_list)}
168 | max_seq_length = args_in_use.max_seq_length
169 | tokenizer = tokenization.FullTokenizer(vocab_file=args_in_use.vocab_file,
170 | do_lower_case=True)
171 | if args_in_use.MODE == 'SINGLE':
172 | while True:
173 | question = input('(PRESS q to quit)\n> ')
174 | if question == 'q':
175 | break
176 | predict_example = InputExample('id', question, None)
177 | feature = convert_single_example(
178 | predict_example, label_list, max_seq_length, tokenizer)
179 | start_time = time.time()
180 | prediction = predict_fn({
181 | args_in_use.tensor_input_ids: [feature.input_ids],
182 | args_in_use.tensor_input_mask: [feature.input_mask],
183 | args_in_use.tensor_segment_ids: [feature.segment_ids],
184 | })
185 | print(f'elapsed time: {time.time()-start_time}s')
186 | print(prediction)
187 | probabilities = prediction["probabilities"]
188 | max_index = np.argmax(probabilities[0])
189 | print(probabilities, '\n')
190 | print(f'label: {label_map[max_index]}')
191 |
192 | elif args_in_use.MODE == 'BATCH':
193 | questions = [
194 | '我要投诉的',
195 | '我很不开心',
196 | '我好喜欢你'
197 | ]
198 | features = []
199 | for i, question in enumerate(questions):
200 | predict_example = InputExample(f'id{i}', question, None)
201 | feature = convert_single_example(
202 | predict_example, label_list, max_seq_length, tokenizer)
203 | features.append(feature)
204 | feed_dict = {
205 | args_in_use.tensor_input_ids: [feature.input_ids for feature in features],
206 | args_in_use.tensor_input_mask: [feature.input_mask for feature in features],
207 | args_in_use.tensor_segment_ids: [feature.segment_ids for feature in features],
208 | }
209 | predictions = predict_fn(feed_dict)
210 | probabilities = predictions["q"]
211 | print(probabilities, '\n')
212 | max_idxs = np.argmax(probabilities, 1)
213 | print(
214 | f'labels: {[label_map[max_index] for max_index in max_idxs]}')
215 | else:
216 | raise ValueError('unsupported mode')
217 |
--------------------------------------------------------------------------------
/test_tf_serving.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Description : 使用tf-serving部署,使用requests调用接口
4 | Author : MeteorMan
5 | date: 2020/3/4
6 | """
7 |
8 | import json
9 | import numpy as np
10 | import requests
11 | import pickle
12 | import os
13 | import tensorflow as tf
14 | import tokenization
15 |
16 |
17 | class InputExample(object):
18 | """A single training/test example for simple sequence classification."""
19 |
20 | def __init__(self, guid, text_a, text_b=None, label=None):
21 | """Constructs a InputExample.
22 |
23 | Args:
24 | guid: Unique id for the example.
25 | text_a: string. The untokenized text of the first sequence. For single
26 | sequence tasks, only this sequence must be specified.
27 | text_b: (Optional) string. The untokenized text of the second sequence.
28 | Only must be specified for sequence pair tasks.
29 | label: (Optional) string. The label of the example. This should be
30 | specified for train and dev examples, but not for test examples.
31 | """
32 | self.guid = guid
33 | self.text_a = text_a
34 | self.text_b = text_b
35 | self.label = label
36 |
37 | class PaddingInputExample(object):
38 | """Fake example so the num input examples is a multiple of the batch size.
39 |
40 | When running eval/predict on the TPU, we need to pad the number of examples
41 | to be a multiple of the batch size, because the TPU requires a fixed batch
42 | size. The alternative is to drop the last batch, which is bad because it means
43 | the entire output data won't be generated.
44 |
45 | We use this class instead of `None` because treating `None` as padding
46 | battches could cause silent errors.
47 | """
48 |
49 |
50 | class InputFeatures(object):
51 | """A single set of features of data."""
52 |
53 | def __init__(self,
54 | input_ids,
55 | input_mask,
56 | segment_ids,
57 | label_id,
58 | is_real_example=True):
59 | self.input_ids = input_ids
60 | self.input_mask = input_mask
61 | self.segment_ids = segment_ids
62 | self.label_id = label_id
63 | self.is_real_example = is_real_example
64 |
65 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
66 | """Truncates a sequence pair in place to the maximum length."""
67 |
68 | # This is a simple heuristic which will always truncate the longer sequence
69 | # one token at a time. This makes more sense than truncating an equal percent
70 | # of tokens from each, since if one sequence is very short then each token
71 | # that's truncated likely contains more information than a longer sequence.
72 | while True:
73 | total_length = len(tokens_a) + len(tokens_b)
74 | if total_length <= max_length:
75 | break
76 | if len(tokens_a) > len(tokens_b):
77 | tokens_a.pop()
78 | else:
79 | tokens_b.pop()
80 |
81 |
82 | def convert_single_example(ex_index, example, label_list, max_seq_length,
83 | tokenizer):
84 | """Converts a single `InputExample` into a single `InputFeatures`."""
85 |
86 | if isinstance(example, PaddingInputExample):
87 | return InputFeatures(
88 | input_ids=[0] * max_seq_length,
89 | input_mask=[0] * max_seq_length,
90 | segment_ids=[0] * max_seq_length,
91 | label_id=0,
92 | is_real_example=False)
93 |
94 | label_map = {}
95 | for (i, label) in enumerate(label_list):
96 | label_map[label] = i
97 |
98 | tokens_a = tokenizer.tokenize(example.text_a)
99 | tokens_b = None
100 | if example.text_b:
101 | tokens_b = tokenizer.tokenize(example.text_b)
102 |
103 | if tokens_b:
104 | # Modifies `tokens_a` and `tokens_b` in place so that the total
105 | # length is less than the specified length.
106 | # Account for [CLS], [SEP], [SEP] with "- 3"
107 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
108 | else:
109 | # Account for [CLS] and [SEP] with "- 2"
110 | if len(tokens_a) > max_seq_length - 2:
111 | tokens_a = tokens_a[0:(max_seq_length - 2)]
112 |
113 | # The convention in BERT is:
114 | # (a) For sequence pairs:
115 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
116 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
117 | # (b) For single sequences:
118 | # tokens: [CLS] the dog is hairy . [SEP]
119 | # type_ids: 0 0 0 0 0 0 0
120 | #
121 | # Where "type_ids" are used to indicate whether this is the first
122 | # sequence or the second sequence. The embedding vectors for `type=0` and
123 | # `type=1` were learned during pre-training and are added to the wordpiece
124 | # embedding vector (and position vector). This is not *strictly* necessary
125 | # since the [SEP] token unambiguously separates the sequences, but it makes
126 | # it easier for the model to learn the concept of sequences.
127 | #
128 | # For classification tasks, the first vector (corresponding to [CLS]) is
129 | # used as the "sentence vector". Note that this only makes sense because
130 | # the entire model is fine-tuned.
131 | tokens = []
132 | segment_ids = []
133 | tokens.append("[CLS]")
134 | segment_ids.append(0)
135 | for token in tokens_a:
136 | tokens.append(token)
137 | segment_ids.append(0)
138 | tokens.append("[SEP]")
139 | segment_ids.append(0)
140 |
141 | if tokens_b:
142 | for token in tokens_b:
143 | tokens.append(token)
144 | segment_ids.append(1)
145 | tokens.append("[SEP]")
146 | segment_ids.append(1)
147 |
148 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
149 |
150 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
151 | # tokens are attended to.
152 | input_mask = [1] * len(input_ids)
153 |
154 | # Zero-pad up to the sequence length.
155 | while len(input_ids) < max_seq_length:
156 | input_ids.append(0)
157 | input_mask.append(0)
158 | segment_ids.append(0)
159 |
160 | assert len(input_ids) == max_seq_length
161 | assert len(input_mask) == max_seq_length
162 | assert len(segment_ids) == max_seq_length
163 |
164 | label_id = label_map[example.label]
165 | if ex_index < 5:
166 | tf.logging.info("*** Example ***")
167 | tf.logging.info("guid: %s" % (example.guid))
168 | tf.logging.info("tokens: %s" % " ".join(
169 | [tokenization.printable_text(x) for x in tokens]))
170 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
171 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
172 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
173 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id))
174 |
175 | feature = InputFeatures(
176 | input_ids=input_ids,
177 | input_mask=input_mask,
178 | segment_ids=segment_ids,
179 | label_id=label_id,
180 | is_real_example=True)
181 | return feature
182 |
183 | label_dict = pickle.load(open('./label2id.pkl', 'rb'))
184 | label_list = list(label_dict.keys())
185 | # id2label = get_suit_dict()
186 | max_seq_length = 64
187 | tokenizer = tokenization.FullTokenizer(vocab_file='./vocab.txt',
188 | do_lower_case=True)
189 |
190 | def predict_offline_single():
191 | while True:
192 | import time
193 | question = input("> ")
194 | start_time = time.time()
195 | predict_example = InputExample("100", question, None, '婚姻家庭')
196 | feature = convert_single_example(100, predict_example, label_list, max_seq_length, tokenizer)
197 | data = json.dumps({
198 | "instances": [
199 | {
200 | "input_ids": feature.input_ids,
201 | "input_mask": feature.input_mask,
202 | "segment_ids": feature.segment_ids,
203 | "label_ids": [feature.label_id]
204 | }
205 | ]
206 | })
207 | headers = {"content-type": "application/json"}
208 | # json_response = requests.post(
209 | # 'http://127.0.0.1:9001/v1/models/bert_domain_classify:predict',
210 | # data=data, headers=headers)
211 | json_response = requests.post(
212 | 'http://192.168.0.105:9001/v1/models/bert_classify:predict',
213 | data=data, headers=headers)
214 | end_time = time.time()
215 | print('Spend time {}sec'.format(end_time-start_time))
216 | predictions = np.array(json.loads(json_response.text)['predictions'])
217 | # print(np.argmax(predictions, axis=-1))
218 | label = label_list[predictions.argmax()]
219 | print(label)
220 |
221 |
222 | def predict_offline_batch():
223 | while True:
224 | import time
225 | question = input("> ")
226 | start_time = time.time()
227 | sents_list = question.split('|&|')
228 | instances_dict = {}
229 | instance_list = []
230 | for sent in sents_list:
231 | predict_example = InputExample("100", sent, None, '婚姻家庭')
232 | feature = convert_single_example(100, predict_example, label_list, max_seq_length, tokenizer)
233 | instance_list.append(
234 | {
235 | "input_ids": feature.input_ids,
236 | "input_mask": feature.input_mask,
237 | "segment_ids": feature.segment_ids,
238 | "label_ids": [feature.label_id]
239 | }
240 | )
241 | instances_dict['instances'] = instance_list
242 |
243 | data = json.dumps({
244 | "instances": instances_dict['instances']
245 | })
246 | headers = {"content-type": "application/json"}
247 | # json_response = requests.post(
248 | # 'http://127.0.0.1:9001/v1/models/bert_domain_classify:predict',
249 | # data=data, headers=headers)
250 | json_response = requests.post(
251 | 'http://192.168.0.105:9001/v1/models/bert_classify:predict',
252 | data=data, headers=headers)
253 | end_time = time.time()
254 | print('Spend time {}sec'.format(end_time-start_time))
255 | predictions = np.array(json.loads(json_response.text)['predictions'])
256 | print(predictions)
257 | # print(np.argmax(predictions, axis=-1))
258 | for pred in predictions:
259 |
260 | label = label_list[pred.argmax()]
261 | print(label)
262 |
263 | predict_offline_single()
264 | # predict_offline_batch()
265 |
--------------------------------------------------------------------------------
/run_pb_inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import tokenization
5 | import tensorflow as tf
6 | import numpy as np
7 | import time
8 | parser = argparse.ArgumentParser(
9 | description='BERT model pb model case/batch test program, exit with q')
10 |
11 | parser.add_argument('--model', type=str,
12 | default='./inference/robert_tiny_clue/frozen_model.pb', help='the path for the model')
13 | #TODO: CHECK vocab file
14 | parser.add_argument('--vocab_file', type=str,
15 | # default='../ALBERT/albert_base/vocab_chinese.txt')
16 | default='/Users/huangdongxiao2/CodeRepos/SesameSt/albert_zh/inference/robert_tiny_clue/vocab.txt')
17 | parser.add_argument('--labels', type=list, default=[
18 | 'happy', 'anger', 'lost', 'fear', 'sad', 'other', 'anxiety'], help='label list')
19 | parser.add_argument('--max_seq_length', type=int, default=128,
20 | help='the length of sequence for text padding')
21 | parser.add_argument('--tensor_input_ids', type=str, default='input_ids:0',
22 | help='the input_ids op_name for graph, format: :')
23 | parser.add_argument('--tensor_input_mask', type=str, default='input_mask:0',
24 | help='the input_mask op_name for graph, format: :')
25 | parser.add_argument('--tensor_segment_ids', type=str, default='segment_ids:0',
26 | help='the segment_ids op_name for graph, format: :')
27 | parser.add_argument('--tensor_output', type=str, default='loss/pred_prob:0',
28 | help='the output op_name for graph, format: :')
29 | parser.add_argument('--MODE', type=str, default='SINGLE',
30 | help='SINGLE prediction or BATCH prediction')
31 | args_in_use = parser.parse_args()
32 | """
33 | gpu settting
34 | """
35 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # default: 0
36 | os.environ["CUDA_VISIBLE_DEVICES"] = "2"
37 | """
38 | load pb model and predict
39 | """
40 |
41 |
42 | class InputExample(object):
43 | """A single training/test example for simple sequence classification."""
44 |
45 | def __init__(self, guid, text_a, text_b=None):
46 | """Constructs a InputExample.
47 | Args:
48 | guid: Unique id for the example.
49 | text_a: string. The untokenized text of the first sequence. For single
50 | sequence tasks, only this sequence must be specified.
51 | text_b: (Optional) string. The untokenized text of the second sequence.
52 | Only must be specified for sequence pair tasks.
53 | label: (Optional) string. The label of the example. This should be
54 | specified for train and dev examples, but not for test examples.
55 | """
56 | self.guid = guid
57 | self.text_a = text_a
58 | self.text_b = text_b
59 |
60 |
61 | class InputFeatures(object):
62 | """A single set of features of data."""
63 |
64 | def __init__(self,
65 | input_ids,
66 | input_mask,
67 | segment_ids,
68 | is_real_example=True):
69 | self.input_ids = input_ids
70 | self.input_mask = input_mask
71 | self.segment_ids = segment_ids
72 | self.is_real_example = is_real_example
73 |
74 |
75 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
76 | """Truncates a sequence pair in place to the maximum length."""
77 |
78 | # This is a simple heuristic which will always truncate the longer sequence
79 | # one token at a time. This makes more sense than truncating an equal percent
80 | # of tokens from each, since if one sequence is very short then each token
81 | # that's truncated likely contains more information than a longer sequence.
82 | while True:
83 | total_length = len(tokens_a) + len(tokens_b)
84 | if total_length <= max_length:
85 | break
86 | if len(tokens_a) > len(tokens_b):
87 | tokens_a.pop()
88 | else:
89 | tokens_b.pop()
90 |
91 |
92 | def convert_single_example(example, label_list, max_seq_length,
93 | tokenizer):
94 | """Converts a single `InputExample` into a single `InputFeatures`."""
95 |
96 | label_map = {}
97 | for (i, label) in enumerate(label_list):
98 | label_map[label] = i
99 |
100 | tokens_a = tokenizer.tokenize(example.text_a)
101 | tokens_b = None
102 | if example.text_b:
103 | tokens_b = tokenizer.tokenize(example.text_b)
104 |
105 | if tokens_b:
106 | # Modifies `tokens_a` and `tokens_b` in place so that the total
107 | # length is less than the specified length.
108 | # Account for [CLS], [SEP], [SEP] with "- 3"
109 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
110 | else:
111 | # Account for [CLS] and [SEP] with "- 2"
112 | if len(tokens_a) > max_seq_length - 2:
113 | tokens_a = tokens_a[0:(max_seq_length - 2)]
114 |
115 | # The convention in BERT is:
116 | # (a) For sequence pairs:
117 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
118 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
119 | # (b) For single sequences:
120 | # tokens: [CLS] the dog is hairy . [SEP]
121 | # type_ids: 0 0 0 0 0 0 0
122 | #
123 | # Where "type_ids" are used to indicate whether this is the first
124 | # sequence or the second sequence. The embedding vectors for `type=0` and
125 | # `type=1` were learned during pre-training and are added to the wordpiece
126 | # embedding vector (and position vector). This is not *strictly* necessary
127 | # since the [SEP] token unambiguously separates the sequences, but it makes
128 | # it easier for the model to learn the concept of sequences.
129 | #
130 | # For classification tasks, the first vector (corresponding to [CLS]) is
131 | # used as the "sentence vector". Note that this only makes sense because
132 | # the entire model is fine-tuned.
133 | tokens = []
134 | segment_ids = []
135 | tokens.append("[CLS]")
136 | segment_ids.append(0)
137 | for token in tokens_a:
138 | tokens.append(token)
139 | segment_ids.append(0)
140 | tokens.append("[SEP]")
141 | segment_ids.append(0)
142 |
143 | if tokens_b:
144 | for token in tokens_b:
145 | tokens.append(token)
146 | segment_ids.append(1)
147 | tokens.append("[SEP]")
148 | segment_ids.append(1)
149 |
150 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
151 |
152 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
153 | # tokens are attended to.
154 | input_mask = [1] * len(input_ids)
155 |
156 | # Zero-pad up to the sequence length.
157 | while len(input_ids) < max_seq_length:
158 | input_ids.append(0)
159 | input_mask.append(0)
160 | segment_ids.append(0)
161 |
162 | assert len(input_ids) == max_seq_length
163 | assert len(input_mask) == max_seq_length
164 | assert len(segment_ids) == max_seq_length
165 |
166 | feature = InputFeatures(
167 | input_ids=input_ids,
168 | input_mask=input_mask,
169 | segment_ids=segment_ids,
170 | is_real_example=True)
171 | return feature
172 |
173 |
174 | with tf.Graph().as_default():
175 | output_graph_def = tf.GraphDef()
176 | label_list = args_in_use.labels
177 | label_map = {i: label for i, label in enumerate(label_list)}
178 |
179 | max_seq_length = args_in_use.max_seq_length
180 | """
181 | load pb model
182 | """
183 | with open(args_in_use.model, 'rb') as f:
184 | output_graph_def.ParseFromString(f.read())
185 | tf.import_graph_def(output_graph_def, name='')
186 | """
187 | enter a text and predict
188 | """
189 | with tf.Session() as sess:
190 | tf.global_variables_initializer().run()
191 | input_ids = sess.graph.get_tensor_by_name(
192 | args_in_use.tensor_input_ids)
193 | input_mask = sess.graph.get_tensor_by_name(
194 | args_in_use.tensor_input_mask)
195 | segment_ids = sess.graph.get_tensor_by_name(
196 | args_in_use.tensor_segment_ids)
197 | tokenizer = tokenization.FullTokenizer(
198 | vocab_file=args_in_use.vocab_file, do_lower_case=True)
199 | output = args_in_use.tensor_output
200 | if args_in_use.MODE == 'SINGLE':
201 | while 1:
202 | question = input("enter a sentence:")
203 | if question == 'q' or question == 'quit()':
204 | break
205 | predict_example = InputExample('id', question, None)
206 | feature = convert_single_example(
207 | predict_example, label_list, max_seq_length, tokenizer)
208 | # print(feature.input_ids)
209 | # print(feature.input_mask)
210 | # print(feature.segment_ids)
211 | feed_dict = {
212 | input_ids: [feature.input_ids],
213 | input_mask: [feature.input_mask],
214 | segment_ids: [feature.segment_ids],
215 | }
216 | # works fine
217 | # feed_dict = {
218 | # args_in_use.tensor_input_ids: [feature.input_ids],
219 | # args_in_use.tensor_input_mask: [feature.input_mask],
220 | # args_in_use.tensor_segment_ids: [feature.segment_ids],
221 | # }
222 | start_time = time.time()
223 | y_pred_cls = sess.run(output, feed_dict=feed_dict)
224 | print(f'elapsed time: {time.time()-start_time}s')
225 | max_index = np.argmax(y_pred_cls[0])
226 | print(" current results ", y_pred_cls)
227 | print(f'label: {label_map[max_index]}')
228 | elif args_in_use.MODE == 'BATCH':
229 | questions = [
230 | '我要投诉的',
231 | '我很不开心',
232 | '我好喜欢你'
233 | ]
234 | features = []
235 | for i, question in enumerate(questions):
236 | predict_example = InputExample(f'id{i}', question, None)
237 | feature = convert_single_example(
238 | predict_example, label_list, max_seq_length, tokenizer)
239 | features.append(feature)
240 | feed_dict = {
241 | input_ids: [feature.input_ids for feature in features],
242 | input_mask: [feature.input_mask for feature in features],
243 | segment_ids: [feature.segment_ids for feature in features],
244 | }
245 | y_pred_cls = sess.run(output, feed_dict=feed_dict)
246 | max_idxs = np.argmax(y_pred_cls, 1)
247 | print(y_pred_cls)
248 | print(
249 | f'labels: {[label_map[max_index] for max_index in max_idxs]}')
250 | else:
251 | raise ValueError('unsupported mode')
252 |
--------------------------------------------------------------------------------
/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | for item in items:
140 | output.append(vocab[item])
141 | return output
142 |
143 |
144 | def convert_tokens_to_ids(vocab, tokens):
145 | return convert_by_vocab(vocab, tokens)
146 |
147 |
148 | def convert_ids_to_tokens(inv_vocab, ids):
149 | return convert_by_vocab(inv_vocab, ids)
150 |
151 |
152 | def whitespace_tokenize(text):
153 | """Runs basic whitespace cleaning and splitting on a piece of text."""
154 | text = text.strip()
155 | if not text:
156 | return []
157 | tokens = text.split()
158 | return tokens
159 |
160 |
161 | class FullTokenizer(object):
162 | """Runs end-to-end tokenziation."""
163 |
164 | def __init__(self, vocab_file, do_lower_case=True):
165 | self.vocab = load_vocab(vocab_file)
166 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
169 |
170 | def tokenize(self, text):
171 | split_tokens = []
172 | for token in self.basic_tokenizer.tokenize(text):
173 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
174 | split_tokens.append(sub_token)
175 |
176 | return split_tokens
177 |
178 | def convert_tokens_to_ids(self, tokens):
179 | return convert_by_vocab(self.vocab, tokens)
180 |
181 | def convert_ids_to_tokens(self, ids):
182 | return convert_by_vocab(self.inv_vocab, ids)
183 |
184 |
185 | class BasicTokenizer(object):
186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
187 |
188 | def __init__(self, do_lower_case=True):
189 | """Constructs a BasicTokenizer.
190 |
191 | Args:
192 | do_lower_case: Whether to lower case the input.
193 | """
194 | self.do_lower_case = do_lower_case
195 |
196 | def tokenize(self, text):
197 | """Tokenizes a piece of text."""
198 | text = convert_to_unicode(text)
199 | text = self._clean_text(text)
200 |
201 | # This was added on November 1st, 2018 for the multilingual and Chinese
202 | # models. This is also applied to the English models now, but it doesn't
203 | # matter since the English models were not trained on any Chinese data
204 | # and generally don't have any Chinese data in them (there are Chinese
205 | # characters in the vocabulary because Wikipedia does have some Chinese
206 | # words in the English Wikipedia.).
207 | text = self._tokenize_chinese_chars(text)
208 |
209 | orig_tokens = whitespace_tokenize(text)
210 | split_tokens = []
211 | for token in orig_tokens:
212 | if self.do_lower_case:
213 | token = token.lower()
214 | token = self._run_strip_accents(token)
215 | split_tokens.extend(self._run_split_on_punc(token))
216 |
217 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
218 | return output_tokens
219 |
220 | def _run_strip_accents(self, text):
221 | """Strips accents from a piece of text."""
222 | text = unicodedata.normalize("NFD", text)
223 | output = []
224 | for char in text:
225 | cat = unicodedata.category(char)
226 | if cat == "Mn":
227 | continue
228 | output.append(char)
229 | return "".join(output)
230 |
231 | def _run_split_on_punc(self, text):
232 | """Splits punctuation on a piece of text."""
233 | chars = list(text)
234 | i = 0
235 | start_new_word = True
236 | output = []
237 | while i < len(chars):
238 | char = chars[i]
239 | if _is_punctuation(char):
240 | output.append([char])
241 | start_new_word = True
242 | else:
243 | if start_new_word:
244 | output.append([])
245 | start_new_word = False
246 | output[-1].append(char)
247 | i += 1
248 |
249 | return ["".join(x) for x in output]
250 |
251 | def _tokenize_chinese_chars(self, text):
252 | """Adds whitespace around any CJK character."""
253 | output = []
254 | for char in text:
255 | cp = ord(char)
256 | if self._is_chinese_char(cp):
257 | output.append(" ")
258 | output.append(char)
259 | output.append(" ")
260 | else:
261 | output.append(char)
262 | return "".join(output)
263 |
264 | def _is_chinese_char(self, cp):
265 | """Checks whether CP is the codepoint of a CJK character."""
266 | # This defines a "chinese character" as anything in the CJK Unicode block:
267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
268 | #
269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
270 | # despite its name. The modern Korean Hangul alphabet is a different block,
271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
272 | # space-separated words, so they are not treated specially and handled
273 | # like the all of the other languages.
274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
275 | (cp >= 0x3400 and cp <= 0x4DBF) or #
276 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
277 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
278 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
280 | (cp >= 0xF900 and cp <= 0xFAFF) or #
281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
282 | return True
283 |
284 | return False
285 |
286 | def _clean_text(self, text):
287 | """Performs invalid character removal and whitespace cleanup on text."""
288 | output = []
289 | for char in text:
290 | cp = ord(char)
291 | if cp == 0 or cp == 0xfffd or _is_control(char):
292 | continue
293 | if _is_whitespace(char):
294 | output.append(" ")
295 | else:
296 | output.append(char)
297 | return "".join(output)
298 |
299 |
300 | class WordpieceTokenizer(object):
301 | """Runs WordPiece tokenziation."""
302 |
303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
304 | self.vocab = vocab
305 | self.unk_token = unk_token
306 | self.max_input_chars_per_word = max_input_chars_per_word
307 |
308 | def tokenize(self, text):
309 | """Tokenizes a piece of text into its word pieces.
310 |
311 | This uses a greedy longest-match-first algorithm to perform tokenization
312 | using the given vocabulary.
313 |
314 | For example:
315 | input = "unaffable"
316 | output = ["un", "##aff", "##able"]
317 |
318 | Args:
319 | text: A single token or whitespace separated tokens. This should have
320 | already been passed through `BasicTokenizer.
321 |
322 | Returns:
323 | A list of wordpiece tokens.
324 | """
325 |
326 | text = convert_to_unicode(text)
327 |
328 | output_tokens = []
329 | for token in whitespace_tokenize(text):
330 | chars = list(token)
331 | if len(chars) > self.max_input_chars_per_word:
332 | output_tokens.append(self.unk_token)
333 | continue
334 |
335 | is_bad = False
336 | start = 0
337 | sub_tokens = []
338 | while start < len(chars):
339 | end = len(chars)
340 | cur_substr = None
341 | while start < end:
342 | substr = "".join(chars[start:end])
343 | if start > 0:
344 | substr = "##" + substr
345 | if substr in self.vocab:
346 | cur_substr = substr
347 | break
348 | end -= 1
349 | if cur_substr is None:
350 | is_bad = True
351 | break
352 | sub_tokens.append(cur_substr)
353 | start = end
354 |
355 | if is_bad:
356 | output_tokens.append(self.unk_token)
357 | else:
358 | output_tokens.extend(sub_tokens)
359 | return output_tokens
360 |
361 |
362 | def _is_whitespace(char):
363 | """Checks whether `chars` is a whitespace character."""
364 | # \t, \n, and \r are technically contorl characters but we treat them
365 | # as whitespace since they are generally considered as such.
366 | if char == " " or char == "\t" or char == "\n" or char == "\r":
367 | return True
368 | cat = unicodedata.category(char)
369 | if cat == "Zs":
370 | return True
371 | return False
372 |
373 |
374 | def _is_control(char):
375 | """Checks whether `chars` is a control character."""
376 | # These are technically control characters but we count them as whitespace
377 | # characters.
378 | if char == "\t" or char == "\n" or char == "\r":
379 | return False
380 | cat = unicodedata.category(char)
381 | if cat.startswith("C"):
382 | return True
383 | return False
384 |
385 |
386 | def _is_punctuation(char):
387 | """Checks whether `chars` is a punctuation character."""
388 | cp = ord(char)
389 | # We treat all non-letter/number ASCII as punctuation.
390 | # Characters such as "^", "$", and "`" are not in the Unicode
391 | # Punctuation class but we treat them as punctuation anyways, for
392 | # consistency.
393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
395 | return True
396 | cat = unicodedata.category(char)
397 | if cat.startswith("P"):
398 | return True
399 | return False
400 |
--------------------------------------------------------------------------------
/cmrc/cmrc_tool/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | for item in items:
140 | output.append(vocab[item])
141 | return output
142 |
143 |
144 | def convert_tokens_to_ids(vocab, tokens):
145 | return convert_by_vocab(vocab, tokens)
146 |
147 |
148 | def convert_ids_to_tokens(inv_vocab, ids):
149 | return convert_by_vocab(inv_vocab, ids)
150 |
151 |
152 | def whitespace_tokenize(text):
153 | """Runs basic whitespace cleaning and splitting on a piece of text."""
154 | text = text.strip()
155 | if not text:
156 | return []
157 | tokens = text.split()
158 | return tokens
159 |
160 | class FullTokenizer(object): # what bert uses
161 | """Runs end-to-end tokenziation."""
162 |
163 | def __init__(self, vocab_file, do_lower_case=True):
164 | self.vocab = load_vocab(vocab_file)
165 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
166 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
167 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
168 |
169 | def tokenize(self, text):
170 | split_tokens = []
171 | for token in self.basic_tokenizer.tokenize(text):
172 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
173 | split_tokens.append(sub_token)
174 |
175 | return split_tokens
176 |
177 | def convert_tokens_to_ids(self, tokens):
178 | return convert_by_vocab(self.vocab, tokens)
179 |
180 | def convert_ids_to_tokens(self, ids):
181 | return convert_by_vocab(self.inv_vocab, ids)
182 |
183 | class BasicTokenizer(object):# tokenize by space
184 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
185 |
186 | def __init__(self, do_lower_case=True):
187 | """Constructs a BasicTokenizer.
188 |
189 | Args:
190 | do_lower_case: Whether to lower case the input.
191 | """
192 | self.do_lower_case = do_lower_case
193 |
194 | def tokenize(self, text):
195 | """Tokenizes a piece of text."""
196 | text = convert_to_unicode(text)
197 | text = self._clean_text(text)
198 |
199 | # This was added on November 1st, 2018 for the multilingual and Chinese
200 | # models. This is also applied to the English models now, but it doesn't
201 | # matter since the English models were not trained on any Chinese data
202 | # and generally don't have any Chinese data in them (there are Chinese
203 | # characters in the vocabulary because Wikipedia does have some Chinese
204 | # words in the English Wikipedia.).
205 | text = self._tokenize_chinese_chars(text) # 也就是在中文字符的前后加上空格,这样后续的分词流程会把每一个字符当成一个词。
206 |
207 | orig_tokens = whitespace_tokenize(text)
208 | split_tokens = []
209 | for token in orig_tokens:
210 | if self.do_lower_case:
211 | token = token.lower()
212 | token = self._run_strip_accents(token)
213 | split_tokens.extend(self._run_split_on_punc(token))
214 |
215 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
216 | return output_tokens
217 |
218 | def _run_strip_accents(self, text):
219 | """
220 | Strips accents from a piece of text.
221 |
222 | >>> s1 = 'café'
223 | >>> s2 = 'cafe\u0301'
224 | >>> s1, s2
225 | ('café', 'café')
226 | >>> len(s1), len(s2)
227 | (4, 5)
228 | >>> s1 == s2
229 | False
230 | 我们”看到”的é其实可以有两种表示方法,一是用一个codepoint直接表示”é”,另外一种是用”e”再加上特殊的codepoint U+0301两个字符来表示。U+0301是COMBINING ACUTE ACCENT,它跟在e之后就变成了”é”。类似的”a\u0301”显示出来就是”á”。注意:这只是打印出来一模一样而已,但是在计算机内部的表示它们完全不同的,前者é是一个codepoint,值为0xe9,而后者是两个codepoint,分别是0x65和0x301。unicodedata.normalize(“NFD”, text)就会把0xe9变成0x65和0x301,比如下面的测试代码。
231 |
232 |
233 | """
234 | text = unicodedata.normalize("NFD", text)
235 | output = []
236 | for char in text:
237 | cat = unicodedata.category(char)
238 | if cat == "Mn": # ACCENT
239 | continue
240 | output.append(char)
241 | return "".join(output)
242 |
243 | def _run_split_on_punc(self, text):
244 | """Splits punctuation on a piece of text."""
245 | chars = list(text)
246 | i = 0
247 | start_new_word = True
248 | output = []
249 | """这个函数会对输入字符串用标点进行切分,返回一个list,list的每一个元素都是一个char。比如输入he’s,则输出是[[h,e], [’],[s]]。"""
250 | while i < len(chars):
251 | char = chars[i]
252 | if _is_punctuation(char):
253 | output.append([char])
254 | start_new_word = True
255 | else:
256 | if start_new_word:
257 | output.append([])
258 | start_new_word = False
259 | output[-1].append(char)
260 | i += 1
261 |
262 | return ["".join(x) for x in output]
263 |
264 | def _tokenize_chinese_chars(self, text):
265 | """Adds whitespace around any CJK character."""
266 | output = []
267 | for char in text:
268 | cp = ord(char)
269 | if self._is_chinese_char(cp):
270 | output.append(" ")
271 | output.append(char)
272 | output.append(" ")
273 | else:
274 | output.append(char)
275 | return "".join(output)
276 |
277 | def _is_chinese_char(self, cp):
278 | """Checks whether CP is the codepoint of a CJK character."""
279 | # This defines a "chinese character" as anything in the CJK Unicode block:
280 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
281 | #
282 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
283 | # despite its name. The modern Korean Hangul alphabet is a different block,
284 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
285 | # space-separated words, so they are not treated specially and handled
286 | # like the all of the other languages.
287 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
288 | (cp >= 0x3400 and cp <= 0x4DBF) or #
289 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
290 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
291 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
292 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
293 | (cp >= 0xF900 and cp <= 0xFAFF) or #
294 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
295 | return True
296 |
297 | return False
298 |
299 | def _clean_text(self, text):
300 | """Performs invalid character removal and whitespace cleanup on text."""
301 | output = []
302 | for char in text:
303 | cp = ord(char)
304 | # codepoint == 0 means: nonsense character; codepoint == 0xfffd is shown as � which is used for unrecognized character
305 | if cp == 0 or cp == 0xfffd or _is_control(char):
306 | continue
307 | if _is_whitespace(char):
308 | output.append(" ")
309 | else:
310 | output.append(char)
311 | return "".join(output)
312 |
313 |
314 | class WordpieceTokenizer(object): # tokenize by space and then do it tokenize in a more detailede way to WordPiece level
315 | """Runs WordPiece tokenziation."""
316 |
317 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
318 | self.vocab = vocab
319 | self.unk_token = unk_token
320 | self.max_input_chars_per_word = max_input_chars_per_word
321 |
322 | def tokenize(self, text):
323 | """Tokenizes a piece of text into its word pieces.
324 |
325 | This uses a greedy longest-match-first algorithm to perform tokenization
326 | using the given vocabulary.
327 |
328 | For example:
329 | input = "unaffable"
330 | 我们用一个例子来看代码的执行过程。比如假设输入是”unaffable”。我们跳到while循环部分,这是start=0,end=len(chars)=9,也就是先看看unaffable在不在词典里,如果在,那么直接作为一个WordPiece,如果不再,那么end-=1,也就是看unaffabl在不在词典里,最终发现”un”在词典里,把un加到结果里。
331 |
332 | 接着start=2,看affable在不在,不在再看affabl,…,最后发现 ##aff 在词典里。注意:##表示这个词是接着前面的,这样使得WordPiece切分是可逆的——我们可以恢复出“真正”的词。
333 | output = ["un", "##aff", "##able"]
334 |
335 | Args:
336 | text: A single token or whitespace separated tokens. This should have
337 | already been passed through `BasicTokenizer.
338 |
339 | Returns:
340 | A list of wordpiece tokens.
341 | """
342 |
343 | text = convert_to_unicode(text)
344 |
345 | output_tokens = []
346 | for token in whitespace_tokenize(text):
347 | chars = list(token)
348 | if len(chars) > self.max_input_chars_per_word:
349 | output_tokens.append(self.unk_token)
350 | continue
351 |
352 | is_bad = False
353 | start = 0
354 | sub_tokens = []
355 | while start < len(chars):
356 | end = len(chars)
357 | cur_substr = None
358 | while start < end:
359 | substr = "".join(chars[start:end])
360 | if start > 0:
361 | substr = "##" + substr
362 | if substr in self.vocab:
363 | cur_substr = substr
364 | break
365 | end -= 1
366 | if cur_substr is None:
367 | is_bad = True
368 | break
369 | sub_tokens.append(cur_substr)
370 | start = end
371 |
372 | if is_bad:
373 | output_tokens.append(self.unk_token)
374 | else:
375 | output_tokens.extend(sub_tokens)
376 | return output_tokens
377 |
378 |
379 | def _is_whitespace(char):
380 | """Checks whether `chars` is a whitespace character."""
381 | # \t, \n, and \r are technically contorl characters but we treat them
382 | # as whitespace since they are generally considered as such.
383 | if char == " " or char == "\t" or char == "\n" or char == "\r":
384 | return True
385 | cat = unicodedata.category(char)
386 | if cat == "Zs":
387 | return True
388 | return False
389 |
390 |
391 | def _is_control(char):
392 | """Checks whether `chars` is a control character. non printable character"""
393 | # These are technically control characters but we count them as whitespace
394 | # characters.
395 | if char == "\t" or char == "\n" or char == "\r":
396 | return False
397 | cat = unicodedata.category(char)
398 | if cat.startswith("C"):
399 | return True
400 | return False
401 |
402 |
403 | def _is_punctuation(char):
404 | """Checks whether `chars` is a punctuation character."""
405 | cp = ord(char)
406 | # We treat all non-letter/number ASCII as punctuation.
407 | # Characters such as "^", "$", and "`" are not in the Unicode
408 | # Punctuation class but we treat them as punctuation anyways, for
409 | # consistency.
410 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
411 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
412 | return True
413 | cat = unicodedata.category(char)
414 | if cat.startswith("P"):
415 | return True
416 | return False
417 |
--------------------------------------------------------------------------------
/cmrc/cmrc_tool/run_squad_inf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import math
5 | import tokenization
6 | import collections
7 | import tensorflow as tf
8 | import numpy as np
9 | import time
10 | import six
11 |
12 | parser = argparse.ArgumentParser(
13 | description='BERT model saved model case/batch test program, exit with q')
14 |
15 | parser.add_argument('--model', type=str,
16 | default='./1586959006', help='the path for the model')
17 | parser.add_argument('--vocab_file', type=str,
18 | default='./1586959006/vocab.txt')
19 | parser.add_argument('--max_seq_length', type=int, default=384,
20 | help='the length of sequence for text padding')
21 | parser.add_argument('--do_lower_case', type=bool, default=True,
22 | help='Whether to lower case the input text. Should be True for uncased models and False for cased models.')
23 | parser.add_argument('--max_answer_length', type=int, default=128,
24 | help='The maximum length of an answer that can be generated. This is needed because the start and end predictions are not conditioned on one another.')
25 | parser.add_argument('--n_best_size', type=int, default=20,
26 | help='"The total number of n-best predictions to generate in the nbest_predictions.json output file.')
27 | parser.add_argument('--doc_stride', type=int, default=128,
28 | help='the length of document stride')
29 | parser.add_argument('--max_query_length', type=int, default=64,
30 | help='the max length of query length')
31 | parser.add_argument('--tensor_start_positions', type=str, default='start_positions',
32 | help='the start_positions feature name for saved model')
33 | parser.add_argument('--tensor_end_positions', type=str, default='end_positions',
34 | help='the end_positions feature name for saved model')
35 | parser.add_argument('--tensor_unique_ids', type=str, default='unique_ids',
36 | help='the unique_ids feature name for saved model')
37 | parser.add_argument('--tensor_input_ids', type=str, default='input_ids',
38 | help='the input_ids feature name for saved model')
39 | parser.add_argument('--tensor_input_mask', type=str, default='input_mask',
40 | help='the input_mask feature name for saved model')
41 | parser.add_argument('--tensor_segment_ids', type=str, default='segment_ids',
42 | help='the segment_ids feature name for saved model')
43 | parser.add_argument('--MODE', type=str, default='SINGLE',
44 | help='SINGLE prediction or BATCH prediction')
45 | args_in_use = parser.parse_args()
46 |
47 |
48 | class SquadExample(object):
49 | """A single training/test example for simple sequence classification.
50 |
51 | For examples without an answer, the start and end position are -1.
52 | """
53 |
54 | def __init__(self,
55 | qas_id,
56 | question_text,
57 | doc_tokens,
58 | orig_answer_text=None,
59 | start_position=None,
60 | end_position=None,
61 | is_impossible=False):
62 | self.qas_id = qas_id
63 | self.question_text = question_text
64 | self.doc_tokens = doc_tokens
65 | self.orig_answer_text = orig_answer_text
66 | self.start_position = start_position
67 | self.end_position = end_position
68 | self.is_impossible = is_impossible
69 |
70 | def __str__(self):
71 | return self.__repr__()
72 |
73 | def __repr__(self):
74 | s = ""
75 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
76 | s += ", question_text: %s" % (
77 | tokenization.printable_text(self.question_text))
78 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
79 | if self.start_position:
80 | s += ", start_position: %d" % (self.start_position)
81 | if self.start_position:
82 | s += ", end_position: %d" % (self.end_position)
83 | if self.start_position:
84 | s += ", is_impossible: %r" % (self.is_impossible)
85 | return s
86 |
87 |
88 | class InputFeatures(object):
89 | """A single set of features of data."""
90 |
91 | def __init__(self,
92 | unique_id,
93 | example_index,
94 | doc_span_index,
95 | tokens,
96 | token_to_orig_map,
97 | token_is_max_context,
98 | input_ids,
99 | input_mask,
100 | segment_ids,
101 | start_position=None,
102 | end_position=None,
103 | is_impossible=None):
104 | self.unique_id = unique_id
105 | self.example_index = example_index
106 | self.doc_span_index = doc_span_index
107 | self.tokens = tokens
108 | self.token_to_orig_map = token_to_orig_map
109 | self.token_is_max_context = token_is_max_context
110 | self.input_ids = input_ids
111 | self.input_mask = input_mask
112 | self.segment_ids = segment_ids
113 | self.start_position = start_position
114 | self.end_position = end_position
115 | self.is_impossible = is_impossible
116 |
117 |
118 | def _check_is_max_context(doc_spans, cur_span_index, position):
119 | """Check if this is the 'max context' doc span for the token."""
120 |
121 | # Because of the sliding window approach taken to scoring documents, a single
122 | # token can appear in multiple documents. E.g.
123 | # Doc: the man went to the store and bought a gallon of milk
124 | # Span A: the man went to the
125 | # Span B: to the store and bought
126 | # Span C: and bought a gallon of
127 | # ...
128 | #
129 | # Now the word 'bought' will have two scores from spans B and C. We only
130 | # want to consider the score with "maximum context", which we define as
131 | # the *minimum* of its left and right context (the *sum* of left and
132 | # right context will always be the same, of course).
133 | #
134 | # In the example the maximum context for 'bought' would be span C since
135 | # it has 1 left context and 3 right context, while span B has 4 left context
136 | # and 0 right context.
137 | best_score = None
138 | best_span_index = None
139 | for (span_index, doc_span) in enumerate(doc_spans):
140 | end = doc_span.start + doc_span.length - 1
141 | if position < doc_span.start:
142 | continue
143 | if position > end:
144 | continue
145 | num_left_context = position - doc_span.start
146 | num_right_context = end - position
147 | score = min(num_left_context, num_right_context) + \
148 | 0.01 * doc_span.length
149 | if best_score is None or score > best_score:
150 | best_score = score
151 | best_span_index = span_index
152 |
153 | return cur_span_index == best_span_index
154 |
155 |
156 | def convert_examples_to_features(examples, tokenizer, max_seq_length,
157 | doc_stride, max_query_length):
158 | """Loads a data file into a list of `InputBatch`s."""
159 |
160 | unique_id = 1000000000
161 |
162 | features = []
163 | for (example_index, example) in enumerate(examples):
164 | query_tokens = tokenizer.tokenize(example.question_text)
165 |
166 | if len(query_tokens) > max_query_length:
167 | query_tokens = query_tokens[0:max_query_length]
168 |
169 | # 将原文变成token之后,token和原始文本index的对应关系
170 | tok_to_orig_index = []
171 | # 将原文变成token之后,原始文本和token的index的对应关系
172 | orig_to_tok_index = []
173 | all_doc_tokens = []
174 | for (i, token) in enumerate(example.doc_tokens):
175 | orig_to_tok_index.append(len(all_doc_tokens))
176 | sub_tokens = tokenizer.tokenize(token)
177 | for sub_token in sub_tokens:
178 | tok_to_orig_index.append(i)
179 | all_doc_tokens.append(sub_token)
180 |
181 | tok_start_position = None
182 | tok_end_position = None
183 |
184 | # The -3 accounts for [CLS], [SEP] and [SEP]
185 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
186 |
187 | # We can have documents that are longer than the maximum sequence length.
188 | # To deal with this we do a sliding window approach, where we take chunks
189 | # of the up to our max length with a stride of `doc_stride`.
190 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
191 | "DocSpan", ["start", "length"])
192 | doc_spans = []
193 | start_offset = 0
194 | while start_offset < len(all_doc_tokens):
195 | length = len(all_doc_tokens) - start_offset
196 | if length > max_tokens_for_doc:
197 | length = max_tokens_for_doc
198 | doc_spans.append(_DocSpan(start=start_offset, length=length))
199 | # 如果length是小于max_tokens_for_doc的话,那么就会直接break
200 | if start_offset + length == len(all_doc_tokens):
201 | break
202 | start_offset += min(length, doc_stride)
203 |
204 | for (doc_span_index, doc_span) in enumerate(doc_spans):
205 | tokens = []
206 | token_to_orig_map = {}
207 | token_is_max_context = {}
208 | # segment_ids中0代表[CLS]、第一个[SEP]和query_tokens,1代表doc和第二个[SEP]
209 | segment_ids = []
210 | tokens.append("[CLS]")
211 | segment_ids.append(0)
212 | for token in query_tokens:
213 | tokens.append(token)
214 | segment_ids.append(0)
215 | tokens.append("[SEP]")
216 | segment_ids.append(0)
217 |
218 | for i in range(doc_span.length):
219 | split_token_index = doc_span.start + i
220 | # len(tokens)为[CLS]+query_tokens+[SEP]的大小,应该是doc_tokens第i个token
221 | token_to_orig_map[len(
222 | tokens)] = tok_to_orig_index[split_token_index]
223 |
224 | is_max_context = _check_is_max_context(doc_spans, doc_span_index,
225 | split_token_index)
226 | token_is_max_context[len(tokens)] = is_max_context
227 | tokens.append(all_doc_tokens[split_token_index])
228 | segment_ids.append(1)
229 | tokens.append("[SEP]")
230 | segment_ids.append(1)
231 |
232 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
233 |
234 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
235 | # tokens are attended to.
236 | input_mask = [1] * len(input_ids)
237 |
238 | # Zero-pad up to the sequence length.
239 | while len(input_ids) < max_seq_length:
240 | input_ids.append(0)
241 | input_mask.append(0)
242 | segment_ids.append(0)
243 |
244 | assert len(input_ids) == max_seq_length
245 | assert len(input_mask) == max_seq_length
246 | assert len(segment_ids) == max_seq_length
247 |
248 | start_position = None
249 | end_position = None
250 |
251 | features.append(
252 | InputFeatures(
253 | unique_id=unique_id,
254 | example_index=example_index,
255 | doc_span_index=doc_span_index,
256 | tokens=tokens, # 可能会引入UNK
257 | token_to_orig_map=token_to_orig_map,
258 | token_is_max_context=token_is_max_context,
259 | input_ids=input_ids,
260 | input_mask=input_mask,
261 | segment_ids=segment_ids,
262 | start_position=start_position,
263 | end_position=end_position))
264 | unique_id += 1
265 |
266 | return features
267 |
268 |
269 | def read_squad_examples(input_data, tokenizer):
270 | """
271 | https://github.com/eva-n27/BERT-for-Chinese-Question-Answering/blob/master/run_squad.py
272 | Read a SQuAD json file into a list of SquadExample.
273 | 这个函数将input_data[i]["paragraphs"]["context"]变成一个list,词的list
274 | 然后遍历"qas",对于每一个qa,提取
275 | {
276 | qas_id: qa['id'],
277 | question_text: qa["question"],
278 | orig_answer_text: answer["text"],
279 | start_position: start_position,
280 | end_position: end_position
281 | }
282 | """
283 | import unicodedata
284 |
285 | def is_whitespace(c):
286 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
287 | return True
288 | return False
289 |
290 | def is_control(char):
291 | """Checks whether `chars` is a control character."""
292 | # These are technically control characters but we count them as whitespace
293 | # characters.
294 | if char == "\t" or char == "\n" or char == "\r":
295 | return False
296 | cat = unicodedata.category(char)
297 | if cat.startswith("C"):
298 | return True
299 | return False
300 |
301 | def clean_text(text):
302 | """Performs invalid character removal and whitespace cleanup on text."""
303 | output = []
304 | for char in text:
305 | cp = ord(char)
306 | if cp == 0 or cp == 0xfffd or is_control(char):
307 | continue
308 | if is_whitespace(char):
309 | output.append(" ")
310 | else:
311 | output.append(char)
312 | return "".join(output)
313 |
314 | examples = []
315 | tf.logging.info("*** reading squad examples ***")
316 | for entry in input_data:
317 | for paragraph in entry["paragraphs"]:
318 | paragraph_text = " ".join(tokenization.whitespace_tokenize(
319 | clean_text(paragraph["context"])))
320 |
321 | for qa in paragraph["qas"]:
322 | doc_tokens = tokenizer.basic_tokenizer.tokenize(paragraph_text)
323 | qas_id = qa["id"]
324 | question_text = qa["question"]
325 | start_position = None
326 | end_position = None
327 | orig_answer_text = None
328 | is_impossible = False
329 |
330 | example = SquadExample(
331 | qas_id=qas_id,
332 | question_text=question_text,
333 | doc_tokens=doc_tokens,
334 | orig_answer_text=orig_answer_text,
335 | start_position=start_position,
336 | end_position=end_position)
337 | examples.append(example)
338 | return examples
339 |
340 |
341 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
342 | """Project the tokenized prediction back to the original text."""
343 |
344 | # When we created the data, we kept track of the alignment between original
345 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
346 | # now `orig_text` contains the span of our original text corresponding to the
347 | # span that we predicted.
348 | #
349 | # However, `orig_text` may contain extra characters that we don't want in
350 | # our prediction.
351 | #
352 | # For example, let's say:
353 | # pred_text = steve smith
354 | # orig_text = Steve Smith's
355 | #
356 | # We don't want to return `orig_text` because it contains the extra "'s".
357 | #
358 | # We don't want to return `pred_text` because it's already been normalized
359 | # (the SQuAD eval script also does punctuation stripping/lower casing but
360 | # our tokenizer does additional normalization like stripping accent
361 | # characters).
362 | #
363 | # What we really want to return is "Steve Smith".
364 | #
365 | # Therefore, we have to apply a semi-complicated alignment heruistic between
366 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
367 | # can fail in certain cases in which case we just return `orig_text`.
368 |
369 | def _strip_spaces(text):
370 | ns_chars = []
371 | ns_to_s_map = collections.OrderedDict()
372 | for (i, c) in enumerate(text):
373 | if c == " ":
374 | continue
375 | ns_to_s_map[len(ns_chars)] = i
376 | ns_chars.append(c)
377 | ns_text = "".join(ns_chars)
378 | return (ns_text, ns_to_s_map)
379 |
380 | # We first tokenize `orig_text`, strip whitespace from the result
381 | # and `pred_text`, and check if they are the same length. If they are
382 | # NOT the same length, the heuristic has failed. If they are the same
383 | # length, we assume the characters are one-to-one aligned.
384 |
385 | #################################################
386 | """
387 | 原本
388 | """
389 | # tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
390 | # tok_text = "".join(tokenizer.tokenize(orig_text))
391 | """
392 | 更新的版本 20200416
393 | """
394 | tokenizer = tokenization.FullTokenizer(vocab_file=args_in_use.vocab_file,
395 | do_lower_case=True)
396 | tok_text = "".join(tokenizer.tokenize(orig_text))
397 | # De-tokenize WordPieces that have been split off.
398 | tok_text = tok_text.replace(" ##", "")
399 | tok_text = tok_text.replace("##", "")
400 | ####################################################
401 |
402 | start_position = tok_text.find(pred_text)
403 | if start_position == -1:
404 | print(f'{pred_text} not in {tok_text}') #UNK 的情况
405 | return orig_text
406 | end_position = start_position + len(pred_text) - 1
407 |
408 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
409 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
410 |
411 | if len(orig_ns_text) != len(tok_ns_text):
412 | return orig_text
413 |
414 | # We then project the characters in `pred_text` back to `orig_text` using
415 | # the character-to-character alignment.
416 | tok_s_to_ns_map = {}
417 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
418 | tok_s_to_ns_map[tok_index] = i
419 |
420 | orig_start_position = None
421 | if start_position in tok_s_to_ns_map:
422 | ns_start_position = tok_s_to_ns_map[start_position]
423 | if ns_start_position in orig_ns_to_s_map:
424 | orig_start_position = orig_ns_to_s_map[ns_start_position]
425 |
426 | if orig_start_position is None:
427 | return orig_text
428 |
429 | orig_end_position = None
430 | if end_position in tok_s_to_ns_map:
431 | ns_end_position = tok_s_to_ns_map[end_position]
432 | if ns_end_position in orig_ns_to_s_map:
433 | orig_end_position = orig_ns_to_s_map[ns_end_position]
434 |
435 | if orig_end_position is None:
436 | return orig_text
437 |
438 | output_text = orig_text[orig_start_position:(orig_end_position + 1)]
439 | return output_text
440 |
441 |
442 | def _compute_softmax(scores):
443 | """Compute softmax probability over raw logits."""
444 | if not scores:
445 | return []
446 |
447 | max_score = None
448 | for score in scores:
449 | if max_score is None or score > max_score:
450 | max_score = score
451 |
452 | exp_scores = []
453 | total_sum = 0.0
454 | for score in scores:
455 | x = math.exp(score - max_score)
456 | exp_scores.append(x)
457 | total_sum += x
458 |
459 | probs = []
460 | for score in exp_scores:
461 | probs.append(score / total_sum)
462 | return probs
463 |
464 |
465 | def _get_best_indexes(logits, n_best_size):
466 | """Get the n-best logits from a list."""
467 | index_and_score = sorted(
468 | enumerate(logits), key=lambda x: x[1], reverse=True)
469 |
470 | best_indexes = []
471 | for i in range(len(index_and_score)):
472 | if i >= n_best_size:
473 | break
474 | best_indexes.append(index_and_score[i][0])
475 | return best_indexes
476 |
477 |
478 | def get_predictions(all_examples, all_features, all_results, n_best_size,
479 | max_answer_length, do_lower_case):
480 |
481 | example_index_to_features = collections.defaultdict(list)
482 | for feature in all_features:
483 | example_index_to_features[feature.example_index].append(feature)
484 |
485 | unique_id_to_result = {}
486 | for result in all_results:
487 | unique_id_to_result[result.unique_id] = result
488 |
489 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
490 | "PrelimPrediction",
491 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
492 |
493 | all_predictions = collections.OrderedDict()
494 | all_nbest_json = collections.OrderedDict()
495 | for (example_index, example) in enumerate(all_examples):
496 | features = example_index_to_features[example_index]
497 |
498 | prelim_predictions = []
499 | for (feature_index, feature) in enumerate(features):
500 | result = unique_id_to_result[feature.unique_id]
501 |
502 | start_indexes = _get_best_indexes(result.start_logits, n_best_size)
503 | end_indexes = _get_best_indexes(result.end_logits, n_best_size)
504 | for start_index in start_indexes:
505 | for end_index in end_indexes:
506 | # We could hypothetically create invalid predictions, e.g., predict
507 | # that the start of the span is in the question. We throw out all
508 | # invalid predictions.
509 | if start_index >= len(feature.tokens):
510 | continue
511 | if end_index >= len(feature.tokens):
512 | continue
513 | if start_index not in feature.token_to_orig_map:
514 | continue
515 | if end_index not in feature.token_to_orig_map:
516 | continue
517 | if not feature.token_is_max_context.get(start_index, False):
518 | continue
519 | if end_index < start_index:
520 | continue
521 | length = end_index - start_index + 1
522 | if length > max_answer_length:
523 | continue
524 | prelim_predictions.append(
525 | _PrelimPrediction(
526 | feature_index=feature_index,
527 | start_index=start_index,
528 | end_index=end_index,
529 | start_logit=result.start_logits[start_index],
530 | end_logit=result.end_logits[end_index]))
531 |
532 | prelim_predictions = sorted(
533 | prelim_predictions,
534 | key=lambda x: (x.start_logit + x.end_logit),
535 | reverse=True)
536 |
537 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
538 | "NbestPrediction", ["text", "start_logit", "end_logit"])
539 |
540 | seen_predictions = {}
541 | nbest = []
542 | for pred in prelim_predictions:
543 | if len(nbest) >= n_best_size:
544 | break
545 | feature = features[pred.feature_index]
546 |
547 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
548 | orig_doc_start = feature.token_to_orig_map[pred.start_index]
549 | orig_doc_end = feature.token_to_orig_map[pred.end_index]
550 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
551 | tok_text = "".join(tok_tokens)
552 |
553 | # De-tokenize WordPieces that have been split off.
554 | tok_text = tok_text.replace(" ##", "")
555 | tok_text = tok_text.replace("##", "")
556 |
557 | # Clean whitespace
558 | tok_text = tok_text.strip()
559 | tok_text = " ".join(tok_text.split())
560 | orig_text = "".join(orig_tokens)
561 |
562 | final_text = get_final_text(tok_text, orig_text, do_lower_case)
563 | if final_text in seen_predictions:
564 | continue
565 |
566 | seen_predictions[final_text] = True
567 | nbest.append(
568 | _NbestPrediction(
569 | text=final_text,
570 | start_logit=pred.start_logit,
571 | end_logit=pred.end_logit))
572 |
573 | # In very rare edge cases we could have no valid predictions. So we
574 | # just create a nonce prediction in this case to avoid failure.
575 | if not nbest:
576 | nbest.append(
577 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
578 |
579 | assert len(nbest) >= 1
580 |
581 | total_scores = []
582 | for entry in nbest:
583 | total_scores.append(entry.start_logit + entry.end_logit)
584 |
585 | probs = _compute_softmax(total_scores)
586 |
587 | nbest_json = []
588 | for (i, entry) in enumerate(nbest):
589 | output = collections.OrderedDict()
590 | output["text"] = entry.text
591 | output["probability"] = probs[i]
592 | output["start_logit"] = entry.start_logit
593 | output["end_logit"] = entry.end_logit
594 | nbest_json.append(output)
595 |
596 | assert len(nbest_json) >= 1
597 |
598 | all_predictions[example.qas_id] = nbest_json[0]["text"]
599 | all_nbest_json[example.qas_id] = nbest_json
600 | return all_predictions, all_nbest_json
601 |
602 |
603 | RawResult = collections.namedtuple("RawResult",
604 | ["unique_id", "start_logits", "end_logits"])
605 |
606 |
607 | if __name__ == '__main__':
608 | predict_fn = tf.contrib.predictor.from_saved_model(args_in_use.model)
609 | max_seq_length = args_in_use.max_seq_length
610 | tokenizer = tokenization.FullTokenizer(vocab_file=args_in_use.vocab_file,
611 | do_lower_case=True)
612 | if args_in_use.MODE == 'SINGLE':
613 | # while True:
614 | # paragraph = input('(PRESS q to quit)请输入段落\n> ')
615 | # question = input('(PRESS q to quit)请输入问题\n> ')
616 | # if question == 'q':
617 | # break
618 |
619 | """
620 | DEMO EXAMPLES
621 | """
622 | # paragraph = "《战国无双3》()是由光荣和ω-force开发的战国无双系列 的正统第三续作。本作以三大故事为主轴,分别是以武田信玄等人为主的《关东三国 志》,织田信长等人为主的《战国三杰》,石田三成等人为主的《关原的年轻武者》 ,丰富游戏内的剧情。此部份专门介绍角色,欲知武器情报、奥义字或擅长攻击类型 等,请至战国无双系列1.由于乡里大辅先生因故去世,不得不寻找其他声优接手。从 猛将传 and Z开始。2.战国无双 编年史的原创男女主角亦有专属声优。此模式是任 > 天堂游戏谜之村雨城改编的新增模式。本作中共有20张战场地图(不含村雨城),后 来发行的猛将传再新增3张战场地图。但游戏内战役数量繁多,部分地图会有兼用的> 状况,战役虚实则是以光荣发行的2本「战国无双3 人物真书」内容为主,以下是相> 关介绍。(注:前方加☆者为猛将传新增关卡及地图。)合并本篇和猛将传的内容"
623 | # question = "《战国无双3》是由哪两个公司合作开发的?"
624 | # paragraph = "弗赖永广场(Freyung)是奥地利首都维也纳的一个三角形广场,位于内城(第一区)。这个广场最初位于古罗马堡垒Vindabona城墙以外,在12世纪,奥地利公爵亨利二世邀请爱尔兰僧侣来此修建了“苏格兰修道院”(Schottenkloster),因为当时爱尔兰被称为“新苏格兰”。修道院周围的广场也称为苏格兰广场。“弗赖永”这个名称源于古代德语词汇“frey”,意为自由。这是因为修道院拥有不受公爵管理的特权,还能保护逃亡者。1773年,其侧又新建小修道院,因其形状被称为鞋盒房子(Schubladkastenhaus)。弗赖永广场成为重要的市场,各种各样的街头艺术家以表演为生。其中一种表演是“维也纳小丑”(Wiener Hanswurst)。由于霍夫堡皇宫距此不远,在17世纪和18世纪,许多奥地利贵族在广场上和附近的Herrengasse兴建他们的城市住所。1856年,拆除了弗赖永广场和毗邻的Am Hof广场之间的房屋,以加宽街道。19世纪后期,银行和其他金融机构也迁来此处设立总部。"
625 | # paragraph = "“弗赖永”这个名称源于古代德语词汇“frey”,意为自由。这是因为修道院拥有不受公爵管理的特权,还能保护逃亡者。1773年,其侧又新建小修道院,因其形状被称为鞋盒房子(Schubladkastenhaus)。弗赖永广场成为重要的市场,各种各样的街头艺术家以表演为生。其中一种表演是“维也纳小丑”(Wiener Hanswurst)。由于霍夫堡皇宫距此不远,在17世纪和18世纪,许多奥地利贵族在广场上和附近的Herrengasse兴建他们的城市住所。1856年,拆除了弗赖永广场和毗邻的Am Hof广场之间的房屋,以加宽街道。19世纪后期,银行和其他金融机构也迁来此处设立总部。"
626 | # question = "为什么修道院的名称取自意为自由的古代德语词汇“frey”?"
627 | paragraph = "历时五年四次审议的《电子商务法》已于今年1月1日起实施。《电子商务法》的颁布实施为我国规范当前电子商务市场秩序、维护公平竞争环境、保障参与主体权益、促进电子商务健康快速发展奠定了法律基础。从总体上,应该看到《电子商务法》是一部以促进电子商务发展为立法目标之一的法律,是一部权益法,也是一部促进法。《电子商务法》专门设立了“电子商务促进”章节,明确了国家发展电子商务的重点方向。其中,农村电商和电商扶贫成为促进的重点."
628 | question = "电子商务法的目的"
629 | paragraph2 = """
630 | 第八条 本法所称对外贸易经营者,是指依法办理工商登记或者其他执业手续,依照本法和其他有关法律、行政法规的规定从事对外贸易经营活动的法人、其他组织或者个人。
631 | """
632 | question2 = "对外贸易经营者的定义"
633 | # question = "电子商务法的生效日期" # bad case
634 | # question = "电子商务法生效时间" # 一般的 case
635 | # question = "电子商务法实施时间" # good case
636 | input_data = [{
637 | "paragraphs": [
638 | {
639 | "context": paragraph,
640 | "qas": [
641 | {
642 | "question": question,
643 | "id": "RANDOM_QUESTION_ID"
644 | }
645 | ]
646 | },
647 | {
648 | "context": paragraph2,
649 | "qas": [
650 | {
651 | "question": question2,
652 | "id": "RANDOM_QUESTION_ID2"
653 | }
654 | ]
655 | },
656 | ]
657 | }]
658 | predict_examples = read_squad_examples(input_data, tokenizer)
659 |
660 | features = convert_examples_to_features(
661 | examples=predict_examples,
662 | tokenizer=tokenizer,
663 | max_seq_length=args_in_use.max_seq_length,
664 | doc_stride=args_in_use.doc_stride,
665 | max_query_length=args_in_use.max_query_length)
666 |
667 | start_time = time.time()
668 | results = predict_fn({
669 | args_in_use.tensor_unique_ids: [feature.unique_id for feature in features],
670 | args_in_use.tensor_input_ids: [feature.input_ids for feature in features],
671 | args_in_use.tensor_input_mask: [feature.input_mask for feature in features],
672 | args_in_use.tensor_segment_ids: [feature.segment_ids for feature in features],
673 | })
674 | print(f'elapsed time: {time.time()-start_time}s')
675 | print(np.shape(results['end_logits']))
676 | unique_ids = results['unique_ids']
677 | start_logits_list = results['start_logits']
678 | end_logits_list = results['end_logits']
679 | all_results = []
680 | for unique_id, start_logits, end_logits in zip(unique_ids, start_logits_list, end_logits_list):
681 | # unique_id = int(result["unique_ids"])
682 | # start_logits = [float(x) for x in result["start_logits"].flat]
683 | # end_logits = [float(x) for x in result["end_logits"].flat]
684 | _raw_result = RawResult(
685 | unique_id=unique_id,
686 | start_logits=start_logits,
687 | end_logits=end_logits)
688 | all_results.append(_raw_result)
689 | all_predictions, all_nbest_json = get_predictions(predict_examples, features, all_results,
690 | args_in_use.n_best_size, args_in_use.max_answer_length,
691 | args_in_use.do_lower_case)
692 | print(all_predictions)
693 | # print(all_nbest_json)
694 | # 分数对不上
695 |
696 | elif args_in_use.MODE == 'BATCH':
697 | pass
698 | # TO BE IMPLEMENTED
699 | else:
700 | raise ValueError('unsupported mode')
701 |
--------------------------------------------------------------------------------
/cmrc/cmrc_tool/run_squad_inf_cmrc.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import argparse
3 | import os
4 | import sys
5 | import math
6 | import tokenization
7 | import collections
8 | import tensorflow as tf
9 | import numpy as np
10 | import time
11 | import six
12 | import pdb
13 |
14 | parser = argparse.ArgumentParser(
15 | description='BERT model saved model case/batch test program, exit with q')
16 |
17 | parser.add_argument('--model', type=str,
18 | default='./cmrc_1588284341', help='the path for the model')
19 | parser.add_argument('--vocab_file', type=str,
20 | default='./1586959006/vocab.txt')
21 | parser.add_argument('--max_seq_length', type=int, default=512,
22 | help='the length of sequence for text padding')
23 | parser.add_argument('--do_lower_case', type=bool, default=True,
24 | help='Whether to lower case the input text. Should be True for uncased models and False for cased models.')
25 | parser.add_argument('--max_answer_length', type=int, default=312,
26 | help='The maximum length of an answer that can be generated. This is needed because the start and end predictions are not conditioned on one another.')
27 | parser.add_argument('--n_best_size', type=int, default=20,
28 | help='"The total number of n-best predictions to generate in the nbest_predictions.json output file.')
29 | parser.add_argument('--doc_stride', type=int, default=128,
30 | help='the length of document stride')
31 | parser.add_argument('--max_query_length', type=int, default=64,
32 | help='the max length of query length')
33 | parser.add_argument('--tensor_start_positions', type=str, default='start_positions',
34 | help='the start_positions feature name for saved model')
35 | parser.add_argument('--tensor_end_positions', type=str, default='end_positions',
36 | help='the end_positions feature name for saved model')
37 | parser.add_argument('--tensor_unique_ids', type=str, default='unique_ids',
38 | help='the unique_ids feature name for saved model')
39 | parser.add_argument('--tensor_input_ids', type=str, default='input_ids',
40 | help='the input_ids feature name for saved model')
41 | parser.add_argument('--tensor_input_mask', type=str, default='input_mask',
42 | help='the input_mask feature name for saved model')
43 | parser.add_argument('--tensor_segment_ids', type=str, default='segment_ids',
44 | help='the segment_ids feature name for saved model')
45 | parser.add_argument('--tensor_input_span_mask', type=str, default='input_span_mask',
46 | help='the input_span_mask feature name for saved model')
47 | parser.add_argument('--MODE', type=str, default='SINGLE',
48 | help='SINGLE prediction or BATCH prediction')
49 | args_in_use = parser.parse_args()
50 |
51 |
52 | class SquadExample(object):
53 | """A single training/test example for simple sequence classification.
54 |
55 | For examples without an answer, the start and end position are -1.
56 | """
57 |
58 | def __init__(self,
59 | qas_id,
60 | question_text,
61 | doc_tokens,
62 | orig_answer_text=None,
63 | start_position=None,
64 | end_position=None,
65 | is_impossible=False):
66 | self.qas_id = qas_id
67 | self.question_text = question_text
68 | self.doc_tokens = doc_tokens
69 | self.orig_answer_text = orig_answer_text
70 | self.start_position = start_position
71 | self.end_position = end_position
72 | self.is_impossible = is_impossible
73 |
74 | def __str__(self):
75 | return self.__repr__()
76 |
77 | def __repr__(self):
78 | s = ""
79 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
80 | s += ", question_text: %s" % (
81 | tokenization.printable_text(self.question_text))
82 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
83 | if self.start_position:
84 | s += ", start_position: %d" % (self.start_position)
85 | if self.start_position:
86 | s += ", end_position: %d" % (self.end_position)
87 | if self.start_position:
88 | s += ", is_impossible: %r" % (self.is_impossible)
89 | return s
90 |
91 |
92 | class InputFeatures(object):
93 | """A single set of features of data."""
94 |
95 | def __init__(self,
96 | unique_id,
97 | example_index,
98 | doc_span_index,
99 | tokens,
100 | token_to_orig_map,
101 | token_is_max_context,
102 | input_ids,
103 | input_mask,
104 | segment_ids,
105 | input_span_mask,
106 | start_position=None,
107 | end_position=None):
108 | self.unique_id = unique_id
109 | self.example_index = example_index
110 | self.doc_span_index = doc_span_index
111 | self.tokens = tokens
112 | self.token_to_orig_map = token_to_orig_map
113 | self.token_is_max_context = token_is_max_context
114 | self.input_ids = input_ids
115 | self.input_mask = input_mask
116 | self.segment_ids = segment_ids
117 | self.input_span_mask = input_span_mask
118 | self.start_position = start_position
119 | self.end_position = end_position
120 |
121 |
122 | def _check_is_max_context(doc_spans, cur_span_index, position):
123 | """Check if this is the 'max context' doc span for the token."""
124 |
125 | # Because of the sliding window approach taken to scoring documents, a single
126 | # token can appear in multiple documents. E.g.
127 | # Doc: the man went to the store and bought a gallon of milk
128 | # Span A: the man went to the
129 | # Span B: to the store and bought
130 | # Span C: and bought a gallon of
131 | # ...
132 | #
133 | # Now the word 'bought' will have two scores from spans B and C. We only
134 | # want to consider the score with "maximum context", which we define as
135 | # the *minimum* of its left and right context (the *sum* of left and
136 | # right context will always be the same, of course).
137 | #
138 | # In the example the maximum context for 'bought' would be span C since
139 | # it has 1 left context and 3 right context, while span B has 4 left context
140 | # and 0 right context.
141 | best_score = None
142 | best_span_index = None
143 | for (span_index, doc_span) in enumerate(doc_spans):
144 | end = doc_span.start + doc_span.length - 1
145 | if position < doc_span.start:
146 | continue
147 | if position > end:
148 | continue
149 | num_left_context = position - doc_span.start
150 | num_right_context = end - position
151 | score = min(num_left_context, num_right_context) + \
152 | 0.01 * doc_span.length
153 | if best_score is None or score > best_score:
154 | best_score = score
155 | best_span_index = span_index
156 |
157 | return cur_span_index == best_span_index
158 |
159 |
160 | class ChineseFullTokenizer(object):
161 | """Runs end-to-end tokenziation."""
162 |
163 | def __init__(self, vocab_file, do_lower_case=False):
164 | self.vocab = tokenization.load_vocab(vocab_file)
165 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
166 | self.wordpiece_tokenizer = tokenization.WordpieceTokenizer(
167 | vocab=self.vocab)
168 | self.do_lower_case = do_lower_case
169 |
170 | def tokenize(self, text):
171 | split_tokens = []
172 | for token in customize_tokenizer(text, do_lower_case=self.do_lower_case):
173 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
174 | split_tokens.append(sub_token)
175 |
176 | return split_tokens
177 |
178 | def convert_tokens_to_ids(self, tokens):
179 | return tokenization.convert_by_vocab(self.vocab, tokens)
180 |
181 | def convert_ids_to_tokens(self, ids):
182 | return tokenization.convert_by_vocab(self.inv_vocab, ids)
183 |
184 |
185 | def convert_examples_to_features(examples, tokenizer, max_seq_length,
186 | doc_stride, max_query_length):
187 | """Loads a data file into a list of `InputBatch`s."""
188 |
189 | unique_id = 1000000000
190 | tokenizer = ChineseFullTokenizer(
191 | vocab_file=args_in_use.vocab_file, do_lower_case=args_in_use.do_lower_case)
192 |
193 | features = []
194 | for (example_index, example) in enumerate(examples):
195 | query_tokens = tokenizer.tokenize(example.question_text)
196 |
197 | if len(query_tokens) > max_query_length:
198 | query_tokens = query_tokens[0:max_query_length]
199 |
200 | # 将原文变成token之后,token和原始文本index的对应关系
201 | tok_to_orig_index = []
202 | # 将原文变成token之后,原始文本和token的index的对应关系
203 | orig_to_tok_index = []
204 | all_doc_tokens = []
205 | for (i, token) in enumerate(example.doc_tokens):
206 | orig_to_tok_index.append(len(all_doc_tokens))
207 | sub_tokens = tokenizer.tokenize(token)
208 | for sub_token in sub_tokens:
209 | tok_to_orig_index.append(i)
210 | all_doc_tokens.append(sub_token)
211 |
212 | tok_start_position = None
213 | tok_end_position = None
214 |
215 | # The -3 accounts for [CLS], [SEP] and [SEP]
216 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
217 |
218 | # We can have documents that are longer than the maximum sequence length.
219 | # To deal with this we do a sliding window approach, where we take chunks
220 | # of the up to our max length with a stride of `doc_stride`.
221 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
222 | "DocSpan", ["start", "length"])
223 | doc_spans = []
224 | start_offset = 0
225 | while start_offset < len(all_doc_tokens):
226 | length = len(all_doc_tokens) - start_offset
227 | if length > max_tokens_for_doc:
228 | length = max_tokens_for_doc
229 | doc_spans.append(_DocSpan(start=start_offset, length=length))
230 | # 如果length是小于max_tokens_for_doc的话,那么就会直接break
231 | if start_offset + length == len(all_doc_tokens):
232 | break
233 | start_offset += min(length, doc_stride)
234 |
235 | for (doc_span_index, doc_span) in enumerate(doc_spans):
236 | tokens = []
237 | token_to_orig_map = {}
238 | token_is_max_context = {}
239 | # segment_ids中0代表[CLS]、第一个[SEP]和query_tokens,1代表doc和第二个[SEP]
240 | segment_ids = []
241 | input_span_mask = []
242 | tokens.append("[CLS]")
243 | segment_ids.append(0)
244 | input_span_mask.append(1)
245 | for token in query_tokens:
246 | tokens.append(token)
247 | segment_ids.append(0)
248 | input_span_mask.append(0)
249 | tokens.append("[SEP]")
250 | segment_ids.append(0)
251 | input_span_mask.append(0) # TODO:check why
252 |
253 | for i in range(doc_span.length):
254 | split_token_index = doc_span.start + i
255 | # len(tokens)为[CLS]+query_tokens+[SEP]的大小,应该是doc_tokens第i个token
256 | token_to_orig_map[len(
257 | tokens)] = tok_to_orig_index[split_token_index]
258 |
259 | is_max_context = _check_is_max_context(doc_spans, doc_span_index,
260 | split_token_index)
261 | token_is_max_context[len(tokens)] = is_max_context
262 | tokens.append(all_doc_tokens[split_token_index])
263 | segment_ids.append(1)
264 | input_span_mask.append(1)
265 | tokens.append("[SEP]")
266 | segment_ids.append(1)
267 | input_span_mask.append(0)
268 |
269 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
270 |
271 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
272 | # tokens are attended to.
273 | input_mask = [1] * len(input_ids)
274 |
275 | # Zero-pad up to the sequence length.
276 | while len(input_ids) < max_seq_length:
277 | input_ids.append(0)
278 | input_mask.append(0)
279 | segment_ids.append(0)
280 | input_span_mask.append(0)
281 |
282 | assert len(input_ids) == max_seq_length
283 | assert len(input_mask) == max_seq_length
284 | assert len(segment_ids) == max_seq_length
285 | assert len(input_span_mask) == max_seq_length
286 |
287 | start_position = None
288 | end_position = None
289 |
290 | features.append(
291 | InputFeatures(
292 | unique_id=unique_id,
293 | example_index=example_index,
294 | doc_span_index=doc_span_index,
295 | tokens=tokens, # 可能会引入UNK
296 | token_to_orig_map=token_to_orig_map,
297 | token_is_max_context=token_is_max_context,
298 | input_ids=input_ids,
299 | input_mask=input_mask,
300 | segment_ids=segment_ids,
301 | input_span_mask=input_span_mask,
302 | start_position=start_position,
303 | end_position=end_position))
304 | unique_id += 1
305 |
306 | return features
307 |
308 |
309 | def customize_tokenizer(text, do_lower_case=False):
310 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
311 | temp_x = ""
312 | text = tokenization.convert_to_unicode(text)
313 | for c in text:
314 | if tokenizer._is_chinese_char(ord(c)) or tokenization._is_punctuation(c) or tokenization._is_whitespace(c) or tokenization._is_control(c):
315 | temp_x += " " + c + " "
316 | else:
317 | temp_x += c
318 | if do_lower_case:
319 | temp_x = temp_x.lower()
320 | return temp_x.split()
321 |
322 |
323 | def read_squad_examples(input_file):
324 | """Read a SQuAD json file into a list of SquadExample."""
325 | # with tf.gfile.Open(input_file, "r") as reader:
326 | # input_data = json.load(reader)["data"]
327 |
328 | #
329 | examples = []
330 | for entry in input_data:
331 | for paragraph in entry["paragraphs"]:
332 | paragraph_text = paragraph["context"]
333 | raw_doc_tokens = customize_tokenizer(
334 | paragraph_text, do_lower_case=args_in_use.do_lower_case)
335 | doc_tokens = []
336 | char_to_word_offset = []
337 | prev_is_whitespace = True
338 |
339 | k = 0
340 | temp_word = ""
341 | for c in paragraph_text:
342 | # c is whitespace
343 | if tokenization._is_whitespace(c) or not c.split():
344 | char_to_word_offset.append(k-1)
345 | continue
346 | else:
347 | temp_word += c
348 | char_to_word_offset.append(k)
349 | if args_in_use.do_lower_case:
350 | temp_word = temp_word.lower()
351 | if temp_word == raw_doc_tokens[k]:
352 | doc_tokens.append(temp_word)
353 | temp_word = ""
354 | k += 1
355 | if k != len(raw_doc_tokens):
356 | print(paragraph)
357 | print(doc_tokens)
358 | print(raw_doc_tokens)
359 | assert k == len(raw_doc_tokens)
360 |
361 | for qa in paragraph["qas"]:
362 | qas_id = qa["id"]
363 | question_text = qa["question"]
364 | start_position = None
365 | end_position = None
366 | orig_answer_text = None
367 |
368 | example = SquadExample(
369 | qas_id=qas_id,
370 | question_text=question_text,
371 | doc_tokens=doc_tokens,
372 | orig_answer_text=orig_answer_text,
373 | start_position=start_position,
374 | end_position=end_position)
375 | examples.append(example)
376 | tf.logging.info("**********read_squad_examples complete!**********")
377 |
378 | return examples
379 |
380 |
381 | def get_final_text(pred_text, orig_text, do_lower_case):
382 | """Project the tokenized prediction back to the original text."""
383 |
384 | # When we created the data, we kept track of the alignment between original
385 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
386 | # now `orig_text` contains the span of our original text corresponding to the
387 | # span that we predicted.
388 | #
389 | # However, `orig_text` may contain extra characters that we don't want in
390 | # our prediction.
391 | #
392 | # For example, let's say:
393 | # pred_text = steve smith
394 | # orig_text = Steve Smith's
395 | #
396 | # We don't want to return `orig_text` because it contains the extra "'s".
397 | #
398 | # We don't want to return `pred_text` because it's already been normalized
399 | # (the SQuAD eval script also does punctuation stripping/lower casing but
400 | # our tokenizer does additional normalization like stripping accent
401 | # characters).
402 | #
403 | # What we really want to return is "Steve Smith".
404 | #
405 | # Therefore, we have to apply a semi-complicated alignment heruistic between
406 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
407 | # can fail in certain cases in which case we just return `orig_text`.
408 |
409 | def _strip_spaces(text):
410 | ns_chars = []
411 | ns_to_s_map = collections.OrderedDict()
412 | for (i, c) in enumerate(text):
413 | if c == " ":
414 | continue
415 | ns_to_s_map[len(ns_chars)] = i
416 | ns_chars.append(c)
417 | ns_text = "".join(ns_chars)
418 | return (ns_text, ns_to_s_map)
419 |
420 | # We first tokenize `orig_text`, strip whitespace from the result
421 | # and `pred_text`, and check if they are the same length. If they are
422 | # NOT the same length, the heuristic has failed. If they are the same
423 | # length, we assume the characters are one-to-one aligned.
424 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
425 |
426 | tok_text = " ".join(tokenizer.tokenize(orig_text))
427 |
428 | start_position = tok_text.find(pred_text)
429 | if start_position == -1:
430 | tf.logging.info(
431 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
432 | return orig_text
433 | end_position = start_position + len(pred_text) - 1
434 |
435 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
436 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
437 |
438 | if len(orig_ns_text) != len(tok_ns_text):
439 | if args_in_use.verbose_logging:
440 | tf.logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
441 | orig_ns_text, tok_ns_text)
442 | return orig_text
443 |
444 | # We then project the characters in `pred_text` back to `orig_text` using
445 | # the character-to-character alignment.
446 | tok_s_to_ns_map = {}
447 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
448 | tok_s_to_ns_map[tok_index] = i
449 |
450 | orig_start_position = None
451 | if start_position in tok_s_to_ns_map:
452 | ns_start_position = tok_s_to_ns_map[start_position]
453 | if ns_start_position in orig_ns_to_s_map:
454 | orig_start_position = orig_ns_to_s_map[ns_start_position]
455 |
456 | if orig_start_position is None:
457 | if args_in_use.verbose_logging:
458 | tf.logging.info("Couldn't map start position")
459 | return orig_text
460 |
461 | orig_end_position = None
462 | if end_position in tok_s_to_ns_map:
463 | ns_end_position = tok_s_to_ns_map[end_position]
464 | if ns_end_position in orig_ns_to_s_map:
465 | orig_end_position = orig_ns_to_s_map[ns_end_position]
466 |
467 | if orig_end_position is None:
468 | if args_in_use.verbose_logging:
469 | tf.logging.info("Couldn't map end position")
470 | return orig_text
471 |
472 | output_text = orig_text[orig_start_position:(orig_end_position + 1)]
473 | return output_text
474 |
475 |
476 | def _compute_softmax(scores):
477 | """Compute softmax probability over raw logits."""
478 | if not scores:
479 | return []
480 |
481 | max_score = None
482 | for score in scores:
483 | if max_score is None or score > max_score:
484 | max_score = score
485 |
486 | exp_scores = []
487 | total_sum = 0.0
488 | for score in scores:
489 | x = math.exp(score - max_score)
490 | exp_scores.append(x)
491 | total_sum += x
492 |
493 | probs = []
494 | for score in exp_scores:
495 | probs.append(score / total_sum)
496 | return probs
497 |
498 |
499 | def _get_best_indexes(logits, n_best_size):
500 | """Get the n-best logits from a list."""
501 | index_and_score = sorted(
502 | enumerate(logits), key=lambda x: x[1], reverse=True)
503 |
504 | best_indexes = []
505 | for i in range(len(index_and_score)):
506 | if i >= n_best_size:
507 | break
508 | best_indexes.append(index_and_score[i][0])
509 | return best_indexes
510 |
511 |
512 | def get_predictions(all_examples, all_features, all_results, n_best_size,
513 | max_answer_length, do_lower_case):
514 |
515 | example_index_to_features = collections.defaultdict(list)
516 | for feature in all_features:
517 | example_index_to_features[feature.example_index].append(feature)
518 |
519 | unique_id_to_result = {}
520 | for result in all_results:
521 | unique_id_to_result[result.unique_id] = result
522 |
523 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
524 | "PrelimPrediction",
525 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
526 |
527 | all_predictions = collections.OrderedDict()
528 | all_nbest_json = collections.OrderedDict()
529 | for (example_index, example) in enumerate(all_examples):
530 | features = example_index_to_features[example_index]
531 |
532 | prelim_predictions = []
533 | for (feature_index, feature) in enumerate(features):
534 | result = unique_id_to_result[feature.unique_id]
535 |
536 | start_indexes = _get_best_indexes(result.start_logits, n_best_size)
537 | end_indexes = _get_best_indexes(result.end_logits, n_best_size)
538 | for start_index in start_indexes:
539 | for end_index in end_indexes:
540 | # We could hypothetically create invalid predictions, e.g., predict
541 | # that the start of the span is in the question. We throw out all
542 | # invalid predictions.
543 | if start_index >= len(feature.tokens):
544 | continue
545 | if end_index >= len(feature.tokens):
546 | continue
547 | if start_index not in feature.token_to_orig_map:
548 | continue
549 | if end_index not in feature.token_to_orig_map:
550 | continue
551 | if not feature.token_is_max_context.get(start_index, False):
552 | continue
553 | if end_index < start_index:
554 | continue
555 | length = end_index - start_index + 1
556 | if length > max_answer_length:
557 | continue
558 | prelim_predictions.append(
559 | _PrelimPrediction(
560 | feature_index=feature_index,
561 | start_index=start_index,
562 | end_index=end_index,
563 | start_logit=result.start_logits[start_index],
564 | end_logit=result.end_logits[end_index]))
565 |
566 | prelim_predictions = sorted(
567 | prelim_predictions,
568 | key=lambda x: (x.start_logit + x.end_logit),
569 | reverse=True)
570 |
571 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
572 | "NbestPrediction", ["text", "start_logit", "end_logit", "start_index", "end_index"])
573 |
574 | seen_predictions = {}
575 | nbest = []
576 | for pred in prelim_predictions:
577 | if len(nbest) >= n_best_size:
578 | break
579 | feature = features[pred.feature_index]
580 |
581 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
582 | orig_doc_start = feature.token_to_orig_map[pred.start_index]
583 | orig_doc_end = feature.token_to_orig_map[pred.end_index]
584 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
585 | tok_text = " ".join(tok_tokens)
586 |
587 | # De-tokenize WordPieces that have been split off.
588 | tok_text = tok_text.replace(" ##", "")
589 | tok_text = tok_text.replace("##", "")
590 |
591 | # Clean whitespace
592 | tok_text = tok_text.strip()
593 | tok_text = " ".join(tok_text.split())
594 | orig_text = " ".join(orig_tokens)
595 |
596 | final_text = get_final_text(tok_text, orig_text, do_lower_case)
597 | final_text = final_text.replace(' ', '')
598 | if final_text in seen_predictions:
599 | continue
600 |
601 | seen_predictions[final_text] = True
602 | nbest.append(
603 | _NbestPrediction(
604 | text=final_text,
605 | start_logit=pred.start_logit,
606 | end_logit=pred.end_logit,
607 | start_index=pred.start_index,
608 | end_index=pred.end_index))
609 |
610 | # In very rare edge cases we could have no valid predictions. So we
611 | # just create a nonce prediction in this case to avoid failure.
612 | if not nbest:
613 | nbest.append(
614 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
615 |
616 | assert len(nbest) >= 1
617 |
618 | total_scores = []
619 | best_non_null_entry = None
620 | for entry in nbest:
621 | total_scores.append(entry.start_logit + entry.end_logit)
622 | if not best_non_null_entry:
623 | if entry.text:
624 | best_non_null_entry = entry
625 |
626 | probs = _compute_softmax(total_scores)
627 |
628 | nbest_json = []
629 | for (i, entry) in enumerate(nbest):
630 | output = collections.OrderedDict()
631 | output["text"] = entry.text
632 | output["probability"] = probs[i]
633 | output["start_logit"] = entry.start_logit
634 | output["end_logit"] = entry.end_logit
635 | output["start_index"] = entry.start_index
636 | output["end_index"] = entry.end_index
637 | nbest_json.append(output)
638 |
639 | assert len(nbest_json) >= 1
640 |
641 | all_predictions[example.qas_id] = nbest_json[0]["text"]
642 | all_nbest_json[example.qas_id] = nbest_json
643 | return all_predictions, all_nbest_json
644 |
645 |
646 | RawResult = collections.namedtuple("RawResult",
647 | ["unique_id", "start_logits", "end_logits"])
648 |
649 |
650 | if __name__ == '__main__':
651 | predict_fn = tf.contrib.predictor.from_saved_model(args_in_use.model)
652 | max_seq_length = args_in_use.max_seq_length
653 | tokenizer = tokenization.FullTokenizer(vocab_file=args_in_use.vocab_file,
654 | do_lower_case=True)
655 | if args_in_use.MODE == 'SINGLE':
656 | # while True:
657 | # paragraph = input('(PRESS q to quit)请输入段落\n> ')
658 | # question = input('(PRESS q to quit)请输入问题\n> ')
659 | # if question == 'q':
660 | # break
661 |
662 | """
663 | DEMO EXAMPLES
664 | """
665 | # paragraph = "《战国无双3》()是由光荣和ω-force开发的战国无双系列 的正统第三续作。本作以三大故事为主轴,分别是以武田信玄等人为主的《关东三国 志》,织田信长等人为主的《战国三杰》,石田三成等人为主的《关原的年轻武者》 ,丰富游戏内的剧情。此部份专门介绍角色,欲知武器情报、奥义字或擅长攻击类型 等,请至战国无双系列1.由于乡里大辅先生因故去世,不得不寻找其他声优接手。从 猛将传 and Z开始。2.战国无双 编年史的原创男女主角亦有专属声优。此模式是任 > 天堂游戏谜之村雨城改编的新增模式。本作中共有20张战场地图(不含村雨城),后 来发行的猛将传再新增3张战场地图。但游戏内战役数量繁多,部分地图会有兼用的> 状况,战役虚实则是以光荣发行的2本「战国无双3 人物真书」内容为主,以下是相> 关介绍。(注:前方加☆者为猛将传新增关卡及地图。)合并本篇和猛将传的内容"
666 | # question = "《战国无双3》是由哪两个公司合作开发的?"
667 | # paragraph = "弗赖永广场(Freyung)是奥地利首都维也纳的一个三角形广场,位于内城(第一区)。这个广场最初位于古罗马堡垒Vindabona城墙以外,在12世纪,奥地利公爵亨利二世邀请爱尔兰僧侣来此修建了“苏格兰修道院”(Schottenkloster),因为当时爱尔兰被称为“新苏格兰”。修道院周围的广场也称为苏格兰广场。“弗赖永”这个名称源于古代德语词汇“frey”,意为自由。这是因为修道院拥有不受公爵管理的特权,还能保护逃亡者。1773年,其侧又新建小修道院,因其形状被称为鞋盒房子(Schubladkastenhaus)。弗赖永广场成为重要的市场,各种各样的街头艺术家以表演为生。其中一种表演是“维也纳小丑”(Wiener Hanswurst)。由于霍夫堡皇宫距此不远,在17世纪和18世纪,许多奥地利贵族在广场上和附近的Herrengasse兴建他们的城市住所。1856年,拆除了弗赖永广场和毗邻的Am Hof广场之间的房屋,以加宽街道。19世纪后期,银行和其他金融机构也迁来此处设立总部。"
668 | # paragraph = "“弗赖永”这个名称源于古代德语词汇“frey”,意为自由。这是因为修道院拥有不受公爵管理的特权,还能保护逃亡者。1773年,其侧又新建小修道院,因其形状被称为鞋盒房子(Schubladkastenhaus)。弗赖永广场成为重要的市场,各种各样的街头艺术家以表演为生。其中一种表演是“维也纳小丑”(Wiener Hanswurst)。由于霍夫堡皇宫距此不远,在17世纪和18世纪,许多奥地利贵族在广场上和附近的Herrengasse兴建他们的城市住所。1856年,拆除了弗赖永广场和毗邻的Am Hof广场之间的房屋,以加宽街道。19世纪后期,银行和其他金融机构也迁来此处设立总部。"
669 | # question = "为什么修道院的名称取自意为自由的古代德语词汇“frey”?"
670 | paragraph = "历时五年四次审议的《电子商务法》已于今年1月1日起实施。《电子商务法》的颁布实施为我国规范当前电子商务市场秩序、维护公平竞争环境、保障参与主体权益、促进电子商务健康快速发展奠定了法律基础。从总体上,应该看到《电子商务法》是一部以促进电子商务发展为立法目标之一的法律,是一部权益法,也是一部促进法。《电子商务法》专门设立了“电子商务促进”章节,明确了国家发展电子商务的重点方向。其中,农村电商和电商扶贫成为促进的重点."
671 | question = "电子商务法的目的"
672 | paragraph2 ="阳光板大部分使用的是聚碳酸酯(PC)原料生产,利用空挤压工艺在耐候性脆弱的PC板材上空挤压UV树脂,质量好一点的板面均匀分布有高浓度的UV层,阻挡紫外线的穿过,防止板材黄变,延长板材寿命使产品使用寿命达到10年以上。并且产品具有长期持续透明性的特点。(有单面和双面UV防护)。用途:住宅/商厦采光天幕,工厂厂房 仓库采光顶,体育场馆采光顶,广告牌,通道/停车棚,游泳池/温室覆盖,室内隔断。另本司有隔热保温的PC板材做温棚 遮阳棚 都不错2832217048@qq.com"
673 | question2 = "阳光板雨棚能用几年"
674 | paragraph3 = "藏蓝色,兼于蓝色和黑色之间,既有蓝色的沉静安宁,也有黑色的神秘成熟,既有黑色的收敛效果,又不乏蓝色的洁净长久,虽然不会大热流行,却是可以长久的信任,当藏蓝色与其他颜色相遇,你便会懂得它内在的涵养。藏蓝色+橙色单纯的藏蓝色会给人很严肃的气氛,橙色的点缀让藏蓝色也充满时尚活力。藏蓝色+白色白色是藏蓝色的最佳搭档,两者搭档最容易显得很干净,藏蓝色和白色营造的洗练感,让通勤装永远都不会过时,展现出都市女性的利落感。藏蓝色+粉色藏蓝色和粉色组合散发出成熟优雅的女人味,让粉色显出别样娇嫩。藏蓝色+米色藏蓝色和米色的搭配散发出浓郁的知性气质,稚气的设计细节更显年轻。藏蓝色+红色藏蓝色和红色的搭配更加的沉稳,也更具存在感,如果是面积差不多的服装来搭配,可以用红色的小物点缀来巧妙的平衡。藏蓝色+松石绿藏蓝色搭配柔和的松石绿色给人上品好品质的感觉,用凉鞋和项链来点缀更加具有层次感。藏蓝色+黄色明亮的黄色热情洋溢的融化了藏蓝色的冰冷静谧,细节感的设计更加具有轻松休闲的气氛。藏蓝色+金色推荐单品:藏蓝色"
675 | question3 = "藏蓝色配什么颜色好看"
676 | # question = "电子商务法的生效日期" # bad case
677 | # question = "电子商务法生效时间" # 一般的 case
678 | # question = "电子商务法实施时间" # good case
679 | input_data = [{
680 | "paragraphs": [
681 | {
682 | "context": paragraph,
683 | "qas": [
684 | {
685 | "question": question,
686 | "id": "RANDOM_QUESTION_ID"
687 | }
688 | ]
689 | },
690 | {
691 | "context": paragraph2,
692 | "qas": [
693 | {
694 | "question": question2,
695 | "id": "RANDOM_QUESTION_ID2"
696 | }
697 | ]
698 | },
699 | {
700 | "context": paragraph3,
701 | "qas": [
702 | {
703 | "question": question3,
704 | "id": "RANDOM_QUESTION_ID3"
705 | }
706 | ]
707 | },
708 | ]
709 | }]
710 | predict_examples = read_squad_examples(input_data)
711 |
712 | features = convert_examples_to_features(
713 | examples=predict_examples,
714 | tokenizer=tokenizer,
715 | max_seq_length=args_in_use.max_seq_length,
716 | doc_stride=args_in_use.doc_stride,
717 | max_query_length=args_in_use.max_query_length)
718 |
719 | start_time = time.time()
720 | results = predict_fn({
721 | args_in_use.tensor_unique_ids: [feature.unique_id for feature in features],
722 | args_in_use.tensor_input_ids: [feature.input_ids for feature in features],
723 | args_in_use.tensor_input_mask: [feature.input_mask for feature in features],
724 | args_in_use.tensor_segment_ids: [feature.segment_ids for feature in features],
725 | args_in_use.tensor_input_span_mask: [feature.input_span_mask for feature in features],
726 | })
727 | print(f'elapsed time: {time.time()-start_time}s')
728 | print(np.shape(results['end_logits']))
729 | unique_ids = results['unique_ids']
730 | start_logits_list = results['start_logits']
731 | end_logits_list = results['end_logits']
732 | all_results = []
733 | for unique_id, start_logits, end_logits in zip(unique_ids, start_logits_list, end_logits_list):
734 | # unique_id = int(result["unique_ids"])
735 | # start_logits = [float(x) for x in result["start_logits"].flat]
736 | # end_logits = [float(x) for x in result["end_logits"].flat]
737 | _raw_result = RawResult(
738 | unique_id=unique_id,
739 | start_logits=start_logits,
740 | end_logits=end_logits)
741 | all_results.append(_raw_result)
742 | all_predictions, all_nbest_json = get_predictions(predict_examples, features, all_results,
743 | args_in_use.n_best_size, args_in_use.max_answer_length,
744 | args_in_use.do_lower_case)
745 | print(all_predictions)
746 | # print(all_nbest_json)
747 | # 分数对不上
748 |
749 | elif args_in_use.MODE == 'BATCH':
750 | pass
751 | # TO BE IMPLEMENTED
752 | else:
753 | raise ValueError('unsupported mode')
754 |
--------------------------------------------------------------------------------
/modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """The main BERT model and related functions."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import copy
23 | import json
24 | import math
25 | import re
26 | import six
27 | import tensorflow as tf
28 |
29 |
30 | class BertConfig(object):
31 | """Configuration for `BertModel`."""
32 |
33 | def __init__(self,
34 | vocab_size,
35 | hidden_size=768,
36 | num_hidden_layers=12,
37 | num_attention_heads=12,
38 | intermediate_size=3072,
39 | hidden_act="gelu",
40 | hidden_dropout_prob=0.1,
41 | attention_probs_dropout_prob=0.1,
42 | max_position_embeddings=512,
43 | type_vocab_size=16,
44 | initializer_range=0.02):
45 | """Constructs BertConfig.
46 |
47 | Args:
48 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
49 | hidden_size: Size of the encoder layers and the pooler layer.
50 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
51 | num_attention_heads: Number of attention heads for each attention layer in
52 | the Transformer encoder.
53 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
54 | layer in the Transformer encoder.
55 | hidden_act: The non-linear activation function (function or string) in the
56 | encoder and pooler.
57 | hidden_dropout_prob: The dropout probability for all fully connected
58 | layers in the embeddings, encoder, and pooler.
59 | attention_probs_dropout_prob: The dropout ratio for the attention
60 | probabilities.
61 | max_position_embeddings: The maximum sequence length that this model might
62 | ever be used with. Typically set this to something large just in case
63 | (e.g., 512 or 1024 or 2048).
64 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
65 | `BertModel`.
66 | initializer_range: The stdev of the truncated_normal_initializer for
67 | initializing all weight matrices.
68 | """
69 | self.vocab_size = vocab_size
70 | self.hidden_size = hidden_size
71 | self.num_hidden_layers = num_hidden_layers
72 | self.num_attention_heads = num_attention_heads
73 | self.hidden_act = hidden_act
74 | self.intermediate_size = intermediate_size
75 | self.hidden_dropout_prob = hidden_dropout_prob
76 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
77 | self.max_position_embeddings = max_position_embeddings
78 | self.type_vocab_size = type_vocab_size
79 | self.initializer_range = initializer_range
80 |
81 | @classmethod
82 | def from_dict(cls, json_object):
83 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
84 | config = BertConfig(vocab_size=None)
85 | for (key, value) in six.iteritems(json_object):
86 | config.__dict__[key] = value
87 | return config
88 |
89 | @classmethod
90 | def from_json_file(cls, json_file):
91 | """Constructs a `BertConfig` from a json file of parameters."""
92 | with tf.gfile.GFile(json_file, "r") as reader:
93 | text = reader.read()
94 | return cls.from_dict(json.loads(text))
95 |
96 | def to_dict(self):
97 | """Serializes this instance to a Python dictionary."""
98 | output = copy.deepcopy(self.__dict__)
99 | return output
100 |
101 | def to_json_string(self):
102 | """Serializes this instance to a JSON string."""
103 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
104 |
105 |
106 | class BertModel(object):
107 | """BERT model ("Bidirectional Encoder Representations from Transformers").
108 |
109 | Example usage:
110 |
111 | ```python
112 | # Already been converted into WordPiece token ids
113 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
114 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
115 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
116 |
117 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
118 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
119 |
120 | model = modeling.BertModel(config=config, is_training=True,
121 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
122 |
123 | label_embeddings = tf.get_variable(...)
124 | pooled_output = model.get_pooled_output()
125 | logits = tf.matmul(pooled_output, label_embeddings)
126 | ...
127 | ```
128 | """
129 |
130 | def __init__(self,
131 | config,
132 | is_training,
133 | input_ids,
134 | input_mask=None,
135 | token_type_ids=None,
136 | use_one_hot_embeddings=True,
137 | scope=None):
138 | """Constructor for BertModel.
139 |
140 | Args:
141 | config: `BertConfig` instance.
142 | is_training: bool. true for training model, false for eval model. Controls
143 | whether dropout will be applied.
144 | input_ids: int32 Tensor of shape [batch_size, seq_length].
145 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
146 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
147 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
148 | embeddings or tf.embedding_lookup() for the word embeddings. On the TPU,
149 | it is much faster if this is True, on the CPU or GPU, it is faster if
150 | this is False.
151 | scope: (optional) variable scope. Defaults to "bert".
152 |
153 | Raises:
154 | ValueError: The config is invalid or one of the input tensor shapes
155 | is invalid.
156 | """
157 | config = copy.deepcopy(config)
158 | if not is_training:
159 | config.hidden_dropout_prob = 0.0
160 | config.attention_probs_dropout_prob = 0.0
161 |
162 | input_shape = get_shape_list(input_ids, expected_rank=2)
163 | batch_size = input_shape[0]
164 | seq_length = input_shape[1]
165 |
166 | if input_mask is None:
167 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
168 |
169 | if token_type_ids is None:
170 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
171 |
172 | with tf.variable_scope(scope, default_name="bert"):
173 | with tf.variable_scope("embeddings"):
174 | # Perform embedding lookup on the word ids.
175 | (self.embedding_output, self.embedding_table) = embedding_lookup(
176 | input_ids=input_ids,
177 | vocab_size=config.vocab_size,
178 | embedding_size=config.hidden_size,
179 | initializer_range=config.initializer_range,
180 | word_embedding_name="word_embeddings",
181 | use_one_hot_embeddings=use_one_hot_embeddings)
182 |
183 | # Add positional embeddings and token type embeddings, then layer
184 | # normalize and perform dropout.
185 | self.embedding_output = embedding_postprocessor(
186 | input_tensor=self.embedding_output,
187 | use_token_type=True,
188 | token_type_ids=token_type_ids,
189 | token_type_vocab_size=config.type_vocab_size,
190 | token_type_embedding_name="token_type_embeddings",
191 | use_position_embeddings=True,
192 | position_embedding_name="position_embeddings",
193 | initializer_range=config.initializer_range,
194 | max_position_embeddings=config.max_position_embeddings,
195 | dropout_prob=config.hidden_dropout_prob)
196 |
197 | with tf.variable_scope("encoder"):
198 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
199 | # mask of shape [batch_size, seq_length, seq_length] which is used
200 | # for the attention scores.
201 | attention_mask = create_attention_mask_from_input_mask(
202 | input_ids, input_mask)
203 |
204 | # Run the stacked transformer.
205 | # `sequence_output` shape = [batch_size, seq_length, hidden_size].
206 | self.all_encoder_layers = transformer_model(
207 | input_tensor=self.embedding_output,
208 | attention_mask=attention_mask,
209 | hidden_size=config.hidden_size,
210 | num_hidden_layers=config.num_hidden_layers,
211 | num_attention_heads=config.num_attention_heads,
212 | intermediate_size=config.intermediate_size,
213 | intermediate_act_fn=get_activation(config.hidden_act),
214 | hidden_dropout_prob=config.hidden_dropout_prob,
215 | attention_probs_dropout_prob=config.attention_probs_dropout_prob,
216 | initializer_range=config.initializer_range,
217 | do_return_all_layers=True)
218 |
219 | self.sequence_output = self.all_encoder_layers[-1]
220 | # The "pooler" converts the encoded sequence tensor of shape
221 | # [batch_size, seq_length, hidden_size] to a tensor of shape
222 | # [batch_size, hidden_size]. This is necessary for segment-level
223 | # (or segment-pair-level) classification tasks where we need a fixed
224 | # dimensional representation of the segment.
225 | with tf.variable_scope("pooler"):
226 | # We "pool" the model by simply taking the hidden state corresponding
227 | # to the first token. We assume that this has been pre-trained
228 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
229 | self.pooled_output = tf.layers.dense(
230 | first_token_tensor,
231 | config.hidden_size,
232 | activation=tf.tanh,
233 | kernel_initializer=create_initializer(config.initializer_range))
234 |
235 | def get_pooled_output(self):
236 | return self.pooled_output
237 |
238 | def get_sequence_output(self):
239 | """Gets final hidden layer of encoder.
240 |
241 | Returns:
242 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
243 | to the final hidden of the transformer encoder.
244 | """
245 | return self.sequence_output
246 |
247 | def get_all_encoder_layers(self):
248 | return self.all_encoder_layers
249 |
250 | def get_embedding_output(self):
251 | """Gets output of the embedding lookup (i.e., input to the transformer).
252 |
253 | Returns:
254 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
255 | to the output of the embedding layer, after summing the word
256 | embeddings with the positional embeddings and the token type embeddings,
257 | then performing layer normalization. This is the input to the transformer.
258 | """
259 | return self.embedding_output
260 |
261 | def get_embedding_table(self):
262 | return self.embedding_table
263 |
264 |
265 | def gelu(input_tensor):
266 | """Gaussian Error Linear Unit.
267 |
268 | This is a smoother version of the RELU.
269 | Original paper: https://arxiv.org/abs/1606.08415
270 |
271 | Args:
272 | input_tensor: float Tensor to perform activation.
273 |
274 | Returns:
275 | `input_tensor` with the GELU activation applied.
276 | """
277 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0)))
278 | return input_tensor * cdf
279 |
280 |
281 | def get_activation(activation_string):
282 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
283 |
284 | Args:
285 | activation_string: String name of the activation function.
286 |
287 | Returns:
288 | A Python function corresponding to the activation function. If
289 | `activation_string` is None, empty, or "linear", this will return None.
290 | If `activation_string` is not a string, it will return `activation_string`.
291 |
292 | Raises:
293 | ValueError: The `activation_string` does not correspond to a known
294 | activation.
295 | """
296 |
297 | # We assume that anything that"s not a string is already an activation
298 | # function, so we just return it.
299 | if not isinstance(activation_string, six.string_types):
300 | return activation_string
301 |
302 | if not activation_string:
303 | return None
304 |
305 | act = activation_string.lower()
306 | if act == "linear":
307 | return None
308 | elif act == "relu":
309 | return tf.nn.relu
310 | elif act == "gelu":
311 | return gelu
312 | elif act == "tanh":
313 | return tf.tanh
314 | else:
315 | raise ValueError("Unsupported activation: %s" % act)
316 |
317 |
318 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
319 | """Compute the union of the current variables and checkpoint variables."""
320 | assignment_map = {}
321 | initialized_variable_names = {}
322 |
323 | name_to_variable = collections.OrderedDict()
324 | for var in tvars:
325 | name = var.name
326 | m = re.match("^(.*):\\d+$", name)
327 | if m is not None:
328 | name = m.group(1)
329 | name_to_variable[name] = var
330 |
331 | init_vars = tf.train.list_variables(init_checkpoint)
332 |
333 | assignment_map = collections.OrderedDict()
334 | for x in init_vars:
335 | (name, var) = (x[0], x[1])
336 | if name not in name_to_variable:
337 | continue
338 | assignment_map[name] = name
339 | initialized_variable_names[name] = 1
340 | initialized_variable_names[name + ":0"] = 1
341 |
342 | return (assignment_map, initialized_variable_names)
343 |
344 |
345 | def dropout(input_tensor, dropout_prob):
346 | """Perform dropout.
347 |
348 | Args:
349 | input_tensor: float Tensor.
350 | dropout_prob: Python float. The probability of dropping out a value (NOT of
351 | *keeping* a dimension as in `tf.nn.dropout`).
352 |
353 | Returns:
354 | A version of `input_tensor` with dropout applied.
355 | """
356 | if dropout_prob is None or dropout_prob == 0.0:
357 | return input_tensor
358 |
359 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
360 | return output
361 |
362 |
363 | def layer_norm(input_tensor, name=None):
364 | """Run layer normalization on the last dimension of the tensor."""
365 | return tf.contrib.layers.layer_norm(
366 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
367 |
368 |
369 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
370 | """Runs layer normalization followed by dropout."""
371 | output_tensor = layer_norm(input_tensor, name)
372 | output_tensor = dropout(output_tensor, dropout_prob)
373 | return output_tensor
374 |
375 |
376 | def create_initializer(initializer_range=0.02):
377 | """Creates a `truncated_normal_initializer` with the given range."""
378 | return tf.truncated_normal_initializer(stddev=initializer_range)
379 |
380 |
381 | def embedding_lookup(input_ids,
382 | vocab_size,
383 | embedding_size=128,
384 | initializer_range=0.02,
385 | word_embedding_name="word_embeddings",
386 | use_one_hot_embeddings=False):
387 | """Looks up words embeddings for id tensor.
388 |
389 | Args:
390 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
391 | ids.
392 | vocab_size: int. Size of the embedding vocabulary.
393 | embedding_size: int. Width of the word embeddings.
394 | initializer_range: float. Embedding initialization range.
395 | word_embedding_name: string. Name of the embedding table.
396 | use_one_hot_embeddings: bool. If True, use one-hot method for word
397 | embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better
398 | for TPUs.
399 |
400 | Returns:
401 | float Tensor of shape [batch_size, seq_length, embedding_size].
402 | """
403 | # This function assumes that the input is of shape [batch_size, seq_length,
404 | # num_inputs].
405 | #
406 | # If the input is a 2D tensor of shape [batch_size, seq_length], we
407 | # reshape to [batch_size, seq_length, 1].
408 | if input_ids.shape.ndims == 2:
409 | input_ids = tf.expand_dims(input_ids, axis=[-1])
410 |
411 | embedding_table = tf.get_variable(
412 | name=word_embedding_name,
413 | shape=[vocab_size, embedding_size],
414 | initializer=create_initializer(initializer_range))
415 |
416 | if use_one_hot_embeddings:
417 | flat_input_ids = tf.reshape(input_ids, [-1])
418 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
419 | output = tf.matmul(one_hot_input_ids, embedding_table)
420 | else:
421 | output = tf.nn.embedding_lookup(embedding_table, input_ids)
422 |
423 | input_shape = get_shape_list(input_ids)
424 |
425 | output = tf.reshape(output,
426 | input_shape[0:-1] + [input_shape[-1] * embedding_size])
427 | return (output, embedding_table)
428 |
429 |
430 | def embedding_postprocessor(input_tensor,
431 | use_token_type=False,
432 | token_type_ids=None,
433 | token_type_vocab_size=16,
434 | token_type_embedding_name="token_type_embeddings",
435 | use_position_embeddings=True,
436 | position_embedding_name="position_embeddings",
437 | initializer_range=0.02,
438 | max_position_embeddings=512,
439 | dropout_prob=0.1):
440 | """Performs various post-processing on a word embedding tensor.
441 |
442 | Args:
443 | input_tensor: float Tensor of shape [batch_size, seq_length,
444 | embedding_size].
445 | use_token_type: bool. Whether to add embeddings for `token_type_ids`.
446 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
447 | Must be specified if `use_token_type` is True.
448 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
449 | token_type_embedding_name: string. The name of the embedding table variable
450 | for token type ids.
451 | use_position_embeddings: bool. Whether to add position embeddings for the
452 | position of each token in the sequence.
453 | position_embedding_name: string. The name of the embedding table variable
454 | for positional embeddings.
455 | initializer_range: float. Range of the weight initialization.
456 | max_position_embeddings: int. Maximum sequence length that might ever be
457 | used with this model. This can be longer than the sequence length of
458 | input_tensor, but cannot be shorter.
459 | dropout_prob: float. Dropout probability applied to the final output tensor.
460 |
461 | Returns:
462 | float tensor with same shape as `input_tensor`.
463 |
464 | Raises:
465 | ValueError: One of the tensor shapes or input values is invalid.
466 | """
467 | input_shape = get_shape_list(input_tensor, expected_rank=3)
468 | batch_size = input_shape[0]
469 | seq_length = input_shape[1]
470 | width = input_shape[2]
471 |
472 | output = input_tensor
473 |
474 | if use_token_type:
475 | if token_type_ids is None:
476 | raise ValueError("`token_type_ids` must be specified if"
477 | "`use_token_type` is True.")
478 | token_type_table = tf.get_variable(
479 | name=token_type_embedding_name,
480 | shape=[token_type_vocab_size, width],
481 | initializer=create_initializer(initializer_range))
482 | # This vocab will be small so we always do one-hot here, since it is always
483 | # faster for a small vocabulary.
484 | flat_token_type_ids = tf.reshape(token_type_ids, [-1])
485 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
486 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
487 | token_type_embeddings = tf.reshape(token_type_embeddings,
488 | [batch_size, seq_length, width])
489 | output += token_type_embeddings
490 |
491 | if use_position_embeddings:
492 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
493 | with tf.control_dependencies([assert_op]):
494 | full_position_embeddings = tf.get_variable(
495 | name=position_embedding_name,
496 | shape=[max_position_embeddings, width],
497 | initializer=create_initializer(initializer_range))
498 | # Since the position embedding table is a learned variable, we create it
499 | # using a (long) sequence length `max_position_embeddings`. The actual
500 | # sequence length might be shorter than this, for faster training of
501 | # tasks that do not have long sequences.
502 | #
503 | # So `full_position_embeddings` is effectively an embedding table
504 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
505 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
506 | # perform a slice.
507 | position_embeddings = tf.slice(full_position_embeddings, [0, 0],
508 | [seq_length, -1])
509 | num_dims = len(output.shape.as_list())
510 |
511 | # Only the last two dimensions are relevant (`seq_length` and `width`), so
512 | # we broadcast among the first dimensions, which is typically just
513 | # the batch size.
514 | position_broadcast_shape = []
515 | for _ in range(num_dims - 2):
516 | position_broadcast_shape.append(1)
517 | position_broadcast_shape.extend([seq_length, width])
518 | position_embeddings = tf.reshape(position_embeddings,
519 | position_broadcast_shape)
520 | output += position_embeddings
521 |
522 | output = layer_norm_and_dropout(output, dropout_prob)
523 | return output
524 |
525 |
526 | def create_attention_mask_from_input_mask(from_tensor, to_mask):
527 | """Create 3D attention mask from a 2D tensor mask.
528 |
529 | Args:
530 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
531 | to_mask: int32 Tensor of shape [batch_size, to_seq_length].
532 |
533 | Returns:
534 | float Tensor of shape [batch_size, from_seq_length, to_seq_length].
535 | """
536 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
537 | batch_size = from_shape[0]
538 | from_seq_length = from_shape[1]
539 |
540 | to_shape = get_shape_list(to_mask, expected_rank=2)
541 | to_seq_length = to_shape[1]
542 |
543 | to_mask = tf.cast(
544 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
545 |
546 | # We don't assume that `from_tensor` is a mask (although it could be). We
547 | # don't actually care if we attend *from* padding tokens (only *to* padding)
548 | # tokens so we create a tensor of all ones.
549 | #
550 | # `broadcast_ones` = [batch_size, from_seq_length, 1]
551 | broadcast_ones = tf.ones(
552 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32)
553 |
554 | # Here we broadcast along two dimensions to create the mask.
555 | mask = broadcast_ones * to_mask
556 |
557 | return mask
558 |
559 |
560 | def attention_layer(from_tensor,
561 | to_tensor,
562 | attention_mask=None,
563 | num_attention_heads=1,
564 | size_per_head=512,
565 | query_act=None,
566 | key_act=None,
567 | value_act=None,
568 | attention_probs_dropout_prob=0.0,
569 | initializer_range=0.02,
570 | do_return_2d_tensor=False,
571 | batch_size=None,
572 | from_seq_length=None,
573 | to_seq_length=None):
574 | """Performs multi-headed attention from `from_tensor` to `to_tensor`.
575 |
576 | This is an implementation of multi-headed attention based on "Attention
577 | is all you Need". If `from_tensor` and `to_tensor` are the same, then
578 | this is self-attention. Each timestep in `from_tensor` attends to the
579 | corresponding sequence in `to_tensor`, and returns a fixed-with vector.
580 |
581 | This function first projects `from_tensor` into a "query" tensor and
582 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list
583 | of tensors of length `num_attention_heads`, where each tensor is of shape
584 | [batch_size, seq_length, size_per_head].
585 |
586 | Then, the query and key tensors are dot-producted and scaled. These are
587 | softmaxed to obtain attention probabilities. The value tensors are then
588 | interpolated by these probabilities, then concatenated back to a single
589 | tensor and returned.
590 |
591 | In practice, the multi-headed attention are done with transposes and
592 | reshapes rather than actual separate tensors.
593 |
594 | Args:
595 | from_tensor: float Tensor of shape [batch_size, from_seq_length,
596 | from_width].
597 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
598 | attention_mask: (optional) int32 Tensor of shape [batch_size,
599 | from_seq_length, to_seq_length]. The values should be 1 or 0. The
600 | attention scores will effectively be set to -infinity for any positions in
601 | the mask that are 0, and will be unchanged for positions that are 1.
602 | num_attention_heads: int. Number of attention heads.
603 | size_per_head: int. Size of each attention head.
604 | query_act: (optional) Activation function for the query transform.
605 | key_act: (optional) Activation function for the key transform.
606 | value_act: (optional) Activation function for the value transform.
607 | attention_probs_dropout_prob: (optional) float. Dropout probability of the
608 | attention probabilities.
609 | initializer_range: float. Range of the weight initializer.
610 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
611 | * from_seq_length, num_attention_heads * size_per_head]. If False, the
612 | output will be of shape [batch_size, from_seq_length, num_attention_heads
613 | * size_per_head].
614 | batch_size: (Optional) int. If the input is 2D, this might be the batch size
615 | of the 3D version of the `from_tensor` and `to_tensor`.
616 | from_seq_length: (Optional) If the input is 2D, this might be the seq length
617 | of the 3D version of the `from_tensor`.
618 | to_seq_length: (Optional) If the input is 2D, this might be the seq length
619 | of the 3D version of the `to_tensor`.
620 |
621 | Returns:
622 | float Tensor of shape [batch_size, from_seq_length,
623 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
624 | true, this will be of shape [batch_size * from_seq_length,
625 | num_attention_heads * size_per_head]).
626 |
627 | Raises:
628 | ValueError: Any of the arguments or tensor shapes are invalid.
629 | """
630 |
631 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
632 | seq_length, width):
633 | output_tensor = tf.reshape(
634 | input_tensor, [batch_size, seq_length, num_attention_heads, width])
635 |
636 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
637 | return output_tensor
638 |
639 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
640 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
641 |
642 | if len(from_shape) != len(to_shape):
643 | raise ValueError(
644 | "The rank of `from_tensor` must match the rank of `to_tensor`.")
645 |
646 | if len(from_shape) == 3:
647 | batch_size = from_shape[0]
648 | from_seq_length = from_shape[1]
649 | to_seq_length = to_shape[1]
650 | elif len(from_shape) == 2:
651 | if (batch_size is None or from_seq_length is None or to_seq_length is None):
652 | raise ValueError(
653 | "When passing in rank 2 tensors to attention_layer, the values "
654 | "for `batch_size`, `from_seq_length`, and `to_seq_length` "
655 | "must all be specified.")
656 |
657 | # Scalar dimensions referenced here:
658 | # B = batch size (number of sequences)
659 | # F = `from_tensor` sequence length
660 | # T = `to_tensor` sequence length
661 | # N = `num_attention_heads`
662 | # H = `size_per_head`
663 |
664 | from_tensor_2d = reshape_to_matrix(from_tensor)
665 | to_tensor_2d = reshape_to_matrix(to_tensor)
666 |
667 | # `query_layer` = [B*F, N*H]
668 | query_layer = tf.layers.dense(
669 | from_tensor_2d,
670 | num_attention_heads * size_per_head,
671 | activation=query_act,
672 | name="query",
673 | kernel_initializer=create_initializer(initializer_range))
674 |
675 | # `key_layer` = [B*T, N*H]
676 | key_layer = tf.layers.dense(
677 | to_tensor_2d,
678 | num_attention_heads * size_per_head,
679 | activation=key_act,
680 | name="key",
681 | kernel_initializer=create_initializer(initializer_range))
682 |
683 | # `value_layer` = [B*T, N*H]
684 | value_layer = tf.layers.dense(
685 | to_tensor_2d,
686 | num_attention_heads * size_per_head,
687 | activation=value_act,
688 | name="value",
689 | kernel_initializer=create_initializer(initializer_range))
690 |
691 | # `query_layer` = [B, N, F, H]
692 | query_layer = transpose_for_scores(query_layer, batch_size,
693 | num_attention_heads, from_seq_length,
694 | size_per_head)
695 |
696 | # `key_layer` = [B, N, T, H]
697 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
698 | to_seq_length, size_per_head)
699 |
700 | # Take the dot product between "query" and "key" to get the raw
701 | # attention scores.
702 | # `attention_scores` = [B, N, F, T]
703 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
704 | attention_scores = tf.multiply(attention_scores,
705 | 1.0 / math.sqrt(float(size_per_head)))
706 |
707 | if attention_mask is not None:
708 | # `attention_mask` = [B, 1, F, T]
709 | attention_mask = tf.expand_dims(attention_mask, axis=[1])
710 |
711 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
712 | # masked positions, this operation will create a tensor which is 0.0 for
713 | # positions we want to attend and -10000.0 for masked positions.
714 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
715 |
716 | # Since we are adding it to the raw scores before the softmax, this is
717 | # effectively the same as removing these entirely.
718 | attention_scores += adder
719 |
720 | # Normalize the attention scores to probabilities.
721 | # `attention_probs` = [B, N, F, T]
722 | attention_probs = tf.nn.softmax(attention_scores)
723 |
724 | # This is actually dropping out entire tokens to attend to, which might
725 | # seem a bit unusual, but is taken from the original Transformer paper.
726 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
727 |
728 | # `value_layer` = [B, T, N, H]
729 | value_layer = tf.reshape(
730 | value_layer,
731 | [batch_size, to_seq_length, num_attention_heads, size_per_head])
732 |
733 | # `value_layer` = [B, N, T, H]
734 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
735 |
736 | # `context_layer` = [B, N, F, H]
737 | context_layer = tf.matmul(attention_probs, value_layer)
738 |
739 | # `context_layer` = [B, F, N, H]
740 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
741 |
742 | if do_return_2d_tensor:
743 | # `context_layer` = [B*F, N*H]
744 | context_layer = tf.reshape(
745 | context_layer,
746 | [batch_size * from_seq_length, num_attention_heads * size_per_head])
747 | else:
748 | # `context_layer` = [B, F, N*H]
749 | context_layer = tf.reshape(
750 | context_layer,
751 | [batch_size, from_seq_length, num_attention_heads * size_per_head])
752 |
753 | return context_layer
754 |
755 |
756 | def transformer_model(input_tensor,
757 | attention_mask=None,
758 | hidden_size=768,
759 | num_hidden_layers=12,
760 | num_attention_heads=12,
761 | intermediate_size=3072,
762 | intermediate_act_fn=gelu,
763 | hidden_dropout_prob=0.1,
764 | attention_probs_dropout_prob=0.1,
765 | initializer_range=0.02,
766 | do_return_all_layers=False):
767 | """Multi-headed, multi-layer Transformer from "Attention is All You Need".
768 |
769 | This is almost an exact implementation of the original Transformer encoder.
770 |
771 | See the original paper:
772 | https://arxiv.org/abs/1706.03762
773 |
774 | Also see:
775 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
776 |
777 | Args:
778 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
779 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
780 | seq_length], with 1 for positions that can be attended to and 0 in
781 | positions that should not be.
782 | hidden_size: int. Hidden size of the Transformer.
783 | num_hidden_layers: int. Number of layers (blocks) in the Transformer.
784 | num_attention_heads: int. Number of attention heads in the Transformer.
785 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed
786 | forward) layer.
787 | intermediate_act_fn: function. The non-linear activation function to apply
788 | to the output of the intermediate/feed-forward layer.
789 | hidden_dropout_prob: float. Dropout probability for the hidden layers.
790 | attention_probs_dropout_prob: float. Dropout probability of the attention
791 | probabilities.
792 | initializer_range: float. Range of the initializer (stddev of truncated
793 | normal).
794 | do_return_all_layers: Whether to also return all layers or just the final
795 | layer.
796 |
797 | Returns:
798 | float Tensor of shape [batch_size, seq_length, hidden_size], the final
799 | hidden layer of the Transformer.
800 |
801 | Raises:
802 | ValueError: A Tensor shape or parameter is invalid.
803 | """
804 | if hidden_size % num_attention_heads != 0:
805 | raise ValueError(
806 | "The hidden size (%d) is not a multiple of the number of attention "
807 | "heads (%d)" % (hidden_size, num_attention_heads))
808 |
809 | attention_head_size = int(hidden_size / num_attention_heads)
810 | input_shape = get_shape_list(input_tensor, expected_rank=3)
811 | batch_size = input_shape[0]
812 | seq_length = input_shape[1]
813 | input_width = input_shape[2]
814 |
815 | # The Transformer performs sum residuals on all layers so the input needs
816 | # to be the same as the hidden size.
817 | if input_width != hidden_size:
818 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
819 | (input_width, hidden_size))
820 |
821 | # We keep the representation as a 2D tensor to avoid re-shaping it back and
822 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
823 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
824 | # help the optimizer.
825 | prev_output = reshape_to_matrix(input_tensor)
826 |
827 | all_layer_outputs = []
828 | for layer_idx in range(num_hidden_layers):
829 | with tf.variable_scope("layer_%d" % layer_idx):
830 | layer_input = prev_output
831 |
832 | with tf.variable_scope("attention"):
833 | attention_heads = []
834 | with tf.variable_scope("self"):
835 | attention_head = attention_layer(
836 | from_tensor=layer_input,
837 | to_tensor=layer_input,
838 | attention_mask=attention_mask,
839 | num_attention_heads=num_attention_heads,
840 | size_per_head=attention_head_size,
841 | attention_probs_dropout_prob=attention_probs_dropout_prob,
842 | initializer_range=initializer_range,
843 | do_return_2d_tensor=True,
844 | batch_size=batch_size,
845 | from_seq_length=seq_length,
846 | to_seq_length=seq_length)
847 | attention_heads.append(attention_head)
848 |
849 | attention_output = None
850 | if len(attention_heads) == 1:
851 | attention_output = attention_heads[0]
852 | else:
853 | # In the case where we have other sequences, we just concatenate
854 | # them to the self-attention head before the projection.
855 | attention_output = tf.concat(attention_heads, axis=-1)
856 |
857 | # Run a linear projection of `hidden_size` then add a residual
858 | # with `layer_input`.
859 | with tf.variable_scope("output"):
860 | attention_output = tf.layers.dense(
861 | attention_output,
862 | hidden_size,
863 | kernel_initializer=create_initializer(initializer_range))
864 | attention_output = dropout(attention_output, hidden_dropout_prob)
865 | attention_output = layer_norm(attention_output + layer_input)
866 |
867 | # The activation is only applied to the "intermediate" hidden layer.
868 | with tf.variable_scope("intermediate"):
869 | intermediate_output = tf.layers.dense(
870 | attention_output,
871 | intermediate_size,
872 | activation=intermediate_act_fn,
873 | kernel_initializer=create_initializer(initializer_range))
874 |
875 | # Down-project back to `hidden_size` then add the residual.
876 | with tf.variable_scope("output"):
877 | layer_output = tf.layers.dense(
878 | intermediate_output,
879 | hidden_size,
880 | kernel_initializer=create_initializer(initializer_range))
881 | layer_output = dropout(layer_output, hidden_dropout_prob)
882 | layer_output = layer_norm(layer_output + attention_output)
883 | prev_output = layer_output
884 | all_layer_outputs.append(layer_output)
885 |
886 | if do_return_all_layers:
887 | final_outputs = []
888 | for layer_output in all_layer_outputs:
889 | final_output = reshape_from_matrix(layer_output, input_shape)
890 | final_outputs.append(final_output)
891 | return final_outputs
892 | else:
893 | final_output = reshape_from_matrix(prev_output, input_shape)
894 | return final_output
895 |
896 |
897 | def get_shape_list(tensor, expected_rank=None, name=None):
898 | """Returns a list of the shape of tensor, preferring static dimensions.
899 |
900 | Args:
901 | tensor: A tf.Tensor object to find the shape of.
902 | expected_rank: (optional) int. The expected rank of `tensor`. If this is
903 | specified and the `tensor` has a different rank, and exception will be
904 | thrown.
905 | name: Optional name of the tensor for the error message.
906 |
907 | Returns:
908 | A list of dimensions of the shape of tensor. All static dimensions will
909 | be returned as python integers, and dynamic dimensions will be returned
910 | as tf.Tensor scalars.
911 | """
912 | if name is None:
913 | name = tensor.name
914 |
915 | if expected_rank is not None:
916 | assert_rank(tensor, expected_rank, name)
917 |
918 | shape = tensor.shape.as_list()
919 |
920 | non_static_indexes = []
921 | for (index, dim) in enumerate(shape):
922 | if dim is None:
923 | non_static_indexes.append(index)
924 |
925 | if not non_static_indexes:
926 | return shape
927 |
928 | dyn_shape = tf.shape(tensor)
929 | for index in non_static_indexes:
930 | shape[index] = dyn_shape[index]
931 | return shape
932 |
933 |
934 | def reshape_to_matrix(input_tensor):
935 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
936 | ndims = input_tensor.shape.ndims
937 | if ndims < 2:
938 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
939 | (input_tensor.shape))
940 | if ndims == 2:
941 | return input_tensor
942 |
943 | width = input_tensor.shape[-1]
944 | output_tensor = tf.reshape(input_tensor, [-1, width])
945 | return output_tensor
946 |
947 |
948 | def reshape_from_matrix(output_tensor, orig_shape_list):
949 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
950 | if len(orig_shape_list) == 2:
951 | return output_tensor
952 |
953 | output_shape = get_shape_list(output_tensor)
954 |
955 | orig_dims = orig_shape_list[0:-1]
956 | width = output_shape[-1]
957 |
958 | return tf.reshape(output_tensor, orig_dims + [width])
959 |
960 |
961 | def assert_rank(tensor, expected_rank, name=None):
962 | """Raises an exception if the tensor rank is not of the expected rank.
963 |
964 | Args:
965 | tensor: A tf.Tensor to check the rank of.
966 | expected_rank: Python integer or list of integers, expected rank.
967 | name: Optional name of the tensor for the error message.
968 |
969 | Raises:
970 | ValueError: If the expected shape doesn't match the actual shape.
971 | """
972 | if name is None:
973 | name = tensor.name
974 |
975 | expected_rank_dict = {}
976 | if isinstance(expected_rank, six.integer_types):
977 | expected_rank_dict[expected_rank] = True
978 | else:
979 | for x in expected_rank:
980 | expected_rank_dict[x] = True
981 |
982 | actual_rank = tensor.shape.ndims
983 | if actual_rank not in expected_rank_dict:
984 | scope_name = tf.get_variable_scope().name
985 | raise ValueError(
986 | "For the tensor `%s` in scope `%s`, the actual rank "
987 | "`%d` (shape = %s) is not equal to the expected rank `%s`" %
988 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
989 |
--------------------------------------------------------------------------------