├── output
└── README.txt
├── image
├── UbuntuV1_V2.png
└── Douban_Ecommerce.png
├── uncased_L-12_H-768_A-12
└── README.txt
├── scripts
├── ubuntu_test.sh
├── ubuntu_train.sh
└── adaptation.sh
├── __init__.py
├── data
└── Ubuntu_V1_Xu
│ ├── README.txt
│ ├── tokenization.py
│ ├── create_finetuning_data.py
│ └── create_adaptation_data.py
├── uncased_L-12_H-768_A-12_adapted
└── README.txt
├── compute_metrics.py
├── README.md
├── metrics.py
├── optimization.py
├── test.py
├── tokenization.py
├── train.py
├── adapt_switch.py
└── modeling_switch.py
/output/README.txt:
--------------------------------------------------------------------------------
1 |
2 | Models will be saved here.
3 |
--------------------------------------------------------------------------------
/image/UbuntuV1_V2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JasonForJoy/SA-BERT/HEAD/image/UbuntuV1_V2.png
--------------------------------------------------------------------------------
/image/Douban_Ecommerce.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JasonForJoy/SA-BERT/HEAD/image/Douban_Ecommerce.png
--------------------------------------------------------------------------------
/uncased_L-12_H-768_A-12/README.txt:
--------------------------------------------------------------------------------
1 |
2 | ====== Download the BERT base model ======
3 |
4 | link: https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
5 | Move to path: ./uncased_L-12_H-768_A-12
6 |
--------------------------------------------------------------------------------
/scripts/ubuntu_test.sh:
--------------------------------------------------------------------------------
1 |
2 | CUDA_VISIBLE_DEVICES=3 python -u ../test.py \
3 | --test_dir ../data/Ubuntu_V1_Xu/processed_test.tfrecord \
4 | --vocab_file ../uncased_L-12_H-768_A-12/vocab.txt \
5 | --bert_config_file ../uncased_L-12_H-768_A-12/bert_config.json \
6 | --max_seq_length 512 \
7 | --eval_batch_size 50 \
8 | --restore_model_dir ../output/Ubuntu_V1_Xu/1569550213 > log_test.txt 2>&1 &
9 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/scripts/ubuntu_train.sh:
--------------------------------------------------------------------------------
1 |
2 | CUDA_VISIBLE_DEVICES=3 python -u ../train.py \
3 | --task_name fine_tuning \
4 | --train_dir ../data/Ubuntu_V1_Xu/processed_train.tfrecord \
5 | --valid_dir ../data/Ubuntu_V1_Xu/processed_valid.tfrecord \
6 | --output_dir ../output/Ubuntu_V1_Xu \
7 | --do_lower_case True \
8 | --vocab_file ../uncased_L-12_H-768_A-12/vocab.txt \
9 | --bert_config_file ../uncased_L-12_H-768_A-12/bert_config.json \
10 | --init_checkpoint ../uncased_L-12_H-768_A-12/bert_model.ckpt \
11 | --max_seq_length 512 \
12 | --do_train True \
13 | --train_batch_size 25 \
14 | --learning_rate 2e-5 \
15 | --num_train_epochs 10 \
16 | --warmup_proportion 0.1 > log_train.txt 2>&1 &
17 |
--------------------------------------------------------------------------------
/scripts/adaptation.sh:
--------------------------------------------------------------------------------
1 |
2 | CUDA_VISIBLE_DEVICES=3 python -u ../adapt_switch.py \
3 | --task_name adaptation \
4 | --sample_num 5000000 \
5 | --mid_save_step 15000 \
6 | --input_file ../data/Ubuntu_V1_Xu/pretrain_data.tfrecord \
7 | --output_dir ../uncased_L-12_H-768_A-12_adapted \
8 | --vocab_file ../uncased_L-12_H-768_A-12/vocab.txt \
9 | --bert_config_file ../uncased_L-12_H-768_A-12/bert_config.json \
10 | --init_checkpoint ../uncased_L-12_H-768_A-12/bert_model.ckpt \
11 | --max_seq_length 512 \
12 | --max_predictions_per_seq 25 \
13 | --train_batch_size 20 \
14 | --eval_batch_size 20 \
15 | --learning_rate 5e-5 \
16 | --num_train_epochs 1 \
17 | --warmup_proportion 0.1 > log_adaptation.txt 2>&1 &
18 |
--------------------------------------------------------------------------------
/data/Ubuntu_V1_Xu/README.txt:
--------------------------------------------------------------------------------
1 |
2 | ====== Download the dataset ======
3 |
4 | Take Ubuntu_V1 as an example
5 | link: https://drive.google.com/file/d/1-rNv34hLoZr300JF3v7nuLswM7GRqeNc/view
6 | Move to path: /data/Ubuntu_V1_Xu/Ubuntu_Corpus_V1
7 |
8 | If you use the processed dataset, please cite the following paper:
9 |
10 | @inproceedings{Gu:2019:IMN:3357384.3358140,
11 | author = {Gu, Jia-Chen and
12 | Ling, Zhen-Hua and
13 | Liu, Quan},
14 | title = {Interactive Matching Network for Multi-Turn Response Selection in Retrieval-Based Chatbots},
15 | booktitle = {Proceedings of the 28th ACM International Conference on Information and Knowledge Management},
16 | series = {CIKM '19},
17 | year = {2019},
18 | isbn = {978-1-4503-6976-3},
19 | location = {Beijing, China},
20 | pages = {2321--2324},
21 | url = {http://doi.acm.org/10.1145/3357384.3358140},
22 | doi = {10.1145/3357384.3358140},
23 | acmid = {3358140},
24 | publisher = {ACM},
25 | }
--------------------------------------------------------------------------------
/uncased_L-12_H-768_A-12_adapted/README.txt:
--------------------------------------------------------------------------------
1 |
2 | ====== Download the ADAPTED BERT base model ======
3 |
4 | We provide the model adapted on Ubuntu V1
5 | link: https://drive.google.com/file/d/1M8V018XZbVDo4Xq96pCLFRt6yVzoKtjH/view?usp=sharing
6 | Move to path: ./uncased_L-12_H-768_A-12_adapted
7 |
8 | If you use the adapted model, please cite the following paper:
9 |
10 | @inproceedings{Gu:2020:SABERT:3340531.3412330,
11 | author = {Gu, Jia-Chen and
12 | Li, Tianda and
13 | Liu, Quan and
14 | Ling, Zhen-Hua and
15 | Su, Zhiming and
16 | Wei, Si and
17 | Zhu, Xiaodan
18 | },
19 | title = {Speaker-Aware BERT for Multi-Turn Response Selection in Retrieval-Based Chatbots},
20 | booktitle = {Proceedings of the 29th ACM International Conference on Information and Knowledge Management},
21 | series = {CIKM '20},
22 | year = {2020},
23 | isbn = {978-1-4503-6859-9},
24 | location = {Virtual Event, Ireland},
25 | url = {http://doi.acm.org/10.1145/3340531.3412330},
26 | doi = {10.1145/3340531.3412330},
27 | acmid = {3412330},
28 | publisher = {ACM},
29 | }
30 |
--------------------------------------------------------------------------------
/compute_metrics.py:
--------------------------------------------------------------------------------
1 | '''
2 | Load the output_test.txt file and compute the metrics
3 | '''
4 |
5 |
6 | import random
7 | from collections import defaultdict
8 | import metrics
9 |
10 |
11 | test_out_filename = "./output/Ubuntu_V1_Xu/1596330255/output_test.txt" # modify this variable to the path to the testing model
12 | print("*"*20 + test_out_filename + "*"*20 + "\n")
13 |
14 | with open(test_out_filename, 'r') as f:
15 |
16 | # candidate size = 10
17 | results = defaultdict(list)
18 | lines = f.readlines()
19 | for line in lines[1:]:
20 | line = line.strip().split('\t')
21 | us_id = line[0]
22 | r_id = line[1]
23 | prob_score = float(line[2])
24 | label = float(line[4])
25 | results[us_id].append((r_id, label, prob_score))
26 |
27 | accu, precision, recall, f1, loss = metrics.classification_metrics(results)
28 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss))
29 | total_valid_query = metrics.get_num_valid_query(results)
30 | mvp = metrics.mean_average_precision(results)
31 | mrr = metrics.mean_reciprocal_rank(results)
32 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tNum_query: {}'.format(
33 | mvp, mrr, total_valid_query))
34 | top_1_precision = metrics.top_k_precision(results, k=1)
35 | top_2_precision = metrics.top_k_precision(results, k=2)
36 | top_5_precision = metrics.top_k_precision(results, k=5)
37 | print('Recall_10@1: {}\tRecall_10@2: {}\tRecall_10@5: {}\n'.format(
38 | top_1_precision, top_2_precision, top_5_precision))
39 |
40 | # candidate size = 2, the results may vary at different runs because we sample the negative candidate randomly
41 | results_bin = defaultdict(list)
42 | for us_id, candidates in results.items():
43 | false_candidates = []
44 | for candidate in candidates:
45 | r_id, label, prob_score = candidate
46 | if label == 1.0:
47 | results_bin[us_id].append(candidate)
48 | if label == 0.0:
49 | false_candidates.append(candidate)
50 | false_candidate = random.sample(false_candidates, 1)
51 | results_bin[us_id].append(false_candidate[0])
52 |
53 | accu, precision, recall, f1, loss = metrics.classification_metrics(results_bin)
54 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss))
55 | total_valid_query = metrics.get_num_valid_query(results_bin)
56 | mvp = metrics.mean_average_precision(results_bin)
57 | mrr = metrics.mean_reciprocal_rank(results_bin)
58 | top_1_precision = metrics.top_k_precision(results_bin, k=1)
59 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tNum_query: {}'.format(
60 | mvp, mrr, total_valid_query))
61 | print('Recall_2@1: {}\n'.format(
62 | top_1_precision))
63 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Speaker-Aware BERT for Multi-Turn Response Selection
2 | This repository contains the source code and pre-trained models for the CIKM 2020 paper [Speaker-Aware BERT for Multi-Turn Response Selection in Retrieval-Based Chatbots](https://arxiv.org/pdf/2004.03588.pdf) by Gu et al.
3 |
4 |
5 | ## Results
6 |
7 |
8 |
9 |
10 | ## Dependencies
11 | Python 3.6
12 | Tensorflow 1.13.1
13 |
14 |
15 | ## Download
16 | - Download the [BERT released by the Google research](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip),
17 | and move to path: ./uncased_L-12_H-768_A-12
18 |
19 | - We also provide the [BERT adapted on the Ubuntu V1 dataset](https://drive.google.com/file/d/1M8V018XZbVDo4Xq96pCLFRt6yVzoKtjH/view?usp=sharing),
20 | and move to path: ./uncased_L-12_H-768_A-12_adapted. You just need to fine tune it to reproduce our results.
21 |
22 | - Download the [Ubuntu V1 dataset](https://drive.google.com/file/d/1-rNv34hLoZr300JF3v7nuLswM7GRqeNc/view),
23 | and move to path: ./data/Ubuntu_V1_Xu/Ubuntu_Corpus_V1
24 |
25 |
26 | ## Adaptation
27 | Create the adaptation data.
28 | ```
29 | cd data/Ubuntu_V1_Xu/
30 | python create_adaptation_data.py
31 | ```
32 | Running the adaptation process.
33 | ```
34 | cd scripts/
35 | bash adaptation.sh
36 | ```
37 | The adapted model will be saved to the path ```./uncased_L-12_H-768_A-12_adapted```.
38 | Modify the filenames in this folder to make it the same as those in Google's BERT.
39 |
40 |
41 | ## Training
42 | Create the fine-tuning data.
43 | ```
44 | cd data/Ubuntu_V1_Xu/
45 | python create_finetuning_data.py
46 | ```
47 | Running the fine-tuning process.
48 |
49 | ```
50 | cd scripts/
51 | bash ubuntu_train.sh
52 | ```
53 |
54 | ## Testing
55 | Modify the variable ```restore_model_dir``` in ```ubuntu_test.sh```
56 | ```
57 | cd scripts/
58 | bash ubuntu_v1_test.sh
59 | ```
60 | A "output_test.txt" file which records scores for each context-response pair will be saved to the path of ```restore_model_dir```.
61 | Modify the variable ```test_out_filename``` in ```compute_metrics.py``` and then run the following command, various metrics will be shown.
62 | ```
63 | python compute_metrics.py
64 | ```
65 |
66 |
67 | ## Cite
68 | If you use the source code and pre-trained models, please cite the following paper:
69 | **"Speaker-Aware BERT for Multi-Turn Response Selection in Retrieval-Based Chatbots"**
70 | Jia-Chen Gu, Tianda Li, Quan Liu, Zhen-Hua Ling, Zhiming Su, Si Wei, Xiaodan Zhu. _CIKM (2020)_
71 |
72 | ```
73 | @inproceedings{Gu:2020:SABERT:3340531.3412330,
74 | author = {Gu, Jia-Chen and
75 | Li, Tianda and
76 | Liu, Quan and
77 | Ling, Zhen-Hua and
78 | Su, Zhiming and
79 | Wei, Si and
80 | Zhu, Xiaodan
81 | },
82 | title = {Speaker-Aware BERT for Multi-Turn Response Selection in Retrieval-Based Chatbots},
83 | booktitle = {Proceedings of the 29th ACM International Conference on Information and Knowledge Management},
84 | series = {CIKM '20},
85 | year = {2020},
86 | isbn = {978-1-4503-6859-9},
87 | location = {Virtual Event, Ireland},
88 | pages = {2041--2044},
89 | url = {http://doi.acm.org/10.1145/3340531.3412330},
90 | doi = {10.1145/3340531.3412330},
91 | acmid = {3412330},
92 | publisher = {ACM},
93 | }
94 | ```
95 |
96 |
97 | ## Update
98 | Please feel free to open issues if you have some problems.
99 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import operator
2 | import math
3 |
4 |
5 | def is_valid_query(v):
6 | num_pos = 0
7 | num_neg = 0
8 | for aid, label, score in v:
9 | if label > 0:
10 | num_pos += 1
11 | else:
12 | num_neg += 1
13 | if num_pos > 0 and num_neg > 0:
14 | return True
15 | else:
16 | return False
17 |
18 |
19 | def get_num_valid_query(results):
20 | num_query = 0
21 | for k, v in results.items():
22 | if not is_valid_query(v):
23 | continue
24 | num_query += 1
25 | return num_query
26 |
27 |
28 | def top_1_precision(results):
29 | num_query = 0
30 | top_1_correct = 0.0
31 | for k, v in results.items():
32 | if not is_valid_query(v):
33 | continue
34 | num_query += 1
35 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
36 | aid, label, score = sorted_v[0]
37 | if label > 0:
38 | top_1_correct += 1
39 |
40 | if num_query > 0:
41 | return top_1_correct / num_query
42 | else:
43 | return 0.0
44 |
45 |
46 | def mean_reciprocal_rank(results):
47 | num_query = 0
48 | mrr = 0.0
49 | for k, v in results.items():
50 | if not is_valid_query(v):
51 | continue
52 |
53 | num_query += 1
54 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
55 | for i, rec in enumerate(sorted_v):
56 | aid, label, score = rec
57 | if label > 0:
58 | mrr += 1.0 / (i + 1)
59 | break
60 |
61 | if num_query == 0:
62 | return 0.0
63 | else:
64 | mrr = mrr / num_query
65 | return mrr
66 |
67 |
68 | def mean_average_precision(results):
69 | num_query = 0
70 | mvp = 0.0
71 | for k, v in results.items():
72 | if not is_valid_query(v):
73 | continue
74 |
75 | num_query += 1
76 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
77 | num_relevant_doc = 0.0
78 | avp = 0.0
79 | for i, rec in enumerate(sorted_v):
80 | aid, label, score = rec
81 | if label == 1:
82 | num_relevant_doc += 1
83 | precision = num_relevant_doc / (i + 1)
84 | avp += precision
85 | avp = avp / num_relevant_doc
86 | mvp += avp
87 |
88 | if num_query == 0:
89 | return 0.0
90 | else:
91 | mvp = mvp / num_query
92 | return mvp
93 |
94 |
95 | def classification_metrics(results):
96 | total_num = 0
97 | total_correct = 0
98 | true_positive = 0
99 | positive_correct = 0
100 | predicted_positive = 0
101 |
102 | loss = 0.0;
103 | for k, v in results.items():
104 | for rec in v:
105 | total_num += 1
106 | aid, label, score = rec
107 |
108 | if score > 0.5:
109 | predicted_positive += 1
110 |
111 | if label > 0:
112 | true_positive += 1
113 | loss += -math.log(score + 1e-12)
114 | else:
115 | loss += -math.log(1.0 - score + 1e-12);
116 |
117 | if score > 0.5 and label > 0:
118 | total_correct += 1
119 | positive_correct += 1
120 |
121 | if score < 0.5 and label < 0.5:
122 | total_correct += 1
123 |
124 | accuracy = float(total_correct) / total_num
125 | precision = float(positive_correct) / (predicted_positive + 1e-12)
126 | recall = float(positive_correct) / true_positive
127 | F1 = 2.0 * precision * recall / (1e-12 + precision + recall)
128 | return accuracy, precision, recall, F1, loss / total_num;
129 |
130 |
131 | def top_k_precision(results, k=1):
132 | num_query = 0
133 | top_1_correct = 0.0
134 | for key, v in results.items():
135 | if not is_valid_query(v):
136 | continue
137 | num_query += 1
138 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
139 | if k == 1:
140 | aid, label, score = sorted_v[0]
141 | if label > 0:
142 | top_1_correct += 1
143 | elif k == 2:
144 | aid1, label1, score1 = sorted_v[0]
145 | aid2, label2, score2 = sorted_v[1]
146 | if label1 > 0 or label2 > 0:
147 | top_1_correct += 1
148 | elif k == 5:
149 | for vv in sorted_v[0:5]:
150 | label = vv[1]
151 | if label > 0:
152 | top_1_correct += 1
153 | break
154 | else:
155 | raise BaseException
156 |
157 | if num_query > 0:
158 | return top_1_correct/num_query
159 | else:
160 | return 0.0
--------------------------------------------------------------------------------
/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
26 | """Creates an optimizer training op."""
27 | global_step = tf.train.get_or_create_global_step()
28 |
29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
30 |
31 | # Implements linear decay of the learning rate.
32 | learning_rate = tf.train.polynomial_decay(
33 | learning_rate,
34 | global_step,
35 | num_train_steps,
36 | end_learning_rate=0.0,
37 | power=1.0,
38 | cycle=False)
39 |
40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
41 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
42 | if num_warmup_steps:
43 | global_steps_int = tf.cast(global_step, tf.int32)
44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
45 |
46 | global_steps_float = tf.cast(global_steps_int, tf.float32)
47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
48 |
49 | warmup_percent_done = global_steps_float / warmup_steps_float
50 | warmup_learning_rate = init_lr * warmup_percent_done
51 |
52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
53 | learning_rate = (
54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
55 |
56 | # It is recommended that you use this optimizer for fine tuning, since this
57 | # is how the model was trained (note that the Adam m/v variables are NOT
58 | # loaded from init_checkpoint.)
59 | optimizer = AdamWeightDecayOptimizer(
60 | learning_rate=learning_rate,
61 | weight_decay_rate=0.01,
62 | beta_1=0.9,
63 | beta_2=0.999,
64 | epsilon=1e-6,
65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
66 |
67 | if use_tpu:
68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
69 |
70 | tvars = tf.trainable_variables()
71 | grads = tf.gradients(loss, tvars)
72 |
73 | # This is how the model was pre-trained.
74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
75 |
76 | train_op = optimizer.apply_gradients(
77 | zip(grads, tvars), global_step=global_step)
78 |
79 | # Normally the global step update is done inside of `apply_gradients`.
80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
81 | # a different optimizer, you should probably take this line out.
82 | new_global_step = global_step + 1
83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
84 | return train_op
85 |
86 |
87 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
88 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
89 |
90 | def __init__(self,
91 | learning_rate,
92 | weight_decay_rate=0.0,
93 | beta_1=0.9,
94 | beta_2=0.999,
95 | epsilon=1e-6,
96 | exclude_from_weight_decay=None,
97 | name="AdamWeightDecayOptimizer"):
98 | """Constructs a AdamWeightDecayOptimizer."""
99 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
100 |
101 | self.learning_rate = learning_rate
102 | self.weight_decay_rate = weight_decay_rate
103 | self.beta_1 = beta_1
104 | self.beta_2 = beta_2
105 | self.epsilon = epsilon
106 | self.exclude_from_weight_decay = exclude_from_weight_decay
107 |
108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
109 | """See base class."""
110 | assignments = []
111 | for (grad, param) in grads_and_vars:
112 | if grad is None or param is None:
113 | continue
114 |
115 | param_name = self._get_variable_name(param.name)
116 |
117 | m = tf.get_variable(
118 | name=param_name + "/adam_m",
119 | shape=param.shape.as_list(),
120 | dtype=tf.float32,
121 | trainable=False,
122 | initializer=tf.zeros_initializer())
123 | v = tf.get_variable(
124 | name=param_name + "/adam_v",
125 | shape=param.shape.as_list(),
126 | dtype=tf.float32,
127 | trainable=False,
128 | initializer=tf.zeros_initializer())
129 |
130 | # Standard Adam update.
131 | next_m = (
132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
133 | next_v = (
134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
135 | tf.square(grad)))
136 |
137 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
138 |
139 | # Just adding the square of the weights to the loss function is *not*
140 | # the correct way of using L2 regularization/weight decay with Adam,
141 | # since that will interact with the m and v parameters in strange ways.
142 | #
143 | # Instead we want ot decay the weights in a manner that doesn't interact
144 | # with the m/v parameters. This is equivalent to adding the square
145 | # of the weights to the loss with plain (non-momentum) SGD.
146 | if self._do_use_weight_decay(param_name):
147 | update += self.weight_decay_rate * param
148 |
149 | update_with_lr = self.learning_rate * update
150 |
151 | next_param = param - update_with_lr
152 |
153 | assignments.extend(
154 | [param.assign(next_param),
155 | m.assign(next_m),
156 | v.assign(next_v)])
157 | return tf.group(*assignments, name=name)
158 |
159 | def _do_use_weight_decay(self, param_name):
160 | """Whether to use L2 weight decay for `param_name`."""
161 | if not self.weight_decay_rate:
162 | return False
163 | if self.exclude_from_weight_decay:
164 | for r in self.exclude_from_weight_decay:
165 | if re.search(r, param_name) is not None:
166 | return False
167 | return True
168 |
169 | def _get_variable_name(self, param_name):
170 | """Get the variable name from the tensor name."""
171 | m = re.match("^(.*):\\d+$", param_name)
172 | if m is not None:
173 | param_name = m.group(1)
174 | return param_name
175 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """BERT finetuning runner."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import random
23 | import operator
24 | from time import time
25 | from collections import defaultdict
26 | import tensorflow as tf
27 | import optimization
28 | import tokenization
29 | import modeling_switch as modeling
30 | import metrics
31 |
32 | flags = tf.flags
33 | FLAGS = flags.FLAGS
34 |
35 | ## Required parameters
36 | flags.DEFINE_string("test_dir", 'test.tfrecord',
37 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.")
38 |
39 | flags.DEFINE_string("restore_model_dir", 'output/',
40 | "The output directory where the model checkpoints have been written.")
41 |
42 | flags.DEFINE_string("task_name", 'TestModel',
43 | "The name of the task.")
44 |
45 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json',
46 | "The config json file corresponding to the pre-trained BERT model. "
47 | "This specifies the model architecture.")
48 |
49 | flags.DEFINE_integer("max_seq_length", 320,
50 | "The maximum total input sequence length after WordPiece tokenization. "
51 | "Sequences longer than this will be truncated, and sequences shorter "
52 | "than this will be padded.")
53 |
54 | flags.DEFINE_bool("do_eval", True,
55 | "Whether to run eval on the dev set.")
56 |
57 | flags.DEFINE_integer("eval_batch_size", 32,
58 | "Total batch size for predict.")
59 |
60 |
61 | def print_configuration_op(FLAGS):
62 | print('My Configurations:')
63 | for name, value in FLAGS.__flags.items():
64 | value=value.value
65 | if type(value) == float:
66 | print(' %s:\t %f'%(name, value))
67 | elif type(value) == int:
68 | print(' %s:\t %d'%(name, value))
69 | elif type(value) == str:
70 | print(' %s:\t %s'%(name, value))
71 | elif type(value) == bool:
72 | print(' %s:\t %s'%(name, value))
73 | else:
74 | print('%s:\t %s' % (name, value))
75 | print('End of configuration')
76 |
77 |
78 | def total_sample(file_name):
79 | sample_nums = 0
80 | for record in tf.python_io.tf_record_iterator(file_name):
81 | sample_nums += 1
82 | return sample_nums
83 |
84 |
85 | def print_weight(name):
86 | with open('valid/weight_log' + name + str(random.randint(0, 100)), 'w') as fw:
87 | variables = tf.trainable_variables()
88 | for variable in variables:
89 | fw.write(str(variable.eval()))
90 | fw.write('\n')
91 |
92 |
93 | def parse_exmp(serial_exmp):
94 | input_data = tf.parse_single_example(serial_exmp,
95 | features={
96 | "ques_ids":
97 | tf.FixedLenFeature([], tf.int64),
98 | "ans_ids":
99 | tf.FixedLenFeature([], tf.int64),
100 | "input_sents":
101 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
102 | "input_mask":
103 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
104 | "segment_ids":
105 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
106 | "switch_ids":
107 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
108 | "label_ids":
109 | tf.FixedLenFeature([], tf.float32),
110 | }
111 | )
112 | # So cast all int64 to int32.
113 | for name in list(input_data.keys()):
114 | t = input_data[name]
115 | if t.dtype == tf.int64:
116 | t = tf.to_int32(t)
117 | input_data[name] = t
118 |
119 | ques_ids = input_data["ques_ids"]
120 | ans_ids = input_data['ans_ids']
121 | sents = input_data["input_sents"]
122 | mask = input_data["input_mask"]
123 | segment_ids= input_data["segment_ids"]
124 | switch_ids= input_data["switch_ids"]
125 | labels = input_data['label_ids']
126 | return ques_ids, ans_ids, sents, mask, segment_ids, switch_ids, labels
127 |
128 |
129 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, switch_ids, labels, ques_ids, ans_ids,
130 | num_labels, use_one_hot_embeddings):
131 | """Creates a classification model."""
132 | model = modeling.BertModel(
133 | config=bert_config,
134 | is_training=is_training,
135 | input_ids=input_ids,
136 | input_mask=input_mask,
137 | token_type_ids=segment_ids,
138 | switch_ids=switch_ids,
139 | use_one_hot_embeddings=use_one_hot_embeddings)
140 |
141 | # In the demo, we are doing a simple classification task on the entire
142 | # segment.
143 | #
144 | # If you want to use the token-level output, use model.get_sequence_output()
145 | # instead.
146 | target_loss_weight = [1.0, 1.0]
147 | target_loss_weight = tf.convert_to_tensor(target_loss_weight)
148 |
149 | flagx = tf.cast(tf.greater(labels, 0), dtype=tf.float32)
150 | flagy = tf.cast(tf.equal(labels, 0), dtype=tf.float32)
151 |
152 | all_target_loss = target_loss_weight[1] * flagx + target_loss_weight[0] * flagy
153 |
154 | output_layer = model.get_pooled_output()
155 |
156 | hidden_size = output_layer.shape[-1].value
157 |
158 | output_weights = tf.get_variable(
159 | "output_weights", [num_labels, hidden_size],
160 | initializer=tf.truncated_normal_initializer(stddev=0.02))
161 |
162 | output_bias = tf.get_variable(
163 | "output_bias", [num_labels], initializer=tf.zeros_initializer())
164 |
165 | with tf.variable_scope("loss"):
166 | # if is_training:
167 | # # I.e., 0.1 dropout
168 | # output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
169 | output_layer = tf.layers.dropout(output_layer, rate=0.1, training=is_training)
170 |
171 | logits = tf.matmul(output_layer, output_weights, transpose_b=True)
172 | logits = tf.nn.bias_add(logits, output_bias)
173 |
174 | probabilities = tf.sigmoid(logits, name="prob")
175 | logits = tf.squeeze(logits,[1])
176 | losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
177 | losses = tf.multiply(losses, all_target_loss)
178 |
179 | mean_loss = tf.reduce_mean(losses, name="mean_loss") + sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
180 |
181 | with tf.name_scope("accuracy"):
182 | correct_prediction = tf.equal(tf.sign(probabilities - 0.5), tf.sign(labels - 0.5))
183 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy")
184 | #
185 | # one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
186 | #
187 | # per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
188 | # loss = tf.reduce_mean(per_example_loss)
189 |
190 | return mean_loss, logits, probabilities, accuracy, model, output_layer
191 |
192 |
193 | best_score = 0.0
194 | def run_test(dir_path, op_name, sess, training, accuracy, prob, pair_ids, output_layer):
195 | results = defaultdict(list)
196 | num_test = 0
197 | num_correct = 0.0
198 | n_updates = 0
199 | mrr = 0
200 | t0 = time()
201 | try:
202 | while True:
203 | n_updates += 1
204 |
205 | batch_accuracy, predicted_prob, pair_ = sess.run([accuracy, prob, pair_ids], feed_dict={training: False})
206 | question_id, answer_id, label = pair_
207 |
208 | num_test += len(predicted_prob)
209 | num_correct += len(predicted_prob) * batch_accuracy
210 | for i, prob_score in enumerate(predicted_prob):
211 | # question_id, answer_id, label = pair_id[i]
212 | results[question_id[i]].append((answer_id[i], label[i], prob_score[0]))
213 |
214 | if n_updates%2000 == 0:
215 | tf.logging.info("n_update %d , %s: Mins Used: %.2f" %
216 | (n_updates, op_name, (time() - t0) / 60.0))
217 |
218 | except tf.errors.OutOfRangeError:
219 | # calculate top-1 precision
220 | print('num_test_samples: {} test_accuracy: {}'.format(num_test, num_correct / num_test))
221 | accu, precision, recall, f1, loss = metrics.classification_metrics(results)
222 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss))
223 |
224 | mvp = metrics.mean_average_precision(results)
225 | mrr = metrics.mean_reciprocal_rank(results)
226 | top_1_precision = metrics.top_1_precision(results)
227 | total_valid_query = metrics.get_num_valid_query(results)
228 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'.format(
229 | mvp, mrr, top_1_precision, total_valid_query))
230 |
231 | out_path = os.path.join(dir_path, "output_test.txt")
232 | print("Saving evaluation to {}".format(out_path))
233 | with open(out_path, 'w') as f:
234 | f.write("query_id\tdocument_id\tscore\trank\trelevance\n")
235 | for us_id, v in results.items():
236 | v.sort(key=operator.itemgetter(2), reverse=True)
237 | for i, rec in enumerate(v):
238 | r_id, label, prob_score = rec
239 | rank = i+1
240 | f.write('{}\t{}\t{}\t{}\t{}\n'.format(us_id, r_id, prob_score, rank, label))
241 | return mrr
242 |
243 |
244 | def main(_):
245 | tf.logging.set_verbosity(tf.logging.INFO)
246 |
247 | print_configuration_op(FLAGS)
248 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
249 |
250 | test_data_size = total_sample(FLAGS.test_dir)
251 | tf.logging.info('test data size: {}'.format(test_data_size))
252 |
253 | filenames = tf.placeholder(tf.string, shape=[None])
254 | shuffle_size = tf.placeholder(tf.int64)
255 | dataset = tf.data.TFRecordDataset(filenames)
256 | dataset = dataset.map(parse_exmp) # Parse the record into tensors.
257 | dataset = dataset.repeat(1)
258 | # dataset = dataset.shuffle(shuffle_size)
259 | dataset = dataset.batch(FLAGS.eval_batch_size)
260 | iterator = dataset.make_initializable_iterator()
261 | ques_ids, ans_ids, sents, mask, segment_ids, switch_ids, labels = iterator.get_next() # output dir
262 | pair_ids = [ques_ids, ans_ids, labels]
263 |
264 | training = tf.placeholder(tf.bool)
265 | mean_loss, logits, probabilities, accuracy, model, output_layer = create_model(bert_config,
266 | is_training = training,
267 | input_ids = sents,
268 | input_mask = mask,
269 | segment_ids = segment_ids,
270 | switch_ids = switch_ids,
271 | labels = labels,
272 | ques_ids = ques_ids,
273 | ans_ids = ans_ids,
274 | num_labels = 1,
275 | use_one_hot_embeddings = False)
276 |
277 |
278 | config = tf.ConfigProto(allow_soft_placement=True)
279 | config.gpu_options.allow_growth = True
280 |
281 | if FLAGS.do_eval:
282 | with tf.Session(config=config) as sess:
283 | tf.logging.info("*** Restore model ***")
284 |
285 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir)
286 | variables = tf.trainable_variables()
287 | saver = tf.train.Saver(variables)
288 | saver.restore(sess, ckpt.model_checkpoint_path)
289 |
290 | tf.logging.info('Test begin')
291 | sess.run(iterator.initializer,
292 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1})
293 | run_test(FLAGS.restore_model_dir, "test", sess, training, accuracy, probabilities, pair_ids, output_layer)
294 |
295 |
296 | if __name__ == "__main__":
297 |
298 | tf.app.run()
299 |
300 |
--------------------------------------------------------------------------------
/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | for item in items:
140 | output.append(vocab[item])
141 | return output
142 |
143 |
144 | def convert_tokens_to_ids(vocab, tokens):
145 | return convert_by_vocab(vocab, tokens)
146 |
147 |
148 | def convert_ids_to_tokens(inv_vocab, ids):
149 | return convert_by_vocab(inv_vocab, ids)
150 |
151 |
152 | def whitespace_tokenize(text):
153 | """Runs basic whitespace cleaning and splitting on a piece of text."""
154 | text = text.strip()
155 | if not text:
156 | return []
157 | tokens = text.split()
158 | return tokens
159 |
160 |
161 | class FullTokenizer(object):
162 | """Runs end-to-end tokenziation."""
163 |
164 | def __init__(self, vocab_file, do_lower_case=True):
165 | self.vocab = load_vocab(vocab_file)
166 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
169 |
170 | def tokenize(self, text):
171 | split_tokens = []
172 | for token in self.basic_tokenizer.tokenize(text):
173 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
174 | split_tokens.append(sub_token)
175 |
176 | return split_tokens
177 |
178 | def convert_tokens_to_ids(self, tokens):
179 | return convert_by_vocab(self.vocab, tokens)
180 |
181 | def convert_ids_to_tokens(self, ids):
182 | return convert_by_vocab(self.inv_vocab, ids)
183 |
184 |
185 | class BasicTokenizer(object):
186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
187 |
188 | def __init__(self, do_lower_case=True):
189 | """Constructs a BasicTokenizer.
190 |
191 | Args:
192 | do_lower_case: Whether to lower case the input.
193 | """
194 | self.do_lower_case = do_lower_case
195 |
196 | def tokenize(self, text):
197 | """Tokenizes a piece of text."""
198 | text = convert_to_unicode(text)
199 | text = self._clean_text(text)
200 |
201 | # This was added on November 1st, 2018 for the multilingual and Chinese
202 | # models. This is also applied to the English models now, but it doesn't
203 | # matter since the English models were not trained on any Chinese data
204 | # and generally don't have any Chinese data in them (there are Chinese
205 | # characters in the vocabulary because Wikipedia does have some Chinese
206 | # words in the English Wikipedia.).
207 | text = self._tokenize_chinese_chars(text)
208 |
209 | orig_tokens = whitespace_tokenize(text)
210 | split_tokens = []
211 | for token in orig_tokens:
212 | if self.do_lower_case:
213 | token = token.lower()
214 | token = self._run_strip_accents(token)
215 | split_tokens.extend(self._run_split_on_punc(token))
216 |
217 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
218 | return output_tokens
219 |
220 | def _run_strip_accents(self, text):
221 | """Strips accents from a piece of text."""
222 | text = unicodedata.normalize("NFD", text)
223 | output = []
224 | for char in text:
225 | cat = unicodedata.category(char)
226 | if cat == "Mn":
227 | continue
228 | output.append(char)
229 | return "".join(output)
230 |
231 | def _run_split_on_punc(self, text):
232 | """Splits punctuation on a piece of text."""
233 | chars = list(text)
234 | i = 0
235 | start_new_word = True
236 | output = []
237 | while i < len(chars):
238 | char = chars[i]
239 | if _is_punctuation(char):
240 | output.append([char])
241 | start_new_word = True
242 | else:
243 | if start_new_word:
244 | output.append([])
245 | start_new_word = False
246 | output[-1].append(char)
247 | i += 1
248 |
249 | return ["".join(x) for x in output]
250 |
251 | def _tokenize_chinese_chars(self, text):
252 | """Adds whitespace around any CJK character."""
253 | output = []
254 | for char in text:
255 | cp = ord(char)
256 | if self._is_chinese_char(cp):
257 | output.append(" ")
258 | output.append(char)
259 | output.append(" ")
260 | else:
261 | output.append(char)
262 | return "".join(output)
263 |
264 | def _is_chinese_char(self, cp):
265 | """Checks whether CP is the codepoint of a CJK character."""
266 | # This defines a "chinese character" as anything in the CJK Unicode block:
267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
268 | #
269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
270 | # despite its name. The modern Korean Hangul alphabet is a different block,
271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
272 | # space-separated words, so they are not treated specially and handled
273 | # like the all of the other languages.
274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
275 | (cp >= 0x3400 and cp <= 0x4DBF) or #
276 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
277 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
278 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
280 | (cp >= 0xF900 and cp <= 0xFAFF) or #
281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
282 | return True
283 |
284 | return False
285 |
286 | def _clean_text(self, text):
287 | """Performs invalid character removal and whitespace cleanup on text."""
288 | output = []
289 | for char in text:
290 | cp = ord(char)
291 | if cp == 0 or cp == 0xfffd or _is_control(char):
292 | continue
293 | if _is_whitespace(char):
294 | output.append(" ")
295 | else:
296 | output.append(char)
297 | return "".join(output)
298 |
299 |
300 | class WordpieceTokenizer(object):
301 | """Runs WordPiece tokenziation."""
302 |
303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
304 | self.vocab = vocab
305 | self.unk_token = unk_token
306 | self.max_input_chars_per_word = max_input_chars_per_word
307 |
308 | def tokenize(self, text):
309 | """Tokenizes a piece of text into its word pieces.
310 |
311 | This uses a greedy longest-match-first algorithm to perform tokenization
312 | using the given vocabulary.
313 |
314 | For example:
315 | input = "unaffable"
316 | output = ["un", "##aff", "##able"]
317 |
318 | Args:
319 | text: A single token or whitespace separated tokens. This should have
320 | already been passed through `BasicTokenizer.
321 |
322 | Returns:
323 | A list of wordpiece tokens.
324 | """
325 |
326 | text = convert_to_unicode(text)
327 |
328 | output_tokens = []
329 | for token in whitespace_tokenize(text):
330 | chars = list(token)
331 | if len(chars) > self.max_input_chars_per_word:
332 | output_tokens.append(self.unk_token)
333 | continue
334 |
335 | is_bad = False
336 | start = 0
337 | sub_tokens = []
338 | while start < len(chars):
339 | end = len(chars)
340 | cur_substr = None
341 | while start < end:
342 | substr = "".join(chars[start:end])
343 | if start > 0:
344 | substr = "##" + substr
345 | if substr in self.vocab:
346 | cur_substr = substr
347 | break
348 | end -= 1
349 | if cur_substr is None:
350 | is_bad = True
351 | break
352 | sub_tokens.append(cur_substr)
353 | start = end
354 |
355 | if is_bad:
356 | output_tokens.append(self.unk_token)
357 | else:
358 | output_tokens.extend(sub_tokens)
359 | return output_tokens
360 |
361 |
362 | def _is_whitespace(char):
363 | """Checks whether `chars` is a whitespace character."""
364 | # \t, \n, and \r are technically contorl characters but we treat them
365 | # as whitespace since they are generally considered as such.
366 | if char == " " or char == "\t" or char == "\n" or char == "\r":
367 | return True
368 | cat = unicodedata.category(char)
369 | if cat == "Zs":
370 | return True
371 | return False
372 |
373 |
374 | def _is_control(char):
375 | """Checks whether `chars` is a control character."""
376 | # These are technically control characters but we count them as whitespace
377 | # characters.
378 | if char == "\t" or char == "\n" or char == "\r":
379 | return False
380 | cat = unicodedata.category(char)
381 | if cat in ("Cc", "Cf"):
382 | return True
383 | return False
384 |
385 |
386 | def _is_punctuation(char):
387 | """Checks whether `chars` is a punctuation character."""
388 | cp = ord(char)
389 | # We treat all non-letter/number ASCII as punctuation.
390 | # Characters such as "^", "$", and "`" are not in the Unicode
391 | # Punctuation class but we treat them as punctuation anyways, for
392 | # consistency.
393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
395 | return True
396 | cat = unicodedata.category(char)
397 | if cat.startswith("P"):
398 | return True
399 | return False
400 |
--------------------------------------------------------------------------------
/data/Ubuntu_V1_Xu/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | for item in items:
140 | output.append(vocab[item])
141 | return output
142 |
143 |
144 | def convert_tokens_to_ids(vocab, tokens):
145 | return convert_by_vocab(vocab, tokens)
146 |
147 |
148 | def convert_ids_to_tokens(inv_vocab, ids):
149 | return convert_by_vocab(inv_vocab, ids)
150 |
151 |
152 | def whitespace_tokenize(text):
153 | """Runs basic whitespace cleaning and splitting on a piece of text."""
154 | text = text.strip()
155 | if not text:
156 | return []
157 | tokens = text.split()
158 | return tokens
159 |
160 |
161 | class FullTokenizer(object):
162 | """Runs end-to-end tokenziation."""
163 |
164 | def __init__(self, vocab_file, do_lower_case=True):
165 | self.vocab = load_vocab(vocab_file)
166 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
169 |
170 | def tokenize(self, text):
171 | split_tokens = []
172 | for token in self.basic_tokenizer.tokenize(text):
173 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
174 | split_tokens.append(sub_token)
175 |
176 | return split_tokens
177 |
178 | def convert_tokens_to_ids(self, tokens):
179 | return convert_by_vocab(self.vocab, tokens)
180 |
181 | def convert_ids_to_tokens(self, ids):
182 | return convert_by_vocab(self.inv_vocab, ids)
183 |
184 |
185 | class BasicTokenizer(object):
186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
187 |
188 | def __init__(self, do_lower_case=True):
189 | """Constructs a BasicTokenizer.
190 |
191 | Args:
192 | do_lower_case: Whether to lower case the input.
193 | """
194 | self.do_lower_case = do_lower_case
195 |
196 | def tokenize(self, text):
197 | """Tokenizes a piece of text."""
198 | text = convert_to_unicode(text)
199 | text = self._clean_text(text)
200 |
201 | # This was added on November 1st, 2018 for the multilingual and Chinese
202 | # models. This is also applied to the English models now, but it doesn't
203 | # matter since the English models were not trained on any Chinese data
204 | # and generally don't have any Chinese data in them (there are Chinese
205 | # characters in the vocabulary because Wikipedia does have some Chinese
206 | # words in the English Wikipedia.).
207 | text = self._tokenize_chinese_chars(text)
208 |
209 | orig_tokens = whitespace_tokenize(text)
210 | split_tokens = []
211 | for token in orig_tokens:
212 | if self.do_lower_case:
213 | token = token.lower()
214 | token = self._run_strip_accents(token)
215 | split_tokens.extend(self._run_split_on_punc(token))
216 |
217 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
218 | return output_tokens
219 |
220 | def _run_strip_accents(self, text):
221 | """Strips accents from a piece of text."""
222 | text = unicodedata.normalize("NFD", text)
223 | output = []
224 | for char in text:
225 | cat = unicodedata.category(char)
226 | if cat == "Mn":
227 | continue
228 | output.append(char)
229 | return "".join(output)
230 |
231 | def _run_split_on_punc(self, text):
232 | """Splits punctuation on a piece of text."""
233 | chars = list(text)
234 | i = 0
235 | start_new_word = True
236 | output = []
237 | while i < len(chars):
238 | char = chars[i]
239 | if _is_punctuation(char):
240 | output.append([char])
241 | start_new_word = True
242 | else:
243 | if start_new_word:
244 | output.append([])
245 | start_new_word = False
246 | output[-1].append(char)
247 | i += 1
248 |
249 | return ["".join(x) for x in output]
250 |
251 | def _tokenize_chinese_chars(self, text):
252 | """Adds whitespace around any CJK character."""
253 | output = []
254 | for char in text:
255 | cp = ord(char)
256 | if self._is_chinese_char(cp):
257 | output.append(" ")
258 | output.append(char)
259 | output.append(" ")
260 | else:
261 | output.append(char)
262 | return "".join(output)
263 |
264 | def _is_chinese_char(self, cp):
265 | """Checks whether CP is the codepoint of a CJK character."""
266 | # This defines a "chinese character" as anything in the CJK Unicode block:
267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
268 | #
269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
270 | # despite its name. The modern Korean Hangul alphabet is a different block,
271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
272 | # space-separated words, so they are not treated specially and handled
273 | # like the all of the other languages.
274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
275 | (cp >= 0x3400 and cp <= 0x4DBF) or #
276 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
277 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
278 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
280 | (cp >= 0xF900 and cp <= 0xFAFF) or #
281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
282 | return True
283 |
284 | return False
285 |
286 | def _clean_text(self, text):
287 | """Performs invalid character removal and whitespace cleanup on text."""
288 | output = []
289 | for char in text:
290 | cp = ord(char)
291 | if cp == 0 or cp == 0xfffd or _is_control(char):
292 | continue
293 | if _is_whitespace(char):
294 | output.append(" ")
295 | else:
296 | output.append(char)
297 | return "".join(output)
298 |
299 |
300 | class WordpieceTokenizer(object):
301 | """Runs WordPiece tokenziation."""
302 |
303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
304 | self.vocab = vocab
305 | self.unk_token = unk_token
306 | self.max_input_chars_per_word = max_input_chars_per_word
307 |
308 | def tokenize(self, text):
309 | """Tokenizes a piece of text into its word pieces.
310 |
311 | This uses a greedy longest-match-first algorithm to perform tokenization
312 | using the given vocabulary.
313 |
314 | For example:
315 | input = "unaffable"
316 | output = ["un", "##aff", "##able"]
317 |
318 | Args:
319 | text: A single token or whitespace separated tokens. This should have
320 | already been passed through `BasicTokenizer.
321 |
322 | Returns:
323 | A list of wordpiece tokens.
324 | """
325 |
326 | text = convert_to_unicode(text)
327 |
328 | output_tokens = []
329 | for token in whitespace_tokenize(text):
330 | chars = list(token)
331 | if len(chars) > self.max_input_chars_per_word:
332 | output_tokens.append(self.unk_token)
333 | continue
334 |
335 | is_bad = False
336 | start = 0
337 | sub_tokens = []
338 | while start < len(chars):
339 | end = len(chars)
340 | cur_substr = None
341 | while start < end:
342 | substr = "".join(chars[start:end])
343 | if start > 0:
344 | substr = "##" + substr
345 | if substr in self.vocab:
346 | cur_substr = substr
347 | break
348 | end -= 1
349 | if cur_substr is None:
350 | is_bad = True
351 | break
352 | sub_tokens.append(cur_substr)
353 | start = end
354 |
355 | if is_bad:
356 | output_tokens.append(self.unk_token)
357 | else:
358 | output_tokens.extend(sub_tokens)
359 | return output_tokens
360 |
361 |
362 | def _is_whitespace(char):
363 | """Checks whether `chars` is a whitespace character."""
364 | # \t, \n, and \r are technically contorl characters but we treat them
365 | # as whitespace since they are generally considered as such.
366 | if char == " " or char == "\t" or char == "\n" or char == "\r":
367 | return True
368 | cat = unicodedata.category(char)
369 | if cat == "Zs":
370 | return True
371 | return False
372 |
373 |
374 | def _is_control(char):
375 | """Checks whether `chars` is a control character."""
376 | # These are technically control characters but we count them as whitespace
377 | # characters.
378 | if char == "\t" or char == "\n" or char == "\r":
379 | return False
380 | cat = unicodedata.category(char)
381 | if cat in ("Cc", "Cf"):
382 | return True
383 | return False
384 |
385 |
386 | def _is_punctuation(char):
387 | """Checks whether `chars` is a punctuation character."""
388 | cp = ord(char)
389 | # We treat all non-letter/number ASCII as punctuation.
390 | # Characters such as "^", "$", and "`" are not in the Unicode
391 | # Punctuation class but we treat them as punctuation anyways, for
392 | # consistency.
393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
395 | return True
396 | cat = unicodedata.category(char)
397 | if cat.startswith("P"):
398 | return True
399 | return False
400 |
--------------------------------------------------------------------------------
/data/Ubuntu_V1_Xu/create_finetuning_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import collections
3 | import tokenization
4 | import tensorflow as tf
5 | from tqdm import tqdm
6 |
7 |
8 | tf.flags.DEFINE_string("response_file", "./Ubuntu_Corpus_V1/responses.txt",
9 | "path to response file")
10 | tf.flags.DEFINE_string("train_file", "./Ubuntu_Corpus_V1/train.txt",
11 | "path to train file")
12 | tf.flags.DEFINE_string("valid_file", "./Ubuntu_Corpus_V1/valid.txt",
13 | "path to valid file")
14 | tf.flags.DEFINE_string("test_file", "./Ubuntu_Corpus_V1/test.txt",
15 | "path to test file")
16 |
17 | tf.flags.DEFINE_string("vocab_file", "../../uncased_L-12_H-768_A-12/vocab.txt",
18 | "path to vocab file")
19 | tf.flags.DEFINE_integer("max_seq_length", 512,
20 | "max sequence length of concatenated context and response")
21 | tf.flags.DEFINE_bool("do_lower_case", True,
22 | "whether to lower case the input text")
23 |
24 |
25 |
26 | def print_configuration_op(FLAGS):
27 | print('My Configurations:')
28 | for name, value in FLAGS.__flags.items():
29 | value=value.value
30 | if type(value) == float:
31 | print(' %s:\t %f'%(name, value))
32 | elif type(value) == int:
33 | print(' %s:\t %d'%(name, value))
34 | elif type(value) == str:
35 | print(' %s:\t %s'%(name, value))
36 | elif type(value) == bool:
37 | print(' %s:\t %s'%(name, value))
38 | else:
39 | print('%s:\t %s' % (name, value))
40 | print('End of configuration')
41 |
42 |
43 | def load_responses(fname):
44 | responses={}
45 | with open(fname, 'rt') as f:
46 | for line in f:
47 | line = line.strip()
48 | fields = line.split('\t')
49 | if len(fields) != 2:
50 | print("WRONG LINE: {}".format(line))
51 | r_text = 'unknown'
52 | else:
53 | r_text = fields[1]
54 | responses[fields[0]] = r_text
55 | return responses
56 |
57 |
58 | def load_dataset(fname, responses):
59 |
60 | processed_fname = "processed_" + fname.split("/")[-1]
61 | dataset_size = 0
62 | print("Generating the file of {} ...".format(processed_fname))
63 |
64 | with open(processed_fname, 'w') as fw:
65 | with open(fname, 'rt') as fr:
66 | for line in fr:
67 | line = line.strip()
68 | fields = line.split('\t')
69 |
70 | us_id = fields[0]
71 | context = fields[1]
72 |
73 | if fields[2] != "NA":
74 | pos_ids = [id for id in fields[2].split('|')]
75 | for r_id in pos_ids:
76 | r_utter = responses[r_id]
77 | dataset_size += 1
78 | fw.write("\t".join([str(us_id), context, r_id, r_utter, 'follow']))
79 | fw.write('\n')
80 |
81 | if fields[3] != "NA":
82 | neg_ids = [id for id in fields[3].split('|')]
83 | for r_id in neg_ids:
84 | r_utter = responses[r_id]
85 | dataset_size += 1
86 | fw.write("\t".join([str(us_id), context, r_id, r_utter, 'unfollow']))
87 | fw.write('\n')
88 |
89 | print("{} dataset_size: {}".format(processed_fname, dataset_size))
90 | return processed_fname
91 |
92 |
93 | class InputExample(object):
94 | def __init__(self, guid,ques_ids, text_a, ans_ids, text_b=None, label=None):
95 | """Constructs a InputExample.
96 | Args:
97 | guid: Unique id for the example.
98 | text_a: string. The untokenized text of the first sequence. For single
99 | sequence tasks, only this sequence must be specified.
100 | text_b: (Optional) string. The untokenized text of the second sequence.
101 | Only must be specified for sequence pair tasks.
102 | label: (Optional) string. The label of the example. This should be
103 | specified for train and dev examples, but not for test examples.
104 | """
105 | self.guid = guid
106 | self.ques_ids = ques_ids
107 | self.ans_ids = ans_ids
108 | self.text_a = text_a
109 | self.text_b = text_b
110 | self.label = label
111 |
112 | class InputFeatures(object):
113 | """A single set of features of data."""
114 | def __init__(self, ques_ids, ans_ids, input_sents, input_mask, segment_ids, switch_ids, label_id):
115 | self.ques_ids = ques_ids
116 | self.ans_ids = ans_ids
117 | self.input_sents = input_sents
118 | self.input_mask = input_mask
119 | self.segment_ids = segment_ids
120 | self.switch_ids=switch_ids
121 | self.label_id = label_id
122 |
123 | def read_processed_file(input_file):
124 | lines = []
125 | num_lines = sum(1 for line in open(input_file, 'r'))
126 | with open(input_file, 'r') as f:
127 | for line in tqdm(f, total=num_lines):
128 | concat = []
129 | temp = line.rstrip().split('\t')
130 | concat.append(temp[0]) # contxt id
131 | concat.append(temp[1]) # contxt
132 | concat.append(temp[2]) # response id
133 | concat.append(temp[3]) # response
134 | concat.append(temp[4]) # label
135 | lines.append(concat)
136 | return lines
137 |
138 | def create_examples(lines, set_type):
139 | """Creates examples for the training and dev sets."""
140 | examples = []
141 | for (i, line) in enumerate(lines):
142 | guid = "%s-%s" % (set_type, str(i))
143 | ques_ids = line[0]
144 | text_a = tokenization.convert_to_unicode(line[1])
145 | ans_ids = line[2]
146 | text_b = tokenization.convert_to_unicode(line[3])
147 | label = tokenization.convert_to_unicode(line[-1])
148 | examples.append(InputExample(guid=guid, ques_ids=ques_ids, text_a=text_a, ans_ids=ans_ids, text_b=text_b, label=label))
149 | return examples
150 |
151 |
152 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
153 | """Truncates a sequence pair in place to the maximum length."""
154 |
155 | # This is a simple heuristic which will always truncate the longer sequence
156 | # one token at a time. This makes more sense than truncating an equal percent
157 | # of tokens from each, since if one sequence is very short then each token
158 | # that's truncated likely contains more information than a longer sequence.
159 | while True:
160 | total_length = len(tokens_a) + len(tokens_b)
161 | if total_length <= max_length:
162 | break
163 | if len(tokens_a) > len(tokens_b):
164 | tokens_a.pop()
165 | else:
166 | tokens_b.pop()
167 |
168 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
169 | """Loads a data file into a list of `InputBatch`s."""
170 |
171 | label_map = {} # label
172 | for (i, label) in enumerate(label_list): # ['0', '1']
173 | label_map[label] = i
174 |
175 | features = [] # feature
176 | for (ex_index, example) in enumerate(examples):
177 | ques_ids = int(example.ques_ids)
178 | ans_ids = int(example.ans_ids)
179 |
180 | # tokens_a = tokenizer.tokenize(example.text_a) # text_a tokenize
181 | text_a_utters = example.text_a.split(" __EOS__ ")
182 | tokens_a = []
183 | text_a_switch = []
184 | for text_a_utter_idx, text_a_utter in enumerate(text_a_utters):
185 | if text_a_utter_idx%2 == 0:
186 | text_a_switch_flag = 0
187 | else:
188 | text_a_switch_flag = 1
189 | text_a_utter_token = tokenizer.tokenize(text_a_utter + " __EOS__")
190 | tokens_a.extend(text_a_utter_token)
191 | text_a_switch.extend([text_a_switch_flag]*len(text_a_utter_token))
192 | assert len(tokens_a) == len(text_a_switch)
193 |
194 | tokens_b = None
195 | if example.text_b:
196 | tokens_b = tokenizer.tokenize(example.text_b) # text_b tokenize
197 |
198 | if tokens_b: # if has b
199 | # Modifies `tokens_a` and `tokens_b` in place so that the total
200 | # length is less than the specified length.
201 | # Account for [CLS], [SEP], [SEP] with "- 3"
202 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) # truncate
203 | else:
204 | # Account for [CLS] and [SEP] with "- 2"
205 | if len(tokens_a) > max_seq_length - 2:
206 | tokens_a = tokens_a[0:(max_seq_length - 2)]
207 |
208 | # The convention in BERT is:
209 | # (a) For sequence pairs:
210 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
211 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
212 | # (b) For single sequences:
213 | # tokens: [CLS] the dog is hairy . [SEP]
214 | # type_ids: 0 0 0 0 0 0 0
215 | #
216 | # Where "type_ids" are used to indicate whether this is the first
217 | # sequence or the second sequence. The embedding vectors for `type=0` and
218 | # `type=1` were learned during pre-training and are added to the wordpiece
219 | # embedding vector (and position vector). This is not *strictly* necessary
220 | # since the [SEP] token unambiguously separates the sequences, but it makes
221 | # it easier for the model to learn the concept of sequences.
222 | #
223 | # For classification tasks, the first vector (corresponding to [CLS]) is
224 | # used as as the "sentence vector". Note that this only makes sense because # (?)
225 | # the entire model is fine-tuned.
226 | tokens = []
227 | segment_ids = []
228 | switch_ids = []
229 | tokens.append("[CLS]")
230 | segment_ids.append(0)
231 | switch_ids.append(0)
232 | for token_idx, token in enumerate(tokens_a):
233 | tokens.append(token)
234 | segment_ids.append(0)
235 | switch_ids.append(text_a_switch[token_idx])
236 | tokens.append("[SEP]")
237 | segment_ids.append(0)
238 | switch_ids.append(0)
239 |
240 | if tokens_b:
241 | for token_idx, token in enumerate(tokens_b):
242 | tokens.append(token)
243 | segment_ids.append(1)
244 | switch_ids.append(1)
245 | tokens.append("[SEP]")
246 | segment_ids.append(1)
247 | switch_ids.append(1)
248 |
249 | input_sents = tokenizer.convert_tokens_to_ids(tokens)
250 |
251 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
252 | # tokens are attended to.
253 | input_mask = [1] * len(input_sents) # mask
254 |
255 | # Zero-pad up to the sequence length.
256 | while len(input_sents) < max_seq_length:
257 | input_sents.append(0)
258 | input_mask.append(0)
259 | segment_ids.append(0)
260 | switch_ids.append(0)
261 |
262 | assert len(input_sents) == max_seq_length
263 | assert len(input_mask) == max_seq_length
264 | assert len(segment_ids) == max_seq_length
265 | assert len(switch_ids) == max_seq_length
266 |
267 | label_id = label_map[example.label]
268 |
269 | if ex_index%2000 == 0:
270 | print('convert_{}_examples_to_features'.format(ex_index))
271 |
272 | features.append(
273 | InputFeatures( # object
274 | ques_ids=ques_ids,
275 | ans_ids = ans_ids,
276 | input_sents=input_sents,
277 | input_mask=input_mask,
278 | segment_ids=segment_ids,
279 | switch_ids=switch_ids,
280 | label_id=label_id))
281 |
282 | return features
283 |
284 |
285 | def write_instance_to_example_files(instances, output_files):
286 | writers = []
287 |
288 | for output_file in output_files:
289 | writers.append(tf.python_io.TFRecordWriter(output_file))
290 |
291 | writer_index = 0
292 | total_written = 0
293 | for (inst_index, instance) in enumerate(instances):
294 | features = collections.OrderedDict()
295 | features["ques_ids"] = create_int_feature([instance.ques_ids])
296 | features["ans_ids"] = create_int_feature([instance.ans_ids])
297 | features["input_sents"] = create_int_feature(instance.input_sents)
298 | features["input_mask"] = create_int_feature(instance.input_mask)
299 | features["segment_ids"] = create_int_feature(instance.segment_ids)
300 | features["switch_ids"] = create_int_feature(instance.switch_ids)
301 | features["label_ids"] = create_float_feature([instance.label_id])
302 |
303 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
304 |
305 | writers[writer_index].write(tf_example.SerializeToString())
306 | writer_index = (writer_index + 1) % len(writers)
307 |
308 | total_written += 1
309 |
310 | print("write_{}_instance_to_example_files".format(total_written))
311 |
312 | for feature_name in features.keys():
313 | feature = features[feature_name]
314 | values = []
315 | if feature.int64_list.value:
316 | values = feature.int64_list.value
317 | elif feature.float_list.value:
318 | values = feature.float_list.value
319 | tf.logging.info(
320 | "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
321 |
322 | for writer in writers:
323 | writer.close()
324 |
325 |
326 | def create_int_feature(values):
327 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
328 | return feature
329 |
330 | def create_float_feature(values):
331 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
332 | return feature
333 |
334 |
335 |
336 | if __name__ == "__main__":
337 |
338 | FLAGS = tf.flags.FLAGS
339 | print_configuration_op(FLAGS)
340 |
341 | responses = load_responses(FLAGS.response_file)
342 | train_filename = load_dataset(FLAGS.train_file, responses)
343 | valid_filename = load_dataset(FLAGS.valid_file, responses)
344 | test_filename = load_dataset(FLAGS.test_file, responses)
345 |
346 | filenames = [train_filename, valid_filename, test_filename]
347 | filetypes = ["train", "valid", "test"]
348 | files = zip(filenames, filetypes)
349 |
350 | label_list = ["unfollow", "follow"]
351 | tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
352 |
353 | for (filename, filetype) in files:
354 | examples = create_examples(read_processed_file(filename), filetype)
355 | features = convert_examples_to_features(examples, label_list, FLAGS.max_seq_length, tokenizer)
356 | new_filename = filename[:-4] + ".tfrecord"
357 | write_instance_to_example_files(features, [new_filename])
358 | print('Convert {} to {} done'.format(filename, new_filename))
359 |
360 | print("Sub-process(es) done.")
361 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """BERT finetuning runner."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import operator
23 | from time import time
24 | from collections import defaultdict
25 | import tensorflow as tf
26 | import optimization
27 | import tokenization
28 | import modeling_switch as modeling
29 | import metrics
30 |
31 | flags = tf.flags
32 | FLAGS = flags.FLAGS
33 |
34 | ## Required parameters
35 | flags.DEFINE_string("train_dir", 'train.tfrecord',
36 | "The input train data dir. Should contain the .tsv files (or other data files) for the task.")
37 |
38 | flags.DEFINE_string("valid_dir", 'valid.tfrecord',
39 | "The input valid data dir. Should contain the .tsv files (or other data files) for the task.")
40 |
41 | flags.DEFINE_string("output_dir", 'output',
42 | "The output directory where the model checkpoints will be written.")
43 |
44 | flags.DEFINE_string("task_name", 'ResponseSelection',
45 | "The name of the task to train.")
46 |
47 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json',
48 | "The config json file corresponding to the pre-trained BERT model. "
49 | "This specifies the model architecture.")
50 |
51 | flags.DEFINE_string("vocab_file", 'uncased_L-12_H-768_A-12/vocab.txt',
52 | "The vocabulary file that the BERT model was trained on.")
53 |
54 | flags.DEFINE_string("init_checkpoint", 'uncased_L-12_H-768_A-12/bert_model.ckpt',
55 | "Initial checkpoint (usually from a pre-trained BERT model).")
56 |
57 | flags.DEFINE_bool("do_lower_case", True,
58 | "Whether to lower case the input text. Should be True for uncased "
59 | "models and False for cased models.")
60 |
61 | flags.DEFINE_integer("max_seq_length", 320,
62 | "The maximum total input sequence length after WordPiece tokenization. "
63 | "Sequences longer than this will be truncated, and sequences shorter "
64 | "than this will be padded.")
65 |
66 | flags.DEFINE_bool("do_train", True,
67 | "Whether to run training.")
68 |
69 | flags.DEFINE_bool("do_eval", True,
70 | "Whether to run eval on the dev set.")
71 |
72 | flags.DEFINE_bool("do_predict", True,
73 | "Whether to run the model in inference mode on the test set.")
74 |
75 | flags.DEFINE_float("warmup_proportion", 0.1,
76 | "Proportion of training to perform linear learning rate warmup for. "
77 | "E.g., 0.1 = 10% of training.")
78 |
79 | flags.DEFINE_integer("train_batch_size", 12,
80 | "Total batch size for training.")
81 |
82 | flags.DEFINE_integer("eval_batch_size", 12,
83 | "Total batch size for eval.")
84 |
85 | flags.DEFINE_integer("predict_batch_size", 8,
86 | "Total batch size for predict.")
87 |
88 | flags.DEFINE_float("learning_rate", 2e-5,
89 | "The initial learning rate for Adam.")
90 |
91 | flags.DEFINE_integer("num_train_epochs", 5,
92 | "Total number of training epochs to perform.")
93 |
94 |
95 |
96 | def print_configuration_op(FLAGS):
97 | print('My Configurations:')
98 | for name, value in FLAGS.__flags.items():
99 | value=value.value
100 | if type(value) == float:
101 | print(' %s:\t %f'%(name, value))
102 | elif type(value) == int:
103 | print(' %s:\t %d'%(name, value))
104 | elif type(value) == str:
105 | print(' %s:\t %s'%(name, value))
106 | elif type(value) == bool:
107 | print(' %s:\t %s'%(name, value))
108 | else:
109 | print('%s:\t %s' % (name, value))
110 | print('End of configuration')
111 |
112 |
113 | def total_sample(file_name):
114 | sample_nums = 0
115 | for record in tf.python_io.tf_record_iterator(file_name):
116 | sample_nums += 1
117 | return sample_nums
118 |
119 |
120 | def parse_exmp(serial_exmp):
121 | input_data = tf.parse_single_example(serial_exmp,
122 | features={
123 | "ques_ids":
124 | tf.FixedLenFeature([], tf.int64),
125 | "ans_ids":
126 | tf.FixedLenFeature([], tf.int64),
127 | "input_sents":
128 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
129 | "input_mask":
130 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
131 | "segment_ids":
132 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
133 | "switch_ids":
134 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
135 | "label_ids":
136 | tf.FixedLenFeature([], tf.float32),
137 | }
138 | )
139 | # So cast all int64 to int32.
140 | for name in list(input_data.keys()):
141 | t = input_data[name]
142 | if t.dtype == tf.int64:
143 | t = tf.to_int32(t)
144 | input_data[name] = t
145 |
146 | ques_ids = input_data["ques_ids"]
147 | ans_ids = input_data['ans_ids']
148 | sents = input_data["input_sents"]
149 | mask = input_data["input_mask"]
150 | segment_ids= input_data["segment_ids"]
151 | switch_ids= input_data["switch_ids"]
152 | labels = input_data['label_ids']
153 | return ques_ids, ans_ids, sents, mask, segment_ids, switch_ids, labels
154 |
155 |
156 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, switch_ids, labels, ques_ids, ans_ids,
157 | num_labels, use_one_hot_embeddings):
158 | """Creates a classification model."""
159 | model = modeling.BertModel(
160 | config=bert_config,
161 | is_training=is_training,
162 | input_ids=input_ids,
163 | input_mask=input_mask,
164 | token_type_ids=segment_ids,
165 | switch_ids=switch_ids,
166 | use_one_hot_embeddings=use_one_hot_embeddings)
167 |
168 | # In the demo, we are doing a simple classification task on the entire
169 | # segment.
170 | #
171 | # If you want to use the token-level output, use model.get_sequence_output()
172 | # instead.
173 | target_loss_weight = [1.0, 1.0]
174 | target_loss_weight = tf.convert_to_tensor(target_loss_weight)
175 |
176 | flagx = tf.cast(tf.greater(labels, 0), dtype=tf.float32)
177 | flagy = tf.cast(tf.equal(labels, 0), dtype=tf.float32)
178 |
179 | all_target_loss = target_loss_weight[1] * flagx + target_loss_weight[0] * flagy
180 |
181 | output_layer = model.get_pooled_output()
182 |
183 | hidden_size = output_layer.shape[-1].value
184 |
185 | output_weights = tf.get_variable(
186 | "output_weights", [num_labels, hidden_size],
187 | initializer=tf.truncated_normal_initializer(stddev=0.02))
188 |
189 | output_bias = tf.get_variable(
190 | "output_bias", [num_labels], initializer=tf.zeros_initializer())
191 |
192 | with tf.variable_scope("loss"):
193 |
194 | output_layer = tf.layers.dropout(output_layer, rate=0.1, training=is_training)
195 |
196 | logits = tf.matmul(output_layer, output_weights, transpose_b=True)
197 | logits = tf.nn.bias_add(logits, output_bias)
198 |
199 | probabilities = tf.sigmoid(logits, name="prob")
200 | logits = tf.squeeze(logits,[1])
201 | losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
202 | losses = tf.multiply(losses, all_target_loss)
203 |
204 | mean_loss = tf.reduce_mean(losses, name="mean_loss") + sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
205 |
206 | with tf.name_scope("accuracy"):
207 | correct_prediction = tf.equal(tf.sign(probabilities - 0.5), tf.sign(labels - 0.5))
208 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy")
209 |
210 | # one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
211 | #
212 | # per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
213 | # loss = tf.reduce_mean(per_example_loss)
214 |
215 | return mean_loss, logits, probabilities, accuracy, model
216 |
217 |
218 | def run_epoch(epoch_no, op_name, sess, training, logits, accuracy, mean_loss, train_opt=tf.constant(0)):
219 | n_updates = 0
220 | t_loss = 0
221 | n_all = 0
222 | t0 = time()
223 | try:
224 | while True:
225 | n_updates += 1
226 | batch_logits, batch_loss, _ , accur= sess.run([logits, mean_loss, train_opt, accuracy], feed_dict={training:True})
227 | n_sample = batch_logits.shape[0]
228 | n_all += n_sample
229 | t_loss += batch_loss * n_sample
230 | if n_updates%2000 == 0:
231 | tf.logging.info("epoch: %i n_update %d , %s: Mins Used: %.2f, Loss: %.4f, Accuarcy: %.2f" %
232 | (epoch_no, n_updates, op_name, (time() - t0) / 60.0, t_loss / n_all, 100 * accur))
233 |
234 | except tf.errors.OutOfRangeError:
235 | tf.logging.info("epoch: %i %s: Mins Used: %.2f, Loss: %.4f, Accuarcy: %.2f" %
236 | (epoch_no, op_name, (time() - t0)/60.0, t_loss / n_all, 100*accur))
237 | pass
238 | return t_loss / n_all
239 |
240 |
241 | best_score = 0.0
242 | def run_test(epoch_no, dir_path, op_name, sess, training, accuracy, prob, pair_ids):
243 | results = defaultdict(list)
244 | num_test = 0
245 | num_correct = 0.0
246 | n_updates = 0
247 | mrr = 0
248 | t0 = time()
249 | try:
250 | while True:
251 | n_updates += 1
252 |
253 | batch_accuracy, predicted_prob, pair_ = sess.run([accuracy, prob, pair_ids], feed_dict={training:False})
254 | question_id, answer_id, label = pair_
255 |
256 | num_test += len(predicted_prob)
257 | num_correct += len(predicted_prob) * batch_accuracy
258 | for i, prob_score in enumerate(predicted_prob):
259 | results[question_id[i]].append((answer_id[i], label[i], prob_score[0]))
260 |
261 | if n_updates%2000 == 0:
262 | tf.logging.info("epoch: %i n_update %d , %s: Mins Used: %.2f" %
263 | (epoch_no, n_updates, op_name, (time() - t0)/60.0 ))
264 |
265 | except tf.errors.OutOfRangeError:
266 |
267 | # calculate top-1 precision
268 | print('num_test_samples: {} test_accuracy: {}'.format(num_test, num_correct / num_test))
269 | accu, precision, recall, f1, loss = metrics.classification_metrics(results)
270 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss))
271 |
272 | mvp = metrics.mean_average_precision(results)
273 | mrr = metrics.mean_reciprocal_rank(results)
274 | top_1_precision = metrics.top_1_precision(results)
275 | total_valid_query = metrics.get_num_valid_query(results)
276 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'.format(
277 | mvp, mrr, top_1_precision, total_valid_query))
278 |
279 | out_path = os.path.join(dir_path, "output_epoch_{}.txt".format(epoch_no))
280 | print("Saving evaluation to {}".format(out_path))
281 | with open(out_path, 'w') as f:
282 | f.write("query_id\tdocument_id\tscore\trank\trelevance\n")
283 | for us_id, v in results.items():
284 | v.sort(key=operator.itemgetter(2), reverse=True)
285 | for i, rec in enumerate(v):
286 | r_id, label, prob_score = rec
287 | rank = i+1
288 | f.write('{}\t{}\t{}\t{}\t{}\n'.format(us_id, r_id, prob_score, rank, label))
289 |
290 | global best_score
291 | if op_name == 'valid' and mrr > best_score:
292 | best_score = mrr
293 | saver = tf.train.Saver()
294 | dir_path = os.path.join(dir_path, "epoch {}".format(epoch_no))
295 | if not os.path.exists(dir_path):
296 | os.makedirs(dir_path)
297 | saver.save(sess, dir_path)
298 | tf.logging.info(">> save model!")
299 |
300 | return mrr
301 |
302 |
303 |
304 | def main(_):
305 | tf.logging.set_verbosity(tf.logging.INFO)
306 |
307 | print_configuration_op(FLAGS)
308 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
309 | root_path = FLAGS.output_dir
310 | if not os.path.exists(root_path):
311 | os.makedirs(root_path)
312 |
313 | timestamp = str(int(time()))
314 | root_path = os.path.join(root_path, timestamp)
315 | tf.logging.info('root_path: {}'.format(root_path))
316 | if not os.path.exists(root_path):
317 | os.makedirs(root_path)
318 |
319 | train_data_size = total_sample(FLAGS.train_dir)
320 | tf.logging.info('train data size: {}'.format(train_data_size))
321 | valid_data_size = total_sample(FLAGS.valid_dir)
322 | tf.logging.info('valid data size: {}'.format(valid_data_size))
323 |
324 | num_train_steps = train_data_size // FLAGS.train_batch_size * FLAGS.num_train_epochs
325 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
326 |
327 | filenames = tf.placeholder(tf.string, shape=[None])
328 | shuffle_size = tf.placeholder(tf.int64)
329 | dataset = tf.data.TFRecordDataset(filenames)
330 | dataset = dataset.map(parse_exmp) # Parse the record into tensors.
331 | dataset = dataset.repeat(1)
332 | # buffer_size 100
333 | dataset = dataset.shuffle(shuffle_size)
334 | dataset = dataset.batch(FLAGS.train_batch_size)
335 | iterator = dataset.make_initializable_iterator()
336 | ques_ids, ans_ids, sents, mask, segment_ids, switch_ids, labels = iterator.get_next() # output dir
337 | pair_ids = [ques_ids, ans_ids, labels]
338 |
339 |
340 | training = tf.placeholder(tf.bool)
341 | mean_loss, logits, probabilities, accuracy, model = create_model(bert_config,
342 | is_training = training,
343 | input_ids = sents,
344 | input_mask = mask,
345 | segment_ids = segment_ids,
346 | switch_ids = switch_ids,
347 | labels = labels,
348 | ques_ids = ques_ids,
349 | ans_ids = ans_ids,
350 | num_labels = 1,
351 | use_one_hot_embeddings = False)
352 |
353 |
354 | # init model with pre-training
355 | tvars = tf.trainable_variables()
356 | if FLAGS.init_checkpoint:
357 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,FLAGS.init_checkpoint)
358 | tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map)
359 |
360 | tf.logging.info("**** Trainable Variables ****")
361 | for var in tvars:
362 | init_string = ""
363 | if var.name in initialized_variable_names:
364 | init_string = ", *INIT_FROM_CKPT*"
365 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
366 | init_string)
367 |
368 |
369 | train_opt = optimization.create_optimizer(mean_loss, FLAGS.learning_rate, num_train_steps, num_warmup_steps, False)
370 |
371 | config = tf.ConfigProto(allow_soft_placement=True)
372 | config.gpu_options.allow_growth = True
373 |
374 |
375 | if FLAGS.do_train:
376 | with tf.Session(config=config) as sess:
377 | sess.run(tf.global_variables_initializer())
378 |
379 | for epoch in range(FLAGS.num_train_epochs):
380 | tf.logging.info('Epoch {} training begin'.format(epoch))
381 | sess.run(iterator.initializer,
382 | feed_dict={filenames: [FLAGS.train_dir], shuffle_size: 1024})
383 | run_epoch(epoch, "train", sess, training, logits, accuracy, mean_loss, train_opt)
384 |
385 | tf.logging.info('Valid begin')
386 | sess.run(iterator.initializer,
387 | feed_dict={filenames: [FLAGS.valid_dir], shuffle_size: 1})
388 | run_test(epoch, root_path, "valid", sess, training, accuracy, probabilities, pair_ids)
389 |
390 |
391 |
392 | if __name__ == "__main__":
393 | tf.app.run()
394 |
--------------------------------------------------------------------------------
/data/Ubuntu_V1_Xu/create_adaptation_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Create masked LM/next sentence masked_lm TF examples for BERT."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import random
23 | import tokenization
24 | import numpy as np
25 | import tensorflow as tf
26 | from tqdm import tqdm
27 |
28 | flags = tf.flags
29 | FLAGS = flags.FLAGS
30 |
31 | flags.DEFINE_string("train_file", './Ubuntu_Corpus_V1/train.txt',
32 | "Input raw text file (or comma-separated list of files).")
33 |
34 | flags.DEFINE_string("response_file", './Ubuntu_Corpus_V1/responses.txt',
35 | "Input raw text file (or comma-separated list of files).")
36 |
37 | flags.DEFINE_string("output_file", './pretrain_data.tfrecord',
38 | "Output TF example file (or comma-separated list of files).")
39 |
40 | flags.DEFINE_string("vocab_file", '../../uncased_L-12_H-768_A-12/vocab.txt',
41 | "The vocabulary file that the BERT model was trained on.")
42 |
43 | flags.DEFINE_bool("do_lower_case", True,
44 | "Whether to lower case the input text. Should be True for uncased "
45 | "models and False for cased models.")
46 |
47 | flags.DEFINE_integer("max_seq_length", 512,
48 | "Maximum sequence length.")
49 |
50 | flags.DEFINE_integer("max_predictions_per_seq", 25,
51 | "Maximum number of masked LM predictions per sequence.")
52 |
53 | flags.DEFINE_integer("random_seed", 12345,
54 | "Random seed for data generation.")
55 |
56 | flags.DEFINE_integer("dupe_factor", 10,
57 | "Number of times to duplicate the input data (with different masks).")
58 |
59 | flags.DEFINE_float("masked_lm_prob", 0.15,
60 | "Masked LM probability.")
61 |
62 | flags.DEFINE_float("short_seq_prob", 0.1,
63 | "Probability of creating sequences which are shorter than the maximum length.")
64 |
65 |
66 |
67 | class TrainingInstance(object):
68 | """A single training instance (sentence pair)."""
69 |
70 | def __init__(self, tokens, segment_ids, switch_ids, masked_lm_positions, masked_lm_labels,
71 | is_random_next):
72 | self.tokens = tokens
73 | self.segment_ids = segment_ids
74 | self.switch_ids = switch_ids
75 | self.is_random_next = is_random_next
76 | self.masked_lm_positions = masked_lm_positions
77 | self.masked_lm_labels = masked_lm_labels
78 |
79 | def __str__(self):
80 | s = ""
81 | s += "tokens: %s\n" % (" ".join(
82 | [tokenization.printable_text(x) for x in self.tokens]))
83 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
84 | s += "switch_ids: %s\n" % (" ".join([str(x) for x in self.switch_ids]))
85 | s += "is_random_next: %s\n" % self.is_random_next
86 | s += "masked_lm_positions: %s\n" % (" ".join(
87 | [str(x) for x in self.masked_lm_positions]))
88 | s += "masked_lm_labels: %s\n" % (" ".join(
89 | [tokenization.printable_text(x) for x in self.masked_lm_labels]))
90 | s += "\n"
91 | return s
92 |
93 | def __repr__(self):
94 | return self.__str__()
95 |
96 |
97 | def write_instance_to_example_files(instances, tokenizer, max_seq_length,
98 | max_predictions_per_seq, output_files):
99 | """Create TF example files from `TrainingInstance`s."""
100 | writers = []
101 | for output_file in output_files:
102 | writers.append(tf.python_io.TFRecordWriter(output_file))
103 |
104 | writer_index = 0
105 |
106 | total_written = 0
107 | for (inst_index, instance) in enumerate(instances):
108 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
109 | input_mask = [1] * len(input_ids)
110 | segment_ids = list(instance.segment_ids)
111 | switch_ids = list(instance.switch_ids)
112 | assert len(input_ids) <= max_seq_length
113 |
114 | while len(input_ids) < max_seq_length:
115 | input_ids.append(0)
116 | input_mask.append(0)
117 | segment_ids.append(0)
118 | switch_ids.append(0)
119 |
120 | assert len(input_ids) == max_seq_length
121 | assert len(input_mask) == max_seq_length
122 | assert len(segment_ids) == max_seq_length
123 | assert len(switch_ids) == max_seq_length
124 |
125 | masked_lm_positions = list(instance.masked_lm_positions)
126 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
127 | masked_lm_weights = [1.0] * len(masked_lm_ids)
128 |
129 | while len(masked_lm_positions) < max_predictions_per_seq:
130 | masked_lm_positions.append(0)
131 | masked_lm_ids.append(0)
132 | masked_lm_weights.append(0.0)
133 |
134 | next_sentence_label = 1 if instance.is_random_next else 0
135 |
136 | features = collections.OrderedDict()
137 | features["input_ids"] = create_int_feature(input_ids)
138 | features["input_mask"] = create_int_feature(input_mask)
139 | features["segment_ids"] = create_int_feature(segment_ids)
140 | features["switch_ids"] = create_int_feature(switch_ids)
141 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
142 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
143 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
144 | features["next_sentence_labels"] = create_int_feature([next_sentence_label])
145 |
146 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
147 |
148 | writers[writer_index].write(tf_example.SerializeToString())
149 | writer_index = (writer_index + 1) % len(writers)
150 |
151 | total_written += 1
152 |
153 | if inst_index < 20:
154 | tf.logging.info("*** Example ***")
155 | tf.logging.info("tokens: %s" % " ".join(
156 | [tokenization.printable_text(x) for x in instance.tokens]))
157 |
158 | for feature_name in features.keys():
159 | feature = features[feature_name]
160 | values = []
161 | if feature.int64_list.value:
162 | values = feature.int64_list.value
163 | elif feature.float_list.value:
164 | values = feature.float_list.value
165 | tf.logging.info(
166 | "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
167 |
168 | for writer in writers:
169 | writer.close()
170 |
171 | tf.logging.info("Wrote %d total instances", total_written)
172 |
173 |
174 | def create_int_feature(values):
175 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
176 | return feature
177 |
178 | def create_float_feature(values):
179 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
180 | return feature
181 |
182 |
183 | def create_training_instances(context, response, switch, tokenizer, max_seq_length,
184 | dupe_factor, short_seq_prob, masked_lm_prob,
185 | max_predictions_per_seq, rng):
186 |
187 | # Input file format:
188 | # (1) One sentence per line. These should ideally be actual sentences, not
189 | # entire paragraphs or arbitrary spans of text. (Because we use the
190 | # sentence boundaries for the "next sentence prediction" task).
191 | # (2) Blank lines between documents. Document boundaries are needed so
192 | # that the "next sentence prediction" task doesn't span between documents.
193 |
194 | sid_r = np.arange(0, len(context))
195 | rng.shuffle(sid_r)
196 |
197 | vocab_words = list(tokenizer.vocab.keys())
198 | instances = []
199 | for _ in tqdm(range(dupe_factor)):
200 | for i in tqdm(range(len(sid_r))):
201 |
202 | sent_a = []
203 | switch_a = []
204 | for j in range(len(context[i])):
205 | utterance_a = context[i][j]
206 | utterance_a = tokenization.convert_to_unicode(utterance_a)
207 | utterance_a = tokenizer.tokenize(utterance_a)
208 | sent_a.extend(utterance_a)
209 | switch_a.extend([switch[i][j]] * len(utterance_a))
210 | assert len(sent_a) == len(switch_a)
211 |
212 | if random.random() < 0.5:
213 | sent_b = response[sid_r[i]]
214 | is_random_next = True
215 | else:
216 | sent_b = response[i]
217 | is_random_next = False
218 |
219 | sent_b = tokenization.convert_to_unicode(sent_b)
220 | sent_b = tokenizer.tokenize(sent_b)
221 | instances.extend(
222 | create_instances_from_document(
223 | sent_a, sent_b, switch_a, is_random_next, max_seq_length, short_seq_prob,
224 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
225 |
226 | rng.shuffle(instances)
227 | return instances
228 |
229 |
230 | def create_instances_from_document(
231 | tokens_a, tokens_b, switch_a, is_random_next, max_seq_length, short_seq_prob,
232 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
233 | """Creates `TrainingInstance`s for a single document."""
234 |
235 | # Account for [CLS], [SEP], [SEP]
236 | max_num_tokens = max_seq_length - 3
237 |
238 | # We *usually* want to fill up the entire sequence since we are padding
239 | # to `max_seq_length` anyways, so short sequences are generally wasted
240 | # computation. However, we *sometimes*
241 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
242 | # sequences to minimize the mismatch between pre-training and fine-tuning.
243 | # The `target_seq_length` is just a rough target however, whereas
244 | # `max_seq_length` is a hard limit.
245 |
246 | # We DON'T just concatenate all of the tokens from a document into a long
247 | # sequence and choose an arbitrary split point because this would make the
248 | # next sentence prediction task too easy. Instead, we split the input into
249 | # segments "A" and "B" based on the actual "sentences" provided by the user
250 | # input.
251 | instances = []
252 |
253 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
254 |
255 | assert len(tokens_a) >= 1
256 | assert len(tokens_b) >= 1
257 |
258 | tokens = []
259 | segment_ids = []
260 | switch_ids = []
261 | tokens.append("[CLS]")
262 | segment_ids.append(0)
263 | switch_ids.append(0)
264 | for i, token in enumerate(tokens_a):
265 | tokens.append(token)
266 | segment_ids.append(0)
267 | switch_ids.append(switch_a[i])
268 |
269 | tokens.append("[SEP]")
270 | segment_ids.append(0)
271 | switch_ids.append(0)
272 |
273 | for token in tokens_b:
274 | tokens.append(token)
275 | segment_ids.append(1)
276 | switch_ids.append(1)
277 | tokens.append("[SEP]")
278 | segment_ids.append(1)
279 | switch_ids.append(1)
280 |
281 | (tokens, masked_lm_positions,
282 | masked_lm_labels) = create_masked_lm_predictions(
283 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
284 | instance = TrainingInstance(
285 | tokens=tokens,
286 | segment_ids=segment_ids,
287 | switch_ids=switch_ids,
288 | is_random_next=is_random_next,
289 | masked_lm_positions=masked_lm_positions,
290 | masked_lm_labels=masked_lm_labels)
291 | instances.append(instance)
292 |
293 | return instances
294 |
295 |
296 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
297 | ["index", "label"])
298 |
299 |
300 | def create_masked_lm_predictions(tokens, masked_lm_prob,
301 | max_predictions_per_seq, vocab_words, rng):
302 | """Creates the predictions for the masked LM objective."""
303 |
304 | cand_indexes = []
305 | for (i, token) in enumerate(tokens):
306 | if token == "[CLS]" or token == "[SEP]":
307 | continue
308 | cand_indexes.append(i)
309 |
310 | rng.shuffle(cand_indexes)
311 |
312 | output_tokens = list(tokens)
313 |
314 | num_to_predict = min(max_predictions_per_seq,
315 | max(1, int(round(len(tokens) * masked_lm_prob))))
316 |
317 | masked_lms = []
318 | covered_indexes = set()
319 | for index in cand_indexes:
320 | if len(masked_lms) >= num_to_predict:
321 | break
322 | if index in covered_indexes:
323 | continue
324 | covered_indexes.add(index)
325 |
326 | masked_token = None
327 | # 80% of the time, replace with [MASK]
328 | if rng.random() < 0.8:
329 | masked_token = "[MASK]"
330 | else:
331 | # 10% of the time, keep original
332 | if rng.random() < 0.5:
333 | masked_token = tokens[index]
334 | # 10% of the time, replace with random word
335 | else:
336 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
337 |
338 | output_tokens[index] = masked_token
339 |
340 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
341 |
342 | masked_lms = sorted(masked_lms, key=lambda x: x.index)
343 |
344 | masked_lm_positions = []
345 | masked_lm_labels = []
346 | for p in masked_lms:
347 | masked_lm_positions.append(p.index)
348 | masked_lm_labels.append(p.label)
349 |
350 | return (output_tokens, masked_lm_positions, masked_lm_labels)
351 |
352 |
353 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
354 | """Truncates a pair of sequences to a maximum sequence length."""
355 | while True:
356 | total_length = len(tokens_a) + len(tokens_b)
357 | if total_length <= max_num_tokens:
358 | break
359 |
360 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
361 | assert len(trunc_tokens) >= 1
362 |
363 | # We want to sometimes truncate from the front and sometimes from the
364 | # back to add more randomness and avoid biases.
365 | if rng.random() < 0.5:
366 | del trunc_tokens[0]
367 | else:
368 | trunc_tokens.pop()
369 |
370 | def print_configuration_op(FLAGS):
371 | print('My Configurations:')
372 | for name, value in FLAGS.__flags.items():
373 | value = value.value
374 | if type(value) == float:
375 | print(' %s:\t %f' % (name, value))
376 | elif type(value) == int:
377 | print(' %s:\t %d' % (name, value))
378 | elif type(value) == str:
379 | print(' %s:\t %s' % (name, value))
380 | elif type(value) == bool:
381 | print(' %s:\t %s' % (name, value))
382 | else:
383 | print('%s:\t %s' % (name, value))
384 | print('End of configuration')
385 |
386 | def main(_):
387 | tf.logging.set_verbosity(tf.logging.INFO)
388 | print_configuration_op(FLAGS)
389 |
390 | tokenizer = tokenization.FullTokenizer(
391 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
392 |
393 | # 1. load context-response pairs
394 | response_dict = {}
395 | with open(FLAGS.response_file, 'rt') as f:
396 | for line in f:
397 | line = line.strip()
398 | fields = line.split('\t')
399 | if len(fields) != 2:
400 | print("WRONG LINE: {}".format(line))
401 | r_text = 'unknown'
402 | else:
403 | r_text = fields[1]
404 | response_dict[fields[0]] = r_text
405 |
406 | context = []
407 | response = []
408 | switch = []
409 | with open(FLAGS.train_file, 'rb') as f:
410 | lines = f.readlines()
411 | for index, line in enumerate(lines):
412 | line = line.decode('utf-8').strip()
413 | fields = line.split('\t')
414 | context_i = fields[1]
415 | utterances_i = context_i.split(" __EOS__ ")
416 | # utterances = [utterance + " __EOS__" for utterance in utterances]
417 | new_utterances_i = []
418 | switch_i = []
419 | for j, utterance in enumerate(utterances_i):
420 | new_utterances_i.append(utterance + " __EOS__")
421 | if j%2 == 0:
422 | switch_i.append(0)
423 | else:
424 | switch_i.append(1)
425 | assert len(new_utterances_i) == len(switch_i)
426 |
427 | if fields[2] != "NA":
428 | pos_ids = [id for id in fields[2].split('|')]
429 | for r_id in pos_ids:
430 | context.append(new_utterances_i)
431 |
432 | switch.append(switch_i)
433 |
434 | response_i = response_dict[r_id]
435 | response.append(response_i)
436 |
437 | if index % 10000 == 0:
438 | print('Done:', index)
439 |
440 | tf.logging.info("Reading from input files: {} context-response pairs".format(len(context)))
441 |
442 |
443 | # 2. create training instances
444 | rng = random.Random(FLAGS.random_seed)
445 | instances = create_training_instances(
446 | context, response, switch, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
447 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
448 | rng)
449 |
450 |
451 | # 3. write instance to example files
452 | output_files = [FLAGS.output_file]
453 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
454 | FLAGS.max_predictions_per_seq, output_files)
455 |
456 |
457 | if __name__ == "__main__":
458 | flags.mark_flag_as_required("train_file")
459 | flags.mark_flag_as_required("response_file")
460 | flags.mark_flag_as_required("output_file")
461 | flags.mark_flag_as_required("vocab_file")
462 | tf.app.run()
463 |
464 |
--------------------------------------------------------------------------------
/adapt_switch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Run masked LM/next sentence masked_lm pre-training for BERT."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import modeling_switch as modeling
23 | import optimization
24 | import tensorflow as tf
25 | from time import time
26 | import datetime
27 |
28 | flags = tf.flags
29 | FLAGS = flags.FLAGS
30 |
31 | ## Required parameters
32 | flags.DEFINE_integer("sample_num", '126',
33 | "total sample number")
34 |
35 | flags.DEFINE_integer("mid_save_step", '15000',
36 | "Epoch is so long, mid_save_step 15000 is roughly 3 hours")
37 |
38 | flags.DEFINE_string("input_file", 'output/test.tfrecord',
39 | "The input data dir. Should contain the .tsv files (or other data files) for the task.")
40 |
41 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json',
42 | "The config json file corresponding to the pre-trained BERT model. "
43 | "This specifies the model architecture.")
44 |
45 | flags.DEFINE_string("task_name", 'adaptation',
46 | "The name of the task to train.")
47 |
48 | flags.DEFINE_string("vocab_file", 'uncased_L-12_H-768_A-12/vocab.txt',
49 | "The vocabulary file that the BERT model was trained on.")
50 |
51 | flags.DEFINE_string("output_dir", './L-12_H-768_A-12_adapted',
52 | "The output directory where the model checkpoints will be written.")
53 |
54 | ## Other parameters
55 | flags.DEFINE_string("init_checkpoint", 'uncased_L-12_H-768_A-12/bert_model.ckpt',
56 | "Initial checkpoint (usually from a pre-trained BERT model).")
57 |
58 | flags.DEFINE_integer("max_seq_length", 320,
59 | "The maximum total input sequence length after WordPiece tokenization. "
60 | "Sequences longer than this will be truncated, and sequences shorter "
61 | "than this will be padded. Must match data generation.")
62 |
63 | flags.DEFINE_integer("max_predictions_per_seq", 10,
64 | "Maximum number of masked LM predictions per sequence. "
65 | "Must match data generation.")
66 |
67 | flags.DEFINE_bool("do_train", True,
68 | "Whether to run training.")
69 |
70 | flags.DEFINE_bool("do_eval", True,
71 | "Whether to run eval on the dev set.")
72 |
73 | flags.DEFINE_integer("train_batch_size", 8,
74 | "Total batch size for training.")
75 |
76 | flags.DEFINE_integer("eval_batch_size", 8,
77 | "Total batch size for eval.")
78 |
79 | flags.DEFINE_float("learning_rate", 5e-5,
80 | "The initial learning rate for Adam.")
81 |
82 | flags.DEFINE_float("warmup_proportion", 0.1,
83 | "Number of warmup steps.")
84 |
85 | flags.DEFINE_integer("num_train_epochs", 10,
86 | "num_train_epochs.")
87 |
88 |
89 |
90 | def model_fn_builder(features, is_training, bert_config, init_checkpoint, learning_rate,
91 | num_train_steps, num_warmup_steps, use_tpu,
92 | use_one_hot_embeddings):
93 | """Returns `model_fn` closure for TPUEstimator."""
94 |
95 | input_ids, input_mask, segment_ids, switch_ids, masked_lm_positions, \
96 | masked_lm_ids, masked_lm_weights, next_sentence_labels = features
97 |
98 | model = modeling.BertModel(
99 | config=bert_config,
100 | is_training=is_training,
101 | input_ids=input_ids,
102 | input_mask=input_mask,
103 | token_type_ids=segment_ids,
104 | switch_ids=switch_ids,
105 | use_one_hot_embeddings=use_one_hot_embeddings)
106 |
107 | (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
108 | bert_config, model.get_sequence_output(), model.get_embedding_table(),
109 | masked_lm_positions, masked_lm_ids, masked_lm_weights)
110 |
111 | (next_sentence_loss, next_sentence_example_loss, next_sentence_log_probs) = get_next_sentence_output(
112 | bert_config, model.get_pooled_output(), next_sentence_labels)
113 |
114 | total_loss = masked_lm_loss + next_sentence_loss
115 |
116 | tvars = tf.trainable_variables()
117 |
118 | if init_checkpoint:
119 | (assignment_map,
120 | initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
121 |
122 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
123 |
124 |
125 | train_op = optimization.create_optimizer(
126 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
127 |
128 | matrix = metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights,
129 | next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels)
130 |
131 | return train_op, total_loss, matrix, input_ids
132 |
133 |
134 |
135 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights,
136 | next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels):
137 | """Computes the loss and accuracy of the model."""
138 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
139 | [-1, masked_lm_log_probs.shape[-1]]) # [batch_size*max_predictions_per_seq, dim]
140 | masked_lm_predictions = tf.argmax(
141 | masked_lm_log_probs, axis=-1, output_type=tf.int32) # [batch_size*max_predictions_per_seq, ]
142 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
143 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
144 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
145 | masked_lm_accuracy = tf.metrics.accuracy(
146 | labels=masked_lm_ids,
147 | predictions=masked_lm_predictions,
148 | weights=masked_lm_weights)
149 | masked_lm_mean_loss = tf.metrics.mean(
150 | values=masked_lm_example_loss, weights=masked_lm_weights)
151 |
152 | next_sentence_log_probs = tf.reshape(
153 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) # [batch_size, 2]
154 | next_sentence_predictions = tf.argmax(
155 | next_sentence_log_probs, axis=-1, output_type=tf.int32) # [batch_size, ]
156 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
157 | next_sentence_accuracy = tf.metrics.accuracy(
158 | labels=next_sentence_labels, predictions=next_sentence_predictions)
159 | next_sentence_mean_loss = tf.metrics.mean(
160 | values=next_sentence_example_loss)
161 | # next_sentence_mean_loss = tf.reduce_mean(next_sentence_example_loss)
162 |
163 | return {
164 | "masked_lm_accuracy": masked_lm_accuracy,
165 | "masked_lm_loss": masked_lm_mean_loss,
166 | "next_sentence_accuracy": next_sentence_accuracy,
167 | "next_sentence_loss": next_sentence_mean_loss,
168 | }
169 |
170 |
171 |
172 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
173 | label_ids, label_weights):
174 | """Get loss and log probs for the masked LM."""
175 | input_tensor = gather_indexes(input_tensor, positions) # [batch_size*max_predictions_per_seq, dim]
176 |
177 | with tf.variable_scope("cls/predictions"):
178 | # We apply one more non-linear transformation before the output layer.
179 | # This matrix is not used after pre-training.
180 | with tf.variable_scope("transform"):
181 | input_tensor = tf.layers.dense(
182 | input_tensor,
183 | units=bert_config.hidden_size,
184 | activation=modeling.get_activation(bert_config.hidden_act),
185 | kernel_initializer=modeling.create_initializer(
186 | bert_config.initializer_range))
187 | input_tensor = modeling.layer_norm(input_tensor)
188 |
189 | # The output weights are the same as the input embeddings, but there is
190 | # an output-only bias for each token.
191 | output_bias = tf.get_variable(
192 | "output_bias",
193 | shape=[bert_config.vocab_size],
194 | initializer=tf.zeros_initializer())
195 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
196 | logits = tf.nn.bias_add(logits, output_bias)
197 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size*max_predictions_per_seq, vocab_size]
198 |
199 | label_ids = tf.reshape(label_ids, [-1])
200 | label_weights = tf.reshape(label_weights, [-1])
201 |
202 | one_hot_labels = tf.one_hot(
203 | label_ids, depth=bert_config.vocab_size, dtype=tf.float32)
204 |
205 | # The `positions` tensor might be zero-padded (if the sequence is too
206 | # short to have the maximum number of predictions). The `label_weights`
207 | # tensor has a value of 1.0 for every real prediction and 0.0 for the
208 | # padding predictions.
209 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch_size*max_predictions_per_seq, ]
210 | numerator = tf.reduce_sum(label_weights * per_example_loss) # [1, ]
211 | denominator = tf.reduce_sum(label_weights) + 1e-5
212 | loss = numerator / denominator
213 |
214 | return (loss, per_example_loss, log_probs)
215 |
216 |
217 | def get_next_sentence_output(bert_config, input_tensor, labels):
218 | """Get loss and log probs for the next sentence prediction."""
219 |
220 | # Simple binary classification. Note that 0 is "next sentence" and 1 is
221 | # "random sentence". This weight matrix is not used after pre-training.
222 | with tf.variable_scope("cls/seq_relationship"):
223 | output_weights = tf.get_variable(
224 | "output_weights",
225 | shape=[2, bert_config.hidden_size],
226 | initializer=modeling.create_initializer(bert_config.initializer_range))
227 | output_bias = tf.get_variable(
228 | "output_bias", shape=[2], initializer=tf.zeros_initializer())
229 |
230 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
231 | logits = tf.nn.bias_add(logits, output_bias)
232 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size, 2]
233 | labels = tf.reshape(labels, [-1])
234 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
235 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) # [batch_size, ]
236 | loss = tf.reduce_mean(per_example_loss) # [1, ]
237 | return (loss, per_example_loss, log_probs)
238 |
239 |
240 | def gather_indexes(sequence_tensor, positions):
241 | """Gathers the vectors at the specific positions over a minibatch."""
242 | # sequence_tensor = [batch_size, seq_length, width]
243 | # positions = [batch_size, max_predictions_per_seq]
244 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
245 | batch_size = sequence_shape[0]
246 | seq_length = sequence_shape[1]
247 | width = sequence_shape[2]
248 |
249 | flat_offsets = tf.reshape(
250 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
251 | flat_positions = tf.reshape(positions + flat_offsets, [-1])
252 | flat_sequence_tensor = tf.reshape(sequence_tensor,
253 | [batch_size * seq_length, width])
254 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
255 | return output_tensor
256 |
257 |
258 | def input_fn_builder(input_files,
259 | max_seq_length,
260 | max_predictions_per_seq,
261 | is_training,
262 | num_cpu_threads=4):
263 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
264 |
265 | def input_fn(params):
266 | """The actual input function."""
267 | batch_size = params["batch_size"]
268 |
269 | name_to_features = {
270 | "input_ids":
271 | tf.FixedLenFeature([max_seq_length], tf.int64),
272 | "input_mask":
273 | tf.FixedLenFeature([max_seq_length], tf.int64),
274 | "segment_ids":
275 | tf.FixedLenFeature([max_seq_length], tf.int64),
276 | "masked_lm_positions":
277 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
278 | "masked_lm_ids":
279 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
280 | "masked_lm_weights":
281 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
282 | "next_sentence_labels":
283 | tf.FixedLenFeature([1], tf.int64),
284 | }
285 |
286 | # For training, we want a lot of parallel reading and shuffling.
287 | # For eval, we want no shuffling and parallel reading doesn't matter.
288 | if is_training:
289 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
290 | d = d.repeat()
291 | d = d.shuffle(buffer_size=len(input_files))
292 |
293 | # `cycle_length` is the number of parallel files that get read.
294 | cycle_length = min(num_cpu_threads, len(input_files))
295 |
296 | # `sloppy` mode means that the interleaving is not exact. This adds
297 | # even more randomness to the training pipeline.
298 | d = d.apply(
299 | tf.contrib.data.parallel_interleave(
300 | tf.data.TFRecordDataset,
301 | sloppy=is_training,
302 | cycle_length=cycle_length))
303 | d = d.shuffle(buffer_size=100)
304 | else:
305 | d = tf.data.TFRecordDataset(input_files)
306 | # Since we evaluate for a fixed number of steps we don't want to encounter
307 | # out-of-range exceptions.
308 | d = d.repeat()
309 |
310 | # We must `drop_remainder` on training because the TPU requires fixed
311 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU
312 | # and we *don't* want to drop the remainder, otherwise we wont cover
313 | # every sample.
314 | d = d.apply(
315 | tf.contrib.data.map_and_batch(
316 | lambda record: _decode_record(record, name_to_features),
317 | batch_size=batch_size,
318 | num_parallel_batches=num_cpu_threads,
319 | drop_remainder=True))
320 | return d
321 |
322 | return input_fn
323 |
324 |
325 | def _decode_record(record, name_to_features):
326 | """Decodes a record to a TensorFlow example."""
327 | example = tf.parse_single_example(record, name_to_features)
328 |
329 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
330 | # So cast all int64 to int32.
331 | for name in list(example.keys()):
332 | t = example[name]
333 | if t.dtype == tf.int64:
334 | t = tf.to_int32(t)
335 | example[name] = t
336 |
337 | return example
338 |
339 |
340 |
341 | def run_epoch( epoch, sess, evaluate, eval_op, input_ids, lm_losses, saver, root_path, save_step, mid_save_step, phase, batch_size=16, train_op=tf.constant(0)):
342 | t_loss = 0
343 | n_all = 0
344 | t0 = time()
345 | t1 = time()
346 |
347 | masked_lm_accuracy = 0.0
348 | masked_lm_mean_loss = 0.0
349 | next_sentence_accuracy = 0.0
350 | next_sentence_mean_loss = 0.0
351 |
352 | step = 0
353 |
354 | print('running begin ... ')
355 | try:
356 | while True:
357 | step = step + 1
358 | y, matrix, batch_loss, _, _ = sess.run([input_ids, evaluate, lm_losses, train_op, eval_op] )
359 | masked_lm_accuracy, masked_lm_mean_loss, next_sentence_accuracy, next_sentence_mean_loss = matrix
360 |
361 | n_sample = len(y)
362 | n_all += n_sample
363 |
364 | t_loss += batch_loss * n_sample
365 | # save every epoch or 3 hour
366 | # if (step % save_step == 0) or (step % 15000 == 0):
367 | if (step % mid_save_step == 2):
368 | # c_time = str(datetime.datetime.now()).replace(' ', '-').split('.')[0]
369 | c_time = str(int(time()))
370 | save_path = os.path.join(root_path, 'bert_model_{0}_epoch_{1}'.format(c_time, epoch))
371 | if not os.path.exists(save_path):
372 | os.makedirs(save_path)
373 | saver.save(sess, os.path.join(save_path,'bert_model_{}.ckpt'.format(c_time)), global_step = step)
374 | print('save model epoch {}'.format(int(step/save_step)))
375 | print('masked_lm_accuracy {:.6f}, masked_lm_mean_loss {:.6f}, next_sentence_accuracy {:.6f}, next_sentence_mean_loss{:.6f}'.format(
376 | masked_lm_accuracy, masked_lm_mean_loss, next_sentence_accuracy, next_sentence_mean_loss
377 | ))
378 |
379 | print("{} Loss: {:.4f}, {:.2f} Seconds Used:".
380 | format(phase, t_loss / n_all, time() - t1))
381 | t1=time()
382 | print('Sample seen {} total time {}'.format(n_all,time() - t0))
383 |
384 | except tf.errors.OutOfRangeError:
385 | print('Epoch {} Done'.format(epoch))
386 | # c_time = str(datetime.datetime.now()).replace(' ', '-').split('.')[0]
387 | c_time = str(int(time()))
388 | save_path = os.path.join(root_path, 'bert_model_{0}_epoch_{1}'.format(c_time, step / save_step))
389 | if not os.path.exists(save_path):
390 | os.makedirs(save_path)
391 | saver.save(sess, os.path.join(save_path, 'bert_model_{}.ckpt'.format(c_time)), global_step=step)
392 | print('save model epoch {}'.format(int(step / save_step)))
393 |
394 | print(
395 | 'masked_lm_accuracy {:.6f}, masked_lm_mean_loss {:.6f}, next_sentence_accuracy {:.6f}, next_sentence_mean_loss{:.6f}'.format(
396 | masked_lm_accuracy, masked_lm_mean_loss, next_sentence_accuracy, next_sentence_mean_loss
397 | ))
398 | print("{} Loss: {:.4f}, {:.2f} Seconds Used:".
399 | format(phase, t_loss / n_all, time() - t1))
400 | t1 = time()
401 | print('Sample seen {} total time {}'.format(n_all, time() - t0))
402 | pass
403 |
404 |
405 |
406 | def parse_exmp(serial_exmp):
407 | input_data = tf.parse_single_example(serial_exmp,
408 | features={
409 | "input_ids":
410 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
411 | "input_mask":
412 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
413 | "segment_ids":
414 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
415 | "switch_ids":
416 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
417 | "masked_lm_positions":
418 | tf.FixedLenFeature([FLAGS.max_predictions_per_seq], tf.int64),
419 | "masked_lm_ids":
420 | tf.FixedLenFeature([FLAGS.max_predictions_per_seq], tf.int64),
421 | "masked_lm_weights":
422 | tf.FixedLenFeature([FLAGS.max_predictions_per_seq], tf.float32),
423 | "next_sentence_labels":
424 | tf.FixedLenFeature([1], tf.int64),
425 | }
426 | )
427 | # So cast all int64 to int32.
428 | for name in list(input_data.keys()):
429 | t = input_data[name]
430 | if t.dtype == tf.int64:
431 | t = tf.to_int32(t)
432 | input_data[name] = t
433 |
434 | input_ids = input_data["input_ids"]
435 | input_mask = input_data["input_mask"]
436 | segment_ids = input_data["segment_ids"]
437 | switch_ids = input_data["switch_ids"]
438 | m_lp = input_data["masked_lm_positions"]
439 | m_lids = input_data["masked_lm_ids"]
440 | m_lm_w = input_data["masked_lm_weights"]
441 | nsl = input_data["next_sentence_labels"]
442 | return input_ids, input_mask, segment_ids, switch_ids, m_lp, m_lids, m_lm_w, nsl
443 |
444 |
445 | def print_configuration_op(FLAGS):
446 | print('My Configurations:')
447 | #pdb.set_trace()
448 | for name, value in FLAGS.__flags.items():
449 | value=value.value
450 | if type(value) == float:
451 | print(' %s:\t %f'%(name, value))
452 | elif type(value) == int:
453 | print(' %s:\t %d'%(name, value))
454 | elif type(value) == str:
455 | print(' %s:\t %s'%(name, value))
456 | elif type(value) == bool:
457 | print(' %s:\t %s'%(name, value))
458 | else:
459 | print('%s:\t %s' % (name, value))
460 | print('End of configuration')
461 |
462 |
463 | def main(_):
464 | tf.logging.set_verbosity(tf.logging.INFO)
465 | print_configuration_op(FLAGS)
466 |
467 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
468 | root_path = FLAGS.output_dir
469 | if not os.path.exists(root_path):
470 | os.makedirs(root_path)
471 |
472 | num_train_steps = FLAGS.sample_num // FLAGS.train_batch_size * FLAGS.num_train_epochs
473 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
474 |
475 | buffer_size = 1000
476 | filenames = tf.placeholder(tf.string, shape=[None])
477 | dataset = tf.data.TFRecordDataset(filenames)
478 | dataset = dataset.map(parse_exmp) # Parse the record into tensors.
479 | dataset = dataset.repeat(1)
480 | dataset = dataset.shuffle(buffer_size)
481 | dataset = dataset.batch(FLAGS.train_batch_size)
482 | iterator = dataset.make_initializable_iterator()
483 | save_step = FLAGS.sample_num // FLAGS.train_batch_size
484 |
485 | input_ids, input_mask, segment_ids, switch_ids, masked_lm_positions, \
486 | masked_lm_ids, masked_lm_weights, next_sentence_labels = iterator.get_next()
487 | features = [input_ids, input_mask, segment_ids, switch_ids, masked_lm_positions, \
488 | masked_lm_ids, masked_lm_weights, next_sentence_labels]
489 | train_op, loss, matrix, input_ids = model_fn_builder(
490 | features, # ----model_fn_builder----
491 | is_training=True,
492 | bert_config=bert_config,
493 | init_checkpoint=FLAGS.init_checkpoint,
494 | learning_rate=FLAGS.learning_rate,
495 | num_train_steps=num_train_steps,
496 | num_warmup_steps=num_warmup_steps,
497 | use_tpu=False,
498 | use_one_hot_embeddings=False)
499 |
500 |
501 | masked_lm_accuracy, masked_acc_op = matrix["masked_lm_accuracy"]
502 | masked_lm_mean_loss, masked_loss_op= matrix["masked_lm_loss"]
503 | next_sentence_accuracy, next_sentence_op = matrix["next_sentence_accuracy"]
504 | next_sentence_mean_loss, next_sentence_loss_op = matrix["next_sentence_loss"]
505 |
506 | evaluate = [masked_lm_accuracy, masked_lm_mean_loss, next_sentence_accuracy, next_sentence_mean_loss]
507 | eval_op = [masked_acc_op, masked_loss_op, next_sentence_op, next_sentence_loss_op]
508 |
509 | config = tf.ConfigProto(allow_soft_placement=True)
510 | config.gpu_options.allow_growth = True
511 | saver = tf.train.Saver()
512 | with tf.Session(config=config) as sess:
513 | sess.run(tf.global_variables_initializer())
514 | sess.run(tf.local_variables_initializer())
515 |
516 | for epoch in range(FLAGS.num_train_epochs):
517 | sess.run(iterator.initializer, feed_dict={filenames: [FLAGS.input_file]})
518 | run_epoch(epoch, sess, evaluate, eval_op, input_ids, loss, saver, root_path, save_step,
519 | FLAGS.mid_save_step,'train', batch_size=FLAGS.train_batch_size, train_op=train_op)
520 |
521 |
522 |
523 | if __name__ == "__main__":
524 | tf.app.run()
525 |
526 |
--------------------------------------------------------------------------------
/modeling_switch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """The main BERT model and related functions."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import copy
23 | import json
24 | import math
25 | import re
26 | import numpy as np
27 | import six
28 | import tensorflow as tf
29 |
30 |
31 | class BertConfig(object):
32 | """Configuration for `BertModel`."""
33 |
34 | def __init__(self,
35 | vocab_size,
36 | hidden_size=768,
37 | num_hidden_layers=12,
38 | num_attention_heads=12,
39 | intermediate_size=3072,
40 | hidden_act="gelu",
41 | hidden_dropout_prob=0.1,
42 | attention_probs_dropout_prob=0.1,
43 | max_position_embeddings=512,
44 | type_vocab_size=16,
45 | initializer_range=0.02):
46 | """Constructs BertConfig.
47 |
48 | Args:
49 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
50 | hidden_size: Size of the encoder layers and the pooler layer.
51 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
52 | num_attention_heads: Number of attention heads for each attention layer in
53 | the Transformer encoder.
54 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
55 | layer in the Transformer encoder.
56 | hidden_act: The non-linear activation function (function or string) in the
57 | encoder and pooler.
58 | hidden_dropout_prob: The dropout probability for all fully connected
59 | layers in the embeddings, encoder, and pooler.
60 | attention_probs_dropout_prob: The dropout ratio for the attention
61 | probabilities.
62 | max_position_embeddings: The maximum sequence length that this model might
63 | ever be used with. Typically set this to something large just in case
64 | (e.g., 512 or 1024 or 2048).
65 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
66 | `BertModel`.
67 | initializer_range: The stdev of the truncated_normal_initializer for
68 | initializing all weight matrices.
69 | """
70 | self.vocab_size = vocab_size
71 | self.hidden_size = hidden_size
72 | self.num_hidden_layers = num_hidden_layers
73 | self.num_attention_heads = num_attention_heads
74 | self.hidden_act = hidden_act
75 | self.intermediate_size = intermediate_size
76 | self.hidden_dropout_prob = hidden_dropout_prob
77 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
78 | self.max_position_embeddings = max_position_embeddings
79 | self.type_vocab_size = type_vocab_size
80 | self.initializer_range = initializer_range
81 |
82 | @classmethod
83 | def from_dict(cls, json_object):
84 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
85 | config = BertConfig(vocab_size=None)
86 | for (key, value) in six.iteritems(json_object):
87 | config.__dict__[key] = value
88 | return config
89 |
90 | @classmethod
91 | def from_json_file(cls, json_file):
92 | """Constructs a `BertConfig` from a json file of parameters."""
93 | with tf.gfile.GFile(json_file, "r") as reader:
94 | text = reader.read()
95 | return cls.from_dict(json.loads(text))
96 |
97 | def to_dict(self):
98 | """Serializes this instance to a Python dictionary."""
99 | output = copy.deepcopy(self.__dict__)
100 | return output
101 |
102 | def to_json_string(self):
103 | """Serializes this instance to a JSON string."""
104 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
105 |
106 |
107 | class BertModel(object):
108 | """BERT model ("Bidirectional Encoder Representations from Transformers").
109 |
110 | Example usage:
111 |
112 | ```python
113 | # Already been converted into WordPiece token ids
114 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
115 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
116 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
117 |
118 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
119 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
120 |
121 | model = modeling.BertModel(config=config, is_training=True,
122 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
123 |
124 | label_embeddings = tf.get_variable(...)
125 | pooled_output = model.get_pooled_output()
126 | logits = tf.matmul(pooled_output, label_embeddings)
127 | ...
128 | ```
129 | """
130 |
131 | def __init__(self,
132 | config,
133 | is_training,
134 | input_ids,
135 | input_mask=None,
136 | token_type_ids=None,
137 | switch_ids=None,
138 | use_one_hot_embeddings=False,
139 | scope=None):
140 | """Constructor for BertModel.
141 |
142 | Args:
143 | config: `BertConfig` instance.
144 | is_training: bool. true for training model, false for eval model. Controls
145 | whether dropout will be applied.
146 | input_ids: int32 Tensor of shape [batch_size, seq_length].
147 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
148 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
149 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
150 | embeddings or tf.embedding_lookup() for the word embeddings.
151 | scope: (optional) variable scope. Defaults to "bert".
152 |
153 | Raises:
154 | ValueError: The config is invalid or one of the input tensor shapes
155 | is invalid.
156 | """
157 | config = copy.deepcopy(config)
158 | # if not is_training:
159 | # config.hidden_dropout_prob = 0.0
160 | # config.attention_probs_dropout_prob = 0.0
161 | config.hidden_dropout_prob = tf.cast(is_training, tf.float32) * config.hidden_dropout_prob
162 | config.attention_probs_dropout_prob = tf.cast(is_training, tf.float32) * config.attention_probs_dropout_prob
163 |
164 | input_shape = get_shape_list(input_ids, expected_rank=2)
165 | batch_size = input_shape[0]
166 | seq_length = input_shape[1]
167 |
168 | if input_mask is None:
169 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32)
170 |
171 | if token_type_ids is None:
172 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32)
173 |
174 | with tf.variable_scope(scope, default_name="bert"):
175 | with tf.variable_scope("embeddings"):
176 | # Perform embedding lookup on the word ids.
177 | (self.embedding_output, self.embedding_table) = embedding_lookup(
178 | input_ids=input_ids,
179 | vocab_size=config.vocab_size,
180 | embedding_size=config.hidden_size,
181 | initializer_range=config.initializer_range,
182 | word_embedding_name="word_embeddings",
183 | use_one_hot_embeddings=use_one_hot_embeddings)
184 |
185 | # Add positional embeddings and token type embeddings, then layer
186 | # normalize and perform dropout.
187 | self.embedding_output = embedding_postprocessor(
188 | input_tensor=self.embedding_output,
189 | use_token_type=True,
190 | token_type_ids=token_type_ids,
191 | token_type_vocab_size=config.type_vocab_size,
192 | token_type_embedding_name="token_type_embeddings",
193 | use_switch=True,
194 | switch_ids=switch_ids,
195 | use_position_embeddings=True,
196 | position_embedding_name="position_embeddings",
197 | initializer_range=config.initializer_range,
198 | max_position_embeddings=config.max_position_embeddings,
199 | dropout_prob=config.hidden_dropout_prob)
200 |
201 | with tf.variable_scope("encoder"):
202 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D
203 | # mask of shape [batch_size, seq_length, seq_length] which is used
204 | # for the attention scores.
205 | attention_mask = create_attention_mask_from_input_mask(
206 | input_ids, input_mask)
207 |
208 | # Run the stacked transformer.
209 | # `sequence_output` shape = [batch_size, seq_length, hidden_size].
210 | self.all_encoder_layers = transformer_model(
211 | input_tensor=self.embedding_output,
212 | attention_mask=attention_mask,
213 | hidden_size=config.hidden_size,
214 | num_hidden_layers=config.num_hidden_layers,
215 | num_attention_heads=config.num_attention_heads,
216 | intermediate_size=config.intermediate_size,
217 | intermediate_act_fn=get_activation(config.hidden_act),
218 | hidden_dropout_prob=config.hidden_dropout_prob,
219 | attention_probs_dropout_prob=config.attention_probs_dropout_prob,
220 | initializer_range=config.initializer_range,
221 | do_return_all_layers=True)
222 |
223 | self.sequence_output = self.all_encoder_layers[-1]
224 | # The "pooler" converts the encoded sequence tensor of shape
225 | # [batch_size, seq_length, hidden_size] to a tensor of shape
226 | # [batch_size, hidden_size]. This is necessary for segment-level
227 | # (or segment-pair-level) classification tasks where we need a fixed
228 | # dimensional representation of the segment.
229 | with tf.variable_scope("pooler"):
230 | # We "pool" the model by simply taking the hidden state corresponding
231 | # to the first token. We assume that this has been pre-trained
232 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1)
233 | self.pooled_output = tf.layers.dense(
234 | first_token_tensor,
235 | config.hidden_size,
236 | activation=tf.tanh,
237 | kernel_initializer=create_initializer(config.initializer_range))
238 |
239 | def get_pooled_output(self):
240 | return self.pooled_output
241 |
242 | def get_sequence_output(self):
243 | """Gets final hidden layer of encoder.
244 |
245 | Returns:
246 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
247 | to the final hidden of the transformer encoder.
248 | """
249 | return self.sequence_output
250 |
251 | def get_all_encoder_layers(self):
252 | return self.all_encoder_layers
253 |
254 | def get_embedding_output(self):
255 | """Gets output of the embedding lookup (i.e., input to the transformer).
256 |
257 | Returns:
258 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
259 | to the output of the embedding layer, after summing the word
260 | embeddings with the positional embeddings and the token type embeddings,
261 | then performing layer normalization. This is the input to the transformer.
262 | """
263 | return self.embedding_output
264 |
265 | def get_embedding_table(self):
266 | return self.embedding_table
267 |
268 |
269 | def gelu(x):
270 | """Gaussian Error Linear Unit.
271 |
272 | This is a smoother version of the RELU.
273 | Original paper: https://arxiv.org/abs/1606.08415
274 | Args:
275 | x: float Tensor to perform activation.
276 |
277 | Returns:
278 | `x` with the GELU activation applied.
279 | """
280 | cdf = 0.5 * (1.0 + tf.tanh(
281 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
282 | return x * cdf
283 |
284 |
285 | def get_activation(activation_string):
286 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
287 |
288 | Args:
289 | activation_string: String name of the activation function.
290 |
291 | Returns:
292 | A Python function corresponding to the activation function. If
293 | `activation_string` is None, empty, or "linear", this will return None.
294 | If `activation_string` is not a string, it will return `activation_string`.
295 |
296 | Raises:
297 | ValueError: The `activation_string` does not correspond to a known
298 | activation.
299 | """
300 |
301 | # We assume that anything that"s not a string is already an activation
302 | # function, so we just return it.
303 | if not isinstance(activation_string, six.string_types):
304 | return activation_string
305 |
306 | if not activation_string:
307 | return None
308 |
309 | act = activation_string.lower()
310 | if act == "linear":
311 | return None
312 | elif act == "relu":
313 | return tf.nn.relu
314 | elif act == "gelu":
315 | return gelu
316 | elif act == "tanh":
317 | return tf.tanh
318 | else:
319 | raise ValueError("Unsupported activation: %s" % act)
320 |
321 |
322 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
323 | """Compute the union of the current variables and checkpoint variables."""
324 | assignment_map = {}
325 | initialized_variable_names = {}
326 |
327 | name_to_variable = collections.OrderedDict()
328 | for var in tvars:
329 | name = var.name
330 | m = re.match("^(.*):\\d+$", name)
331 | if m is not None:
332 | name = m.group(1)
333 | name_to_variable[name] = var
334 |
335 | init_vars = tf.train.list_variables(init_checkpoint)
336 |
337 | assignment_map = collections.OrderedDict()
338 | for x in init_vars:
339 | (name, var) = (x[0], x[1])
340 | if name not in name_to_variable:
341 | continue
342 | assignment_map[name] = name
343 | initialized_variable_names[name] = 1
344 | initialized_variable_names[name + ":0"] = 1
345 |
346 | return (assignment_map, initialized_variable_names)
347 |
348 |
349 | def dropout(input_tensor, dropout_prob):
350 | """Perform dropout.
351 |
352 | Args:
353 | input_tensor: float Tensor.
354 | dropout_prob: Python float. The probability of dropping out a value (NOT of
355 | *keeping* a dimension as in `tf.nn.dropout`).
356 |
357 | Returns:
358 | A version of `input_tensor` with dropout applied.
359 | """
360 | if dropout_prob is None or dropout_prob == 0.0:
361 | return input_tensor
362 |
363 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
364 | return output
365 |
366 |
367 | def layer_norm(input_tensor, name=None):
368 | """Run layer normalization on the last dimension of the tensor."""
369 | return tf.contrib.layers.layer_norm(
370 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
371 |
372 |
373 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
374 | """Runs layer normalization followed by dropout."""
375 | output_tensor = layer_norm(input_tensor, name)
376 | output_tensor = dropout(output_tensor, dropout_prob)
377 | return output_tensor
378 |
379 |
380 | def create_initializer(initializer_range=0.02):
381 | """Creates a `truncated_normal_initializer` with the given range."""
382 | return tf.truncated_normal_initializer(stddev=initializer_range)
383 |
384 |
385 | def embedding_lookup(input_ids,
386 | vocab_size,
387 | embedding_size=128,
388 | initializer_range=0.02,
389 | word_embedding_name="word_embeddings",
390 | use_one_hot_embeddings=False):
391 | """Looks up words embeddings for id tensor.
392 |
393 | Args:
394 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
395 | ids.
396 | vocab_size: int. Size of the embedding vocabulary.
397 | embedding_size: int. Width of the word embeddings.
398 | initializer_range: float. Embedding initialization range.
399 | word_embedding_name: string. Name of the embedding table.
400 | use_one_hot_embeddings: bool. If True, use one-hot method for word
401 | embeddings. If False, use `tf.gather()`.
402 |
403 | Returns:
404 | float Tensor of shape [batch_size, seq_length, embedding_size].
405 | """
406 | # This function assumes that the input is of shape [batch_size, seq_length,
407 | # num_inputs].
408 | #
409 | # If the input is a 2D tensor of shape [batch_size, seq_length], we
410 | # reshape to [batch_size, seq_length, 1].
411 | if input_ids.shape.ndims == 2:
412 | input_ids = tf.expand_dims(input_ids, axis=[-1])
413 |
414 | embedding_table = tf.get_variable(
415 | name=word_embedding_name,
416 | shape=[vocab_size, embedding_size],
417 | initializer=create_initializer(initializer_range))
418 |
419 | flat_input_ids = tf.reshape(input_ids, [-1])
420 | if use_one_hot_embeddings:
421 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
422 | output = tf.matmul(one_hot_input_ids, embedding_table)
423 | else:
424 | output = tf.gather(embedding_table, flat_input_ids)
425 |
426 | input_shape = get_shape_list(input_ids)
427 |
428 | output = tf.reshape(output,
429 | input_shape[0:-1] + [input_shape[-1] * embedding_size])
430 | return (output, embedding_table)
431 |
432 |
433 | def embedding_postprocessor(input_tensor,
434 | use_token_type=False,
435 | token_type_ids=None,
436 | token_type_vocab_size=16,
437 | token_type_embedding_name="token_type_embeddings",
438 | use_switch=False,
439 | switch_ids=None,
440 | use_position_embeddings=True,
441 | position_embedding_name="position_embeddings",
442 | initializer_range=0.02,
443 | max_position_embeddings=512,
444 | dropout_prob=0.1):
445 | """Performs various post-processing on a word embedding tensor.
446 |
447 | Args:
448 | input_tensor: float Tensor of shape [batch_size, seq_length,
449 | embedding_size].
450 | use_token_type: bool. Whether to add embeddings for `token_type_ids`.
451 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
452 | Must be specified if `use_token_type` is True.
453 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
454 | token_type_embedding_name: string. The name of the embedding table variable
455 | for token type ids.
456 | use_position_embeddings: bool. Whether to add position embeddings for the
457 | position of each token in the sequence.
458 | position_embedding_name: string. The name of the embedding table variable
459 | for positional embeddings.
460 | initializer_range: float. Range of the weight initialization.
461 | max_position_embeddings: int. Maximum sequence length that might ever be
462 | used with this model. This can be longer than the sequence length of
463 | input_tensor, but cannot be shorter.
464 | dropout_prob: float. Dropout probability applied to the final output tensor.
465 |
466 | Returns:
467 | float tensor with same shape as `input_tensor`.
468 |
469 | Raises:
470 | ValueError: One of the tensor shapes or input values is invalid.
471 | """
472 | input_shape = get_shape_list(input_tensor, expected_rank=3)
473 | batch_size = input_shape[0]
474 | seq_length = input_shape[1]
475 | width = input_shape[2]
476 |
477 | output = input_tensor
478 |
479 | if use_token_type:
480 | if token_type_ids is None:
481 | raise ValueError("`token_type_ids` must be specified if"
482 | "`use_token_type` is True.")
483 | token_type_table = tf.get_variable(
484 | name=token_type_embedding_name,
485 | shape=[token_type_vocab_size, width],
486 | initializer=create_initializer(initializer_range))
487 | # This vocab will be small so we always do one-hot here, since it is always
488 | # faster for a small vocabulary.
489 | flat_token_type_ids = tf.reshape(token_type_ids, [-1])
490 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
491 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
492 | token_type_embeddings = tf.reshape(token_type_embeddings,
493 | [batch_size, seq_length, width])
494 | output += token_type_embeddings
495 |
496 | if use_switch:
497 | if switch_ids is None:
498 | raise ValueError("`switch_ids` must be specified if"
499 | "`use_switch` is True.")
500 | switch_type_table = tf.get_variable(
501 | name='switch_embedding',
502 | shape=[2, width],
503 | initializer=create_initializer(initializer_range))
504 | #
505 | flat_switch_ids = tf.reshape(switch_ids, [-1])
506 | switch_one_hot_ids = tf.one_hot(flat_switch_ids, depth=2)
507 | switch_type_embeddings = tf.matmul(switch_one_hot_ids, switch_type_table)
508 | switch_type_embeddings = tf.reshape(switch_type_embeddings,
509 | [batch_size, seq_length, width])
510 | output += switch_type_embeddings
511 |
512 | if use_position_embeddings:
513 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
514 | with tf.control_dependencies([assert_op]):
515 | full_position_embeddings = tf.get_variable(
516 | name=position_embedding_name,
517 | shape=[max_position_embeddings, width],
518 | initializer=create_initializer(initializer_range))
519 | # Since the position embedding table is a learned variable, we create it
520 | # using a (long) sequence length `max_position_embeddings`. The actual
521 | # sequence length might be shorter than this, for faster training of
522 | # tasks that do not have long sequences.
523 | #
524 | # So `full_position_embeddings` is effectively an embedding table
525 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
526 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
527 | # perform a slice.
528 | position_embeddings = tf.slice(full_position_embeddings, [0, 0],
529 | [seq_length, -1])
530 | num_dims = len(output.shape.as_list())
531 |
532 | # Only the last two dimensions are relevant (`seq_length` and `width`), so
533 | # we broadcast among the first dimensions, which is typically just
534 | # the batch size.
535 | position_broadcast_shape = []
536 | for _ in range(num_dims - 2):
537 | position_broadcast_shape.append(1)
538 | position_broadcast_shape.extend([seq_length, width])
539 | position_embeddings = tf.reshape(position_embeddings,
540 | position_broadcast_shape)
541 | output += position_embeddings
542 |
543 | output = layer_norm_and_dropout(output, dropout_prob)
544 | return output
545 |
546 |
547 | def create_attention_mask_from_input_mask(from_tensor, to_mask):
548 | """Create 3D attention mask from a 2D tensor mask.
549 |
550 | Args:
551 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
552 | to_mask: int32 Tensor of shape [batch_size, to_seq_length].
553 |
554 | Returns:
555 | float Tensor of shape [batch_size, from_seq_length, to_seq_length].
556 | """
557 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
558 | batch_size = from_shape[0]
559 | from_seq_length = from_shape[1]
560 |
561 | to_shape = get_shape_list(to_mask, expected_rank=2)
562 | to_seq_length = to_shape[1]
563 |
564 | to_mask = tf.cast(
565 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32)
566 |
567 | # We don't assume that `from_tensor` is a mask (although it could be). We
568 | # don't actually care if we attend *from* padding tokens (only *to* padding)
569 | # tokens so we create a tensor of all ones.
570 | #
571 | # `broadcast_ones` = [batch_size, from_seq_length, 1]
572 | broadcast_ones = tf.ones(
573 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32)
574 |
575 | # Here we broadcast along two dimensions to create the mask.
576 | mask = broadcast_ones * to_mask
577 |
578 | return mask
579 |
580 |
581 | def attention_layer(from_tensor,
582 | to_tensor,
583 | attention_mask=None,
584 | num_attention_heads=1,
585 | size_per_head=512,
586 | query_act=None,
587 | key_act=None,
588 | value_act=None,
589 | attention_probs_dropout_prob=0.0,
590 | initializer_range=0.02,
591 | do_return_2d_tensor=False,
592 | batch_size=None,
593 | from_seq_length=None,
594 | to_seq_length=None):
595 | """Performs multi-headed attention from `from_tensor` to `to_tensor`.
596 |
597 | This is an implementation of multi-headed attention based on "Attention
598 | is all you Need". If `from_tensor` and `to_tensor` are the same, then
599 | this is self-attention. Each timestep in `from_tensor` attends to the
600 | corresponding sequence in `to_tensor`, and returns a fixed-with vector.
601 |
602 | This function first projects `from_tensor` into a "query" tensor and
603 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list
604 | of tensors of length `num_attention_heads`, where each tensor is of shape
605 | [batch_size, seq_length, size_per_head].
606 |
607 | Then, the query and key tensors are dot-producted and scaled. These are
608 | softmaxed to obtain attention probabilities. The value tensors are then
609 | interpolated by these probabilities, then concatenated back to a single
610 | tensor and returned.
611 |
612 | In practice, the multi-headed attention are done with transposes and
613 | reshapes rather than actual separate tensors.
614 |
615 | Args:
616 | from_tensor: float Tensor of shape [batch_size, from_seq_length,
617 | from_width].
618 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
619 | attention_mask: (optional) int32 Tensor of shape [batch_size,
620 | from_seq_length, to_seq_length]. The values should be 1 or 0. The
621 | attention scores will effectively be set to -infinity for any positions in
622 | the mask that are 0, and will be unchanged for positions that are 1.
623 | num_attention_heads: int. Number of attention heads.
624 | size_per_head: int. Size of each attention head.
625 | query_act: (optional) Activation function for the query transform.
626 | key_act: (optional) Activation function for the key transform.
627 | value_act: (optional) Activation function for the value transform.
628 | attention_probs_dropout_prob: (optional) float. Dropout probability of the
629 | attention probabilities.
630 | initializer_range: float. Range of the weight initializer.
631 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
632 | * from_seq_length, num_attention_heads * size_per_head]. If False, the
633 | output will be of shape [batch_size, from_seq_length, num_attention_heads
634 | * size_per_head].
635 | batch_size: (Optional) int. If the input is 2D, this might be the batch size
636 | of the 3D version of the `from_tensor` and `to_tensor`.
637 | from_seq_length: (Optional) If the input is 2D, this might be the seq length
638 | of the 3D version of the `from_tensor`.
639 | to_seq_length: (Optional) If the input is 2D, this might be the seq length
640 | of the 3D version of the `to_tensor`.
641 |
642 | Returns:
643 | float Tensor of shape [batch_size, from_seq_length,
644 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
645 | true, this will be of shape [batch_size * from_seq_length,
646 | num_attention_heads * size_per_head]).
647 |
648 | Raises:
649 | ValueError: Any of the arguments or tensor shapes are invalid.
650 | """
651 |
652 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
653 | seq_length, width):
654 | output_tensor = tf.reshape(
655 | input_tensor, [batch_size, seq_length, num_attention_heads, width])
656 |
657 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
658 | return output_tensor
659 |
660 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3])
661 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3])
662 |
663 | if len(from_shape) != len(to_shape):
664 | raise ValueError(
665 | "The rank of `from_tensor` must match the rank of `to_tensor`.")
666 |
667 | if len(from_shape) == 3:
668 | batch_size = from_shape[0]
669 | from_seq_length = from_shape[1]
670 | to_seq_length = to_shape[1]
671 | elif len(from_shape) == 2:
672 | if (batch_size is None or from_seq_length is None or to_seq_length is None):
673 | raise ValueError(
674 | "When passing in rank 2 tensors to attention_layer, the values "
675 | "for `batch_size`, `from_seq_length`, and `to_seq_length` "
676 | "must all be specified.")
677 |
678 | # Scalar dimensions referenced here:
679 | # B = batch size (number of sequences)
680 | # F = `from_tensor` sequence length
681 | # T = `to_tensor` sequence length
682 | # N = `num_attention_heads`
683 | # H = `size_per_head`
684 |
685 | from_tensor_2d = reshape_to_matrix(from_tensor)
686 | to_tensor_2d = reshape_to_matrix(to_tensor)
687 |
688 | # `query_layer` = [B*F, N*H]
689 | query_layer = tf.layers.dense(
690 | from_tensor_2d,
691 | num_attention_heads * size_per_head,
692 | activation=query_act,
693 | name="query",
694 | kernel_initializer=create_initializer(initializer_range))
695 |
696 | # `key_layer` = [B*T, N*H]
697 | key_layer = tf.layers.dense(
698 | to_tensor_2d,
699 | num_attention_heads * size_per_head,
700 | activation=key_act,
701 | name="key",
702 | kernel_initializer=create_initializer(initializer_range))
703 |
704 | # `value_layer` = [B*T, N*H]
705 | value_layer = tf.layers.dense(
706 | to_tensor_2d,
707 | num_attention_heads * size_per_head,
708 | activation=value_act,
709 | name="value",
710 | kernel_initializer=create_initializer(initializer_range))
711 |
712 | # `query_layer` = [B, N, F, H]
713 | query_layer = transpose_for_scores(query_layer, batch_size,
714 | num_attention_heads, from_seq_length,
715 | size_per_head)
716 |
717 | # `key_layer` = [B, N, T, H]
718 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads,
719 | to_seq_length, size_per_head)
720 |
721 | # Take the dot product between "query" and "key" to get the raw
722 | # attention scores.
723 | # `attention_scores` = [B, N, F, T]
724 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
725 | attention_scores = tf.multiply(attention_scores,
726 | 1.0 / math.sqrt(float(size_per_head)))
727 |
728 | if attention_mask is not None:
729 | # `attention_mask` = [B, 1, F, T]
730 | attention_mask = tf.expand_dims(attention_mask, axis=[1])
731 |
732 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
733 | # masked positions, this operation will create a tensor which is 0.0 for
734 | # positions we want to attend and -10000.0 for masked positions.
735 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0
736 |
737 | # Since we are adding it to the raw scores before the softmax, this is
738 | # effectively the same as removing these entirely.
739 | attention_scores += adder
740 |
741 | # Normalize the attention scores to probabilities.
742 | # `attention_probs` = [B, N, F, T]
743 | attention_probs = tf.nn.softmax(attention_scores)
744 |
745 | # This is actually dropping out entire tokens to attend to, which might
746 | # seem a bit unusual, but is taken from the original Transformer paper.
747 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
748 |
749 | # `value_layer` = [B, T, N, H]
750 | value_layer = tf.reshape(
751 | value_layer,
752 | [batch_size, to_seq_length, num_attention_heads, size_per_head])
753 |
754 | # `value_layer` = [B, N, T, H]
755 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3])
756 |
757 | # `context_layer` = [B, N, F, H]
758 | context_layer = tf.matmul(attention_probs, value_layer)
759 |
760 | # `context_layer` = [B, F, N, H]
761 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
762 |
763 | if do_return_2d_tensor:
764 | # `context_layer` = [B*F, N*H]
765 | context_layer = tf.reshape(
766 | context_layer,
767 | [batch_size * from_seq_length, num_attention_heads * size_per_head])
768 | else:
769 | # `context_layer` = [B, F, N*H]
770 | context_layer = tf.reshape(
771 | context_layer,
772 | [batch_size, from_seq_length, num_attention_heads * size_per_head])
773 |
774 | return context_layer
775 |
776 |
777 | def transformer_model(input_tensor,
778 | attention_mask=None,
779 | hidden_size=768,
780 | num_hidden_layers=12,
781 | num_attention_heads=12,
782 | intermediate_size=3072,
783 | intermediate_act_fn=gelu,
784 | hidden_dropout_prob=0.1,
785 | attention_probs_dropout_prob=0.1,
786 | initializer_range=0.02,
787 | do_return_all_layers=False):
788 | """Multi-headed, multi-layer Transformer from "Attention is All You Need".
789 |
790 | This is almost an exact implementation of the original Transformer encoder.
791 |
792 | See the original paper:
793 | https://arxiv.org/abs/1706.03762
794 |
795 | Also see:
796 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
797 |
798 | Args:
799 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
800 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
801 | seq_length], with 1 for positions that can be attended to and 0 in
802 | positions that should not be.
803 | hidden_size: int. Hidden size of the Transformer.
804 | num_hidden_layers: int. Number of layers (blocks) in the Transformer.
805 | num_attention_heads: int. Number of attention heads in the Transformer.
806 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed
807 | forward) layer.
808 | intermediate_act_fn: function. The non-linear activation function to apply
809 | to the output of the intermediate/feed-forward layer.
810 | hidden_dropout_prob: float. Dropout probability for the hidden layers.
811 | attention_probs_dropout_prob: float. Dropout probability of the attention
812 | probabilities.
813 | initializer_range: float. Range of the initializer (stddev of truncated
814 | normal).
815 | do_return_all_layers: Whether to also return all layers or just the final
816 | layer.
817 |
818 | Returns:
819 | float Tensor of shape [batch_size, seq_length, hidden_size], the final
820 | hidden layer of the Transformer.
821 |
822 | Raises:
823 | ValueError: A Tensor shape or parameter is invalid.
824 | """
825 | if hidden_size % num_attention_heads != 0:
826 | raise ValueError(
827 | "The hidden size (%d) is not a multiple of the number of attention "
828 | "heads (%d)" % (hidden_size, num_attention_heads))
829 |
830 | attention_head_size = int(hidden_size / num_attention_heads)
831 | input_shape = get_shape_list(input_tensor, expected_rank=3)
832 | batch_size = input_shape[0]
833 | seq_length = input_shape[1]
834 | input_width = input_shape[2]
835 |
836 | # The Transformer performs sum residuals on all layers so the input needs
837 | # to be the same as the hidden size.
838 | if input_width != hidden_size:
839 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
840 | (input_width, hidden_size))
841 |
842 | # We keep the representation as a 2D tensor to avoid re-shaping it back and
843 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
844 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to
845 | # help the optimizer.
846 | prev_output = reshape_to_matrix(input_tensor)
847 |
848 | all_layer_outputs = []
849 | for layer_idx in range(num_hidden_layers):
850 | with tf.variable_scope("layer_%d" % layer_idx):
851 | layer_input = prev_output
852 |
853 | with tf.variable_scope("attention"):
854 | attention_heads = []
855 | with tf.variable_scope("self"):
856 | attention_head = attention_layer(
857 | from_tensor=layer_input,
858 | to_tensor=layer_input,
859 | attention_mask=attention_mask,
860 | num_attention_heads=num_attention_heads,
861 | size_per_head=attention_head_size,
862 | attention_probs_dropout_prob=attention_probs_dropout_prob,
863 | initializer_range=initializer_range,
864 | do_return_2d_tensor=True,
865 | batch_size=batch_size,
866 | from_seq_length=seq_length,
867 | to_seq_length=seq_length)
868 | attention_heads.append(attention_head)
869 |
870 | attention_output = None
871 | if len(attention_heads) == 1:
872 | attention_output = attention_heads[0]
873 | else:
874 | # In the case where we have other sequences, we just concatenate
875 | # them to the self-attention head before the projection.
876 | attention_output = tf.concat(attention_heads, axis=-1)
877 |
878 | # Run a linear projection of `hidden_size` then add a residual
879 | # with `layer_input`.
880 | with tf.variable_scope("output"):
881 | attention_output = tf.layers.dense(
882 | attention_output,
883 | hidden_size,
884 | kernel_initializer=create_initializer(initializer_range))
885 | attention_output = dropout(attention_output, hidden_dropout_prob)
886 | attention_output = layer_norm(attention_output + layer_input)
887 |
888 | # The activation is only applied to the "intermediate" hidden layer.
889 | with tf.variable_scope("intermediate"):
890 | intermediate_output = tf.layers.dense(
891 | attention_output,
892 | intermediate_size,
893 | activation=intermediate_act_fn,
894 | kernel_initializer=create_initializer(initializer_range))
895 |
896 | # Down-project back to `hidden_size` then add the residual.
897 | with tf.variable_scope("output"):
898 | layer_output = tf.layers.dense(
899 | intermediate_output,
900 | hidden_size,
901 | kernel_initializer=create_initializer(initializer_range))
902 | layer_output = dropout(layer_output, hidden_dropout_prob)
903 | layer_output = layer_norm(layer_output + attention_output)
904 | prev_output = layer_output
905 | all_layer_outputs.append(layer_output)
906 |
907 | if do_return_all_layers:
908 | final_outputs = []
909 | for layer_output in all_layer_outputs:
910 | final_output = reshape_from_matrix(layer_output, input_shape)
911 | final_outputs.append(final_output)
912 | return final_outputs
913 | else:
914 | final_output = reshape_from_matrix(prev_output, input_shape)
915 | return final_output
916 |
917 |
918 | def get_shape_list(tensor, expected_rank=None, name=None):
919 | """Returns a list of the shape of tensor, preferring static dimensions.
920 |
921 | Args:
922 | tensor: A tf.Tensor object to find the shape of.
923 | expected_rank: (optional) int. The expected rank of `tensor`. If this is
924 | specified and the `tensor` has a different rank, and exception will be
925 | thrown.
926 | name: Optional name of the tensor for the error message.
927 |
928 | Returns:
929 | A list of dimensions of the shape of tensor. All static dimensions will
930 | be returned as python integers, and dynamic dimensions will be returned
931 | as tf.Tensor scalars.
932 | """
933 | if name is None:
934 | name = tensor.name
935 |
936 | if expected_rank is not None:
937 | assert_rank(tensor, expected_rank, name)
938 |
939 | shape = tensor.shape.as_list()
940 |
941 | non_static_indexes = []
942 | for (index, dim) in enumerate(shape):
943 | if dim is None:
944 | non_static_indexes.append(index)
945 |
946 | if not non_static_indexes:
947 | return shape
948 |
949 | dyn_shape = tf.shape(tensor)
950 | for index in non_static_indexes:
951 | shape[index] = dyn_shape[index]
952 | return shape
953 |
954 |
955 | def reshape_to_matrix(input_tensor):
956 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
957 | ndims = input_tensor.shape.ndims
958 | if ndims < 2:
959 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" %
960 | (input_tensor.shape))
961 | if ndims == 2:
962 | return input_tensor
963 |
964 | width = input_tensor.shape[-1]
965 | output_tensor = tf.reshape(input_tensor, [-1, width])
966 | return output_tensor
967 |
968 |
969 | def reshape_from_matrix(output_tensor, orig_shape_list):
970 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
971 | if len(orig_shape_list) == 2:
972 | return output_tensor
973 |
974 | output_shape = get_shape_list(output_tensor)
975 |
976 | orig_dims = orig_shape_list[0:-1]
977 | width = output_shape[-1]
978 |
979 | return tf.reshape(output_tensor, orig_dims + [width])
980 |
981 |
982 | def assert_rank(tensor, expected_rank, name=None):
983 | """Raises an exception if the tensor rank is not of the expected rank.
984 |
985 | Args:
986 | tensor: A tf.Tensor to check the rank of.
987 | expected_rank: Python integer or list of integers, expected rank.
988 | name: Optional name of the tensor for the error message.
989 |
990 | Raises:
991 | ValueError: If the expected shape doesn't match the actual shape.
992 | """
993 | if name is None:
994 | name = tensor.name
995 |
996 | expected_rank_dict = {}
997 | if isinstance(expected_rank, six.integer_types):
998 | expected_rank_dict[expected_rank] = True
999 | else:
1000 | for x in expected_rank:
1001 | expected_rank_dict[x] = True
1002 |
1003 | actual_rank = tensor.shape.ndims
1004 | if actual_rank not in expected_rank_dict:
1005 | scope_name = tf.get_variable_scope().name
1006 | raise ValueError(
1007 | "For the tensor `%s` in scope `%s`, the actual rank "
1008 | "`%d` (shape = %s) is not equal to the expected rank `%s`" %
1009 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank)))
1010 |
--------------------------------------------------------------------------------