├── output └── README.txt ├── image ├── UbuntuV1_V2.png └── Douban_Ecommerce.png ├── uncased_L-12_H-768_A-12 └── README.txt ├── scripts ├── ubuntu_test.sh ├── ubuntu_train.sh └── adaptation.sh ├── __init__.py ├── data └── Ubuntu_V1_Xu │ ├── README.txt │ ├── tokenization.py │ ├── create_finetuning_data.py │ └── create_adaptation_data.py ├── uncased_L-12_H-768_A-12_adapted └── README.txt ├── compute_metrics.py ├── README.md ├── metrics.py ├── optimization.py ├── test.py ├── tokenization.py ├── train.py ├── adapt_switch.py └── modeling_switch.py /output/README.txt: -------------------------------------------------------------------------------- 1 | 2 | Models will be saved here. 3 | -------------------------------------------------------------------------------- /image/UbuntuV1_V2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonForJoy/SA-BERT/HEAD/image/UbuntuV1_V2.png -------------------------------------------------------------------------------- /image/Douban_Ecommerce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonForJoy/SA-BERT/HEAD/image/Douban_Ecommerce.png -------------------------------------------------------------------------------- /uncased_L-12_H-768_A-12/README.txt: -------------------------------------------------------------------------------- 1 | 2 | ====== Download the BERT base model ====== 3 | 4 | link: https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip 5 | Move to path: ./uncased_L-12_H-768_A-12 6 | -------------------------------------------------------------------------------- /scripts/ubuntu_test.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=3 python -u ../test.py \ 3 | --test_dir ../data/Ubuntu_V1_Xu/processed_test.tfrecord \ 4 | --vocab_file ../uncased_L-12_H-768_A-12/vocab.txt \ 5 | --bert_config_file ../uncased_L-12_H-768_A-12/bert_config.json \ 6 | --max_seq_length 512 \ 7 | --eval_batch_size 50 \ 8 | --restore_model_dir ../output/Ubuntu_V1_Xu/1569550213 > log_test.txt 2>&1 & 9 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /scripts/ubuntu_train.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=3 python -u ../train.py \ 3 | --task_name fine_tuning \ 4 | --train_dir ../data/Ubuntu_V1_Xu/processed_train.tfrecord \ 5 | --valid_dir ../data/Ubuntu_V1_Xu/processed_valid.tfrecord \ 6 | --output_dir ../output/Ubuntu_V1_Xu \ 7 | --do_lower_case True \ 8 | --vocab_file ../uncased_L-12_H-768_A-12/vocab.txt \ 9 | --bert_config_file ../uncased_L-12_H-768_A-12/bert_config.json \ 10 | --init_checkpoint ../uncased_L-12_H-768_A-12/bert_model.ckpt \ 11 | --max_seq_length 512 \ 12 | --do_train True \ 13 | --train_batch_size 25 \ 14 | --learning_rate 2e-5 \ 15 | --num_train_epochs 10 \ 16 | --warmup_proportion 0.1 > log_train.txt 2>&1 & 17 | -------------------------------------------------------------------------------- /scripts/adaptation.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=3 python -u ../adapt_switch.py \ 3 | --task_name adaptation \ 4 | --sample_num 5000000 \ 5 | --mid_save_step 15000 \ 6 | --input_file ../data/Ubuntu_V1_Xu/pretrain_data.tfrecord \ 7 | --output_dir ../uncased_L-12_H-768_A-12_adapted \ 8 | --vocab_file ../uncased_L-12_H-768_A-12/vocab.txt \ 9 | --bert_config_file ../uncased_L-12_H-768_A-12/bert_config.json \ 10 | --init_checkpoint ../uncased_L-12_H-768_A-12/bert_model.ckpt \ 11 | --max_seq_length 512 \ 12 | --max_predictions_per_seq 25 \ 13 | --train_batch_size 20 \ 14 | --eval_batch_size 20 \ 15 | --learning_rate 5e-5 \ 16 | --num_train_epochs 1 \ 17 | --warmup_proportion 0.1 > log_adaptation.txt 2>&1 & 18 | -------------------------------------------------------------------------------- /data/Ubuntu_V1_Xu/README.txt: -------------------------------------------------------------------------------- 1 | 2 | ====== Download the dataset ====== 3 | 4 | Take Ubuntu_V1 as an example 5 | link: https://drive.google.com/file/d/1-rNv34hLoZr300JF3v7nuLswM7GRqeNc/view 6 | Move to path: /data/Ubuntu_V1_Xu/Ubuntu_Corpus_V1 7 | 8 | If you use the processed dataset, please cite the following paper: 9 | 10 | @inproceedings{Gu:2019:IMN:3357384.3358140, 11 | author = {Gu, Jia-Chen and 12 | Ling, Zhen-Hua and 13 | Liu, Quan}, 14 | title = {Interactive Matching Network for Multi-Turn Response Selection in Retrieval-Based Chatbots}, 15 | booktitle = {Proceedings of the 28th ACM International Conference on Information and Knowledge Management}, 16 | series = {CIKM '19}, 17 | year = {2019}, 18 | isbn = {978-1-4503-6976-3}, 19 | location = {Beijing, China}, 20 | pages = {2321--2324}, 21 | url = {http://doi.acm.org/10.1145/3357384.3358140}, 22 | doi = {10.1145/3357384.3358140}, 23 | acmid = {3358140}, 24 | publisher = {ACM}, 25 | } -------------------------------------------------------------------------------- /uncased_L-12_H-768_A-12_adapted/README.txt: -------------------------------------------------------------------------------- 1 | 2 | ====== Download the ADAPTED BERT base model ====== 3 | 4 | We provide the model adapted on Ubuntu V1 5 | link: https://drive.google.com/file/d/1M8V018XZbVDo4Xq96pCLFRt6yVzoKtjH/view?usp=sharing 6 | Move to path: ./uncased_L-12_H-768_A-12_adapted 7 | 8 | If you use the adapted model, please cite the following paper: 9 | 10 | @inproceedings{Gu:2020:SABERT:3340531.3412330, 11 | author = {Gu, Jia-Chen and 12 | Li, Tianda and 13 | Liu, Quan and 14 | Ling, Zhen-Hua and 15 | Su, Zhiming and 16 | Wei, Si and 17 | Zhu, Xiaodan 18 | }, 19 | title = {Speaker-Aware BERT for Multi-Turn Response Selection in Retrieval-Based Chatbots}, 20 | booktitle = {Proceedings of the 29th ACM International Conference on Information and Knowledge Management}, 21 | series = {CIKM '20}, 22 | year = {2020}, 23 | isbn = {978-1-4503-6859-9}, 24 | location = {Virtual Event, Ireland}, 25 | url = {http://doi.acm.org/10.1145/3340531.3412330}, 26 | doi = {10.1145/3340531.3412330}, 27 | acmid = {3412330}, 28 | publisher = {ACM}, 29 | } 30 | -------------------------------------------------------------------------------- /compute_metrics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Load the output_test.txt file and compute the metrics 3 | ''' 4 | 5 | 6 | import random 7 | from collections import defaultdict 8 | import metrics 9 | 10 | 11 | test_out_filename = "./output/Ubuntu_V1_Xu/1596330255/output_test.txt" # modify this variable to the path to the testing model 12 | print("*"*20 + test_out_filename + "*"*20 + "\n") 13 | 14 | with open(test_out_filename, 'r') as f: 15 | 16 | # candidate size = 10 17 | results = defaultdict(list) 18 | lines = f.readlines() 19 | for line in lines[1:]: 20 | line = line.strip().split('\t') 21 | us_id = line[0] 22 | r_id = line[1] 23 | prob_score = float(line[2]) 24 | label = float(line[4]) 25 | results[us_id].append((r_id, label, prob_score)) 26 | 27 | accu, precision, recall, f1, loss = metrics.classification_metrics(results) 28 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss)) 29 | total_valid_query = metrics.get_num_valid_query(results) 30 | mvp = metrics.mean_average_precision(results) 31 | mrr = metrics.mean_reciprocal_rank(results) 32 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tNum_query: {}'.format( 33 | mvp, mrr, total_valid_query)) 34 | top_1_precision = metrics.top_k_precision(results, k=1) 35 | top_2_precision = metrics.top_k_precision(results, k=2) 36 | top_5_precision = metrics.top_k_precision(results, k=5) 37 | print('Recall_10@1: {}\tRecall_10@2: {}\tRecall_10@5: {}\n'.format( 38 | top_1_precision, top_2_precision, top_5_precision)) 39 | 40 | # candidate size = 2, the results may vary at different runs because we sample the negative candidate randomly 41 | results_bin = defaultdict(list) 42 | for us_id, candidates in results.items(): 43 | false_candidates = [] 44 | for candidate in candidates: 45 | r_id, label, prob_score = candidate 46 | if label == 1.0: 47 | results_bin[us_id].append(candidate) 48 | if label == 0.0: 49 | false_candidates.append(candidate) 50 | false_candidate = random.sample(false_candidates, 1) 51 | results_bin[us_id].append(false_candidate[0]) 52 | 53 | accu, precision, recall, f1, loss = metrics.classification_metrics(results_bin) 54 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss)) 55 | total_valid_query = metrics.get_num_valid_query(results_bin) 56 | mvp = metrics.mean_average_precision(results_bin) 57 | mrr = metrics.mean_reciprocal_rank(results_bin) 58 | top_1_precision = metrics.top_k_precision(results_bin, k=1) 59 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tNum_query: {}'.format( 60 | mvp, mrr, total_valid_query)) 61 | print('Recall_2@1: {}\n'.format( 62 | top_1_precision)) 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Speaker-Aware BERT for Multi-Turn Response Selection 2 | This repository contains the source code and pre-trained models for the CIKM 2020 paper [Speaker-Aware BERT for Multi-Turn Response Selection in Retrieval-Based Chatbots](https://arxiv.org/pdf/2004.03588.pdf) by Gu et al.
3 | 4 | 5 | ## Results 6 | 7 | 8 | 9 | 10 | ## Dependencies 11 | Python 3.6
12 | Tensorflow 1.13.1 13 | 14 | 15 | ## Download 16 | - Download the [BERT released by the Google research](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip), 17 | and move to path: ./uncased_L-12_H-768_A-12
18 | 19 | - We also provide the [BERT adapted on the Ubuntu V1 dataset](https://drive.google.com/file/d/1M8V018XZbVDo4Xq96pCLFRt6yVzoKtjH/view?usp=sharing), 20 | and move to path: ./uncased_L-12_H-768_A-12_adapted. You just need to fine tune it to reproduce our results.
21 | 22 | - Download the [Ubuntu V1 dataset](https://drive.google.com/file/d/1-rNv34hLoZr300JF3v7nuLswM7GRqeNc/view), 23 | and move to path: ./data/Ubuntu_V1_Xu/Ubuntu_Corpus_V1
24 | 25 | 26 | ## Adaptation 27 | Create the adaptation data. 28 | ``` 29 | cd data/Ubuntu_V1_Xu/ 30 | python create_adaptation_data.py 31 | ``` 32 | Running the adaptation process. 33 | ``` 34 | cd scripts/ 35 | bash adaptation.sh 36 | ``` 37 | The adapted model will be saved to the path ```./uncased_L-12_H-768_A-12_adapted```.
38 | Modify the filenames in this folder to make it the same as those in Google's BERT. 39 | 40 | 41 | ## Training 42 | Create the fine-tuning data. 43 | ``` 44 | cd data/Ubuntu_V1_Xu/ 45 | python create_finetuning_data.py 46 | ``` 47 | Running the fine-tuning process. 48 | 49 | ``` 50 | cd scripts/ 51 | bash ubuntu_train.sh 52 | ``` 53 | 54 | ## Testing 55 | Modify the variable ```restore_model_dir``` in ```ubuntu_test.sh``` 56 | ``` 57 | cd scripts/ 58 | bash ubuntu_v1_test.sh 59 | ``` 60 | A "output_test.txt" file which records scores for each context-response pair will be saved to the path of ```restore_model_dir```.
61 | Modify the variable ```test_out_filename``` in ```compute_metrics.py``` and then run the following command, various metrics will be shown. 62 | ``` 63 | python compute_metrics.py 64 | ``` 65 | 66 | 67 | ## Cite 68 | If you use the source code and pre-trained models, please cite the following paper: 69 | **"Speaker-Aware BERT for Multi-Turn Response Selection in Retrieval-Based Chatbots"** 70 | Jia-Chen Gu, Tianda Li, Quan Liu, Zhen-Hua Ling, Zhiming Su, Si Wei, Xiaodan Zhu. _CIKM (2020)_ 71 | 72 | ``` 73 | @inproceedings{Gu:2020:SABERT:3340531.3412330, 74 | author = {Gu, Jia-Chen and 75 | Li, Tianda and 76 | Liu, Quan and 77 | Ling, Zhen-Hua and 78 | Su, Zhiming and 79 | Wei, Si and 80 | Zhu, Xiaodan 81 | }, 82 | title = {Speaker-Aware BERT for Multi-Turn Response Selection in Retrieval-Based Chatbots}, 83 | booktitle = {Proceedings of the 29th ACM International Conference on Information and Knowledge Management}, 84 | series = {CIKM '20}, 85 | year = {2020}, 86 | isbn = {978-1-4503-6859-9}, 87 | location = {Virtual Event, Ireland}, 88 | pages = {2041--2044}, 89 | url = {http://doi.acm.org/10.1145/3340531.3412330}, 90 | doi = {10.1145/3340531.3412330}, 91 | acmid = {3412330}, 92 | publisher = {ACM}, 93 | } 94 | ``` 95 | 96 | 97 | ## Update 98 | Please feel free to open issues if you have some problems. 99 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import math 3 | 4 | 5 | def is_valid_query(v): 6 | num_pos = 0 7 | num_neg = 0 8 | for aid, label, score in v: 9 | if label > 0: 10 | num_pos += 1 11 | else: 12 | num_neg += 1 13 | if num_pos > 0 and num_neg > 0: 14 | return True 15 | else: 16 | return False 17 | 18 | 19 | def get_num_valid_query(results): 20 | num_query = 0 21 | for k, v in results.items(): 22 | if not is_valid_query(v): 23 | continue 24 | num_query += 1 25 | return num_query 26 | 27 | 28 | def top_1_precision(results): 29 | num_query = 0 30 | top_1_correct = 0.0 31 | for k, v in results.items(): 32 | if not is_valid_query(v): 33 | continue 34 | num_query += 1 35 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True) 36 | aid, label, score = sorted_v[0] 37 | if label > 0: 38 | top_1_correct += 1 39 | 40 | if num_query > 0: 41 | return top_1_correct / num_query 42 | else: 43 | return 0.0 44 | 45 | 46 | def mean_reciprocal_rank(results): 47 | num_query = 0 48 | mrr = 0.0 49 | for k, v in results.items(): 50 | if not is_valid_query(v): 51 | continue 52 | 53 | num_query += 1 54 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True) 55 | for i, rec in enumerate(sorted_v): 56 | aid, label, score = rec 57 | if label > 0: 58 | mrr += 1.0 / (i + 1) 59 | break 60 | 61 | if num_query == 0: 62 | return 0.0 63 | else: 64 | mrr = mrr / num_query 65 | return mrr 66 | 67 | 68 | def mean_average_precision(results): 69 | num_query = 0 70 | mvp = 0.0 71 | for k, v in results.items(): 72 | if not is_valid_query(v): 73 | continue 74 | 75 | num_query += 1 76 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True) 77 | num_relevant_doc = 0.0 78 | avp = 0.0 79 | for i, rec in enumerate(sorted_v): 80 | aid, label, score = rec 81 | if label == 1: 82 | num_relevant_doc += 1 83 | precision = num_relevant_doc / (i + 1) 84 | avp += precision 85 | avp = avp / num_relevant_doc 86 | mvp += avp 87 | 88 | if num_query == 0: 89 | return 0.0 90 | else: 91 | mvp = mvp / num_query 92 | return mvp 93 | 94 | 95 | def classification_metrics(results): 96 | total_num = 0 97 | total_correct = 0 98 | true_positive = 0 99 | positive_correct = 0 100 | predicted_positive = 0 101 | 102 | loss = 0.0; 103 | for k, v in results.items(): 104 | for rec in v: 105 | total_num += 1 106 | aid, label, score = rec 107 | 108 | if score > 0.5: 109 | predicted_positive += 1 110 | 111 | if label > 0: 112 | true_positive += 1 113 | loss += -math.log(score + 1e-12) 114 | else: 115 | loss += -math.log(1.0 - score + 1e-12); 116 | 117 | if score > 0.5 and label > 0: 118 | total_correct += 1 119 | positive_correct += 1 120 | 121 | if score < 0.5 and label < 0.5: 122 | total_correct += 1 123 | 124 | accuracy = float(total_correct) / total_num 125 | precision = float(positive_correct) / (predicted_positive + 1e-12) 126 | recall = float(positive_correct) / true_positive 127 | F1 = 2.0 * precision * recall / (1e-12 + precision + recall) 128 | return accuracy, precision, recall, F1, loss / total_num; 129 | 130 | 131 | def top_k_precision(results, k=1): 132 | num_query = 0 133 | top_1_correct = 0.0 134 | for key, v in results.items(): 135 | if not is_valid_query(v): 136 | continue 137 | num_query += 1 138 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True) 139 | if k == 1: 140 | aid, label, score = sorted_v[0] 141 | if label > 0: 142 | top_1_correct += 1 143 | elif k == 2: 144 | aid1, label1, score1 = sorted_v[0] 145 | aid2, label2, score2 = sorted_v[1] 146 | if label1 > 0 or label2 > 0: 147 | top_1_correct += 1 148 | elif k == 5: 149 | for vv in sorted_v[0:5]: 150 | label = vv[1] 151 | if label > 0: 152 | top_1_correct += 1 153 | break 154 | else: 155 | raise BaseException 156 | 157 | if num_query > 0: 158 | return top_1_correct/num_query 159 | else: 160 | return 0.0 -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import random 23 | import operator 24 | from time import time 25 | from collections import defaultdict 26 | import tensorflow as tf 27 | import optimization 28 | import tokenization 29 | import modeling_switch as modeling 30 | import metrics 31 | 32 | flags = tf.flags 33 | FLAGS = flags.FLAGS 34 | 35 | ## Required parameters 36 | flags.DEFINE_string("test_dir", 'test.tfrecord', 37 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.") 38 | 39 | flags.DEFINE_string("restore_model_dir", 'output/', 40 | "The output directory where the model checkpoints have been written.") 41 | 42 | flags.DEFINE_string("task_name", 'TestModel', 43 | "The name of the task.") 44 | 45 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json', 46 | "The config json file corresponding to the pre-trained BERT model. " 47 | "This specifies the model architecture.") 48 | 49 | flags.DEFINE_integer("max_seq_length", 320, 50 | "The maximum total input sequence length after WordPiece tokenization. " 51 | "Sequences longer than this will be truncated, and sequences shorter " 52 | "than this will be padded.") 53 | 54 | flags.DEFINE_bool("do_eval", True, 55 | "Whether to run eval on the dev set.") 56 | 57 | flags.DEFINE_integer("eval_batch_size", 32, 58 | "Total batch size for predict.") 59 | 60 | 61 | def print_configuration_op(FLAGS): 62 | print('My Configurations:') 63 | for name, value in FLAGS.__flags.items(): 64 | value=value.value 65 | if type(value) == float: 66 | print(' %s:\t %f'%(name, value)) 67 | elif type(value) == int: 68 | print(' %s:\t %d'%(name, value)) 69 | elif type(value) == str: 70 | print(' %s:\t %s'%(name, value)) 71 | elif type(value) == bool: 72 | print(' %s:\t %s'%(name, value)) 73 | else: 74 | print('%s:\t %s' % (name, value)) 75 | print('End of configuration') 76 | 77 | 78 | def total_sample(file_name): 79 | sample_nums = 0 80 | for record in tf.python_io.tf_record_iterator(file_name): 81 | sample_nums += 1 82 | return sample_nums 83 | 84 | 85 | def print_weight(name): 86 | with open('valid/weight_log' + name + str(random.randint(0, 100)), 'w') as fw: 87 | variables = tf.trainable_variables() 88 | for variable in variables: 89 | fw.write(str(variable.eval())) 90 | fw.write('\n') 91 | 92 | 93 | def parse_exmp(serial_exmp): 94 | input_data = tf.parse_single_example(serial_exmp, 95 | features={ 96 | "ques_ids": 97 | tf.FixedLenFeature([], tf.int64), 98 | "ans_ids": 99 | tf.FixedLenFeature([], tf.int64), 100 | "input_sents": 101 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 102 | "input_mask": 103 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 104 | "segment_ids": 105 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 106 | "switch_ids": 107 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 108 | "label_ids": 109 | tf.FixedLenFeature([], tf.float32), 110 | } 111 | ) 112 | # So cast all int64 to int32. 113 | for name in list(input_data.keys()): 114 | t = input_data[name] 115 | if t.dtype == tf.int64: 116 | t = tf.to_int32(t) 117 | input_data[name] = t 118 | 119 | ques_ids = input_data["ques_ids"] 120 | ans_ids = input_data['ans_ids'] 121 | sents = input_data["input_sents"] 122 | mask = input_data["input_mask"] 123 | segment_ids= input_data["segment_ids"] 124 | switch_ids= input_data["switch_ids"] 125 | labels = input_data['label_ids'] 126 | return ques_ids, ans_ids, sents, mask, segment_ids, switch_ids, labels 127 | 128 | 129 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, switch_ids, labels, ques_ids, ans_ids, 130 | num_labels, use_one_hot_embeddings): 131 | """Creates a classification model.""" 132 | model = modeling.BertModel( 133 | config=bert_config, 134 | is_training=is_training, 135 | input_ids=input_ids, 136 | input_mask=input_mask, 137 | token_type_ids=segment_ids, 138 | switch_ids=switch_ids, 139 | use_one_hot_embeddings=use_one_hot_embeddings) 140 | 141 | # In the demo, we are doing a simple classification task on the entire 142 | # segment. 143 | # 144 | # If you want to use the token-level output, use model.get_sequence_output() 145 | # instead. 146 | target_loss_weight = [1.0, 1.0] 147 | target_loss_weight = tf.convert_to_tensor(target_loss_weight) 148 | 149 | flagx = tf.cast(tf.greater(labels, 0), dtype=tf.float32) 150 | flagy = tf.cast(tf.equal(labels, 0), dtype=tf.float32) 151 | 152 | all_target_loss = target_loss_weight[1] * flagx + target_loss_weight[0] * flagy 153 | 154 | output_layer = model.get_pooled_output() 155 | 156 | hidden_size = output_layer.shape[-1].value 157 | 158 | output_weights = tf.get_variable( 159 | "output_weights", [num_labels, hidden_size], 160 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 161 | 162 | output_bias = tf.get_variable( 163 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 164 | 165 | with tf.variable_scope("loss"): 166 | # if is_training: 167 | # # I.e., 0.1 dropout 168 | # output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 169 | output_layer = tf.layers.dropout(output_layer, rate=0.1, training=is_training) 170 | 171 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 172 | logits = tf.nn.bias_add(logits, output_bias) 173 | 174 | probabilities = tf.sigmoid(logits, name="prob") 175 | logits = tf.squeeze(logits,[1]) 176 | losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels) 177 | losses = tf.multiply(losses, all_target_loss) 178 | 179 | mean_loss = tf.reduce_mean(losses, name="mean_loss") + sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 180 | 181 | with tf.name_scope("accuracy"): 182 | correct_prediction = tf.equal(tf.sign(probabilities - 0.5), tf.sign(labels - 0.5)) 183 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy") 184 | # 185 | # one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 186 | # 187 | # per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 188 | # loss = tf.reduce_mean(per_example_loss) 189 | 190 | return mean_loss, logits, probabilities, accuracy, model, output_layer 191 | 192 | 193 | best_score = 0.0 194 | def run_test(dir_path, op_name, sess, training, accuracy, prob, pair_ids, output_layer): 195 | results = defaultdict(list) 196 | num_test = 0 197 | num_correct = 0.0 198 | n_updates = 0 199 | mrr = 0 200 | t0 = time() 201 | try: 202 | while True: 203 | n_updates += 1 204 | 205 | batch_accuracy, predicted_prob, pair_ = sess.run([accuracy, prob, pair_ids], feed_dict={training: False}) 206 | question_id, answer_id, label = pair_ 207 | 208 | num_test += len(predicted_prob) 209 | num_correct += len(predicted_prob) * batch_accuracy 210 | for i, prob_score in enumerate(predicted_prob): 211 | # question_id, answer_id, label = pair_id[i] 212 | results[question_id[i]].append((answer_id[i], label[i], prob_score[0])) 213 | 214 | if n_updates%2000 == 0: 215 | tf.logging.info("n_update %d , %s: Mins Used: %.2f" % 216 | (n_updates, op_name, (time() - t0) / 60.0)) 217 | 218 | except tf.errors.OutOfRangeError: 219 | # calculate top-1 precision 220 | print('num_test_samples: {} test_accuracy: {}'.format(num_test, num_correct / num_test)) 221 | accu, precision, recall, f1, loss = metrics.classification_metrics(results) 222 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss)) 223 | 224 | mvp = metrics.mean_average_precision(results) 225 | mrr = metrics.mean_reciprocal_rank(results) 226 | top_1_precision = metrics.top_1_precision(results) 227 | total_valid_query = metrics.get_num_valid_query(results) 228 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'.format( 229 | mvp, mrr, top_1_precision, total_valid_query)) 230 | 231 | out_path = os.path.join(dir_path, "output_test.txt") 232 | print("Saving evaluation to {}".format(out_path)) 233 | with open(out_path, 'w') as f: 234 | f.write("query_id\tdocument_id\tscore\trank\trelevance\n") 235 | for us_id, v in results.items(): 236 | v.sort(key=operator.itemgetter(2), reverse=True) 237 | for i, rec in enumerate(v): 238 | r_id, label, prob_score = rec 239 | rank = i+1 240 | f.write('{}\t{}\t{}\t{}\t{}\n'.format(us_id, r_id, prob_score, rank, label)) 241 | return mrr 242 | 243 | 244 | def main(_): 245 | tf.logging.set_verbosity(tf.logging.INFO) 246 | 247 | print_configuration_op(FLAGS) 248 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 249 | 250 | test_data_size = total_sample(FLAGS.test_dir) 251 | tf.logging.info('test data size: {}'.format(test_data_size)) 252 | 253 | filenames = tf.placeholder(tf.string, shape=[None]) 254 | shuffle_size = tf.placeholder(tf.int64) 255 | dataset = tf.data.TFRecordDataset(filenames) 256 | dataset = dataset.map(parse_exmp) # Parse the record into tensors. 257 | dataset = dataset.repeat(1) 258 | # dataset = dataset.shuffle(shuffle_size) 259 | dataset = dataset.batch(FLAGS.eval_batch_size) 260 | iterator = dataset.make_initializable_iterator() 261 | ques_ids, ans_ids, sents, mask, segment_ids, switch_ids, labels = iterator.get_next() # output dir 262 | pair_ids = [ques_ids, ans_ids, labels] 263 | 264 | training = tf.placeholder(tf.bool) 265 | mean_loss, logits, probabilities, accuracy, model, output_layer = create_model(bert_config, 266 | is_training = training, 267 | input_ids = sents, 268 | input_mask = mask, 269 | segment_ids = segment_ids, 270 | switch_ids = switch_ids, 271 | labels = labels, 272 | ques_ids = ques_ids, 273 | ans_ids = ans_ids, 274 | num_labels = 1, 275 | use_one_hot_embeddings = False) 276 | 277 | 278 | config = tf.ConfigProto(allow_soft_placement=True) 279 | config.gpu_options.allow_growth = True 280 | 281 | if FLAGS.do_eval: 282 | with tf.Session(config=config) as sess: 283 | tf.logging.info("*** Restore model ***") 284 | 285 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir) 286 | variables = tf.trainable_variables() 287 | saver = tf.train.Saver(variables) 288 | saver.restore(sess, ckpt.model_checkpoint_path) 289 | 290 | tf.logging.info('Test begin') 291 | sess.run(iterator.initializer, 292 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1}) 293 | run_test(FLAGS.restore_model_dir, "test", sess, training, accuracy, probabilities, pair_ids, output_layer) 294 | 295 | 296 | if __name__ == "__main__": 297 | 298 | tf.app.run() 299 | 300 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat in ("Cc", "Cf"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /data/Ubuntu_V1_Xu/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat in ("Cc", "Cf"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /data/Ubuntu_V1_Xu/create_finetuning_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import collections 3 | import tokenization 4 | import tensorflow as tf 5 | from tqdm import tqdm 6 | 7 | 8 | tf.flags.DEFINE_string("response_file", "./Ubuntu_Corpus_V1/responses.txt", 9 | "path to response file") 10 | tf.flags.DEFINE_string("train_file", "./Ubuntu_Corpus_V1/train.txt", 11 | "path to train file") 12 | tf.flags.DEFINE_string("valid_file", "./Ubuntu_Corpus_V1/valid.txt", 13 | "path to valid file") 14 | tf.flags.DEFINE_string("test_file", "./Ubuntu_Corpus_V1/test.txt", 15 | "path to test file") 16 | 17 | tf.flags.DEFINE_string("vocab_file", "../../uncased_L-12_H-768_A-12/vocab.txt", 18 | "path to vocab file") 19 | tf.flags.DEFINE_integer("max_seq_length", 512, 20 | "max sequence length of concatenated context and response") 21 | tf.flags.DEFINE_bool("do_lower_case", True, 22 | "whether to lower case the input text") 23 | 24 | 25 | 26 | def print_configuration_op(FLAGS): 27 | print('My Configurations:') 28 | for name, value in FLAGS.__flags.items(): 29 | value=value.value 30 | if type(value) == float: 31 | print(' %s:\t %f'%(name, value)) 32 | elif type(value) == int: 33 | print(' %s:\t %d'%(name, value)) 34 | elif type(value) == str: 35 | print(' %s:\t %s'%(name, value)) 36 | elif type(value) == bool: 37 | print(' %s:\t %s'%(name, value)) 38 | else: 39 | print('%s:\t %s' % (name, value)) 40 | print('End of configuration') 41 | 42 | 43 | def load_responses(fname): 44 | responses={} 45 | with open(fname, 'rt') as f: 46 | for line in f: 47 | line = line.strip() 48 | fields = line.split('\t') 49 | if len(fields) != 2: 50 | print("WRONG LINE: {}".format(line)) 51 | r_text = 'unknown' 52 | else: 53 | r_text = fields[1] 54 | responses[fields[0]] = r_text 55 | return responses 56 | 57 | 58 | def load_dataset(fname, responses): 59 | 60 | processed_fname = "processed_" + fname.split("/")[-1] 61 | dataset_size = 0 62 | print("Generating the file of {} ...".format(processed_fname)) 63 | 64 | with open(processed_fname, 'w') as fw: 65 | with open(fname, 'rt') as fr: 66 | for line in fr: 67 | line = line.strip() 68 | fields = line.split('\t') 69 | 70 | us_id = fields[0] 71 | context = fields[1] 72 | 73 | if fields[2] != "NA": 74 | pos_ids = [id for id in fields[2].split('|')] 75 | for r_id in pos_ids: 76 | r_utter = responses[r_id] 77 | dataset_size += 1 78 | fw.write("\t".join([str(us_id), context, r_id, r_utter, 'follow'])) 79 | fw.write('\n') 80 | 81 | if fields[3] != "NA": 82 | neg_ids = [id for id in fields[3].split('|')] 83 | for r_id in neg_ids: 84 | r_utter = responses[r_id] 85 | dataset_size += 1 86 | fw.write("\t".join([str(us_id), context, r_id, r_utter, 'unfollow'])) 87 | fw.write('\n') 88 | 89 | print("{} dataset_size: {}".format(processed_fname, dataset_size)) 90 | return processed_fname 91 | 92 | 93 | class InputExample(object): 94 | def __init__(self, guid,ques_ids, text_a, ans_ids, text_b=None, label=None): 95 | """Constructs a InputExample. 96 | Args: 97 | guid: Unique id for the example. 98 | text_a: string. The untokenized text of the first sequence. For single 99 | sequence tasks, only this sequence must be specified. 100 | text_b: (Optional) string. The untokenized text of the second sequence. 101 | Only must be specified for sequence pair tasks. 102 | label: (Optional) string. The label of the example. This should be 103 | specified for train and dev examples, but not for test examples. 104 | """ 105 | self.guid = guid 106 | self.ques_ids = ques_ids 107 | self.ans_ids = ans_ids 108 | self.text_a = text_a 109 | self.text_b = text_b 110 | self.label = label 111 | 112 | class InputFeatures(object): 113 | """A single set of features of data.""" 114 | def __init__(self, ques_ids, ans_ids, input_sents, input_mask, segment_ids, switch_ids, label_id): 115 | self.ques_ids = ques_ids 116 | self.ans_ids = ans_ids 117 | self.input_sents = input_sents 118 | self.input_mask = input_mask 119 | self.segment_ids = segment_ids 120 | self.switch_ids=switch_ids 121 | self.label_id = label_id 122 | 123 | def read_processed_file(input_file): 124 | lines = [] 125 | num_lines = sum(1 for line in open(input_file, 'r')) 126 | with open(input_file, 'r') as f: 127 | for line in tqdm(f, total=num_lines): 128 | concat = [] 129 | temp = line.rstrip().split('\t') 130 | concat.append(temp[0]) # contxt id 131 | concat.append(temp[1]) # contxt 132 | concat.append(temp[2]) # response id 133 | concat.append(temp[3]) # response 134 | concat.append(temp[4]) # label 135 | lines.append(concat) 136 | return lines 137 | 138 | def create_examples(lines, set_type): 139 | """Creates examples for the training and dev sets.""" 140 | examples = [] 141 | for (i, line) in enumerate(lines): 142 | guid = "%s-%s" % (set_type, str(i)) 143 | ques_ids = line[0] 144 | text_a = tokenization.convert_to_unicode(line[1]) 145 | ans_ids = line[2] 146 | text_b = tokenization.convert_to_unicode(line[3]) 147 | label = tokenization.convert_to_unicode(line[-1]) 148 | examples.append(InputExample(guid=guid, ques_ids=ques_ids, text_a=text_a, ans_ids=ans_ids, text_b=text_b, label=label)) 149 | return examples 150 | 151 | 152 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 153 | """Truncates a sequence pair in place to the maximum length.""" 154 | 155 | # This is a simple heuristic which will always truncate the longer sequence 156 | # one token at a time. This makes more sense than truncating an equal percent 157 | # of tokens from each, since if one sequence is very short then each token 158 | # that's truncated likely contains more information than a longer sequence. 159 | while True: 160 | total_length = len(tokens_a) + len(tokens_b) 161 | if total_length <= max_length: 162 | break 163 | if len(tokens_a) > len(tokens_b): 164 | tokens_a.pop() 165 | else: 166 | tokens_b.pop() 167 | 168 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 169 | """Loads a data file into a list of `InputBatch`s.""" 170 | 171 | label_map = {} # label 172 | for (i, label) in enumerate(label_list): # ['0', '1'] 173 | label_map[label] = i 174 | 175 | features = [] # feature 176 | for (ex_index, example) in enumerate(examples): 177 | ques_ids = int(example.ques_ids) 178 | ans_ids = int(example.ans_ids) 179 | 180 | # tokens_a = tokenizer.tokenize(example.text_a) # text_a tokenize 181 | text_a_utters = example.text_a.split(" __EOS__ ") 182 | tokens_a = [] 183 | text_a_switch = [] 184 | for text_a_utter_idx, text_a_utter in enumerate(text_a_utters): 185 | if text_a_utter_idx%2 == 0: 186 | text_a_switch_flag = 0 187 | else: 188 | text_a_switch_flag = 1 189 | text_a_utter_token = tokenizer.tokenize(text_a_utter + " __EOS__") 190 | tokens_a.extend(text_a_utter_token) 191 | text_a_switch.extend([text_a_switch_flag]*len(text_a_utter_token)) 192 | assert len(tokens_a) == len(text_a_switch) 193 | 194 | tokens_b = None 195 | if example.text_b: 196 | tokens_b = tokenizer.tokenize(example.text_b) # text_b tokenize 197 | 198 | if tokens_b: # if has b 199 | # Modifies `tokens_a` and `tokens_b` in place so that the total 200 | # length is less than the specified length. 201 | # Account for [CLS], [SEP], [SEP] with "- 3" 202 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) # truncate 203 | else: 204 | # Account for [CLS] and [SEP] with "- 2" 205 | if len(tokens_a) > max_seq_length - 2: 206 | tokens_a = tokens_a[0:(max_seq_length - 2)] 207 | 208 | # The convention in BERT is: 209 | # (a) For sequence pairs: 210 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 211 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 212 | # (b) For single sequences: 213 | # tokens: [CLS] the dog is hairy . [SEP] 214 | # type_ids: 0 0 0 0 0 0 0 215 | # 216 | # Where "type_ids" are used to indicate whether this is the first 217 | # sequence or the second sequence. The embedding vectors for `type=0` and 218 | # `type=1` were learned during pre-training and are added to the wordpiece 219 | # embedding vector (and position vector). This is not *strictly* necessary 220 | # since the [SEP] token unambiguously separates the sequences, but it makes 221 | # it easier for the model to learn the concept of sequences. 222 | # 223 | # For classification tasks, the first vector (corresponding to [CLS]) is 224 | # used as as the "sentence vector". Note that this only makes sense because # (?) 225 | # the entire model is fine-tuned. 226 | tokens = [] 227 | segment_ids = [] 228 | switch_ids = [] 229 | tokens.append("[CLS]") 230 | segment_ids.append(0) 231 | switch_ids.append(0) 232 | for token_idx, token in enumerate(tokens_a): 233 | tokens.append(token) 234 | segment_ids.append(0) 235 | switch_ids.append(text_a_switch[token_idx]) 236 | tokens.append("[SEP]") 237 | segment_ids.append(0) 238 | switch_ids.append(0) 239 | 240 | if tokens_b: 241 | for token_idx, token in enumerate(tokens_b): 242 | tokens.append(token) 243 | segment_ids.append(1) 244 | switch_ids.append(1) 245 | tokens.append("[SEP]") 246 | segment_ids.append(1) 247 | switch_ids.append(1) 248 | 249 | input_sents = tokenizer.convert_tokens_to_ids(tokens) 250 | 251 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 252 | # tokens are attended to. 253 | input_mask = [1] * len(input_sents) # mask 254 | 255 | # Zero-pad up to the sequence length. 256 | while len(input_sents) < max_seq_length: 257 | input_sents.append(0) 258 | input_mask.append(0) 259 | segment_ids.append(0) 260 | switch_ids.append(0) 261 | 262 | assert len(input_sents) == max_seq_length 263 | assert len(input_mask) == max_seq_length 264 | assert len(segment_ids) == max_seq_length 265 | assert len(switch_ids) == max_seq_length 266 | 267 | label_id = label_map[example.label] 268 | 269 | if ex_index%2000 == 0: 270 | print('convert_{}_examples_to_features'.format(ex_index)) 271 | 272 | features.append( 273 | InputFeatures( # object 274 | ques_ids=ques_ids, 275 | ans_ids = ans_ids, 276 | input_sents=input_sents, 277 | input_mask=input_mask, 278 | segment_ids=segment_ids, 279 | switch_ids=switch_ids, 280 | label_id=label_id)) 281 | 282 | return features 283 | 284 | 285 | def write_instance_to_example_files(instances, output_files): 286 | writers = [] 287 | 288 | for output_file in output_files: 289 | writers.append(tf.python_io.TFRecordWriter(output_file)) 290 | 291 | writer_index = 0 292 | total_written = 0 293 | for (inst_index, instance) in enumerate(instances): 294 | features = collections.OrderedDict() 295 | features["ques_ids"] = create_int_feature([instance.ques_ids]) 296 | features["ans_ids"] = create_int_feature([instance.ans_ids]) 297 | features["input_sents"] = create_int_feature(instance.input_sents) 298 | features["input_mask"] = create_int_feature(instance.input_mask) 299 | features["segment_ids"] = create_int_feature(instance.segment_ids) 300 | features["switch_ids"] = create_int_feature(instance.switch_ids) 301 | features["label_ids"] = create_float_feature([instance.label_id]) 302 | 303 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 304 | 305 | writers[writer_index].write(tf_example.SerializeToString()) 306 | writer_index = (writer_index + 1) % len(writers) 307 | 308 | total_written += 1 309 | 310 | print("write_{}_instance_to_example_files".format(total_written)) 311 | 312 | for feature_name in features.keys(): 313 | feature = features[feature_name] 314 | values = [] 315 | if feature.int64_list.value: 316 | values = feature.int64_list.value 317 | elif feature.float_list.value: 318 | values = feature.float_list.value 319 | tf.logging.info( 320 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 321 | 322 | for writer in writers: 323 | writer.close() 324 | 325 | 326 | def create_int_feature(values): 327 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 328 | return feature 329 | 330 | def create_float_feature(values): 331 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 332 | return feature 333 | 334 | 335 | 336 | if __name__ == "__main__": 337 | 338 | FLAGS = tf.flags.FLAGS 339 | print_configuration_op(FLAGS) 340 | 341 | responses = load_responses(FLAGS.response_file) 342 | train_filename = load_dataset(FLAGS.train_file, responses) 343 | valid_filename = load_dataset(FLAGS.valid_file, responses) 344 | test_filename = load_dataset(FLAGS.test_file, responses) 345 | 346 | filenames = [train_filename, valid_filename, test_filename] 347 | filetypes = ["train", "valid", "test"] 348 | files = zip(filenames, filetypes) 349 | 350 | label_list = ["unfollow", "follow"] 351 | tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 352 | 353 | for (filename, filetype) in files: 354 | examples = create_examples(read_processed_file(filename), filetype) 355 | features = convert_examples_to_features(examples, label_list, FLAGS.max_seq_length, tokenizer) 356 | new_filename = filename[:-4] + ".tfrecord" 357 | write_instance_to_example_files(features, [new_filename]) 358 | print('Convert {} to {} done'.format(filename, new_filename)) 359 | 360 | print("Sub-process(es) done.") 361 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import operator 23 | from time import time 24 | from collections import defaultdict 25 | import tensorflow as tf 26 | import optimization 27 | import tokenization 28 | import modeling_switch as modeling 29 | import metrics 30 | 31 | flags = tf.flags 32 | FLAGS = flags.FLAGS 33 | 34 | ## Required parameters 35 | flags.DEFINE_string("train_dir", 'train.tfrecord', 36 | "The input train data dir. Should contain the .tsv files (or other data files) for the task.") 37 | 38 | flags.DEFINE_string("valid_dir", 'valid.tfrecord', 39 | "The input valid data dir. Should contain the .tsv files (or other data files) for the task.") 40 | 41 | flags.DEFINE_string("output_dir", 'output', 42 | "The output directory where the model checkpoints will be written.") 43 | 44 | flags.DEFINE_string("task_name", 'ResponseSelection', 45 | "The name of the task to train.") 46 | 47 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json', 48 | "The config json file corresponding to the pre-trained BERT model. " 49 | "This specifies the model architecture.") 50 | 51 | flags.DEFINE_string("vocab_file", 'uncased_L-12_H-768_A-12/vocab.txt', 52 | "The vocabulary file that the BERT model was trained on.") 53 | 54 | flags.DEFINE_string("init_checkpoint", 'uncased_L-12_H-768_A-12/bert_model.ckpt', 55 | "Initial checkpoint (usually from a pre-trained BERT model).") 56 | 57 | flags.DEFINE_bool("do_lower_case", True, 58 | "Whether to lower case the input text. Should be True for uncased " 59 | "models and False for cased models.") 60 | 61 | flags.DEFINE_integer("max_seq_length", 320, 62 | "The maximum total input sequence length after WordPiece tokenization. " 63 | "Sequences longer than this will be truncated, and sequences shorter " 64 | "than this will be padded.") 65 | 66 | flags.DEFINE_bool("do_train", True, 67 | "Whether to run training.") 68 | 69 | flags.DEFINE_bool("do_eval", True, 70 | "Whether to run eval on the dev set.") 71 | 72 | flags.DEFINE_bool("do_predict", True, 73 | "Whether to run the model in inference mode on the test set.") 74 | 75 | flags.DEFINE_float("warmup_proportion", 0.1, 76 | "Proportion of training to perform linear learning rate warmup for. " 77 | "E.g., 0.1 = 10% of training.") 78 | 79 | flags.DEFINE_integer("train_batch_size", 12, 80 | "Total batch size for training.") 81 | 82 | flags.DEFINE_integer("eval_batch_size", 12, 83 | "Total batch size for eval.") 84 | 85 | flags.DEFINE_integer("predict_batch_size", 8, 86 | "Total batch size for predict.") 87 | 88 | flags.DEFINE_float("learning_rate", 2e-5, 89 | "The initial learning rate for Adam.") 90 | 91 | flags.DEFINE_integer("num_train_epochs", 5, 92 | "Total number of training epochs to perform.") 93 | 94 | 95 | 96 | def print_configuration_op(FLAGS): 97 | print('My Configurations:') 98 | for name, value in FLAGS.__flags.items(): 99 | value=value.value 100 | if type(value) == float: 101 | print(' %s:\t %f'%(name, value)) 102 | elif type(value) == int: 103 | print(' %s:\t %d'%(name, value)) 104 | elif type(value) == str: 105 | print(' %s:\t %s'%(name, value)) 106 | elif type(value) == bool: 107 | print(' %s:\t %s'%(name, value)) 108 | else: 109 | print('%s:\t %s' % (name, value)) 110 | print('End of configuration') 111 | 112 | 113 | def total_sample(file_name): 114 | sample_nums = 0 115 | for record in tf.python_io.tf_record_iterator(file_name): 116 | sample_nums += 1 117 | return sample_nums 118 | 119 | 120 | def parse_exmp(serial_exmp): 121 | input_data = tf.parse_single_example(serial_exmp, 122 | features={ 123 | "ques_ids": 124 | tf.FixedLenFeature([], tf.int64), 125 | "ans_ids": 126 | tf.FixedLenFeature([], tf.int64), 127 | "input_sents": 128 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 129 | "input_mask": 130 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 131 | "segment_ids": 132 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 133 | "switch_ids": 134 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 135 | "label_ids": 136 | tf.FixedLenFeature([], tf.float32), 137 | } 138 | ) 139 | # So cast all int64 to int32. 140 | for name in list(input_data.keys()): 141 | t = input_data[name] 142 | if t.dtype == tf.int64: 143 | t = tf.to_int32(t) 144 | input_data[name] = t 145 | 146 | ques_ids = input_data["ques_ids"] 147 | ans_ids = input_data['ans_ids'] 148 | sents = input_data["input_sents"] 149 | mask = input_data["input_mask"] 150 | segment_ids= input_data["segment_ids"] 151 | switch_ids= input_data["switch_ids"] 152 | labels = input_data['label_ids'] 153 | return ques_ids, ans_ids, sents, mask, segment_ids, switch_ids, labels 154 | 155 | 156 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, switch_ids, labels, ques_ids, ans_ids, 157 | num_labels, use_one_hot_embeddings): 158 | """Creates a classification model.""" 159 | model = modeling.BertModel( 160 | config=bert_config, 161 | is_training=is_training, 162 | input_ids=input_ids, 163 | input_mask=input_mask, 164 | token_type_ids=segment_ids, 165 | switch_ids=switch_ids, 166 | use_one_hot_embeddings=use_one_hot_embeddings) 167 | 168 | # In the demo, we are doing a simple classification task on the entire 169 | # segment. 170 | # 171 | # If you want to use the token-level output, use model.get_sequence_output() 172 | # instead. 173 | target_loss_weight = [1.0, 1.0] 174 | target_loss_weight = tf.convert_to_tensor(target_loss_weight) 175 | 176 | flagx = tf.cast(tf.greater(labels, 0), dtype=tf.float32) 177 | flagy = tf.cast(tf.equal(labels, 0), dtype=tf.float32) 178 | 179 | all_target_loss = target_loss_weight[1] * flagx + target_loss_weight[0] * flagy 180 | 181 | output_layer = model.get_pooled_output() 182 | 183 | hidden_size = output_layer.shape[-1].value 184 | 185 | output_weights = tf.get_variable( 186 | "output_weights", [num_labels, hidden_size], 187 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 188 | 189 | output_bias = tf.get_variable( 190 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 191 | 192 | with tf.variable_scope("loss"): 193 | 194 | output_layer = tf.layers.dropout(output_layer, rate=0.1, training=is_training) 195 | 196 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 197 | logits = tf.nn.bias_add(logits, output_bias) 198 | 199 | probabilities = tf.sigmoid(logits, name="prob") 200 | logits = tf.squeeze(logits,[1]) 201 | losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels) 202 | losses = tf.multiply(losses, all_target_loss) 203 | 204 | mean_loss = tf.reduce_mean(losses, name="mean_loss") + sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 205 | 206 | with tf.name_scope("accuracy"): 207 | correct_prediction = tf.equal(tf.sign(probabilities - 0.5), tf.sign(labels - 0.5)) 208 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy") 209 | 210 | # one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 211 | # 212 | # per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 213 | # loss = tf.reduce_mean(per_example_loss) 214 | 215 | return mean_loss, logits, probabilities, accuracy, model 216 | 217 | 218 | def run_epoch(epoch_no, op_name, sess, training, logits, accuracy, mean_loss, train_opt=tf.constant(0)): 219 | n_updates = 0 220 | t_loss = 0 221 | n_all = 0 222 | t0 = time() 223 | try: 224 | while True: 225 | n_updates += 1 226 | batch_logits, batch_loss, _ , accur= sess.run([logits, mean_loss, train_opt, accuracy], feed_dict={training:True}) 227 | n_sample = batch_logits.shape[0] 228 | n_all += n_sample 229 | t_loss += batch_loss * n_sample 230 | if n_updates%2000 == 0: 231 | tf.logging.info("epoch: %i n_update %d , %s: Mins Used: %.2f, Loss: %.4f, Accuarcy: %.2f" % 232 | (epoch_no, n_updates, op_name, (time() - t0) / 60.0, t_loss / n_all, 100 * accur)) 233 | 234 | except tf.errors.OutOfRangeError: 235 | tf.logging.info("epoch: %i %s: Mins Used: %.2f, Loss: %.4f, Accuarcy: %.2f" % 236 | (epoch_no, op_name, (time() - t0)/60.0, t_loss / n_all, 100*accur)) 237 | pass 238 | return t_loss / n_all 239 | 240 | 241 | best_score = 0.0 242 | def run_test(epoch_no, dir_path, op_name, sess, training, accuracy, prob, pair_ids): 243 | results = defaultdict(list) 244 | num_test = 0 245 | num_correct = 0.0 246 | n_updates = 0 247 | mrr = 0 248 | t0 = time() 249 | try: 250 | while True: 251 | n_updates += 1 252 | 253 | batch_accuracy, predicted_prob, pair_ = sess.run([accuracy, prob, pair_ids], feed_dict={training:False}) 254 | question_id, answer_id, label = pair_ 255 | 256 | num_test += len(predicted_prob) 257 | num_correct += len(predicted_prob) * batch_accuracy 258 | for i, prob_score in enumerate(predicted_prob): 259 | results[question_id[i]].append((answer_id[i], label[i], prob_score[0])) 260 | 261 | if n_updates%2000 == 0: 262 | tf.logging.info("epoch: %i n_update %d , %s: Mins Used: %.2f" % 263 | (epoch_no, n_updates, op_name, (time() - t0)/60.0 )) 264 | 265 | except tf.errors.OutOfRangeError: 266 | 267 | # calculate top-1 precision 268 | print('num_test_samples: {} test_accuracy: {}'.format(num_test, num_correct / num_test)) 269 | accu, precision, recall, f1, loss = metrics.classification_metrics(results) 270 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss)) 271 | 272 | mvp = metrics.mean_average_precision(results) 273 | mrr = metrics.mean_reciprocal_rank(results) 274 | top_1_precision = metrics.top_1_precision(results) 275 | total_valid_query = metrics.get_num_valid_query(results) 276 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'.format( 277 | mvp, mrr, top_1_precision, total_valid_query)) 278 | 279 | out_path = os.path.join(dir_path, "output_epoch_{}.txt".format(epoch_no)) 280 | print("Saving evaluation to {}".format(out_path)) 281 | with open(out_path, 'w') as f: 282 | f.write("query_id\tdocument_id\tscore\trank\trelevance\n") 283 | for us_id, v in results.items(): 284 | v.sort(key=operator.itemgetter(2), reverse=True) 285 | for i, rec in enumerate(v): 286 | r_id, label, prob_score = rec 287 | rank = i+1 288 | f.write('{}\t{}\t{}\t{}\t{}\n'.format(us_id, r_id, prob_score, rank, label)) 289 | 290 | global best_score 291 | if op_name == 'valid' and mrr > best_score: 292 | best_score = mrr 293 | saver = tf.train.Saver() 294 | dir_path = os.path.join(dir_path, "epoch {}".format(epoch_no)) 295 | if not os.path.exists(dir_path): 296 | os.makedirs(dir_path) 297 | saver.save(sess, dir_path) 298 | tf.logging.info(">> save model!") 299 | 300 | return mrr 301 | 302 | 303 | 304 | def main(_): 305 | tf.logging.set_verbosity(tf.logging.INFO) 306 | 307 | print_configuration_op(FLAGS) 308 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 309 | root_path = FLAGS.output_dir 310 | if not os.path.exists(root_path): 311 | os.makedirs(root_path) 312 | 313 | timestamp = str(int(time())) 314 | root_path = os.path.join(root_path, timestamp) 315 | tf.logging.info('root_path: {}'.format(root_path)) 316 | if not os.path.exists(root_path): 317 | os.makedirs(root_path) 318 | 319 | train_data_size = total_sample(FLAGS.train_dir) 320 | tf.logging.info('train data size: {}'.format(train_data_size)) 321 | valid_data_size = total_sample(FLAGS.valid_dir) 322 | tf.logging.info('valid data size: {}'.format(valid_data_size)) 323 | 324 | num_train_steps = train_data_size // FLAGS.train_batch_size * FLAGS.num_train_epochs 325 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 326 | 327 | filenames = tf.placeholder(tf.string, shape=[None]) 328 | shuffle_size = tf.placeholder(tf.int64) 329 | dataset = tf.data.TFRecordDataset(filenames) 330 | dataset = dataset.map(parse_exmp) # Parse the record into tensors. 331 | dataset = dataset.repeat(1) 332 | # buffer_size 100 333 | dataset = dataset.shuffle(shuffle_size) 334 | dataset = dataset.batch(FLAGS.train_batch_size) 335 | iterator = dataset.make_initializable_iterator() 336 | ques_ids, ans_ids, sents, mask, segment_ids, switch_ids, labels = iterator.get_next() # output dir 337 | pair_ids = [ques_ids, ans_ids, labels] 338 | 339 | 340 | training = tf.placeholder(tf.bool) 341 | mean_loss, logits, probabilities, accuracy, model = create_model(bert_config, 342 | is_training = training, 343 | input_ids = sents, 344 | input_mask = mask, 345 | segment_ids = segment_ids, 346 | switch_ids = switch_ids, 347 | labels = labels, 348 | ques_ids = ques_ids, 349 | ans_ids = ans_ids, 350 | num_labels = 1, 351 | use_one_hot_embeddings = False) 352 | 353 | 354 | # init model with pre-training 355 | tvars = tf.trainable_variables() 356 | if FLAGS.init_checkpoint: 357 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,FLAGS.init_checkpoint) 358 | tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map) 359 | 360 | tf.logging.info("**** Trainable Variables ****") 361 | for var in tvars: 362 | init_string = "" 363 | if var.name in initialized_variable_names: 364 | init_string = ", *INIT_FROM_CKPT*" 365 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 366 | init_string) 367 | 368 | 369 | train_opt = optimization.create_optimizer(mean_loss, FLAGS.learning_rate, num_train_steps, num_warmup_steps, False) 370 | 371 | config = tf.ConfigProto(allow_soft_placement=True) 372 | config.gpu_options.allow_growth = True 373 | 374 | 375 | if FLAGS.do_train: 376 | with tf.Session(config=config) as sess: 377 | sess.run(tf.global_variables_initializer()) 378 | 379 | for epoch in range(FLAGS.num_train_epochs): 380 | tf.logging.info('Epoch {} training begin'.format(epoch)) 381 | sess.run(iterator.initializer, 382 | feed_dict={filenames: [FLAGS.train_dir], shuffle_size: 1024}) 383 | run_epoch(epoch, "train", sess, training, logits, accuracy, mean_loss, train_opt) 384 | 385 | tf.logging.info('Valid begin') 386 | sess.run(iterator.initializer, 387 | feed_dict={filenames: [FLAGS.valid_dir], shuffle_size: 1}) 388 | run_test(epoch, root_path, "valid", sess, training, accuracy, probabilities, pair_ids) 389 | 390 | 391 | 392 | if __name__ == "__main__": 393 | tf.app.run() 394 | -------------------------------------------------------------------------------- /data/Ubuntu_V1_Xu/create_adaptation_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import random 23 | import tokenization 24 | import numpy as np 25 | import tensorflow as tf 26 | from tqdm import tqdm 27 | 28 | flags = tf.flags 29 | FLAGS = flags.FLAGS 30 | 31 | flags.DEFINE_string("train_file", './Ubuntu_Corpus_V1/train.txt', 32 | "Input raw text file (or comma-separated list of files).") 33 | 34 | flags.DEFINE_string("response_file", './Ubuntu_Corpus_V1/responses.txt', 35 | "Input raw text file (or comma-separated list of files).") 36 | 37 | flags.DEFINE_string("output_file", './pretrain_data.tfrecord', 38 | "Output TF example file (or comma-separated list of files).") 39 | 40 | flags.DEFINE_string("vocab_file", '../../uncased_L-12_H-768_A-12/vocab.txt', 41 | "The vocabulary file that the BERT model was trained on.") 42 | 43 | flags.DEFINE_bool("do_lower_case", True, 44 | "Whether to lower case the input text. Should be True for uncased " 45 | "models and False for cased models.") 46 | 47 | flags.DEFINE_integer("max_seq_length", 512, 48 | "Maximum sequence length.") 49 | 50 | flags.DEFINE_integer("max_predictions_per_seq", 25, 51 | "Maximum number of masked LM predictions per sequence.") 52 | 53 | flags.DEFINE_integer("random_seed", 12345, 54 | "Random seed for data generation.") 55 | 56 | flags.DEFINE_integer("dupe_factor", 10, 57 | "Number of times to duplicate the input data (with different masks).") 58 | 59 | flags.DEFINE_float("masked_lm_prob", 0.15, 60 | "Masked LM probability.") 61 | 62 | flags.DEFINE_float("short_seq_prob", 0.1, 63 | "Probability of creating sequences which are shorter than the maximum length.") 64 | 65 | 66 | 67 | class TrainingInstance(object): 68 | """A single training instance (sentence pair).""" 69 | 70 | def __init__(self, tokens, segment_ids, switch_ids, masked_lm_positions, masked_lm_labels, 71 | is_random_next): 72 | self.tokens = tokens 73 | self.segment_ids = segment_ids 74 | self.switch_ids = switch_ids 75 | self.is_random_next = is_random_next 76 | self.masked_lm_positions = masked_lm_positions 77 | self.masked_lm_labels = masked_lm_labels 78 | 79 | def __str__(self): 80 | s = "" 81 | s += "tokens: %s\n" % (" ".join( 82 | [tokenization.printable_text(x) for x in self.tokens])) 83 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 84 | s += "switch_ids: %s\n" % (" ".join([str(x) for x in self.switch_ids])) 85 | s += "is_random_next: %s\n" % self.is_random_next 86 | s += "masked_lm_positions: %s\n" % (" ".join( 87 | [str(x) for x in self.masked_lm_positions])) 88 | s += "masked_lm_labels: %s\n" % (" ".join( 89 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 90 | s += "\n" 91 | return s 92 | 93 | def __repr__(self): 94 | return self.__str__() 95 | 96 | 97 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 98 | max_predictions_per_seq, output_files): 99 | """Create TF example files from `TrainingInstance`s.""" 100 | writers = [] 101 | for output_file in output_files: 102 | writers.append(tf.python_io.TFRecordWriter(output_file)) 103 | 104 | writer_index = 0 105 | 106 | total_written = 0 107 | for (inst_index, instance) in enumerate(instances): 108 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 109 | input_mask = [1] * len(input_ids) 110 | segment_ids = list(instance.segment_ids) 111 | switch_ids = list(instance.switch_ids) 112 | assert len(input_ids) <= max_seq_length 113 | 114 | while len(input_ids) < max_seq_length: 115 | input_ids.append(0) 116 | input_mask.append(0) 117 | segment_ids.append(0) 118 | switch_ids.append(0) 119 | 120 | assert len(input_ids) == max_seq_length 121 | assert len(input_mask) == max_seq_length 122 | assert len(segment_ids) == max_seq_length 123 | assert len(switch_ids) == max_seq_length 124 | 125 | masked_lm_positions = list(instance.masked_lm_positions) 126 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 127 | masked_lm_weights = [1.0] * len(masked_lm_ids) 128 | 129 | while len(masked_lm_positions) < max_predictions_per_seq: 130 | masked_lm_positions.append(0) 131 | masked_lm_ids.append(0) 132 | masked_lm_weights.append(0.0) 133 | 134 | next_sentence_label = 1 if instance.is_random_next else 0 135 | 136 | features = collections.OrderedDict() 137 | features["input_ids"] = create_int_feature(input_ids) 138 | features["input_mask"] = create_int_feature(input_mask) 139 | features["segment_ids"] = create_int_feature(segment_ids) 140 | features["switch_ids"] = create_int_feature(switch_ids) 141 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 142 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 143 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 144 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 145 | 146 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 147 | 148 | writers[writer_index].write(tf_example.SerializeToString()) 149 | writer_index = (writer_index + 1) % len(writers) 150 | 151 | total_written += 1 152 | 153 | if inst_index < 20: 154 | tf.logging.info("*** Example ***") 155 | tf.logging.info("tokens: %s" % " ".join( 156 | [tokenization.printable_text(x) for x in instance.tokens])) 157 | 158 | for feature_name in features.keys(): 159 | feature = features[feature_name] 160 | values = [] 161 | if feature.int64_list.value: 162 | values = feature.int64_list.value 163 | elif feature.float_list.value: 164 | values = feature.float_list.value 165 | tf.logging.info( 166 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 167 | 168 | for writer in writers: 169 | writer.close() 170 | 171 | tf.logging.info("Wrote %d total instances", total_written) 172 | 173 | 174 | def create_int_feature(values): 175 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 176 | return feature 177 | 178 | def create_float_feature(values): 179 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 180 | return feature 181 | 182 | 183 | def create_training_instances(context, response, switch, tokenizer, max_seq_length, 184 | dupe_factor, short_seq_prob, masked_lm_prob, 185 | max_predictions_per_seq, rng): 186 | 187 | # Input file format: 188 | # (1) One sentence per line. These should ideally be actual sentences, not 189 | # entire paragraphs or arbitrary spans of text. (Because we use the 190 | # sentence boundaries for the "next sentence prediction" task). 191 | # (2) Blank lines between documents. Document boundaries are needed so 192 | # that the "next sentence prediction" task doesn't span between documents. 193 | 194 | sid_r = np.arange(0, len(context)) 195 | rng.shuffle(sid_r) 196 | 197 | vocab_words = list(tokenizer.vocab.keys()) 198 | instances = [] 199 | for _ in tqdm(range(dupe_factor)): 200 | for i in tqdm(range(len(sid_r))): 201 | 202 | sent_a = [] 203 | switch_a = [] 204 | for j in range(len(context[i])): 205 | utterance_a = context[i][j] 206 | utterance_a = tokenization.convert_to_unicode(utterance_a) 207 | utterance_a = tokenizer.tokenize(utterance_a) 208 | sent_a.extend(utterance_a) 209 | switch_a.extend([switch[i][j]] * len(utterance_a)) 210 | assert len(sent_a) == len(switch_a) 211 | 212 | if random.random() < 0.5: 213 | sent_b = response[sid_r[i]] 214 | is_random_next = True 215 | else: 216 | sent_b = response[i] 217 | is_random_next = False 218 | 219 | sent_b = tokenization.convert_to_unicode(sent_b) 220 | sent_b = tokenizer.tokenize(sent_b) 221 | instances.extend( 222 | create_instances_from_document( 223 | sent_a, sent_b, switch_a, is_random_next, max_seq_length, short_seq_prob, 224 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) 225 | 226 | rng.shuffle(instances) 227 | return instances 228 | 229 | 230 | def create_instances_from_document( 231 | tokens_a, tokens_b, switch_a, is_random_next, max_seq_length, short_seq_prob, 232 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 233 | """Creates `TrainingInstance`s for a single document.""" 234 | 235 | # Account for [CLS], [SEP], [SEP] 236 | max_num_tokens = max_seq_length - 3 237 | 238 | # We *usually* want to fill up the entire sequence since we are padding 239 | # to `max_seq_length` anyways, so short sequences are generally wasted 240 | # computation. However, we *sometimes* 241 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 242 | # sequences to minimize the mismatch between pre-training and fine-tuning. 243 | # The `target_seq_length` is just a rough target however, whereas 244 | # `max_seq_length` is a hard limit. 245 | 246 | # We DON'T just concatenate all of the tokens from a document into a long 247 | # sequence and choose an arbitrary split point because this would make the 248 | # next sentence prediction task too easy. Instead, we split the input into 249 | # segments "A" and "B" based on the actual "sentences" provided by the user 250 | # input. 251 | instances = [] 252 | 253 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 254 | 255 | assert len(tokens_a) >= 1 256 | assert len(tokens_b) >= 1 257 | 258 | tokens = [] 259 | segment_ids = [] 260 | switch_ids = [] 261 | tokens.append("[CLS]") 262 | segment_ids.append(0) 263 | switch_ids.append(0) 264 | for i, token in enumerate(tokens_a): 265 | tokens.append(token) 266 | segment_ids.append(0) 267 | switch_ids.append(switch_a[i]) 268 | 269 | tokens.append("[SEP]") 270 | segment_ids.append(0) 271 | switch_ids.append(0) 272 | 273 | for token in tokens_b: 274 | tokens.append(token) 275 | segment_ids.append(1) 276 | switch_ids.append(1) 277 | tokens.append("[SEP]") 278 | segment_ids.append(1) 279 | switch_ids.append(1) 280 | 281 | (tokens, masked_lm_positions, 282 | masked_lm_labels) = create_masked_lm_predictions( 283 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 284 | instance = TrainingInstance( 285 | tokens=tokens, 286 | segment_ids=segment_ids, 287 | switch_ids=switch_ids, 288 | is_random_next=is_random_next, 289 | masked_lm_positions=masked_lm_positions, 290 | masked_lm_labels=masked_lm_labels) 291 | instances.append(instance) 292 | 293 | return instances 294 | 295 | 296 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance", 297 | ["index", "label"]) 298 | 299 | 300 | def create_masked_lm_predictions(tokens, masked_lm_prob, 301 | max_predictions_per_seq, vocab_words, rng): 302 | """Creates the predictions for the masked LM objective.""" 303 | 304 | cand_indexes = [] 305 | for (i, token) in enumerate(tokens): 306 | if token == "[CLS]" or token == "[SEP]": 307 | continue 308 | cand_indexes.append(i) 309 | 310 | rng.shuffle(cand_indexes) 311 | 312 | output_tokens = list(tokens) 313 | 314 | num_to_predict = min(max_predictions_per_seq, 315 | max(1, int(round(len(tokens) * masked_lm_prob)))) 316 | 317 | masked_lms = [] 318 | covered_indexes = set() 319 | for index in cand_indexes: 320 | if len(masked_lms) >= num_to_predict: 321 | break 322 | if index in covered_indexes: 323 | continue 324 | covered_indexes.add(index) 325 | 326 | masked_token = None 327 | # 80% of the time, replace with [MASK] 328 | if rng.random() < 0.8: 329 | masked_token = "[MASK]" 330 | else: 331 | # 10% of the time, keep original 332 | if rng.random() < 0.5: 333 | masked_token = tokens[index] 334 | # 10% of the time, replace with random word 335 | else: 336 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 337 | 338 | output_tokens[index] = masked_token 339 | 340 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) 341 | 342 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 343 | 344 | masked_lm_positions = [] 345 | masked_lm_labels = [] 346 | for p in masked_lms: 347 | masked_lm_positions.append(p.index) 348 | masked_lm_labels.append(p.label) 349 | 350 | return (output_tokens, masked_lm_positions, masked_lm_labels) 351 | 352 | 353 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 354 | """Truncates a pair of sequences to a maximum sequence length.""" 355 | while True: 356 | total_length = len(tokens_a) + len(tokens_b) 357 | if total_length <= max_num_tokens: 358 | break 359 | 360 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 361 | assert len(trunc_tokens) >= 1 362 | 363 | # We want to sometimes truncate from the front and sometimes from the 364 | # back to add more randomness and avoid biases. 365 | if rng.random() < 0.5: 366 | del trunc_tokens[0] 367 | else: 368 | trunc_tokens.pop() 369 | 370 | def print_configuration_op(FLAGS): 371 | print('My Configurations:') 372 | for name, value in FLAGS.__flags.items(): 373 | value = value.value 374 | if type(value) == float: 375 | print(' %s:\t %f' % (name, value)) 376 | elif type(value) == int: 377 | print(' %s:\t %d' % (name, value)) 378 | elif type(value) == str: 379 | print(' %s:\t %s' % (name, value)) 380 | elif type(value) == bool: 381 | print(' %s:\t %s' % (name, value)) 382 | else: 383 | print('%s:\t %s' % (name, value)) 384 | print('End of configuration') 385 | 386 | def main(_): 387 | tf.logging.set_verbosity(tf.logging.INFO) 388 | print_configuration_op(FLAGS) 389 | 390 | tokenizer = tokenization.FullTokenizer( 391 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 392 | 393 | # 1. load context-response pairs 394 | response_dict = {} 395 | with open(FLAGS.response_file, 'rt') as f: 396 | for line in f: 397 | line = line.strip() 398 | fields = line.split('\t') 399 | if len(fields) != 2: 400 | print("WRONG LINE: {}".format(line)) 401 | r_text = 'unknown' 402 | else: 403 | r_text = fields[1] 404 | response_dict[fields[0]] = r_text 405 | 406 | context = [] 407 | response = [] 408 | switch = [] 409 | with open(FLAGS.train_file, 'rb') as f: 410 | lines = f.readlines() 411 | for index, line in enumerate(lines): 412 | line = line.decode('utf-8').strip() 413 | fields = line.split('\t') 414 | context_i = fields[1] 415 | utterances_i = context_i.split(" __EOS__ ") 416 | # utterances = [utterance + " __EOS__" for utterance in utterances] 417 | new_utterances_i = [] 418 | switch_i = [] 419 | for j, utterance in enumerate(utterances_i): 420 | new_utterances_i.append(utterance + " __EOS__") 421 | if j%2 == 0: 422 | switch_i.append(0) 423 | else: 424 | switch_i.append(1) 425 | assert len(new_utterances_i) == len(switch_i) 426 | 427 | if fields[2] != "NA": 428 | pos_ids = [id for id in fields[2].split('|')] 429 | for r_id in pos_ids: 430 | context.append(new_utterances_i) 431 | 432 | switch.append(switch_i) 433 | 434 | response_i = response_dict[r_id] 435 | response.append(response_i) 436 | 437 | if index % 10000 == 0: 438 | print('Done:', index) 439 | 440 | tf.logging.info("Reading from input files: {} context-response pairs".format(len(context))) 441 | 442 | 443 | # 2. create training instances 444 | rng = random.Random(FLAGS.random_seed) 445 | instances = create_training_instances( 446 | context, response, switch, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, 447 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, 448 | rng) 449 | 450 | 451 | # 3. write instance to example files 452 | output_files = [FLAGS.output_file] 453 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 454 | FLAGS.max_predictions_per_seq, output_files) 455 | 456 | 457 | if __name__ == "__main__": 458 | flags.mark_flag_as_required("train_file") 459 | flags.mark_flag_as_required("response_file") 460 | flags.mark_flag_as_required("output_file") 461 | flags.mark_flag_as_required("vocab_file") 462 | tf.app.run() 463 | 464 | -------------------------------------------------------------------------------- /adapt_switch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run masked LM/next sentence masked_lm pre-training for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import modeling_switch as modeling 23 | import optimization 24 | import tensorflow as tf 25 | from time import time 26 | import datetime 27 | 28 | flags = tf.flags 29 | FLAGS = flags.FLAGS 30 | 31 | ## Required parameters 32 | flags.DEFINE_integer("sample_num", '126', 33 | "total sample number") 34 | 35 | flags.DEFINE_integer("mid_save_step", '15000', 36 | "Epoch is so long, mid_save_step 15000 is roughly 3 hours") 37 | 38 | flags.DEFINE_string("input_file", 'output/test.tfrecord', 39 | "The input data dir. Should contain the .tsv files (or other data files) for the task.") 40 | 41 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json', 42 | "The config json file corresponding to the pre-trained BERT model. " 43 | "This specifies the model architecture.") 44 | 45 | flags.DEFINE_string("task_name", 'adaptation', 46 | "The name of the task to train.") 47 | 48 | flags.DEFINE_string("vocab_file", 'uncased_L-12_H-768_A-12/vocab.txt', 49 | "The vocabulary file that the BERT model was trained on.") 50 | 51 | flags.DEFINE_string("output_dir", './L-12_H-768_A-12_adapted', 52 | "The output directory where the model checkpoints will be written.") 53 | 54 | ## Other parameters 55 | flags.DEFINE_string("init_checkpoint", 'uncased_L-12_H-768_A-12/bert_model.ckpt', 56 | "Initial checkpoint (usually from a pre-trained BERT model).") 57 | 58 | flags.DEFINE_integer("max_seq_length", 320, 59 | "The maximum total input sequence length after WordPiece tokenization. " 60 | "Sequences longer than this will be truncated, and sequences shorter " 61 | "than this will be padded. Must match data generation.") 62 | 63 | flags.DEFINE_integer("max_predictions_per_seq", 10, 64 | "Maximum number of masked LM predictions per sequence. " 65 | "Must match data generation.") 66 | 67 | flags.DEFINE_bool("do_train", True, 68 | "Whether to run training.") 69 | 70 | flags.DEFINE_bool("do_eval", True, 71 | "Whether to run eval on the dev set.") 72 | 73 | flags.DEFINE_integer("train_batch_size", 8, 74 | "Total batch size for training.") 75 | 76 | flags.DEFINE_integer("eval_batch_size", 8, 77 | "Total batch size for eval.") 78 | 79 | flags.DEFINE_float("learning_rate", 5e-5, 80 | "The initial learning rate for Adam.") 81 | 82 | flags.DEFINE_float("warmup_proportion", 0.1, 83 | "Number of warmup steps.") 84 | 85 | flags.DEFINE_integer("num_train_epochs", 10, 86 | "num_train_epochs.") 87 | 88 | 89 | 90 | def model_fn_builder(features, is_training, bert_config, init_checkpoint, learning_rate, 91 | num_train_steps, num_warmup_steps, use_tpu, 92 | use_one_hot_embeddings): 93 | """Returns `model_fn` closure for TPUEstimator.""" 94 | 95 | input_ids, input_mask, segment_ids, switch_ids, masked_lm_positions, \ 96 | masked_lm_ids, masked_lm_weights, next_sentence_labels = features 97 | 98 | model = modeling.BertModel( 99 | config=bert_config, 100 | is_training=is_training, 101 | input_ids=input_ids, 102 | input_mask=input_mask, 103 | token_type_ids=segment_ids, 104 | switch_ids=switch_ids, 105 | use_one_hot_embeddings=use_one_hot_embeddings) 106 | 107 | (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( 108 | bert_config, model.get_sequence_output(), model.get_embedding_table(), 109 | masked_lm_positions, masked_lm_ids, masked_lm_weights) 110 | 111 | (next_sentence_loss, next_sentence_example_loss, next_sentence_log_probs) = get_next_sentence_output( 112 | bert_config, model.get_pooled_output(), next_sentence_labels) 113 | 114 | total_loss = masked_lm_loss + next_sentence_loss 115 | 116 | tvars = tf.trainable_variables() 117 | 118 | if init_checkpoint: 119 | (assignment_map, 120 | initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 121 | 122 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 123 | 124 | 125 | train_op = optimization.create_optimizer( 126 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 127 | 128 | matrix = metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, 129 | next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels) 130 | 131 | return train_op, total_loss, matrix, input_ids 132 | 133 | 134 | 135 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, 136 | next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels): 137 | """Computes the loss and accuracy of the model.""" 138 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs, 139 | [-1, masked_lm_log_probs.shape[-1]]) # [batch_size*max_predictions_per_seq, dim] 140 | masked_lm_predictions = tf.argmax( 141 | masked_lm_log_probs, axis=-1, output_type=tf.int32) # [batch_size*max_predictions_per_seq, ] 142 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) 143 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) 144 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) 145 | masked_lm_accuracy = tf.metrics.accuracy( 146 | labels=masked_lm_ids, 147 | predictions=masked_lm_predictions, 148 | weights=masked_lm_weights) 149 | masked_lm_mean_loss = tf.metrics.mean( 150 | values=masked_lm_example_loss, weights=masked_lm_weights) 151 | 152 | next_sentence_log_probs = tf.reshape( 153 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) # [batch_size, 2] 154 | next_sentence_predictions = tf.argmax( 155 | next_sentence_log_probs, axis=-1, output_type=tf.int32) # [batch_size, ] 156 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) 157 | next_sentence_accuracy = tf.metrics.accuracy( 158 | labels=next_sentence_labels, predictions=next_sentence_predictions) 159 | next_sentence_mean_loss = tf.metrics.mean( 160 | values=next_sentence_example_loss) 161 | # next_sentence_mean_loss = tf.reduce_mean(next_sentence_example_loss) 162 | 163 | return { 164 | "masked_lm_accuracy": masked_lm_accuracy, 165 | "masked_lm_loss": masked_lm_mean_loss, 166 | "next_sentence_accuracy": next_sentence_accuracy, 167 | "next_sentence_loss": next_sentence_mean_loss, 168 | } 169 | 170 | 171 | 172 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, 173 | label_ids, label_weights): 174 | """Get loss and log probs for the masked LM.""" 175 | input_tensor = gather_indexes(input_tensor, positions) # [batch_size*max_predictions_per_seq, dim] 176 | 177 | with tf.variable_scope("cls/predictions"): 178 | # We apply one more non-linear transformation before the output layer. 179 | # This matrix is not used after pre-training. 180 | with tf.variable_scope("transform"): 181 | input_tensor = tf.layers.dense( 182 | input_tensor, 183 | units=bert_config.hidden_size, 184 | activation=modeling.get_activation(bert_config.hidden_act), 185 | kernel_initializer=modeling.create_initializer( 186 | bert_config.initializer_range)) 187 | input_tensor = modeling.layer_norm(input_tensor) 188 | 189 | # The output weights are the same as the input embeddings, but there is 190 | # an output-only bias for each token. 191 | output_bias = tf.get_variable( 192 | "output_bias", 193 | shape=[bert_config.vocab_size], 194 | initializer=tf.zeros_initializer()) 195 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 196 | logits = tf.nn.bias_add(logits, output_bias) 197 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size*max_predictions_per_seq, vocab_size] 198 | 199 | label_ids = tf.reshape(label_ids, [-1]) 200 | label_weights = tf.reshape(label_weights, [-1]) 201 | 202 | one_hot_labels = tf.one_hot( 203 | label_ids, depth=bert_config.vocab_size, dtype=tf.float32) 204 | 205 | # The `positions` tensor might be zero-padded (if the sequence is too 206 | # short to have the maximum number of predictions). The `label_weights` 207 | # tensor has a value of 1.0 for every real prediction and 0.0 for the 208 | # padding predictions. 209 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch_size*max_predictions_per_seq, ] 210 | numerator = tf.reduce_sum(label_weights * per_example_loss) # [1, ] 211 | denominator = tf.reduce_sum(label_weights) + 1e-5 212 | loss = numerator / denominator 213 | 214 | return (loss, per_example_loss, log_probs) 215 | 216 | 217 | def get_next_sentence_output(bert_config, input_tensor, labels): 218 | """Get loss and log probs for the next sentence prediction.""" 219 | 220 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 221 | # "random sentence". This weight matrix is not used after pre-training. 222 | with tf.variable_scope("cls/seq_relationship"): 223 | output_weights = tf.get_variable( 224 | "output_weights", 225 | shape=[2, bert_config.hidden_size], 226 | initializer=modeling.create_initializer(bert_config.initializer_range)) 227 | output_bias = tf.get_variable( 228 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 229 | 230 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 231 | logits = tf.nn.bias_add(logits, output_bias) 232 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size, 2] 233 | labels = tf.reshape(labels, [-1]) 234 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) 235 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) # [batch_size, ] 236 | loss = tf.reduce_mean(per_example_loss) # [1, ] 237 | return (loss, per_example_loss, log_probs) 238 | 239 | 240 | def gather_indexes(sequence_tensor, positions): 241 | """Gathers the vectors at the specific positions over a minibatch.""" 242 | # sequence_tensor = [batch_size, seq_length, width] 243 | # positions = [batch_size, max_predictions_per_seq] 244 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 245 | batch_size = sequence_shape[0] 246 | seq_length = sequence_shape[1] 247 | width = sequence_shape[2] 248 | 249 | flat_offsets = tf.reshape( 250 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 251 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 252 | flat_sequence_tensor = tf.reshape(sequence_tensor, 253 | [batch_size * seq_length, width]) 254 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 255 | return output_tensor 256 | 257 | 258 | def input_fn_builder(input_files, 259 | max_seq_length, 260 | max_predictions_per_seq, 261 | is_training, 262 | num_cpu_threads=4): 263 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 264 | 265 | def input_fn(params): 266 | """The actual input function.""" 267 | batch_size = params["batch_size"] 268 | 269 | name_to_features = { 270 | "input_ids": 271 | tf.FixedLenFeature([max_seq_length], tf.int64), 272 | "input_mask": 273 | tf.FixedLenFeature([max_seq_length], tf.int64), 274 | "segment_ids": 275 | tf.FixedLenFeature([max_seq_length], tf.int64), 276 | "masked_lm_positions": 277 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 278 | "masked_lm_ids": 279 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 280 | "masked_lm_weights": 281 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32), 282 | "next_sentence_labels": 283 | tf.FixedLenFeature([1], tf.int64), 284 | } 285 | 286 | # For training, we want a lot of parallel reading and shuffling. 287 | # For eval, we want no shuffling and parallel reading doesn't matter. 288 | if is_training: 289 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 290 | d = d.repeat() 291 | d = d.shuffle(buffer_size=len(input_files)) 292 | 293 | # `cycle_length` is the number of parallel files that get read. 294 | cycle_length = min(num_cpu_threads, len(input_files)) 295 | 296 | # `sloppy` mode means that the interleaving is not exact. This adds 297 | # even more randomness to the training pipeline. 298 | d = d.apply( 299 | tf.contrib.data.parallel_interleave( 300 | tf.data.TFRecordDataset, 301 | sloppy=is_training, 302 | cycle_length=cycle_length)) 303 | d = d.shuffle(buffer_size=100) 304 | else: 305 | d = tf.data.TFRecordDataset(input_files) 306 | # Since we evaluate for a fixed number of steps we don't want to encounter 307 | # out-of-range exceptions. 308 | d = d.repeat() 309 | 310 | # We must `drop_remainder` on training because the TPU requires fixed 311 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 312 | # and we *don't* want to drop the remainder, otherwise we wont cover 313 | # every sample. 314 | d = d.apply( 315 | tf.contrib.data.map_and_batch( 316 | lambda record: _decode_record(record, name_to_features), 317 | batch_size=batch_size, 318 | num_parallel_batches=num_cpu_threads, 319 | drop_remainder=True)) 320 | return d 321 | 322 | return input_fn 323 | 324 | 325 | def _decode_record(record, name_to_features): 326 | """Decodes a record to a TensorFlow example.""" 327 | example = tf.parse_single_example(record, name_to_features) 328 | 329 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 330 | # So cast all int64 to int32. 331 | for name in list(example.keys()): 332 | t = example[name] 333 | if t.dtype == tf.int64: 334 | t = tf.to_int32(t) 335 | example[name] = t 336 | 337 | return example 338 | 339 | 340 | 341 | def run_epoch( epoch, sess, evaluate, eval_op, input_ids, lm_losses, saver, root_path, save_step, mid_save_step, phase, batch_size=16, train_op=tf.constant(0)): 342 | t_loss = 0 343 | n_all = 0 344 | t0 = time() 345 | t1 = time() 346 | 347 | masked_lm_accuracy = 0.0 348 | masked_lm_mean_loss = 0.0 349 | next_sentence_accuracy = 0.0 350 | next_sentence_mean_loss = 0.0 351 | 352 | step = 0 353 | 354 | print('running begin ... ') 355 | try: 356 | while True: 357 | step = step + 1 358 | y, matrix, batch_loss, _, _ = sess.run([input_ids, evaluate, lm_losses, train_op, eval_op] ) 359 | masked_lm_accuracy, masked_lm_mean_loss, next_sentence_accuracy, next_sentence_mean_loss = matrix 360 | 361 | n_sample = len(y) 362 | n_all += n_sample 363 | 364 | t_loss += batch_loss * n_sample 365 | # save every epoch or 3 hour 366 | # if (step % save_step == 0) or (step % 15000 == 0): 367 | if (step % mid_save_step == 2): 368 | # c_time = str(datetime.datetime.now()).replace(' ', '-').split('.')[0] 369 | c_time = str(int(time())) 370 | save_path = os.path.join(root_path, 'bert_model_{0}_epoch_{1}'.format(c_time, epoch)) 371 | if not os.path.exists(save_path): 372 | os.makedirs(save_path) 373 | saver.save(sess, os.path.join(save_path,'bert_model_{}.ckpt'.format(c_time)), global_step = step) 374 | print('save model epoch {}'.format(int(step/save_step))) 375 | print('masked_lm_accuracy {:.6f}, masked_lm_mean_loss {:.6f}, next_sentence_accuracy {:.6f}, next_sentence_mean_loss{:.6f}'.format( 376 | masked_lm_accuracy, masked_lm_mean_loss, next_sentence_accuracy, next_sentence_mean_loss 377 | )) 378 | 379 | print("{} Loss: {:.4f}, {:.2f} Seconds Used:". 380 | format(phase, t_loss / n_all, time() - t1)) 381 | t1=time() 382 | print('Sample seen {} total time {}'.format(n_all,time() - t0)) 383 | 384 | except tf.errors.OutOfRangeError: 385 | print('Epoch {} Done'.format(epoch)) 386 | # c_time = str(datetime.datetime.now()).replace(' ', '-').split('.')[0] 387 | c_time = str(int(time())) 388 | save_path = os.path.join(root_path, 'bert_model_{0}_epoch_{1}'.format(c_time, step / save_step)) 389 | if not os.path.exists(save_path): 390 | os.makedirs(save_path) 391 | saver.save(sess, os.path.join(save_path, 'bert_model_{}.ckpt'.format(c_time)), global_step=step) 392 | print('save model epoch {}'.format(int(step / save_step))) 393 | 394 | print( 395 | 'masked_lm_accuracy {:.6f}, masked_lm_mean_loss {:.6f}, next_sentence_accuracy {:.6f}, next_sentence_mean_loss{:.6f}'.format( 396 | masked_lm_accuracy, masked_lm_mean_loss, next_sentence_accuracy, next_sentence_mean_loss 397 | )) 398 | print("{} Loss: {:.4f}, {:.2f} Seconds Used:". 399 | format(phase, t_loss / n_all, time() - t1)) 400 | t1 = time() 401 | print('Sample seen {} total time {}'.format(n_all, time() - t0)) 402 | pass 403 | 404 | 405 | 406 | def parse_exmp(serial_exmp): 407 | input_data = tf.parse_single_example(serial_exmp, 408 | features={ 409 | "input_ids": 410 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 411 | "input_mask": 412 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 413 | "segment_ids": 414 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 415 | "switch_ids": 416 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 417 | "masked_lm_positions": 418 | tf.FixedLenFeature([FLAGS.max_predictions_per_seq], tf.int64), 419 | "masked_lm_ids": 420 | tf.FixedLenFeature([FLAGS.max_predictions_per_seq], tf.int64), 421 | "masked_lm_weights": 422 | tf.FixedLenFeature([FLAGS.max_predictions_per_seq], tf.float32), 423 | "next_sentence_labels": 424 | tf.FixedLenFeature([1], tf.int64), 425 | } 426 | ) 427 | # So cast all int64 to int32. 428 | for name in list(input_data.keys()): 429 | t = input_data[name] 430 | if t.dtype == tf.int64: 431 | t = tf.to_int32(t) 432 | input_data[name] = t 433 | 434 | input_ids = input_data["input_ids"] 435 | input_mask = input_data["input_mask"] 436 | segment_ids = input_data["segment_ids"] 437 | switch_ids = input_data["switch_ids"] 438 | m_lp = input_data["masked_lm_positions"] 439 | m_lids = input_data["masked_lm_ids"] 440 | m_lm_w = input_data["masked_lm_weights"] 441 | nsl = input_data["next_sentence_labels"] 442 | return input_ids, input_mask, segment_ids, switch_ids, m_lp, m_lids, m_lm_w, nsl 443 | 444 | 445 | def print_configuration_op(FLAGS): 446 | print('My Configurations:') 447 | #pdb.set_trace() 448 | for name, value in FLAGS.__flags.items(): 449 | value=value.value 450 | if type(value) == float: 451 | print(' %s:\t %f'%(name, value)) 452 | elif type(value) == int: 453 | print(' %s:\t %d'%(name, value)) 454 | elif type(value) == str: 455 | print(' %s:\t %s'%(name, value)) 456 | elif type(value) == bool: 457 | print(' %s:\t %s'%(name, value)) 458 | else: 459 | print('%s:\t %s' % (name, value)) 460 | print('End of configuration') 461 | 462 | 463 | def main(_): 464 | tf.logging.set_verbosity(tf.logging.INFO) 465 | print_configuration_op(FLAGS) 466 | 467 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 468 | root_path = FLAGS.output_dir 469 | if not os.path.exists(root_path): 470 | os.makedirs(root_path) 471 | 472 | num_train_steps = FLAGS.sample_num // FLAGS.train_batch_size * FLAGS.num_train_epochs 473 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 474 | 475 | buffer_size = 1000 476 | filenames = tf.placeholder(tf.string, shape=[None]) 477 | dataset = tf.data.TFRecordDataset(filenames) 478 | dataset = dataset.map(parse_exmp) # Parse the record into tensors. 479 | dataset = dataset.repeat(1) 480 | dataset = dataset.shuffle(buffer_size) 481 | dataset = dataset.batch(FLAGS.train_batch_size) 482 | iterator = dataset.make_initializable_iterator() 483 | save_step = FLAGS.sample_num // FLAGS.train_batch_size 484 | 485 | input_ids, input_mask, segment_ids, switch_ids, masked_lm_positions, \ 486 | masked_lm_ids, masked_lm_weights, next_sentence_labels = iterator.get_next() 487 | features = [input_ids, input_mask, segment_ids, switch_ids, masked_lm_positions, \ 488 | masked_lm_ids, masked_lm_weights, next_sentence_labels] 489 | train_op, loss, matrix, input_ids = model_fn_builder( 490 | features, # ----model_fn_builder---- 491 | is_training=True, 492 | bert_config=bert_config, 493 | init_checkpoint=FLAGS.init_checkpoint, 494 | learning_rate=FLAGS.learning_rate, 495 | num_train_steps=num_train_steps, 496 | num_warmup_steps=num_warmup_steps, 497 | use_tpu=False, 498 | use_one_hot_embeddings=False) 499 | 500 | 501 | masked_lm_accuracy, masked_acc_op = matrix["masked_lm_accuracy"] 502 | masked_lm_mean_loss, masked_loss_op= matrix["masked_lm_loss"] 503 | next_sentence_accuracy, next_sentence_op = matrix["next_sentence_accuracy"] 504 | next_sentence_mean_loss, next_sentence_loss_op = matrix["next_sentence_loss"] 505 | 506 | evaluate = [masked_lm_accuracy, masked_lm_mean_loss, next_sentence_accuracy, next_sentence_mean_loss] 507 | eval_op = [masked_acc_op, masked_loss_op, next_sentence_op, next_sentence_loss_op] 508 | 509 | config = tf.ConfigProto(allow_soft_placement=True) 510 | config.gpu_options.allow_growth = True 511 | saver = tf.train.Saver() 512 | with tf.Session(config=config) as sess: 513 | sess.run(tf.global_variables_initializer()) 514 | sess.run(tf.local_variables_initializer()) 515 | 516 | for epoch in range(FLAGS.num_train_epochs): 517 | sess.run(iterator.initializer, feed_dict={filenames: [FLAGS.input_file]}) 518 | run_epoch(epoch, sess, evaluate, eval_op, input_ids, loss, saver, root_path, save_step, 519 | FLAGS.mid_save_step,'train', batch_size=FLAGS.train_batch_size, train_op=train_op) 520 | 521 | 522 | 523 | if __name__ == "__main__": 524 | tf.app.run() 525 | 526 | -------------------------------------------------------------------------------- /modeling_switch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import numpy as np 27 | import six 28 | import tensorflow as tf 29 | 30 | 31 | class BertConfig(object): 32 | """Configuration for `BertModel`.""" 33 | 34 | def __init__(self, 35 | vocab_size, 36 | hidden_size=768, 37 | num_hidden_layers=12, 38 | num_attention_heads=12, 39 | intermediate_size=3072, 40 | hidden_act="gelu", 41 | hidden_dropout_prob=0.1, 42 | attention_probs_dropout_prob=0.1, 43 | max_position_embeddings=512, 44 | type_vocab_size=16, 45 | initializer_range=0.02): 46 | """Constructs BertConfig. 47 | 48 | Args: 49 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 50 | hidden_size: Size of the encoder layers and the pooler layer. 51 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 52 | num_attention_heads: Number of attention heads for each attention layer in 53 | the Transformer encoder. 54 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 55 | layer in the Transformer encoder. 56 | hidden_act: The non-linear activation function (function or string) in the 57 | encoder and pooler. 58 | hidden_dropout_prob: The dropout probability for all fully connected 59 | layers in the embeddings, encoder, and pooler. 60 | attention_probs_dropout_prob: The dropout ratio for the attention 61 | probabilities. 62 | max_position_embeddings: The maximum sequence length that this model might 63 | ever be used with. Typically set this to something large just in case 64 | (e.g., 512 or 1024 or 2048). 65 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 66 | `BertModel`. 67 | initializer_range: The stdev of the truncated_normal_initializer for 68 | initializing all weight matrices. 69 | """ 70 | self.vocab_size = vocab_size 71 | self.hidden_size = hidden_size 72 | self.num_hidden_layers = num_hidden_layers 73 | self.num_attention_heads = num_attention_heads 74 | self.hidden_act = hidden_act 75 | self.intermediate_size = intermediate_size 76 | self.hidden_dropout_prob = hidden_dropout_prob 77 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 78 | self.max_position_embeddings = max_position_embeddings 79 | self.type_vocab_size = type_vocab_size 80 | self.initializer_range = initializer_range 81 | 82 | @classmethod 83 | def from_dict(cls, json_object): 84 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 85 | config = BertConfig(vocab_size=None) 86 | for (key, value) in six.iteritems(json_object): 87 | config.__dict__[key] = value 88 | return config 89 | 90 | @classmethod 91 | def from_json_file(cls, json_file): 92 | """Constructs a `BertConfig` from a json file of parameters.""" 93 | with tf.gfile.GFile(json_file, "r") as reader: 94 | text = reader.read() 95 | return cls.from_dict(json.loads(text)) 96 | 97 | def to_dict(self): 98 | """Serializes this instance to a Python dictionary.""" 99 | output = copy.deepcopy(self.__dict__) 100 | return output 101 | 102 | def to_json_string(self): 103 | """Serializes this instance to a JSON string.""" 104 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 105 | 106 | 107 | class BertModel(object): 108 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 109 | 110 | Example usage: 111 | 112 | ```python 113 | # Already been converted into WordPiece token ids 114 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 115 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 116 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 117 | 118 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 119 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 120 | 121 | model = modeling.BertModel(config=config, is_training=True, 122 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 123 | 124 | label_embeddings = tf.get_variable(...) 125 | pooled_output = model.get_pooled_output() 126 | logits = tf.matmul(pooled_output, label_embeddings) 127 | ... 128 | ``` 129 | """ 130 | 131 | def __init__(self, 132 | config, 133 | is_training, 134 | input_ids, 135 | input_mask=None, 136 | token_type_ids=None, 137 | switch_ids=None, 138 | use_one_hot_embeddings=False, 139 | scope=None): 140 | """Constructor for BertModel. 141 | 142 | Args: 143 | config: `BertConfig` instance. 144 | is_training: bool. true for training model, false for eval model. Controls 145 | whether dropout will be applied. 146 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 147 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 148 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 149 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 150 | embeddings or tf.embedding_lookup() for the word embeddings. 151 | scope: (optional) variable scope. Defaults to "bert". 152 | 153 | Raises: 154 | ValueError: The config is invalid or one of the input tensor shapes 155 | is invalid. 156 | """ 157 | config = copy.deepcopy(config) 158 | # if not is_training: 159 | # config.hidden_dropout_prob = 0.0 160 | # config.attention_probs_dropout_prob = 0.0 161 | config.hidden_dropout_prob = tf.cast(is_training, tf.float32) * config.hidden_dropout_prob 162 | config.attention_probs_dropout_prob = tf.cast(is_training, tf.float32) * config.attention_probs_dropout_prob 163 | 164 | input_shape = get_shape_list(input_ids, expected_rank=2) 165 | batch_size = input_shape[0] 166 | seq_length = input_shape[1] 167 | 168 | if input_mask is None: 169 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 170 | 171 | if token_type_ids is None: 172 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 173 | 174 | with tf.variable_scope(scope, default_name="bert"): 175 | with tf.variable_scope("embeddings"): 176 | # Perform embedding lookup on the word ids. 177 | (self.embedding_output, self.embedding_table) = embedding_lookup( 178 | input_ids=input_ids, 179 | vocab_size=config.vocab_size, 180 | embedding_size=config.hidden_size, 181 | initializer_range=config.initializer_range, 182 | word_embedding_name="word_embeddings", 183 | use_one_hot_embeddings=use_one_hot_embeddings) 184 | 185 | # Add positional embeddings and token type embeddings, then layer 186 | # normalize and perform dropout. 187 | self.embedding_output = embedding_postprocessor( 188 | input_tensor=self.embedding_output, 189 | use_token_type=True, 190 | token_type_ids=token_type_ids, 191 | token_type_vocab_size=config.type_vocab_size, 192 | token_type_embedding_name="token_type_embeddings", 193 | use_switch=True, 194 | switch_ids=switch_ids, 195 | use_position_embeddings=True, 196 | position_embedding_name="position_embeddings", 197 | initializer_range=config.initializer_range, 198 | max_position_embeddings=config.max_position_embeddings, 199 | dropout_prob=config.hidden_dropout_prob) 200 | 201 | with tf.variable_scope("encoder"): 202 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 203 | # mask of shape [batch_size, seq_length, seq_length] which is used 204 | # for the attention scores. 205 | attention_mask = create_attention_mask_from_input_mask( 206 | input_ids, input_mask) 207 | 208 | # Run the stacked transformer. 209 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 210 | self.all_encoder_layers = transformer_model( 211 | input_tensor=self.embedding_output, 212 | attention_mask=attention_mask, 213 | hidden_size=config.hidden_size, 214 | num_hidden_layers=config.num_hidden_layers, 215 | num_attention_heads=config.num_attention_heads, 216 | intermediate_size=config.intermediate_size, 217 | intermediate_act_fn=get_activation(config.hidden_act), 218 | hidden_dropout_prob=config.hidden_dropout_prob, 219 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 220 | initializer_range=config.initializer_range, 221 | do_return_all_layers=True) 222 | 223 | self.sequence_output = self.all_encoder_layers[-1] 224 | # The "pooler" converts the encoded sequence tensor of shape 225 | # [batch_size, seq_length, hidden_size] to a tensor of shape 226 | # [batch_size, hidden_size]. This is necessary for segment-level 227 | # (or segment-pair-level) classification tasks where we need a fixed 228 | # dimensional representation of the segment. 229 | with tf.variable_scope("pooler"): 230 | # We "pool" the model by simply taking the hidden state corresponding 231 | # to the first token. We assume that this has been pre-trained 232 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 233 | self.pooled_output = tf.layers.dense( 234 | first_token_tensor, 235 | config.hidden_size, 236 | activation=tf.tanh, 237 | kernel_initializer=create_initializer(config.initializer_range)) 238 | 239 | def get_pooled_output(self): 240 | return self.pooled_output 241 | 242 | def get_sequence_output(self): 243 | """Gets final hidden layer of encoder. 244 | 245 | Returns: 246 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 247 | to the final hidden of the transformer encoder. 248 | """ 249 | return self.sequence_output 250 | 251 | def get_all_encoder_layers(self): 252 | return self.all_encoder_layers 253 | 254 | def get_embedding_output(self): 255 | """Gets output of the embedding lookup (i.e., input to the transformer). 256 | 257 | Returns: 258 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 259 | to the output of the embedding layer, after summing the word 260 | embeddings with the positional embeddings and the token type embeddings, 261 | then performing layer normalization. This is the input to the transformer. 262 | """ 263 | return self.embedding_output 264 | 265 | def get_embedding_table(self): 266 | return self.embedding_table 267 | 268 | 269 | def gelu(x): 270 | """Gaussian Error Linear Unit. 271 | 272 | This is a smoother version of the RELU. 273 | Original paper: https://arxiv.org/abs/1606.08415 274 | Args: 275 | x: float Tensor to perform activation. 276 | 277 | Returns: 278 | `x` with the GELU activation applied. 279 | """ 280 | cdf = 0.5 * (1.0 + tf.tanh( 281 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) 282 | return x * cdf 283 | 284 | 285 | def get_activation(activation_string): 286 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 287 | 288 | Args: 289 | activation_string: String name of the activation function. 290 | 291 | Returns: 292 | A Python function corresponding to the activation function. If 293 | `activation_string` is None, empty, or "linear", this will return None. 294 | If `activation_string` is not a string, it will return `activation_string`. 295 | 296 | Raises: 297 | ValueError: The `activation_string` does not correspond to a known 298 | activation. 299 | """ 300 | 301 | # We assume that anything that"s not a string is already an activation 302 | # function, so we just return it. 303 | if not isinstance(activation_string, six.string_types): 304 | return activation_string 305 | 306 | if not activation_string: 307 | return None 308 | 309 | act = activation_string.lower() 310 | if act == "linear": 311 | return None 312 | elif act == "relu": 313 | return tf.nn.relu 314 | elif act == "gelu": 315 | return gelu 316 | elif act == "tanh": 317 | return tf.tanh 318 | else: 319 | raise ValueError("Unsupported activation: %s" % act) 320 | 321 | 322 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 323 | """Compute the union of the current variables and checkpoint variables.""" 324 | assignment_map = {} 325 | initialized_variable_names = {} 326 | 327 | name_to_variable = collections.OrderedDict() 328 | for var in tvars: 329 | name = var.name 330 | m = re.match("^(.*):\\d+$", name) 331 | if m is not None: 332 | name = m.group(1) 333 | name_to_variable[name] = var 334 | 335 | init_vars = tf.train.list_variables(init_checkpoint) 336 | 337 | assignment_map = collections.OrderedDict() 338 | for x in init_vars: 339 | (name, var) = (x[0], x[1]) 340 | if name not in name_to_variable: 341 | continue 342 | assignment_map[name] = name 343 | initialized_variable_names[name] = 1 344 | initialized_variable_names[name + ":0"] = 1 345 | 346 | return (assignment_map, initialized_variable_names) 347 | 348 | 349 | def dropout(input_tensor, dropout_prob): 350 | """Perform dropout. 351 | 352 | Args: 353 | input_tensor: float Tensor. 354 | dropout_prob: Python float. The probability of dropping out a value (NOT of 355 | *keeping* a dimension as in `tf.nn.dropout`). 356 | 357 | Returns: 358 | A version of `input_tensor` with dropout applied. 359 | """ 360 | if dropout_prob is None or dropout_prob == 0.0: 361 | return input_tensor 362 | 363 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 364 | return output 365 | 366 | 367 | def layer_norm(input_tensor, name=None): 368 | """Run layer normalization on the last dimension of the tensor.""" 369 | return tf.contrib.layers.layer_norm( 370 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 371 | 372 | 373 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 374 | """Runs layer normalization followed by dropout.""" 375 | output_tensor = layer_norm(input_tensor, name) 376 | output_tensor = dropout(output_tensor, dropout_prob) 377 | return output_tensor 378 | 379 | 380 | def create_initializer(initializer_range=0.02): 381 | """Creates a `truncated_normal_initializer` with the given range.""" 382 | return tf.truncated_normal_initializer(stddev=initializer_range) 383 | 384 | 385 | def embedding_lookup(input_ids, 386 | vocab_size, 387 | embedding_size=128, 388 | initializer_range=0.02, 389 | word_embedding_name="word_embeddings", 390 | use_one_hot_embeddings=False): 391 | """Looks up words embeddings for id tensor. 392 | 393 | Args: 394 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 395 | ids. 396 | vocab_size: int. Size of the embedding vocabulary. 397 | embedding_size: int. Width of the word embeddings. 398 | initializer_range: float. Embedding initialization range. 399 | word_embedding_name: string. Name of the embedding table. 400 | use_one_hot_embeddings: bool. If True, use one-hot method for word 401 | embeddings. If False, use `tf.gather()`. 402 | 403 | Returns: 404 | float Tensor of shape [batch_size, seq_length, embedding_size]. 405 | """ 406 | # This function assumes that the input is of shape [batch_size, seq_length, 407 | # num_inputs]. 408 | # 409 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 410 | # reshape to [batch_size, seq_length, 1]. 411 | if input_ids.shape.ndims == 2: 412 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 413 | 414 | embedding_table = tf.get_variable( 415 | name=word_embedding_name, 416 | shape=[vocab_size, embedding_size], 417 | initializer=create_initializer(initializer_range)) 418 | 419 | flat_input_ids = tf.reshape(input_ids, [-1]) 420 | if use_one_hot_embeddings: 421 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 422 | output = tf.matmul(one_hot_input_ids, embedding_table) 423 | else: 424 | output = tf.gather(embedding_table, flat_input_ids) 425 | 426 | input_shape = get_shape_list(input_ids) 427 | 428 | output = tf.reshape(output, 429 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 430 | return (output, embedding_table) 431 | 432 | 433 | def embedding_postprocessor(input_tensor, 434 | use_token_type=False, 435 | token_type_ids=None, 436 | token_type_vocab_size=16, 437 | token_type_embedding_name="token_type_embeddings", 438 | use_switch=False, 439 | switch_ids=None, 440 | use_position_embeddings=True, 441 | position_embedding_name="position_embeddings", 442 | initializer_range=0.02, 443 | max_position_embeddings=512, 444 | dropout_prob=0.1): 445 | """Performs various post-processing on a word embedding tensor. 446 | 447 | Args: 448 | input_tensor: float Tensor of shape [batch_size, seq_length, 449 | embedding_size]. 450 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 451 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 452 | Must be specified if `use_token_type` is True. 453 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 454 | token_type_embedding_name: string. The name of the embedding table variable 455 | for token type ids. 456 | use_position_embeddings: bool. Whether to add position embeddings for the 457 | position of each token in the sequence. 458 | position_embedding_name: string. The name of the embedding table variable 459 | for positional embeddings. 460 | initializer_range: float. Range of the weight initialization. 461 | max_position_embeddings: int. Maximum sequence length that might ever be 462 | used with this model. This can be longer than the sequence length of 463 | input_tensor, but cannot be shorter. 464 | dropout_prob: float. Dropout probability applied to the final output tensor. 465 | 466 | Returns: 467 | float tensor with same shape as `input_tensor`. 468 | 469 | Raises: 470 | ValueError: One of the tensor shapes or input values is invalid. 471 | """ 472 | input_shape = get_shape_list(input_tensor, expected_rank=3) 473 | batch_size = input_shape[0] 474 | seq_length = input_shape[1] 475 | width = input_shape[2] 476 | 477 | output = input_tensor 478 | 479 | if use_token_type: 480 | if token_type_ids is None: 481 | raise ValueError("`token_type_ids` must be specified if" 482 | "`use_token_type` is True.") 483 | token_type_table = tf.get_variable( 484 | name=token_type_embedding_name, 485 | shape=[token_type_vocab_size, width], 486 | initializer=create_initializer(initializer_range)) 487 | # This vocab will be small so we always do one-hot here, since it is always 488 | # faster for a small vocabulary. 489 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 490 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 491 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 492 | token_type_embeddings = tf.reshape(token_type_embeddings, 493 | [batch_size, seq_length, width]) 494 | output += token_type_embeddings 495 | 496 | if use_switch: 497 | if switch_ids is None: 498 | raise ValueError("`switch_ids` must be specified if" 499 | "`use_switch` is True.") 500 | switch_type_table = tf.get_variable( 501 | name='switch_embedding', 502 | shape=[2, width], 503 | initializer=create_initializer(initializer_range)) 504 | # 505 | flat_switch_ids = tf.reshape(switch_ids, [-1]) 506 | switch_one_hot_ids = tf.one_hot(flat_switch_ids, depth=2) 507 | switch_type_embeddings = tf.matmul(switch_one_hot_ids, switch_type_table) 508 | switch_type_embeddings = tf.reshape(switch_type_embeddings, 509 | [batch_size, seq_length, width]) 510 | output += switch_type_embeddings 511 | 512 | if use_position_embeddings: 513 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 514 | with tf.control_dependencies([assert_op]): 515 | full_position_embeddings = tf.get_variable( 516 | name=position_embedding_name, 517 | shape=[max_position_embeddings, width], 518 | initializer=create_initializer(initializer_range)) 519 | # Since the position embedding table is a learned variable, we create it 520 | # using a (long) sequence length `max_position_embeddings`. The actual 521 | # sequence length might be shorter than this, for faster training of 522 | # tasks that do not have long sequences. 523 | # 524 | # So `full_position_embeddings` is effectively an embedding table 525 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 526 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 527 | # perform a slice. 528 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 529 | [seq_length, -1]) 530 | num_dims = len(output.shape.as_list()) 531 | 532 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 533 | # we broadcast among the first dimensions, which is typically just 534 | # the batch size. 535 | position_broadcast_shape = [] 536 | for _ in range(num_dims - 2): 537 | position_broadcast_shape.append(1) 538 | position_broadcast_shape.extend([seq_length, width]) 539 | position_embeddings = tf.reshape(position_embeddings, 540 | position_broadcast_shape) 541 | output += position_embeddings 542 | 543 | output = layer_norm_and_dropout(output, dropout_prob) 544 | return output 545 | 546 | 547 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 548 | """Create 3D attention mask from a 2D tensor mask. 549 | 550 | Args: 551 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 552 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 553 | 554 | Returns: 555 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 556 | """ 557 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 558 | batch_size = from_shape[0] 559 | from_seq_length = from_shape[1] 560 | 561 | to_shape = get_shape_list(to_mask, expected_rank=2) 562 | to_seq_length = to_shape[1] 563 | 564 | to_mask = tf.cast( 565 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 566 | 567 | # We don't assume that `from_tensor` is a mask (although it could be). We 568 | # don't actually care if we attend *from* padding tokens (only *to* padding) 569 | # tokens so we create a tensor of all ones. 570 | # 571 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 572 | broadcast_ones = tf.ones( 573 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 574 | 575 | # Here we broadcast along two dimensions to create the mask. 576 | mask = broadcast_ones * to_mask 577 | 578 | return mask 579 | 580 | 581 | def attention_layer(from_tensor, 582 | to_tensor, 583 | attention_mask=None, 584 | num_attention_heads=1, 585 | size_per_head=512, 586 | query_act=None, 587 | key_act=None, 588 | value_act=None, 589 | attention_probs_dropout_prob=0.0, 590 | initializer_range=0.02, 591 | do_return_2d_tensor=False, 592 | batch_size=None, 593 | from_seq_length=None, 594 | to_seq_length=None): 595 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 596 | 597 | This is an implementation of multi-headed attention based on "Attention 598 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 599 | this is self-attention. Each timestep in `from_tensor` attends to the 600 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 601 | 602 | This function first projects `from_tensor` into a "query" tensor and 603 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 604 | of tensors of length `num_attention_heads`, where each tensor is of shape 605 | [batch_size, seq_length, size_per_head]. 606 | 607 | Then, the query and key tensors are dot-producted and scaled. These are 608 | softmaxed to obtain attention probabilities. The value tensors are then 609 | interpolated by these probabilities, then concatenated back to a single 610 | tensor and returned. 611 | 612 | In practice, the multi-headed attention are done with transposes and 613 | reshapes rather than actual separate tensors. 614 | 615 | Args: 616 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 617 | from_width]. 618 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 619 | attention_mask: (optional) int32 Tensor of shape [batch_size, 620 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 621 | attention scores will effectively be set to -infinity for any positions in 622 | the mask that are 0, and will be unchanged for positions that are 1. 623 | num_attention_heads: int. Number of attention heads. 624 | size_per_head: int. Size of each attention head. 625 | query_act: (optional) Activation function for the query transform. 626 | key_act: (optional) Activation function for the key transform. 627 | value_act: (optional) Activation function for the value transform. 628 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 629 | attention probabilities. 630 | initializer_range: float. Range of the weight initializer. 631 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 632 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 633 | output will be of shape [batch_size, from_seq_length, num_attention_heads 634 | * size_per_head]. 635 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 636 | of the 3D version of the `from_tensor` and `to_tensor`. 637 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 638 | of the 3D version of the `from_tensor`. 639 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 640 | of the 3D version of the `to_tensor`. 641 | 642 | Returns: 643 | float Tensor of shape [batch_size, from_seq_length, 644 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 645 | true, this will be of shape [batch_size * from_seq_length, 646 | num_attention_heads * size_per_head]). 647 | 648 | Raises: 649 | ValueError: Any of the arguments or tensor shapes are invalid. 650 | """ 651 | 652 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 653 | seq_length, width): 654 | output_tensor = tf.reshape( 655 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 656 | 657 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 658 | return output_tensor 659 | 660 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 661 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 662 | 663 | if len(from_shape) != len(to_shape): 664 | raise ValueError( 665 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 666 | 667 | if len(from_shape) == 3: 668 | batch_size = from_shape[0] 669 | from_seq_length = from_shape[1] 670 | to_seq_length = to_shape[1] 671 | elif len(from_shape) == 2: 672 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 673 | raise ValueError( 674 | "When passing in rank 2 tensors to attention_layer, the values " 675 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 676 | "must all be specified.") 677 | 678 | # Scalar dimensions referenced here: 679 | # B = batch size (number of sequences) 680 | # F = `from_tensor` sequence length 681 | # T = `to_tensor` sequence length 682 | # N = `num_attention_heads` 683 | # H = `size_per_head` 684 | 685 | from_tensor_2d = reshape_to_matrix(from_tensor) 686 | to_tensor_2d = reshape_to_matrix(to_tensor) 687 | 688 | # `query_layer` = [B*F, N*H] 689 | query_layer = tf.layers.dense( 690 | from_tensor_2d, 691 | num_attention_heads * size_per_head, 692 | activation=query_act, 693 | name="query", 694 | kernel_initializer=create_initializer(initializer_range)) 695 | 696 | # `key_layer` = [B*T, N*H] 697 | key_layer = tf.layers.dense( 698 | to_tensor_2d, 699 | num_attention_heads * size_per_head, 700 | activation=key_act, 701 | name="key", 702 | kernel_initializer=create_initializer(initializer_range)) 703 | 704 | # `value_layer` = [B*T, N*H] 705 | value_layer = tf.layers.dense( 706 | to_tensor_2d, 707 | num_attention_heads * size_per_head, 708 | activation=value_act, 709 | name="value", 710 | kernel_initializer=create_initializer(initializer_range)) 711 | 712 | # `query_layer` = [B, N, F, H] 713 | query_layer = transpose_for_scores(query_layer, batch_size, 714 | num_attention_heads, from_seq_length, 715 | size_per_head) 716 | 717 | # `key_layer` = [B, N, T, H] 718 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 719 | to_seq_length, size_per_head) 720 | 721 | # Take the dot product between "query" and "key" to get the raw 722 | # attention scores. 723 | # `attention_scores` = [B, N, F, T] 724 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 725 | attention_scores = tf.multiply(attention_scores, 726 | 1.0 / math.sqrt(float(size_per_head))) 727 | 728 | if attention_mask is not None: 729 | # `attention_mask` = [B, 1, F, T] 730 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 731 | 732 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 733 | # masked positions, this operation will create a tensor which is 0.0 for 734 | # positions we want to attend and -10000.0 for masked positions. 735 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 736 | 737 | # Since we are adding it to the raw scores before the softmax, this is 738 | # effectively the same as removing these entirely. 739 | attention_scores += adder 740 | 741 | # Normalize the attention scores to probabilities. 742 | # `attention_probs` = [B, N, F, T] 743 | attention_probs = tf.nn.softmax(attention_scores) 744 | 745 | # This is actually dropping out entire tokens to attend to, which might 746 | # seem a bit unusual, but is taken from the original Transformer paper. 747 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 748 | 749 | # `value_layer` = [B, T, N, H] 750 | value_layer = tf.reshape( 751 | value_layer, 752 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 753 | 754 | # `value_layer` = [B, N, T, H] 755 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 756 | 757 | # `context_layer` = [B, N, F, H] 758 | context_layer = tf.matmul(attention_probs, value_layer) 759 | 760 | # `context_layer` = [B, F, N, H] 761 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 762 | 763 | if do_return_2d_tensor: 764 | # `context_layer` = [B*F, N*H] 765 | context_layer = tf.reshape( 766 | context_layer, 767 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 768 | else: 769 | # `context_layer` = [B, F, N*H] 770 | context_layer = tf.reshape( 771 | context_layer, 772 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 773 | 774 | return context_layer 775 | 776 | 777 | def transformer_model(input_tensor, 778 | attention_mask=None, 779 | hidden_size=768, 780 | num_hidden_layers=12, 781 | num_attention_heads=12, 782 | intermediate_size=3072, 783 | intermediate_act_fn=gelu, 784 | hidden_dropout_prob=0.1, 785 | attention_probs_dropout_prob=0.1, 786 | initializer_range=0.02, 787 | do_return_all_layers=False): 788 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 789 | 790 | This is almost an exact implementation of the original Transformer encoder. 791 | 792 | See the original paper: 793 | https://arxiv.org/abs/1706.03762 794 | 795 | Also see: 796 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 797 | 798 | Args: 799 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 800 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 801 | seq_length], with 1 for positions that can be attended to and 0 in 802 | positions that should not be. 803 | hidden_size: int. Hidden size of the Transformer. 804 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 805 | num_attention_heads: int. Number of attention heads in the Transformer. 806 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 807 | forward) layer. 808 | intermediate_act_fn: function. The non-linear activation function to apply 809 | to the output of the intermediate/feed-forward layer. 810 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 811 | attention_probs_dropout_prob: float. Dropout probability of the attention 812 | probabilities. 813 | initializer_range: float. Range of the initializer (stddev of truncated 814 | normal). 815 | do_return_all_layers: Whether to also return all layers or just the final 816 | layer. 817 | 818 | Returns: 819 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 820 | hidden layer of the Transformer. 821 | 822 | Raises: 823 | ValueError: A Tensor shape or parameter is invalid. 824 | """ 825 | if hidden_size % num_attention_heads != 0: 826 | raise ValueError( 827 | "The hidden size (%d) is not a multiple of the number of attention " 828 | "heads (%d)" % (hidden_size, num_attention_heads)) 829 | 830 | attention_head_size = int(hidden_size / num_attention_heads) 831 | input_shape = get_shape_list(input_tensor, expected_rank=3) 832 | batch_size = input_shape[0] 833 | seq_length = input_shape[1] 834 | input_width = input_shape[2] 835 | 836 | # The Transformer performs sum residuals on all layers so the input needs 837 | # to be the same as the hidden size. 838 | if input_width != hidden_size: 839 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 840 | (input_width, hidden_size)) 841 | 842 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 843 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 844 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 845 | # help the optimizer. 846 | prev_output = reshape_to_matrix(input_tensor) 847 | 848 | all_layer_outputs = [] 849 | for layer_idx in range(num_hidden_layers): 850 | with tf.variable_scope("layer_%d" % layer_idx): 851 | layer_input = prev_output 852 | 853 | with tf.variable_scope("attention"): 854 | attention_heads = [] 855 | with tf.variable_scope("self"): 856 | attention_head = attention_layer( 857 | from_tensor=layer_input, 858 | to_tensor=layer_input, 859 | attention_mask=attention_mask, 860 | num_attention_heads=num_attention_heads, 861 | size_per_head=attention_head_size, 862 | attention_probs_dropout_prob=attention_probs_dropout_prob, 863 | initializer_range=initializer_range, 864 | do_return_2d_tensor=True, 865 | batch_size=batch_size, 866 | from_seq_length=seq_length, 867 | to_seq_length=seq_length) 868 | attention_heads.append(attention_head) 869 | 870 | attention_output = None 871 | if len(attention_heads) == 1: 872 | attention_output = attention_heads[0] 873 | else: 874 | # In the case where we have other sequences, we just concatenate 875 | # them to the self-attention head before the projection. 876 | attention_output = tf.concat(attention_heads, axis=-1) 877 | 878 | # Run a linear projection of `hidden_size` then add a residual 879 | # with `layer_input`. 880 | with tf.variable_scope("output"): 881 | attention_output = tf.layers.dense( 882 | attention_output, 883 | hidden_size, 884 | kernel_initializer=create_initializer(initializer_range)) 885 | attention_output = dropout(attention_output, hidden_dropout_prob) 886 | attention_output = layer_norm(attention_output + layer_input) 887 | 888 | # The activation is only applied to the "intermediate" hidden layer. 889 | with tf.variable_scope("intermediate"): 890 | intermediate_output = tf.layers.dense( 891 | attention_output, 892 | intermediate_size, 893 | activation=intermediate_act_fn, 894 | kernel_initializer=create_initializer(initializer_range)) 895 | 896 | # Down-project back to `hidden_size` then add the residual. 897 | with tf.variable_scope("output"): 898 | layer_output = tf.layers.dense( 899 | intermediate_output, 900 | hidden_size, 901 | kernel_initializer=create_initializer(initializer_range)) 902 | layer_output = dropout(layer_output, hidden_dropout_prob) 903 | layer_output = layer_norm(layer_output + attention_output) 904 | prev_output = layer_output 905 | all_layer_outputs.append(layer_output) 906 | 907 | if do_return_all_layers: 908 | final_outputs = [] 909 | for layer_output in all_layer_outputs: 910 | final_output = reshape_from_matrix(layer_output, input_shape) 911 | final_outputs.append(final_output) 912 | return final_outputs 913 | else: 914 | final_output = reshape_from_matrix(prev_output, input_shape) 915 | return final_output 916 | 917 | 918 | def get_shape_list(tensor, expected_rank=None, name=None): 919 | """Returns a list of the shape of tensor, preferring static dimensions. 920 | 921 | Args: 922 | tensor: A tf.Tensor object to find the shape of. 923 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 924 | specified and the `tensor` has a different rank, and exception will be 925 | thrown. 926 | name: Optional name of the tensor for the error message. 927 | 928 | Returns: 929 | A list of dimensions of the shape of tensor. All static dimensions will 930 | be returned as python integers, and dynamic dimensions will be returned 931 | as tf.Tensor scalars. 932 | """ 933 | if name is None: 934 | name = tensor.name 935 | 936 | if expected_rank is not None: 937 | assert_rank(tensor, expected_rank, name) 938 | 939 | shape = tensor.shape.as_list() 940 | 941 | non_static_indexes = [] 942 | for (index, dim) in enumerate(shape): 943 | if dim is None: 944 | non_static_indexes.append(index) 945 | 946 | if not non_static_indexes: 947 | return shape 948 | 949 | dyn_shape = tf.shape(tensor) 950 | for index in non_static_indexes: 951 | shape[index] = dyn_shape[index] 952 | return shape 953 | 954 | 955 | def reshape_to_matrix(input_tensor): 956 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 957 | ndims = input_tensor.shape.ndims 958 | if ndims < 2: 959 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 960 | (input_tensor.shape)) 961 | if ndims == 2: 962 | return input_tensor 963 | 964 | width = input_tensor.shape[-1] 965 | output_tensor = tf.reshape(input_tensor, [-1, width]) 966 | return output_tensor 967 | 968 | 969 | def reshape_from_matrix(output_tensor, orig_shape_list): 970 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 971 | if len(orig_shape_list) == 2: 972 | return output_tensor 973 | 974 | output_shape = get_shape_list(output_tensor) 975 | 976 | orig_dims = orig_shape_list[0:-1] 977 | width = output_shape[-1] 978 | 979 | return tf.reshape(output_tensor, orig_dims + [width]) 980 | 981 | 982 | def assert_rank(tensor, expected_rank, name=None): 983 | """Raises an exception if the tensor rank is not of the expected rank. 984 | 985 | Args: 986 | tensor: A tf.Tensor to check the rank of. 987 | expected_rank: Python integer or list of integers, expected rank. 988 | name: Optional name of the tensor for the error message. 989 | 990 | Raises: 991 | ValueError: If the expected shape doesn't match the actual shape. 992 | """ 993 | if name is None: 994 | name = tensor.name 995 | 996 | expected_rank_dict = {} 997 | if isinstance(expected_rank, six.integer_types): 998 | expected_rank_dict[expected_rank] = True 999 | else: 1000 | for x in expected_rank: 1001 | expected_rank_dict[x] = True 1002 | 1003 | actual_rank = tensor.shape.ndims 1004 | if actual_rank not in expected_rank_dict: 1005 | scope_name = tf.get_variable_scope().name 1006 | raise ValueError( 1007 | "For the tensor `%s` in scope `%s`, the actual rank " 1008 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 1009 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 1010 | --------------------------------------------------------------------------------