├── .gitignore
├── .idea
├── bert.iml
├── encodings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── __init__.py
├── atec_analysis.py
├── atec_rnn_similar.py
├── bert_as_feature.py
├── create_pretraining_data.py
├── data
└── atec
│ ├── dev.csv
│ └── train.csv
├── extract_features.py
├── general_utils.py
├── get_started
├── custom_estimator.py
├── estimator_test.py
├── iris_data.py
└── premade_estimator.py
├── modeling.py
├── modeling_test.py
├── optimization.py
├── optimization_test.py
├── run_classifier.py
├── run_classifier_predict_online.py
├── run_pretraining.py
├── test.py
├── tokenization.py
└── tokenization_test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Initially taken from Github's Python gitignore file
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # atec_code
9 | atec*.py
10 | !atec_rnn_similar.py
11 |
12 |
13 | # weight
14 | weight/*
15 | !weight/atec/
16 | weight/atec/*
17 |
18 | # data
19 | data/atec/*
20 | !data/atec/train.csv
21 | !data/atec/dev.csv
22 |
23 |
24 | # output
25 | output/
26 |
27 | # Distribution / packaging
28 | .Python
29 | build/
30 | develop-eggs/
31 | dist/
32 | downloads/
33 | eggs/
34 | .eggs/
35 | lib/
36 | lib64/
37 | parts/
38 | sdist/
39 | var/
40 | wheels/
41 | *.egg-info/
42 | .installed.cfg
43 | *.egg
44 | MANIFEST
45 |
46 | # PyInstaller
47 | # Usually these files are written by a python script from a template
48 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
49 | *.manifest
50 | *.spec
51 |
52 | # Installer logs
53 | pip-log.txt
54 | pip-delete-this-directory.txt
55 |
56 | # Unit test / coverage reports
57 | htmlcov/
58 | .tox/
59 | .nox/
60 | .coverage
61 | .coverage.*
62 | .cache
63 | nosetests.xml
64 | coverage.xml
65 | *.cover
66 | .hypothesis/
67 | .pytest_cache/
68 |
69 | # Translations
70 | *.mo
71 | *.pot
72 |
73 | # Django stuff:
74 | *.log
75 | local_settings.py
76 | db.sqlite3
77 |
78 | # Flask stuff:
79 | instance/
80 | .webassets-cache
81 |
82 | # Scrapy stuff:
83 | .scrapy
84 |
85 | # Sphinx documentation
86 | docs/_build/
87 |
88 | # PyBuilder
89 | target/
90 |
91 | # Jupyter Notebook
92 | .ipynb_checkpoints
93 |
94 | # IPython
95 | profile_default/
96 | ipython_config.py
97 |
98 | # pyenv
99 | .python-version
100 |
101 | # celery beat schedule file
102 | celerybeat-schedule
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
--------------------------------------------------------------------------------
/.idea/bert.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
107 |
108 |
109 |
110 | map_
111 | map_a
112 | file_based_input_fn_builder
113 | output_dir
114 | saved_model_dir
115 | create
116 | num_train_epochs
117 | num_train_steps
118 | d = d.
119 | mini
120 | lof_
121 | log
122 | d = d
123 | log_sof
124 | train_and
125 | log_s
126 | num_parallel_calls
127 | close
128 | create_mo
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 | 1546679375826
323 |
324 |
325 | 1546679375826
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 | file://$PROJECT_DIR$/atec_bert.py
381 | 212
382 |
383 |
384 |
385 | file://$PROJECT_DIR$/atec_bert.py
386 | 219
387 |
388 |
389 |
390 | file://$PROJECT_DIR$/atec_bert.py
391 | 386
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [TOC]
2 |
3 | # Use BERT as feature
4 | 1. 如何调用bert,将输入的语句输出为向量?
5 | 2. 如果在自己的代码中添加bert作为底层特征,需要官方例子run_classifier.py的那么多代码吗?
6 | # 环境
7 |
8 | ```python
9 | mac:
10 | tf==1.4.0
11 | python=2.7
12 |
13 | windows:
14 | tf==1.12
15 | python=3.5
16 | ```
17 |
18 | # 入口
19 |
20 | 调用预训练的模型,来做句子的预测。
21 | bert_as_feature.py
22 | 配置data_root为模型的地址
23 | 调用预训练模型:chinese_L-12_H-768_A-12
24 | 调用核心代码:
25 | ```python
26 | # graph
27 | input_ids = tf.placeholder(tf.int32, shape=[None, None], name='input_ids')
28 | input_mask = tf.placeholder(tf.int32, shape=[None, None], name='input_masks')
29 | segment_ids = tf.placeholder(tf.int32, shape=[None, None], name='segment_ids')
30 |
31 | # 初始化BERT
32 | model = modeling.BertModel(
33 | config=bert_config,
34 | is_training=False,
35 | input_ids=input_ids,
36 | input_mask=input_mask,
37 | token_type_ids=segment_ids,
38 | use_one_hot_embeddings=False)
39 |
40 | # 加载bert模型
41 | tvars = tf.trainable_variables()
42 | (assignment, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_check_point)
43 |
44 | # 获取最后一层和倒数第二层。
45 | encoder_last_layer = model.get_sequence_output()
46 | encoder_last2_layer = model.all_encoder_layers[-2]
47 |
48 | with tf.Session() as sess:
49 | sess.run(tf.global_variables_initializer())
50 |
51 | token = tokenization.CharTokenizer(vocab_file=bert_vocab_file)
52 | query = u'Jack,请回答1988, UNwant\u00E9d,running'
53 | split_tokens = token.tokenize(query)
54 | word_ids = token.convert_tokens_to_ids(split_tokens)
55 | word_mask = [1] * len(word_ids)
56 | word_segment_ids = [0] * len(word_ids)
57 | fd = {input_ids: [word_ids], input_mask: [word_mask], segment_ids: [word_segment_ids]}
58 | last, last2 = sess.run([encoder_last_layer, encoder_last_layer], feed_dict=fd)
59 | print('last shape:{}, last2 shape: {}'.format(last.shape, last2.shape))
60 | ```
61 |
62 | 完整代码见: [bert_as_feature.py](https://github.com/InsaneLife/bert/blob/master/bert_as_feature.py)
63 |
64 | 代码库:https://github.com/InsaneLife/bert
65 |
66 | 中文模型下载:**[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
67 |
68 | # 最终结果
69 |
70 | 最后一层和倒数第二层:
71 | last shape:(1, 14, 768), last2 shape: (1, 14, 768)
72 |
73 | ```
74 | # last value
75 | [[ 0.8200665 1.7532703 -0.3771637 ... -0.63692784 -0.17133102
76 | 0.01075665]
77 | [ 0.79148203 -0.08384223 -0.51832616 ... 0.8080162 1.9931345
78 | 1.072408 ]
79 | [-0.02546642 2.2759912 -0.6004753 ... -0.88577884 3.1459959
80 | -0.03815675]
81 | ...
82 | [-0.15581022 1.154014 -0.96733016 ... -0.47922543 0.51068854
83 | 0.29749477]
84 | [ 0.38253042 0.09779643 -0.39919692 ... 0.98277044 0.6780443
85 | -0.52883977]
86 | [ 0.20359193 -0.42314947 0.51891303 ... -0.23625426 0.666618
87 | 0.30184716]]
88 | ```
89 |
90 |
91 |
92 | # 预处理
93 |
94 | `tokenization.py`是对输入的句子处理,包含两个主要类:`BasickTokenizer`, `FullTokenizer`
95 |
96 | `BasickTokenizer`会对每个字做分割,会识别英文单词,对于数字会合并,例如:
97 |
98 | ```
99 | query: 'Jack,请回答1988, UNwant\u00E9d,running'
100 | token: ['jack', ',', '请', '回', '答', '1988', ',', 'unwanted', ',', 'running']
101 | ```
102 |
103 | `FullTokenizer`会对英文字符做n-gram匹配,会将英文单词拆分,例如running会拆分为run、##ing,主要是针对英文。
104 |
105 | ```
106 | query: 'UNwant\u00E9d,running'
107 | token: ["un", "##want", "##ed", ",", "runn", "##ing"]
108 | ```
109 |
110 | 对于中文数据,特别是NER,如果数字和英文单词是整体的话,会出现大量UNK,所以要将其拆开,想要的结果:
111 |
112 | ```
113 | query: 'Jack,请回答1988'
114 | token: ['j', 'a', 'c', 'k', ',', '请', '回', '答', '1', '9', '8', '8']
115 | ```
116 |
117 | 具体变动如下:
118 |
119 | ```python
120 | class CharTokenizer(object):
121 | """Runs end-to-end tokenziation."""
122 | def __init__(self, vocab_file, do_lower_case=True):
123 | self.vocab = load_vocab(vocab_file)
124 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
125 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
126 |
127 | def tokenize(self, text):
128 | split_tokens = []
129 | for token in self.basic_tokenizer.tokenize(text):
130 | for sub_token in token:
131 | split_tokens.append(sub_token)
132 | return split_tokens
133 |
134 | def convert_tokens_to_ids(self, tokens):
135 | return convert_tokens_to_ids(self.vocab, tokens)
136 | ```
137 |
138 |
139 |
140 |
141 |
142 |
143 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/atec_analysis.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import pandas as pd
3 | from sklearn import model_selection
4 |
5 | data_path = "data/atec/all.csv"
6 |
7 | all_data = pd.read_csv(data_path, sep='\t', names=['line', 'seq1', 'seq2', 'label'], header=-1)
8 | print(all_data.shape)
9 | print(all_data['seq1'].str.len().max()) # 97
10 | print(all_data['seq2'].str.len().max()) # 112
11 | print(all_data['seq2'].str.cat(all_data['seq1']).str.len().max()) # 166
12 |
13 | x = all_data[['line', 'seq1', 'seq2']]
14 | y = all_data['label']
15 | x_train, x_test, y_train, y_test = model_selection.train_test_split(x, y, test_size=0.3, random_state=2019)
16 |
17 | x_train['label'] = y_train
18 | x_train.to_csv('data/atec/train.csv', index=False, sep='\t', encoding='utf-8', header=False)
19 |
20 | x_test['label'] = y_test
21 | x_test.to_csv('data/atec/dev.csv', index=False, sep='\t', encoding='utf-8', header=False)
22 |
23 | a = 10e-2
24 |
25 |
26 | pass
27 |
--------------------------------------------------------------------------------
/atec_rnn_similar.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | import tensorflow as tf
4 | import tokenization, modeling, optimization
5 | import os
6 | import collections
7 | from general_utils import Progbar
8 | # import tensorflow.keras.utils.Progbar as Progbar
9 | from atec_rnn_config import Config
10 | import sklearn as sk
11 | from sklearn.metrics import roc_auc_score
12 | from sklearn import metrics
13 | import shutil
14 | import time
15 | import random
16 |
17 |
18 | class SimilarModel(object):
19 | def __init__(self, config, train_size):
20 | self.learning_rate = config.learning_rate
21 | self.num_labels = config.n_classes # num of class
22 | self.logger = config.logger
23 | self.config = config
24 | self.num_train_steps = int(train_size / config.train_batch_size * config.nepochs)
25 | self.num_warmup_steps = int(self.num_train_steps * config.warmup_proportion)
26 |
27 | def initialize_session(self):
28 | self.logger.info("Initializing tf session")
29 | global_config = tf.ConfigProto()
30 | global_config.gpu_options.allow_growth = True
31 | self.sess = tf.Session(config=global_config)
32 | self.sess.run(tf.global_variables_initializer())
33 | self.sess.run(tf.tables_initializer())
34 | self.saver = tf.train.Saver()
35 |
36 | def restore_session(self, dir_model):
37 | self.logger.info("Reloading the latest trained model...")
38 | if not os.path.exists(dir_model):
39 | os.makedirs(dir_model)
40 | self.saver.restore(self.sess, dir_model)
41 |
42 | def save_session(self):
43 | if not os.path.exists(self.config.dir_model):
44 | os.makedirs(self.config.dir_model)
45 | self.saver.save(self.sess, self.config.dir_model)
46 |
47 | def close_session(self):
48 | self.sess.close()
49 |
50 | def add_summary(self):
51 | self.merged = tf.summary.merge_all()
52 | self.file_writer = tf.summary.FileWriter(self.config.log_dir, self.sess.graph)
53 |
54 | # input_ids, input_mask, segment_ids, labels,
55 | def add_placeholder(self):
56 | self.text1_word_ids = tf.placeholder(tf.int32, shape=[None, None], name='text1_word_ids')
57 | self.text1_char_ids = tf.placeholder(dtype=tf.int32, shape=[None, None, None], name="text1_char_ids")
58 | self.text1_lengths = tf.placeholder(tf.int32, shape=[None], name='text1_lengths')
59 | self.text1_word_lengths = tf.placeholder(tf.int32, shape=[None, None], name="text1_word_lengths")
60 | self.text2_word_ids = tf.placeholder(tf.int32, shape=[None, None], name='text2_word_ids')
61 | self.text2_char_ids = tf.placeholder(dtype=tf.int32, shape=[None, None, None], name="text2_char_ids")
62 | self.text2_lengths = tf.placeholder(tf.int32, shape=[None], name='text2_lengths')
63 | self.text2_word_lengths = tf.placeholder(tf.int32, shape=[None, None], name="text2_word_lengths")
64 | self.labels = tf.placeholder(tf.int32, shape=[None], name='labels')
65 | self.dropout = tf.placeholder(dtype=tf.float32, shape=[], name="dropout")
66 |
67 | def add_embedding_layer(self):
68 | with tf.variable_scope('word_embedding_layer'):
69 | if self.config.embeddings is None:
70 | _word_embeddings = tf.get_variable(name='_word_embeddings', dtype=tf.int32,
71 | shape=[self.config.nwords, self.config.dim_word])
72 | else:
73 | _word_embeddings = tf.Variable(self.config.embeddings, dtype=tf.int32,
74 | shape=[self.config.nwords, self.config.dim_word])
75 | text1_word_embeddings = tf.nn.embedding_lookup(_word_embeddings, self.text1_word_ids)
76 | text2_word_embeddings = tf.nn.embedding_lookup(_word_embeddings, self.text2_word_ids)
77 |
78 | with tf.variable_scope('char_embedding_layer'):
79 | if self.config.use_chars:
80 | _char_embeddings = tf.get_variable(name='_char_embeddings', dtype=tf.int32,
81 | shape=[self.config.nchars, self.config.dim_char])
82 | text1_char_embeddings = tf.nn.embedding_lookup(_char_embeddings, self.text1_char_ids,
83 | name="char_embeddings")
84 | text2_char_embeddings = tf.nn.embedding_lookup(_char_embeddings, self.text2_char_ids,
85 | name="char_embeddings")
86 | # bs, sentence_len, word_len
87 | s1, s2 = tf.shape(text1_char_embeddings), tf.shape(text2_char_embeddings)
88 | text1_char_embeddings = tf.reshape(_char_embeddings,
89 | shape=[s1[0] * s1[1], s1[-2], self.config.dim_char])
90 | text2_char_embeddings = tf.reshape(_char_embeddings,
91 | shape=[s2[0] * s2[1], s2[-2], self.config.dim_char])
92 | text1_word_lengths = tf.reshape(self.text1_word_lengths, shape=[s1[0] * s1[1]])
93 | text2_word_lengths = tf.reshape(self.text2_word_lengths, shape=[s2[0] * s2[1]])
94 | # bi Rnn
95 | cell_fw = tf.nn.rnn_cell.GRUCell(self.config.hidden_size_char, reuse=tf.AUTO_REUSE)
96 | stacked_gru_fw = tf.nn.rnn_cell.MultiRNNCell([cell_fw], state_is_tuple=True)
97 | cell_bw = tf.nn.rnn_cell.GRUCell(self.config.hidden_size_char, reuse=tf.AUTO_REUSE)
98 | stacked_gru_bw = tf.nn.rnn_cell.MultiRNNCell([cell_bw], state_is_tuple=True)
99 | _, (text1_output_fw, text1_output_bw) = tf.nn.bidirectional_dynamic_rnn(stacked_gru_fw, stacked_gru_bw,
100 | text1_char_embeddings,
101 | sequence_length=text1_word_lengths,
102 | dtype=tf.float32)
103 | _, (text2_output_fw, text2_output_bw) = tf.nn.bidirectional_dynamic_rnn(stacked_gru_fw, stacked_gru_bw,
104 | text2_char_embeddings,
105 | sequence_length=text2_word_lengths,
106 | dtype=tf.float32)
107 | text1_output = tf.concat([text1_output_fw, text1_output_bw], axis=-1)
108 | text1_output = tf.reshape(text1_output, shape=[s1[0], s1[1], 2 * self.config.hidden_size_char])
109 | text1_word_embeddings = tf.concat([text1_word_embeddings, text1_output], axis=-1)
110 | self.text1_word_embeddings = tf.nn.dropout(text1_word_embeddings, self.dropout)
111 | text2_output = tf.concat([text2_output_fw, text2_output_bw], axis=-1)
112 | text2_output = tf.reshape(text2_output, shape=[s2[0], s2[1], 2 * self.config.hidden_size_char])
113 | text2_word_embeddings = tf.concat([text2_word_embeddings, text2_output], axis=-1)
114 | self.text2_word_embeddings = tf.nn.dropout(text2_word_embeddings, self.dropout)
115 |
116 | def add_simlar_layer(self):
117 | with tf.variable_scope("bi_rnn"):
118 | # bi Rnn
119 | cell_fw = tf.nn.rnn_cell.GRUCell(self.config.hidden_size_gru, reuse=tf.AUTO_REUSE)
120 | stacked_gru_fw = tf.nn.rnn_cell.MultiRNNCell([cell_fw], state_is_tuple=True)
121 | cell_bw = tf.nn.rnn_cell.GRUCell(self.config.hidden_size_gru, reuse=tf.AUTO_REUSE)
122 | stacked_gru_bw = tf.nn.rnn_cell.MultiRNNCell([cell_bw], state_is_tuple=True)
123 | text1_state_output, text1_final_state = tf.nn.bidirectional_dynamic_rnn(stacked_gru_fw,
124 | stacked_gru_bw,
125 | self.text1_word_embeddings,
126 | sequence_length=self.text1_lengths,
127 | dtype=tf.float32)
128 | self.text1_state_output = tf.concat(text1_state_output, axis=-1)
129 | self.text1_final_state = tf.concat(text1_final_state, axis=-1)
130 | text2_state_output, text2_final_state = tf.nn.bidirectional_dynamic_rnn(stacked_gru_fw, stacked_gru_bw,
131 | self.text2_word_embeddings,
132 | sequence_length=self.text2_lengths,
133 | dtype=tf.float32)
134 | self.text2_state_output = tf.concat(text2_state_output, axis=-1)
135 | self.text2_final_state = tf.concat(text2_final_state, axis=-1)
136 |
137 | # todo: add attention layer
138 |
139 | with tf.variable_scope("cosine_similar"):
140 | # Cosine similarity
141 | # text1_norm = sqrt(sum(each x^2))
142 | text1_norm = tf.sqrt(tf.reduce_sum(tf.square(self.text1_final_state), 1, True))
143 | text2_norm = tf.sqrt(tf.reduce_sum(tf.square(self.text2_final_state), 1, True))
144 |
145 | prod = tf.reduce_sum(tf.multiply(self.text1_final_state, self.text2_final_state), 1, True)
146 | norm_prod = tf.multiply(text1_norm, text2_norm)
147 |
148 | # cos_sim_raw = query * doc / (||query|| * ||doc||), [bs]
149 | cos_sim_raw = tf.truediv(prod, norm_prod)
150 | # gamma = 20
151 | self.cos_sim = cos_sim_raw
152 | with tf.variable_scope("manhattan_distance"):
153 | self.diff = tf.reduce_sum(tf.abs(tf.subtract(self.text1_final_state, self.text1_final_state)), axis=1) #
154 | self.similarity = tf.exp(-1.0 * self.diff)
155 | # MSE
156 | with tf.variable_scope("loss"):
157 | diff = tf.subtract(self.similarity, self.labels - 1.0) / 4.0 # 32
158 | self.loss = tf.square(diff) # (batch_size,)
159 | self.cost = tf.reduce_mean(self.loss) # (1,)
--------------------------------------------------------------------------------
/bert_as_feature.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | import tensorflow as tf
4 | import tokenization, modeling
5 | import os
6 |
7 | os.environ["CUDA_VISIBLE_DEVICES"] = '2'
8 |
9 | # 配置文件
10 | data_root = 'weight/chinese_L-12_H-768_A-12/'
11 | bert_config_file = data_root + 'bert_config.json'
12 | bert_config = modeling.BertConfig.from_json_file(bert_config_file)
13 | init_checkpoint = data_root + 'bert_model.ckpt'
14 | bert_vocab_file = data_root + 'vocab.txt'
15 | bert_vocab_En_file = 'weight/uncased_L-12_H-768_A-12/vocab.txt'
16 |
17 | # test
18 | token = tokenization.CharTokenizer(vocab_file=bert_vocab_file)
19 | split_tokens = token.tokenize('龘,Jack,请回答1988')
20 | word_ids = token.convert_tokens_to_ids(split_tokens)
21 | word_mask = [1] * len(word_ids)
22 | word_segment_ids = [0] * len(word_ids)
23 |
24 | # graph
25 | input_ids = tf.placeholder(tf.int32, shape=[None, None], name='input_ids')
26 | input_mask = tf.placeholder(tf.int32, shape=[None, None], name='input_masks')
27 | segment_ids = tf.placeholder(tf.int32, shape=[None, None], name='segment_ids')
28 |
29 | # 初始化BERT
30 | model = modeling.BertModel(
31 | config=bert_config,
32 | is_training=False,
33 | input_ids=input_ids,
34 | input_mask=input_mask,
35 | token_type_ids=segment_ids,
36 | use_one_hot_embeddings=False)
37 |
38 | # 加载bert模型
39 | tvars = tf.trainable_variables()
40 | (assignment, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
41 | tf.train.init_from_checkpoint(init_checkpoint, assignment)
42 | # 获取最后一层和倒数第二层。
43 | encoder_last_layer = model.get_sequence_output()
44 | encoder_last2_layer = model.all_encoder_layers[-2]
45 |
46 | with tf.Session() as sess:
47 | sess.run(tf.global_variables_initializer())
48 |
49 | token = tokenization.CharTokenizer(vocab_file=bert_vocab_file)
50 | query = u'Jack,请回答1988, UNwant\u00E9d,running'
51 | split_tokens = token.tokenize(query)
52 | word_ids = token.convert_tokens_to_ids(split_tokens)
53 | word_mask = [1] * len(word_ids)
54 | word_segment_ids = [0] * len(word_ids)
55 | fd = {input_ids: [word_ids], input_mask: [word_mask], segment_ids: [word_segment_ids]}
56 | last, last2 = sess.run([encoder_last_layer, encoder_last_layer], feed_dict=fd)
57 | print('last shape:{}, last2 shape: {}'.format(last.shape, last2.shape))
58 | pass
59 |
--------------------------------------------------------------------------------
/create_pretraining_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Create masked LM/next sentence masked_lm TF examples for BERT."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import random
23 |
24 | import tokenization
25 | import tensorflow as tf
26 |
27 | flags = tf.flags
28 |
29 | FLAGS = flags.FLAGS
30 |
31 | flags.DEFINE_string("input_file", None,
32 | "Input raw text file (or comma-separated list of files).")
33 |
34 | flags.DEFINE_string(
35 | "output_file", None,
36 | "Output TF example file (or comma-separated list of files).")
37 |
38 | flags.DEFINE_string("vocab_file", None,
39 | "The vocabulary file that the BERT model was trained on.")
40 |
41 | flags.DEFINE_bool(
42 | "do_lower_case", True,
43 | "Whether to lower case the input text. Should be True for uncased "
44 | "models and False for cased models.")
45 |
46 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
47 |
48 | flags.DEFINE_integer("max_predictions_per_seq", 20,
49 | "Maximum number of masked LM predictions per sequence.")
50 |
51 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
52 |
53 | flags.DEFINE_integer(
54 | "dupe_factor", 10,
55 | "Number of times to duplicate the input data (with different masks).")
56 |
57 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
58 |
59 | flags.DEFINE_float(
60 | "short_seq_prob", 0.1,
61 | "Probability of creating sequences which are shorter than the "
62 | "maximum length.")
63 |
64 |
65 | class TrainingInstance(object):
66 | """A single training instance (sentence pair)."""
67 |
68 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
69 | is_random_next):
70 | self.tokens = tokens
71 | self.segment_ids = segment_ids
72 | self.is_random_next = is_random_next
73 | self.masked_lm_positions = masked_lm_positions
74 | self.masked_lm_labels = masked_lm_labels
75 |
76 | def __str__(self):
77 | s = ""
78 | s += "tokens: %s\n" % (" ".join(
79 | [tokenization.printable_text(x) for x in self.tokens]))
80 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
81 | s += "is_random_next: %s\n" % self.is_random_next
82 | s += "masked_lm_positions: %s\n" % (" ".join(
83 | [str(x) for x in self.masked_lm_positions]))
84 | s += "masked_lm_labels: %s\n" % (" ".join(
85 | [tokenization.printable_text(x) for x in self.masked_lm_labels]))
86 | s += "\n"
87 | return s
88 |
89 | def __repr__(self):
90 | return self.__str__()
91 |
92 |
93 | def write_instance_to_example_files(instances, tokenizer, max_seq_length,
94 | max_predictions_per_seq, output_files):
95 | """Create TF example files from `TrainingInstance`s."""
96 | writers = []
97 | for output_file in output_files:
98 | writers.append(tf.python_io.TFRecordWriter(output_file))
99 |
100 | writer_index = 0
101 |
102 | total_written = 0
103 | for (inst_index, instance) in enumerate(instances):
104 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
105 | input_mask = [1] * len(input_ids)
106 | segment_ids = list(instance.segment_ids)
107 | assert len(input_ids) <= max_seq_length
108 |
109 | while len(input_ids) < max_seq_length:
110 | input_ids.append(0)
111 | input_mask.append(0)
112 | segment_ids.append(0)
113 |
114 | assert len(input_ids) == max_seq_length
115 | assert len(input_mask) == max_seq_length
116 | assert len(segment_ids) == max_seq_length
117 |
118 | masked_lm_positions = list(instance.masked_lm_positions)
119 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
120 | masked_lm_weights = [1.0] * len(masked_lm_ids)
121 |
122 | while len(masked_lm_positions) < max_predictions_per_seq:
123 | masked_lm_positions.append(0)
124 | masked_lm_ids.append(0)
125 | masked_lm_weights.append(0.0)
126 |
127 | next_sentence_label = 1 if instance.is_random_next else 0
128 |
129 | features = collections.OrderedDict()
130 | features["input_ids"] = create_int_feature(input_ids)
131 | features["input_mask"] = create_int_feature(input_mask)
132 | features["segment_ids"] = create_int_feature(segment_ids)
133 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
134 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
135 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
136 | features["next_sentence_labels"] = create_int_feature([next_sentence_label])
137 |
138 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
139 |
140 | writers[writer_index].write(tf_example.SerializeToString())
141 | writer_index = (writer_index + 1) % len(writers)
142 |
143 | total_written += 1
144 |
145 | if inst_index < 20:
146 | tf.logging.info("*** Example ***")
147 | tf.logging.info("tokens: %s" % " ".join(
148 | [tokenization.printable_text(x) for x in instance.tokens]))
149 |
150 | for feature_name in features.keys():
151 | feature = features[feature_name]
152 | values = []
153 | if feature.int64_list.value:
154 | values = feature.int64_list.value
155 | elif feature.float_list.value:
156 | values = feature.float_list.value
157 | tf.logging.info(
158 | "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
159 |
160 | for writer in writers:
161 | writer.close()
162 |
163 | tf.logging.info("Wrote %d total instances", total_written)
164 |
165 |
166 | def create_int_feature(values):
167 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
168 | return feature
169 |
170 |
171 | def create_float_feature(values):
172 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
173 | return feature
174 |
175 |
176 | def create_training_instances(input_files, tokenizer, max_seq_length,
177 | dupe_factor, short_seq_prob, masked_lm_prob,
178 | max_predictions_per_seq, rng):
179 | """Create `TrainingInstance`s from raw text."""
180 | all_documents = [[]]
181 |
182 | # Input file format:
183 | # (1) One sentence per line. These should ideally be actual sentences, not
184 | # entire paragraphs or arbitrary spans of text. (Because we use the
185 | # sentence boundaries for the "next sentence prediction" task).
186 | # (2) Blank lines between documents. Document boundaries are needed so
187 | # that the "next sentence prediction" task doesn't span between documents.
188 | for input_file in input_files:
189 | with tf.gfile.GFile(input_file, "r") as reader:
190 | while True:
191 | line = tokenization.convert_to_unicode(reader.readline())
192 | if not line:
193 | break
194 | line = line.strip()
195 |
196 | # Empty lines are used as document delimiters
197 | if not line:
198 | all_documents.append([])
199 | tokens = tokenizer.tokenize(line)
200 | if tokens:
201 | all_documents[-1].append(tokens)
202 |
203 | # Remove empty documents
204 | all_documents = [x for x in all_documents if x]
205 | rng.shuffle(all_documents)
206 |
207 | vocab_words = list(tokenizer.vocab.keys())
208 | instances = []
209 | for _ in range(dupe_factor):
210 | for document_index in range(len(all_documents)):
211 | instances.extend(
212 | create_instances_from_document(
213 | all_documents, document_index, max_seq_length, short_seq_prob,
214 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
215 |
216 | rng.shuffle(instances)
217 | return instances
218 |
219 |
220 | def create_instances_from_document(
221 | all_documents, document_index, max_seq_length, short_seq_prob,
222 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
223 | """Creates `TrainingInstance`s for a single document."""
224 | document = all_documents[document_index]
225 |
226 | # Account for [CLS], [SEP], [SEP]
227 | max_num_tokens = max_seq_length - 3
228 |
229 | # We *usually* want to fill up the entire sequence since we are padding
230 | # to `max_seq_length` anyways, so short sequences are generally wasted
231 | # computation. However, we *sometimes*
232 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
233 | # sequences to minimize the mismatch between pre-training and fine-tuning.
234 | # The `target_seq_length` is just a rough target however, whereas
235 | # `max_seq_length` is a hard limit.
236 | target_seq_length = max_num_tokens
237 | if rng.random() < short_seq_prob:
238 | target_seq_length = rng.randint(2, max_num_tokens)
239 |
240 | # We DON'T just concatenate all of the tokens from a document into a long
241 | # sequence and choose an arbitrary split point because this would make the
242 | # next sentence prediction task too easy. Instead, we split the input into
243 | # segments "A" and "B" based on the actual "sentences" provided by the user
244 | # input.
245 | instances = []
246 | current_chunk = []
247 | current_length = 0
248 | i = 0
249 | while i < len(document):
250 | segment = document[i]
251 | current_chunk.append(segment)
252 | current_length += len(segment)
253 | if i == len(document) - 1 or current_length >= target_seq_length:
254 | if current_chunk:
255 | # `a_end` is how many segments from `current_chunk` go into the `A`
256 | # (first) sentence.
257 | a_end = 1
258 | if len(current_chunk) >= 2:
259 | a_end = rng.randint(1, len(current_chunk) - 1)
260 |
261 | tokens_a = []
262 | for j in range(a_end):
263 | tokens_a.extend(current_chunk[j])
264 |
265 | tokens_b = []
266 | # Random next
267 | is_random_next = False
268 | if len(current_chunk) == 1 or rng.random() < 0.5:
269 | is_random_next = True
270 | target_b_length = target_seq_length - len(tokens_a)
271 |
272 | # This should rarely go for more than one iteration for large
273 | # corpora. However, just to be careful, we try to make sure that
274 | # the random document is not the same as the document
275 | # we're processing.
276 | for _ in range(10):
277 | random_document_index = rng.randint(0, len(all_documents) - 1)
278 | if random_document_index != document_index:
279 | break
280 |
281 | random_document = all_documents[random_document_index]
282 | random_start = rng.randint(0, len(random_document) - 1)
283 | for j in range(random_start, len(random_document)):
284 | tokens_b.extend(random_document[j])
285 | if len(tokens_b) >= target_b_length:
286 | break
287 | # We didn't actually use these segments so we "put them back" so
288 | # they don't go to waste.
289 | num_unused_segments = len(current_chunk) - a_end
290 | i -= num_unused_segments
291 | # Actual next
292 | else:
293 | is_random_next = False
294 | for j in range(a_end, len(current_chunk)):
295 | tokens_b.extend(current_chunk[j])
296 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
297 |
298 | assert len(tokens_a) >= 1
299 | assert len(tokens_b) >= 1
300 |
301 | tokens = []
302 | segment_ids = []
303 | tokens.append("[CLS]")
304 | segment_ids.append(0)
305 | for token in tokens_a:
306 | tokens.append(token)
307 | segment_ids.append(0)
308 |
309 | tokens.append("[SEP]")
310 | segment_ids.append(0)
311 |
312 | for token in tokens_b:
313 | tokens.append(token)
314 | segment_ids.append(1)
315 | tokens.append("[SEP]")
316 | segment_ids.append(1)
317 |
318 | (tokens, masked_lm_positions,
319 | masked_lm_labels) = create_masked_lm_predictions(
320 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
321 | instance = TrainingInstance(
322 | tokens=tokens,
323 | segment_ids=segment_ids,
324 | is_random_next=is_random_next,
325 | masked_lm_positions=masked_lm_positions,
326 | masked_lm_labels=masked_lm_labels)
327 | instances.append(instance)
328 | current_chunk = []
329 | current_length = 0
330 | i += 1
331 |
332 | return instances
333 |
334 |
335 | def create_masked_lm_predictions(tokens, masked_lm_prob,
336 | max_predictions_per_seq, vocab_words, rng):
337 | """Creates the predictions for the masked LM objective."""
338 |
339 | cand_indexes = []
340 | for (i, token) in enumerate(tokens):
341 | if token == "[CLS]" or token == "[SEP]":
342 | continue
343 | cand_indexes.append(i)
344 |
345 | rng.shuffle(cand_indexes)
346 |
347 | output_tokens = list(tokens)
348 |
349 | masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name
350 |
351 | num_to_predict = min(max_predictions_per_seq,
352 | max(1, int(round(len(tokens) * masked_lm_prob))))
353 |
354 | masked_lms = []
355 | covered_indexes = set()
356 | for index in cand_indexes:
357 | if len(masked_lms) >= num_to_predict:
358 | break
359 | if index in covered_indexes:
360 | continue
361 | covered_indexes.add(index)
362 |
363 | masked_token = None
364 | # 80% of the time, replace with [MASK]
365 | if rng.random() < 0.8:
366 | masked_token = "[MASK]"
367 | else:
368 | # 10% of the time, keep original
369 | if rng.random() < 0.5:
370 | masked_token = tokens[index]
371 | # 10% of the time, replace with random word
372 | else:
373 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
374 |
375 | output_tokens[index] = masked_token
376 |
377 | masked_lms.append(masked_lm(index=index, label=tokens[index]))
378 |
379 | masked_lms = sorted(masked_lms, key=lambda x: x.index)
380 |
381 | masked_lm_positions = []
382 | masked_lm_labels = []
383 | for p in masked_lms:
384 | masked_lm_positions.append(p.index)
385 | masked_lm_labels.append(p.label)
386 |
387 | return (output_tokens, masked_lm_positions, masked_lm_labels)
388 |
389 |
390 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
391 | """Truncates a pair of sequences to a maximum sequence length."""
392 | while True:
393 | total_length = len(tokens_a) + len(tokens_b)
394 | if total_length <= max_num_tokens:
395 | break
396 |
397 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
398 | assert len(trunc_tokens) >= 1
399 |
400 | # We want to sometimes truncate from the front and sometimes from the
401 | # back to add more randomness and avoid biases.
402 | if rng.random() < 0.5:
403 | del trunc_tokens[0]
404 | else:
405 | trunc_tokens.pop()
406 |
407 |
408 | def main(_):
409 | tf.logging.set_verbosity(tf.logging.INFO)
410 |
411 | tokenizer = tokenization.FullTokenizer(
412 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
413 |
414 | input_files = []
415 | for input_pattern in FLAGS.input_file.split(","):
416 | input_files.extend(tf.gfile.Glob(input_pattern))
417 |
418 | tf.logging.info("*** Reading from input files ***")
419 | for input_file in input_files:
420 | tf.logging.info(" %s", input_file)
421 |
422 | rng = random.Random(FLAGS.random_seed)
423 | instances = create_training_instances(
424 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
425 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
426 | rng)
427 |
428 | output_files = FLAGS.output_file.split(",")
429 | tf.logging.info("*** Writing to output files ***")
430 | for output_file in output_files:
431 | tf.logging.info(" %s", output_file)
432 |
433 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
434 | FLAGS.max_predictions_per_seq, output_files)
435 |
436 |
437 | if __name__ == "__main__":
438 | flags.mark_flag_as_required("input_file")
439 | flags.mark_flag_as_required("output_file")
440 | flags.mark_flag_as_required("vocab_file")
441 | tf.app.run()
442 |
--------------------------------------------------------------------------------
/extract_features.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Extract pre-computed feature vectors from BERT."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import codecs
22 | import collections
23 | import json
24 | import re
25 |
26 | import modeling
27 | import tokenization
28 | import tensorflow as tf
29 |
30 | flags = tf.flags
31 |
32 | FLAGS = flags.FLAGS
33 |
34 | flags.DEFINE_string("input_file", None, "")
35 |
36 | flags.DEFINE_string("output_file", None, "")
37 |
38 | flags.DEFINE_string("layers", "-1,-2,-3,-4", "")
39 |
40 | flags.DEFINE_string(
41 | "bert_config_file", None,
42 | "The config json file corresponding to the pre-trained BERT model. "
43 | "This specifies the model architecture.")
44 |
45 | flags.DEFINE_integer(
46 | "max_seq_length", 128,
47 | "The maximum total input sequence length after WordPiece tokenization. "
48 | "Sequences longer than this will be truncated, and sequences shorter "
49 | "than this will be padded.")
50 |
51 | flags.DEFINE_string(
52 | "init_checkpoint", None,
53 | "Initial checkpoint (usually from a pre-trained BERT model).")
54 |
55 | flags.DEFINE_string("vocab_file", None,
56 | "The vocabulary file that the BERT model was trained on.")
57 |
58 | flags.DEFINE_bool(
59 | "do_lower_case", True,
60 | "Whether to lower case the input text. Should be True for uncased "
61 | "models and False for cased models.")
62 |
63 | flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.")
64 |
65 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
66 |
67 | flags.DEFINE_string("master", None,
68 | "If using a TPU, the address of the master.")
69 |
70 | flags.DEFINE_integer(
71 | "num_tpu_cores", 8,
72 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
73 |
74 | flags.DEFINE_bool(
75 | "use_one_hot_embeddings", False,
76 | "If True, tf.one_hot will be used for embedding lookups, otherwise "
77 | "tf.nn.embedding_lookup will be used. On TPUs, this should be True "
78 | "since it is much faster.")
79 |
80 |
81 | class InputExample(object):
82 |
83 | def __init__(self, unique_id, text_a, text_b):
84 | self.unique_id = unique_id
85 | self.text_a = text_a
86 | self.text_b = text_b
87 |
88 |
89 | class InputFeatures(object):
90 | """A single set of features of data."""
91 |
92 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
93 | self.unique_id = unique_id
94 | self.tokens = tokens
95 | self.input_ids = input_ids
96 | self.input_mask = input_mask
97 | self.input_type_ids = input_type_ids
98 |
99 |
100 | def input_fn_builder(features, seq_length):
101 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
102 |
103 | all_unique_ids = []
104 | all_input_ids = []
105 | all_input_mask = []
106 | all_input_type_ids = []
107 |
108 | for feature in features:
109 | all_unique_ids.append(feature.unique_id)
110 | all_input_ids.append(feature.input_ids)
111 | all_input_mask.append(feature.input_mask)
112 | all_input_type_ids.append(feature.input_type_ids)
113 |
114 | def input_fn(params):
115 | """The actual input function."""
116 | batch_size = params["batch_size"]
117 |
118 | num_examples = len(features)
119 |
120 | # This is for demo purposes and does NOT scale to large data sets. We do
121 | # not use Dataset.from_generator() because that uses tf.py_func which is
122 | # not TPU compatible. The right way to load data is with TFRecordReader.
123 | d = tf.data.Dataset.from_tensor_slices({
124 | "unique_ids":
125 | tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32),
126 | "input_ids":
127 | tf.constant(
128 | all_input_ids, shape=[num_examples, seq_length],
129 | dtype=tf.int32),
130 | "input_mask":
131 | tf.constant(
132 | all_input_mask,
133 | shape=[num_examples, seq_length],
134 | dtype=tf.int32),
135 | "input_type_ids":
136 | tf.constant(
137 | all_input_type_ids,
138 | shape=[num_examples, seq_length],
139 | dtype=tf.int32),
140 | })
141 |
142 | d = d.batch(batch_size=batch_size, drop_remainder=False)
143 | return d
144 |
145 | return input_fn
146 |
147 |
148 | def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu,
149 | use_one_hot_embeddings):
150 | """Returns `model_fn` closure for TPUEstimator."""
151 |
152 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
153 | """The `model_fn` for TPUEstimator."""
154 |
155 | unique_ids = features["unique_ids"]
156 | input_ids = features["input_ids"]
157 | input_mask = features["input_mask"]
158 | input_type_ids = features["input_type_ids"]
159 |
160 | model = modeling.BertModel(
161 | config=bert_config,
162 | is_training=False,
163 | input_ids=input_ids,
164 | input_mask=input_mask,
165 | token_type_ids=input_type_ids,
166 | use_one_hot_embeddings=use_one_hot_embeddings)
167 |
168 | if mode != tf.estimator.ModeKeys.PREDICT:
169 | raise ValueError("Only PREDICT modes are supported: %s" % (mode))
170 |
171 | tvars = tf.trainable_variables()
172 | scaffold_fn = None
173 | (assignment_map, _) = modeling.get_assignment_map_from_checkpoint(
174 | tvars, init_checkpoint)
175 | if use_tpu:
176 |
177 | def tpu_scaffold():
178 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
179 | return tf.train.Scaffold()
180 |
181 | scaffold_fn = tpu_scaffold
182 | else:
183 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
184 |
185 | all_layers = model.get_all_encoder_layers()
186 |
187 | predictions = {
188 | "unique_id": unique_ids,
189 | }
190 |
191 | for (i, layer_index) in enumerate(layer_indexes):
192 | predictions["layer_output_%d" % i] = all_layers[layer_index]
193 |
194 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
195 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
196 | return output_spec
197 |
198 | return model_fn
199 |
200 |
201 | def convert_examples_to_features(examples, seq_length, tokenizer):
202 | """Loads a data file into a list of `InputBatch`s."""
203 |
204 | features = []
205 | for (ex_index, example) in enumerate(examples):
206 | tokens_a = tokenizer.tokenize(example.text_a)
207 |
208 | tokens_b = None
209 | if example.text_b:
210 | tokens_b = tokenizer.tokenize(example.text_b)
211 |
212 | if tokens_b:
213 | # Modifies `tokens_a` and `tokens_b` in place so that the total
214 | # length is less than the specified length.
215 | # Account for [CLS], [SEP], [SEP] with "- 3"
216 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
217 | else:
218 | # Account for [CLS] and [SEP] with "- 2"
219 | if len(tokens_a) > seq_length - 2:
220 | tokens_a = tokens_a[0:(seq_length - 2)]
221 |
222 | # The convention in BERT is:
223 | # (a) For sequence pairs:
224 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
225 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
226 | # (b) For single sequences:
227 | # tokens: [CLS] the dog is hairy . [SEP]
228 | # type_ids: 0 0 0 0 0 0 0
229 | #
230 | # Where "type_ids" are used to indicate whether this is the first
231 | # sequence or the second sequence. The embedding vectors for `type=0` and
232 | # `type=1` were learned during pre-training and are added to the wordpiece
233 | # embedding vector (and position vector). This is not *strictly* necessary
234 | # since the [SEP] token unambiguously separates the sequences, but it makes
235 | # it easier for the model to learn the concept of sequences.
236 | #
237 | # For classification tasks, the first vector (corresponding to [CLS]) is
238 | # used as as the "sentence vector". Note that this only makes sense because
239 | # the entire model is fine-tuned.
240 | tokens = []
241 | input_type_ids = []
242 | tokens.append("[CLS]")
243 | input_type_ids.append(0)
244 | for token in tokens_a:
245 | tokens.append(token)
246 | input_type_ids.append(0)
247 | tokens.append("[SEP]")
248 | input_type_ids.append(0)
249 |
250 | if tokens_b:
251 | for token in tokens_b:
252 | tokens.append(token)
253 | input_type_ids.append(1)
254 | tokens.append("[SEP]")
255 | input_type_ids.append(1)
256 |
257 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
258 |
259 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
260 | # tokens are attended to.
261 | input_mask = [1] * len(input_ids)
262 |
263 | # Zero-pad up to the sequence length.
264 | while len(input_ids) < seq_length:
265 | input_ids.append(0)
266 | input_mask.append(0)
267 | input_type_ids.append(0)
268 |
269 | assert len(input_ids) == seq_length
270 | assert len(input_mask) == seq_length
271 | assert len(input_type_ids) == seq_length
272 |
273 | if ex_index < 5:
274 | tf.logging.info("*** Example ***")
275 | tf.logging.info("unique_id: %s" % (example.unique_id))
276 | tf.logging.info("tokens: %s" % " ".join([str(x) for x in tokens]))
277 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
278 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
279 | tf.logging.info(
280 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
281 |
282 | features.append(
283 | InputFeatures(
284 | unique_id=example.unique_id,
285 | tokens=tokens,
286 | input_ids=input_ids,
287 | input_mask=input_mask,
288 | input_type_ids=input_type_ids))
289 | return features
290 |
291 |
292 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
293 | """Truncates a sequence pair in place to the maximum length."""
294 |
295 | # This is a simple heuristic which will always truncate the longer sequence
296 | # one token at a time. This makes more sense than truncating an equal percent
297 | # of tokens from each, since if one sequence is very short then each token
298 | # that's truncated likely contains more information than a longer sequence.
299 | while True:
300 | total_length = len(tokens_a) + len(tokens_b)
301 | if total_length <= max_length:
302 | break
303 | if len(tokens_a) > len(tokens_b):
304 | tokens_a.pop()
305 | else:
306 | tokens_b.pop()
307 |
308 |
309 | def read_examples(input_file):
310 | """Read a list of `InputExample`s from an input file."""
311 | examples = []
312 | unique_id = 0
313 | with tf.gfile.GFile(input_file, "r") as reader:
314 | while True:
315 | line = tokenization.convert_to_unicode(reader.readline())
316 | if not line:
317 | break
318 | line = line.strip()
319 | text_a = None
320 | text_b = None
321 | m = re.match(r"^(.*) \|\|\| (.*)$", line)
322 | if m is None:
323 | text_a = line
324 | else:
325 | text_a = m.group(1)
326 | text_b = m.group(2)
327 | examples.append(
328 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
329 | unique_id += 1
330 | return examples
331 |
332 |
333 | def main(_):
334 | tf.logging.set_verbosity(tf.logging.INFO)
335 |
336 | layer_indexes = [int(x) for x in FLAGS.layers.split(",")]
337 |
338 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
339 |
340 | tokenizer = tokenization.FullTokenizer(
341 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
342 |
343 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
344 | run_config = tf.contrib.tpu.RunConfig(
345 | master=FLAGS.master,
346 | tpu_config=tf.contrib.tpu.TPUConfig(
347 | num_shards=FLAGS.num_tpu_cores,
348 | per_host_input_for_training=is_per_host))
349 |
350 | examples = read_examples(FLAGS.input_file)
351 |
352 | features = convert_examples_to_features(
353 | examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer)
354 |
355 | unique_id_to_feature = {}
356 | for feature in features:
357 | unique_id_to_feature[feature.unique_id] = feature
358 |
359 | model_fn = model_fn_builder(
360 | bert_config=bert_config,
361 | init_checkpoint=FLAGS.init_checkpoint,
362 | layer_indexes=layer_indexes,
363 | use_tpu=FLAGS.use_tpu,
364 | use_one_hot_embeddings=FLAGS.use_one_hot_embeddings)
365 |
366 | # If TPU is not available, this will fall back to normal Estimator on CPU
367 | # or GPU.
368 | estimator = tf.contrib.tpu.TPUEstimator(
369 | use_tpu=FLAGS.use_tpu,
370 | model_fn=model_fn,
371 | config=run_config,
372 | predict_batch_size=FLAGS.batch_size)
373 |
374 | input_fn = input_fn_builder(
375 | features=features, seq_length=FLAGS.max_seq_length)
376 |
377 | with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file,
378 | "w")) as writer:
379 | for result in estimator.predict(input_fn, yield_single_examples=True):
380 | unique_id = int(result["unique_id"])
381 | feature = unique_id_to_feature[unique_id]
382 | output_json = collections.OrderedDict()
383 | output_json["linex_index"] = unique_id
384 | all_features = []
385 | for (i, token) in enumerate(feature.tokens):
386 | all_layers = []
387 | for (j, layer_index) in enumerate(layer_indexes):
388 | layer_output = result["layer_output_%d" % j]
389 | layers = collections.OrderedDict()
390 | layers["index"] = layer_index
391 | layers["values"] = [
392 | round(float(x), 6) for x in layer_output[i:(i + 1)].flat
393 | ]
394 | all_layers.append(layers)
395 | features = collections.OrderedDict()
396 | features["token"] = token
397 | features["layers"] = all_layers
398 | all_features.append(features)
399 | output_json["features"] = all_features
400 | writer.write(json.dumps(output_json) + "\n")
401 |
402 |
403 | if __name__ == "__main__":
404 | flags.mark_flag_as_required("input_file")
405 | flags.mark_flag_as_required("vocab_file")
406 | flags.mark_flag_as_required("bert_config_file")
407 | flags.mark_flag_as_required("init_checkpoint")
408 | flags.mark_flag_as_required("output_file")
409 | tf.app.run()
410 |
--------------------------------------------------------------------------------
/general_utils.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 |
3 |
4 | import logging
5 | import sys
6 | import time
7 | import numpy as np
8 |
9 |
10 | def get_logger(filename):
11 | logger = logging.getLogger('logger')
12 | logger.setLevel(logging.DEBUG)
13 | logging.basicConfig(format='%(message)s', level=logging.DEBUG)
14 | handler = logging.FileHandler(filename)
15 | handler.setLevel(logging.DEBUG)
16 | handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
17 | logging.getLogger().addHandler(handler)
18 |
19 | return logger
20 |
21 |
22 | class Progbar(object):
23 | """Progbar class copied from keras (https://github.com/fchollet/keras/)
24 |
25 | Displays a progress bar.
26 | Small edit : added strict arg to update
27 | # Arguments
28 | target: Total number of steps expected.
29 | interval: Minimum visual progress update interval (in seconds).
30 | """
31 |
32 | def __init__(self, target, width=30, verbose=1):
33 | self.width = width
34 | self.target = target
35 | self.sum_values = {}
36 | self.unique_values = []
37 | self.start = time.time()
38 | self.total_width = 0
39 | self.seen_so_far = 0
40 | self.verbose = verbose
41 |
42 | def update(self, current, values=[], exact=[], strict=[]):
43 | """
44 | Updates the progress bar.
45 | # Arguments
46 | current: Index of current step.
47 | values: List of tuples (name, value_for_last_step).
48 | The progress bar will display averages for these values.
49 | exact: List of tuples (name, value_for_last_step).
50 | The progress bar will display these values directly.
51 | """
52 |
53 | for k, v in values:
54 | if k not in self.sum_values:
55 | self.sum_values[k] = [v * (current - self.seen_so_far),
56 | current - self.seen_so_far]
57 | self.unique_values.append(k)
58 | else:
59 | self.sum_values[k][0] += v * (current - self.seen_so_far)
60 | self.sum_values[k][1] += (current - self.seen_so_far)
61 | for k, v in exact:
62 | if k not in self.sum_values:
63 | self.unique_values.append(k)
64 | self.sum_values[k] = [v, 1]
65 |
66 | for k, v in strict:
67 | if k not in self.sum_values:
68 | self.unique_values.append(k)
69 | self.sum_values[k] = v
70 |
71 | self.seen_so_far = current
72 |
73 | now = time.time()
74 | if self.verbose == 1:
75 | prev_total_width = self.total_width
76 | sys.stdout.write("\b" * prev_total_width)
77 | sys.stdout.write("\r")
78 |
79 | numdigits = int(np.floor(np.log10(self.target))) + 1
80 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits)
81 | bar = barstr % (current, self.target)
82 | prog = float(current)/self.target
83 | prog_width = int(self.width*prog)
84 | if prog_width > 0:
85 | bar += ('='*(prog_width-1))
86 | if current < self.target:
87 | bar += '>'
88 | else:
89 | bar += '='
90 | bar += ('.'*(self.width-prog_width))
91 | bar += ']'
92 | sys.stdout.write(bar)
93 | self.total_width = len(bar)
94 |
95 | if current:
96 | time_per_unit = (now - self.start) / current
97 | else:
98 | time_per_unit = 0
99 | eta = time_per_unit*(self.target - current)
100 | info = ''
101 | if current < self.target:
102 | info += ' - ETA: %ds' % eta
103 | else:
104 | info += ' - %ds' % (now - self.start)
105 | for k in self.unique_values:
106 | if type(self.sum_values[k]) is list:
107 | info += ' - %s: %.4f' % (k,
108 | self.sum_values[k][0] / max(1, self.sum_values[k][1]))
109 | else:
110 | info += ' - %s: %s' % (k, self.sum_values[k])
111 |
112 | self.total_width += len(info)
113 | if prev_total_width > self.total_width:
114 | info += ((prev_total_width-self.total_width) * " ")
115 |
116 | sys.stdout.write(info)
117 | sys.stdout.flush()
118 |
119 | if current >= self.target:
120 | sys.stdout.write("\n")
121 |
122 | if self.verbose == 2:
123 | if current >= self.target:
124 | info = '%ds' % (now - self.start)
125 | for k in self.unique_values:
126 | info += ' - %s: %.4f' % (k,
127 | self.sum_values[k][0] / max(1, self.sum_values[k][1]))
128 | sys.stdout.write(info + "\n")
129 |
130 | def add(self, n, values=[]):
131 | self.update(self.seen_so_far+n, values)
--------------------------------------------------------------------------------
/get_started/custom_estimator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """An Example of a custom Estimator for the Iris dataset."""
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import argparse
20 | import tensorflow as tf
21 |
22 | import iris_data
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--batch_size', default=100, type=int, help='batch size')
26 | parser.add_argument('--train_steps', default=1000, type=int,
27 | help='number of training steps')
28 |
29 | def my_model(features, labels, mode, params):
30 | """DNN with three hidden layers, and dropout of 0.1 probability."""
31 | # Create three fully connected layers each layer having a dropout
32 | # probability of 0.1.
33 | net = tf.feature_column.input_layer(features, params['feature_columns'])
34 | for units in params['hidden_units']:
35 | net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
36 |
37 | # Compute logits (1 per class).
38 | logits = tf.layers.dense(net, params['n_classes'], activation=None)
39 |
40 | # Compute predictions.
41 | predicted_classes = tf.argmax(logits, 1)
42 | if mode == tf.estimator.ModeKeys.PREDICT:
43 | predictions = {
44 | 'class_ids': predicted_classes[:, tf.newaxis],
45 | 'probabilities': tf.nn.softmax(logits),
46 | 'logits': logits,
47 | }
48 | return tf.estimator.EstimatorSpec(mode, predictions=predictions)
49 |
50 | # Compute loss.
51 | loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
52 |
53 | # Compute evaluation metrics.
54 | accuracy = tf.metrics.accuracy(labels=labels,
55 | predictions=predicted_classes,
56 | name='acc_op')
57 | metrics = {'accuracy': accuracy}
58 | tf.summary.scalar('accuracy', accuracy[1])
59 |
60 | if mode == tf.estimator.ModeKeys.EVAL:
61 | return tf.estimator.EstimatorSpec(
62 | mode, loss=loss, eval_metric_ops=metrics)
63 |
64 | # Create training op.
65 | assert mode == tf.estimator.ModeKeys.TRAIN
66 |
67 | optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
68 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
69 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
70 |
71 |
72 | def main(argv):
73 | args = parser.parse_args(argv[1:])
74 |
75 | # Fetch the data
76 | (train_x, train_y), (test_x, test_y) = iris_data.load_data()
77 |
78 | # Feature columns describe how to use the input.
79 | my_feature_columns = []
80 | for key in train_x.keys():
81 | my_feature_columns.append(tf.feature_column.numeric_column(key=key))
82 |
83 | # Build 2 hidden layer DNN with 10, 10 units respectively.
84 | classifier = tf.estimator.Estimator(
85 | model_fn=my_model,
86 | params={
87 | 'feature_columns': my_feature_columns,
88 | # Two hidden layers of 10 nodes each.
89 | 'hidden_units': [10, 10],
90 | # The model must choose between 3 classes.
91 | 'n_classes': 3,
92 | })
93 |
94 | # Train the Model.
95 | classifier.train(
96 | input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
97 | steps=args.train_steps)
98 |
99 | # Evaluate the model.
100 | eval_result = classifier.evaluate(
101 | input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))
102 |
103 | print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
104 |
105 | # Generate predictions from the model
106 | expected = ['Setosa', 'Versicolor', 'Virginica']
107 | predict_x = {
108 | 'SepalLength': [5.1, 5.9, 6.9],
109 | 'SepalWidth': [3.3, 3.0, 3.1],
110 | 'PetalLength': [1.7, 4.2, 5.4],
111 | 'PetalWidth': [0.5, 1.5, 2.1],
112 | }
113 |
114 | predictions = classifier.predict(
115 | input_fn=lambda:iris_data.eval_input_fn(predict_x,
116 | labels=None,
117 | batch_size=args.batch_size))
118 |
119 | for pred_dict, expec in zip(predictions, expected):
120 | template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')
121 |
122 | class_id = pred_dict['class_ids'][0]
123 | probability = pred_dict['probabilities'][class_id]
124 |
125 | print(template.format(iris_data.SPECIES[class_id],
126 | 100 * probability, expec))
127 |
128 |
129 | if __name__ == '__main__':
130 | tf.logging.set_verbosity(tf.logging.INFO)
131 | tf.app.run(main)
132 |
--------------------------------------------------------------------------------
/get_started/estimator_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """A simple smoke test that runs these examples for 1 training iteraton."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 | import pandas as pd
23 |
24 | from six.moves import StringIO
25 |
26 | import iris_data
27 | import custom_estimator
28 | import premade_estimator
29 |
30 | FOUR_LINES = "\n".join([
31 | "1,52.40, 2823,152,2",
32 | "164, 99.80,176.60,66.20,1",
33 | "176,2824, 136,3.19,0",
34 | "2,177.30,66.30, 53.10,1",])
35 |
36 | def four_lines_data():
37 | text = StringIO(FOUR_LINES)
38 |
39 | df = pd.read_csv(text, names=iris_data.CSV_COLUMN_NAMES)
40 |
41 | xy = (df, df.pop("Species"))
42 | return xy, xy
43 |
44 |
45 | class RegressionTest(tf.test.TestCase):
46 | """Test the regression examples in this directory."""
47 |
48 | @tf.test.mock.patch.dict(premade_estimator.__dict__,
49 | {"load_data": four_lines_data})
50 | def test_premade_estimator(self):
51 | premade_estimator.main([None, "--train_steps=1"])
52 |
53 | @tf.test.mock.patch.dict(custom_estimator.__dict__,
54 | {"load_data": four_lines_data})
55 | def test_custom_estimator(self):
56 | custom_estimator.main([None, "--train_steps=1"])
57 |
58 | if __name__ == "__main__":
59 | tf.test.main()
60 |
--------------------------------------------------------------------------------
/get_started/iris_data.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import tensorflow as tf
3 |
4 | TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
5 | TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
6 |
7 | CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth',
8 | 'PetalLength', 'PetalWidth', 'Species']
9 | SPECIES = ['Sentosa', 'Versicolor', 'Virginica']
10 |
11 | def maybe_download():
12 | train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)
13 | test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)
14 |
15 | return train_path, test_path
16 |
17 | def load_data(y_name='Species'):
18 | """Returns the iris dataset as (train_x, train_y), (test_x, test_y)."""
19 | train_path, test_path = maybe_download()
20 |
21 | train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
22 | train_x, train_y = train, train.pop(y_name)
23 |
24 | test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
25 | test_x, test_y = test, test.pop(y_name)
26 |
27 | return (train_x, train_y), (test_x, test_y)
28 |
29 |
30 | def train_input_fn(features, labels, batch_size):
31 | """An input function for training"""
32 | # Convert the inputs to a Dataset.
33 | dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
34 |
35 | # Shuffle, repeat, and batch the examples.
36 | dataset = dataset.shuffle(1000).repeat().batch(batch_size)
37 |
38 | # Return the read end of the pipeline.
39 | return dataset.make_one_shot_iterator().get_next()
40 |
41 |
42 | def eval_input_fn(features, labels, batch_size):
43 | """An input function for evaluation or prediction"""
44 | features=dict(features)
45 | if labels is None:
46 | # No labels, use only features.
47 | inputs = features
48 | else:
49 | inputs = (features, labels)
50 |
51 | # Convert the inputs to a Dataset.
52 | dataset = tf.data.Dataset.from_tensor_slices(inputs)
53 |
54 | # Batch the examples
55 | assert batch_size is not None, "batch_size must not be None"
56 | dataset = dataset.batch(batch_size)
57 |
58 | # Return the read end of the pipeline.
59 | return dataset.make_one_shot_iterator().get_next()
60 |
61 |
62 | # The remainder of this file contains a simple example of a csv parser,
63 | # implemented using a the `Dataset` class.
64 |
65 | # `tf.parse_csv` sets the types of the outputs to match the examples given in
66 | # the `record_defaults` argument.
67 | CSV_TYPES = [[0.0], [0.0], [0.0], [0.0], [0]]
68 |
69 | def _parse_line(line):
70 | # Decode the line into its fields
71 | fields = tf.decode_csv(line, record_defaults=CSV_TYPES)
72 |
73 | # Pack the result into a dictionary
74 | features = dict(zip(CSV_COLUMN_NAMES, fields))
75 |
76 | # Separate the label from the features
77 | label = features.pop('Species')
78 |
79 | return features, label
80 |
81 |
82 | def csv_input_fn(csv_path, batch_size):
83 | # Create a dataset containing the text lines.
84 | dataset = tf.data.TextLineDataset(csv_path).skip(1)
85 |
86 | # Parse each line.
87 | dataset = dataset.map(_parse_line)
88 |
89 | # Shuffle, repeat, and batch the examples.
90 | dataset = dataset.shuffle(1000).repeat().batch(batch_size)
91 |
92 | # Return the read end of the pipeline.
93 | return dataset.make_one_shot_iterator().get_next()
--------------------------------------------------------------------------------
/get_started/premade_estimator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """An Example of a DNNClassifier for the Iris dataset."""
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import argparse
20 | import tensorflow as tf
21 |
22 | import iris_data
23 |
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('--batch_size', default=100, type=int, help='batch size')
27 | parser.add_argument('--train_steps', default=1000, type=int,
28 | help='number of training steps')
29 |
30 | def main(argv):
31 | args = parser.parse_args(argv[1:])
32 |
33 | # Fetch the data
34 | (train_x, train_y), (test_x, test_y) = iris_data.load_data()
35 |
36 | # Feature columns describe how to use the input.
37 | my_feature_columns = []
38 | for key in train_x.keys():
39 | my_feature_columns.append(tf.feature_column.numeric_column(key=key))
40 |
41 | # Build 2 hidden layer DNN with 10, 10 units respectively.
42 | classifier = tf.estimator.DNNClassifier(
43 | feature_columns=my_feature_columns,
44 | # Two hidden layers of 10 nodes each.
45 | hidden_units=[10, 10],
46 | # The model must choose between 3 classes.
47 | n_classes=3)
48 |
49 | # Train the Model.
50 | classifier.train(
51 | input_fn=lambda:iris_data.train_input_fn(train_x, train_y,
52 | args.batch_size),
53 | steps=args.train_steps)
54 |
55 | # Evaluate the model.
56 | eval_result = classifier.evaluate(
57 | input_fn=lambda:iris_data.eval_input_fn(test_x, test_y,
58 | args.batch_size))
59 |
60 | print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
61 |
62 | # Generate predictions from the model
63 | expected = ['Setosa', 'Versicolor', 'Virginica']
64 | predict_x = {
65 | 'SepalLength': [5.1, 5.9, 6.9],
66 | 'SepalWidth': [3.3, 3.0, 3.1],
67 | 'PetalLength': [1.7, 4.2, 5.4],
68 | 'PetalWidth': [0.5, 1.5, 2.1],
69 | }
70 |
71 | predictions = classifier.predict(
72 | input_fn=lambda:iris_data.eval_input_fn(predict_x,
73 | labels=None,
74 | batch_size=args.batch_size))
75 |
76 | for pred_dict, expec in zip(predictions, expected):
77 | template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"')
78 |
79 | class_id = pred_dict['class_ids'][0]
80 | probability = pred_dict['probabilities'][class_id]
81 |
82 | print(template.format(iris_data.SPECIES[class_id],
83 | 100 * probability, expec))
84 |
85 |
86 | if __name__ == '__main__':
87 | tf.logging.set_verbosity(tf.logging.INFO)
88 | tf.app.run(main)
89 |
--------------------------------------------------------------------------------
/modeling_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import collections
20 | import json
21 | import random
22 | import re
23 |
24 | import modeling
25 | import six
26 | import tensorflow as tf
27 |
28 |
29 | class BertModelTest(tf.test.TestCase):
30 |
31 | class BertModelTester(object):
32 |
33 | def __init__(self,
34 | parent,
35 | batch_size=13,
36 | seq_length=7,
37 | is_training=True,
38 | use_input_mask=True,
39 | use_token_type_ids=True,
40 | vocab_size=99,
41 | hidden_size=32,
42 | num_hidden_layers=5,
43 | num_attention_heads=4,
44 | intermediate_size=37,
45 | hidden_act="gelu",
46 | hidden_dropout_prob=0.1,
47 | attention_probs_dropout_prob=0.1,
48 | max_position_embeddings=512,
49 | type_vocab_size=16,
50 | initializer_range=0.02,
51 | scope=None):
52 | self.parent = parent
53 | self.batch_size = batch_size
54 | self.seq_length = seq_length
55 | self.is_training = is_training
56 | self.use_input_mask = use_input_mask
57 | self.use_token_type_ids = use_token_type_ids
58 | self.vocab_size = vocab_size
59 | self.hidden_size = hidden_size
60 | self.num_hidden_layers = num_hidden_layers
61 | self.num_attention_heads = num_attention_heads
62 | self.intermediate_size = intermediate_size
63 | self.hidden_act = hidden_act
64 | self.hidden_dropout_prob = hidden_dropout_prob
65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
66 | self.max_position_embeddings = max_position_embeddings
67 | self.type_vocab_size = type_vocab_size
68 | self.initializer_range = initializer_range
69 | self.scope = scope
70 |
71 | def create_model(self):
72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length],
73 | self.vocab_size)
74 |
75 | input_mask = None
76 | if self.use_input_mask:
77 | input_mask = BertModelTest.ids_tensor(
78 | [self.batch_size, self.seq_length], vocab_size=2)
79 |
80 | token_type_ids = None
81 | if self.use_token_type_ids:
82 | token_type_ids = BertModelTest.ids_tensor(
83 | [self.batch_size, self.seq_length], self.type_vocab_size)
84 |
85 | config = modeling.BertConfig(
86 | vocab_size=self.vocab_size,
87 | hidden_size=self.hidden_size,
88 | num_hidden_layers=self.num_hidden_layers,
89 | num_attention_heads=self.num_attention_heads,
90 | intermediate_size=self.intermediate_size,
91 | hidden_act=self.hidden_act,
92 | hidden_dropout_prob=self.hidden_dropout_prob,
93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob,
94 | max_position_embeddings=self.max_position_embeddings,
95 | type_vocab_size=self.type_vocab_size,
96 | initializer_range=self.initializer_range)
97 |
98 | model = modeling.BertModel(
99 | config=config,
100 | is_training=self.is_training,
101 | input_ids=input_ids,
102 | input_mask=input_mask,
103 | token_type_ids=token_type_ids,
104 | scope=self.scope)
105 |
106 | outputs = {
107 | "embedding_output": model.get_embedding_output(),
108 | "sequence_output": model.get_sequence_output(),
109 | "pooled_output": model.get_pooled_output(),
110 | "all_encoder_layers": model.get_all_encoder_layers(),
111 | }
112 | return outputs
113 |
114 | def check_output(self, result):
115 | self.parent.assertAllEqual(
116 | result["embedding_output"].shape,
117 | [self.batch_size, self.seq_length, self.hidden_size])
118 |
119 | self.parent.assertAllEqual(
120 | result["sequence_output"].shape,
121 | [self.batch_size, self.seq_length, self.hidden_size])
122 |
123 | self.parent.assertAllEqual(result["pooled_output"].shape,
124 | [self.batch_size, self.hidden_size])
125 |
126 | def test_default(self):
127 | self.run_tester(BertModelTest.BertModelTester(self))
128 |
129 | def test_config_to_json_string(self):
130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37)
131 | obj = json.loads(config.to_json_string())
132 | self.assertEqual(obj["vocab_size"], 99)
133 | self.assertEqual(obj["hidden_size"], 37)
134 |
135 | def run_tester(self, tester):
136 | with self.test_session() as sess:
137 | ops = tester.create_model()
138 | init_op = tf.group(tf.global_variables_initializer(),
139 | tf.local_variables_initializer())
140 | sess.run(init_op)
141 | output_result = sess.run(ops)
142 | tester.check_output(output_result)
143 |
144 | self.assert_all_tensors_reachable(sess, [init_op, ops])
145 |
146 | @classmethod
147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
148 | """Creates a random int32 tensor of the shape within the vocab size."""
149 | if rng is None:
150 | rng = random.Random()
151 |
152 | total_dims = 1
153 | for dim in shape:
154 | total_dims *= dim
155 |
156 | values = []
157 | for _ in range(total_dims):
158 | values.append(rng.randint(0, vocab_size - 1))
159 |
160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name)
161 |
162 | def assert_all_tensors_reachable(self, sess, outputs):
163 | """Checks that all the tensors in the graph are reachable from outputs."""
164 | graph = sess.graph
165 |
166 | ignore_strings = [
167 | "^.*/dilation_rate$",
168 | "^.*/Tensordot/concat$",
169 | "^.*/Tensordot/concat/axis$",
170 | "^testing/.*$",
171 | ]
172 |
173 | ignore_regexes = [re.compile(x) for x in ignore_strings]
174 |
175 | unreachable = self.get_unreachable_ops(graph, outputs)
176 | filtered_unreachable = []
177 | for x in unreachable:
178 | do_ignore = False
179 | for r in ignore_regexes:
180 | m = r.match(x.name)
181 | if m is not None:
182 | do_ignore = True
183 | if do_ignore:
184 | continue
185 | filtered_unreachable.append(x)
186 | unreachable = filtered_unreachable
187 |
188 | self.assertEqual(
189 | len(unreachable), 0, "The following ops are unreachable: %s" %
190 | (" ".join([x.name for x in unreachable])))
191 |
192 | @classmethod
193 | def get_unreachable_ops(cls, graph, outputs):
194 | """Finds all of the tensors in graph that are unreachable from outputs."""
195 | outputs = cls.flatten_recursive(outputs)
196 | output_to_op = collections.defaultdict(list)
197 | op_to_all = collections.defaultdict(list)
198 | assign_out_to_in = collections.defaultdict(list)
199 |
200 | for op in graph.get_operations():
201 | for x in op.inputs:
202 | op_to_all[op.name].append(x.name)
203 | for y in op.outputs:
204 | output_to_op[y.name].append(op.name)
205 | op_to_all[op.name].append(y.name)
206 | if str(op.type) == "Assign":
207 | for y in op.outputs:
208 | for x in op.inputs:
209 | assign_out_to_in[y.name].append(x.name)
210 |
211 | assign_groups = collections.defaultdict(list)
212 | for out_name in assign_out_to_in.keys():
213 | name_group = assign_out_to_in[out_name]
214 | for n1 in name_group:
215 | assign_groups[n1].append(out_name)
216 | for n2 in name_group:
217 | if n1 != n2:
218 | assign_groups[n1].append(n2)
219 |
220 | seen_tensors = {}
221 | stack = [x.name for x in outputs]
222 | while stack:
223 | name = stack.pop()
224 | if name in seen_tensors:
225 | continue
226 | seen_tensors[name] = True
227 |
228 | if name in output_to_op:
229 | for op_name in output_to_op[name]:
230 | if op_name in op_to_all:
231 | for input_name in op_to_all[op_name]:
232 | if input_name not in stack:
233 | stack.append(input_name)
234 |
235 | expanded_names = []
236 | if name in assign_groups:
237 | for assign_name in assign_groups[name]:
238 | expanded_names.append(assign_name)
239 |
240 | for expanded_name in expanded_names:
241 | if expanded_name not in stack:
242 | stack.append(expanded_name)
243 |
244 | unreachable_ops = []
245 | for op in graph.get_operations():
246 | is_unreachable = False
247 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs]
248 | for name in all_names:
249 | if name not in seen_tensors:
250 | is_unreachable = True
251 | if is_unreachable:
252 | unreachable_ops.append(op)
253 | return unreachable_ops
254 |
255 | @classmethod
256 | def flatten_recursive(cls, item):
257 | """Flattens (potentially nested) a tuple/dictionary/list to a list."""
258 | output = []
259 | if isinstance(item, list):
260 | output.extend(item)
261 | elif isinstance(item, tuple):
262 | output.extend(list(item))
263 | elif isinstance(item, dict):
264 | for (_, v) in six.iteritems(item):
265 | output.append(v)
266 | else:
267 | return [item]
268 |
269 | flat_output = []
270 | for x in output:
271 | flat_output.extend(cls.flatten_recursive(x))
272 | return flat_output
273 |
274 |
275 | if __name__ == "__main__":
276 | tf.test.main()
277 |
--------------------------------------------------------------------------------
/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
26 | """Creates an optimizer training op."""
27 | global_step = tf.train.get_or_create_global_step()
28 |
29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
30 |
31 | # Implements linear decay of the learning rate.
32 | learning_rate = tf.train.polynomial_decay(
33 | learning_rate,
34 | global_step,
35 | num_train_steps,
36 | end_learning_rate=0.0,
37 | power=1.0,
38 | cycle=False)
39 |
40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
41 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
42 | if num_warmup_steps:
43 | global_steps_int = tf.cast(global_step, tf.int32)
44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
45 |
46 | global_steps_float = tf.cast(global_steps_int, tf.float32)
47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
48 |
49 | warmup_percent_done = global_steps_float / warmup_steps_float
50 | warmup_learning_rate = init_lr * warmup_percent_done
51 |
52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
53 | learning_rate = (
54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
55 |
56 | # It is recommended that you use this optimizer for fine tuning, since this
57 | # is how the model was trained (note that the Adam m/v variables are NOT
58 | # loaded from init_checkpoint.)
59 | optimizer = AdamWeightDecayOptimizer(
60 | learning_rate=learning_rate,
61 | weight_decay_rate=0.01,
62 | beta_1=0.9,
63 | beta_2=0.999,
64 | epsilon=1e-6,
65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
66 |
67 | if use_tpu:
68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
69 |
70 | tvars = tf.trainable_variables()
71 | grads = tf.gradients(loss, tvars)
72 |
73 | # This is how the model was pre-trained.
74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
75 |
76 | train_op = optimizer.apply_gradients(
77 | zip(grads, tvars), global_step=global_step)
78 |
79 | new_global_step = global_step + 1
80 | train_op = tf.group(train_op, global_step.assign(new_global_step))
81 | return train_op
82 |
83 |
84 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
85 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
86 |
87 | def __init__(self,
88 | learning_rate,
89 | weight_decay_rate=0.0,
90 | beta_1=0.9,
91 | beta_2=0.999,
92 | epsilon=1e-6,
93 | exclude_from_weight_decay=None,
94 | name="AdamWeightDecayOptimizer"):
95 | """Constructs a AdamWeightDecayOptimizer."""
96 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
97 |
98 | self.learning_rate = learning_rate
99 | self.weight_decay_rate = weight_decay_rate
100 | self.beta_1 = beta_1
101 | self.beta_2 = beta_2
102 | self.epsilon = epsilon
103 | self.exclude_from_weight_decay = exclude_from_weight_decay
104 |
105 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
106 | """See base class."""
107 | assignments = []
108 | for (grad, param) in grads_and_vars:
109 | if grad is None or param is None:
110 | continue
111 |
112 | param_name = self._get_variable_name(param.name)
113 |
114 | m = tf.get_variable(
115 | name=param_name + "/adam_m",
116 | shape=param.shape.as_list(),
117 | dtype=tf.float32,
118 | trainable=False,
119 | initializer=tf.zeros_initializer())
120 | v = tf.get_variable(
121 | name=param_name + "/adam_v",
122 | shape=param.shape.as_list(),
123 | dtype=tf.float32,
124 | trainable=False,
125 | initializer=tf.zeros_initializer())
126 |
127 | # Standard Adam update.
128 | next_m = (
129 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
130 | next_v = (
131 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
132 | tf.square(grad)))
133 |
134 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
135 |
136 | # Just adding the square of the weights to the loss function is *not*
137 | # the correct way of using L2 regularization/weight decay with Adam,
138 | # since that will interact with the m and v parameters in strange ways.
139 | #
140 | # Instead we want ot decay the weights in a manner that doesn't interact
141 | # with the m/v parameters. This is equivalent to adding the square
142 | # of the weights to the loss with plain (non-momentum) SGD.
143 | if self._do_use_weight_decay(param_name):
144 | update += self.weight_decay_rate * param
145 |
146 | update_with_lr = self.learning_rate * update
147 |
148 | next_param = param - update_with_lr
149 |
150 | assignments.extend(
151 | [param.assign(next_param),
152 | m.assign(next_m),
153 | v.assign(next_v)])
154 | return tf.group(*assignments, name=name)
155 |
156 | def _do_use_weight_decay(self, param_name):
157 | """Whether to use L2 weight decay for `param_name`."""
158 | if not self.weight_decay_rate:
159 | return False
160 | if self.exclude_from_weight_decay:
161 | for r in self.exclude_from_weight_decay:
162 | if re.search(r, param_name) is not None:
163 | return False
164 | return True
165 |
166 | def _get_variable_name(self, param_name):
167 | """Get the variable name from the tensor name."""
168 | m = re.match("^(.*):\\d+$", param_name)
169 | if m is not None:
170 | param_name = m.group(1)
171 | return param_name
172 |
--------------------------------------------------------------------------------
/optimization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import optimization
20 | import tensorflow as tf
21 |
22 |
23 | class OptimizationTest(tf.test.TestCase):
24 |
25 | def test_adam(self):
26 | with self.test_session() as sess:
27 | w = tf.get_variable(
28 | "w",
29 | shape=[3],
30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1]))
31 | x = tf.constant([0.4, 0.2, -0.5])
32 | loss = tf.reduce_mean(tf.square(x - w))
33 | tvars = tf.trainable_variables()
34 | grads = tf.gradients(loss, tvars)
35 | global_step = tf.train.get_or_create_global_step()
36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2)
37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step)
38 | init_op = tf.group(tf.global_variables_initializer(),
39 | tf.local_variables_initializer())
40 | sess.run(init_op)
41 | for _ in range(100):
42 | sess.run(train_op)
43 | w_np = sess.run(w)
44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
45 |
46 |
47 | if __name__ == "__main__":
48 | tf.test.main()
49 |
--------------------------------------------------------------------------------
/run_classifier.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """BERT finetuning runner."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import csv
23 | import os
24 | import modeling
25 | import optimization
26 | import tokenization
27 | import tensorflow as tf
28 |
29 | flags = tf.flags
30 |
31 | FLAGS = flags.FLAGS
32 |
33 | ## Required parameters
34 | flags.DEFINE_string(
35 | "data_dir", None,
36 | "The input data dir. Should contain the .tsv files (or other data files) "
37 | "for the task.")
38 |
39 | flags.DEFINE_string(
40 | "bert_config_file", None,
41 | "The config json file corresponding to the pre-trained BERT model. "
42 | "This specifies the model architecture.")
43 |
44 | flags.DEFINE_string("task_name", None, "The name of the task to train.")
45 |
46 | flags.DEFINE_string("vocab_file", None,
47 | "The vocabulary file that the BERT model was trained on.")
48 |
49 | flags.DEFINE_string(
50 | "output_dir", None,
51 | "The output directory where the model checkpoints will be written.")
52 |
53 | ## Other parameters
54 |
55 | flags.DEFINE_string(
56 | "init_checkpoint", None,
57 | "Initial checkpoint (usually from a pre-trained BERT model).")
58 |
59 | flags.DEFINE_bool(
60 | "do_lower_case", True,
61 | "Whether to lower case the input text. Should be True for uncased "
62 | "models and False for cased models.")
63 |
64 | flags.DEFINE_integer(
65 | "max_seq_length", 128,
66 | "The maximum total input sequence length after WordPiece tokenization. "
67 | "Sequences longer than this will be truncated, and sequences shorter "
68 | "than this will be padded.")
69 |
70 | flags.DEFINE_bool("do_train", False, "Whether to run training.")
71 |
72 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
73 |
74 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
75 |
76 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
77 |
78 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
79 |
80 | flags.DEFINE_float("num_train_epochs", 3.0,
81 | "Total number of training epochs to perform.")
82 |
83 | flags.DEFINE_float(
84 | "warmup_proportion", 0.1,
85 | "Proportion of training to perform linear learning rate warmup for. "
86 | "E.g., 0.1 = 10% of training.")
87 |
88 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
89 | "How often to save the model checkpoint.")
90 |
91 | flags.DEFINE_integer("iterations_per_loop", 1000,
92 | "How many steps to make in each estimator call.")
93 |
94 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
95 |
96 | tf.flags.DEFINE_string(
97 | "tpu_name", None,
98 | "The Cloud TPU to use for training. This should be either the name "
99 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
100 | "url.")
101 |
102 | tf.flags.DEFINE_string(
103 | "tpu_zone", None,
104 | "[Optional] GCE zone where the Cloud TPU is located in. If not "
105 | "specified, we will attempt to automatically detect the GCE project from "
106 | "metadata.")
107 |
108 | tf.flags.DEFINE_string(
109 | "gcp_project", None,
110 | "[Optional] Project name for the Cloud TPU-enabled project. If not "
111 | "specified, we will attempt to automatically detect the GCE project from "
112 | "metadata.")
113 |
114 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
115 |
116 | flags.DEFINE_integer(
117 | "num_tpu_cores", 8,
118 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
119 |
120 |
121 | class InputExample(object):
122 | """A single training/test example for simple sequence classification."""
123 |
124 | def __init__(self, guid, text_a, text_b=None, label=None):
125 | """Constructs a InputExample.
126 |
127 | Args:
128 | guid: Unique id for the example.
129 | text_a: string. The untokenized text of the first sequence. For single
130 | sequence tasks, only this sequence must be specified.
131 | text_b: (Optional) string. The untokenized text of the second sequence.
132 | Only must be specified for sequence pair tasks.
133 | label: (Optional) string. The label of the example. This should be
134 | specified for train and dev examples, but not for test examples.
135 | """
136 | self.guid = guid
137 | self.text_a = text_a
138 | self.text_b = text_b
139 | self.label = label
140 |
141 |
142 | class InputFeatures(object):
143 | """A single set of features of data."""
144 |
145 | def __init__(self, input_ids, input_mask, segment_ids, label_id):
146 | self.input_ids = input_ids
147 | self.input_mask = input_mask
148 | self.segment_ids = segment_ids
149 | self.label_id = label_id
150 |
151 |
152 | class DataProcessor(object):
153 | """Base class for data converters for sequence classification data sets."""
154 |
155 | def get_train_examples(self, data_dir):
156 | """Gets a collection of `InputExample`s for the train set."""
157 | raise NotImplementedError()
158 |
159 | def get_dev_examples(self, data_dir):
160 | """Gets a collection of `InputExample`s for the dev set."""
161 | raise NotImplementedError()
162 |
163 | def get_labels(self):
164 | """Gets the list of labels for this data set."""
165 | raise NotImplementedError()
166 |
167 | @classmethod
168 | def _read_tsv(cls, input_file, quotechar=None):
169 | """Reads a tab separated value file."""
170 | with tf.gfile.Open(input_file, "r") as f:
171 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
172 | lines = []
173 | for line in reader:
174 | lines.append(line)
175 | return lines
176 |
177 |
178 | class XnliProcessor(DataProcessor):
179 | """Processor for the XNLI data set."""
180 |
181 | def __init__(self):
182 | self.language = "zh"
183 |
184 | def get_train_examples(self, data_dir):
185 | """See base class."""
186 | lines = self._read_tsv(
187 | os.path.join(data_dir, "multinli",
188 | "multinli.train.%s.tsv" % self.language))
189 | examples = []
190 | for (i, line) in enumerate(lines):
191 | if i == 0:
192 | continue
193 | guid = "train-%d" % (i)
194 | text_a = tokenization.convert_to_unicode(line[0])
195 | text_b = tokenization.convert_to_unicode(line[1])
196 | label = tokenization.convert_to_unicode(line[2])
197 | if label == tokenization.convert_to_unicode("contradictory"):
198 | label = tokenization.convert_to_unicode("contradiction")
199 | examples.append(
200 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
201 | return examples
202 |
203 | def get_dev_examples(self, data_dir):
204 | """See base class."""
205 | lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
206 | examples = []
207 | for (i, line) in enumerate(lines):
208 | if i == 0:
209 | continue
210 | guid = "dev-%d" % (i)
211 | language = tokenization.convert_to_unicode(line[0])
212 | if language != tokenization.convert_to_unicode(self.language):
213 | continue
214 | text_a = tokenization.convert_to_unicode(line[6])
215 | text_b = tokenization.convert_to_unicode(line[7])
216 | label = tokenization.convert_to_unicode(line[1])
217 | examples.append(
218 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
219 | return examples
220 |
221 | def get_labels(self):
222 | """See base class."""
223 | return ["contradiction", "entailment", "neutral"]
224 |
225 |
226 | class MnliProcessor(DataProcessor):
227 | """Processor for the MultiNLI data set (GLUE version)."""
228 |
229 | def get_train_examples(self, data_dir):
230 | """See base class."""
231 | return self._create_examples(
232 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
233 |
234 | def get_dev_examples(self, data_dir):
235 | """See base class."""
236 | return self._create_examples(
237 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
238 | "dev_matched")
239 |
240 | def get_labels(self):
241 | """See base class."""
242 | return ["contradiction", "entailment", "neutral"]
243 |
244 | def _create_examples(self, lines, set_type):
245 | """Creates examples for the training and dev sets."""
246 | examples = []
247 | for (i, line) in enumerate(lines):
248 | if i == 0:
249 | continue
250 | guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0]))
251 | text_a = tokenization.convert_to_unicode(line[8])
252 | text_b = tokenization.convert_to_unicode(line[9])
253 | label = tokenization.convert_to_unicode(line[-1])
254 | examples.append(
255 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
256 | return examples
257 |
258 |
259 | class MrpcProcessor(DataProcessor):
260 | """Processor for the MRPC data set (GLUE version)."""
261 |
262 | def get_train_examples(self, data_dir):
263 | """See base class."""
264 | return self._create_examples(
265 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
266 |
267 | def get_dev_examples(self, data_dir):
268 | """See base class."""
269 | return self._create_examples(
270 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
271 |
272 | def get_labels(self):
273 | """See base class."""
274 | return ["0", "1"]
275 |
276 | def _create_examples(self, lines, set_type):
277 | """Creates examples for the training and dev sets."""
278 | examples = []
279 | for (i, line) in enumerate(lines):
280 | if i == 0:
281 | continue
282 | guid = "%s-%s" % (set_type, i)
283 | text_a = tokenization.convert_to_unicode(line[3])
284 | text_b = tokenization.convert_to_unicode(line[4])
285 | label = tokenization.convert_to_unicode(line[0])
286 | examples.append(
287 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
288 | return examples
289 |
290 |
291 | class ColaProcessor(DataProcessor):
292 | """Processor for the CoLA data set (GLUE version)."""
293 |
294 | def get_train_examples(self, data_dir):
295 | """See base class."""
296 | return self._create_examples(
297 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
298 |
299 | def get_dev_examples(self, data_dir):
300 | """See base class."""
301 | return self._create_examples(
302 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
303 |
304 | def get_labels(self):
305 | """See base class."""
306 | return ["0", "1"]
307 |
308 | def _create_examples(self, lines, set_type):
309 | """Creates examples for the training and dev sets."""
310 | examples = []
311 | for (i, line) in enumerate(lines):
312 | guid = "%s-%s" % (set_type, i)
313 | text_a = tokenization.convert_to_unicode(line[3])
314 | label = tokenization.convert_to_unicode(line[1])
315 | examples.append(
316 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
317 | return examples
318 |
319 |
320 | def convert_single_example(ex_index, example, label_list, max_seq_length,
321 | tokenizer):
322 | """Converts a single `InputExample` into a single `InputFeatures`."""
323 | label_map = {}
324 | for (i, label) in enumerate(label_list):
325 | label_map[label] = i
326 |
327 | tokens_a = tokenizer.tokenize(example.text_a)
328 | tokens_b = None
329 | if example.text_b:
330 | tokens_b = tokenizer.tokenize(example.text_b)
331 |
332 | if tokens_b:
333 | # Modifies `tokens_a` and `tokens_b` in place so that the total
334 | # length is less than the specified length.
335 | # Account for [CLS], [SEP], [SEP] with "- 3"
336 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
337 | else:
338 | # Account for [CLS] and [SEP] with "- 2"
339 | if len(tokens_a) > max_seq_length - 2:
340 | tokens_a = tokens_a[0:(max_seq_length - 2)]
341 |
342 | # The convention in BERT is:
343 | # (a) For sequence pairs:
344 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
345 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
346 | # (b) For single sequences:
347 | # tokens: [CLS] the dog is hairy . [SEP]
348 | # type_ids: 0 0 0 0 0 0 0
349 | #
350 | # Where "type_ids" are used to indicate whether this is the first
351 | # sequence or the second sequence. The embedding vectors for `type=0` and
352 | # `type=1` were learned during pre-training and are added to the wordpiece
353 | # embedding vector (and position vector). This is not *strictly* necessary
354 | # since the [SEP] token unambiguously separates the sequences, but it makes
355 | # it easier for the model to learn the concept of sequences.
356 | #
357 | # For classification tasks, the first vector (corresponding to [CLS]) is
358 | # used as as the "sentence vector". Note that this only makes sense because
359 | # the entire model is fine-tuned.
360 | tokens = []
361 | segment_ids = []
362 | tokens.append("[CLS]")
363 | segment_ids.append(0)
364 | for token in tokens_a:
365 | tokens.append(token)
366 | segment_ids.append(0)
367 | tokens.append("[SEP]")
368 | segment_ids.append(0)
369 |
370 | if tokens_b:
371 | for token in tokens_b:
372 | tokens.append(token)
373 | segment_ids.append(1)
374 | tokens.append("[SEP]")
375 | segment_ids.append(1)
376 |
377 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
378 |
379 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
380 | # tokens are attended to.
381 | input_mask = [1] * len(input_ids)
382 |
383 | # Zero-pad up to the sequence length.
384 | while len(input_ids) < max_seq_length:
385 | input_ids.append(0)
386 | input_mask.append(0)
387 | segment_ids.append(0)
388 |
389 | assert len(input_ids) == max_seq_length
390 | assert len(input_mask) == max_seq_length
391 | assert len(segment_ids) == max_seq_length
392 |
393 | label_id = label_map[example.label]
394 | if ex_index < 5:
395 | tf.logging.info("*** Example ***")
396 | tf.logging.info("guid: %s" % (example.guid))
397 | tf.logging.info("tokens: %s" % " ".join(
398 | [tokenization.printable_text(x) for x in tokens]))
399 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
400 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
401 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
402 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id))
403 |
404 | feature = InputFeatures(
405 | input_ids=input_ids,
406 | input_mask=input_mask,
407 | segment_ids=segment_ids,
408 | label_id=label_id)
409 | return feature
410 |
411 |
412 | def filed_based_convert_examples_to_features(
413 | examples, label_list, max_seq_length, tokenizer, output_file):
414 | """Convert a set of `InputExample`s to a TFRecord file."""
415 |
416 | writer = tf.python_io.TFRecordWriter(output_file)
417 |
418 | for (ex_index, example) in enumerate(examples):
419 | if ex_index % 10000 == 0:
420 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
421 |
422 | feature = convert_single_example(ex_index, example, label_list,
423 | max_seq_length, tokenizer)
424 |
425 | def create_int_feature(values):
426 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
427 | return f
428 |
429 | features = collections.OrderedDict()
430 | features["input_ids"] = create_int_feature(feature.input_ids)
431 | features["input_mask"] = create_int_feature(feature.input_mask)
432 | features["segment_ids"] = create_int_feature(feature.segment_ids)
433 | features["label_ids"] = create_int_feature([feature.label_id])
434 |
435 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
436 | writer.write(tf_example.SerializeToString())
437 |
438 |
439 | def file_based_input_fn_builder(input_file, seq_length, is_training,
440 | drop_remainder):
441 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
442 |
443 | name_to_features = {
444 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
445 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
446 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
447 | "label_ids": tf.FixedLenFeature([], tf.int64),
448 | }
449 |
450 | def _decode_record(record, name_to_features):
451 | """Decodes a record to a TensorFlow example."""
452 | example = tf.parse_single_example(record, name_to_features)
453 |
454 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
455 | # So cast all int64 to int32.
456 | for name in list(example.keys()):
457 | t = example[name]
458 | if t.dtype == tf.int64:
459 | t = tf.to_int32(t)
460 | example[name] = t
461 |
462 | return example
463 |
464 | def input_fn(params):
465 | """The actual input function."""
466 | batch_size = params["batch_size"]
467 |
468 | # For training, we want a lot of parallel reading and shuffling.
469 | # For eval, we want no shuffling and parallel reading doesn't matter.
470 | d = tf.data.TFRecordDataset(input_file)
471 | if is_training:
472 | d = d.repeat()
473 | d = d.shuffle(buffer_size=100)
474 |
475 | d = d.apply(
476 | tf.contrib.data.map_and_batch(
477 | lambda record: _decode_record(record, name_to_features),
478 | batch_size=batch_size,
479 | drop_remainder=drop_remainder))
480 |
481 | return d
482 |
483 | return input_fn
484 |
485 |
486 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
487 | """Truncates a sequence pair in place to the maximum length."""
488 |
489 | # This is a simple heuristic which will always truncate the longer sequence
490 | # one token at a time. This makes more sense than truncating an equal percent
491 | # of tokens from each, since if one sequence is very short then each token
492 | # that's truncated likely contains more information than a longer sequence.
493 | while True:
494 | total_length = len(tokens_a) + len(tokens_b)
495 | if total_length <= max_length:
496 | break
497 | if len(tokens_a) > len(tokens_b):
498 | tokens_a.pop()
499 | else:
500 | tokens_b.pop()
501 |
502 |
503 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
504 | labels, num_labels, use_one_hot_embeddings):
505 | """Creates a classification model."""
506 | model = modeling.BertModel(
507 | config=bert_config,
508 | is_training=is_training,
509 | input_ids=input_ids,
510 | input_mask=input_mask,
511 | token_type_ids=segment_ids,
512 | use_one_hot_embeddings=use_one_hot_embeddings)
513 |
514 | # In the demo, we are doing a simple classification task on the entire
515 | # segment.
516 | #
517 | # If you want to use the token-level output, use model.get_sequence_output()
518 | # instead.
519 | output_layer = model.get_pooled_output()
520 |
521 | hidden_size = output_layer.shape[-1].value
522 |
523 | output_weights = tf.get_variable(
524 | "output_weights", [num_labels, hidden_size],
525 | initializer=tf.truncated_normal_initializer(stddev=0.02))
526 |
527 | output_bias = tf.get_variable(
528 | "output_bias", [num_labels], initializer=tf.zeros_initializer())
529 |
530 | with tf.variable_scope("loss"):
531 | if is_training:
532 | # I.e., 0.1 dropout
533 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
534 |
535 | logits = tf.matmul(output_layer, output_weights, transpose_b=True)
536 | logits = tf.nn.bias_add(logits, output_bias)
537 | log_probs = tf.nn.log_softmax(logits, axis=-1)
538 |
539 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
540 |
541 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
542 | loss = tf.reduce_mean(per_example_loss)
543 |
544 | return (loss, per_example_loss, logits)
545 |
546 |
547 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
548 | num_train_steps, num_warmup_steps, use_tpu,
549 | use_one_hot_embeddings):
550 | """Returns `model_fn` closure for TPUEstimator."""
551 |
552 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
553 | """The `model_fn` for TPUEstimator."""
554 |
555 | tf.logging.info("*** Features ***")
556 | for name in sorted(features.keys()):
557 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
558 |
559 | input_ids = features["input_ids"]
560 | input_mask = features["input_mask"]
561 | segment_ids = features["segment_ids"]
562 | label_ids = features["label_ids"]
563 |
564 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
565 |
566 | (total_loss, per_example_loss, logits) = create_model(
567 | bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,
568 | num_labels, use_one_hot_embeddings)
569 |
570 | tvars = tf.trainable_variables()
571 |
572 | scaffold_fn = None
573 | if init_checkpoint:
574 | (assignment_map, initialized_variable_names
575 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
576 | if use_tpu:
577 |
578 | def tpu_scaffold():
579 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
580 | return tf.train.Scaffold()
581 |
582 | scaffold_fn = tpu_scaffold
583 | else:
584 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
585 |
586 | tf.logging.info("**** Trainable Variables ****")
587 | for var in tvars:
588 | init_string = ""
589 | if var.name in initialized_variable_names:
590 | init_string = ", *INIT_FROM_CKPT*"
591 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
592 | init_string)
593 |
594 | output_spec = None
595 | if mode == tf.estimator.ModeKeys.TRAIN:
596 |
597 | train_op = optimization.create_optimizer(
598 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
599 |
600 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
601 | mode=mode,
602 | loss=total_loss,
603 | train_op=train_op,
604 | scaffold_fn=scaffold_fn)
605 | elif mode == tf.estimator.ModeKeys.EVAL:
606 |
607 | def metric_fn(per_example_loss, label_ids, logits):
608 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
609 | accuracy = tf.metrics.accuracy(label_ids, predictions)
610 | loss = tf.metrics.mean(per_example_loss)
611 | return {
612 | "eval_accuracy": accuracy,
613 | "eval_loss": loss,
614 | }
615 |
616 | eval_metrics = (metric_fn, [per_example_loss, label_ids, logits])
617 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
618 | mode=mode,
619 | loss=total_loss,
620 | eval_metrics=eval_metrics,
621 | scaffold_fn=scaffold_fn)
622 | else:
623 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
624 |
625 | return output_spec
626 |
627 | return model_fn
628 |
629 |
630 | # This function is not used by this file but is still used by the Colab and
631 | # people who depend on it.
632 | def input_fn_builder(features, seq_length, is_training, drop_remainder):
633 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
634 |
635 | all_input_ids = []
636 | all_input_mask = []
637 | all_segment_ids = []
638 | all_label_ids = []
639 |
640 | for feature in features:
641 | all_input_ids.append(feature.input_ids)
642 | all_input_mask.append(feature.input_mask)
643 | all_segment_ids.append(feature.segment_ids)
644 | all_label_ids.append(feature.label_id)
645 |
646 | def input_fn(params):
647 | """The actual input function."""
648 | batch_size = params["batch_size"]
649 |
650 | num_examples = len(features)
651 |
652 | # This is for demo purposes and does NOT scale to large data sets. We do
653 | # not use Dataset.from_generator() because that uses tf.py_func which is
654 | # not TPU compatible. The right way to load data is with TFRecordReader.
655 | d = tf.data.Dataset.from_tensor_slices({
656 | "input_ids":
657 | tf.constant(
658 | all_input_ids, shape=[num_examples, seq_length],
659 | dtype=tf.int32),
660 | "input_mask":
661 | tf.constant(
662 | all_input_mask,
663 | shape=[num_examples, seq_length],
664 | dtype=tf.int32),
665 | "segment_ids":
666 | tf.constant(
667 | all_segment_ids,
668 | shape=[num_examples, seq_length],
669 | dtype=tf.int32),
670 | "label_ids":
671 | tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32),
672 | })
673 |
674 | if is_training:
675 | d = d.repeat()
676 | d = d.shuffle(buffer_size=100)
677 |
678 | d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
679 | return d
680 |
681 | return input_fn
682 |
683 |
684 | # This function is not used by this file but is still used by the Colab and
685 | # people who depend on it.
686 | def convert_examples_to_features(examples, label_list, max_seq_length,
687 | tokenizer):
688 | """Convert a set of `InputExample`s to a list of `InputFeatures`."""
689 |
690 | features = []
691 | for (ex_index, example) in enumerate(examples):
692 | if ex_index % 10000 == 0:
693 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
694 |
695 | feature = convert_single_example(ex_index, example, label_list,
696 | max_seq_length, tokenizer)
697 |
698 | features.append(feature)
699 | return features
700 |
701 |
702 | def main(_):
703 | tf.logging.set_verbosity(tf.logging.INFO)
704 |
705 | processors = {
706 | "cola": ColaProcessor,
707 | "mnli": MnliProcessor,
708 | "mrpc": MrpcProcessor,
709 | "xnli": XnliProcessor,
710 | }
711 |
712 | if not FLAGS.do_train and not FLAGS.do_eval:
713 | raise ValueError("At least one of `do_train` or `do_eval` must be True.")
714 |
715 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
716 |
717 | if FLAGS.max_seq_length > bert_config.max_position_embeddings:
718 | raise ValueError(
719 | "Cannot use sequence length %d because the BERT model "
720 | "was only trained up to sequence length %d" %
721 | (FLAGS.max_seq_length, bert_config.max_position_embeddings))
722 |
723 | tf.gfile.MakeDirs(FLAGS.output_dir)
724 |
725 | task_name = FLAGS.task_name.lower()
726 |
727 | if task_name not in processors:
728 | raise ValueError("Task not found: %s" % (task_name))
729 |
730 | processor = processors[task_name]()
731 |
732 | label_list = processor.get_labels()
733 |
734 | tokenizer = tokenization.FullTokenizer(
735 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
736 |
737 | tpu_cluster_resolver = None
738 | if FLAGS.use_tpu and FLAGS.tpu_name:
739 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
740 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
741 |
742 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
743 | run_config = tf.contrib.tpu.RunConfig(
744 | cluster=tpu_cluster_resolver,
745 | master=FLAGS.master,
746 | model_dir=FLAGS.output_dir,
747 | save_checkpoints_steps=FLAGS.save_checkpoints_steps,
748 | tpu_config=tf.contrib.tpu.TPUConfig(
749 | iterations_per_loop=FLAGS.iterations_per_loop,
750 | num_shards=FLAGS.num_tpu_cores,
751 | per_host_input_for_training=is_per_host))
752 |
753 | train_examples = None
754 | num_train_steps = None
755 | num_warmup_steps = None
756 | if FLAGS.do_train:
757 | train_examples = processor.get_train_examples(FLAGS.data_dir)
758 | num_train_steps = int(
759 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
760 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
761 |
762 | model_fn = model_fn_builder(
763 | bert_config=bert_config,
764 | num_labels=len(label_list),
765 | init_checkpoint=FLAGS.init_checkpoint,
766 | learning_rate=FLAGS.learning_rate,
767 | num_train_steps=num_train_steps,
768 | num_warmup_steps=num_warmup_steps,
769 | use_tpu=FLAGS.use_tpu,
770 | use_one_hot_embeddings=FLAGS.use_tpu)
771 |
772 | # If TPU is not available, this will fall back to normal Estimator on CPU
773 | # or GPU.
774 | estimator = tf.contrib.tpu.TPUEstimator(
775 | use_tpu=FLAGS.use_tpu,
776 | model_fn=model_fn,
777 | config=run_config,
778 | train_batch_size=FLAGS.train_batch_size,
779 | eval_batch_size=FLAGS.eval_batch_size)
780 |
781 | if FLAGS.do_train:
782 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
783 | filed_based_convert_examples_to_features(
784 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
785 | tf.logging.info("***** Running training *****")
786 | tf.logging.info(" Num examples = %d", len(train_examples))
787 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
788 | tf.logging.info(" Num steps = %d", num_train_steps)
789 | train_input_fn = file_based_input_fn_builder(
790 | input_file=train_file,
791 | seq_length=FLAGS.max_seq_length,
792 | is_training=True,
793 | drop_remainder=True)
794 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
795 |
796 | if FLAGS.do_eval:
797 | eval_examples = processor.get_dev_examples(FLAGS.data_dir)
798 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
799 | filed_based_convert_examples_to_features(
800 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)
801 |
802 | tf.logging.info("***** Running evaluation *****")
803 | tf.logging.info(" Num examples = %d", len(eval_examples))
804 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
805 |
806 | # This tells the estimator to run through the entire set.
807 | eval_steps = None
808 | # However, if running eval on the TPU, you will need to specify the
809 | # number of steps.
810 | if FLAGS.use_tpu:
811 | # Eval will be slightly WRONG on the TPU because it will truncate
812 | # the last batch.
813 | eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size)
814 |
815 | eval_drop_remainder = True if FLAGS.use_tpu else False
816 | eval_input_fn = file_based_input_fn_builder(
817 | input_file=eval_file,
818 | seq_length=FLAGS.max_seq_length,
819 | is_training=False,
820 | drop_remainder=eval_drop_remainder)
821 |
822 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
823 |
824 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
825 | with tf.gfile.GFile(output_eval_file, "w") as writer:
826 | tf.logging.info("***** Eval results *****")
827 | for key in sorted(result.keys()):
828 | tf.logging.info(" %s = %s", key, str(result[key]))
829 | writer.write("%s = %s\n" % (key, str(result[key])))
830 |
831 |
832 | if __name__ == "__main__":
833 | flags.mark_flag_as_required("data_dir")
834 | flags.mark_flag_as_required("task_name")
835 | flags.mark_flag_as_required("vocab_file")
836 | flags.mark_flag_as_required("bert_config_file")
837 | flags.mark_flag_as_required("output_dir")
838 | tf.app.run()
839 |
--------------------------------------------------------------------------------
/run_classifier_predict_online.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """BERT finetuning runner of classification for online prediction. input is a list. output is a label."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import csv
22 | import os
23 | import modeling
24 | import tokenization
25 | import tensorflow as tf
26 | import numpy as np
27 |
28 | flags = tf.flags
29 |
30 | FLAGS = flags.FLAGS
31 |
32 | ## Required parameters
33 | BERT_BASE_DIR = 'weight/chinese_L-12_H-768_A-12/'
34 | flags.DEFINE_string("bert_config_file", BERT_BASE_DIR + "bert_config.json",
35 | "The config json file corresponding to the pre-trained BERT model. "
36 | "This specifies the model architecture.")
37 |
38 | flags.DEFINE_string("task_name", "sentence_pair", "The name of the task to train.")
39 |
40 | flags.DEFINE_string("vocab_file", BERT_BASE_DIR + "vocab.txt",
41 | "The vocabulary file that the BERT model was trained on.")
42 |
43 | flags.DEFINE_string("init_checkpoint", BERT_BASE_DIR, # model.ckpt-66870--> /model.ckpt-66870
44 | "Initial checkpoint (usually from a pre-trained BERT model).")
45 |
46 | flags.DEFINE_integer("max_seq_length", 512,
47 | "The maximum total input sequence length after WordPiece tokenization. "
48 | "Sequences longer than this will be truncated, and sequences shorter "
49 | "than this will be padded.")
50 |
51 | flags.DEFINE_bool(
52 | "do_lower_case", True,
53 | "Whether to lower case the input text. Should be True for uncased "
54 | "models and False for cased models.")
55 |
56 |
57 | class InputExample(object):
58 | """A single training/test example for simple sequence classification."""
59 |
60 | def __init__(self, guid, text_a, text_b=None, label=None):
61 | """Constructs a InputExample.
62 | Args:
63 | guid: Unique id for the example.
64 | text_a: string. The untokenized text of the first sequence. For single
65 | sequence tasks, only this sequence must be specified.
66 | text_b: (Optional) string. The untokenized text of the second sequence.
67 | Only must be specified for sequence pair tasks.
68 | label: (Optional) string. The label of the example. This should be
69 | specified for train and dev examples, but not for test examples.
70 | """
71 | self.guid = guid
72 | self.text_a = text_a
73 | self.text_b = text_b
74 | self.label = label
75 |
76 |
77 | class InputFeatures(object):
78 | """A single set of features of data."""
79 |
80 | def __init__(self, input_ids, input_mask, segment_ids, label_id):
81 | self.input_ids = input_ids
82 | self.input_mask = input_mask
83 | self.segment_ids = segment_ids
84 | self.label_id = label_id
85 |
86 |
87 | class DataProcessor(object):
88 | """Base class for data converters for sequence classification data sets."""
89 |
90 | def get_train_examples(self, data_dir):
91 | """Gets a collection of `InputExample`s for the train set."""
92 | raise NotImplementedError()
93 |
94 | def get_dev_examples(self, data_dir):
95 | """Gets a collection of `InputExample`s for the dev set."""
96 | raise NotImplementedError()
97 |
98 | def get_test_examples(self, data_dir):
99 | """Gets a collection of `InputExample`s for prediction."""
100 | raise NotImplementedError()
101 |
102 | def get_labels(self):
103 | """Gets the list of labels for this data set."""
104 | raise NotImplementedError()
105 |
106 | @classmethod
107 | def _read_tsv(cls, input_file, quotechar=None):
108 | """Reads a tab separated value file."""
109 | with tf.gfile.Open(input_file, "r") as f:
110 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
111 | lines = []
112 | for line in reader:
113 | lines.append(line)
114 | return lines
115 |
116 |
117 | class SentencePairClassificationProcessor(DataProcessor):
118 | """Processor for the internal data set. sentence pair classification"""
119 |
120 | def __init__(self):
121 | self.language = "zh"
122 |
123 | # def get_train_examples(self, data_dir):
124 | # """See base class."""
125 | # return self._create_examples(
126 | # self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
127 |
128 | # def get_dev_examples(self, data_dir):
129 | # """See base class."""
130 | # return self._create_examples(
131 | # self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
132 |
133 | # def get_test_examples(self, data_dir):
134 | # """See base class."""
135 | # return self._create_examples(
136 | # self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
137 |
138 | def get_labels(self):
139 | """See base class."""
140 | return ["0", "1"]
141 |
142 | # def _create_examples(self, lines, set_type):
143 | """Creates examples for the training and dev sets."""
144 | # examples = []
145 | # for (i, line) in enumerate(lines):
146 | # if i == 0:
147 | # continue
148 | # guid = "%s-%s" % (set_type, i)
149 | # label = tokenization.convert_to_unicode(line[0])
150 | # text_a = tokenization.convert_to_unicode(line[1])
151 | # text_b = tokenization.convert_to_unicode(line[2])
152 | # examples.append(
153 | # InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
154 | # return examples
155 |
156 |
157 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer):
158 | """Converts a single `InputExample` into a single `InputFeatures`."""
159 | label_map = {}
160 | for (i, label) in enumerate(label_list):
161 | label_map[label] = i
162 |
163 | tokens_a = tokenizer.tokenize(example.text_a)
164 | tokens_b = None
165 | if example.text_b:
166 | tokens_b = tokenizer.tokenize(example.text_b)
167 |
168 | if tokens_b:
169 | # Modifies `tokens_a` and `tokens_b` in place so that the total
170 | # length is less than the specified length.
171 | # Account for [CLS], [SEP], [SEP] with "- 3"
172 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
173 | else:
174 | # Account for [CLS] and [SEP] with "- 2"
175 | if len(tokens_a) > max_seq_length - 2:
176 | tokens_a = tokens_a[0:(max_seq_length - 2)]
177 |
178 | # The convention in BERT is:
179 | # (a) For sequence pairs:
180 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
181 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
182 | # (b) For single sequences:
183 | # tokens: [CLS] the dog is hairy . [SEP]
184 | # type_ids: 0 0 0 0 0 0 0
185 | #
186 | # Where "type_ids" are used to indicate whether this is the first
187 | # sequence or the second sequence. The embedding vectors for `type=0` and
188 | # `type=1` were learned during pre-training and are added to the wordpiece
189 | # embedding vector (and position vector). This is not *strictly* necessary
190 | # since the [SEP] token unambiguously separates the sequences, but it makes
191 | # it easier for the model to learn the concept of sequences.
192 | #
193 | # For classification tasks, the first vector (corresponding to [CLS]) is
194 | # used as as the "sentence vector". Note that this only makes sense because
195 | # the entire model is fine-tuned.
196 | tokens = []
197 | segment_ids = []
198 | tokens.append("[CLS]")
199 | segment_ids.append(0)
200 | for token in tokens_a:
201 | tokens.append(token)
202 | segment_ids.append(0)
203 | tokens.append("[SEP]")
204 | segment_ids.append(0)
205 |
206 | if tokens_b:
207 | for token in tokens_b:
208 | tokens.append(token)
209 | segment_ids.append(1)
210 | tokens.append("[SEP]")
211 | segment_ids.append(1)
212 |
213 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
214 |
215 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
216 | # tokens are attended to.
217 | input_mask = [1] * len(input_ids)
218 |
219 | # Zero-pad up to the sequence length.
220 | while len(input_ids) < max_seq_length:
221 | input_ids.append(0)
222 | input_mask.append(0)
223 | segment_ids.append(0)
224 |
225 | assert len(input_ids) == max_seq_length
226 | assert len(input_mask) == max_seq_length
227 | assert len(segment_ids) == max_seq_length
228 |
229 | label_id = label_map[example.label]
230 | if ex_index < 5:
231 | tf.logging.info("*** Example ***")
232 | tf.logging.info("guid: %s" % (example.guid))
233 | tf.logging.info("tokens: %s" % " ".join(
234 | [tokenization.printable_text(x) for x in tokens]))
235 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
236 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
237 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
238 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id))
239 |
240 | feature = InputFeatures(
241 | input_ids=input_ids,
242 | input_mask=input_mask,
243 | segment_ids=segment_ids,
244 | label_id=label_id)
245 | return feature
246 |
247 |
248 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
249 | """Truncates a sequence pair in place to the maximum length."""
250 |
251 | # This is a simple heuristic which will always truncate the longer sequence
252 | # one token at a time. This makes more sense than truncating an equal percent
253 | # of tokens from each, since if one sequence is very short then each token
254 | # that's truncated likely contains more information than a longer sequence.
255 | while True:
256 | total_length = len(tokens_a) + len(tokens_b)
257 | if total_length <= max_length:
258 | break
259 | if len(tokens_a) > len(tokens_b):
260 | tokens_a.pop()
261 | else:
262 | tokens_b.pop()
263 |
264 |
265 | def create_int_feature(values):
266 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
267 | return f
268 |
269 |
270 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
271 | labels, num_labels, use_one_hot_embeddings):
272 | """Creates a classification model."""
273 | model = modeling.BertModel(
274 | config=bert_config,
275 | is_training=is_training,
276 | input_ids=input_ids,
277 | input_mask=input_mask,
278 | token_type_ids=segment_ids,
279 | use_one_hot_embeddings=use_one_hot_embeddings)
280 |
281 | # In the demo, we are doing a simple classification task on the entire
282 | # segment.
283 | #
284 | # If you want to use the token-level output, use model.get_sequence_output()
285 | # instead.
286 | output_layer = model.get_pooled_output()
287 |
288 | hidden_size = output_layer.shape[-1].value
289 |
290 | output_weights = tf.get_variable(
291 | "output_weights", [num_labels, hidden_size],
292 | initializer=tf.truncated_normal_initializer(stddev=0.02))
293 |
294 | output_bias = tf.get_variable(
295 | "output_bias", [num_labels], initializer=tf.zeros_initializer())
296 |
297 | with tf.variable_scope("loss"):
298 | if is_training:
299 | # I.e., 0.1 dropout
300 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
301 |
302 | logits = tf.matmul(output_layer, output_weights, transpose_b=True)
303 | logits = tf.nn.bias_add(logits, output_bias)
304 | probabilities = tf.nn.softmax(logits, axis=-1)
305 | log_probs = tf.nn.log_softmax(logits, axis=-1)
306 |
307 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
308 |
309 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
310 | loss = tf.reduce_mean(per_example_loss)
311 |
312 | return (loss, per_example_loss, logits, probabilities, model)
313 |
314 |
315 | tf.logging.set_verbosity(tf.logging.INFO)
316 | processors = {
317 | "sentence_pair": SentencePairClassificationProcessor,
318 | }
319 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
320 | task_name = FLAGS.task_name.lower()
321 | print("task_name:", task_name)
322 | processor = processors[task_name]()
323 | label_list = processor.get_labels()
324 | # lines_dev=processor.get_dev_examples("./TEXT_DIR")
325 | index2label = {i: label_list[i] for i in range(len(label_list))}
326 | tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
327 |
328 |
329 | def main(_):
330 | pass
331 |
332 |
333 | # init mode and session
334 | # move something codes outside of function, so that this code will run only once during online prediction when predict_online is invoked.
335 | is_training = False
336 | use_one_hot_embeddings = False
337 | batch_size = 1
338 | num_labels = len(label_list)
339 | gpu_config = tf.ConfigProto()
340 | gpu_config.gpu_options.allow_growth = True
341 | sess = tf.Session(config=gpu_config)
342 | model = None
343 | global graph
344 | input_ids_p, input_mask_p, label_ids_p, segment_ids_p = None, None, None, None
345 | if not os.path.exists(FLAGS.init_checkpoint + "checkpoint"):
346 | raise Exception("failed to get checkpoint. going to return ")
347 |
348 | graph = tf.get_default_graph()
349 | with graph.as_default():
350 | print("going to restore checkpoint")
351 | # sess.run(tf.global_variables_initializer())
352 | input_ids_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="input_ids")
353 | input_mask_p = tf.placeholder(tf.int32, [batch_size, FLAGS.max_seq_length], name="input_mask")
354 | label_ids_p = tf.placeholder(tf.int32, [batch_size], name="label_ids")
355 | segment_ids_p = tf.placeholder(tf.int32, [FLAGS.max_seq_length], name="segment_ids")
356 | total_loss, per_example_loss, logits, probabilities, model = create_model(
357 | bert_config, is_training, input_ids_p, input_mask_p, segment_ids_p,
358 | label_ids_p, num_labels, use_one_hot_embeddings)
359 | saver = tf.train.Saver()
360 | saver.restore(sess, tf.train.latest_checkpoint(FLAGS.init_checkpoint))
361 |
362 |
363 | def predict_online(line):
364 | """
365 | do online prediction. each time make prediction for one instance.
366 | you can change to a batch if you want.
367 | :param line: a list. element is: [dummy_label,text_a,text_b]
368 | :return:
369 | """
370 | label = line[
371 | 0] # tokenization.convert_to_unicode(line[0]) # this should compatible with format you defined in processor.
372 | text_a = line[1] # tokenization.convert_to_unicode(line[1])
373 | text_b = line[2] # tokenization.convert_to_unicode(line[2])
374 | example = InputExample(guid=0, text_a=text_a, text_b=text_b, label=label)
375 | feature = convert_single_example(0, example, label_list, FLAGS.max_seq_length, tokenizer)
376 | input_ids = np.reshape([feature.input_ids], (1, FLAGS.max_seq_length))
377 | input_mask = np.reshape([feature.input_mask], (1, FLAGS.max_seq_length))
378 | segment_ids = np.reshape([feature.segment_ids], (FLAGS.max_seq_length))
379 | label_ids = [feature.label_id]
380 |
381 | global graph
382 | with graph.as_default():
383 | feed_dict = {input_ids_p: input_ids, input_mask_p: input_mask, segment_ids_p: segment_ids,
384 | label_ids_p: label_ids}
385 | possibility = sess.run([probabilities], feed_dict)
386 | possibility = possibility[0][0] # get first label
387 | label_index = np.argmax(possibility)
388 | label_predict = index2label[label_index]
389 | # print("label_predict:",label_predict,";possibility:",possibility)
390 | return label_predict, possibility
391 |
392 |
393 | if __name__ == "__main__":
394 | example = ['0',
395 | '\u5165\u804c\u4e00\u5e74\u534a\u672a\u7b7e\u52b3\u52a8\u5408\u540c\u5c0f\u83f2\u6bd5\u4e1a\u4e8e\u67d0\u62a4\u6821\uff0c\u548c\u5176\u4ed6\u7684\u9ad8\u6821\u6bd5\u4e1a\u751f\u4e00\u6837\uff0c\u5979\u4e5f\u5f00\u59cb\u7740\u624b\u627e\u5de5\u4f5c\u3002\u5f88\u5feb\uff0c\u4e00\u5bb6\u6c11\u529e\u533b\u9662\u901a\u8fc7\u67d0\u62db\u8058\u7f51\u7ad9\u627e\u5230\u5c0f\u83f2\uff0c\u901a\u8fc7\u9762\u8bd5\u540e\uff0c\u5c0f\u83f2\u4fbf\u5f00\u59cb\u4e86\u81ea\u5df1\u7684\u804c\u573a\u751f\u6daf\u3002\u8f6c\u773c\u6bd5\u4e1a\u5de5\u4f5c\u8fd1\u4e00\u5e74\uff0c\u533b\u9662\u4ecd\u8fdf\u8fdf\u4e0d\u4e0e\u5176\u7b7e\u8ba2\u52b3\u52a8\u5408\u540c\uff0c\u5c0f\u83f2\u4e0e\u5355\u4f4d\u591a\u6b21\u6c9f\u901a\u534f\u5546\u672a\u679c\uff0c\u65e0\u5948\u5c06\u533b\u9662\u8bc9\u81f3\u6cd5\u9662\u000d\u000a',
396 | '\u652f\u4ed8\u5de5\u8d44']
397 | result = predict_online(example)
398 | print("result:", result)
399 |
--------------------------------------------------------------------------------
/run_pretraining.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Run masked LM/next sentence masked_lm pre-training for BERT."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import modeling
23 | import optimization
24 | import tensorflow as tf
25 |
26 | flags = tf.flags
27 |
28 | FLAGS = flags.FLAGS
29 |
30 | ## Required parameters
31 | flags.DEFINE_string(
32 | "bert_config_file", None,
33 | "The config json file corresponding to the pre-trained BERT model. "
34 | "This specifies the model architecture.")
35 |
36 | flags.DEFINE_string(
37 | "input_file", None,
38 | "Input TF example files (can be a glob or comma separated).")
39 |
40 | flags.DEFINE_string(
41 | "output_dir", None,
42 | "The output directory where the model checkpoints will be written.")
43 |
44 | ## Other parameters
45 | flags.DEFINE_string(
46 | "init_checkpoint", None,
47 | "Initial checkpoint (usually from a pre-trained BERT model).")
48 |
49 | flags.DEFINE_integer(
50 | "max_seq_length", 128,
51 | "The maximum total input sequence length after WordPiece tokenization. "
52 | "Sequences longer than this will be truncated, and sequences shorter "
53 | "than this will be padded. Must match data generation.")
54 |
55 | flags.DEFINE_integer(
56 | "max_predictions_per_seq", 20,
57 | "Maximum number of masked LM predictions per sequence. "
58 | "Must match data generation.")
59 |
60 | flags.DEFINE_bool("do_train", False, "Whether to run training.")
61 |
62 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
63 |
64 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
65 |
66 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
67 |
68 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
69 |
70 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.")
71 |
72 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.")
73 |
74 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
75 | "How often to save the model checkpoint.")
76 |
77 | flags.DEFINE_integer("iterations_per_loop", 1000,
78 | "How many steps to make in each estimator call.")
79 |
80 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
81 |
82 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
83 |
84 | tf.flags.DEFINE_string(
85 | "tpu_name", None,
86 | "The Cloud TPU to use for training. This should be either the name "
87 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
88 | "url.")
89 |
90 | tf.flags.DEFINE_string(
91 | "tpu_zone", None,
92 | "[Optional] GCE zone where the Cloud TPU is located in. If not "
93 | "specified, we will attempt to automatically detect the GCE project from "
94 | "metadata.")
95 |
96 | tf.flags.DEFINE_string(
97 | "gcp_project", None,
98 | "[Optional] Project name for the Cloud TPU-enabled project. If not "
99 | "specified, we will attempt to automatically detect the GCE project from "
100 | "metadata.")
101 |
102 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
103 |
104 | flags.DEFINE_integer(
105 | "num_tpu_cores", 8,
106 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
107 |
108 |
109 | def model_fn_builder(bert_config, init_checkpoint, learning_rate,
110 | num_train_steps, num_warmup_steps, use_tpu,
111 | use_one_hot_embeddings):
112 | """Returns `model_fn` closure for TPUEstimator."""
113 |
114 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
115 | """The `model_fn` for TPUEstimator."""
116 |
117 | tf.logging.info("*** Features ***")
118 | for name in sorted(features.keys()):
119 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
120 |
121 | input_ids = features["input_ids"]
122 | input_mask = features["input_mask"]
123 | segment_ids = features["segment_ids"]
124 | masked_lm_positions = features["masked_lm_positions"]
125 | masked_lm_ids = features["masked_lm_ids"]
126 | masked_lm_weights = features["masked_lm_weights"]
127 | next_sentence_labels = features["next_sentence_labels"]
128 |
129 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
130 |
131 | model = modeling.BertModel(
132 | config=bert_config,
133 | is_training=is_training,
134 | input_ids=input_ids,
135 | input_mask=input_mask,
136 | token_type_ids=segment_ids,
137 | use_one_hot_embeddings=use_one_hot_embeddings)
138 |
139 | (masked_lm_loss,
140 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
141 | bert_config, model.get_sequence_output(), model.get_embedding_table(),
142 | masked_lm_positions, masked_lm_ids, masked_lm_weights)
143 |
144 | (next_sentence_loss, next_sentence_example_loss,
145 | next_sentence_log_probs) = get_next_sentence_output(
146 | bert_config, model.get_pooled_output(), next_sentence_labels)
147 |
148 | total_loss = masked_lm_loss + next_sentence_loss
149 |
150 | tvars = tf.trainable_variables()
151 |
152 | initialized_variable_names = {}
153 | scaffold_fn = None
154 | if init_checkpoint:
155 | (assignment_map, initialized_variable_names
156 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
157 | if use_tpu:
158 |
159 | def tpu_scaffold():
160 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
161 | return tf.train.Scaffold()
162 |
163 | scaffold_fn = tpu_scaffold
164 | else:
165 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
166 |
167 | tf.logging.info("**** Trainable Variables ****")
168 | for var in tvars:
169 | init_string = ""
170 | if var.name in initialized_variable_names:
171 | init_string = ", *INIT_FROM_CKPT*"
172 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
173 | init_string)
174 |
175 | output_spec = None
176 | if mode == tf.estimator.ModeKeys.TRAIN:
177 | train_op = optimization.create_optimizer(
178 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
179 |
180 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
181 | mode=mode,
182 | loss=total_loss,
183 | train_op=train_op,
184 | scaffold_fn=scaffold_fn)
185 | elif mode == tf.estimator.ModeKeys.EVAL:
186 |
187 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
188 | masked_lm_weights, next_sentence_example_loss,
189 | next_sentence_log_probs, next_sentence_labels):
190 | """Computes the loss and accuracy of the model."""
191 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
192 | [-1, masked_lm_log_probs.shape[-1]])
193 | masked_lm_predictions = tf.argmax(
194 | masked_lm_log_probs, axis=-1, output_type=tf.int32)
195 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
196 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
197 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
198 | masked_lm_accuracy = tf.metrics.accuracy(
199 | labels=masked_lm_ids,
200 | predictions=masked_lm_predictions,
201 | weights=masked_lm_weights)
202 | masked_lm_mean_loss = tf.metrics.mean(
203 | values=masked_lm_example_loss, weights=masked_lm_weights)
204 |
205 | next_sentence_log_probs = tf.reshape(
206 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
207 | next_sentence_predictions = tf.argmax(
208 | next_sentence_log_probs, axis=-1, output_type=tf.int32)
209 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
210 | next_sentence_accuracy = tf.metrics.accuracy(
211 | labels=next_sentence_labels, predictions=next_sentence_predictions)
212 | next_sentence_mean_loss = tf.metrics.mean(
213 | values=next_sentence_example_loss)
214 |
215 | return {
216 | "masked_lm_accuracy": masked_lm_accuracy,
217 | "masked_lm_loss": masked_lm_mean_loss,
218 | "next_sentence_accuracy": next_sentence_accuracy,
219 | "next_sentence_loss": next_sentence_mean_loss,
220 | }
221 |
222 | eval_metrics = (metric_fn, [
223 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
224 | masked_lm_weights, next_sentence_example_loss,
225 | next_sentence_log_probs, next_sentence_labels
226 | ])
227 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
228 | mode=mode,
229 | loss=total_loss,
230 | eval_metrics=eval_metrics,
231 | scaffold_fn=scaffold_fn)
232 | else:
233 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
234 |
235 | return output_spec
236 |
237 | return model_fn
238 |
239 |
240 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
241 | label_ids, label_weights):
242 | """Get loss and log probs for the masked LM."""
243 | input_tensor = gather_indexes(input_tensor, positions)
244 |
245 | with tf.variable_scope("cls/predictions"):
246 | # We apply one more non-linear transformation before the output layer.
247 | # This matrix is not used after pre-training.
248 | with tf.variable_scope("transform"):
249 | input_tensor = tf.layers.dense(
250 | input_tensor,
251 | units=bert_config.hidden_size,
252 | activation=modeling.get_activation(bert_config.hidden_act),
253 | kernel_initializer=modeling.create_initializer(
254 | bert_config.initializer_range))
255 | input_tensor = modeling.layer_norm(input_tensor)
256 |
257 | # The output weights are the same as the input embeddings, but there is
258 | # an output-only bias for each token.
259 | output_bias = tf.get_variable(
260 | "output_bias",
261 | shape=[bert_config.vocab_size],
262 | initializer=tf.zeros_initializer())
263 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
264 | logits = tf.nn.bias_add(logits, output_bias)
265 | log_probs = tf.nn.log_softmax(logits, axis=-1)
266 |
267 | label_ids = tf.reshape(label_ids, [-1])
268 | label_weights = tf.reshape(label_weights, [-1])
269 |
270 | one_hot_labels = tf.one_hot(
271 | label_ids, depth=bert_config.vocab_size, dtype=tf.float32)
272 |
273 | # The `positions` tensor might be zero-padded (if the sequence is too
274 | # short to have the maximum number of predictions). The `label_weights`
275 | # tensor has a value of 1.0 for every real prediction and 0.0 for the
276 | # padding predictions.
277 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
278 | numerator = tf.reduce_sum(label_weights * per_example_loss)
279 | denominator = tf.reduce_sum(label_weights) + 1e-5
280 | loss = numerator / denominator
281 |
282 | return (loss, per_example_loss, log_probs)
283 |
284 |
285 | def get_next_sentence_output(bert_config, input_tensor, labels):
286 | """Get loss and log probs for the next sentence prediction."""
287 |
288 | # Simple binary classification. Note that 0 is "next sentence" and 1 is
289 | # "random sentence". This weight matrix is not used after pre-training.
290 | with tf.variable_scope("cls/seq_relationship"):
291 | output_weights = tf.get_variable(
292 | "output_weights",
293 | shape=[2, bert_config.hidden_size],
294 | initializer=modeling.create_initializer(bert_config.initializer_range))
295 | output_bias = tf.get_variable(
296 | "output_bias", shape=[2], initializer=tf.zeros_initializer())
297 |
298 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
299 | logits = tf.nn.bias_add(logits, output_bias)
300 | log_probs = tf.nn.log_softmax(logits, axis=-1)
301 | labels = tf.reshape(labels, [-1])
302 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
303 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
304 | loss = tf.reduce_mean(per_example_loss)
305 | return (loss, per_example_loss, log_probs)
306 |
307 |
308 | def gather_indexes(sequence_tensor, positions):
309 | """Gathers the vectors at the specific positions over a minibatch."""
310 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
311 | batch_size = sequence_shape[0]
312 | seq_length = sequence_shape[1]
313 | width = sequence_shape[2]
314 |
315 | flat_offsets = tf.reshape(
316 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
317 | flat_positions = tf.reshape(positions + flat_offsets, [-1])
318 | flat_sequence_tensor = tf.reshape(sequence_tensor,
319 | [batch_size * seq_length, width])
320 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
321 | return output_tensor
322 |
323 |
324 | def input_fn_builder(input_files,
325 | max_seq_length,
326 | max_predictions_per_seq,
327 | is_training,
328 | num_cpu_threads=4):
329 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
330 |
331 | def input_fn(params):
332 | """The actual input function."""
333 | batch_size = params["batch_size"]
334 |
335 | name_to_features = {
336 | "input_ids":
337 | tf.FixedLenFeature([max_seq_length], tf.int64),
338 | "input_mask":
339 | tf.FixedLenFeature([max_seq_length], tf.int64),
340 | "segment_ids":
341 | tf.FixedLenFeature([max_seq_length], tf.int64),
342 | "masked_lm_positions":
343 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
344 | "masked_lm_ids":
345 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
346 | "masked_lm_weights":
347 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
348 | "next_sentence_labels":
349 | tf.FixedLenFeature([1], tf.int64),
350 | }
351 |
352 | # For training, we want a lot of parallel reading and shuffling.
353 | # For eval, we want no shuffling and parallel reading doesn't matter.
354 | if is_training:
355 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
356 | d = d.repeat()
357 | d = d.shuffle(buffer_size=len(input_files))
358 |
359 | # `cycle_length` is the number of parallel files that get read.
360 | cycle_length = min(num_cpu_threads, len(input_files))
361 |
362 | # `sloppy` mode means that the interleaving is not exact. This adds
363 | # even more randomness to the training pipeline.
364 | d = d.apply(
365 | tf.contrib.data.parallel_interleave(
366 | tf.data.TFRecordDataset,
367 | sloppy=is_training,
368 | cycle_length=cycle_length))
369 | d = d.shuffle(buffer_size=100)
370 | else:
371 | d = tf.data.TFRecordDataset(input_files)
372 | # Since we evaluate for a fixed number of steps we don't want to encounter
373 | # out-of-range exceptions.
374 | d = d.repeat()
375 |
376 | # We must `drop_remainder` on training because the TPU requires fixed
377 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU
378 | # and we *don't* want to drop the remainder, otherwise we wont cover
379 | # every sample.
380 | d = d.apply(
381 | tf.contrib.data.map_and_batch(
382 | lambda record: _decode_record(record, name_to_features),
383 | batch_size=batch_size,
384 | num_parallel_batches=num_cpu_threads,
385 | drop_remainder=True))
386 | return d
387 |
388 | return input_fn
389 |
390 |
391 | def _decode_record(record, name_to_features):
392 | """Decodes a record to a TensorFlow example."""
393 | example = tf.parse_single_example(record, name_to_features)
394 |
395 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
396 | # So cast all int64 to int32.
397 | for name in list(example.keys()):
398 | t = example[name]
399 | if t.dtype == tf.int64:
400 | t = tf.to_int32(t)
401 | example[name] = t
402 |
403 | return example
404 |
405 |
406 | def main(_):
407 | tf.logging.set_verbosity(tf.logging.INFO)
408 |
409 | if not FLAGS.do_train and not FLAGS.do_eval:
410 | raise ValueError("At least one of `do_train` or `do_eval` must be True.")
411 |
412 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
413 |
414 | tf.gfile.MakeDirs(FLAGS.output_dir)
415 |
416 | input_files = []
417 | for input_pattern in FLAGS.input_file.split(","):
418 | input_files.extend(tf.gfile.Glob(input_pattern))
419 |
420 | tf.logging.info("*** Input Files ***")
421 | for input_file in input_files:
422 | tf.logging.info(" %s" % input_file)
423 |
424 | tpu_cluster_resolver = None
425 | if FLAGS.use_tpu and FLAGS.tpu_name:
426 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
427 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
428 |
429 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
430 | run_config = tf.contrib.tpu.RunConfig(
431 | cluster=tpu_cluster_resolver,
432 | master=FLAGS.master,
433 | model_dir=FLAGS.output_dir,
434 | save_checkpoints_steps=FLAGS.save_checkpoints_steps,
435 | tpu_config=tf.contrib.tpu.TPUConfig(
436 | iterations_per_loop=FLAGS.iterations_per_loop,
437 | num_shards=FLAGS.num_tpu_cores,
438 | per_host_input_for_training=is_per_host))
439 |
440 | model_fn = model_fn_builder(
441 | bert_config=bert_config,
442 | init_checkpoint=FLAGS.init_checkpoint,
443 | learning_rate=FLAGS.learning_rate,
444 | num_train_steps=FLAGS.num_train_steps,
445 | num_warmup_steps=FLAGS.num_warmup_steps,
446 | use_tpu=FLAGS.use_tpu,
447 | use_one_hot_embeddings=FLAGS.use_tpu)
448 |
449 | # If TPU is not available, this will fall back to normal Estimator on CPU
450 | # or GPU.
451 | estimator = tf.contrib.tpu.TPUEstimator(
452 | use_tpu=FLAGS.use_tpu,
453 | model_fn=model_fn,
454 | config=run_config,
455 | train_batch_size=FLAGS.train_batch_size,
456 | eval_batch_size=FLAGS.eval_batch_size)
457 |
458 | if FLAGS.do_train:
459 | tf.logging.info("***** Running training *****")
460 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
461 | train_input_fn = input_fn_builder(
462 | input_files=input_files,
463 | max_seq_length=FLAGS.max_seq_length,
464 | max_predictions_per_seq=FLAGS.max_predictions_per_seq,
465 | is_training=True)
466 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
467 |
468 | if FLAGS.do_eval:
469 | tf.logging.info("***** Running evaluation *****")
470 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
471 |
472 | eval_input_fn = input_fn_builder(
473 | input_files=input_files,
474 | max_seq_length=FLAGS.max_seq_length,
475 | max_predictions_per_seq=FLAGS.max_predictions_per_seq,
476 | is_training=False)
477 |
478 | result = estimator.evaluate(
479 | input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
480 |
481 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
482 | with tf.gfile.GFile(output_eval_file, "w") as writer:
483 | tf.logging.info("***** Eval results *****")
484 | for key in sorted(result.keys()):
485 | tf.logging.info(" %s = %s", key, str(result[key]))
486 | writer.write("%s = %s\n" % (key, str(result[key])))
487 |
488 |
489 | if __name__ == "__main__":
490 | flags.mark_flag_as_required("input_file")
491 | flags.mark_flag_as_required("bert_config_file")
492 | flags.mark_flag_as_required("output_dir")
493 | tf.app.run()
494 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 |
3 |
4 | """Script to illustrate usage of tf.estimator.Estimator in TF v1.3"""
5 |
6 | import tensorflow as tf
7 |
8 |
9 | from tensorflow.examples.tutorials.mnist import input_data as mnist_data
10 |
11 | from tensorflow.contrib import slim
12 |
13 | from tensorflow.contrib.learn import ModeKeys
14 |
15 | from tensorflow.contrib.learn import learn_runner
16 |
17 |
18 |
--------------------------------------------------------------------------------
/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import unicodedata
23 | import six
24 | import tensorflow as tf
25 |
26 |
27 | def convert_to_unicode(text):
28 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
29 | if six.PY3:
30 | if isinstance(text, str):
31 | return text
32 | elif isinstance(text, bytes):
33 | return text.decode("utf-8", "ignore")
34 | else:
35 | raise ValueError("Unsupported string type: %s" % (type(text)))
36 | elif six.PY2:
37 | if isinstance(text, str):
38 | return text.decode("utf-8", "ignore")
39 | elif isinstance(text, unicode):
40 | return text
41 | else:
42 | raise ValueError("Unsupported string type: %s" % (type(text)))
43 | else:
44 | raise ValueError("Not running on Python2 or Python 3?")
45 |
46 |
47 | def printable_text(text):
48 | """Returns text encoded in a way suitable for print or `tf.logging`."""
49 |
50 | # These functions want `str` for both Python2 and Python3, but in one case
51 | # it's a Unicode string and in the other it's a byte string.
52 | if six.PY3:
53 | if isinstance(text, str):
54 | return text
55 | elif isinstance(text, bytes):
56 | return text.decode("utf-8", "ignore")
57 | else:
58 | raise ValueError("Unsupported string type: %s" % (type(text)))
59 | elif six.PY2:
60 | if isinstance(text, str):
61 | return text
62 | elif isinstance(text, unicode):
63 | return text.encode("utf-8")
64 | else:
65 | raise ValueError("Unsupported string type: %s" % (type(text)))
66 | else:
67 | raise ValueError("Not running on Python2 or Python 3?")
68 |
69 |
70 | def load_vocab(vocab_file):
71 | """Loads a vocabulary file into a dictionary."""
72 | vocab = collections.OrderedDict()
73 | index = 0
74 | with tf.gfile.GFile(vocab_file, "r") as reader:
75 | while True:
76 | token = convert_to_unicode(reader.readline())
77 | if not token:
78 | break
79 | token = token.strip()
80 | vocab[token] = index
81 | index += 1
82 | return vocab
83 |
84 |
85 | def convert_tokens_to_ids(vocab, tokens, unk_token="[UNK]"):
86 | """Converts a sequence of tokens into ids using the vocab."""
87 | ids = []
88 | for token in tokens:
89 | if token in vocab:
90 | ids.append(vocab[token])
91 | else:
92 | ids.append(vocab[unk_token])
93 | return ids
94 |
95 |
96 | def whitespace_tokenize(text):
97 | """Runs basic whitespace cleaning and splitting on a peice of text."""
98 | text = text.strip()
99 | if not text:
100 | return []
101 | tokens = text.split()
102 | return tokens
103 |
104 |
105 | class FullTokenizer(object):
106 | """Runs end-to-end tokenziation."""
107 |
108 | def __init__(self, vocab_file, do_lower_case=True):
109 | self.vocab = load_vocab(vocab_file)
110 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
111 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
112 |
113 | def tokenize(self, text):
114 | split_tokens = []
115 | for token in self.basic_tokenizer.tokenize(text):
116 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
117 | split_tokens.append(sub_token)
118 |
119 | return split_tokens
120 |
121 | def convert_tokens_to_ids(self, tokens):
122 | return convert_tokens_to_ids(self.vocab, tokens)
123 |
124 |
125 | class CharTokenizer(object):
126 | """Runs end-to-end tokenziation."""
127 |
128 | def __init__(self, vocab_file, do_lower_case=True):
129 | self.vocab = load_vocab(vocab_file)
130 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
131 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
132 |
133 | def tokenize(self, text):
134 | split_tokens = []
135 | for token in self.basic_tokenizer.tokenize(text):
136 | for sub_token in token:
137 | split_tokens.append(sub_token)
138 |
139 | return split_tokens
140 |
141 | def convert_tokens_to_ids(self, tokens):
142 | return convert_tokens_to_ids(self.vocab, tokens)
143 |
144 |
145 | class BasicTokenizer(object):
146 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
147 |
148 | def __init__(self, do_lower_case=True):
149 | """Constructs a BasicTokenizer.
150 |
151 | Args:
152 | do_lower_case: Whether to lower case the input.
153 | """
154 | self.do_lower_case = do_lower_case
155 |
156 | def tokenize(self, text):
157 | """Tokenizes a piece of text."""
158 | text = convert_to_unicode(text)
159 | text = self._clean_text(text)
160 |
161 | # This was added on November 1st, 2018 for the multilingual and Chinese
162 | # models. This is also applied to the English models now, but it doesn't
163 | # matter since the English models were not trained on any Chinese data
164 | # and generally don't have any Chinese data in them (there are Chinese
165 | # characters in the vocabulary because Wikipedia does have some Chinese
166 | # words in the English Wikipedia.).
167 | text = self._tokenize_chinese_chars(text)
168 |
169 | orig_tokens = whitespace_tokenize(text)
170 | split_tokens = []
171 | for token in orig_tokens:
172 | if self.do_lower_case:
173 | token = token.lower()
174 | token = self._run_strip_accents(token)
175 | split_tokens.extend(self._run_split_on_punc(token))
176 |
177 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
178 | return output_tokens
179 |
180 | def _run_strip_accents(self, text):
181 | """Strips accents from a piece of text."""
182 | text = unicodedata.normalize("NFD", text)
183 | output = []
184 | for char in text:
185 | cat = unicodedata.category(char)
186 | if cat == "Mn":
187 | continue
188 | output.append(char)
189 | return "".join(output)
190 |
191 | def _run_split_on_punc(self, text):
192 | """Splits punctuation on a piece of text."""
193 | chars = list(text)
194 | i = 0
195 | start_new_word = True
196 | output = []
197 | while i < len(chars):
198 | char = chars[i]
199 | if _is_punctuation(char):
200 | output.append([char])
201 | start_new_word = True
202 | else:
203 | if start_new_word:
204 | output.append([])
205 | start_new_word = False
206 | output[-1].append(char)
207 | i += 1
208 |
209 | return ["".join(x) for x in output]
210 |
211 | def _tokenize_chinese_chars(self, text):
212 | """Adds whitespace around any CJK character."""
213 | output = []
214 | for char in text:
215 | cp = ord(char)
216 | if self._is_chinese_char(cp):
217 | output.append(" ")
218 | output.append(char)
219 | output.append(" ")
220 | else:
221 | output.append(char)
222 | return "".join(output)
223 |
224 | def _is_chinese_char(self, cp):
225 | """Checks whether CP is the codepoint of a CJK character."""
226 | # This defines a "chinese character" as anything in the CJK Unicode block:
227 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
228 | #
229 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
230 | # despite its name. The modern Korean Hangul alphabet is a different block,
231 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
232 | # space-separated words, so they are not treated specially and handled
233 | # like the all of the other languages.
234 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
235 | (cp >= 0x3400 and cp <= 0x4DBF) or #
236 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
237 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
238 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
239 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
240 | (cp >= 0xF900 and cp <= 0xFAFF) or #
241 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
242 | return True
243 |
244 | return False
245 |
246 | def _clean_text(self, text):
247 | """Performs invalid character removal and whitespace cleanup on text."""
248 | output = []
249 | for char in text:
250 | cp = ord(char)
251 | if cp == 0 or cp == 0xfffd or _is_control(char):
252 | continue
253 | if _is_whitespace(char):
254 | output.append(" ")
255 | else:
256 | output.append(char)
257 | return "".join(output)
258 |
259 |
260 | class WordpieceTokenizer(object):
261 | """Runs WordPiece tokenziation."""
262 |
263 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
264 | self.vocab = vocab
265 | self.unk_token = unk_token
266 | self.max_input_chars_per_word = max_input_chars_per_word
267 |
268 | def tokenize(self, text):
269 | """Tokenizes a piece of text into its word pieces.
270 |
271 | This uses a greedy longest-match-first algorithm to perform tokenization
272 | using the given vocabulary.
273 |
274 | For example:
275 | input = "unaffable"
276 | output = ["un", "##aff", "##able"]
277 |
278 | Args:
279 | text: A single token or whitespace separated tokens. This should have
280 | already been passed through `BasicTokenizer.
281 |
282 | Returns:
283 | A list of wordpiece tokens.
284 | """
285 |
286 | text = convert_to_unicode(text)
287 |
288 | output_tokens = []
289 | for token in whitespace_tokenize(text):
290 | chars = list(token)
291 | if len(chars) > self.max_input_chars_per_word:
292 | output_tokens.append(self.unk_token)
293 | continue
294 |
295 | is_bad = False
296 | start = 0
297 | sub_tokens = []
298 | while start < len(chars):
299 | end = len(chars)
300 | cur_substr = None
301 | while start < end:
302 | substr = "".join(chars[start:end])
303 | if start > 0:
304 | substr = "##" + substr
305 | if substr in self.vocab:
306 | cur_substr = substr
307 | break
308 | end -= 1
309 | if cur_substr is None:
310 | is_bad = True
311 | break
312 | sub_tokens.append(cur_substr)
313 | start = end
314 |
315 | if is_bad:
316 | output_tokens.append(self.unk_token)
317 | else:
318 | output_tokens.extend(sub_tokens)
319 | return output_tokens
320 |
321 |
322 | def _is_whitespace(char):
323 | """Checks whether `chars` is a whitespace character."""
324 | # \t, \n, and \r are technically contorl characters but we treat them
325 | # as whitespace since they are generally considered as such.
326 | if char == " " or char == "\t" or char == "\n" or char == "\r":
327 | return True
328 | cat = unicodedata.category(char)
329 | if cat == "Zs":
330 | return True
331 | return False
332 |
333 |
334 | def _is_control(char):
335 | """Checks whether `chars` is a control character."""
336 | # These are technically control characters but we count them as whitespace
337 | # characters.
338 | if char == "\t" or char == "\n" or char == "\r":
339 | return False
340 | cat = unicodedata.category(char)
341 | if cat.startswith("C"):
342 | return True
343 | return False
344 |
345 |
346 | def _is_punctuation(char):
347 | """Checks whether `chars` is a punctuation character."""
348 | cp = ord(char)
349 | # We treat all non-letter/number ASCII as punctuation.
350 | # Characters such as "^", "$", and "`" are not in the Unicode
351 | # Punctuation class but we treat them as punctuation anyways, for
352 | # consistency.
353 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
354 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
355 | return True
356 | cat = unicodedata.category(char)
357 | if cat.startswith("P"):
358 | return True
359 | return False
360 |
--------------------------------------------------------------------------------
/tokenization_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 | import tempfile
21 |
22 | import tokenization
23 | import tensorflow as tf
24 |
25 |
26 | class TokenizationTest(tf.test.TestCase):
27 |
28 | def test_full_tokenizer(self):
29 | vocab_tokens = [
30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
31 | "##ing", ","
32 | ]
33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
34 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
35 |
36 | vocab_file = vocab_writer.name
37 |
38 | tokenizer = tokenization.FullTokenizer(vocab_file)
39 | os.unlink(vocab_file)
40 |
41 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
42 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
43 |
44 | self.assertAllEqual(
45 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
46 |
47 | def test_chinese(self):
48 | tokenizer = tokenization.BasicTokenizer()
49 |
50 | self.assertAllEqual(
51 | tokenizer.tokenize(u"ah\u535A\u63A8zz"),
52 | [u"ah", u"\u535A", u"\u63A8", u"zz"])
53 |
54 | def test_basic_tokenizer_lower(self):
55 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True)
56 |
57 | self.assertAllEqual(
58 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
59 | ["hello", "!", "how", "are", "you", "?"])
60 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
61 |
62 | def test_basic_tokenizer_no_lower(self):
63 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False)
64 |
65 | self.assertAllEqual(
66 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
67 | ["HeLLo", "!", "how", "Are", "yoU", "?"])
68 |
69 | def test_wordpiece_tokenizer(self):
70 | vocab_tokens = [
71 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
72 | "##ing"
73 | ]
74 |
75 | vocab = {}
76 | for (i, token) in enumerate(vocab_tokens):
77 | vocab[token] = i
78 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab)
79 |
80 | self.assertAllEqual(tokenizer.tokenize(""), [])
81 |
82 | self.assertAllEqual(
83 | tokenizer.tokenize("unwanted running"),
84 | ["un", "##want", "##ed", "runn", "##ing"])
85 |
86 | self.assertAllEqual(
87 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
88 |
89 | def test_convert_tokens_to_ids(self):
90 | vocab_tokens = [
91 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
92 | "##ing"
93 | ]
94 |
95 | vocab = {}
96 | for (i, token) in enumerate(vocab_tokens):
97 | vocab[token] = i
98 |
99 | self.assertAllEqual(
100 | tokenization.convert_tokens_to_ids(
101 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9])
102 |
103 | def test_is_whitespace(self):
104 | self.assertTrue(tokenization._is_whitespace(u" "))
105 | self.assertTrue(tokenization._is_whitespace(u"\t"))
106 | self.assertTrue(tokenization._is_whitespace(u"\r"))
107 | self.assertTrue(tokenization._is_whitespace(u"\n"))
108 | self.assertTrue(tokenization._is_whitespace(u"\u00A0"))
109 |
110 | self.assertFalse(tokenization._is_whitespace(u"A"))
111 | self.assertFalse(tokenization._is_whitespace(u"-"))
112 |
113 | def test_is_control(self):
114 | self.assertTrue(tokenization._is_control(u"\u0005"))
115 |
116 | self.assertFalse(tokenization._is_control(u"A"))
117 | self.assertFalse(tokenization._is_control(u" "))
118 | self.assertFalse(tokenization._is_control(u"\t"))
119 | self.assertFalse(tokenization._is_control(u"\r"))
120 |
121 | def test_is_punctuation(self):
122 | self.assertTrue(tokenization._is_punctuation(u"-"))
123 | self.assertTrue(tokenization._is_punctuation(u"$"))
124 | self.assertTrue(tokenization._is_punctuation(u"`"))
125 | self.assertTrue(tokenization._is_punctuation(u"."))
126 |
127 | self.assertFalse(tokenization._is_punctuation(u"A"))
128 | self.assertFalse(tokenization._is_punctuation(u" "))
129 |
130 |
131 | if __name__ == "__main__":
132 | tf.test.main()
133 |
--------------------------------------------------------------------------------