├── README.md ├── __init__.py ├── compute_metrics.py ├── create_finetuning_data_ar.py ├── create_finetuning_data_ar_gift.py ├── create_finetuning_data_rs.py ├── create_finetuning_data_rs_gift.py ├── create_finetuning_data_si.py ├── create_finetuning_data_si_gift.py ├── create_pretraining_data.py ├── data ├── emnlp2016 │ ├── README.md │ └── data_preprocess.py └── ijcai2019 │ └── README.md ├── image ├── result_addressee_recognition.png ├── result_addressee_recognition_gift.png ├── result_response_selection.png ├── result_response_selection_gift.png ├── result_speaker_identification.png └── result_speaker_identification_gift.png ├── metrics.py ├── modeling_speaker.py ├── modeling_speaker_gift.py ├── optimization.py ├── run_finetuning_ar.py ├── run_finetuning_ar_gift.py ├── run_finetuning_rs.py ├── run_finetuning_rs_gift.py ├── run_finetuning_si.py ├── run_finetuning_si_gift.py ├── run_pretraining.py ├── run_testing_ar.py ├── run_testing_ar_gift.py ├── run_testing_rs.py ├── run_testing_rs_gift.py ├── run_testing_si.py ├── run_testing_si_gift.py ├── scripts ├── run_finetuning.sh ├── run_finetuning_gift.sh ├── run_pretraining.sh ├── run_testing.sh └── run_testing_gift.sh ├── tokenization.py └── uncased_L-12_H-768_A-12 └── README.md /README.md: -------------------------------------------------------------------------------- 1 | # MPC-BERT & GIFT for Multi-Party Conversation Understanding 2 | This repository contains the source codes for the following papers: 3 | - [GIFT: Graph-Induced Fine-Tuning for Multi-Party Conversation Understanding](https://aclanthology.org/2023.acl-long.651.pdf).
4 | Jia-Chen Gu, Zhe-Hua Ling, Quan Liu, Cong Liu, Guoping Hu
5 | _ACL 2023_
6 | 7 | - [MPC-BERT: A Pre-Trained Language Model for Multi-Party Conversation Understanding](https://aclanthology.org/2021.acl-long.285.pdf).
8 | Jia-Chen Gu, Chongyang Tao, Zhen-Hua Ling, Can Xu, Xiubo Geng, Daxin Jiang
9 | _ACL 2021_
10 | 11 | 12 | ## Introduction of MPC-BERT 13 | Recently, various neural models for multi-party conversation (MPC) have achieved impressive improvements on a variety of tasks such as addressee recognition, speaker identification and response prediction. 14 | However, these existing methods on MPC usually represent interlocutors and utterances individually and ignore the inherent complicated structure in MPC which may provide crucial interlocutor and utterance semantics and would enhance the conversation understanding process. 15 | To this end, we present MPC-BERT, a pre-trained model for MPC understanding that considers learning who says what to whom in a unified model with several elaborated self-supervised tasks. 16 | Particularly, these tasks can be generally categorized into (1) interlocutor structure modeling including reply-to utterance recognition, identical speaker searching and pointer consistency distinction, and (2) utterance semantics modeling including masked shared utterance restoration and shared node detection. 17 | We evaluate MPC-BERT on three downstream tasks including addressee recognition, speaker identification and response selection. 18 | Experimental results show that MPC-BERT outperforms previous methods by large margins and achieves new state-of-the-art performance on all three downstream tasks at two benchmarks. 19 | 20 |
21 | 22 |
23 | 24 |
25 | 26 | 27 | ## Introduction of GIFT 28 | Addressing the issues of who saying what to whom in multi-party conversations (MPCs) has recently attracted a lot of research attention. However, existing methods on MPC understanding typically embed interlocutors and utterances into sequential information flows, or utilize only the superficial of inherent graph structures in MPCs. To this end, we present a plug-and-play and lightweight method named graph-induced fine-tuning (GIFT) which can adapt various Transformer-based pre-trained language models (PLMs) for universal MPC understanding. In detail, the full and equivalent connections among utterances in regular Transformer ignore the sparse but distinctive dependency of an utterance on another in MPCs. To distinguish different relationships between utterances, four types of edges are designed to integrate graph-induced signals into attention mechanisms to refine PLMs originally designed for processing sequential texts. We evaluate GIFT by implementing it into three PLMs, and test the performance on three downstream tasks including addressee recognition, speaker identification and response selection. Experimental results show that GIFT can significantly improve the performance of three PLMs on three downstream tasks and two benchmarks with only 4 additional parameters per encoding layer, achieving new state-of-the-art performance on MPC understanding. 29 | 30 |
31 | 32 |
33 | 34 |
35 | 36 | 37 | ## Dependencies 38 | Python 3.6
39 | Tensorflow 1.13.1 40 | 41 | 42 | ## Download 43 | - Download the [BERT released by the Google research](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip), 44 | and move to path: ./uncased_L-12_H-768_A-12
45 | 46 | - We also release the [pre-trained MPC-BERT model](https://drive.google.com/file/d/1krmSuy83IQ0XXYyS9KfurnclmprRHgx_/view?usp=sharing), 47 | and move to path: ./uncased_L-12_H-768_A-12_MPCBERT. You just need to fine-tune it to reproduce our results.
48 | 49 | - Download the [Hu et al. (2019) dataset](https://drive.google.com/file/d/1qSw9X22oGGbuRtfaOAf3Z7ficn6mZgi9/view?usp=sharing) used in our paper, 50 | and move to path: ```./data/ijcai2019/```
51 | 52 | - Download the [Ouchi and Tsuboi (2016) dataset](https://drive.google.com/file/d/1nMiH6dGZfWBoOGbIvyBJp8oxhD8PWSNc/view?usp=sharing) used in our paper, 53 | and move to path: ```./data/emnlp2016/```
54 | Unzip the dataset and run the following commands.
55 | ``` 56 | cd data/emnlp2016/ 57 | python data_preprocess.py 58 | ``` 59 | 60 | 61 | ## Pre-training 62 | Create the pre-training data. 63 | ``` 64 | python create_pretraining_data.py 65 | ``` 66 | Running the pre-training process. 67 | ``` 68 | cd scripts/ 69 | bash run_pretraining.sh 70 | ``` 71 | The pre-trained model will be saved to the path ```./uncased_L-12_H-768_A-12_MPCBERT```.
72 | Modify the filenames in this folder to make it the same as those in Google's BERT. 73 | 74 | 75 | ## Regular Fine-tuning and Testing 76 | Take the task of addressee recognition as an example.
77 | Create the fine-tuning data. 78 | ``` 79 | python create_finetuning_data_ar.py 80 | ``` 81 | Running the fine-tuning process. 82 | ``` 83 | cd scripts/ 84 | bash run_finetuning.sh 85 | ``` 86 | 87 | Modify the variable ```restore_model_dir``` in ```run_testing.sh```
88 | Running the testing process. 89 | ``` 90 | cd scripts/ 91 | bash run_testing.sh 92 | ``` 93 | 94 | 95 | ## GIFT Fine-tuning and Testing 96 | Take the task of addressee recognition as an example.
97 | Create the fine-tuning data. 98 | ``` 99 | python create_finetuning_data_ar_gift.py 100 | ``` 101 | Running the fine-tuning process. 102 | ``` 103 | cd scripts/ 104 | bash run_finetuning_gift.sh 105 | ``` 106 | 107 | Modify the variable ```restore_model_dir``` in ```run_testing_gift.sh```
108 | Running the testing process. 109 | ``` 110 | cd scripts/ 111 | bash run_testing_gift.sh 112 | ``` 113 | 114 | 115 | ## Downstream Tasks 116 | Replace these scripts and its corresponding data when evaluating on other downstream tasks. 117 | ``` 118 | create_finetuning_data_{ar, si, rs}_gift.py 119 | run_finetuning_{ar, si, rs}_gift.py 120 | run_testing_{ar, si, rs}_gift.py 121 | ``` 122 | Specifically for the task of response selection, a ```output_test.txt``` file which records scores for each context-response pair will be saved to the path of ```restore_model_dir``` after testing.
123 | Modify the variable ```test_out_filename``` in ```compute_metrics.py``` and then run ```python compute_metrics.py```, various metrics will be shown. 124 | 125 | 126 | ## Cite 127 | If you think our work is helpful or use the code, please cite the following paper: 128 | 129 | ``` 130 | @inproceedings{gu-etal-2023-gift, 131 | title = "{GIFT}: Graph-Induced Fine-Tuning for Multi-Party Conversation Understanding", 132 | author = "Gu, Jia-Chen and 133 | Ling, Zhen-Hua and 134 | Liu, Quan and 135 | Liu, Cong and 136 | Hu, Guoping", 137 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 138 | month = jul, 139 | year = "2023", 140 | address = "Toronto, Canada", 141 | publisher = "Association for Computational Linguistics", 142 | url = "https://aclanthology.org/2023.acl-long.651", 143 | pages = "11645--11658", 144 | } 145 | ``` 146 | 147 | ``` 148 | @inproceedings{gu-etal-2021-mpc, 149 | title = "{MPC}-{BERT}: A Pre-Trained Language Model for Multi-Party Conversation Understanding", 150 | author = "Gu, Jia-Chen and 151 | Tao, Chongyang and 152 | Ling, Zhen-Hua and 153 | Xu, Can and 154 | Geng, Xiubo and 155 | Jiang, Daxin", 156 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)", 157 | month = aug, 158 | year = "2021", 159 | address = "Online", 160 | publisher = "Association for Computational Linguistics", 161 | url = "https://aclanthology.org/2021.acl-long.285", 162 | pages = "3682--3692", 163 | } 164 | ``` 165 | 166 | 167 | ## Acknowledgments 168 | Thank Wenpeng Hu and Zhangming Chan for providing the processed Hu et al. (2019) dataset used in their [paper](https://www.ijcai.org/proceedings/2019/0696.pdf).
169 | Thank Ran Le for providing the processed Ouchi and Tsuboi (2016) dataset used in their [paper](https://www.aclweb.org/anthology/D19-1199.pdf).
170 | Thank Prasan Yapa for providing a [TF 2.0 version of MPC-BERT](https://github.com/CyraxSector/MPC-BERT-2.0). 171 | 172 | 173 | ## Update 174 | Please keep an eye on this repository if you are interested in our work. 175 | Feel free to contact us (gujc@ustc.edu.cn) or open issues. 176 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /compute_metrics.py: -------------------------------------------------------------------------------- 1 | """Load the output_test.txt file and compute the metrics""" 2 | 3 | 4 | import random 5 | from collections import defaultdict 6 | import metrics 7 | 8 | 9 | test_out_filename = "./output/ijcai2019/PATH_TO_TEST_MODEL/output_test.txt" 10 | print("*"*20 + test_out_filename + "*"*20 + "\n") 11 | 12 | with open(test_out_filename, 'r') as f: 13 | 14 | # candidate size = 10 15 | results = defaultdict(list) 16 | lines = f.readlines() 17 | for line in lines[1:]: 18 | line = line.strip().split('\t') 19 | us_id = line[0] 20 | r_id = line[1] 21 | prob_score = float(line[2]) 22 | label = float(line[4]) 23 | results[us_id].append((r_id, label, prob_score)) 24 | 25 | accu, precision, recall, f1, loss = metrics.classification_metrics(results) 26 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss)) 27 | total_valid_query = metrics.get_num_valid_query(results) 28 | mvp = metrics.mean_average_precision(results) 29 | mrr = metrics.mean_reciprocal_rank(results) 30 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tNum_query: {}'.format( 31 | mvp, mrr, total_valid_query)) 32 | top_1_precision = metrics.top_k_precision(results, k=1) 33 | top_2_precision = metrics.top_k_precision(results, k=2) 34 | top_5_precision = metrics.top_k_precision(results, k=5) 35 | print('Recall_10@1: {}\tRecall_10@2: {}\tRecall_10@5: {}\n'.format( 36 | top_1_precision, top_2_precision, top_5_precision)) 37 | 38 | # candidate size = 2, the result of Recall_2@1 may vary at different runs because the negative candidate is sampled randomly 39 | results_bin = defaultdict(list) 40 | for us_id, candidates in results.items(): 41 | false_candidates = [] 42 | for candidate in candidates: 43 | r_id, label, prob_score = candidate 44 | if label == 1.0: 45 | results_bin[us_id].append(candidate) 46 | if label == 0.0: 47 | false_candidates.append(candidate) 48 | false_candidate = random.sample(false_candidates, 1) 49 | results_bin[us_id].append(false_candidate[0]) 50 | 51 | accu, precision, recall, f1, loss = metrics.classification_metrics(results_bin) 52 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss)) 53 | total_valid_query = metrics.get_num_valid_query(results_bin) 54 | mvp = metrics.mean_average_precision(results_bin) 55 | mrr = metrics.mean_reciprocal_rank(results_bin) 56 | top_1_precision = metrics.top_k_precision(results_bin, k=1) 57 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tNum_query: {}'.format( 58 | mvp, mrr, total_valid_query)) 59 | print('Recall_2@1: {}\n'.format( 60 | top_1_precision)) 61 | -------------------------------------------------------------------------------- /create_finetuning_data_ar.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import json 3 | import random 4 | import numpy as np 5 | import collections 6 | from tqdm import tqdm 7 | import tokenization 8 | import tensorflow as tf 9 | 10 | 11 | """ Hu et al. GSN: A Graph-Structured Network for Multi-Party Dialogues. IJCAI 2019. """ 12 | tf.flags.DEFINE_string("train_file", "./data/ijcai2019/train.json", 13 | "path to train file") 14 | tf.flags.DEFINE_string("valid_file", "./data/ijcai2019/dev.json", 15 | "path to valid file") 16 | tf.flags.DEFINE_string("test_file", "./data/ijcai2019/test.json", 17 | "path to test file") 18 | tf.flags.DEFINE_integer("max_seq_length", 230, 19 | "max sequence length of concatenated context and response") 20 | tf.flags.DEFINE_integer("max_utr_num", 7, 21 | "Maximum utterance number.") 22 | 23 | """ 24 | Ouchi et al. Addressee and Response Selection for Multi-Party Conversation. EMNLP 2016. 25 | relesed the original dataset which is composed of 3 experimental settings according to conversation lengths. 26 | 27 | In our experiments, we used the version processed and used in 28 | Le et al. Who Is Speaking to Whom? Learning to Identify Utterance Addressee in Multi-Party Conversations. EMNLP 2019. 29 | """ 30 | 31 | # Length-5 32 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/5_train.json", 33 | # "path to train file") 34 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/5_dev.json", 35 | # "path to valid file") 36 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/5_test.json", 37 | # "path to test file") 38 | # tf.flags.DEFINE_integer("max_seq_length", 120, 39 | # "max sequence length of concatenated context and response") 40 | # tf.flags.DEFINE_integer("max_utr_num", 5, 41 | # "Maximum utterance number.") 42 | 43 | # Length-10 44 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/10_train.json", 45 | # "path to train file") 46 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/10_dev.json", 47 | # "path to valid file") 48 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/10_test.json", 49 | # "path to test file") 50 | # tf.flags.DEFINE_integer("max_seq_length", 220, 51 | # "max sequence length of concatenated context and response") 52 | # tf.flags.DEFINE_integer("max_utr_num", 10, 53 | # "Maximum utterance number.") 54 | 55 | # Length-15 56 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/15_train.json", 57 | # "path to train file") 58 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/15_dev.json", 59 | # "path to valid file") 60 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/15_test.json", 61 | # "path to test file") 62 | # tf.flags.DEFINE_integer("max_seq_length", 320, 63 | # "max sequence length of concatenated context and response") 64 | # tf.flags.DEFINE_integer("max_utr_num", 15, 65 | # "Maximum utterance number.") 66 | 67 | tf.flags.DEFINE_string("vocab_file", "./uncased_L-12_H-768_A-12/vocab.txt", 68 | "path to vocab file") 69 | tf.flags.DEFINE_bool("do_lower_case", True, 70 | "whether to lower case the input text") 71 | 72 | 73 | 74 | def print_configuration_op(FLAGS): 75 | print('My Configurations:') 76 | for name, value in FLAGS.__flags.items(): 77 | value=value.value 78 | if type(value) == float: 79 | print(' %s:\t %f'%(name, value)) 80 | elif type(value) == int: 81 | print(' %s:\t %d'%(name, value)) 82 | elif type(value) == str: 83 | print(' %s:\t %s'%(name, value)) 84 | elif type(value) == bool: 85 | print(' %s:\t %s'%(name, value)) 86 | else: 87 | print('%s:\t %s' % (name, value)) 88 | print('End of configuration') 89 | 90 | 91 | def load_dataset(fname): 92 | dataset = [] 93 | with open(fname, 'r') as f: 94 | for line in f: 95 | data = json.loads(line) 96 | ctx = data['context'] 97 | ctx_spk = data['ctx_spk'] 98 | ctx_adr = data['ctx_adr'] 99 | rsp = data['answer'] 100 | rsp_spk = data['ans_spk'] 101 | rsp_adr = data['ans_adr'] 102 | 103 | integrate_ctx = ctx + [rsp] 104 | integrate_ctx_spk = ctx_spk + [rsp_spk] 105 | integrate_ctx_adr = ctx_adr + [rsp_adr] 106 | assert len(integrate_ctx) == len(integrate_ctx_spk) 107 | assert len(integrate_ctx) == len(integrate_ctx_adr) 108 | 109 | label = [] 110 | for utr_id_adr, utr_adr in enumerate(integrate_ctx_adr): 111 | 112 | label_utr = [0 for _ in range(len(integrate_ctx))] 113 | for cand_utr_id_spk, cand_utr_spk in enumerate(integrate_ctx_spk[:utr_id_adr]): # consider only the preceding utterances 114 | if cand_utr_spk == utr_adr: 115 | label_utr[cand_utr_id_spk] = 1 116 | label.append(label_utr) 117 | 118 | dataset.append((ctx, ctx_spk, rsp, rsp_spk, label)) 119 | 120 | print("dataset_size: {}".format(len(dataset))) 121 | return dataset 122 | 123 | 124 | class InputExample(object): 125 | def __init__(self, guid, ctx, ctx_spk, rsp, rsp_spk, label): 126 | """Constructs a InputExample.""" 127 | self.guid = guid 128 | self.ctx = ctx 129 | self.ctx_spk = ctx_spk 130 | self.rsp = rsp 131 | self.rsp_spk = rsp_spk 132 | self.label = label 133 | 134 | 135 | def create_examples(lines, set_type): 136 | """Creates examples for datasets.""" 137 | examples = [] 138 | for (i, line) in enumerate(lines): 139 | guid = "%s-%s" % (set_type, str(i)) 140 | ctx = [tokenization.convert_to_unicode(utr) for utr in line[0]] 141 | ctx_spk = line[1] 142 | rsp = tokenization.convert_to_unicode(line[2]) 143 | rsp_spk = line[3] 144 | label = line[-1] 145 | examples.append(InputExample(guid=guid, ctx=ctx, ctx_spk=ctx_spk, rsp=rsp, rsp_spk=rsp_spk, label=label)) 146 | return examples 147 | 148 | 149 | def truncate_seq_pair(ctx_tokens, rsp_tokens, max_length): 150 | """Truncates a sequence pair in place to the maximum length.""" 151 | while True: 152 | utr_lens = [len(utr_tokens) for utr_tokens in ctx_tokens] 153 | total_length = sum(utr_lens) + len(rsp_tokens) 154 | if total_length <= max_length: 155 | break 156 | 157 | # truncate the longest utterance or response 158 | if sum(utr_lens) > len(rsp_tokens): 159 | trunc_tokens = ctx_tokens[np.argmax(np.array(utr_lens))] 160 | else: 161 | trunc_tokens = rsp_tokens 162 | assert len(trunc_tokens) >= 1 163 | 164 | if random.random() < 0.5: 165 | del trunc_tokens[0] 166 | else: 167 | trunc_tokens.pop() 168 | 169 | 170 | class InputFeatures(object): 171 | """A single set of features of data.""" 172 | def __init__(self, input_sents, input_mask, segment_ids, speaker_ids, cls_positions, label_id, label_weights): 173 | self.input_sents = input_sents 174 | self.input_mask = input_mask 175 | self.segment_ids = segment_ids 176 | self.speaker_ids = speaker_ids 177 | self.cls_positions = cls_positions 178 | self.label_id = label_id 179 | self.label_weights = label_weights 180 | 181 | 182 | def convert_examples_to_features(examples, max_seq_length, max_utr_num, tokenizer): 183 | """Loads a data file into a list of `InputBatch`s.""" 184 | 185 | features = [] 186 | for example in tqdm(examples, total=len(examples)): 187 | 188 | ctx_tokens = [] 189 | for utr in example.ctx: 190 | utr_tokens = tokenizer.tokenize(utr) 191 | ctx_tokens.append(utr_tokens) 192 | assert len(ctx_tokens) == len(example.ctx_spk) 193 | 194 | rsp_tokens = tokenizer.tokenize(example.rsp) 195 | 196 | # [CLS]s for context, [CLS] for response, [SEP] 197 | max_num_tokens = max_seq_length - len(ctx_tokens) - 1 - 1 198 | truncate_seq_pair(ctx_tokens, rsp_tokens, max_num_tokens) 199 | 200 | tokens = [] 201 | segment_ids = [] 202 | speaker_ids = [] 203 | cls_positions = [] 204 | 205 | # utterances 206 | for i in range(len(ctx_tokens)): 207 | utr_tokens = ctx_tokens[i] 208 | utr_spk = example.ctx_spk[i] 209 | 210 | cls_positions.append(len(tokens)) 211 | tokens.append("[CLS]") 212 | segment_ids.append(0) 213 | speaker_ids.append(utr_spk) 214 | 215 | for token in utr_tokens: 216 | tokens.append(token) 217 | segment_ids.append(0) 218 | speaker_ids.append(utr_spk) 219 | 220 | # response 221 | cls_positions.append(len(tokens)) 222 | tokens.append("[CLS]") 223 | segment_ids.append(0) 224 | speaker_ids.append(example.rsp_spk) 225 | 226 | for token in rsp_tokens: 227 | tokens.append(token) 228 | segment_ids.append(0) 229 | speaker_ids.append(example.rsp_spk) 230 | 231 | tokens.append("[SEP]") 232 | segment_ids.append(0) 233 | speaker_ids.append(example.rsp_spk) 234 | 235 | 236 | input_sents = tokenizer.convert_tokens_to_ids(tokens) 237 | input_mask = [1] * len(input_sents) 238 | assert len(input_sents) <= max_seq_length 239 | while len(input_sents) < max_seq_length: 240 | input_sents.append(0) 241 | input_mask.append(0) 242 | segment_ids.append(0) 243 | speaker_ids.append(0) 244 | assert len(input_sents) == max_seq_length 245 | assert len(input_mask) == max_seq_length 246 | assert len(segment_ids) == max_seq_length 247 | assert len(speaker_ids) == max_seq_length 248 | 249 | assert len(cls_positions) <= max_utr_num 250 | while len(cls_positions) < max_utr_num: 251 | cls_positions.append(0) 252 | assert len(cls_positions) == max_utr_num 253 | 254 | label_id = [] 255 | for label_utr in example.label: 256 | assert len(label_utr) <= max_utr_num 257 | while len(label_utr) < max_utr_num: 258 | label_utr.append(0) 259 | assert len(label_utr) == max_utr_num 260 | label_id.append(label_utr) 261 | 262 | assert len(label_id) <= max_utr_num 263 | while len(label_id) < max_utr_num: 264 | label_id.append([0] * max_utr_num) 265 | assert len(label_id) == max_utr_num 266 | 267 | label_id_flat = [] 268 | label_weights = [] 269 | for label_utr in label_id: 270 | label_id_flat.extend(label_utr) 271 | if sum(label_utr) > 0: 272 | label_weights.append(1.0) 273 | else: 274 | label_weights.append(0.0) 275 | assert len(label_id_flat) == max_utr_num * max_utr_num 276 | assert len(label_weights) == max_utr_num 277 | 278 | features.append( 279 | InputFeatures( 280 | input_sents=input_sents, 281 | input_mask=input_mask, 282 | segment_ids=segment_ids, 283 | speaker_ids=speaker_ids, 284 | cls_positions=cls_positions, 285 | label_id=label_id_flat, 286 | label_weights=label_weights)) 287 | 288 | return features 289 | 290 | 291 | def write_instance_to_example_files(instances, output_files): 292 | writers = [] 293 | 294 | for output_file in output_files: 295 | writers.append(tf.python_io.TFRecordWriter(output_file)) 296 | 297 | writer_index = 0 298 | total_written = 0 299 | for (inst_index, instance) in enumerate(instances): 300 | features = collections.OrderedDict() 301 | features["input_sents"] = create_int_feature(instance.input_sents) 302 | features["input_mask"] = create_int_feature(instance.input_mask) 303 | features["segment_ids"] = create_int_feature(instance.segment_ids) 304 | features["speaker_ids"] = create_int_feature(instance.speaker_ids) 305 | features["cls_positions"] = create_int_feature(instance.cls_positions) 306 | features["label_ids"] = create_int_feature(instance.label_id) 307 | features["label_weights"] = create_float_feature(instance.label_weights) 308 | 309 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 310 | 311 | writers[writer_index].write(tf_example.SerializeToString()) 312 | writer_index = (writer_index + 1) % len(writers) 313 | 314 | total_written += 1 315 | 316 | print("write_{}_instance_to_example_files".format(total_written)) 317 | 318 | for feature_name in features.keys(): 319 | feature = features[feature_name] 320 | values = [] 321 | if feature.int64_list.value: 322 | values = feature.int64_list.value 323 | elif feature.float_list.value: 324 | values = feature.float_list.value 325 | tf.logging.info( 326 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 327 | 328 | for writer in writers: 329 | writer.close() 330 | 331 | 332 | def create_int_feature(values): 333 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 334 | return feature 335 | 336 | def create_float_feature(values): 337 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 338 | return feature 339 | 340 | 341 | 342 | if __name__ == "__main__": 343 | 344 | FLAGS = tf.flags.FLAGS 345 | print_configuration_op(FLAGS) 346 | 347 | tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 348 | 349 | filenames = [FLAGS.train_file, FLAGS.valid_file, FLAGS.test_file] 350 | filetypes = ["train", "valid", "test"] 351 | for (filename, filetype) in zip(filenames, filetypes): 352 | dataset = load_dataset(filename) 353 | examples = create_examples(dataset, filetype) 354 | features = convert_examples_to_features(examples, FLAGS.max_seq_length, FLAGS.max_utr_num, tokenizer) 355 | new_filename = filename[:-5] + "_ar.tfrecord" 356 | write_instance_to_example_files(features, [new_filename]) 357 | print('Convert {} to {} done'.format(filename, new_filename)) 358 | 359 | print("Sub-process(es) done.") 360 | -------------------------------------------------------------------------------- /create_finetuning_data_rs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import json 3 | import random 4 | import numpy as np 5 | import collections 6 | from tqdm import tqdm 7 | import tokenization 8 | import tensorflow as tf 9 | 10 | 11 | """ Hu et al. GSN: A Graph-Structured Network for Multi-Party Dialogues. IJCAI 2019. """ 12 | tf.flags.DEFINE_string("train_file", "./data/ijcai2019/train.json", 13 | "path to train file") 14 | tf.flags.DEFINE_string("valid_file", "./data/ijcai2019/dev.json", 15 | "path to valid file") 16 | tf.flags.DEFINE_string("test_file", "./data/ijcai2019/test.json", 17 | "path to test file") 18 | tf.flags.DEFINE_integer("max_seq_length", 230, 19 | "max sequence length of concatenated context and response") 20 | tf.flags.DEFINE_integer("max_utr_num", 7, 21 | "Maximum utterance number.") 22 | 23 | """ 24 | Ouchi et al. Addressee and Response Selection for Multi-Party Conversation. EMNLP 2016. 25 | relesed the original dataset which is composed of 3 experimental settings according to conversation lengths. 26 | 27 | In our experiments, we used the version processed and used in 28 | Le et al. Who Is Speaking to Whom? Learning to Identify Utterance Addressee in Multi-Party Conversations. EMNLP 2019. 29 | """ 30 | 31 | # Length-5 32 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/5_train.json", 33 | # "path to train file") 34 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/5_dev.json", 35 | # "path to valid file") 36 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/5_test.json", 37 | # "path to test file") 38 | # tf.flags.DEFINE_integer("max_seq_length", 120, 39 | # "max sequence length of concatenated context and response") 40 | # tf.flags.DEFINE_integer("max_utr_num", 5, 41 | # "Maximum utterance number.") 42 | 43 | # Length-10 44 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/10_train.json", 45 | # "path to train file") 46 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/10_dev.json", 47 | # "path to valid file") 48 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/10_test.json", 49 | # "path to test file") 50 | # tf.flags.DEFINE_integer("max_seq_length", 220, 51 | # "max sequence length of concatenated context and response") 52 | # tf.flags.DEFINE_integer("max_utr_num", 10, 53 | # "Maximum utterance number.") 54 | 55 | # Length-15 56 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/15_train.json", 57 | # "path to train file") 58 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/15_dev.json", 59 | # "path to valid file") 60 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/15_test.json", 61 | # "path to test file") 62 | # tf.flags.DEFINE_integer("max_seq_length", 320, 63 | # "max sequence length of concatenated context and response") 64 | # tf.flags.DEFINE_integer("max_utr_num", 15, 65 | # "Maximum utterance number.") 66 | 67 | tf.flags.DEFINE_string("vocab_file", "./uncased_L-12_H-768_A-12/vocab.txt", 68 | "path to vocab file") 69 | tf.flags.DEFINE_bool("do_lower_case", True, 70 | "whether to lower case the input text") 71 | 72 | 73 | 74 | def print_configuration_op(FLAGS): 75 | print('My Configurations:') 76 | for name, value in FLAGS.__flags.items(): 77 | value=value.value 78 | if type(value) == float: 79 | print(' %s:\t %f'%(name, value)) 80 | elif type(value) == int: 81 | print(' %s:\t %d'%(name, value)) 82 | elif type(value) == str: 83 | print(' %s:\t %s'%(name, value)) 84 | elif type(value) == bool: 85 | print(' %s:\t %s'%(name, value)) 86 | else: 87 | print('%s:\t %s' % (name, value)) 88 | print('End of configuration') 89 | 90 | 91 | def load_dataset(fname, n_negative): 92 | ctx_list = [] 93 | ctx_spk_list = [] 94 | rsp_list = [] 95 | rsp_spk_list = [] 96 | with open(fname, 'r') as f: 97 | for line in f: 98 | data = json.loads(line) 99 | ctx_list.append(data['context']) 100 | ctx_spk_list.append(data['ctx_spk']) 101 | rsp_list.append(data['answer']) 102 | rsp_spk_list.append(data['ans_spk']) 103 | print("matched context-response pairs: {}".format(len(ctx_list))) 104 | 105 | dataset = [] 106 | index_list = list(range(len(ctx_list))) 107 | for i in range(len(ctx_list)): 108 | ctx = ctx_list[i] 109 | ctx_spk = ctx_spk_list[i] 110 | 111 | # positive 112 | rsp = rsp_list[i] 113 | rsp_spk = rsp_spk_list[i] 114 | dataset.append((i, ctx, ctx_spk, i, rsp, rsp_spk, 'follow')) 115 | 116 | # negative 117 | negatives = random.sample(index_list, n_negative) 118 | while i in negatives: 119 | negatives = random.sample(index_list, n_negative) 120 | assert i not in negatives 121 | for n_id in negatives: 122 | dataset.append((i, ctx, ctx_spk, n_id, rsp_list[n_id], rsp_spk, 'unfollow')) 123 | 124 | print("dataset_size: {}".format(len(dataset))) 125 | return dataset 126 | 127 | 128 | class InputExample(object): 129 | def __init__(self, guid, ctx_id, ctx, ctx_spk, rsp_id, rsp, rsp_spk, label): 130 | """Constructs a InputExample.""" 131 | self.guid = guid 132 | self.ctx_id = ctx_id 133 | self.ctx = ctx 134 | self.ctx_spk = ctx_spk 135 | self.rsp_id = rsp_id 136 | self.rsp = rsp 137 | self.rsp_spk = rsp_spk 138 | self.label = label 139 | 140 | 141 | def create_examples(lines, set_type): 142 | """Creates examples for datasets.""" 143 | examples = [] 144 | for (i, line) in enumerate(lines): 145 | guid = "%s-%s" % (set_type, str(i)) 146 | ctx_id = line[0] 147 | ctx = [tokenization.convert_to_unicode(utr) for utr in line[1]] 148 | ctx_spk = line[2] 149 | rsp_id = line[3] 150 | rsp = tokenization.convert_to_unicode(line[4]) 151 | rsp_spk = line[5] 152 | label = tokenization.convert_to_unicode(line[-1]) 153 | examples.append(InputExample(guid=guid, ctx_id=ctx_id, ctx=ctx, ctx_spk=ctx_spk, 154 | rsp_id=rsp_id, rsp=rsp, rsp_spk=rsp_spk, label=label)) 155 | return examples 156 | 157 | 158 | def truncate_seq_pair(tokens_a, tokens_b, max_length): 159 | """Truncates a sequence pair in place to the maximum length.""" 160 | while True: 161 | total_length = len(tokens_a) + len(tokens_b) 162 | if total_length <= max_length: 163 | break 164 | 165 | if len(tokens_a) > len(tokens_b): 166 | trunc_tokens = tokens_a 167 | else: 168 | trunc_tokens = tokens_b 169 | 170 | if random.random() < 0.5: 171 | del trunc_tokens[0] 172 | else: 173 | trunc_tokens.pop() 174 | 175 | 176 | class InputFeatures(object): 177 | """A single set of features of data.""" 178 | def __init__(self, ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, label_id): 179 | self.ctx_id = ctx_id 180 | self.rsp_id = rsp_id 181 | self.input_sents = input_sents 182 | self.input_mask = input_mask 183 | self.segment_ids = segment_ids 184 | self.speaker_ids = speaker_ids 185 | self.label_id = label_id 186 | 187 | 188 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 189 | """Loads a data file into a list of `InputBatch`s.""" 190 | 191 | label_map = {} 192 | for (i, label) in enumerate(label_list): # ['0', '1'] 193 | label_map[label] = i 194 | 195 | features = [] 196 | for example in tqdm(examples, total=len(examples)): 197 | ctx_id = int(example.ctx_id) 198 | rsp_id = int(example.rsp_id) 199 | 200 | ctx_tokens = [] 201 | ctx_spk = [] 202 | for i in range(len(example.ctx)): 203 | utr = example.ctx[i] 204 | utr_spk = example.ctx_spk[i] 205 | utr_tokens = tokenizer.tokenize(utr) 206 | ctx_tokens.extend(utr_tokens) 207 | ctx_spk.extend([utr_spk]*len(utr_tokens)) 208 | assert len(ctx_tokens) == len(ctx_spk) 209 | 210 | rsp_tokens = tokenizer.tokenize(example.rsp) 211 | 212 | # Account for [CLS], [SEP], [SEP] with "- 3" 213 | truncate_seq_pair(ctx_tokens, rsp_tokens, max_seq_length - 3) 214 | 215 | 216 | tokens = [] 217 | segment_ids = [] 218 | speaker_ids = [] 219 | tokens.append("[CLS]") 220 | segment_ids.append(0) 221 | speaker_ids.append(0) 222 | for token_idx, token in enumerate(ctx_tokens): 223 | tokens.append(token) 224 | segment_ids.append(0) 225 | speaker_ids.append(ctx_spk[token_idx]) 226 | tokens.append("[SEP]") 227 | segment_ids.append(0) 228 | speaker_ids.append(0) 229 | 230 | for token_idx, token in enumerate(rsp_tokens): 231 | tokens.append(token) 232 | segment_ids.append(1) 233 | # speaker_ids.append(0) # 0 for mask speaker 234 | speaker_ids.append(example.rsp_spk) 235 | tokens.append("[SEP]") 236 | segment_ids.append(1) 237 | speaker_ids.append(example.rsp_spk) 238 | 239 | input_sents = tokenizer.convert_tokens_to_ids(tokens) 240 | input_mask = [1] * len(input_sents) 241 | assert len(input_sents) <= max_seq_length 242 | while len(input_sents) < max_seq_length: 243 | input_sents.append(0) 244 | input_mask.append(0) 245 | segment_ids.append(0) 246 | speaker_ids.append(0) 247 | assert len(input_sents) == max_seq_length 248 | assert len(input_mask) == max_seq_length 249 | assert len(segment_ids) == max_seq_length 250 | assert len(speaker_ids) == max_seq_length 251 | 252 | label_id = label_map[example.label] 253 | 254 | features.append( 255 | InputFeatures( 256 | ctx_id=ctx_id, 257 | rsp_id = rsp_id, 258 | input_sents=input_sents, 259 | input_mask=input_mask, 260 | segment_ids=segment_ids, 261 | speaker_ids=speaker_ids, 262 | label_id=label_id)) 263 | 264 | return features 265 | 266 | 267 | def write_instance_to_example_files(instances, output_files): 268 | writers = [] 269 | 270 | for output_file in output_files: 271 | writers.append(tf.python_io.TFRecordWriter(output_file)) 272 | 273 | writer_index = 0 274 | total_written = 0 275 | for (inst_index, instance) in enumerate(instances): 276 | features = collections.OrderedDict() 277 | features["ctx_id"] = create_int_feature([instance.ctx_id]) 278 | features["rsp_id"] = create_int_feature([instance.rsp_id]) 279 | features["input_sents"] = create_int_feature(instance.input_sents) 280 | features["input_mask"] = create_int_feature(instance.input_mask) 281 | features["segment_ids"] = create_int_feature(instance.segment_ids) 282 | features["speaker_ids"] = create_int_feature(instance.speaker_ids) 283 | features["label_ids"] = create_float_feature([instance.label_id]) 284 | 285 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 286 | 287 | writers[writer_index].write(tf_example.SerializeToString()) 288 | writer_index = (writer_index + 1) % len(writers) 289 | 290 | total_written += 1 291 | 292 | print("write_{}_instance_to_example_files".format(total_written)) 293 | 294 | for feature_name in features.keys(): 295 | feature = features[feature_name] 296 | values = [] 297 | if feature.int64_list.value: 298 | values = feature.int64_list.value 299 | elif feature.float_list.value: 300 | values = feature.float_list.value 301 | tf.logging.info( 302 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 303 | 304 | for writer in writers: 305 | writer.close() 306 | 307 | 308 | def create_int_feature(values): 309 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 310 | return feature 311 | 312 | def create_float_feature(values): 313 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 314 | return feature 315 | 316 | 317 | 318 | if __name__ == "__main__": 319 | 320 | FLAGS = tf.flags.FLAGS 321 | print_configuration_op(FLAGS) 322 | 323 | tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 324 | label_list = ["unfollow", "follow"] 325 | 326 | filenames = [FLAGS.train_file, FLAGS.valid_file, FLAGS.test_file] 327 | filetypes = ["train", "valid", "test"] 328 | file_n_negative = [1, 9, 9] 329 | 330 | for (filename, filetype, n_negative) in zip(filenames, filetypes, file_n_negative): 331 | dataset = load_dataset(filename, n_negative) 332 | examples = create_examples(dataset, filetype) 333 | features = convert_examples_to_features(examples, label_list, FLAGS.max_seq_length, tokenizer) 334 | new_filename = filename[:-5] + "_rs.tfrecord" 335 | write_instance_to_example_files(features, [new_filename]) 336 | print('Convert {} to {} done'.format(filename, new_filename)) 337 | 338 | print("Sub-process(es) done.") 339 | -------------------------------------------------------------------------------- /create_finetuning_data_si.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import json 3 | import random 4 | import numpy as np 5 | import collections 6 | from tqdm import tqdm 7 | import tokenization 8 | import tensorflow as tf 9 | 10 | 11 | """ Hu et al. GSN: A Graph-Structured Network for Multi-Party Dialogues. IJCAI 2019. """ 12 | tf.flags.DEFINE_string("train_file", "./data/ijcai2019/train.json", 13 | "path to train file") 14 | tf.flags.DEFINE_string("valid_file", "./data/ijcai2019/dev.json", 15 | "path to valid file") 16 | tf.flags.DEFINE_string("test_file", "./data/ijcai2019/test.json", 17 | "path to test file") 18 | tf.flags.DEFINE_integer("max_seq_length", 230, 19 | "max sequence length of concatenated context and response") 20 | tf.flags.DEFINE_integer("max_utr_num", 7, 21 | "Maximum utterance number.") 22 | 23 | """ 24 | Ouchi et al. Addressee and Response Selection for Multi-Party Conversation. EMNLP 2016. 25 | relesed the original dataset which is composed of 3 experimental settings according to conversation lengths. 26 | 27 | In our experiments, we used the version processed and used in 28 | Le et al. Who Is Speaking to Whom? Learning to Identify Utterance Addressee in Multi-Party Conversations. EMNLP 2019. 29 | """ 30 | 31 | # Length-5 32 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/5_train.json", 33 | # "path to train file") 34 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/5_dev.json", 35 | # "path to valid file") 36 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/5_test.json", 37 | # "path to test file") 38 | # tf.flags.DEFINE_integer("max_seq_length", 120, 39 | # "max sequence length of concatenated context and response") 40 | # tf.flags.DEFINE_integer("max_utr_num", 5, 41 | # "Maximum utterance number.") 42 | 43 | # Length-10 44 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/10_train.json", 45 | # "path to train file") 46 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/10_dev.json", 47 | # "path to valid file") 48 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/10_test.json", 49 | # "path to test file") 50 | # tf.flags.DEFINE_integer("max_seq_length", 220, 51 | # "max sequence length of concatenated context and response") 52 | # tf.flags.DEFINE_integer("max_utr_num", 10, 53 | # "Maximum utterance number.") 54 | 55 | # Length-15 56 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/15_train.json", 57 | # "path to train file") 58 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/15_dev.json", 59 | # "path to valid file") 60 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/15_test.json", 61 | # "path to test file") 62 | # tf.flags.DEFINE_integer("max_seq_length", 320, 63 | # "max sequence length of concatenated context and response") 64 | # tf.flags.DEFINE_integer("max_utr_num", 15, 65 | # "Maximum utterance number.") 66 | 67 | tf.flags.DEFINE_string("vocab_file", "./uncased_L-12_H-768_A-12/vocab.txt", 68 | "path to vocab file") 69 | tf.flags.DEFINE_bool("do_lower_case", True, 70 | "whether to lower case the input text") 71 | 72 | 73 | 74 | def print_configuration_op(FLAGS): 75 | print('My Configurations:') 76 | for name, value in FLAGS.__flags.items(): 77 | value=value.value 78 | if type(value) == float: 79 | print(' %s:\t %f'%(name, value)) 80 | elif type(value) == int: 81 | print(' %s:\t %d'%(name, value)) 82 | elif type(value) == str: 83 | print(' %s:\t %s'%(name, value)) 84 | elif type(value) == bool: 85 | print(' %s:\t %s'%(name, value)) 86 | else: 87 | print('%s:\t %s' % (name, value)) 88 | print('End of configuration') 89 | 90 | 91 | def load_dataset(fname): 92 | dataset = [] 93 | with open(fname, 'r') as f: 94 | for line in f: 95 | data = json.loads(line) 96 | ctx = data['context'] 97 | ctx_spk = data['ctx_spk'] 98 | rsp = data['answer'] 99 | rsp_spk = data['ans_spk'] 100 | assert len(ctx) ==len(ctx_spk) 101 | 102 | utrs_same_spk_with_rsp_spk = [] 103 | for utr_id, utr_spk in enumerate(ctx_spk): 104 | if utr_spk == rsp_spk: 105 | utrs_same_spk_with_rsp_spk.append(utr_id) 106 | 107 | if len(utrs_same_spk_with_rsp_spk) == 0: 108 | continue 109 | 110 | label = [0 for _ in range(len(ctx))] 111 | for utr_id in utrs_same_spk_with_rsp_spk: 112 | label[utr_id] = 1 113 | 114 | dataset.append((ctx, ctx_spk, rsp, rsp_spk, label)) 115 | 116 | print("dataset_size: {}".format(len(dataset))) 117 | return dataset 118 | 119 | 120 | class InputExample(object): 121 | def __init__(self, guid, ctx, ctx_spk, rsp, rsp_spk, label): 122 | """Constructs a InputExample.""" 123 | self.guid = guid 124 | self.ctx = ctx 125 | self.ctx_spk = ctx_spk 126 | self.rsp = rsp 127 | self.rsp_spk = rsp_spk 128 | self.label = label 129 | 130 | 131 | def create_examples(lines, set_type): 132 | """Creates examples for datasets.""" 133 | examples = [] 134 | for (i, line) in enumerate(lines): 135 | guid = "%s-%s" % (set_type, str(i)) 136 | ctx = [tokenization.convert_to_unicode(utr) for utr in line[0]] 137 | ctx_spk = line[1] 138 | rsp = tokenization.convert_to_unicode(line[2]) 139 | rsp_spk = line[3] 140 | label = line[-1] 141 | examples.append(InputExample(guid=guid, ctx=ctx, ctx_spk=ctx_spk, rsp=rsp, rsp_spk=rsp_spk, label=label)) 142 | return examples 143 | 144 | 145 | def truncate_seq_pair(ctx_tokens, rsp_tokens, max_length): 146 | """Truncates a sequence pair in place to the maximum length.""" 147 | while True: 148 | utr_lens = [len(utr_tokens) for utr_tokens in ctx_tokens] 149 | total_length = sum(utr_lens) + len(rsp_tokens) 150 | if total_length <= max_length: 151 | break 152 | 153 | # truncate the longest utterance or response 154 | if sum(utr_lens) > len(rsp_tokens): 155 | trunc_tokens = ctx_tokens[np.argmax(np.array(utr_lens))] 156 | else: 157 | trunc_tokens = rsp_tokens 158 | assert len(trunc_tokens) >= 1 159 | 160 | if random.random() < 0.5: 161 | del trunc_tokens[0] 162 | else: 163 | trunc_tokens.pop() 164 | 165 | 166 | class InputFeatures(object): 167 | """A single set of features of data.""" 168 | def __init__(self, input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, label_id): 169 | self.input_sents = input_sents 170 | self.input_mask = input_mask 171 | self.segment_ids = segment_ids 172 | self.speaker_ids = speaker_ids 173 | self.cls_positions = cls_positions 174 | self.rsp_position = rsp_position 175 | self.label_id = label_id 176 | 177 | 178 | def convert_examples_to_features(examples, max_seq_length, max_utr_num, tokenizer): 179 | """Loads a data file into a list of `InputBatch`s.""" 180 | 181 | features = [] 182 | for example in tqdm(examples, total=len(examples)): 183 | 184 | ctx_tokens = [] 185 | for utr in example.ctx: 186 | utr_tokens = tokenizer.tokenize(utr) 187 | ctx_tokens.append(utr_tokens) 188 | assert len(ctx_tokens) == len(example.ctx_spk) 189 | 190 | rsp_tokens = tokenizer.tokenize(example.rsp) 191 | 192 | # [CLS]s for context, [CLS] for response, [SEP] 193 | max_num_tokens = max_seq_length - len(ctx_tokens) - 1 - 1 194 | truncate_seq_pair(ctx_tokens, rsp_tokens, max_num_tokens) 195 | 196 | 197 | tokens = [] 198 | segment_ids = [] 199 | speaker_ids = [] 200 | cls_positions = [] 201 | rsp_position = [] 202 | 203 | # utterances 204 | for i in range(len(ctx_tokens)): 205 | utr_tokens = ctx_tokens[i] 206 | utr_spk = example.ctx_spk[i] 207 | 208 | cls_positions.append(len(tokens)) 209 | tokens.append("[CLS]") 210 | segment_ids.append(0) 211 | speaker_ids.append(utr_spk) 212 | 213 | for token in utr_tokens: 214 | tokens.append(token) 215 | segment_ids.append(0) 216 | speaker_ids.append(utr_spk) 217 | 218 | # response 219 | rsp_position.append(len(cls_positions)) 220 | cls_positions.append(len(tokens)) 221 | tokens.append("[CLS]") 222 | segment_ids.append(0) 223 | # speaker_ids.append(example.rsp_spk) 224 | speaker_ids.append(0) # 0 for mask 225 | 226 | for token in rsp_tokens: 227 | tokens.append(token) 228 | segment_ids.append(0) 229 | # speaker_ids.append(example.rsp_spk) 230 | speaker_ids.append(0) 231 | 232 | tokens.append("[SEP]") 233 | segment_ids.append(0) 234 | # speaker_ids.append(example.rsp_spk) 235 | speaker_ids.append(0) 236 | 237 | 238 | input_sents = tokenizer.convert_tokens_to_ids(tokens) 239 | input_mask = [1] * len(input_sents) 240 | assert len(input_sents) <= max_seq_length 241 | while len(input_sents) < max_seq_length: 242 | input_sents.append(0) 243 | input_mask.append(0) 244 | segment_ids.append(0) 245 | speaker_ids.append(0) 246 | assert len(input_sents) == max_seq_length 247 | assert len(input_mask) == max_seq_length 248 | assert len(segment_ids) == max_seq_length 249 | assert len(speaker_ids) == max_seq_length 250 | 251 | assert len(cls_positions) <= max_utr_num 252 | while len(cls_positions) < max_utr_num: 253 | cls_positions.append(0) 254 | assert len(cls_positions) == max_utr_num 255 | 256 | label_id = example.label 257 | assert len(label_id) <= max_utr_num 258 | while len(label_id) < max_utr_num: 259 | label_id.append(0) 260 | assert len(label_id) == max_utr_num 261 | 262 | features.append( 263 | InputFeatures( 264 | input_sents=input_sents, 265 | input_mask=input_mask, 266 | segment_ids=segment_ids, 267 | speaker_ids=speaker_ids, 268 | cls_positions=cls_positions, 269 | rsp_position=rsp_position, 270 | label_id=label_id)) 271 | 272 | return features 273 | 274 | 275 | def write_instance_to_example_files(instances, output_files): 276 | writers = [] 277 | 278 | for output_file in output_files: 279 | writers.append(tf.python_io.TFRecordWriter(output_file)) 280 | 281 | writer_index = 0 282 | total_written = 0 283 | for (inst_index, instance) in enumerate(instances): 284 | features = collections.OrderedDict() 285 | features["input_sents"] = create_int_feature(instance.input_sents) 286 | features["input_mask"] = create_int_feature(instance.input_mask) 287 | features["segment_ids"] = create_int_feature(instance.segment_ids) 288 | features["speaker_ids"] = create_int_feature(instance.speaker_ids) 289 | features["cls_positions"] = create_int_feature(instance.cls_positions) 290 | features["rsp_position"] = create_int_feature(instance.rsp_position) 291 | features["label_ids"] = create_int_feature(instance.label_id) 292 | 293 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 294 | 295 | writers[writer_index].write(tf_example.SerializeToString()) 296 | writer_index = (writer_index + 1) % len(writers) 297 | 298 | total_written += 1 299 | 300 | print("write_{}_instance_to_example_files".format(total_written)) 301 | 302 | for feature_name in features.keys(): 303 | feature = features[feature_name] 304 | values = [] 305 | if feature.int64_list.value: 306 | values = feature.int64_list.value 307 | elif feature.float_list.value: 308 | values = feature.float_list.value 309 | tf.logging.info( 310 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 311 | 312 | for writer in writers: 313 | writer.close() 314 | 315 | 316 | def create_int_feature(values): 317 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 318 | return feature 319 | 320 | def create_float_feature(values): 321 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 322 | return feature 323 | 324 | 325 | 326 | if __name__ == "__main__": 327 | 328 | FLAGS = tf.flags.FLAGS 329 | print_configuration_op(FLAGS) 330 | 331 | tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 332 | 333 | filenames = [FLAGS.train_file, FLAGS.valid_file, FLAGS.test_file] 334 | filetypes = ["train", "valid", "test"] 335 | for (filename, filetype) in zip(filenames, filetypes): 336 | dataset = load_dataset(filename) 337 | examples = create_examples(dataset, filetype) 338 | features = convert_examples_to_features(examples, FLAGS.max_seq_length, FLAGS.max_utr_num, tokenizer) 339 | new_filename = filename[:-5] + "_si.tfrecord" 340 | write_instance_to_example_files(features, [new_filename]) 341 | print('Convert {} to {} done'.format(filename, new_filename)) 342 | 343 | print("Sub-process(es) done.") 344 | -------------------------------------------------------------------------------- /create_finetuning_data_si_gift.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import json 3 | import random 4 | import numpy as np 5 | import collections 6 | from tqdm import tqdm 7 | import tokenization 8 | import tensorflow as tf 9 | 10 | 11 | """ Hu et al. GSN: A Graph-Structured Network for Multi-Party Dialogues. IJCAI 2019. """ 12 | tf.flags.DEFINE_string("train_file", "./data/ijcai2019/train.json", 13 | "path to train file") 14 | tf.flags.DEFINE_string("valid_file", "./data/ijcai2019/dev.json", 15 | "path to valid file") 16 | tf.flags.DEFINE_string("test_file", "./data/ijcai2019/test.json", 17 | "path to test file") 18 | tf.flags.DEFINE_integer("max_seq_length", 230, 19 | "max sequence length of concatenated context and response") 20 | tf.flags.DEFINE_integer("max_utr_num", 7, 21 | "Maximum utterance number.") 22 | 23 | """ 24 | Ouchi et al. Addressee and Response Selection for Multi-Party Conversation. EMNLP 2016. 25 | relesed the original dataset which is composed of 3 experimental settings according to conversation lengths. 26 | 27 | In our experiments, we used the version processed and used in 28 | Le et al. Who Is Speaking to Whom? Learning to Identify Utterance Addressee in Multi-Party Conversations. EMNLP 2019. 29 | """ 30 | 31 | # Length-5 32 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/5_train.json", 33 | # "path to train file") 34 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/5_dev.json", 35 | # "path to valid file") 36 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/5_test.json", 37 | # "path to test file") 38 | # tf.flags.DEFINE_integer("max_seq_length", 120, 39 | # "max sequence length of concatenated context and response") 40 | # tf.flags.DEFINE_integer("max_utr_num", 5, 41 | # "Maximum utterance number.") 42 | 43 | # Length-10 44 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/10_train.json", 45 | # "path to train file") 46 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/10_dev.json", 47 | # "path to valid file") 48 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/10_test.json", 49 | # "path to test file") 50 | # tf.flags.DEFINE_integer("max_seq_length", 220, 51 | # "max sequence length of concatenated context and response") 52 | # tf.flags.DEFINE_integer("max_utr_num", 10, 53 | # "Maximum utterance number.") 54 | 55 | # Length-15 56 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/15_train.json", 57 | # "path to train file") 58 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/15_dev.json", 59 | # "path to valid file") 60 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/15_test.json", 61 | # "path to test file") 62 | # tf.flags.DEFINE_integer("max_seq_length", 320, 63 | # "max sequence length of concatenated context and response") 64 | # tf.flags.DEFINE_integer("max_utr_num", 15, 65 | # "Maximum utterance number.") 66 | 67 | tf.flags.DEFINE_string("vocab_file", "./uncased_L-12_H-768_A-12/vocab.txt", 68 | "path to vocab file") 69 | tf.flags.DEFINE_bool("do_lower_case", True, 70 | "whether to lower case the input text") 71 | 72 | 73 | 74 | def print_configuration_op(FLAGS): 75 | print('My Configurations:') 76 | for name, value in FLAGS.__flags.items(): 77 | value=value.value 78 | if type(value) == float: 79 | print(' %s:\t %f'%(name, value)) 80 | elif type(value) == int: 81 | print(' %s:\t %d'%(name, value)) 82 | elif type(value) == str: 83 | print(' %s:\t %s'%(name, value)) 84 | elif type(value) == bool: 85 | print(' %s:\t %s'%(name, value)) 86 | else: 87 | print('%s:\t %s' % (name, value)) 88 | print('End of configuration') 89 | 90 | 91 | def load_dataset(fname): 92 | dataset = [] 93 | with open(fname, 'r') as f: 94 | for line in f: 95 | data = json.loads(line) 96 | ctx = data['context'] 97 | ctx_spk = data['ctx_spk'] 98 | ctx_adr = data['ctx_adr'] 99 | rsp = data['answer'] 100 | rsp_spk = data['ans_spk'] 101 | rsp_adr = data['ans_adr'] 102 | ctx_relation = data['relation_at'] 103 | rsp_relation = data['ans_idx'] 104 | 105 | utrs_same_spk_with_rsp_spk = [] 106 | for utr_id, utr_spk in enumerate(ctx_spk): 107 | if utr_spk == rsp_spk: 108 | utrs_same_spk_with_rsp_spk.append(utr_id) 109 | 110 | if len(utrs_same_spk_with_rsp_spk) == 0: 111 | continue 112 | 113 | label = [0 for _ in range(len(ctx))] 114 | for utr_id in utrs_same_spk_with_rsp_spk: 115 | label[utr_id] = 1 116 | 117 | # construct the reply mask 118 | integrate_ctx = ctx + [rsp] 119 | integrate_ctx_spk = ctx_spk + [rsp_spk] 120 | integrate_ctx_adr = ctx_adr + [rsp_adr] 121 | assert len(integrate_ctx) == len(integrate_ctx_spk) 122 | assert len(integrate_ctx) == len(integrate_ctx_adr) 123 | integrate_ctx_relation = ctx_relation + [[len(ctx), rsp_relation]] 124 | 125 | reply_mask = [[0 for _ in range(len(integrate_ctx))] for _ in range(len(integrate_ctx))] 126 | for relation in integrate_ctx_relation: 127 | tgt, src = relation 128 | reply_mask[tgt][src] = 1 # reply 129 | reply_mask[src][tgt] = 2 # replied_by 130 | for diagonal in range(len(integrate_ctx)): 131 | reply_mask[diagonal][diagonal] = 3 # reply_to_itself 132 | 133 | dataset.append((ctx, ctx_spk, rsp, rsp_spk, reply_mask, label)) 134 | 135 | print("dataset_size: {}".format(len(dataset))) 136 | return dataset 137 | 138 | 139 | class InputExample(object): 140 | def __init__(self, guid, ctx, ctx_spk, rsp, rsp_spk, label, reply_mask): 141 | """Constructs a InputExample.""" 142 | self.guid = guid 143 | self.ctx = ctx 144 | self.ctx_spk = ctx_spk 145 | self.rsp = rsp 146 | self.rsp_spk = rsp_spk 147 | self.reply_mask = reply_mask 148 | self.label = label 149 | 150 | 151 | def create_examples(lines, set_type): 152 | """Creates examples for datasets.""" 153 | examples = [] 154 | for (i, line) in enumerate(lines): 155 | guid = "%s-%s" % (set_type, str(i)) 156 | ctx = [tokenization.convert_to_unicode(utr) for utr in line[0]] 157 | ctx_spk = line[1] 158 | rsp = tokenization.convert_to_unicode(line[2]) 159 | rsp_spk = line[3] 160 | reply_mask = line[4] 161 | label = line[-1] 162 | examples.append(InputExample(guid=guid, ctx=ctx, ctx_spk=ctx_spk, rsp=rsp, rsp_spk=rsp_spk, reply_mask=reply_mask, label=label)) 163 | return examples 164 | 165 | 166 | def truncate_seq_pair(ctx_tokens, rsp_tokens, max_length): 167 | """Truncates a sequence pair in place to the maximum length.""" 168 | while True: 169 | utr_lens = [len(utr_tokens) for utr_tokens in ctx_tokens] 170 | total_length = sum(utr_lens) + len(rsp_tokens) 171 | if total_length <= max_length: 172 | break 173 | 174 | # truncate the longest utterance or response 175 | if sum(utr_lens) > len(rsp_tokens): 176 | trunc_tokens = ctx_tokens[np.argmax(np.array(utr_lens))] 177 | else: 178 | trunc_tokens = rsp_tokens 179 | assert len(trunc_tokens) >= 1 180 | 181 | if random.random() < 0.5: 182 | del trunc_tokens[0] 183 | else: 184 | trunc_tokens.pop() 185 | 186 | 187 | class InputFeatures(object): 188 | """A single set of features of data.""" 189 | def __init__(self, input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, reply_mask_utr2word_flatten, utr_lens, label_id): 190 | self.input_sents = input_sents 191 | self.input_mask = input_mask 192 | self.segment_ids = segment_ids 193 | self.speaker_ids = speaker_ids 194 | self.cls_positions = cls_positions 195 | self.rsp_position = rsp_position 196 | self.reply_mask_utr2word_flatten = reply_mask_utr2word_flatten 197 | self.utr_lens = utr_lens 198 | self.label_id = label_id 199 | 200 | 201 | def convert_examples_to_features(examples, max_seq_length, max_utr_num, tokenizer): 202 | """Loads a data file into a list of `InputBatch`s.""" 203 | 204 | features = [] 205 | for example in tqdm(examples, total=len(examples)): 206 | 207 | ctx_tokens = [] 208 | for utr in example.ctx: 209 | utr_tokens = tokenizer.tokenize(utr) 210 | ctx_tokens.append(utr_tokens) 211 | assert len(ctx_tokens) == len(example.ctx_spk) 212 | 213 | rsp_tokens = tokenizer.tokenize(example.rsp) 214 | 215 | # [CLS]s for context, [CLS] for response, [SEP] 216 | max_num_tokens = max_seq_length - len(ctx_tokens) - 1 - 1 217 | truncate_seq_pair(ctx_tokens, rsp_tokens, max_num_tokens) 218 | 219 | 220 | tokens = [] 221 | segment_ids = [] 222 | speaker_ids = [] 223 | cls_positions = [] 224 | rsp_position = [] 225 | reply_mask_utr2word = [[] for _ in range(len(example.reply_mask))] 226 | utr_lens = [] 227 | 228 | # utterances 229 | for i in range(len(ctx_tokens)): 230 | utr_tokens = ctx_tokens[i] 231 | utr_spk = example.ctx_spk[i] 232 | 233 | utr_lens.append(len(utr_tokens) + 1) # +1 for [CLS] 234 | cls_positions.append(len(tokens)) 235 | tokens.append("[CLS]") 236 | segment_ids.append(0) 237 | speaker_ids.append(utr_spk) 238 | for j in range(len(example.reply_mask)): 239 | reply_mask_utr2word[j].append(example.reply_mask[j][i]) 240 | 241 | for token in utr_tokens: 242 | tokens.append(token) 243 | segment_ids.append(0) 244 | speaker_ids.append(utr_spk) 245 | for j in range(len(example.reply_mask)): 246 | reply_mask_utr2word[j].append(example.reply_mask[j][i]) 247 | 248 | # response 249 | utr_lens.append(len(rsp_tokens) + 2) # +2 for [CLS] and [SEP] 250 | rsp_position.append(len(cls_positions)) 251 | cls_positions.append(len(tokens)) 252 | tokens.append("[CLS]") 253 | segment_ids.append(0) 254 | # speaker_ids.append(example.rsp_spk) 255 | speaker_ids.append(0) # 0 for mask 256 | for j in range(len(example.reply_mask)): 257 | reply_mask_utr2word[j].append(example.reply_mask[j][-1]) 258 | 259 | for token in rsp_tokens: 260 | tokens.append(token) 261 | segment_ids.append(0) 262 | # speaker_ids.append(example.rsp_spk) 263 | speaker_ids.append(0) 264 | for j in range(len(example.reply_mask)): 265 | reply_mask_utr2word[j].append(example.reply_mask[j][-1]) 266 | 267 | tokens.append("[SEP]") 268 | segment_ids.append(0) 269 | # speaker_ids.append(example.rsp_spk) 270 | speaker_ids.append(0) 271 | for j in range(len(example.reply_mask)): 272 | reply_mask_utr2word[j].append(example.reply_mask[j][-1]) 273 | 274 | assert len(utr_lens) == len(reply_mask_utr2word) 275 | for i in range(len(reply_mask_utr2word)): 276 | assert len(reply_mask_utr2word[i]) <= max_seq_length 277 | while len(reply_mask_utr2word[i]) < max_seq_length: 278 | reply_mask_utr2word[i].append(0) 279 | assert len(reply_mask_utr2word[i]) == max_seq_length 280 | 281 | assert len(reply_mask_utr2word) <= max_utr_num 282 | while len(reply_mask_utr2word) < max_utr_num: 283 | reply_mask_utr2word.append([0]*max_seq_length) 284 | assert len(reply_mask_utr2word) == max_utr_num 285 | 286 | reply_mask_utr2word_flatten = [] 287 | for x in reply_mask_utr2word: 288 | reply_mask_utr2word_flatten.extend(x) 289 | assert len(reply_mask_utr2word_flatten) == max_seq_length * max_utr_num 290 | 291 | assert len(utr_lens) <= max_utr_num 292 | while len(utr_lens) < max_utr_num: 293 | utr_lens.append(0) 294 | assert len(utr_lens) == max_utr_num 295 | 296 | 297 | input_sents = tokenizer.convert_tokens_to_ids(tokens) 298 | input_mask = [1] * len(input_sents) 299 | assert len(input_sents) <= max_seq_length 300 | while len(input_sents) < max_seq_length: 301 | input_sents.append(0) 302 | input_mask.append(0) 303 | segment_ids.append(0) 304 | speaker_ids.append(0) 305 | assert len(input_sents) == max_seq_length 306 | assert len(input_mask) == max_seq_length 307 | assert len(segment_ids) == max_seq_length 308 | assert len(speaker_ids) == max_seq_length 309 | 310 | assert len(cls_positions) <= max_utr_num 311 | while len(cls_positions) < max_utr_num: 312 | cls_positions.append(0) 313 | assert len(cls_positions) == max_utr_num 314 | 315 | label_id = example.label 316 | assert len(label_id) <= max_utr_num 317 | while len(label_id) < max_utr_num: 318 | label_id.append(0) 319 | assert len(label_id) == max_utr_num 320 | 321 | features.append( 322 | InputFeatures( 323 | input_sents=input_sents, 324 | input_mask=input_mask, 325 | segment_ids=segment_ids, 326 | speaker_ids=speaker_ids, 327 | cls_positions=cls_positions, 328 | rsp_position=rsp_position, 329 | reply_mask_utr2word_flatten=reply_mask_utr2word_flatten, 330 | utr_lens=utr_lens, 331 | label_id=label_id)) 332 | 333 | return features 334 | 335 | 336 | def write_instance_to_example_files(instances, output_files): 337 | writers = [] 338 | 339 | for output_file in output_files: 340 | writers.append(tf.python_io.TFRecordWriter(output_file)) 341 | 342 | writer_index = 0 343 | total_written = 0 344 | for (inst_index, instance) in enumerate(instances): 345 | features = collections.OrderedDict() 346 | features["input_sents"] = create_int_feature(instance.input_sents) 347 | features["input_mask"] = create_int_feature(instance.input_mask) 348 | features["segment_ids"] = create_int_feature(instance.segment_ids) 349 | features["speaker_ids"] = create_int_feature(instance.speaker_ids) 350 | features["cls_positions"] = create_int_feature(instance.cls_positions) 351 | features["rsp_position"] = create_int_feature(instance.rsp_position) 352 | features["reply_mask_utr2word_flatten"] = create_int_feature(instance.reply_mask_utr2word_flatten) 353 | features["utr_lens"] = create_int_feature(instance.utr_lens) 354 | features["label_ids"] = create_int_feature(instance.label_id) 355 | 356 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 357 | 358 | writers[writer_index].write(tf_example.SerializeToString()) 359 | writer_index = (writer_index + 1) % len(writers) 360 | 361 | total_written += 1 362 | 363 | print("write_{}_instance_to_example_files".format(total_written)) 364 | 365 | for feature_name in features.keys(): 366 | feature = features[feature_name] 367 | values = [] 368 | if feature.int64_list.value: 369 | values = feature.int64_list.value 370 | elif feature.float_list.value: 371 | values = feature.float_list.value 372 | tf.logging.info( 373 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 374 | 375 | for writer in writers: 376 | writer.close() 377 | 378 | 379 | def create_int_feature(values): 380 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 381 | return feature 382 | 383 | def create_float_feature(values): 384 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 385 | return feature 386 | 387 | 388 | 389 | if __name__ == "__main__": 390 | 391 | FLAGS = tf.flags.FLAGS 392 | print_configuration_op(FLAGS) 393 | 394 | tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 395 | 396 | filenames = [FLAGS.train_file, FLAGS.valid_file, FLAGS.test_file] 397 | filetypes = ["train", "valid", "test"] 398 | for (filename, filetype) in zip(filenames, filetypes): 399 | dataset = load_dataset(filename) 400 | examples = create_examples(dataset, filetype) 401 | features = convert_examples_to_features(examples, FLAGS.max_seq_length, FLAGS.max_utr_num, tokenizer) 402 | new_filename = filename[:-5] + "_si_gift.tfrecord" 403 | write_instance_to_example_files(features, [new_filename]) 404 | print('Convert {} to {} done'.format(filename, new_filename)) 405 | 406 | print("Sub-process(es) done.") 407 | -------------------------------------------------------------------------------- /data/emnlp2016/README.md: -------------------------------------------------------------------------------- 1 | The files of 5.json, 10.json and 15.json are placed in this folder.
2 | 3 | Run the following command to derive the processed datasets.
4 | ``` 5 | python data_preprocess.py 6 | ``` 7 | -------------------------------------------------------------------------------- /data/emnlp2016/data_preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | 4 | 5 | def get_relation(ctx_spk_index, ctx_adr_index, rsp_spk_index, rsp_adr_index): 6 | 7 | ctx_relation = [] 8 | ctx_spk_2_utr = {} 9 | for utr_id, (utr_spk, utr_adr) in enumerate(zip(ctx_spk_index, ctx_adr_index)): 10 | if utr_spk in ctx_spk_2_utr: 11 | ctx_spk_2_utr[utr_spk].append(utr_id) 12 | else: 13 | ctx_spk_2_utr[utr_spk] = [utr_id] 14 | 15 | if utr_adr == -1: 16 | continue 17 | 18 | if utr_adr in ctx_spk_2_utr: 19 | src = ctx_spk_2_utr[utr_adr][-1] 20 | tgt = utr_id 21 | ctx_relation.append([tgt, src]) 22 | 23 | if rsp_adr_index in ctx_spk_2_utr: 24 | rsp_idx = ctx_spk_2_utr[rsp_adr_index][-1] 25 | else: 26 | rsp_idx = -1 27 | 28 | return ctx_relation, rsp_idx 29 | 30 | 31 | # main 32 | for dialogue_len in [5, 10, 15]: 33 | print("Processing the dataset of conversation length: {} ...".format(dialogue_len)) 34 | 35 | with open("{}.json".format(dialogue_len), "r") as fin: 36 | data = json.load(fin) 37 | 38 | for split, dialogues in data.items(): 39 | with open("{}_{}.json".format(dialogue_len, split), "w") as fout: 40 | 41 | for dialogue in tqdm(dialogues, total=len(dialogues)): 42 | assert len(dialogue) == dialogue_len 43 | 44 | user_index = {'-': -1} 45 | 46 | # context 47 | ctx = [] 48 | ctx_spk = [] 49 | ctx_adr = [] 50 | for utterance in dialogue[:-1]: 51 | assert len(utterance) == 3 52 | utr_spk = utterance[0] 53 | utr = utterance[1] 54 | utr_adr = utterance[2] 55 | 56 | if utr_spk not in user_index: 57 | user_index[utr_spk] = len(user_index) 58 | if utr_adr not in user_index: 59 | user_index[utr_adr] = len(user_index) 60 | 61 | ctx.append(utr) 62 | ctx_spk.append(user_index[utr_spk]) 63 | ctx_adr.append(user_index[utr_adr]) 64 | 65 | # response 66 | response = dialogue[-1] 67 | rsp_spk = response[0] 68 | rsp = response[1] 69 | rsp_adr = response[2] 70 | 71 | if rsp_spk not in user_index: 72 | user_index[rsp_spk] = len(user_index) 73 | assert rsp_adr in user_index 74 | assert rsp_adr != '-' 75 | # if rsp_adr not in user_index: 76 | # user_index[rsp_adr] = len(user_index) 77 | 78 | rsp_spk = user_index[rsp_spk] 79 | rsp_adr = user_index[rsp_adr] 80 | 81 | ctx_relation, rsp_idx = get_relation(ctx_spk, ctx_adr, rsp_spk, rsp_adr) 82 | 83 | d = {} 84 | d['context'] = ctx 85 | d['relation_at'] = ctx_relation 86 | d['ctx_spk'] = ctx_spk 87 | d['ctx_adr'] = ctx_adr 88 | 89 | d['answer'] = rsp 90 | d['ans_idx'] = rsp_idx 91 | d['ans_spk'] = rsp_spk 92 | d['ans_adr'] = rsp_adr 93 | 94 | json_str = json.dumps(d) # indent=2 95 | fout.write(json_str +'\n') 96 | -------------------------------------------------------------------------------- /data/ijcai2019/README.md: -------------------------------------------------------------------------------- 1 | The files of train.json, dev.json and test.json are placed in this folder.
2 | -------------------------------------------------------------------------------- /image/result_addressee_recognition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonForJoy/MPC-BERT/d0171e941facf2066d9c23e2c8f936e41b94a008/image/result_addressee_recognition.png -------------------------------------------------------------------------------- /image/result_addressee_recognition_gift.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonForJoy/MPC-BERT/d0171e941facf2066d9c23e2c8f936e41b94a008/image/result_addressee_recognition_gift.png -------------------------------------------------------------------------------- /image/result_response_selection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonForJoy/MPC-BERT/d0171e941facf2066d9c23e2c8f936e41b94a008/image/result_response_selection.png -------------------------------------------------------------------------------- /image/result_response_selection_gift.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonForJoy/MPC-BERT/d0171e941facf2066d9c23e2c8f936e41b94a008/image/result_response_selection_gift.png -------------------------------------------------------------------------------- /image/result_speaker_identification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonForJoy/MPC-BERT/d0171e941facf2066d9c23e2c8f936e41b94a008/image/result_speaker_identification.png -------------------------------------------------------------------------------- /image/result_speaker_identification_gift.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JasonForJoy/MPC-BERT/d0171e941facf2066d9c23e2c8f936e41b94a008/image/result_speaker_identification_gift.png -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import math 3 | 4 | 5 | def is_valid_query(v): 6 | num_pos = 0 7 | num_neg = 0 8 | for aid, label, score in v: 9 | if label > 0: 10 | num_pos += 1 11 | else: 12 | num_neg += 1 13 | if num_pos > 0 and num_neg > 0: 14 | return True 15 | else: 16 | return False 17 | 18 | 19 | def get_num_valid_query(results): 20 | num_query = 0 21 | for k, v in results.items(): 22 | if not is_valid_query(v): 23 | continue 24 | num_query += 1 25 | return num_query 26 | 27 | 28 | def top_1_precision(results): 29 | num_query = 0 30 | top_1_correct = 0.0 31 | for k, v in results.items(): 32 | if not is_valid_query(v): 33 | continue 34 | num_query += 1 35 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True) 36 | aid, label, score = sorted_v[0] 37 | if label > 0: 38 | top_1_correct += 1 39 | 40 | if num_query > 0: 41 | return top_1_correct / num_query 42 | else: 43 | return 0.0 44 | 45 | 46 | def mean_reciprocal_rank(results): 47 | num_query = 0 48 | mrr = 0.0 49 | for k, v in results.items(): 50 | if not is_valid_query(v): 51 | continue 52 | 53 | num_query += 1 54 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True) 55 | for i, rec in enumerate(sorted_v): 56 | aid, label, score = rec 57 | if label > 0: 58 | mrr += 1.0 / (i + 1) 59 | break 60 | 61 | if num_query == 0: 62 | return 0.0 63 | else: 64 | mrr = mrr / num_query 65 | return mrr 66 | 67 | 68 | def mean_average_precision(results): 69 | num_query = 0 70 | mvp = 0.0 71 | for k, v in results.items(): 72 | if not is_valid_query(v): 73 | continue 74 | 75 | num_query += 1 76 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True) 77 | num_relevant_doc = 0.0 78 | avp = 0.0 79 | for i, rec in enumerate(sorted_v): 80 | aid, label, score = rec 81 | if label == 1: 82 | num_relevant_doc += 1 83 | precision = num_relevant_doc / (i + 1) 84 | avp += precision 85 | avp = avp / num_relevant_doc 86 | mvp += avp 87 | 88 | if num_query == 0: 89 | return 0.0 90 | else: 91 | mvp = mvp / num_query 92 | return mvp 93 | 94 | 95 | def classification_metrics(results): 96 | total_num = 0 97 | total_correct = 0 98 | true_positive = 0 99 | positive_correct = 0 100 | predicted_positive = 0 101 | 102 | loss = 0.0; 103 | for k, v in results.items(): 104 | for rec in v: 105 | total_num += 1 106 | aid, label, score = rec 107 | 108 | if score > 0.5: 109 | predicted_positive += 1 110 | 111 | if label > 0: 112 | true_positive += 1 113 | loss += -math.log(score + 1e-12) 114 | else: 115 | loss += -math.log(1.0 - score + 1e-12); 116 | 117 | if score > 0.5 and label > 0: 118 | total_correct += 1 119 | positive_correct += 1 120 | 121 | if score < 0.5 and label < 0.5: 122 | total_correct += 1 123 | 124 | accuracy = float(total_correct) / total_num 125 | precision = float(positive_correct) / (predicted_positive + 1e-12) 126 | recall = float(positive_correct) / true_positive 127 | F1 = 2.0 * precision * recall / (1e-12 + precision + recall) 128 | return accuracy, precision, recall, F1, loss / total_num; 129 | 130 | 131 | def top_k_precision(results, k=1): 132 | num_query = 0 133 | top_1_correct = 0.0 134 | for key, v in results.items(): 135 | if not is_valid_query(v): 136 | continue 137 | num_query += 1 138 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True) 139 | if k == 1: 140 | aid, label, score = sorted_v[0] 141 | if label > 0: 142 | top_1_correct += 1 143 | elif k == 2: 144 | aid1, label1, score1 = sorted_v[0] 145 | aid2, label2, score2 = sorted_v[1] 146 | if label1 > 0 or label2 > 0: 147 | top_1_correct += 1 148 | elif k == 5: 149 | for vv in sorted_v[0:5]: 150 | label = vv[1] 151 | if label > 0: 152 | top_1_correct += 1 153 | break 154 | else: 155 | raise BaseException 156 | 157 | if num_query > 0: 158 | return top_1_correct/num_query 159 | else: 160 | return 0.0 -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /run_finetuning_rs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """MPC-BERT finetuning runner on the downstream task of response selection.""" 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import operator 10 | from time import time 11 | from collections import defaultdict 12 | import tensorflow as tf 13 | import optimization 14 | import tokenization 15 | import modeling_speaker as modeling 16 | import metrics 17 | 18 | flags = tf.flags 19 | FLAGS = flags.FLAGS 20 | 21 | flags.DEFINE_string("train_dir", 'train.tfrecord', 22 | "The input train data dir. Should contain the .tsv files (or other data files) for the task.") 23 | 24 | flags.DEFINE_string("valid_dir", 'valid.tfrecord', 25 | "The input valid data dir. Should contain the .tsv files (or other data files) for the task.") 26 | 27 | flags.DEFINE_string("output_dir", 'output', 28 | "The output directory where the model checkpoints will be written.") 29 | 30 | flags.DEFINE_string("task_name", 'ResponseSelection', 31 | "The name of the task to train.") 32 | 33 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json', 34 | "The config json file corresponding to the pre-trained BERT model. " 35 | "This specifies the model architecture.") 36 | 37 | flags.DEFINE_string("vocab_file", 'uncased_L-12_H-768_A-12/vocab.txt', 38 | "The vocabulary file that the BERT model was trained on.") 39 | 40 | flags.DEFINE_string("init_checkpoint", 'uncased_L-12_H-768_A-12/bert_model.ckpt', 41 | "Initial checkpoint (usually from a pre-trained BERT model).") 42 | 43 | flags.DEFINE_bool("do_lower_case", True, 44 | "Whether to lower case the input text. Should be True for uncased " 45 | "models and False for cased models.") 46 | 47 | flags.DEFINE_integer("max_seq_length", 320, 48 | "The maximum total input sequence length after WordPiece tokenization. " 49 | "Sequences longer than this will be truncated, and sequences shorter " 50 | "than this will be padded.") 51 | 52 | flags.DEFINE_integer("max_utr_num", 7, 53 | "Maximum utterance number.") 54 | 55 | flags.DEFINE_bool("do_train", True, 56 | "Whether to run training.") 57 | 58 | flags.DEFINE_float("warmup_proportion", 0.1, 59 | "Proportion of training to perform linear learning rate warmup for. " 60 | "E.g., 0.1 = 10% of training.") 61 | 62 | flags.DEFINE_integer("train_batch_size", 12, 63 | "Total batch size for training.") 64 | 65 | flags.DEFINE_float("learning_rate", 2e-5, 66 | "The initial learning rate for Adam.") 67 | 68 | flags.DEFINE_integer("num_train_epochs", 5, 69 | "Total number of training epochs to perform.") 70 | 71 | 72 | 73 | def print_configuration_op(FLAGS): 74 | print('My Configurations:') 75 | for name, value in FLAGS.__flags.items(): 76 | value=value.value 77 | if type(value) == float: 78 | print(' %s:\t %f'%(name, value)) 79 | elif type(value) == int: 80 | print(' %s:\t %d'%(name, value)) 81 | elif type(value) == str: 82 | print(' %s:\t %s'%(name, value)) 83 | elif type(value) == bool: 84 | print(' %s:\t %s'%(name, value)) 85 | else: 86 | print('%s:\t %s' % (name, value)) 87 | print('End of configuration') 88 | 89 | 90 | def count_data_size(file_name): 91 | sample_nums = 0 92 | for record in tf.python_io.tf_record_iterator(file_name): 93 | sample_nums += 1 94 | return sample_nums 95 | 96 | 97 | def parse_exmp(serial_exmp): 98 | input_data = tf.parse_single_example(serial_exmp, 99 | features={ 100 | "ctx_id": 101 | tf.FixedLenFeature([], tf.int64), 102 | "rsp_id": 103 | tf.FixedLenFeature([], tf.int64), 104 | "input_sents": 105 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 106 | "input_mask": 107 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 108 | "segment_ids": 109 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 110 | "speaker_ids": 111 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 112 | "label_ids": 113 | tf.FixedLenFeature([], tf.float32), 114 | } 115 | ) 116 | # So cast all int64 to int32. 117 | for name in list(input_data.keys()): 118 | t = input_data[name] 119 | if t.dtype == tf.int64: 120 | t = tf.to_int32(t) 121 | input_data[name] = t 122 | 123 | ctx_id = input_data["ctx_id"] 124 | rsp_id = input_data['rsp_id'] 125 | input_sents = input_data["input_sents"] 126 | input_mask = input_data["input_mask"] 127 | segment_ids= input_data["segment_ids"] 128 | speaker_ids= input_data["speaker_ids"] 129 | labels = input_data['label_ids'] 130 | return ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, labels 131 | 132 | 133 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, labels, ctx_id, rsp_id, 134 | num_labels, use_one_hot_embeddings): 135 | """Creates a classification model.""" 136 | model = modeling.BertModel( 137 | config=bert_config, 138 | is_training=is_training, 139 | input_ids=input_ids, 140 | input_mask=input_mask, 141 | token_type_ids=segment_ids, 142 | speaker_ids=speaker_ids, 143 | use_one_hot_embeddings=use_one_hot_embeddings) 144 | 145 | target_loss_weight = [1.0, 1.0] 146 | target_loss_weight = tf.convert_to_tensor(target_loss_weight) 147 | 148 | flagx = tf.cast(tf.greater(labels, 0), dtype=tf.float32) 149 | flagy = tf.cast(tf.equal(labels, 0), dtype=tf.float32) 150 | 151 | all_target_loss = target_loss_weight[1] * flagx + target_loss_weight[0] * flagy 152 | 153 | output_layer = model.get_pooled_output() 154 | hidden_size = output_layer.shape[-1].value 155 | 156 | output_weights = tf.get_variable( 157 | "output_weights", [num_labels, hidden_size], 158 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 159 | output_bias = tf.get_variable( 160 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 161 | 162 | with tf.variable_scope("loss"): 163 | output_layer = tf.layers.dropout(output_layer, rate=0.1, training=is_training) 164 | 165 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 166 | logits = tf.nn.bias_add(logits, output_bias) 167 | 168 | probabilities = tf.sigmoid(logits, name="prob") 169 | logits = tf.squeeze(logits,[1]) 170 | losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels) 171 | losses = tf.multiply(losses, all_target_loss) 172 | 173 | mean_loss = tf.reduce_mean(losses, name="mean_loss") + sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 174 | 175 | with tf.name_scope("accuracy"): 176 | correct_prediction = tf.equal(tf.sign(probabilities - 0.5), tf.sign(labels - 0.5)) 177 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy") 178 | 179 | return mean_loss, logits, probabilities, accuracy 180 | 181 | 182 | def run_epoch(epoch, op_name, sess, training, logits, accuracy, mean_loss, train_opt): 183 | 184 | step = 0 185 | t0 = time() 186 | 187 | try: 188 | while True: 189 | step += 1 190 | batch_logits, batch_loss, _, batch_accuracy = sess.run([logits, mean_loss, train_opt, accuracy], feed_dict={training:True}) 191 | 192 | if step % 1000 == 0: 193 | tf.logging.info("Epoch: %i, Step: %d, Time (min): %.2f, Loss: %.4f, Accuracy: %.2f" % 194 | (epoch, step, (time() - t0) / 60.0, batch_loss, 100 * batch_accuracy)) 195 | 196 | except tf.errors.OutOfRangeError: 197 | tf.logging.info("Epoch: %i, Step: %d, Time (min): %.2f, Loss: %.4f, Accuracy: %.2f" % 198 | (epoch, step, (time() - t0) / 60.0, batch_loss, 100 * batch_accuracy)) 199 | pass 200 | 201 | 202 | best_score = 0.0 203 | def run_test(epoch, op_name, sess, training, accuracy, prob, pair_ids, saver, dir_path): 204 | 205 | step = 0 206 | t0 = time() 207 | num_test = 0 208 | num_correct = 0.0 209 | mrr = 0 210 | results = defaultdict(list) 211 | 212 | try: 213 | while True: 214 | step += 1 215 | batch_accuracy, predicted_prob, pair_ = sess.run([accuracy, prob, pair_ids], feed_dict={training:False}) 216 | question_id, answer_id, label = pair_ 217 | 218 | num_test += len(predicted_prob) 219 | num_correct += len(predicted_prob) * batch_accuracy 220 | for i, prob_score in enumerate(predicted_prob): 221 | results[question_id[i]].append((answer_id[i], label[i], prob_score[0])) 222 | 223 | if step % 1000 == 0: 224 | tf.logging.info("Epoch: %i, Step: %d, Time (min): %.2f" % (epoch, step, (time() - t0)/60.0 )) 225 | 226 | except tf.errors.OutOfRangeError: 227 | print('num_test_samples: {}, test_accuracy: {}'.format(num_test, num_correct / num_test)) 228 | accu, precision, recall, f1, loss = metrics.classification_metrics(results) 229 | print('Accuracy: {}, Precision: {}, Recall: {}, F1: {}, Loss: {}'.format(accu, precision, recall, f1, loss)) 230 | 231 | mvp = metrics.mean_average_precision(results) 232 | mrr = metrics.mean_reciprocal_rank(results) 233 | top_1_precision = metrics.top_1_precision(results) 234 | total_valid_query = metrics.get_num_valid_query(results) 235 | print('MAP (mean average precision: {}, MRR (mean reciprocal rank): {}, Top-1 precision: {}, Num_query: {}'.format( 236 | mvp, mrr, top_1_precision, total_valid_query)) 237 | 238 | out_path = os.path.join(dir_path, "output_epoch_{}.txt".format(epoch)) 239 | print("Saving evaluation to {}".format(out_path)) 240 | with open(out_path, 'w') as f: 241 | f.write("query_id\tdocument_id\tscore\trank\trelevance\n") 242 | for us_id, v in results.items(): 243 | v.sort(key=operator.itemgetter(2), reverse=True) 244 | for i, rec in enumerate(v): 245 | r_id, label, prob_score = rec 246 | rank = i + 1 247 | f.write('{}\t{}\t{}\t{}\t{}\n'.format(us_id, r_id, prob_score, rank, label)) 248 | 249 | global best_score 250 | if op_name == 'valid' and mrr > best_score: 251 | best_score = mrr 252 | dir_path = os.path.join(dir_path, "epoch_{}".format(epoch)) 253 | saver.save(sess, dir_path) 254 | tf.logging.info(">> Save model!") 255 | 256 | return mrr 257 | 258 | 259 | 260 | def main(_): 261 | tf.logging.set_verbosity(tf.logging.INFO) 262 | print_configuration_op(FLAGS) 263 | 264 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 265 | 266 | root_path = FLAGS.output_dir 267 | if not os.path.exists(root_path): 268 | os.makedirs(root_path) 269 | timestamp = str(int(time())) 270 | root_path = os.path.join(root_path, timestamp) 271 | tf.logging.info('root_path: {}'.format(root_path)) 272 | if not os.path.exists(root_path): 273 | os.makedirs(root_path) 274 | 275 | train_data_size = count_data_size(FLAGS.train_dir) 276 | tf.logging.info('train data size: {}'.format(train_data_size)) 277 | valid_data_size = count_data_size(FLAGS.valid_dir) 278 | tf.logging.info('valid data size: {}'.format(valid_data_size)) 279 | 280 | num_train_steps = train_data_size // FLAGS.train_batch_size * FLAGS.num_train_epochs 281 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 282 | 283 | filenames = tf.placeholder(tf.string, shape=[None]) 284 | shuffle_size = tf.placeholder(tf.int64) 285 | dataset = tf.data.TFRecordDataset(filenames) 286 | dataset = dataset.map(parse_exmp) # Parse the record into tensors. 287 | dataset = dataset.repeat(1) 288 | # buffer_size 100 289 | dataset = dataset.shuffle(shuffle_size) 290 | dataset = dataset.batch(FLAGS.train_batch_size) 291 | iterator = dataset.make_initializable_iterator() 292 | ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, labels = iterator.get_next() 293 | pair_ids = [ctx_id, rsp_id, labels] 294 | 295 | 296 | training = tf.placeholder(tf.bool) 297 | mean_loss, logits, probabilities, accuracy = create_model(bert_config = bert_config, 298 | is_training = training, 299 | input_ids = input_sents, 300 | input_mask = input_mask, 301 | segment_ids = segment_ids, 302 | speaker_ids = speaker_ids, 303 | labels = labels, 304 | ctx_id = ctx_id, 305 | rsp_id = rsp_id, 306 | num_labels = 1, 307 | use_one_hot_embeddings = False) 308 | 309 | # init model with pre-training 310 | tvars = tf.trainable_variables() 311 | if FLAGS.init_checkpoint: 312 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,FLAGS.init_checkpoint) 313 | tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map) 314 | 315 | tf.logging.info("**** Trainable Variables ****") 316 | for var in tvars: 317 | init_string = "" 318 | if var.name in initialized_variable_names: 319 | init_string = ", *INIT_FROM_CKPT*" 320 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) 321 | 322 | train_opt = optimization.create_optimizer(mean_loss, FLAGS.learning_rate, num_train_steps, num_warmup_steps, False) 323 | 324 | config = tf.ConfigProto(allow_soft_placement=True) 325 | config.gpu_options.allow_growth = True 326 | saver = tf.train.Saver() 327 | 328 | if FLAGS.do_train: 329 | with tf.Session(config=config) as sess: 330 | sess.run(tf.global_variables_initializer()) 331 | 332 | for epoch in range(FLAGS.num_train_epochs): 333 | tf.logging.info('Train begin epoch {}'.format(epoch)) 334 | sess.run(iterator.initializer, 335 | feed_dict={filenames: [FLAGS.train_dir], shuffle_size: 1024}) 336 | run_epoch(epoch, "train", sess, training, logits, accuracy, mean_loss, train_opt) 337 | 338 | tf.logging.info('Valid begin') 339 | sess.run(iterator.initializer, 340 | feed_dict={filenames: [FLAGS.valid_dir], shuffle_size: 1}) 341 | run_test(epoch, "valid", sess, training, accuracy, probabilities, pair_ids, saver, root_path) 342 | 343 | 344 | if __name__ == "__main__": 345 | tf.app.run() 346 | -------------------------------------------------------------------------------- /run_finetuning_si.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """MPC-BERT finetuning runner on the downstream task of speaker identification.""" 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import operator 10 | from time import time 11 | from collections import defaultdict 12 | import tensorflow as tf 13 | import optimization 14 | import tokenization 15 | import modeling_speaker as modeling 16 | 17 | flags = tf.flags 18 | FLAGS = flags.FLAGS 19 | 20 | flags.DEFINE_string("train_dir", 'train.tfrecord', 21 | "The input train data dir. Should contain the .tsv files (or other data files) for the task.") 22 | 23 | flags.DEFINE_string("valid_dir", 'valid.tfrecord', 24 | "The input valid data dir. Should contain the .tsv files (or other data files) for the task.") 25 | 26 | flags.DEFINE_string("output_dir", 'output', 27 | "The output directory where the model checkpoints will be written.") 28 | 29 | flags.DEFINE_string("task_name", 'SpeakerIdentification', 30 | "The name of the task to train.") 31 | 32 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json', 33 | "The config json file corresponding to the pre-trained BERT model. " 34 | "This specifies the model architecture.") 35 | 36 | flags.DEFINE_string("vocab_file", 'uncased_L-12_H-768_A-12/vocab.txt', 37 | "The vocabulary file that the BERT model was trained on.") 38 | 39 | flags.DEFINE_string("init_checkpoint", 'uncased_L-12_H-768_A-12/bert_model.ckpt', 40 | "Initial checkpoint (usually from a pre-trained BERT model).") 41 | 42 | flags.DEFINE_bool("do_lower_case", True, 43 | "Whether to lower case the input text. Should be True for uncased " 44 | "models and False for cased models.") 45 | 46 | flags.DEFINE_integer("max_seq_length", 320, 47 | "The maximum total input sequence length after WordPiece tokenization. " 48 | "Sequences longer than this will be truncated, and sequences shorter " 49 | "than this will be padded.") 50 | 51 | flags.DEFINE_integer("max_utr_num", 7, 52 | "Maximum utterance number.") 53 | 54 | flags.DEFINE_bool("do_train", True, 55 | "Whether to run training.") 56 | 57 | flags.DEFINE_float("warmup_proportion", 0.1, 58 | "Proportion of training to perform linear learning rate warmup for. " 59 | "E.g., 0.1 = 10% of training.") 60 | 61 | flags.DEFINE_integer("train_batch_size", 12, 62 | "Total batch size for training.") 63 | 64 | flags.DEFINE_float("learning_rate", 2e-5, 65 | "The initial learning rate for Adam.") 66 | 67 | flags.DEFINE_integer("num_train_epochs", 5, 68 | "Total number of training epochs to perform.") 69 | 70 | 71 | 72 | def print_configuration_op(FLAGS): 73 | print('My Configurations:') 74 | for name, value in FLAGS.__flags.items(): 75 | value=value.value 76 | if type(value) == float: 77 | print(' %s:\t %f'%(name, value)) 78 | elif type(value) == int: 79 | print(' %s:\t %d'%(name, value)) 80 | elif type(value) == str: 81 | print(' %s:\t %s'%(name, value)) 82 | elif type(value) == bool: 83 | print(' %s:\t %s'%(name, value)) 84 | else: 85 | print('%s:\t %s' % (name, value)) 86 | print('End of configuration') 87 | 88 | 89 | def count_data_size(file_name): 90 | sample_nums = 0 91 | for record in tf.python_io.tf_record_iterator(file_name): 92 | sample_nums += 1 93 | return sample_nums 94 | 95 | 96 | def parse_exmp(serial_exmp): 97 | input_data = tf.parse_single_example(serial_exmp, 98 | features={ 99 | "input_sents": 100 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 101 | "input_mask": 102 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 103 | "segment_ids": 104 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 105 | "speaker_ids": 106 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 107 | "cls_positions": 108 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64), 109 | "rsp_position": 110 | tf.FixedLenFeature([1], tf.int64), 111 | "label_ids": 112 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64), 113 | } 114 | ) 115 | # So cast all int64 to int32. 116 | for name in list(input_data.keys()): 117 | t = input_data[name] 118 | if t.dtype == tf.int64: 119 | t = tf.to_int32(t) 120 | input_data[name] = t 121 | 122 | input_sents = input_data["input_sents"] 123 | input_mask = input_data["input_mask"] 124 | segment_ids= input_data["segment_ids"] 125 | speaker_ids= input_data["speaker_ids"] 126 | cls_positions= input_data["cls_positions"] 127 | rsp_position= input_data["rsp_position"] 128 | labels = input_data['label_ids'] 129 | return input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels 130 | 131 | 132 | def gather_indexes(sequence_tensor, positions): 133 | """Gathers the vectors at the specific positions over a minibatch.""" 134 | # sequence_tensor = [batch_size, seq_length, width] 135 | # positions = [batch_size, max_utr_num] 136 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 137 | batch_size = sequence_shape[0] 138 | seq_length = sequence_shape[1] 139 | width = sequence_shape[2] 140 | 141 | flat_offsets = tf.reshape( 142 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) # [batch_size, 1] 143 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) # [batch_size*max_utr_num, ] 144 | flat_sequence_tensor = tf.reshape(sequence_tensor, 145 | [batch_size * seq_length, width]) 146 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) # [batch_size*max_utr_num, width] 147 | return output_tensor 148 | 149 | 150 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, 151 | num_labels, use_one_hot_embeddings): 152 | """Creates a classification model.""" 153 | model = modeling.BertModel( 154 | config=bert_config, 155 | is_training=is_training, 156 | input_ids=input_ids, 157 | input_mask=input_mask, 158 | token_type_ids=segment_ids, 159 | speaker_ids=speaker_ids, 160 | use_one_hot_embeddings=use_one_hot_embeddings) 161 | 162 | input_tensor = gather_indexes(model.get_sequence_output(), cls_positions) # [batch_size*max_utr_num, dim] 163 | 164 | input_shape = modeling.get_shape_list(input_tensor, expected_rank=2) 165 | width = input_shape[-1] 166 | positions_shape = modeling.get_shape_list(cls_positions, expected_rank=2) 167 | max_utr_num = positions_shape[-1] 168 | 169 | with tf.variable_scope("cls/speaker_restore"): 170 | # We apply one more non-linear transformation before the output layer. 171 | with tf.variable_scope("transform"): 172 | input_tensor = tf.layers.dense( 173 | input_tensor, 174 | units=bert_config.hidden_size, 175 | activation=modeling.get_activation(bert_config.hidden_act), 176 | kernel_initializer=modeling.create_initializer(bert_config.initializer_range)) 177 | input_tensor = modeling.layer_norm(input_tensor) # [batch_size*max_utr_num, dim] 178 | 179 | input_tensor = tf.reshape(input_tensor, [-1, max_utr_num, width]) # [batch_size, max_utr_num, dim] 180 | 181 | rsp_tensor = gather_indexes(input_tensor, rsp_position) # [batch_size*1, dim] 182 | rsp_tensor = tf.reshape(rsp_tensor, [-1, 1, width]) # [batch_size, 1, dim] 183 | 184 | output_weights = tf.get_variable( 185 | "output_weights", 186 | shape=[width, width], 187 | initializer=modeling.create_initializer(bert_config.initializer_range)) 188 | logits = tf.matmul(tf.einsum('aij,jk->aik', rsp_tensor, output_weights), 189 | input_tensor, transpose_b=True) # [batch_size, 1, max_utr_num] 190 | logits = tf.squeeze(logits, [1]) # [batch_size, max_utr_num] 191 | 192 | mask = tf.sequence_mask(tf.reshape(rsp_position, [-1, ]), max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num] 193 | logits = logits * mask + -1e9 * (1-mask) 194 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size, max_utr_num] 195 | 196 | # loss 197 | one_hot_labels = tf.cast(labels, "float") # [batch_size, max_utr_num] 198 | per_example_loss = - tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch_size, ] 199 | mean_loss = tf.reduce_mean(per_example_loss, name="mean_loss") 200 | 201 | # accuracy 202 | predictions = tf.argmax(log_probs, axis=-1, output_type=tf.int32) # [batch_size, ] 203 | predictions_one_hot = tf.one_hot(predictions, depth=max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num] 204 | correct_prediction = tf.reduce_sum(predictions_one_hot * one_hot_labels, -1) # [batch_size, ] 205 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy") 206 | 207 | return mean_loss, logits, log_probs, accuracy 208 | 209 | 210 | def run_epoch(epoch, op_name, sess, training, logits, accuracy, mean_loss, train_opt): 211 | 212 | step = 0 213 | t0 = time() 214 | 215 | try: 216 | while True: 217 | step += 1 218 | batch_logits, batch_loss, _, batch_accuracy = sess.run([logits, mean_loss, train_opt, accuracy], feed_dict={training: True}) 219 | 220 | if step % 1000 == 0: 221 | tf.logging.info("Epoch: %i, Step: %d, Time (min): %.2f, Loss: %.4f, Accuracy: %.2f" % 222 | (epoch, step, (time() - t0) / 60.0, batch_loss, 100 * batch_accuracy)) 223 | 224 | except tf.errors.OutOfRangeError: 225 | tf.logging.info("Epoch: %i, Step: %d, Time (min): %.2f, Loss: %.4f, Accuracy: %.2f" % 226 | (epoch, step, (time() - t0) / 60.0, batch_loss, 100 * batch_accuracy)) 227 | pass 228 | 229 | 230 | best_score = 0.0 231 | def run_test(epoch, op_name, sess, training, prob, accuracy, saver, dir_path): 232 | 233 | step = 0 234 | t0 = time() 235 | num_test = 0 236 | num_correct = 0.0 237 | test_accuracy = 0 238 | 239 | try: 240 | while True: 241 | step += 1 242 | batch_accuracy, predicted_prob = sess.run([accuracy, prob], feed_dict={training: False}) 243 | 244 | num_test += len(predicted_prob) 245 | num_correct += len(predicted_prob) * batch_accuracy 246 | 247 | if step % 100 == 0: 248 | tf.logging.info("Epoch: %i, Step: %d, Time (min): %.2f" % (epoch, step, (time() - t0)/60.0 )) 249 | 250 | except tf.errors.OutOfRangeError: 251 | test_accuracy = num_correct / num_test 252 | print('num_test_samples: {}, test_accuracy: {}'.format(num_test, test_accuracy)) 253 | 254 | global best_score 255 | if op_name == 'valid' and test_accuracy > best_score: 256 | best_score = test_accuracy 257 | dir_path = os.path.join(dir_path, "epoch_{}".format(epoch)) 258 | saver.save(sess, dir_path) 259 | tf.logging.info(">> Save model!") 260 | 261 | return test_accuracy 262 | 263 | 264 | 265 | def main(_): 266 | tf.logging.set_verbosity(tf.logging.INFO) 267 | print_configuration_op(FLAGS) 268 | 269 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 270 | 271 | root_path = FLAGS.output_dir 272 | if not os.path.exists(root_path): 273 | os.makedirs(root_path) 274 | timestamp = str(int(time())) 275 | root_path = os.path.join(root_path, timestamp) 276 | tf.logging.info('root_path: {}'.format(root_path)) 277 | if not os.path.exists(root_path): 278 | os.makedirs(root_path) 279 | 280 | train_data_size = count_data_size(FLAGS.train_dir) 281 | tf.logging.info('train data size: {}'.format(train_data_size)) 282 | valid_data_size = count_data_size(FLAGS.valid_dir) 283 | tf.logging.info('valid data size: {}'.format(valid_data_size)) 284 | 285 | num_train_steps = train_data_size // FLAGS.train_batch_size * FLAGS.num_train_epochs 286 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 287 | 288 | filenames = tf.placeholder(tf.string, shape=[None]) 289 | shuffle_size = tf.placeholder(tf.int64) 290 | dataset = tf.data.TFRecordDataset(filenames) 291 | dataset = dataset.map(parse_exmp) # Parse the record into tensors. 292 | dataset = dataset.repeat(1) 293 | # buffer_size 100 294 | dataset = dataset.shuffle(shuffle_size) 295 | dataset = dataset.batch(FLAGS.train_batch_size) 296 | iterator = dataset.make_initializable_iterator() 297 | input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels = iterator.get_next() 298 | 299 | 300 | training = tf.placeholder(tf.bool) 301 | mean_loss, logits, log_probs, accuracy = create_model(bert_config = bert_config, 302 | is_training = training, 303 | input_ids = input_sents, 304 | input_mask = input_mask, 305 | segment_ids = segment_ids, 306 | speaker_ids = speaker_ids, 307 | cls_positions = cls_positions, 308 | rsp_position = rsp_position, 309 | labels = labels, 310 | num_labels = 1, 311 | use_one_hot_embeddings = False) 312 | 313 | # init model with pre-training 314 | tvars = tf.trainable_variables() 315 | if FLAGS.init_checkpoint: 316 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, FLAGS.init_checkpoint) 317 | tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map) 318 | 319 | tf.logging.info("**** Trainable Variables ****") 320 | for var in tvars: 321 | init_string = "" 322 | if var.name in initialized_variable_names: 323 | init_string = ", *INIT_FROM_CKPT*" 324 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) 325 | 326 | train_opt = optimization.create_optimizer(mean_loss, FLAGS.learning_rate, num_train_steps, num_warmup_steps, False) 327 | 328 | config = tf.ConfigProto(allow_soft_placement=True) 329 | config.gpu_options.allow_growth = True 330 | saver = tf.train.Saver() 331 | 332 | if FLAGS.do_train: 333 | with tf.Session(config=config) as sess: 334 | sess.run(tf.global_variables_initializer()) 335 | 336 | for epoch in range(FLAGS.num_train_epochs): 337 | tf.logging.info('Train begin epoch {}'.format(epoch)) 338 | sess.run(iterator.initializer, 339 | feed_dict={filenames: [FLAGS.train_dir], shuffle_size: 1024}) 340 | run_epoch(epoch, "train", sess, training, logits, accuracy, mean_loss, train_opt) 341 | 342 | tf.logging.info('Valid begin') 343 | sess.run(iterator.initializer, 344 | feed_dict={filenames: [FLAGS.valid_dir], shuffle_size: 1}) 345 | run_test(epoch, "valid", sess, training, log_probs, accuracy, saver, root_path) 346 | 347 | 348 | if __name__ == "__main__": 349 | tf.app.run() 350 | -------------------------------------------------------------------------------- /run_testing_ar.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """MPC-BERT testing runner on the downstream task of addressee recognition.""" 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import operator 10 | from time import time 11 | from collections import defaultdict 12 | import tensorflow as tf 13 | import optimization 14 | import tokenization 15 | import modeling_speaker as modeling 16 | 17 | flags = tf.flags 18 | FLAGS = flags.FLAGS 19 | 20 | flags.DEFINE_string("task_name", 'Testing', 21 | "The name of the task.") 22 | 23 | flags.DEFINE_string("test_dir", 'test.tfrecord', 24 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.") 25 | 26 | flags.DEFINE_string("restore_model_dir", 'output/', 27 | "The output directory where the model checkpoints have been written.") 28 | 29 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json', 30 | "The config json file corresponding to the pre-trained BERT model. " 31 | "This specifies the model architecture.") 32 | 33 | flags.DEFINE_bool("do_eval", True, 34 | "Whether to run eval on the dev set.") 35 | 36 | flags.DEFINE_integer("eval_batch_size", 32, 37 | "Total batch size for predict.") 38 | 39 | flags.DEFINE_integer("max_seq_length", 320, 40 | "The maximum total input sequence length after WordPiece tokenization. " 41 | "Sequences longer than this will be truncated, and sequences shorter " 42 | "than this will be padded.") 43 | 44 | flags.DEFINE_integer("max_utr_num", 7, 45 | "Maximum utterance number.") 46 | 47 | 48 | def print_configuration_op(FLAGS): 49 | print('My Configurations:') 50 | for name, value in FLAGS.__flags.items(): 51 | value=value.value 52 | if type(value) == float: 53 | print(' %s:\t %f'%(name, value)) 54 | elif type(value) == int: 55 | print(' %s:\t %d'%(name, value)) 56 | elif type(value) == str: 57 | print(' %s:\t %s'%(name, value)) 58 | elif type(value) == bool: 59 | print(' %s:\t %s'%(name, value)) 60 | else: 61 | print('%s:\t %s' % (name, value)) 62 | print('End of configuration') 63 | 64 | 65 | def count_data_size(file_name): 66 | sample_nums = 0 67 | for record in tf.python_io.tf_record_iterator(file_name): 68 | sample_nums += 1 69 | return sample_nums 70 | 71 | 72 | def parse_exmp(serial_exmp): 73 | input_data = tf.parse_single_example(serial_exmp, 74 | features={ 75 | "input_sents": 76 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 77 | "input_mask": 78 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 79 | "segment_ids": 80 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 81 | "speaker_ids": 82 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 83 | "cls_positions": 84 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64), 85 | "label_ids": 86 | tf.FixedLenFeature([FLAGS.max_utr_num*FLAGS.max_utr_num], tf.int64), 87 | "label_weights": 88 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.float32), 89 | } 90 | ) 91 | # So cast all int64 to int32. 92 | for name in list(input_data.keys()): 93 | t = input_data[name] 94 | if t.dtype == tf.int64: 95 | t = tf.to_int32(t) 96 | input_data[name] = t 97 | 98 | input_sents = input_data["input_sents"] 99 | input_mask = input_data["input_mask"] 100 | segment_ids= input_data["segment_ids"] 101 | speaker_ids= input_data["speaker_ids"] 102 | cls_positions= input_data["cls_positions"] 103 | labels = input_data['label_ids'] 104 | label_weights = input_data['label_weights'] 105 | return input_sents, input_mask, segment_ids, speaker_ids, cls_positions, labels, label_weights 106 | 107 | 108 | def gather_indexes(sequence_tensor, positions): 109 | """Gathers the vectors at the specific positions over a minibatch.""" 110 | # sequence_tensor = [batch_size, seq_length, width] 111 | # positions = [batch_size, max_utr_num] 112 | 113 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 114 | batch_size = sequence_shape[0] 115 | seq_length = sequence_shape[1] 116 | width = sequence_shape[2] 117 | 118 | flat_offsets = tf.reshape( 119 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) # [batch_size, 1] 120 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) # [batch_size*max_utr_num, ] 121 | flat_sequence_tensor = tf.reshape(sequence_tensor, 122 | [batch_size * seq_length, width]) 123 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) # [batch_size*max_utr_num, width] 124 | return output_tensor 125 | 126 | 127 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, cls_positions, labels, label_weights, 128 | use_one_hot_embeddings): 129 | """Creates a classification model.""" 130 | model = modeling.BertModel( 131 | config=bert_config, 132 | is_training=is_training, 133 | input_ids=input_ids, 134 | input_mask=input_mask, 135 | token_type_ids=segment_ids, 136 | speaker_ids=speaker_ids, 137 | use_one_hot_embeddings=use_one_hot_embeddings) 138 | 139 | positions_shape = modeling.get_shape_list(cls_positions, expected_rank=2) 140 | max_utr_num = positions_shape[-1] 141 | 142 | input_tensor = gather_indexes(model.get_sequence_output(), cls_positions) # [batch_size*max_utr_num, dim] 143 | input_shape = modeling.get_shape_list(input_tensor, expected_rank=2) 144 | width = input_shape[-1] 145 | 146 | with tf.variable_scope("cls/addressee_recognize"): 147 | # We apply one more non-linear transformation before the output layer. 148 | with tf.variable_scope("transform"): 149 | input_tensor = tf.layers.dense( 150 | input_tensor, 151 | units=bert_config.hidden_size, 152 | activation=modeling.get_activation(bert_config.hidden_act), 153 | kernel_initializer=modeling.create_initializer(bert_config.initializer_range)) 154 | input_tensor = modeling.layer_norm(input_tensor) # [batch_size*max_utr_num, dim] 155 | 156 | input_tensor = tf.reshape(input_tensor, [-1, max_utr_num, width]) # [batch_size, max_utr_num, dim] 157 | output_weights = tf.get_variable( 158 | "output_weights", 159 | shape=[width, width], 160 | initializer=modeling.create_initializer(bert_config.initializer_range)) 161 | logits = tf.matmul(tf.einsum('aij,jk->aik', input_tensor, output_weights), 162 | input_tensor, transpose_b=True) # [batch_size, max_utr_num, max_utr_num] 163 | 164 | # mask = [[0. 0. 0. 0. 0.] 165 | # [1. 0. 0. 0. 0.] 166 | # [1. 1. 0. 0. 0.] 167 | # [1. 1. 1. 0. 0.] 168 | # [1. 1. 1. 1. 0.]] 169 | mask = tf.matrix_band_part(tf.ones((max_utr_num, max_utr_num)), -1, 0) - tf.matrix_band_part(tf.ones((max_utr_num, max_utr_num)), 0, 0) 170 | logits = logits * mask + -1e9 * (1-mask) # [batch_size, max_utr_num, max_utr_num] 171 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size, max_utr_num, max_utr_num] 172 | one_hot_labels = tf.reshape(labels, [-1, max_utr_num, max_utr_num]) 173 | one_hot_labels = tf.cast(one_hot_labels, "float") 174 | 175 | # loss 176 | per_example_loss = - tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch_size, max_utr_num] 177 | numerator = tf.reduce_sum(label_weights * per_example_loss) # [1, ] 178 | denominator = tf.reduce_sum(label_weights) + 1e-5 # [1, ] 179 | mean_loss = numerator / denominator 180 | 181 | # accuracy 182 | predictions = tf.argmax(log_probs, axis=-1, output_type=tf.int32) # [batch_size, max_utr_num] 183 | predictions_one_hot = tf.one_hot(predictions, depth=max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num, max_utr_num] 184 | correct_prediction_utr = tf.reduce_sum(predictions_one_hot * one_hot_labels, -1) # [batch_size, max_utr_num] 185 | numerator_utr = tf.reduce_sum(label_weights * correct_prediction_utr) # [1, ] 186 | accuracy_utr = numerator_utr / denominator 187 | 188 | correct_prediction_sess = tf.equal(tf.reduce_sum(label_weights * correct_prediction_utr, -1), 189 | tf.reduce_sum(label_weights, -1)) # [batch_size, ] 190 | accuracy_sess = tf.reduce_mean(tf.cast(correct_prediction_sess, "float"), name="accuracy") 191 | 192 | return mean_loss, logits, log_probs, accuracy_utr, accuracy_sess, tf.reduce_sum(label_weights) 193 | 194 | 195 | def run_test(sess, training, prob, accuracy_utr, accuracy_sess, num_utr): 196 | 197 | step = 0 198 | t0 = time() 199 | num_test_utr = 0 200 | num_correct_utr = 0.0 201 | test_accuracy_utr = 0 202 | num_test_sess = 0 203 | num_correct_sess = 0.0 204 | test_accuracy_sess = 0 205 | 206 | try: 207 | while True: 208 | step += 1 209 | batch_accuracy_utr, batch_accuracy_sess, batch_predicted_prob, batch_num_utr = sess.run( 210 | [accuracy_utr, accuracy_sess, prob, num_utr], feed_dict={training: False}) 211 | 212 | num_test_utr += int(batch_num_utr) 213 | num_correct_utr += batch_num_utr * batch_accuracy_utr 214 | 215 | num_test_sess += len(batch_predicted_prob) 216 | num_correct_sess += len(batch_predicted_prob) * batch_accuracy_sess 217 | 218 | if step % 100 == 0: 219 | tf.logging.info("Step %d, Time (min): %.2f" % (step, (time() - t0) / 60.0)) 220 | 221 | except tf.errors.OutOfRangeError: 222 | test_accuracy_utr = num_correct_utr / num_test_utr 223 | print('num_test_utterance: {}, test_accuracy_utr: {}'.format(num_test_utr, test_accuracy_utr)) 224 | test_accuracy_sess = num_correct_sess / num_test_sess 225 | print('num_test_session: {}, test_accuracy_sess: {}'.format(num_test_sess, test_accuracy_sess)) 226 | 227 | return test_accuracy_sess 228 | 229 | 230 | def main(_): 231 | tf.logging.set_verbosity(tf.logging.INFO) 232 | print_configuration_op(FLAGS) 233 | 234 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 235 | 236 | test_data_size = count_data_size(FLAGS.test_dir) 237 | tf.logging.info('test data size: {}'.format(test_data_size)) 238 | 239 | filenames = tf.placeholder(tf.string, shape=[None]) 240 | shuffle_size = tf.placeholder(tf.int64) 241 | dataset = tf.data.TFRecordDataset(filenames) 242 | dataset = dataset.map(parse_exmp) # Parse the record into tensors. 243 | dataset = dataset.repeat(1) 244 | # dataset = dataset.shuffle(shuffle_size) 245 | dataset = dataset.batch(FLAGS.eval_batch_size) 246 | iterator = dataset.make_initializable_iterator() 247 | input_sents, input_mask, segment_ids, speaker_ids, cls_positions, labels, label_weights = iterator.get_next() 248 | 249 | training = tf.placeholder(tf.bool) 250 | mean_loss, logits, log_probs, accuracy_utr, accuracy_sess, num_utr = create_model(bert_config = bert_config, 251 | is_training = training, 252 | input_ids = input_sents, 253 | input_mask = input_mask, 254 | segment_ids = segment_ids, 255 | speaker_ids = speaker_ids, 256 | cls_positions = cls_positions, 257 | labels = labels, 258 | label_weights = label_weights, 259 | use_one_hot_embeddings = False) 260 | 261 | config = tf.ConfigProto(allow_soft_placement=True) 262 | config.gpu_options.allow_growth = True 263 | 264 | if FLAGS.do_eval: 265 | with tf.Session(config=config) as sess: 266 | tf.logging.info("*** Restore model ***") 267 | 268 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir) 269 | variables = tf.trainable_variables() 270 | saver = tf.train.Saver(variables) 271 | saver.restore(sess, ckpt.model_checkpoint_path) 272 | 273 | tf.logging.info('Test begin') 274 | sess.run(iterator.initializer, 275 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1}) 276 | run_test(sess, training, log_probs, accuracy_utr, accuracy_sess, num_utr) 277 | 278 | 279 | if __name__ == "__main__": 280 | tf.app.run() 281 | -------------------------------------------------------------------------------- /run_testing_ar_gift.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """MPC-BERT testing runner on the downstream task of speaker identification.""" 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import operator 10 | from time import time 11 | from collections import defaultdict 12 | import tensorflow as tf 13 | import optimization 14 | import tokenization 15 | import modeling_speaker_gift as modeling 16 | 17 | flags = tf.flags 18 | FLAGS = flags.FLAGS 19 | 20 | flags.DEFINE_string("task_name", 'Testing', 21 | "The name of the task.") 22 | 23 | flags.DEFINE_string("test_dir", 'test.tfrecord', 24 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.") 25 | 26 | flags.DEFINE_string("restore_model_dir", 'output/', 27 | "The output directory where the model checkpoints have been written.") 28 | 29 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json', 30 | "The config json file corresponding to the pre-trained BERT model. " 31 | "This specifies the model architecture.") 32 | 33 | flags.DEFINE_bool("do_eval", True, 34 | "Whether to run eval on the dev set.") 35 | 36 | flags.DEFINE_integer("eval_batch_size", 32, 37 | "Total batch size for predict.") 38 | 39 | flags.DEFINE_integer("max_seq_length", 320, 40 | "The maximum total input sequence length after WordPiece tokenization. " 41 | "Sequences longer than this will be truncated, and sequences shorter " 42 | "than this will be padded.") 43 | 44 | flags.DEFINE_integer("max_utr_num", 7, 45 | "Maximum utterance number.") 46 | 47 | 48 | def print_configuration_op(FLAGS): 49 | print('My Configurations:') 50 | for name, value in FLAGS.__flags.items(): 51 | value=value.value 52 | if type(value) == float: 53 | print(' %s:\t %f'%(name, value)) 54 | elif type(value) == int: 55 | print(' %s:\t %d'%(name, value)) 56 | elif type(value) == str: 57 | print(' %s:\t %s'%(name, value)) 58 | elif type(value) == bool: 59 | print(' %s:\t %s'%(name, value)) 60 | else: 61 | print('%s:\t %s' % (name, value)) 62 | print('End of configuration') 63 | 64 | 65 | def count_data_size(file_name): 66 | sample_nums = 0 67 | for record in tf.python_io.tf_record_iterator(file_name): 68 | sample_nums += 1 69 | return sample_nums 70 | 71 | 72 | def parse_exmp(serial_exmp): 73 | input_data = tf.parse_single_example(serial_exmp, 74 | features={ 75 | "input_sents": 76 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 77 | "input_mask": 78 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 79 | "segment_ids": 80 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 81 | "speaker_ids": 82 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 83 | "cls_positions": 84 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64), 85 | "rsp_position": 86 | tf.FixedLenFeature([1], tf.int64), 87 | "label_ids": 88 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64), 89 | "reply_mask_utr2word_flatten": 90 | tf.FixedLenFeature([FLAGS.max_utr_num * FLAGS.max_seq_length], tf.int64), 91 | "utr_lens": 92 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64), 93 | } 94 | ) 95 | # So cast all int64 to int32. 96 | for name in list(input_data.keys()): 97 | t = input_data[name] 98 | if t.dtype == tf.int64: 99 | t = tf.to_int32(t) 100 | input_data[name] = t 101 | 102 | input_sents = input_data["input_sents"] 103 | input_mask = input_data["input_mask"] 104 | segment_ids= input_data["segment_ids"] 105 | speaker_ids= input_data["speaker_ids"] 106 | cls_positions= input_data["cls_positions"] 107 | rsp_position= input_data["rsp_position"] 108 | reply_mask_utr2word = tf.reshape(input_data['reply_mask_utr2word_flatten'], [FLAGS.max_utr_num, FLAGS.max_seq_length]) 109 | utr_lens = input_data['utr_lens'] 110 | labels = input_data['label_ids'] 111 | 112 | reply_mask_word2word = [] 113 | for i in range(FLAGS.max_utr_num): 114 | reply_mask_utr2word_i = reply_mask_utr2word[i] # [max_seq_length, ] 115 | utr_len_i = utr_lens[i] # [1, ] 116 | reply_mask_utr2word_i_tiled = tf.tile(tf.expand_dims(reply_mask_utr2word_i, 0), [utr_len_i, 1]) # [utr_len, max_seq_length] 117 | reply_mask_word2word.append(reply_mask_utr2word_i_tiled) 118 | 119 | reply_mask_pad = tf.zeros(shape=[FLAGS.max_seq_length - tf.reduce_sum(utr_lens), FLAGS.max_seq_length], dtype=tf.int32) 120 | reply_mask_word2word.append(reply_mask_pad) 121 | reply_mask = tf.concat(reply_mask_word2word, 0) # [max_seq_length, max_seq_length] 122 | print("* Loading reply mask ...") 123 | 124 | return input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, reply_mask 125 | 126 | 127 | def gather_indexes(sequence_tensor, positions): 128 | """Gathers the vectors at the specific positions over a minibatch.""" 129 | # sequence_tensor = [batch_size, seq_length, width] 130 | # positions = [batch_size, max_utr_num] 131 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 132 | batch_size = sequence_shape[0] 133 | seq_length = sequence_shape[1] 134 | width = sequence_shape[2] 135 | 136 | flat_offsets = tf.reshape( 137 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) # [batch_size, 1] 138 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) # [batch_size*max_utr_num, ] 139 | flat_sequence_tensor = tf.reshape(sequence_tensor, 140 | [batch_size * seq_length, width]) 141 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) # [batch_size*max_utr_num, width] 142 | return output_tensor 143 | 144 | 145 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, reply_mask, 146 | num_labels, use_one_hot_embeddings): 147 | """Creates a classification model.""" 148 | model = modeling.BertModel( 149 | config=bert_config, 150 | is_training=is_training, 151 | input_ids=input_ids, 152 | input_mask=input_mask, 153 | token_type_ids=segment_ids, 154 | speaker_ids=speaker_ids, 155 | reply_mask=reply_mask, 156 | use_one_hot_embeddings=use_one_hot_embeddings) 157 | 158 | input_tensor = gather_indexes(model.get_sequence_output(), cls_positions) # [batch_size*max_utr_num, dim] 159 | 160 | input_shape = modeling.get_shape_list(input_tensor, expected_rank=2) 161 | width = input_shape[-1] 162 | positions_shape = modeling.get_shape_list(cls_positions, expected_rank=2) 163 | max_utr_num = positions_shape[-1] 164 | 165 | with tf.variable_scope("cls/addressee_recognize"): 166 | # We apply one more non-linear transformation before the output layer. 167 | with tf.variable_scope("transform"): 168 | input_tensor = tf.layers.dense( 169 | input_tensor, 170 | units=bert_config.hidden_size, 171 | activation=modeling.get_activation(bert_config.hidden_act), 172 | kernel_initializer=modeling.create_initializer(bert_config.initializer_range)) 173 | input_tensor = modeling.layer_norm(input_tensor) # [batch_size*max_utr_num, dim] 174 | 175 | input_tensor = tf.reshape(input_tensor, [-1, max_utr_num, width]) # [batch_size, max_utr_num, dim] 176 | 177 | rsp_tensor = gather_indexes(input_tensor, rsp_position) # [batch_size*1, dim] 178 | rsp_tensor = tf.reshape(rsp_tensor, [-1, 1, width]) # [batch_size, 1, dim] 179 | 180 | output_weights = tf.get_variable( 181 | "output_weights", 182 | shape=[width, width], 183 | initializer=modeling.create_initializer(bert_config.initializer_range)) 184 | logits = tf.matmul(tf.einsum('aij,jk->aik', rsp_tensor, output_weights), 185 | input_tensor, transpose_b=True) # [batch_size, 1, max_utr_num] 186 | logits = tf.squeeze(logits, [1]) # [batch_size, max_utr_num] 187 | 188 | mask = tf.sequence_mask(tf.reshape(rsp_position, [-1, ]), max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num] 189 | logits = logits * mask + -1e9 * (1-mask) 190 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size, max_utr_num] 191 | 192 | # loss 193 | one_hot_labels = tf.cast(labels, "float") # [batch_size, max_utr_num] 194 | per_example_loss = - tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch_size, ] 195 | mean_loss = tf.reduce_mean(per_example_loss, name="mean_loss") 196 | 197 | # accuracy 198 | predictions = tf.argmax(log_probs, axis=-1, output_type=tf.int32) # [batch_size, ] 199 | predictions_one_hot = tf.one_hot(predictions, depth=max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num] 200 | correct_prediction = tf.reduce_sum(predictions_one_hot * one_hot_labels, -1) # [batch_size, ] 201 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy") 202 | 203 | return mean_loss, logits, log_probs, accuracy 204 | 205 | 206 | def run_test(sess, training, prob, accuracy): 207 | 208 | step = 0 209 | t0 = time() 210 | num_test = 0 211 | num_correct = 0.0 212 | test_accuracy = 0 213 | 214 | try: 215 | while True: 216 | step += 1 217 | batch_accuracy, predicted_prob = sess.run([accuracy, prob], feed_dict={training: False}) 218 | 219 | num_test += len(predicted_prob) 220 | num_correct += len(predicted_prob) * batch_accuracy 221 | 222 | if step % 100 == 0: 223 | tf.logging.info("Step %d, Time (min): %.2f" % (step, (time() - t0) / 60.0)) 224 | 225 | except tf.errors.OutOfRangeError: 226 | test_accuracy = num_correct / num_test 227 | print('num_test_samples: {}, test_accuracy: {}'.format(num_test, test_accuracy)) 228 | 229 | return test_accuracy 230 | 231 | 232 | def main(_): 233 | tf.logging.set_verbosity(tf.logging.INFO) 234 | print_configuration_op(FLAGS) 235 | 236 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 237 | 238 | test_data_size = count_data_size(FLAGS.test_dir) 239 | tf.logging.info('test data size: {}'.format(test_data_size)) 240 | 241 | filenames = tf.placeholder(tf.string, shape=[None]) 242 | shuffle_size = tf.placeholder(tf.int64) 243 | dataset = tf.data.TFRecordDataset(filenames) 244 | dataset = dataset.map(parse_exmp) # Parse the record into tensors. 245 | dataset = dataset.repeat(1) 246 | # dataset = dataset.shuffle(shuffle_size) 247 | dataset = dataset.batch(FLAGS.eval_batch_size) 248 | iterator = dataset.make_initializable_iterator() 249 | input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, reply_mask = iterator.get_next() 250 | 251 | training = tf.placeholder(tf.bool) 252 | mean_loss, logits, log_probs, accuracy = create_model(bert_config = bert_config, 253 | is_training = training, 254 | input_ids = input_sents, 255 | input_mask = input_mask, 256 | segment_ids = segment_ids, 257 | speaker_ids = speaker_ids, 258 | cls_positions = cls_positions, 259 | rsp_position = rsp_position, 260 | reply_mask = reply_mask, 261 | labels = labels, 262 | num_labels = 1, 263 | use_one_hot_embeddings = False) 264 | 265 | config = tf.ConfigProto(allow_soft_placement=True) 266 | config.gpu_options.allow_growth = True 267 | 268 | if FLAGS.do_eval: 269 | with tf.Session(config=config) as sess: 270 | tf.logging.info("*** Restore model ***") 271 | 272 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir) 273 | variables = tf.trainable_variables() 274 | saver = tf.train.Saver(variables) 275 | saver.restore(sess, ckpt.model_checkpoint_path) 276 | 277 | tf.logging.info('Test begin') 278 | sess.run(iterator.initializer, 279 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1}) 280 | run_test(sess, training, log_probs, accuracy) 281 | 282 | 283 | if __name__ == "__main__": 284 | tf.app.run() 285 | -------------------------------------------------------------------------------- /run_testing_rs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """MPC-BERT testing runner on the downstream task of response selection.""" 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import operator 10 | from time import time 11 | from collections import defaultdict 12 | import tensorflow as tf 13 | import optimization 14 | import tokenization 15 | import modeling_speaker as modeling 16 | import metrics 17 | 18 | flags = tf.flags 19 | FLAGS = flags.FLAGS 20 | 21 | flags.DEFINE_string("task_name", 'Testing', 22 | "The name of the task.") 23 | 24 | flags.DEFINE_string("test_dir", 'test.tfrecord', 25 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.") 26 | 27 | flags.DEFINE_string("restore_model_dir", 'output/', 28 | "The output directory where the model checkpoints have been written.") 29 | 30 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json', 31 | "The config json file corresponding to the pre-trained BERT model. " 32 | "This specifies the model architecture.") 33 | 34 | flags.DEFINE_bool("do_eval", True, 35 | "Whether to run eval on the dev set.") 36 | 37 | flags.DEFINE_integer("eval_batch_size", 32, 38 | "Total batch size for predict.") 39 | 40 | flags.DEFINE_integer("max_seq_length", 320, 41 | "The maximum total input sequence length after WordPiece tokenization. " 42 | "Sequences longer than this will be truncated, and sequences shorter " 43 | "than this will be padded.") 44 | 45 | 46 | def print_configuration_op(FLAGS): 47 | print('My Configurations:') 48 | for name, value in FLAGS.__flags.items(): 49 | value=value.value 50 | if type(value) == float: 51 | print(' %s:\t %f'%(name, value)) 52 | elif type(value) == int: 53 | print(' %s:\t %d'%(name, value)) 54 | elif type(value) == str: 55 | print(' %s:\t %s'%(name, value)) 56 | elif type(value) == bool: 57 | print(' %s:\t %s'%(name, value)) 58 | else: 59 | print('%s:\t %s' % (name, value)) 60 | print('End of configuration') 61 | 62 | 63 | def count_data_size(file_name): 64 | sample_nums = 0 65 | for record in tf.python_io.tf_record_iterator(file_name): 66 | sample_nums += 1 67 | return sample_nums 68 | 69 | 70 | def parse_exmp(serial_exmp): 71 | input_data = tf.parse_single_example(serial_exmp, 72 | features={ 73 | "ctx_id": 74 | tf.FixedLenFeature([], tf.int64), 75 | "rsp_id": 76 | tf.FixedLenFeature([], tf.int64), 77 | "input_sents": 78 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 79 | "input_mask": 80 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 81 | "segment_ids": 82 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 83 | "speaker_ids": 84 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 85 | "label_ids": 86 | tf.FixedLenFeature([], tf.float32), 87 | } 88 | ) 89 | # So cast all int64 to int32. 90 | for name in list(input_data.keys()): 91 | t = input_data[name] 92 | if t.dtype == tf.int64: 93 | t = tf.to_int32(t) 94 | input_data[name] = t 95 | 96 | ctx_id = input_data["ctx_id"] 97 | rsp_id = input_data['rsp_id'] 98 | input_sents = input_data["input_sents"] 99 | input_mask = input_data["input_mask"] 100 | segment_ids= input_data["segment_ids"] 101 | speaker_ids= input_data["speaker_ids"] 102 | labels = input_data['label_ids'] 103 | return ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, labels 104 | 105 | 106 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, labels, ctx_id, rsp_id, 107 | num_labels, use_one_hot_embeddings): 108 | """Creates a classification model.""" 109 | model = modeling.BertModel( 110 | config=bert_config, 111 | is_training=is_training, 112 | input_ids=input_ids, 113 | input_mask=input_mask, 114 | token_type_ids=segment_ids, 115 | speaker_ids=speaker_ids, 116 | use_one_hot_embeddings=use_one_hot_embeddings) 117 | 118 | target_loss_weight = [1.0, 1.0] 119 | target_loss_weight = tf.convert_to_tensor(target_loss_weight) 120 | 121 | flagx = tf.cast(tf.greater(labels, 0), dtype=tf.float32) 122 | flagy = tf.cast(tf.equal(labels, 0), dtype=tf.float32) 123 | 124 | all_target_loss = target_loss_weight[1] * flagx + target_loss_weight[0] * flagy 125 | 126 | output_layer = model.get_pooled_output() 127 | 128 | hidden_size = output_layer.shape[-1].value 129 | 130 | output_weights = tf.get_variable( 131 | "output_weights", [num_labels, hidden_size], 132 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 133 | 134 | output_bias = tf.get_variable( 135 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 136 | 137 | with tf.variable_scope("loss"): 138 | # if is_training: 139 | # output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 140 | output_layer = tf.layers.dropout(output_layer, rate=0.1, training=is_training) 141 | 142 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 143 | logits = tf.nn.bias_add(logits, output_bias) 144 | 145 | probabilities = tf.sigmoid(logits, name="prob") 146 | logits = tf.squeeze(logits,[1]) 147 | losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels) 148 | losses = tf.multiply(losses, all_target_loss) 149 | 150 | mean_loss = tf.reduce_mean(losses, name="mean_loss") + sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 151 | 152 | with tf.name_scope("accuracy"): 153 | correct_prediction = tf.equal(tf.sign(probabilities - 0.5), tf.sign(labels - 0.5)) 154 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy") 155 | 156 | return mean_loss, logits, probabilities, accuracy 157 | 158 | 159 | best_score = 0.0 160 | def run_test(dir_path, op_name, sess, training, accuracy, prob, pair_ids): 161 | 162 | step = 0 163 | t0 = time() 164 | num_test = 0 165 | num_correct = 0.0 166 | mrr = 0 167 | results = defaultdict(list) 168 | 169 | try: 170 | while True: 171 | step += 1 172 | batch_accuracy, predicted_prob, batch_pair_ids = sess.run([accuracy, prob, pair_ids], feed_dict={training: False}) 173 | question_id, answer_id, label = batch_pair_ids 174 | 175 | num_test += len(predicted_prob) 176 | num_correct += len(predicted_prob) * batch_accuracy 177 | for i, prob_score in enumerate(predicted_prob): 178 | results[question_id[i]].append((answer_id[i], label[i], prob_score[0])) 179 | 180 | if step % 100 == 0: 181 | tf.logging.info("n_update %d , %s: Mins Used: %.2f" % 182 | (step, op_name, (time() - t0) / 60.0)) 183 | 184 | except tf.errors.OutOfRangeError: 185 | print('num_test_samples: {} test_accuracy: {}'.format(num_test, num_correct / num_test)) 186 | accu, precision, recall, f1, loss = metrics.classification_metrics(results) 187 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss)) 188 | 189 | mvp = metrics.mean_average_precision(results) 190 | mrr = metrics.mean_reciprocal_rank(results) 191 | top_1_precision = metrics.top_1_precision(results) 192 | total_valid_query = metrics.get_num_valid_query(results) 193 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'.format( 194 | mvp, mrr, top_1_precision, total_valid_query)) 195 | 196 | out_path = os.path.join(dir_path, "output_test.txt") 197 | print("Saving evaluation to {}".format(out_path)) 198 | with open(out_path, 'w') as f: 199 | f.write("query_id\tdocument_id\tscore\trank\trelevance\n") 200 | for us_id, v in results.items(): 201 | v.sort(key=operator.itemgetter(2), reverse=True) 202 | for i, rec in enumerate(v): 203 | r_id, label, prob_score = rec 204 | rank = i+1 205 | f.write('{}\t{}\t{}\t{}\t{}\n'.format(us_id, r_id, prob_score, rank, label)) 206 | return mrr 207 | 208 | 209 | def main(_): 210 | tf.logging.set_verbosity(tf.logging.INFO) 211 | print_configuration_op(FLAGS) 212 | 213 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 214 | 215 | test_data_size = count_data_size(FLAGS.test_dir) 216 | tf.logging.info('test data size: {}'.format(test_data_size)) 217 | 218 | filenames = tf.placeholder(tf.string, shape=[None]) 219 | shuffle_size = tf.placeholder(tf.int64) 220 | dataset = tf.data.TFRecordDataset(filenames) 221 | dataset = dataset.map(parse_exmp) # Parse the record into tensors. 222 | dataset = dataset.repeat(1) 223 | # dataset = dataset.shuffle(shuffle_size) 224 | dataset = dataset.batch(FLAGS.eval_batch_size) 225 | iterator = dataset.make_initializable_iterator() 226 | ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, labels = iterator.get_next() 227 | pair_ids = [ctx_id, rsp_id, labels] 228 | 229 | training = tf.placeholder(tf.bool) 230 | mean_loss, logits, probabilities, accuracy = create_model(bert_config = bert_config, 231 | is_training = training, 232 | input_ids = input_sents, 233 | input_mask = input_mask, 234 | segment_ids = segment_ids, 235 | speaker_ids = speaker_ids, 236 | labels = labels, 237 | ctx_id = ctx_id, 238 | rsp_id = rsp_id, 239 | num_labels = 1, 240 | use_one_hot_embeddings = False) 241 | 242 | 243 | config = tf.ConfigProto(allow_soft_placement=True) 244 | config.gpu_options.allow_growth = True 245 | 246 | if FLAGS.do_eval: 247 | with tf.Session(config=config) as sess: 248 | tf.logging.info("*** Restore model ***") 249 | 250 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir) 251 | variables = tf.trainable_variables() 252 | saver = tf.train.Saver(variables) 253 | saver.restore(sess, ckpt.model_checkpoint_path) 254 | 255 | tf.logging.info('Test begin') 256 | sess.run(iterator.initializer, 257 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1}) 258 | run_test(FLAGS.restore_model_dir, "test", sess, training, accuracy, probabilities, pair_ids) 259 | 260 | 261 | if __name__ == "__main__": 262 | tf.app.run() 263 | 264 | -------------------------------------------------------------------------------- /run_testing_rs_gift.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """MPC-BERT testing runner on the downstream task of response selection.""" 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import operator 10 | from time import time 11 | from collections import defaultdict 12 | import tensorflow as tf 13 | import optimization 14 | import tokenization 15 | import modeling_speaker_gift as modeling 16 | import metrics 17 | 18 | flags = tf.flags 19 | FLAGS = flags.FLAGS 20 | 21 | flags.DEFINE_string("task_name", 'Testing', 22 | "The name of the task.") 23 | 24 | flags.DEFINE_string("test_dir", 'test.tfrecord', 25 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.") 26 | 27 | flags.DEFINE_string("restore_model_dir", 'output/', 28 | "The output directory where the model checkpoints have been written.") 29 | 30 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json', 31 | "The config json file corresponding to the pre-trained BERT model. " 32 | "This specifies the model architecture.") 33 | 34 | flags.DEFINE_bool("do_eval", True, 35 | "Whether to run eval on the dev set.") 36 | 37 | flags.DEFINE_integer("eval_batch_size", 32, 38 | "Total batch size for predict.") 39 | 40 | flags.DEFINE_integer("max_seq_length", 320, 41 | "The maximum total input sequence length after WordPiece tokenization. " 42 | "Sequences longer than this will be truncated, and sequences shorter " 43 | "than this will be padded.") 44 | 45 | flags.DEFINE_integer("max_utr_num", 7, 46 | "Maximum utterance number.") 47 | 48 | 49 | def print_configuration_op(FLAGS): 50 | print('My Configurations:') 51 | for name, value in FLAGS.__flags.items(): 52 | value=value.value 53 | if type(value) == float: 54 | print(' %s:\t %f'%(name, value)) 55 | elif type(value) == int: 56 | print(' %s:\t %d'%(name, value)) 57 | elif type(value) == str: 58 | print(' %s:\t %s'%(name, value)) 59 | elif type(value) == bool: 60 | print(' %s:\t %s'%(name, value)) 61 | else: 62 | print('%s:\t %s' % (name, value)) 63 | print('End of configuration') 64 | 65 | 66 | def count_data_size(file_name): 67 | sample_nums = 0 68 | for record in tf.python_io.tf_record_iterator(file_name): 69 | sample_nums += 1 70 | return sample_nums 71 | 72 | 73 | def parse_exmp(serial_exmp): 74 | input_data = tf.parse_single_example(serial_exmp, 75 | features={ 76 | "ctx_id": 77 | tf.FixedLenFeature([], tf.int64), 78 | "rsp_id": 79 | tf.FixedLenFeature([], tf.int64), 80 | "input_sents": 81 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 82 | "input_mask": 83 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 84 | "segment_ids": 85 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 86 | "speaker_ids": 87 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 88 | "reply_mask_utr2word_flatten": 89 | tf.FixedLenFeature([FLAGS.max_utr_num * FLAGS.max_seq_length], tf.int64), 90 | "utr_lens": 91 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64), 92 | "label_ids": 93 | tf.FixedLenFeature([], tf.float32), 94 | } 95 | ) 96 | # So cast all int64 to int32. 97 | for name in list(input_data.keys()): 98 | t = input_data[name] 99 | if t.dtype == tf.int64: 100 | t = tf.to_int32(t) 101 | input_data[name] = t 102 | 103 | ctx_id = input_data["ctx_id"] 104 | rsp_id = input_data['rsp_id'] 105 | input_sents = input_data["input_sents"] 106 | input_mask = input_data["input_mask"] 107 | segment_ids= input_data["segment_ids"] 108 | speaker_ids= input_data["speaker_ids"] 109 | reply_mask_utr2word = tf.reshape(input_data['reply_mask_utr2word_flatten'], [FLAGS.max_utr_num, FLAGS.max_seq_length]) 110 | utr_lens = input_data['utr_lens'] 111 | labels = input_data['label_ids'] 112 | 113 | reply_mask_word2word = [] 114 | for i in range(FLAGS.max_utr_num): 115 | reply_mask_utr2word_i = reply_mask_utr2word[i] # [max_seq_length, ] 116 | utr_len_i = utr_lens[i] # [1, ] 117 | reply_mask_utr2word_i_tiled = tf.tile(tf.expand_dims(reply_mask_utr2word_i, 0), [utr_len_i, 1]) # [utr_len, max_seq_length] 118 | reply_mask_word2word.append(reply_mask_utr2word_i_tiled) 119 | 120 | reply_mask_pad = tf.zeros(shape=[FLAGS.max_seq_length - tf.reduce_sum(utr_lens), FLAGS.max_seq_length], dtype=tf.int32) 121 | reply_mask_word2word.append(reply_mask_pad) 122 | reply_mask = tf.concat(reply_mask_word2word, 0) # [max_seq_length, max_seq_length] 123 | print("* Loading reply mask ...") 124 | 125 | return ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, reply_mask, labels 126 | 127 | 128 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, reply_mask, labels, ctx_id, rsp_id, 129 | num_labels, use_one_hot_embeddings): 130 | """Creates a classification model.""" 131 | model = modeling.BertModel( 132 | config=bert_config, 133 | is_training=is_training, 134 | input_ids=input_ids, 135 | input_mask=input_mask, 136 | token_type_ids=segment_ids, 137 | speaker_ids=speaker_ids, 138 | reply_mask=reply_mask, 139 | use_one_hot_embeddings=use_one_hot_embeddings) 140 | 141 | target_loss_weight = [1.0, 1.0] 142 | target_loss_weight = tf.convert_to_tensor(target_loss_weight) 143 | 144 | flagx = tf.cast(tf.greater(labels, 0), dtype=tf.float32) 145 | flagy = tf.cast(tf.equal(labels, 0), dtype=tf.float32) 146 | 147 | all_target_loss = target_loss_weight[1] * flagx + target_loss_weight[0] * flagy 148 | 149 | output_layer = model.get_pooled_output() 150 | 151 | hidden_size = output_layer.shape[-1].value 152 | 153 | output_weights = tf.get_variable( 154 | "output_weights", [num_labels, hidden_size], 155 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 156 | 157 | output_bias = tf.get_variable( 158 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 159 | 160 | with tf.variable_scope("loss"): 161 | # if is_training: 162 | # output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 163 | output_layer = tf.layers.dropout(output_layer, rate=0.1, training=is_training) 164 | 165 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 166 | logits = tf.nn.bias_add(logits, output_bias) 167 | 168 | probabilities = tf.sigmoid(logits, name="prob") 169 | logits = tf.squeeze(logits,[1]) 170 | losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels) 171 | losses = tf.multiply(losses, all_target_loss) 172 | 173 | mean_loss = tf.reduce_mean(losses, name="mean_loss") + sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 174 | 175 | with tf.name_scope("accuracy"): 176 | correct_prediction = tf.equal(tf.sign(probabilities - 0.5), tf.sign(labels - 0.5)) 177 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy") 178 | 179 | return mean_loss, logits, probabilities, accuracy 180 | 181 | 182 | best_score = 0.0 183 | def run_test(dir_path, op_name, sess, training, accuracy, prob, pair_ids): 184 | 185 | step = 0 186 | t0 = time() 187 | num_test = 0 188 | num_correct = 0.0 189 | mrr = 0 190 | results = defaultdict(list) 191 | 192 | try: 193 | while True: 194 | step += 1 195 | batch_accuracy, predicted_prob, batch_pair_ids = sess.run([accuracy, prob, pair_ids], feed_dict={training: False}) 196 | question_id, answer_id, label = batch_pair_ids 197 | 198 | num_test += len(predicted_prob) 199 | num_correct += len(predicted_prob) * batch_accuracy 200 | for i, prob_score in enumerate(predicted_prob): 201 | results[question_id[i]].append((answer_id[i], label[i], prob_score[0])) 202 | 203 | if step % 100 == 0: 204 | tf.logging.info("n_update %d , %s: Mins Used: %.2f" % 205 | (step, op_name, (time() - t0) / 60.0)) 206 | 207 | except tf.errors.OutOfRangeError: 208 | print('num_test_samples: {} test_accuracy: {}'.format(num_test, num_correct / num_test)) 209 | accu, precision, recall, f1, loss = metrics.classification_metrics(results) 210 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss)) 211 | 212 | mvp = metrics.mean_average_precision(results) 213 | mrr = metrics.mean_reciprocal_rank(results) 214 | top_1_precision = metrics.top_1_precision(results) 215 | total_valid_query = metrics.get_num_valid_query(results) 216 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'.format( 217 | mvp, mrr, top_1_precision, total_valid_query)) 218 | 219 | out_path = os.path.join(dir_path, "output_test.txt") 220 | print("Saving evaluation to {}".format(out_path)) 221 | with open(out_path, 'w') as f: 222 | f.write("query_id\tdocument_id\tscore\trank\trelevance\n") 223 | for us_id, v in results.items(): 224 | v.sort(key=operator.itemgetter(2), reverse=True) 225 | for i, rec in enumerate(v): 226 | r_id, label, prob_score = rec 227 | rank = i+1 228 | f.write('{}\t{}\t{}\t{}\t{}\n'.format(us_id, r_id, prob_score, rank, label)) 229 | return mrr 230 | 231 | 232 | def main(_): 233 | tf.logging.set_verbosity(tf.logging.INFO) 234 | print_configuration_op(FLAGS) 235 | 236 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 237 | 238 | test_data_size = count_data_size(FLAGS.test_dir) 239 | tf.logging.info('test data size: {}'.format(test_data_size)) 240 | 241 | filenames = tf.placeholder(tf.string, shape=[None]) 242 | shuffle_size = tf.placeholder(tf.int64) 243 | dataset = tf.data.TFRecordDataset(filenames) 244 | dataset = dataset.map(parse_exmp) # Parse the record into tensors. 245 | dataset = dataset.repeat(1) 246 | # dataset = dataset.shuffle(shuffle_size) 247 | dataset = dataset.batch(FLAGS.eval_batch_size) 248 | iterator = dataset.make_initializable_iterator() 249 | ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, reply_mask, labels = iterator.get_next() 250 | pair_ids = [ctx_id, rsp_id, labels] 251 | 252 | training = tf.placeholder(tf.bool) 253 | mean_loss, logits, probabilities, accuracy = create_model(bert_config = bert_config, 254 | is_training = training, 255 | input_ids = input_sents, 256 | input_mask = input_mask, 257 | segment_ids = segment_ids, 258 | speaker_ids = speaker_ids, 259 | reply_mask = reply_mask, 260 | labels = labels, 261 | ctx_id = ctx_id, 262 | rsp_id = rsp_id, 263 | num_labels = 1, 264 | use_one_hot_embeddings = False) 265 | 266 | 267 | config = tf.ConfigProto(allow_soft_placement=True) 268 | config.gpu_options.allow_growth = True 269 | 270 | if FLAGS.do_eval: 271 | with tf.Session(config=config) as sess: 272 | tf.logging.info("*** Restore model ***") 273 | 274 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir) 275 | variables = tf.trainable_variables() 276 | saver = tf.train.Saver(variables) 277 | saver.restore(sess, ckpt.model_checkpoint_path) 278 | 279 | tf.logging.info('Test begin') 280 | sess.run(iterator.initializer, 281 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1}) 282 | run_test(FLAGS.restore_model_dir, "test", sess, training, accuracy, probabilities, pair_ids) 283 | 284 | 285 | if __name__ == "__main__": 286 | tf.app.run() 287 | 288 | -------------------------------------------------------------------------------- /run_testing_si.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """MPC-BERT testing runner on the downstream task of speaker identification.""" 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import operator 10 | from time import time 11 | from collections import defaultdict 12 | import tensorflow as tf 13 | import optimization 14 | import tokenization 15 | import modeling_speaker as modeling 16 | 17 | flags = tf.flags 18 | FLAGS = flags.FLAGS 19 | 20 | flags.DEFINE_string("task_name", 'Testing', 21 | "The name of the task.") 22 | 23 | flags.DEFINE_string("test_dir", 'test.tfrecord', 24 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.") 25 | 26 | flags.DEFINE_string("restore_model_dir", 'output/', 27 | "The output directory where the model checkpoints have been written.") 28 | 29 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json', 30 | "The config json file corresponding to the pre-trained BERT model. " 31 | "This specifies the model architecture.") 32 | 33 | flags.DEFINE_bool("do_eval", True, 34 | "Whether to run eval on the dev set.") 35 | 36 | flags.DEFINE_integer("eval_batch_size", 32, 37 | "Total batch size for predict.") 38 | 39 | flags.DEFINE_integer("max_seq_length", 320, 40 | "The maximum total input sequence length after WordPiece tokenization. " 41 | "Sequences longer than this will be truncated, and sequences shorter " 42 | "than this will be padded.") 43 | 44 | flags.DEFINE_integer("max_utr_num", 7, 45 | "Maximum utterance number.") 46 | 47 | 48 | def print_configuration_op(FLAGS): 49 | print('My Configurations:') 50 | for name, value in FLAGS.__flags.items(): 51 | value=value.value 52 | if type(value) == float: 53 | print(' %s:\t %f'%(name, value)) 54 | elif type(value) == int: 55 | print(' %s:\t %d'%(name, value)) 56 | elif type(value) == str: 57 | print(' %s:\t %s'%(name, value)) 58 | elif type(value) == bool: 59 | print(' %s:\t %s'%(name, value)) 60 | else: 61 | print('%s:\t %s' % (name, value)) 62 | print('End of configuration') 63 | 64 | 65 | def count_data_size(file_name): 66 | sample_nums = 0 67 | for record in tf.python_io.tf_record_iterator(file_name): 68 | sample_nums += 1 69 | return sample_nums 70 | 71 | 72 | def parse_exmp(serial_exmp): 73 | input_data = tf.parse_single_example(serial_exmp, 74 | features={ 75 | "input_sents": 76 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 77 | "input_mask": 78 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 79 | "segment_ids": 80 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 81 | "speaker_ids": 82 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 83 | "cls_positions": 84 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64), 85 | "rsp_position": 86 | tf.FixedLenFeature([1], tf.int64), 87 | "label_ids": 88 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64), 89 | } 90 | ) 91 | # So cast all int64 to int32. 92 | for name in list(input_data.keys()): 93 | t = input_data[name] 94 | if t.dtype == tf.int64: 95 | t = tf.to_int32(t) 96 | input_data[name] = t 97 | 98 | input_sents = input_data["input_sents"] 99 | input_mask = input_data["input_mask"] 100 | segment_ids= input_data["segment_ids"] 101 | speaker_ids= input_data["speaker_ids"] 102 | cls_positions= input_data["cls_positions"] 103 | rsp_position= input_data["rsp_position"] 104 | labels = input_data['label_ids'] 105 | return input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels 106 | 107 | 108 | def gather_indexes(sequence_tensor, positions): 109 | """Gathers the vectors at the specific positions over a minibatch.""" 110 | # sequence_tensor = [batch_size, seq_length, width] 111 | # positions = [batch_size, max_utr_num] 112 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 113 | batch_size = sequence_shape[0] 114 | seq_length = sequence_shape[1] 115 | width = sequence_shape[2] 116 | 117 | flat_offsets = tf.reshape( 118 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) # [batch_size, 1] 119 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) # [batch_size*max_utr_num, ] 120 | flat_sequence_tensor = tf.reshape(sequence_tensor, 121 | [batch_size * seq_length, width]) 122 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) # [batch_size*max_utr_num, width] 123 | return output_tensor 124 | 125 | 126 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, 127 | num_labels, use_one_hot_embeddings): 128 | """Creates a classification model.""" 129 | model = modeling.BertModel( 130 | config=bert_config, 131 | is_training=is_training, 132 | input_ids=input_ids, 133 | input_mask=input_mask, 134 | token_type_ids=segment_ids, 135 | speaker_ids=speaker_ids, 136 | use_one_hot_embeddings=use_one_hot_embeddings) 137 | 138 | input_tensor = gather_indexes(model.get_sequence_output(), cls_positions) # [batch_size*max_utr_num, dim] 139 | 140 | input_shape = modeling.get_shape_list(input_tensor, expected_rank=2) 141 | width = input_shape[-1] 142 | positions_shape = modeling.get_shape_list(cls_positions, expected_rank=2) 143 | max_utr_num = positions_shape[-1] 144 | 145 | with tf.variable_scope("cls/speaker_restore"): 146 | # We apply one more non-linear transformation before the output layer. 147 | with tf.variable_scope("transform"): 148 | input_tensor = tf.layers.dense( 149 | input_tensor, 150 | units=bert_config.hidden_size, 151 | activation=modeling.get_activation(bert_config.hidden_act), 152 | kernel_initializer=modeling.create_initializer(bert_config.initializer_range)) 153 | input_tensor = modeling.layer_norm(input_tensor) # [batch_size*max_utr_num, dim] 154 | 155 | input_tensor = tf.reshape(input_tensor, [-1, max_utr_num, width]) # [batch_size, max_utr_num, dim] 156 | 157 | rsp_tensor = gather_indexes(input_tensor, rsp_position) # [batch_size*1, dim] 158 | rsp_tensor = tf.reshape(rsp_tensor, [-1, 1, width]) # [batch_size, 1, dim] 159 | 160 | output_weights = tf.get_variable( 161 | "output_weights", 162 | shape=[width, width], 163 | initializer=modeling.create_initializer(bert_config.initializer_range)) 164 | logits = tf.matmul(tf.einsum('aij,jk->aik', rsp_tensor, output_weights), 165 | input_tensor, transpose_b=True) # [batch_size, 1, max_utr_num] 166 | logits = tf.squeeze(logits, [1]) # [batch_size, max_utr_num] 167 | 168 | mask = tf.sequence_mask(tf.reshape(rsp_position, [-1, ]), max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num] 169 | logits = logits * mask + -1e9 * (1-mask) 170 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size, max_utr_num] 171 | 172 | # loss 173 | one_hot_labels = tf.cast(labels, "float") # [batch_size, max_utr_num] 174 | per_example_loss = - tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch_size, ] 175 | mean_loss = tf.reduce_mean(per_example_loss, name="mean_loss") 176 | 177 | # accuracy 178 | predictions = tf.argmax(log_probs, axis=-1, output_type=tf.int32) # [batch_size, ] 179 | predictions_one_hot = tf.one_hot(predictions, depth=max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num] 180 | correct_prediction = tf.reduce_sum(predictions_one_hot * one_hot_labels, -1) # [batch_size, ] 181 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy") 182 | 183 | return mean_loss, logits, log_probs, accuracy 184 | 185 | 186 | def run_test(sess, training, prob, accuracy): 187 | 188 | step = 0 189 | t0 = time() 190 | num_test = 0 191 | num_correct = 0.0 192 | test_accuracy = 0 193 | 194 | try: 195 | while True: 196 | step += 1 197 | batch_accuracy, predicted_prob = sess.run([accuracy, prob], feed_dict={training: False}) 198 | 199 | num_test += len(predicted_prob) 200 | num_correct += len(predicted_prob) * batch_accuracy 201 | 202 | if step % 100 == 0: 203 | tf.logging.info("Step %d, Time (min): %.2f" % (step, (time() - t0) / 60.0)) 204 | 205 | except tf.errors.OutOfRangeError: 206 | test_accuracy = num_correct / num_test 207 | print('num_test_samples: {}, test_accuracy: {}'.format(num_test, test_accuracy)) 208 | 209 | return test_accuracy 210 | 211 | 212 | def main(_): 213 | tf.logging.set_verbosity(tf.logging.INFO) 214 | print_configuration_op(FLAGS) 215 | 216 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 217 | 218 | test_data_size = count_data_size(FLAGS.test_dir) 219 | tf.logging.info('test data size: {}'.format(test_data_size)) 220 | 221 | filenames = tf.placeholder(tf.string, shape=[None]) 222 | shuffle_size = tf.placeholder(tf.int64) 223 | dataset = tf.data.TFRecordDataset(filenames) 224 | dataset = dataset.map(parse_exmp) # Parse the record into tensors. 225 | dataset = dataset.repeat(1) 226 | # dataset = dataset.shuffle(shuffle_size) 227 | dataset = dataset.batch(FLAGS.eval_batch_size) 228 | iterator = dataset.make_initializable_iterator() 229 | input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels = iterator.get_next() 230 | 231 | training = tf.placeholder(tf.bool) 232 | mean_loss, logits, log_probs, accuracy = create_model(bert_config = bert_config, 233 | is_training = training, 234 | input_ids = input_sents, 235 | input_mask = input_mask, 236 | segment_ids = segment_ids, 237 | speaker_ids = speaker_ids, 238 | cls_positions = cls_positions, 239 | rsp_position = rsp_position, 240 | labels = labels, 241 | num_labels = 1, 242 | use_one_hot_embeddings = False) 243 | 244 | config = tf.ConfigProto(allow_soft_placement=True) 245 | config.gpu_options.allow_growth = True 246 | 247 | if FLAGS.do_eval: 248 | with tf.Session(config=config) as sess: 249 | tf.logging.info("*** Restore model ***") 250 | 251 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir) 252 | variables = tf.trainable_variables() 253 | saver = tf.train.Saver(variables) 254 | saver.restore(sess, ckpt.model_checkpoint_path) 255 | 256 | tf.logging.info('Test begin') 257 | sess.run(iterator.initializer, 258 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1}) 259 | run_test(sess, training, log_probs, accuracy) 260 | 261 | 262 | if __name__ == "__main__": 263 | tf.app.run() 264 | -------------------------------------------------------------------------------- /run_testing_si_gift.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """MPC-BERT testing runner on the downstream task of speaker identification.""" 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import operator 10 | from time import time 11 | from collections import defaultdict 12 | import tensorflow as tf 13 | import optimization 14 | import tokenization 15 | import modeling_speaker_gift as modeling 16 | 17 | flags = tf.flags 18 | FLAGS = flags.FLAGS 19 | 20 | flags.DEFINE_string("task_name", 'Testing', 21 | "The name of the task.") 22 | 23 | flags.DEFINE_string("test_dir", 'test.tfrecord', 24 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.") 25 | 26 | flags.DEFINE_string("restore_model_dir", 'output/', 27 | "The output directory where the model checkpoints have been written.") 28 | 29 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json', 30 | "The config json file corresponding to the pre-trained BERT model. " 31 | "This specifies the model architecture.") 32 | 33 | flags.DEFINE_bool("do_eval", True, 34 | "Whether to run eval on the dev set.") 35 | 36 | flags.DEFINE_integer("eval_batch_size", 32, 37 | "Total batch size for predict.") 38 | 39 | flags.DEFINE_integer("max_seq_length", 320, 40 | "The maximum total input sequence length after WordPiece tokenization. " 41 | "Sequences longer than this will be truncated, and sequences shorter " 42 | "than this will be padded.") 43 | 44 | flags.DEFINE_integer("max_utr_num", 7, 45 | "Maximum utterance number.") 46 | 47 | 48 | def print_configuration_op(FLAGS): 49 | print('My Configurations:') 50 | for name, value in FLAGS.__flags.items(): 51 | value=value.value 52 | if type(value) == float: 53 | print(' %s:\t %f'%(name, value)) 54 | elif type(value) == int: 55 | print(' %s:\t %d'%(name, value)) 56 | elif type(value) == str: 57 | print(' %s:\t %s'%(name, value)) 58 | elif type(value) == bool: 59 | print(' %s:\t %s'%(name, value)) 60 | else: 61 | print('%s:\t %s' % (name, value)) 62 | print('End of configuration') 63 | 64 | 65 | def count_data_size(file_name): 66 | sample_nums = 0 67 | for record in tf.python_io.tf_record_iterator(file_name): 68 | sample_nums += 1 69 | return sample_nums 70 | 71 | 72 | def parse_exmp(serial_exmp): 73 | input_data = tf.parse_single_example(serial_exmp, 74 | features={ 75 | "input_sents": 76 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 77 | "input_mask": 78 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 79 | "segment_ids": 80 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 81 | "speaker_ids": 82 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64), 83 | "cls_positions": 84 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64), 85 | "rsp_position": 86 | tf.FixedLenFeature([1], tf.int64), 87 | "label_ids": 88 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64), 89 | "reply_mask_utr2word_flatten": 90 | tf.FixedLenFeature([FLAGS.max_utr_num * FLAGS.max_seq_length], tf.int64), 91 | "utr_lens": 92 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64), 93 | } 94 | ) 95 | # So cast all int64 to int32. 96 | for name in list(input_data.keys()): 97 | t = input_data[name] 98 | if t.dtype == tf.int64: 99 | t = tf.to_int32(t) 100 | input_data[name] = t 101 | 102 | input_sents = input_data["input_sents"] 103 | input_mask = input_data["input_mask"] 104 | segment_ids= input_data["segment_ids"] 105 | speaker_ids= input_data["speaker_ids"] 106 | cls_positions= input_data["cls_positions"] 107 | rsp_position= input_data["rsp_position"] 108 | reply_mask_utr2word = tf.reshape(input_data['reply_mask_utr2word_flatten'], [FLAGS.max_utr_num, FLAGS.max_seq_length]) 109 | utr_lens = input_data['utr_lens'] 110 | labels = input_data['label_ids'] 111 | 112 | reply_mask_word2word = [] 113 | for i in range(FLAGS.max_utr_num): 114 | reply_mask_utr2word_i = reply_mask_utr2word[i] # [max_seq_length, ] 115 | utr_len_i = utr_lens[i] # [1, ] 116 | reply_mask_utr2word_i_tiled = tf.tile(tf.expand_dims(reply_mask_utr2word_i, 0), [utr_len_i, 1]) # [utr_len, max_seq_length] 117 | reply_mask_word2word.append(reply_mask_utr2word_i_tiled) 118 | 119 | reply_mask_pad = tf.zeros(shape=[FLAGS.max_seq_length - tf.reduce_sum(utr_lens), FLAGS.max_seq_length], dtype=tf.int32) 120 | reply_mask_word2word.append(reply_mask_pad) 121 | reply_mask = tf.concat(reply_mask_word2word, 0) # [max_seq_length, max_seq_length] 122 | print("* Loading reply mask ...") 123 | 124 | return input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, reply_mask 125 | 126 | 127 | def gather_indexes(sequence_tensor, positions): 128 | """Gathers the vectors at the specific positions over a minibatch.""" 129 | # sequence_tensor = [batch_size, seq_length, width] 130 | # positions = [batch_size, max_utr_num] 131 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 132 | batch_size = sequence_shape[0] 133 | seq_length = sequence_shape[1] 134 | width = sequence_shape[2] 135 | 136 | flat_offsets = tf.reshape( 137 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) # [batch_size, 1] 138 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) # [batch_size*max_utr_num, ] 139 | flat_sequence_tensor = tf.reshape(sequence_tensor, 140 | [batch_size * seq_length, width]) 141 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) # [batch_size*max_utr_num, width] 142 | return output_tensor 143 | 144 | 145 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, reply_mask, 146 | num_labels, use_one_hot_embeddings): 147 | """Creates a classification model.""" 148 | model = modeling.BertModel( 149 | config=bert_config, 150 | is_training=is_training, 151 | input_ids=input_ids, 152 | input_mask=input_mask, 153 | token_type_ids=segment_ids, 154 | speaker_ids=speaker_ids, 155 | reply_mask=reply_mask, 156 | use_one_hot_embeddings=use_one_hot_embeddings) 157 | 158 | input_tensor = gather_indexes(model.get_sequence_output(), cls_positions) # [batch_size*max_utr_num, dim] 159 | 160 | input_shape = modeling.get_shape_list(input_tensor, expected_rank=2) 161 | width = input_shape[-1] 162 | positions_shape = modeling.get_shape_list(cls_positions, expected_rank=2) 163 | max_utr_num = positions_shape[-1] 164 | 165 | with tf.variable_scope("cls/speaker_restore"): 166 | # We apply one more non-linear transformation before the output layer. 167 | with tf.variable_scope("transform"): 168 | input_tensor = tf.layers.dense( 169 | input_tensor, 170 | units=bert_config.hidden_size, 171 | activation=modeling.get_activation(bert_config.hidden_act), 172 | kernel_initializer=modeling.create_initializer(bert_config.initializer_range)) 173 | input_tensor = modeling.layer_norm(input_tensor) # [batch_size*max_utr_num, dim] 174 | 175 | input_tensor = tf.reshape(input_tensor, [-1, max_utr_num, width]) # [batch_size, max_utr_num, dim] 176 | 177 | rsp_tensor = gather_indexes(input_tensor, rsp_position) # [batch_size*1, dim] 178 | rsp_tensor = tf.reshape(rsp_tensor, [-1, 1, width]) # [batch_size, 1, dim] 179 | 180 | output_weights = tf.get_variable( 181 | "output_weights", 182 | shape=[width, width], 183 | initializer=modeling.create_initializer(bert_config.initializer_range)) 184 | logits = tf.matmul(tf.einsum('aij,jk->aik', rsp_tensor, output_weights), 185 | input_tensor, transpose_b=True) # [batch_size, 1, max_utr_num] 186 | logits = tf.squeeze(logits, [1]) # [batch_size, max_utr_num] 187 | 188 | mask = tf.sequence_mask(tf.reshape(rsp_position, [-1, ]), max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num] 189 | logits = logits * mask + -1e9 * (1-mask) 190 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size, max_utr_num] 191 | 192 | # loss 193 | one_hot_labels = tf.cast(labels, "float") # [batch_size, max_utr_num] 194 | per_example_loss = - tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch_size, ] 195 | mean_loss = tf.reduce_mean(per_example_loss, name="mean_loss") 196 | 197 | # accuracy 198 | predictions = tf.argmax(log_probs, axis=-1, output_type=tf.int32) # [batch_size, ] 199 | predictions_one_hot = tf.one_hot(predictions, depth=max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num] 200 | correct_prediction = tf.reduce_sum(predictions_one_hot * one_hot_labels, -1) # [batch_size, ] 201 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy") 202 | 203 | return mean_loss, logits, log_probs, accuracy 204 | 205 | 206 | def run_test(sess, training, prob, accuracy): 207 | 208 | step = 0 209 | t0 = time() 210 | num_test = 0 211 | num_correct = 0.0 212 | test_accuracy = 0 213 | 214 | try: 215 | while True: 216 | step += 1 217 | batch_accuracy, predicted_prob = sess.run([accuracy, prob], feed_dict={training: False}) 218 | 219 | num_test += len(predicted_prob) 220 | num_correct += len(predicted_prob) * batch_accuracy 221 | 222 | if step % 100 == 0: 223 | tf.logging.info("Step %d, Time (min): %.2f" % (step, (time() - t0) / 60.0)) 224 | 225 | except tf.errors.OutOfRangeError: 226 | test_accuracy = num_correct / num_test 227 | print('num_test_samples: {}, test_accuracy: {}'.format(num_test, test_accuracy)) 228 | 229 | return test_accuracy 230 | 231 | 232 | def main(_): 233 | tf.logging.set_verbosity(tf.logging.INFO) 234 | print_configuration_op(FLAGS) 235 | 236 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 237 | 238 | test_data_size = count_data_size(FLAGS.test_dir) 239 | tf.logging.info('test data size: {}'.format(test_data_size)) 240 | 241 | filenames = tf.placeholder(tf.string, shape=[None]) 242 | shuffle_size = tf.placeholder(tf.int64) 243 | dataset = tf.data.TFRecordDataset(filenames) 244 | dataset = dataset.map(parse_exmp) # Parse the record into tensors. 245 | dataset = dataset.repeat(1) 246 | # dataset = dataset.shuffle(shuffle_size) 247 | dataset = dataset.batch(FLAGS.eval_batch_size) 248 | iterator = dataset.make_initializable_iterator() 249 | input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, reply_mask = iterator.get_next() 250 | 251 | training = tf.placeholder(tf.bool) 252 | mean_loss, logits, log_probs, accuracy = create_model(bert_config = bert_config, 253 | is_training = training, 254 | input_ids = input_sents, 255 | input_mask = input_mask, 256 | segment_ids = segment_ids, 257 | speaker_ids = speaker_ids, 258 | cls_positions = cls_positions, 259 | rsp_position = rsp_position, 260 | reply_mask = reply_mask, 261 | labels = labels, 262 | num_labels = 1, 263 | use_one_hot_embeddings = False) 264 | 265 | config = tf.ConfigProto(allow_soft_placement=True) 266 | config.gpu_options.allow_growth = True 267 | 268 | if FLAGS.do_eval: 269 | with tf.Session(config=config) as sess: 270 | tf.logging.info("*** Restore model ***") 271 | 272 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir) 273 | variables = tf.trainable_variables() 274 | saver = tf.train.Saver(variables) 275 | saver.restore(sess, ckpt.model_checkpoint_path) 276 | 277 | tf.logging.info('Test begin') 278 | sess.run(iterator.initializer, 279 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1}) 280 | run_test(sess, training, log_probs, accuracy) 281 | 282 | 283 | if __name__ == "__main__": 284 | tf.app.run() 285 | -------------------------------------------------------------------------------- /scripts/run_finetuning.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0 python -u ../run_finetuning_ar.py \ 3 | --task_name fine_tuning \ 4 | --train_dir ../data/ijcai2019/train_ar.tfrecord \ 5 | --valid_dir ../data/ijcai2019/dev_ar.tfrecord \ 6 | --output_dir ../output/ijcai2019 \ 7 | --do_lower_case True \ 8 | --vocab_file ../uncased_L-12_H-768_A-12_MPCBERT/vocab.txt \ 9 | --bert_config_file ../uncased_L-12_H-768_A-12_MPCBERT/bert_config.json \ 10 | --init_checkpoint ../uncased_L-12_H-768_A-12_MPCBERT/bert_model.ckpt \ 11 | --max_seq_length 230 \ 12 | --max_utr_num 7 \ 13 | --do_train True \ 14 | --train_batch_size 16 \ 15 | --learning_rate 2e-5 \ 16 | --num_train_epochs 10 \ 17 | --warmup_proportion 0.1 > log_finetuning_MPCBERT_ar.txt 2>&1 & 18 | -------------------------------------------------------------------------------- /scripts/run_finetuning_gift.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0 python -u ../run_finetuning_ar_gift.py \ 3 | --task_name fine_tuning \ 4 | --train_dir ../data/ijcai2019/train_ar_gift.tfrecord \ 5 | --valid_dir ../data/ijcai2019/dev_ar_gift.tfrecord \ 6 | --output_dir ../output/ijcai2019 \ 7 | --do_lower_case True \ 8 | --vocab_file ../uncased_L-12_H-768_A-12_MPCBERT/vocab.txt \ 9 | --bert_config_file ../uncased_L-12_H-768_A-12_MPCBERT/bert_config.json \ 10 | --init_checkpoint ../uncased_L-12_H-768_A-12_MPCBERT/bert_model.ckpt \ 11 | --max_seq_length 230 \ 12 | --max_utr_num 7 \ 13 | --do_train True \ 14 | --train_batch_size 16 \ 15 | --learning_rate 2e-5 \ 16 | --num_train_epochs 10 \ 17 | --warmup_proportion 0.1 > log_finetuning_MPCBERT_ar_GIFT.txt 2>&1 & 18 | -------------------------------------------------------------------------------- /scripts/run_pretraining.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0 python -u ../run_pretraining.py \ 3 | --task_name MPC-BERT-pretraining \ 4 | --input_file ../data/pretraining_data.tfrecord \ 5 | --output_dir ../uncased_L-12_H-768_A-12_MPCBERT \ 6 | --vocab_file ../uncased_L-12_H-768_A-12/vocab.txt \ 7 | --bert_config_file ../uncased_L-12_H-768_A-12/bert_config.json \ 8 | --init_checkpoint ../uncased_L-12_H-768_A-12/bert_model.ckpt \ 9 | --max_seq_length 230 \ 10 | --max_utr_length 30 \ 11 | --max_utr_num 7 \ 12 | --max_predictions_per_seq 25 \ 13 | --max_predictions_per_seq_ar 4 \ 14 | --max_predictions_per_seq_sr 2 \ 15 | --max_predictions_per_seq_cd 2 \ 16 | --train_batch_size 4 \ 17 | --learning_rate 5e-5 \ 18 | --mid_save_step 20000 \ 19 | --num_train_epochs 1 \ 20 | --warmup_proportion 0.1 > log_pretraining_MPCBERT.txt 2>&1 & 21 | -------------------------------------------------------------------------------- /scripts/run_testing.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0 python -u ../run_testing_ar.py \ 3 | --test_dir ../data/ijcai2019/test_ar.tfrecord \ 4 | --vocab_file ../uncased_L-12_H-768_A-12/vocab.txt \ 5 | --bert_config_file ../uncased_L-12_H-768_A-12/bert_config.json \ 6 | --max_seq_length 230 \ 7 | --max_utr_num 7 \ 8 | --eval_batch_size 256 \ 9 | --restore_model_dir ../output/ijcai2019/PATH_TO_TEST_MODEL > log_testing_MPCBERT_ar.txt 2>&1 & 10 | -------------------------------------------------------------------------------- /scripts/run_testing_gift.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0 python -u ../run_testing_ar_gift.py \ 3 | --test_dir ../data/ijcai2019/test_ar_gift.tfrecord \ 4 | --vocab_file ../uncased_L-12_H-768_A-12/vocab.txt \ 5 | --bert_config_file ../uncased_L-12_H-768_A-12/bert_config.json \ 6 | --max_seq_length 230 \ 7 | --max_utr_num 7 \ 8 | --eval_batch_size 256 \ 9 | --restore_model_dir ../output/ijcai2019/PATH_TO_TEST_MODEL > log_testing_MPCBERT_ar_GIFT.txt 2>&1 & 10 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat in ("Cc", "Cf"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /uncased_L-12_H-768_A-12/README.md: -------------------------------------------------------------------------------- 1 | ====== Download the BERT base model ====== 2 | 3 | link: https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip 4 | Move to path: ./uncased_L-12_H-768_A-12 --------------------------------------------------------------------------------