├── README.md
├── __init__.py
├── compute_metrics.py
├── create_finetuning_data_ar.py
├── create_finetuning_data_ar_gift.py
├── create_finetuning_data_rs.py
├── create_finetuning_data_rs_gift.py
├── create_finetuning_data_si.py
├── create_finetuning_data_si_gift.py
├── create_pretraining_data.py
├── data
├── emnlp2016
│ ├── README.md
│ └── data_preprocess.py
└── ijcai2019
│ └── README.md
├── image
├── result_addressee_recognition.png
├── result_addressee_recognition_gift.png
├── result_response_selection.png
├── result_response_selection_gift.png
├── result_speaker_identification.png
└── result_speaker_identification_gift.png
├── metrics.py
├── modeling_speaker.py
├── modeling_speaker_gift.py
├── optimization.py
├── run_finetuning_ar.py
├── run_finetuning_ar_gift.py
├── run_finetuning_rs.py
├── run_finetuning_rs_gift.py
├── run_finetuning_si.py
├── run_finetuning_si_gift.py
├── run_pretraining.py
├── run_testing_ar.py
├── run_testing_ar_gift.py
├── run_testing_rs.py
├── run_testing_rs_gift.py
├── run_testing_si.py
├── run_testing_si_gift.py
├── scripts
├── run_finetuning.sh
├── run_finetuning_gift.sh
├── run_pretraining.sh
├── run_testing.sh
└── run_testing_gift.sh
├── tokenization.py
└── uncased_L-12_H-768_A-12
└── README.md
/README.md:
--------------------------------------------------------------------------------
1 | # MPC-BERT & GIFT for Multi-Party Conversation Understanding
2 | This repository contains the source codes for the following papers:
3 | - [GIFT: Graph-Induced Fine-Tuning for Multi-Party Conversation Understanding](https://aclanthology.org/2023.acl-long.651.pdf).
4 | Jia-Chen Gu, Zhe-Hua Ling, Quan Liu, Cong Liu, Guoping Hu
5 | _ACL 2023_
6 |
7 | - [MPC-BERT: A Pre-Trained Language Model for Multi-Party Conversation Understanding](https://aclanthology.org/2021.acl-long.285.pdf).
8 | Jia-Chen Gu, Chongyang Tao, Zhen-Hua Ling, Can Xu, Xiubo Geng, Daxin Jiang
9 | _ACL 2021_
10 |
11 |
12 | ## Introduction of MPC-BERT
13 | Recently, various neural models for multi-party conversation (MPC) have achieved impressive improvements on a variety of tasks such as addressee recognition, speaker identification and response prediction.
14 | However, these existing methods on MPC usually represent interlocutors and utterances individually and ignore the inherent complicated structure in MPC which may provide crucial interlocutor and utterance semantics and would enhance the conversation understanding process.
15 | To this end, we present MPC-BERT, a pre-trained model for MPC understanding that considers learning who says what to whom in a unified model with several elaborated self-supervised tasks.
16 | Particularly, these tasks can be generally categorized into (1) interlocutor structure modeling including reply-to utterance recognition, identical speaker searching and pointer consistency distinction, and (2) utterance semantics modeling including masked shared utterance restoration and shared node detection.
17 | We evaluate MPC-BERT on three downstream tasks including addressee recognition, speaker identification and response selection.
18 | Experimental results show that MPC-BERT outperforms previous methods by large margins and achieves new state-of-the-art performance on all three downstream tasks at two benchmarks.
19 |
20 |

21 |
22 | 
23 |
24 | 
25 |
26 |
27 | ## Introduction of GIFT
28 | Addressing the issues of who saying what to whom in multi-party conversations (MPCs) has recently attracted a lot of research attention. However, existing methods on MPC understanding typically embed interlocutors and utterances into sequential information flows, or utilize only the superficial of inherent graph structures in MPCs. To this end, we present a plug-and-play and lightweight method named graph-induced fine-tuning (GIFT) which can adapt various Transformer-based pre-trained language models (PLMs) for universal MPC understanding. In detail, the full and equivalent connections among utterances in regular Transformer ignore the sparse but distinctive dependency of an utterance on another in MPCs. To distinguish different relationships between utterances, four types of edges are designed to integrate graph-induced signals into attention mechanisms to refine PLMs originally designed for processing sequential texts. We evaluate GIFT by implementing it into three PLMs, and test the performance on three downstream tasks including addressee recognition, speaker identification and response selection. Experimental results show that GIFT can significantly improve the performance of three PLMs on three downstream tasks and two benchmarks with only 4 additional parameters per encoding layer, achieving new state-of-the-art performance on MPC understanding.
29 |
30 | 
31 |
32 | 
33 |
34 | 
35 |
36 |
37 | ## Dependencies
38 | Python 3.6
39 | Tensorflow 1.13.1
40 |
41 |
42 | ## Download
43 | - Download the [BERT released by the Google research](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip),
44 | and move to path: ./uncased_L-12_H-768_A-12
45 |
46 | - We also release the [pre-trained MPC-BERT model](https://drive.google.com/file/d/1krmSuy83IQ0XXYyS9KfurnclmprRHgx_/view?usp=sharing),
47 | and move to path: ./uncased_L-12_H-768_A-12_MPCBERT. You just need to fine-tune it to reproduce our results.
48 |
49 | - Download the [Hu et al. (2019) dataset](https://drive.google.com/file/d/1qSw9X22oGGbuRtfaOAf3Z7ficn6mZgi9/view?usp=sharing) used in our paper,
50 | and move to path: ```./data/ijcai2019/```
51 |
52 | - Download the [Ouchi and Tsuboi (2016) dataset](https://drive.google.com/file/d/1nMiH6dGZfWBoOGbIvyBJp8oxhD8PWSNc/view?usp=sharing) used in our paper,
53 | and move to path: ```./data/emnlp2016/```
54 | Unzip the dataset and run the following commands.
55 | ```
56 | cd data/emnlp2016/
57 | python data_preprocess.py
58 | ```
59 |
60 |
61 | ## Pre-training
62 | Create the pre-training data.
63 | ```
64 | python create_pretraining_data.py
65 | ```
66 | Running the pre-training process.
67 | ```
68 | cd scripts/
69 | bash run_pretraining.sh
70 | ```
71 | The pre-trained model will be saved to the path ```./uncased_L-12_H-768_A-12_MPCBERT```.
72 | Modify the filenames in this folder to make it the same as those in Google's BERT.
73 |
74 |
75 | ## Regular Fine-tuning and Testing
76 | Take the task of addressee recognition as an example.
77 | Create the fine-tuning data.
78 | ```
79 | python create_finetuning_data_ar.py
80 | ```
81 | Running the fine-tuning process.
82 | ```
83 | cd scripts/
84 | bash run_finetuning.sh
85 | ```
86 |
87 | Modify the variable ```restore_model_dir``` in ```run_testing.sh```
88 | Running the testing process.
89 | ```
90 | cd scripts/
91 | bash run_testing.sh
92 | ```
93 |
94 |
95 | ## GIFT Fine-tuning and Testing
96 | Take the task of addressee recognition as an example.
97 | Create the fine-tuning data.
98 | ```
99 | python create_finetuning_data_ar_gift.py
100 | ```
101 | Running the fine-tuning process.
102 | ```
103 | cd scripts/
104 | bash run_finetuning_gift.sh
105 | ```
106 |
107 | Modify the variable ```restore_model_dir``` in ```run_testing_gift.sh```
108 | Running the testing process.
109 | ```
110 | cd scripts/
111 | bash run_testing_gift.sh
112 | ```
113 |
114 |
115 | ## Downstream Tasks
116 | Replace these scripts and its corresponding data when evaluating on other downstream tasks.
117 | ```
118 | create_finetuning_data_{ar, si, rs}_gift.py
119 | run_finetuning_{ar, si, rs}_gift.py
120 | run_testing_{ar, si, rs}_gift.py
121 | ```
122 | Specifically for the task of response selection, a ```output_test.txt``` file which records scores for each context-response pair will be saved to the path of ```restore_model_dir``` after testing.
123 | Modify the variable ```test_out_filename``` in ```compute_metrics.py``` and then run ```python compute_metrics.py```, various metrics will be shown.
124 |
125 |
126 | ## Cite
127 | If you think our work is helpful or use the code, please cite the following paper:
128 |
129 | ```
130 | @inproceedings{gu-etal-2023-gift,
131 | title = "{GIFT}: Graph-Induced Fine-Tuning for Multi-Party Conversation Understanding",
132 | author = "Gu, Jia-Chen and
133 | Ling, Zhen-Hua and
134 | Liu, Quan and
135 | Liu, Cong and
136 | Hu, Guoping",
137 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
138 | month = jul,
139 | year = "2023",
140 | address = "Toronto, Canada",
141 | publisher = "Association for Computational Linguistics",
142 | url = "https://aclanthology.org/2023.acl-long.651",
143 | pages = "11645--11658",
144 | }
145 | ```
146 |
147 | ```
148 | @inproceedings{gu-etal-2021-mpc,
149 | title = "{MPC}-{BERT}: A Pre-Trained Language Model for Multi-Party Conversation Understanding",
150 | author = "Gu, Jia-Chen and
151 | Tao, Chongyang and
152 | Ling, Zhen-Hua and
153 | Xu, Can and
154 | Geng, Xiubo and
155 | Jiang, Daxin",
156 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)",
157 | month = aug,
158 | year = "2021",
159 | address = "Online",
160 | publisher = "Association for Computational Linguistics",
161 | url = "https://aclanthology.org/2021.acl-long.285",
162 | pages = "3682--3692",
163 | }
164 | ```
165 |
166 |
167 | ## Acknowledgments
168 | Thank Wenpeng Hu and Zhangming Chan for providing the processed Hu et al. (2019) dataset used in their [paper](https://www.ijcai.org/proceedings/2019/0696.pdf).
169 | Thank Ran Le for providing the processed Ouchi and Tsuboi (2016) dataset used in their [paper](https://www.aclweb.org/anthology/D19-1199.pdf).
170 | Thank Prasan Yapa for providing a [TF 2.0 version of MPC-BERT](https://github.com/CyraxSector/MPC-BERT-2.0).
171 |
172 |
173 | ## Update
174 | Please keep an eye on this repository if you are interested in our work.
175 | Feel free to contact us (gujc@ustc.edu.cn) or open issues.
176 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/compute_metrics.py:
--------------------------------------------------------------------------------
1 | """Load the output_test.txt file and compute the metrics"""
2 |
3 |
4 | import random
5 | from collections import defaultdict
6 | import metrics
7 |
8 |
9 | test_out_filename = "./output/ijcai2019/PATH_TO_TEST_MODEL/output_test.txt"
10 | print("*"*20 + test_out_filename + "*"*20 + "\n")
11 |
12 | with open(test_out_filename, 'r') as f:
13 |
14 | # candidate size = 10
15 | results = defaultdict(list)
16 | lines = f.readlines()
17 | for line in lines[1:]:
18 | line = line.strip().split('\t')
19 | us_id = line[0]
20 | r_id = line[1]
21 | prob_score = float(line[2])
22 | label = float(line[4])
23 | results[us_id].append((r_id, label, prob_score))
24 |
25 | accu, precision, recall, f1, loss = metrics.classification_metrics(results)
26 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss))
27 | total_valid_query = metrics.get_num_valid_query(results)
28 | mvp = metrics.mean_average_precision(results)
29 | mrr = metrics.mean_reciprocal_rank(results)
30 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tNum_query: {}'.format(
31 | mvp, mrr, total_valid_query))
32 | top_1_precision = metrics.top_k_precision(results, k=1)
33 | top_2_precision = metrics.top_k_precision(results, k=2)
34 | top_5_precision = metrics.top_k_precision(results, k=5)
35 | print('Recall_10@1: {}\tRecall_10@2: {}\tRecall_10@5: {}\n'.format(
36 | top_1_precision, top_2_precision, top_5_precision))
37 |
38 | # candidate size = 2, the result of Recall_2@1 may vary at different runs because the negative candidate is sampled randomly
39 | results_bin = defaultdict(list)
40 | for us_id, candidates in results.items():
41 | false_candidates = []
42 | for candidate in candidates:
43 | r_id, label, prob_score = candidate
44 | if label == 1.0:
45 | results_bin[us_id].append(candidate)
46 | if label == 0.0:
47 | false_candidates.append(candidate)
48 | false_candidate = random.sample(false_candidates, 1)
49 | results_bin[us_id].append(false_candidate[0])
50 |
51 | accu, precision, recall, f1, loss = metrics.classification_metrics(results_bin)
52 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss))
53 | total_valid_query = metrics.get_num_valid_query(results_bin)
54 | mvp = metrics.mean_average_precision(results_bin)
55 | mrr = metrics.mean_reciprocal_rank(results_bin)
56 | top_1_precision = metrics.top_k_precision(results_bin, k=1)
57 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tNum_query: {}'.format(
58 | mvp, mrr, total_valid_query))
59 | print('Recall_2@1: {}\n'.format(
60 | top_1_precision))
61 |
--------------------------------------------------------------------------------
/create_finetuning_data_ar.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import json
3 | import random
4 | import numpy as np
5 | import collections
6 | from tqdm import tqdm
7 | import tokenization
8 | import tensorflow as tf
9 |
10 |
11 | """ Hu et al. GSN: A Graph-Structured Network for Multi-Party Dialogues. IJCAI 2019. """
12 | tf.flags.DEFINE_string("train_file", "./data/ijcai2019/train.json",
13 | "path to train file")
14 | tf.flags.DEFINE_string("valid_file", "./data/ijcai2019/dev.json",
15 | "path to valid file")
16 | tf.flags.DEFINE_string("test_file", "./data/ijcai2019/test.json",
17 | "path to test file")
18 | tf.flags.DEFINE_integer("max_seq_length", 230,
19 | "max sequence length of concatenated context and response")
20 | tf.flags.DEFINE_integer("max_utr_num", 7,
21 | "Maximum utterance number.")
22 |
23 | """
24 | Ouchi et al. Addressee and Response Selection for Multi-Party Conversation. EMNLP 2016.
25 | relesed the original dataset which is composed of 3 experimental settings according to conversation lengths.
26 |
27 | In our experiments, we used the version processed and used in
28 | Le et al. Who Is Speaking to Whom? Learning to Identify Utterance Addressee in Multi-Party Conversations. EMNLP 2019.
29 | """
30 |
31 | # Length-5
32 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/5_train.json",
33 | # "path to train file")
34 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/5_dev.json",
35 | # "path to valid file")
36 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/5_test.json",
37 | # "path to test file")
38 | # tf.flags.DEFINE_integer("max_seq_length", 120,
39 | # "max sequence length of concatenated context and response")
40 | # tf.flags.DEFINE_integer("max_utr_num", 5,
41 | # "Maximum utterance number.")
42 |
43 | # Length-10
44 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/10_train.json",
45 | # "path to train file")
46 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/10_dev.json",
47 | # "path to valid file")
48 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/10_test.json",
49 | # "path to test file")
50 | # tf.flags.DEFINE_integer("max_seq_length", 220,
51 | # "max sequence length of concatenated context and response")
52 | # tf.flags.DEFINE_integer("max_utr_num", 10,
53 | # "Maximum utterance number.")
54 |
55 | # Length-15
56 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/15_train.json",
57 | # "path to train file")
58 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/15_dev.json",
59 | # "path to valid file")
60 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/15_test.json",
61 | # "path to test file")
62 | # tf.flags.DEFINE_integer("max_seq_length", 320,
63 | # "max sequence length of concatenated context and response")
64 | # tf.flags.DEFINE_integer("max_utr_num", 15,
65 | # "Maximum utterance number.")
66 |
67 | tf.flags.DEFINE_string("vocab_file", "./uncased_L-12_H-768_A-12/vocab.txt",
68 | "path to vocab file")
69 | tf.flags.DEFINE_bool("do_lower_case", True,
70 | "whether to lower case the input text")
71 |
72 |
73 |
74 | def print_configuration_op(FLAGS):
75 | print('My Configurations:')
76 | for name, value in FLAGS.__flags.items():
77 | value=value.value
78 | if type(value) == float:
79 | print(' %s:\t %f'%(name, value))
80 | elif type(value) == int:
81 | print(' %s:\t %d'%(name, value))
82 | elif type(value) == str:
83 | print(' %s:\t %s'%(name, value))
84 | elif type(value) == bool:
85 | print(' %s:\t %s'%(name, value))
86 | else:
87 | print('%s:\t %s' % (name, value))
88 | print('End of configuration')
89 |
90 |
91 | def load_dataset(fname):
92 | dataset = []
93 | with open(fname, 'r') as f:
94 | for line in f:
95 | data = json.loads(line)
96 | ctx = data['context']
97 | ctx_spk = data['ctx_spk']
98 | ctx_adr = data['ctx_adr']
99 | rsp = data['answer']
100 | rsp_spk = data['ans_spk']
101 | rsp_adr = data['ans_adr']
102 |
103 | integrate_ctx = ctx + [rsp]
104 | integrate_ctx_spk = ctx_spk + [rsp_spk]
105 | integrate_ctx_adr = ctx_adr + [rsp_adr]
106 | assert len(integrate_ctx) == len(integrate_ctx_spk)
107 | assert len(integrate_ctx) == len(integrate_ctx_adr)
108 |
109 | label = []
110 | for utr_id_adr, utr_adr in enumerate(integrate_ctx_adr):
111 |
112 | label_utr = [0 for _ in range(len(integrate_ctx))]
113 | for cand_utr_id_spk, cand_utr_spk in enumerate(integrate_ctx_spk[:utr_id_adr]): # consider only the preceding utterances
114 | if cand_utr_spk == utr_adr:
115 | label_utr[cand_utr_id_spk] = 1
116 | label.append(label_utr)
117 |
118 | dataset.append((ctx, ctx_spk, rsp, rsp_spk, label))
119 |
120 | print("dataset_size: {}".format(len(dataset)))
121 | return dataset
122 |
123 |
124 | class InputExample(object):
125 | def __init__(self, guid, ctx, ctx_spk, rsp, rsp_spk, label):
126 | """Constructs a InputExample."""
127 | self.guid = guid
128 | self.ctx = ctx
129 | self.ctx_spk = ctx_spk
130 | self.rsp = rsp
131 | self.rsp_spk = rsp_spk
132 | self.label = label
133 |
134 |
135 | def create_examples(lines, set_type):
136 | """Creates examples for datasets."""
137 | examples = []
138 | for (i, line) in enumerate(lines):
139 | guid = "%s-%s" % (set_type, str(i))
140 | ctx = [tokenization.convert_to_unicode(utr) for utr in line[0]]
141 | ctx_spk = line[1]
142 | rsp = tokenization.convert_to_unicode(line[2])
143 | rsp_spk = line[3]
144 | label = line[-1]
145 | examples.append(InputExample(guid=guid, ctx=ctx, ctx_spk=ctx_spk, rsp=rsp, rsp_spk=rsp_spk, label=label))
146 | return examples
147 |
148 |
149 | def truncate_seq_pair(ctx_tokens, rsp_tokens, max_length):
150 | """Truncates a sequence pair in place to the maximum length."""
151 | while True:
152 | utr_lens = [len(utr_tokens) for utr_tokens in ctx_tokens]
153 | total_length = sum(utr_lens) + len(rsp_tokens)
154 | if total_length <= max_length:
155 | break
156 |
157 | # truncate the longest utterance or response
158 | if sum(utr_lens) > len(rsp_tokens):
159 | trunc_tokens = ctx_tokens[np.argmax(np.array(utr_lens))]
160 | else:
161 | trunc_tokens = rsp_tokens
162 | assert len(trunc_tokens) >= 1
163 |
164 | if random.random() < 0.5:
165 | del trunc_tokens[0]
166 | else:
167 | trunc_tokens.pop()
168 |
169 |
170 | class InputFeatures(object):
171 | """A single set of features of data."""
172 | def __init__(self, input_sents, input_mask, segment_ids, speaker_ids, cls_positions, label_id, label_weights):
173 | self.input_sents = input_sents
174 | self.input_mask = input_mask
175 | self.segment_ids = segment_ids
176 | self.speaker_ids = speaker_ids
177 | self.cls_positions = cls_positions
178 | self.label_id = label_id
179 | self.label_weights = label_weights
180 |
181 |
182 | def convert_examples_to_features(examples, max_seq_length, max_utr_num, tokenizer):
183 | """Loads a data file into a list of `InputBatch`s."""
184 |
185 | features = []
186 | for example in tqdm(examples, total=len(examples)):
187 |
188 | ctx_tokens = []
189 | for utr in example.ctx:
190 | utr_tokens = tokenizer.tokenize(utr)
191 | ctx_tokens.append(utr_tokens)
192 | assert len(ctx_tokens) == len(example.ctx_spk)
193 |
194 | rsp_tokens = tokenizer.tokenize(example.rsp)
195 |
196 | # [CLS]s for context, [CLS] for response, [SEP]
197 | max_num_tokens = max_seq_length - len(ctx_tokens) - 1 - 1
198 | truncate_seq_pair(ctx_tokens, rsp_tokens, max_num_tokens)
199 |
200 | tokens = []
201 | segment_ids = []
202 | speaker_ids = []
203 | cls_positions = []
204 |
205 | # utterances
206 | for i in range(len(ctx_tokens)):
207 | utr_tokens = ctx_tokens[i]
208 | utr_spk = example.ctx_spk[i]
209 |
210 | cls_positions.append(len(tokens))
211 | tokens.append("[CLS]")
212 | segment_ids.append(0)
213 | speaker_ids.append(utr_spk)
214 |
215 | for token in utr_tokens:
216 | tokens.append(token)
217 | segment_ids.append(0)
218 | speaker_ids.append(utr_spk)
219 |
220 | # response
221 | cls_positions.append(len(tokens))
222 | tokens.append("[CLS]")
223 | segment_ids.append(0)
224 | speaker_ids.append(example.rsp_spk)
225 |
226 | for token in rsp_tokens:
227 | tokens.append(token)
228 | segment_ids.append(0)
229 | speaker_ids.append(example.rsp_spk)
230 |
231 | tokens.append("[SEP]")
232 | segment_ids.append(0)
233 | speaker_ids.append(example.rsp_spk)
234 |
235 |
236 | input_sents = tokenizer.convert_tokens_to_ids(tokens)
237 | input_mask = [1] * len(input_sents)
238 | assert len(input_sents) <= max_seq_length
239 | while len(input_sents) < max_seq_length:
240 | input_sents.append(0)
241 | input_mask.append(0)
242 | segment_ids.append(0)
243 | speaker_ids.append(0)
244 | assert len(input_sents) == max_seq_length
245 | assert len(input_mask) == max_seq_length
246 | assert len(segment_ids) == max_seq_length
247 | assert len(speaker_ids) == max_seq_length
248 |
249 | assert len(cls_positions) <= max_utr_num
250 | while len(cls_positions) < max_utr_num:
251 | cls_positions.append(0)
252 | assert len(cls_positions) == max_utr_num
253 |
254 | label_id = []
255 | for label_utr in example.label:
256 | assert len(label_utr) <= max_utr_num
257 | while len(label_utr) < max_utr_num:
258 | label_utr.append(0)
259 | assert len(label_utr) == max_utr_num
260 | label_id.append(label_utr)
261 |
262 | assert len(label_id) <= max_utr_num
263 | while len(label_id) < max_utr_num:
264 | label_id.append([0] * max_utr_num)
265 | assert len(label_id) == max_utr_num
266 |
267 | label_id_flat = []
268 | label_weights = []
269 | for label_utr in label_id:
270 | label_id_flat.extend(label_utr)
271 | if sum(label_utr) > 0:
272 | label_weights.append(1.0)
273 | else:
274 | label_weights.append(0.0)
275 | assert len(label_id_flat) == max_utr_num * max_utr_num
276 | assert len(label_weights) == max_utr_num
277 |
278 | features.append(
279 | InputFeatures(
280 | input_sents=input_sents,
281 | input_mask=input_mask,
282 | segment_ids=segment_ids,
283 | speaker_ids=speaker_ids,
284 | cls_positions=cls_positions,
285 | label_id=label_id_flat,
286 | label_weights=label_weights))
287 |
288 | return features
289 |
290 |
291 | def write_instance_to_example_files(instances, output_files):
292 | writers = []
293 |
294 | for output_file in output_files:
295 | writers.append(tf.python_io.TFRecordWriter(output_file))
296 |
297 | writer_index = 0
298 | total_written = 0
299 | for (inst_index, instance) in enumerate(instances):
300 | features = collections.OrderedDict()
301 | features["input_sents"] = create_int_feature(instance.input_sents)
302 | features["input_mask"] = create_int_feature(instance.input_mask)
303 | features["segment_ids"] = create_int_feature(instance.segment_ids)
304 | features["speaker_ids"] = create_int_feature(instance.speaker_ids)
305 | features["cls_positions"] = create_int_feature(instance.cls_positions)
306 | features["label_ids"] = create_int_feature(instance.label_id)
307 | features["label_weights"] = create_float_feature(instance.label_weights)
308 |
309 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
310 |
311 | writers[writer_index].write(tf_example.SerializeToString())
312 | writer_index = (writer_index + 1) % len(writers)
313 |
314 | total_written += 1
315 |
316 | print("write_{}_instance_to_example_files".format(total_written))
317 |
318 | for feature_name in features.keys():
319 | feature = features[feature_name]
320 | values = []
321 | if feature.int64_list.value:
322 | values = feature.int64_list.value
323 | elif feature.float_list.value:
324 | values = feature.float_list.value
325 | tf.logging.info(
326 | "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
327 |
328 | for writer in writers:
329 | writer.close()
330 |
331 |
332 | def create_int_feature(values):
333 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
334 | return feature
335 |
336 | def create_float_feature(values):
337 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
338 | return feature
339 |
340 |
341 |
342 | if __name__ == "__main__":
343 |
344 | FLAGS = tf.flags.FLAGS
345 | print_configuration_op(FLAGS)
346 |
347 | tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
348 |
349 | filenames = [FLAGS.train_file, FLAGS.valid_file, FLAGS.test_file]
350 | filetypes = ["train", "valid", "test"]
351 | for (filename, filetype) in zip(filenames, filetypes):
352 | dataset = load_dataset(filename)
353 | examples = create_examples(dataset, filetype)
354 | features = convert_examples_to_features(examples, FLAGS.max_seq_length, FLAGS.max_utr_num, tokenizer)
355 | new_filename = filename[:-5] + "_ar.tfrecord"
356 | write_instance_to_example_files(features, [new_filename])
357 | print('Convert {} to {} done'.format(filename, new_filename))
358 |
359 | print("Sub-process(es) done.")
360 |
--------------------------------------------------------------------------------
/create_finetuning_data_rs.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import json
3 | import random
4 | import numpy as np
5 | import collections
6 | from tqdm import tqdm
7 | import tokenization
8 | import tensorflow as tf
9 |
10 |
11 | """ Hu et al. GSN: A Graph-Structured Network for Multi-Party Dialogues. IJCAI 2019. """
12 | tf.flags.DEFINE_string("train_file", "./data/ijcai2019/train.json",
13 | "path to train file")
14 | tf.flags.DEFINE_string("valid_file", "./data/ijcai2019/dev.json",
15 | "path to valid file")
16 | tf.flags.DEFINE_string("test_file", "./data/ijcai2019/test.json",
17 | "path to test file")
18 | tf.flags.DEFINE_integer("max_seq_length", 230,
19 | "max sequence length of concatenated context and response")
20 | tf.flags.DEFINE_integer("max_utr_num", 7,
21 | "Maximum utterance number.")
22 |
23 | """
24 | Ouchi et al. Addressee and Response Selection for Multi-Party Conversation. EMNLP 2016.
25 | relesed the original dataset which is composed of 3 experimental settings according to conversation lengths.
26 |
27 | In our experiments, we used the version processed and used in
28 | Le et al. Who Is Speaking to Whom? Learning to Identify Utterance Addressee in Multi-Party Conversations. EMNLP 2019.
29 | """
30 |
31 | # Length-5
32 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/5_train.json",
33 | # "path to train file")
34 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/5_dev.json",
35 | # "path to valid file")
36 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/5_test.json",
37 | # "path to test file")
38 | # tf.flags.DEFINE_integer("max_seq_length", 120,
39 | # "max sequence length of concatenated context and response")
40 | # tf.flags.DEFINE_integer("max_utr_num", 5,
41 | # "Maximum utterance number.")
42 |
43 | # Length-10
44 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/10_train.json",
45 | # "path to train file")
46 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/10_dev.json",
47 | # "path to valid file")
48 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/10_test.json",
49 | # "path to test file")
50 | # tf.flags.DEFINE_integer("max_seq_length", 220,
51 | # "max sequence length of concatenated context and response")
52 | # tf.flags.DEFINE_integer("max_utr_num", 10,
53 | # "Maximum utterance number.")
54 |
55 | # Length-15
56 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/15_train.json",
57 | # "path to train file")
58 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/15_dev.json",
59 | # "path to valid file")
60 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/15_test.json",
61 | # "path to test file")
62 | # tf.flags.DEFINE_integer("max_seq_length", 320,
63 | # "max sequence length of concatenated context and response")
64 | # tf.flags.DEFINE_integer("max_utr_num", 15,
65 | # "Maximum utterance number.")
66 |
67 | tf.flags.DEFINE_string("vocab_file", "./uncased_L-12_H-768_A-12/vocab.txt",
68 | "path to vocab file")
69 | tf.flags.DEFINE_bool("do_lower_case", True,
70 | "whether to lower case the input text")
71 |
72 |
73 |
74 | def print_configuration_op(FLAGS):
75 | print('My Configurations:')
76 | for name, value in FLAGS.__flags.items():
77 | value=value.value
78 | if type(value) == float:
79 | print(' %s:\t %f'%(name, value))
80 | elif type(value) == int:
81 | print(' %s:\t %d'%(name, value))
82 | elif type(value) == str:
83 | print(' %s:\t %s'%(name, value))
84 | elif type(value) == bool:
85 | print(' %s:\t %s'%(name, value))
86 | else:
87 | print('%s:\t %s' % (name, value))
88 | print('End of configuration')
89 |
90 |
91 | def load_dataset(fname, n_negative):
92 | ctx_list = []
93 | ctx_spk_list = []
94 | rsp_list = []
95 | rsp_spk_list = []
96 | with open(fname, 'r') as f:
97 | for line in f:
98 | data = json.loads(line)
99 | ctx_list.append(data['context'])
100 | ctx_spk_list.append(data['ctx_spk'])
101 | rsp_list.append(data['answer'])
102 | rsp_spk_list.append(data['ans_spk'])
103 | print("matched context-response pairs: {}".format(len(ctx_list)))
104 |
105 | dataset = []
106 | index_list = list(range(len(ctx_list)))
107 | for i in range(len(ctx_list)):
108 | ctx = ctx_list[i]
109 | ctx_spk = ctx_spk_list[i]
110 |
111 | # positive
112 | rsp = rsp_list[i]
113 | rsp_spk = rsp_spk_list[i]
114 | dataset.append((i, ctx, ctx_spk, i, rsp, rsp_spk, 'follow'))
115 |
116 | # negative
117 | negatives = random.sample(index_list, n_negative)
118 | while i in negatives:
119 | negatives = random.sample(index_list, n_negative)
120 | assert i not in negatives
121 | for n_id in negatives:
122 | dataset.append((i, ctx, ctx_spk, n_id, rsp_list[n_id], rsp_spk, 'unfollow'))
123 |
124 | print("dataset_size: {}".format(len(dataset)))
125 | return dataset
126 |
127 |
128 | class InputExample(object):
129 | def __init__(self, guid, ctx_id, ctx, ctx_spk, rsp_id, rsp, rsp_spk, label):
130 | """Constructs a InputExample."""
131 | self.guid = guid
132 | self.ctx_id = ctx_id
133 | self.ctx = ctx
134 | self.ctx_spk = ctx_spk
135 | self.rsp_id = rsp_id
136 | self.rsp = rsp
137 | self.rsp_spk = rsp_spk
138 | self.label = label
139 |
140 |
141 | def create_examples(lines, set_type):
142 | """Creates examples for datasets."""
143 | examples = []
144 | for (i, line) in enumerate(lines):
145 | guid = "%s-%s" % (set_type, str(i))
146 | ctx_id = line[0]
147 | ctx = [tokenization.convert_to_unicode(utr) for utr in line[1]]
148 | ctx_spk = line[2]
149 | rsp_id = line[3]
150 | rsp = tokenization.convert_to_unicode(line[4])
151 | rsp_spk = line[5]
152 | label = tokenization.convert_to_unicode(line[-1])
153 | examples.append(InputExample(guid=guid, ctx_id=ctx_id, ctx=ctx, ctx_spk=ctx_spk,
154 | rsp_id=rsp_id, rsp=rsp, rsp_spk=rsp_spk, label=label))
155 | return examples
156 |
157 |
158 | def truncate_seq_pair(tokens_a, tokens_b, max_length):
159 | """Truncates a sequence pair in place to the maximum length."""
160 | while True:
161 | total_length = len(tokens_a) + len(tokens_b)
162 | if total_length <= max_length:
163 | break
164 |
165 | if len(tokens_a) > len(tokens_b):
166 | trunc_tokens = tokens_a
167 | else:
168 | trunc_tokens = tokens_b
169 |
170 | if random.random() < 0.5:
171 | del trunc_tokens[0]
172 | else:
173 | trunc_tokens.pop()
174 |
175 |
176 | class InputFeatures(object):
177 | """A single set of features of data."""
178 | def __init__(self, ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, label_id):
179 | self.ctx_id = ctx_id
180 | self.rsp_id = rsp_id
181 | self.input_sents = input_sents
182 | self.input_mask = input_mask
183 | self.segment_ids = segment_ids
184 | self.speaker_ids = speaker_ids
185 | self.label_id = label_id
186 |
187 |
188 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
189 | """Loads a data file into a list of `InputBatch`s."""
190 |
191 | label_map = {}
192 | for (i, label) in enumerate(label_list): # ['0', '1']
193 | label_map[label] = i
194 |
195 | features = []
196 | for example in tqdm(examples, total=len(examples)):
197 | ctx_id = int(example.ctx_id)
198 | rsp_id = int(example.rsp_id)
199 |
200 | ctx_tokens = []
201 | ctx_spk = []
202 | for i in range(len(example.ctx)):
203 | utr = example.ctx[i]
204 | utr_spk = example.ctx_spk[i]
205 | utr_tokens = tokenizer.tokenize(utr)
206 | ctx_tokens.extend(utr_tokens)
207 | ctx_spk.extend([utr_spk]*len(utr_tokens))
208 | assert len(ctx_tokens) == len(ctx_spk)
209 |
210 | rsp_tokens = tokenizer.tokenize(example.rsp)
211 |
212 | # Account for [CLS], [SEP], [SEP] with "- 3"
213 | truncate_seq_pair(ctx_tokens, rsp_tokens, max_seq_length - 3)
214 |
215 |
216 | tokens = []
217 | segment_ids = []
218 | speaker_ids = []
219 | tokens.append("[CLS]")
220 | segment_ids.append(0)
221 | speaker_ids.append(0)
222 | for token_idx, token in enumerate(ctx_tokens):
223 | tokens.append(token)
224 | segment_ids.append(0)
225 | speaker_ids.append(ctx_spk[token_idx])
226 | tokens.append("[SEP]")
227 | segment_ids.append(0)
228 | speaker_ids.append(0)
229 |
230 | for token_idx, token in enumerate(rsp_tokens):
231 | tokens.append(token)
232 | segment_ids.append(1)
233 | # speaker_ids.append(0) # 0 for mask speaker
234 | speaker_ids.append(example.rsp_spk)
235 | tokens.append("[SEP]")
236 | segment_ids.append(1)
237 | speaker_ids.append(example.rsp_spk)
238 |
239 | input_sents = tokenizer.convert_tokens_to_ids(tokens)
240 | input_mask = [1] * len(input_sents)
241 | assert len(input_sents) <= max_seq_length
242 | while len(input_sents) < max_seq_length:
243 | input_sents.append(0)
244 | input_mask.append(0)
245 | segment_ids.append(0)
246 | speaker_ids.append(0)
247 | assert len(input_sents) == max_seq_length
248 | assert len(input_mask) == max_seq_length
249 | assert len(segment_ids) == max_seq_length
250 | assert len(speaker_ids) == max_seq_length
251 |
252 | label_id = label_map[example.label]
253 |
254 | features.append(
255 | InputFeatures(
256 | ctx_id=ctx_id,
257 | rsp_id = rsp_id,
258 | input_sents=input_sents,
259 | input_mask=input_mask,
260 | segment_ids=segment_ids,
261 | speaker_ids=speaker_ids,
262 | label_id=label_id))
263 |
264 | return features
265 |
266 |
267 | def write_instance_to_example_files(instances, output_files):
268 | writers = []
269 |
270 | for output_file in output_files:
271 | writers.append(tf.python_io.TFRecordWriter(output_file))
272 |
273 | writer_index = 0
274 | total_written = 0
275 | for (inst_index, instance) in enumerate(instances):
276 | features = collections.OrderedDict()
277 | features["ctx_id"] = create_int_feature([instance.ctx_id])
278 | features["rsp_id"] = create_int_feature([instance.rsp_id])
279 | features["input_sents"] = create_int_feature(instance.input_sents)
280 | features["input_mask"] = create_int_feature(instance.input_mask)
281 | features["segment_ids"] = create_int_feature(instance.segment_ids)
282 | features["speaker_ids"] = create_int_feature(instance.speaker_ids)
283 | features["label_ids"] = create_float_feature([instance.label_id])
284 |
285 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
286 |
287 | writers[writer_index].write(tf_example.SerializeToString())
288 | writer_index = (writer_index + 1) % len(writers)
289 |
290 | total_written += 1
291 |
292 | print("write_{}_instance_to_example_files".format(total_written))
293 |
294 | for feature_name in features.keys():
295 | feature = features[feature_name]
296 | values = []
297 | if feature.int64_list.value:
298 | values = feature.int64_list.value
299 | elif feature.float_list.value:
300 | values = feature.float_list.value
301 | tf.logging.info(
302 | "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
303 |
304 | for writer in writers:
305 | writer.close()
306 |
307 |
308 | def create_int_feature(values):
309 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
310 | return feature
311 |
312 | def create_float_feature(values):
313 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
314 | return feature
315 |
316 |
317 |
318 | if __name__ == "__main__":
319 |
320 | FLAGS = tf.flags.FLAGS
321 | print_configuration_op(FLAGS)
322 |
323 | tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
324 | label_list = ["unfollow", "follow"]
325 |
326 | filenames = [FLAGS.train_file, FLAGS.valid_file, FLAGS.test_file]
327 | filetypes = ["train", "valid", "test"]
328 | file_n_negative = [1, 9, 9]
329 |
330 | for (filename, filetype, n_negative) in zip(filenames, filetypes, file_n_negative):
331 | dataset = load_dataset(filename, n_negative)
332 | examples = create_examples(dataset, filetype)
333 | features = convert_examples_to_features(examples, label_list, FLAGS.max_seq_length, tokenizer)
334 | new_filename = filename[:-5] + "_rs.tfrecord"
335 | write_instance_to_example_files(features, [new_filename])
336 | print('Convert {} to {} done'.format(filename, new_filename))
337 |
338 | print("Sub-process(es) done.")
339 |
--------------------------------------------------------------------------------
/create_finetuning_data_si.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import json
3 | import random
4 | import numpy as np
5 | import collections
6 | from tqdm import tqdm
7 | import tokenization
8 | import tensorflow as tf
9 |
10 |
11 | """ Hu et al. GSN: A Graph-Structured Network for Multi-Party Dialogues. IJCAI 2019. """
12 | tf.flags.DEFINE_string("train_file", "./data/ijcai2019/train.json",
13 | "path to train file")
14 | tf.flags.DEFINE_string("valid_file", "./data/ijcai2019/dev.json",
15 | "path to valid file")
16 | tf.flags.DEFINE_string("test_file", "./data/ijcai2019/test.json",
17 | "path to test file")
18 | tf.flags.DEFINE_integer("max_seq_length", 230,
19 | "max sequence length of concatenated context and response")
20 | tf.flags.DEFINE_integer("max_utr_num", 7,
21 | "Maximum utterance number.")
22 |
23 | """
24 | Ouchi et al. Addressee and Response Selection for Multi-Party Conversation. EMNLP 2016.
25 | relesed the original dataset which is composed of 3 experimental settings according to conversation lengths.
26 |
27 | In our experiments, we used the version processed and used in
28 | Le et al. Who Is Speaking to Whom? Learning to Identify Utterance Addressee in Multi-Party Conversations. EMNLP 2019.
29 | """
30 |
31 | # Length-5
32 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/5_train.json",
33 | # "path to train file")
34 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/5_dev.json",
35 | # "path to valid file")
36 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/5_test.json",
37 | # "path to test file")
38 | # tf.flags.DEFINE_integer("max_seq_length", 120,
39 | # "max sequence length of concatenated context and response")
40 | # tf.flags.DEFINE_integer("max_utr_num", 5,
41 | # "Maximum utterance number.")
42 |
43 | # Length-10
44 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/10_train.json",
45 | # "path to train file")
46 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/10_dev.json",
47 | # "path to valid file")
48 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/10_test.json",
49 | # "path to test file")
50 | # tf.flags.DEFINE_integer("max_seq_length", 220,
51 | # "max sequence length of concatenated context and response")
52 | # tf.flags.DEFINE_integer("max_utr_num", 10,
53 | # "Maximum utterance number.")
54 |
55 | # Length-15
56 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/15_train.json",
57 | # "path to train file")
58 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/15_dev.json",
59 | # "path to valid file")
60 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/15_test.json",
61 | # "path to test file")
62 | # tf.flags.DEFINE_integer("max_seq_length", 320,
63 | # "max sequence length of concatenated context and response")
64 | # tf.flags.DEFINE_integer("max_utr_num", 15,
65 | # "Maximum utterance number.")
66 |
67 | tf.flags.DEFINE_string("vocab_file", "./uncased_L-12_H-768_A-12/vocab.txt",
68 | "path to vocab file")
69 | tf.flags.DEFINE_bool("do_lower_case", True,
70 | "whether to lower case the input text")
71 |
72 |
73 |
74 | def print_configuration_op(FLAGS):
75 | print('My Configurations:')
76 | for name, value in FLAGS.__flags.items():
77 | value=value.value
78 | if type(value) == float:
79 | print(' %s:\t %f'%(name, value))
80 | elif type(value) == int:
81 | print(' %s:\t %d'%(name, value))
82 | elif type(value) == str:
83 | print(' %s:\t %s'%(name, value))
84 | elif type(value) == bool:
85 | print(' %s:\t %s'%(name, value))
86 | else:
87 | print('%s:\t %s' % (name, value))
88 | print('End of configuration')
89 |
90 |
91 | def load_dataset(fname):
92 | dataset = []
93 | with open(fname, 'r') as f:
94 | for line in f:
95 | data = json.loads(line)
96 | ctx = data['context']
97 | ctx_spk = data['ctx_spk']
98 | rsp = data['answer']
99 | rsp_spk = data['ans_spk']
100 | assert len(ctx) ==len(ctx_spk)
101 |
102 | utrs_same_spk_with_rsp_spk = []
103 | for utr_id, utr_spk in enumerate(ctx_spk):
104 | if utr_spk == rsp_spk:
105 | utrs_same_spk_with_rsp_spk.append(utr_id)
106 |
107 | if len(utrs_same_spk_with_rsp_spk) == 0:
108 | continue
109 |
110 | label = [0 for _ in range(len(ctx))]
111 | for utr_id in utrs_same_spk_with_rsp_spk:
112 | label[utr_id] = 1
113 |
114 | dataset.append((ctx, ctx_spk, rsp, rsp_spk, label))
115 |
116 | print("dataset_size: {}".format(len(dataset)))
117 | return dataset
118 |
119 |
120 | class InputExample(object):
121 | def __init__(self, guid, ctx, ctx_spk, rsp, rsp_spk, label):
122 | """Constructs a InputExample."""
123 | self.guid = guid
124 | self.ctx = ctx
125 | self.ctx_spk = ctx_spk
126 | self.rsp = rsp
127 | self.rsp_spk = rsp_spk
128 | self.label = label
129 |
130 |
131 | def create_examples(lines, set_type):
132 | """Creates examples for datasets."""
133 | examples = []
134 | for (i, line) in enumerate(lines):
135 | guid = "%s-%s" % (set_type, str(i))
136 | ctx = [tokenization.convert_to_unicode(utr) for utr in line[0]]
137 | ctx_spk = line[1]
138 | rsp = tokenization.convert_to_unicode(line[2])
139 | rsp_spk = line[3]
140 | label = line[-1]
141 | examples.append(InputExample(guid=guid, ctx=ctx, ctx_spk=ctx_spk, rsp=rsp, rsp_spk=rsp_spk, label=label))
142 | return examples
143 |
144 |
145 | def truncate_seq_pair(ctx_tokens, rsp_tokens, max_length):
146 | """Truncates a sequence pair in place to the maximum length."""
147 | while True:
148 | utr_lens = [len(utr_tokens) for utr_tokens in ctx_tokens]
149 | total_length = sum(utr_lens) + len(rsp_tokens)
150 | if total_length <= max_length:
151 | break
152 |
153 | # truncate the longest utterance or response
154 | if sum(utr_lens) > len(rsp_tokens):
155 | trunc_tokens = ctx_tokens[np.argmax(np.array(utr_lens))]
156 | else:
157 | trunc_tokens = rsp_tokens
158 | assert len(trunc_tokens) >= 1
159 |
160 | if random.random() < 0.5:
161 | del trunc_tokens[0]
162 | else:
163 | trunc_tokens.pop()
164 |
165 |
166 | class InputFeatures(object):
167 | """A single set of features of data."""
168 | def __init__(self, input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, label_id):
169 | self.input_sents = input_sents
170 | self.input_mask = input_mask
171 | self.segment_ids = segment_ids
172 | self.speaker_ids = speaker_ids
173 | self.cls_positions = cls_positions
174 | self.rsp_position = rsp_position
175 | self.label_id = label_id
176 |
177 |
178 | def convert_examples_to_features(examples, max_seq_length, max_utr_num, tokenizer):
179 | """Loads a data file into a list of `InputBatch`s."""
180 |
181 | features = []
182 | for example in tqdm(examples, total=len(examples)):
183 |
184 | ctx_tokens = []
185 | for utr in example.ctx:
186 | utr_tokens = tokenizer.tokenize(utr)
187 | ctx_tokens.append(utr_tokens)
188 | assert len(ctx_tokens) == len(example.ctx_spk)
189 |
190 | rsp_tokens = tokenizer.tokenize(example.rsp)
191 |
192 | # [CLS]s for context, [CLS] for response, [SEP]
193 | max_num_tokens = max_seq_length - len(ctx_tokens) - 1 - 1
194 | truncate_seq_pair(ctx_tokens, rsp_tokens, max_num_tokens)
195 |
196 |
197 | tokens = []
198 | segment_ids = []
199 | speaker_ids = []
200 | cls_positions = []
201 | rsp_position = []
202 |
203 | # utterances
204 | for i in range(len(ctx_tokens)):
205 | utr_tokens = ctx_tokens[i]
206 | utr_spk = example.ctx_spk[i]
207 |
208 | cls_positions.append(len(tokens))
209 | tokens.append("[CLS]")
210 | segment_ids.append(0)
211 | speaker_ids.append(utr_spk)
212 |
213 | for token in utr_tokens:
214 | tokens.append(token)
215 | segment_ids.append(0)
216 | speaker_ids.append(utr_spk)
217 |
218 | # response
219 | rsp_position.append(len(cls_positions))
220 | cls_positions.append(len(tokens))
221 | tokens.append("[CLS]")
222 | segment_ids.append(0)
223 | # speaker_ids.append(example.rsp_spk)
224 | speaker_ids.append(0) # 0 for mask
225 |
226 | for token in rsp_tokens:
227 | tokens.append(token)
228 | segment_ids.append(0)
229 | # speaker_ids.append(example.rsp_spk)
230 | speaker_ids.append(0)
231 |
232 | tokens.append("[SEP]")
233 | segment_ids.append(0)
234 | # speaker_ids.append(example.rsp_spk)
235 | speaker_ids.append(0)
236 |
237 |
238 | input_sents = tokenizer.convert_tokens_to_ids(tokens)
239 | input_mask = [1] * len(input_sents)
240 | assert len(input_sents) <= max_seq_length
241 | while len(input_sents) < max_seq_length:
242 | input_sents.append(0)
243 | input_mask.append(0)
244 | segment_ids.append(0)
245 | speaker_ids.append(0)
246 | assert len(input_sents) == max_seq_length
247 | assert len(input_mask) == max_seq_length
248 | assert len(segment_ids) == max_seq_length
249 | assert len(speaker_ids) == max_seq_length
250 |
251 | assert len(cls_positions) <= max_utr_num
252 | while len(cls_positions) < max_utr_num:
253 | cls_positions.append(0)
254 | assert len(cls_positions) == max_utr_num
255 |
256 | label_id = example.label
257 | assert len(label_id) <= max_utr_num
258 | while len(label_id) < max_utr_num:
259 | label_id.append(0)
260 | assert len(label_id) == max_utr_num
261 |
262 | features.append(
263 | InputFeatures(
264 | input_sents=input_sents,
265 | input_mask=input_mask,
266 | segment_ids=segment_ids,
267 | speaker_ids=speaker_ids,
268 | cls_positions=cls_positions,
269 | rsp_position=rsp_position,
270 | label_id=label_id))
271 |
272 | return features
273 |
274 |
275 | def write_instance_to_example_files(instances, output_files):
276 | writers = []
277 |
278 | for output_file in output_files:
279 | writers.append(tf.python_io.TFRecordWriter(output_file))
280 |
281 | writer_index = 0
282 | total_written = 0
283 | for (inst_index, instance) in enumerate(instances):
284 | features = collections.OrderedDict()
285 | features["input_sents"] = create_int_feature(instance.input_sents)
286 | features["input_mask"] = create_int_feature(instance.input_mask)
287 | features["segment_ids"] = create_int_feature(instance.segment_ids)
288 | features["speaker_ids"] = create_int_feature(instance.speaker_ids)
289 | features["cls_positions"] = create_int_feature(instance.cls_positions)
290 | features["rsp_position"] = create_int_feature(instance.rsp_position)
291 | features["label_ids"] = create_int_feature(instance.label_id)
292 |
293 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
294 |
295 | writers[writer_index].write(tf_example.SerializeToString())
296 | writer_index = (writer_index + 1) % len(writers)
297 |
298 | total_written += 1
299 |
300 | print("write_{}_instance_to_example_files".format(total_written))
301 |
302 | for feature_name in features.keys():
303 | feature = features[feature_name]
304 | values = []
305 | if feature.int64_list.value:
306 | values = feature.int64_list.value
307 | elif feature.float_list.value:
308 | values = feature.float_list.value
309 | tf.logging.info(
310 | "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
311 |
312 | for writer in writers:
313 | writer.close()
314 |
315 |
316 | def create_int_feature(values):
317 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
318 | return feature
319 |
320 | def create_float_feature(values):
321 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
322 | return feature
323 |
324 |
325 |
326 | if __name__ == "__main__":
327 |
328 | FLAGS = tf.flags.FLAGS
329 | print_configuration_op(FLAGS)
330 |
331 | tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
332 |
333 | filenames = [FLAGS.train_file, FLAGS.valid_file, FLAGS.test_file]
334 | filetypes = ["train", "valid", "test"]
335 | for (filename, filetype) in zip(filenames, filetypes):
336 | dataset = load_dataset(filename)
337 | examples = create_examples(dataset, filetype)
338 | features = convert_examples_to_features(examples, FLAGS.max_seq_length, FLAGS.max_utr_num, tokenizer)
339 | new_filename = filename[:-5] + "_si.tfrecord"
340 | write_instance_to_example_files(features, [new_filename])
341 | print('Convert {} to {} done'.format(filename, new_filename))
342 |
343 | print("Sub-process(es) done.")
344 |
--------------------------------------------------------------------------------
/create_finetuning_data_si_gift.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import json
3 | import random
4 | import numpy as np
5 | import collections
6 | from tqdm import tqdm
7 | import tokenization
8 | import tensorflow as tf
9 |
10 |
11 | """ Hu et al. GSN: A Graph-Structured Network for Multi-Party Dialogues. IJCAI 2019. """
12 | tf.flags.DEFINE_string("train_file", "./data/ijcai2019/train.json",
13 | "path to train file")
14 | tf.flags.DEFINE_string("valid_file", "./data/ijcai2019/dev.json",
15 | "path to valid file")
16 | tf.flags.DEFINE_string("test_file", "./data/ijcai2019/test.json",
17 | "path to test file")
18 | tf.flags.DEFINE_integer("max_seq_length", 230,
19 | "max sequence length of concatenated context and response")
20 | tf.flags.DEFINE_integer("max_utr_num", 7,
21 | "Maximum utterance number.")
22 |
23 | """
24 | Ouchi et al. Addressee and Response Selection for Multi-Party Conversation. EMNLP 2016.
25 | relesed the original dataset which is composed of 3 experimental settings according to conversation lengths.
26 |
27 | In our experiments, we used the version processed and used in
28 | Le et al. Who Is Speaking to Whom? Learning to Identify Utterance Addressee in Multi-Party Conversations. EMNLP 2019.
29 | """
30 |
31 | # Length-5
32 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/5_train.json",
33 | # "path to train file")
34 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/5_dev.json",
35 | # "path to valid file")
36 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/5_test.json",
37 | # "path to test file")
38 | # tf.flags.DEFINE_integer("max_seq_length", 120,
39 | # "max sequence length of concatenated context and response")
40 | # tf.flags.DEFINE_integer("max_utr_num", 5,
41 | # "Maximum utterance number.")
42 |
43 | # Length-10
44 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/10_train.json",
45 | # "path to train file")
46 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/10_dev.json",
47 | # "path to valid file")
48 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/10_test.json",
49 | # "path to test file")
50 | # tf.flags.DEFINE_integer("max_seq_length", 220,
51 | # "max sequence length of concatenated context and response")
52 | # tf.flags.DEFINE_integer("max_utr_num", 10,
53 | # "Maximum utterance number.")
54 |
55 | # Length-15
56 | # tf.flags.DEFINE_string("train_file", "./data/emnlp2016/15_train.json",
57 | # "path to train file")
58 | # tf.flags.DEFINE_string("valid_file", "./data/emnlp2016/15_dev.json",
59 | # "path to valid file")
60 | # tf.flags.DEFINE_string("test_file", "./data/emnlp2016/15_test.json",
61 | # "path to test file")
62 | # tf.flags.DEFINE_integer("max_seq_length", 320,
63 | # "max sequence length of concatenated context and response")
64 | # tf.flags.DEFINE_integer("max_utr_num", 15,
65 | # "Maximum utterance number.")
66 |
67 | tf.flags.DEFINE_string("vocab_file", "./uncased_L-12_H-768_A-12/vocab.txt",
68 | "path to vocab file")
69 | tf.flags.DEFINE_bool("do_lower_case", True,
70 | "whether to lower case the input text")
71 |
72 |
73 |
74 | def print_configuration_op(FLAGS):
75 | print('My Configurations:')
76 | for name, value in FLAGS.__flags.items():
77 | value=value.value
78 | if type(value) == float:
79 | print(' %s:\t %f'%(name, value))
80 | elif type(value) == int:
81 | print(' %s:\t %d'%(name, value))
82 | elif type(value) == str:
83 | print(' %s:\t %s'%(name, value))
84 | elif type(value) == bool:
85 | print(' %s:\t %s'%(name, value))
86 | else:
87 | print('%s:\t %s' % (name, value))
88 | print('End of configuration')
89 |
90 |
91 | def load_dataset(fname):
92 | dataset = []
93 | with open(fname, 'r') as f:
94 | for line in f:
95 | data = json.loads(line)
96 | ctx = data['context']
97 | ctx_spk = data['ctx_spk']
98 | ctx_adr = data['ctx_adr']
99 | rsp = data['answer']
100 | rsp_spk = data['ans_spk']
101 | rsp_adr = data['ans_adr']
102 | ctx_relation = data['relation_at']
103 | rsp_relation = data['ans_idx']
104 |
105 | utrs_same_spk_with_rsp_spk = []
106 | for utr_id, utr_spk in enumerate(ctx_spk):
107 | if utr_spk == rsp_spk:
108 | utrs_same_spk_with_rsp_spk.append(utr_id)
109 |
110 | if len(utrs_same_spk_with_rsp_spk) == 0:
111 | continue
112 |
113 | label = [0 for _ in range(len(ctx))]
114 | for utr_id in utrs_same_spk_with_rsp_spk:
115 | label[utr_id] = 1
116 |
117 | # construct the reply mask
118 | integrate_ctx = ctx + [rsp]
119 | integrate_ctx_spk = ctx_spk + [rsp_spk]
120 | integrate_ctx_adr = ctx_adr + [rsp_adr]
121 | assert len(integrate_ctx) == len(integrate_ctx_spk)
122 | assert len(integrate_ctx) == len(integrate_ctx_adr)
123 | integrate_ctx_relation = ctx_relation + [[len(ctx), rsp_relation]]
124 |
125 | reply_mask = [[0 for _ in range(len(integrate_ctx))] for _ in range(len(integrate_ctx))]
126 | for relation in integrate_ctx_relation:
127 | tgt, src = relation
128 | reply_mask[tgt][src] = 1 # reply
129 | reply_mask[src][tgt] = 2 # replied_by
130 | for diagonal in range(len(integrate_ctx)):
131 | reply_mask[diagonal][diagonal] = 3 # reply_to_itself
132 |
133 | dataset.append((ctx, ctx_spk, rsp, rsp_spk, reply_mask, label))
134 |
135 | print("dataset_size: {}".format(len(dataset)))
136 | return dataset
137 |
138 |
139 | class InputExample(object):
140 | def __init__(self, guid, ctx, ctx_spk, rsp, rsp_spk, label, reply_mask):
141 | """Constructs a InputExample."""
142 | self.guid = guid
143 | self.ctx = ctx
144 | self.ctx_spk = ctx_spk
145 | self.rsp = rsp
146 | self.rsp_spk = rsp_spk
147 | self.reply_mask = reply_mask
148 | self.label = label
149 |
150 |
151 | def create_examples(lines, set_type):
152 | """Creates examples for datasets."""
153 | examples = []
154 | for (i, line) in enumerate(lines):
155 | guid = "%s-%s" % (set_type, str(i))
156 | ctx = [tokenization.convert_to_unicode(utr) for utr in line[0]]
157 | ctx_spk = line[1]
158 | rsp = tokenization.convert_to_unicode(line[2])
159 | rsp_spk = line[3]
160 | reply_mask = line[4]
161 | label = line[-1]
162 | examples.append(InputExample(guid=guid, ctx=ctx, ctx_spk=ctx_spk, rsp=rsp, rsp_spk=rsp_spk, reply_mask=reply_mask, label=label))
163 | return examples
164 |
165 |
166 | def truncate_seq_pair(ctx_tokens, rsp_tokens, max_length):
167 | """Truncates a sequence pair in place to the maximum length."""
168 | while True:
169 | utr_lens = [len(utr_tokens) for utr_tokens in ctx_tokens]
170 | total_length = sum(utr_lens) + len(rsp_tokens)
171 | if total_length <= max_length:
172 | break
173 |
174 | # truncate the longest utterance or response
175 | if sum(utr_lens) > len(rsp_tokens):
176 | trunc_tokens = ctx_tokens[np.argmax(np.array(utr_lens))]
177 | else:
178 | trunc_tokens = rsp_tokens
179 | assert len(trunc_tokens) >= 1
180 |
181 | if random.random() < 0.5:
182 | del trunc_tokens[0]
183 | else:
184 | trunc_tokens.pop()
185 |
186 |
187 | class InputFeatures(object):
188 | """A single set of features of data."""
189 | def __init__(self, input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, reply_mask_utr2word_flatten, utr_lens, label_id):
190 | self.input_sents = input_sents
191 | self.input_mask = input_mask
192 | self.segment_ids = segment_ids
193 | self.speaker_ids = speaker_ids
194 | self.cls_positions = cls_positions
195 | self.rsp_position = rsp_position
196 | self.reply_mask_utr2word_flatten = reply_mask_utr2word_flatten
197 | self.utr_lens = utr_lens
198 | self.label_id = label_id
199 |
200 |
201 | def convert_examples_to_features(examples, max_seq_length, max_utr_num, tokenizer):
202 | """Loads a data file into a list of `InputBatch`s."""
203 |
204 | features = []
205 | for example in tqdm(examples, total=len(examples)):
206 |
207 | ctx_tokens = []
208 | for utr in example.ctx:
209 | utr_tokens = tokenizer.tokenize(utr)
210 | ctx_tokens.append(utr_tokens)
211 | assert len(ctx_tokens) == len(example.ctx_spk)
212 |
213 | rsp_tokens = tokenizer.tokenize(example.rsp)
214 |
215 | # [CLS]s for context, [CLS] for response, [SEP]
216 | max_num_tokens = max_seq_length - len(ctx_tokens) - 1 - 1
217 | truncate_seq_pair(ctx_tokens, rsp_tokens, max_num_tokens)
218 |
219 |
220 | tokens = []
221 | segment_ids = []
222 | speaker_ids = []
223 | cls_positions = []
224 | rsp_position = []
225 | reply_mask_utr2word = [[] for _ in range(len(example.reply_mask))]
226 | utr_lens = []
227 |
228 | # utterances
229 | for i in range(len(ctx_tokens)):
230 | utr_tokens = ctx_tokens[i]
231 | utr_spk = example.ctx_spk[i]
232 |
233 | utr_lens.append(len(utr_tokens) + 1) # +1 for [CLS]
234 | cls_positions.append(len(tokens))
235 | tokens.append("[CLS]")
236 | segment_ids.append(0)
237 | speaker_ids.append(utr_spk)
238 | for j in range(len(example.reply_mask)):
239 | reply_mask_utr2word[j].append(example.reply_mask[j][i])
240 |
241 | for token in utr_tokens:
242 | tokens.append(token)
243 | segment_ids.append(0)
244 | speaker_ids.append(utr_spk)
245 | for j in range(len(example.reply_mask)):
246 | reply_mask_utr2word[j].append(example.reply_mask[j][i])
247 |
248 | # response
249 | utr_lens.append(len(rsp_tokens) + 2) # +2 for [CLS] and [SEP]
250 | rsp_position.append(len(cls_positions))
251 | cls_positions.append(len(tokens))
252 | tokens.append("[CLS]")
253 | segment_ids.append(0)
254 | # speaker_ids.append(example.rsp_spk)
255 | speaker_ids.append(0) # 0 for mask
256 | for j in range(len(example.reply_mask)):
257 | reply_mask_utr2word[j].append(example.reply_mask[j][-1])
258 |
259 | for token in rsp_tokens:
260 | tokens.append(token)
261 | segment_ids.append(0)
262 | # speaker_ids.append(example.rsp_spk)
263 | speaker_ids.append(0)
264 | for j in range(len(example.reply_mask)):
265 | reply_mask_utr2word[j].append(example.reply_mask[j][-1])
266 |
267 | tokens.append("[SEP]")
268 | segment_ids.append(0)
269 | # speaker_ids.append(example.rsp_spk)
270 | speaker_ids.append(0)
271 | for j in range(len(example.reply_mask)):
272 | reply_mask_utr2word[j].append(example.reply_mask[j][-1])
273 |
274 | assert len(utr_lens) == len(reply_mask_utr2word)
275 | for i in range(len(reply_mask_utr2word)):
276 | assert len(reply_mask_utr2word[i]) <= max_seq_length
277 | while len(reply_mask_utr2word[i]) < max_seq_length:
278 | reply_mask_utr2word[i].append(0)
279 | assert len(reply_mask_utr2word[i]) == max_seq_length
280 |
281 | assert len(reply_mask_utr2word) <= max_utr_num
282 | while len(reply_mask_utr2word) < max_utr_num:
283 | reply_mask_utr2word.append([0]*max_seq_length)
284 | assert len(reply_mask_utr2word) == max_utr_num
285 |
286 | reply_mask_utr2word_flatten = []
287 | for x in reply_mask_utr2word:
288 | reply_mask_utr2word_flatten.extend(x)
289 | assert len(reply_mask_utr2word_flatten) == max_seq_length * max_utr_num
290 |
291 | assert len(utr_lens) <= max_utr_num
292 | while len(utr_lens) < max_utr_num:
293 | utr_lens.append(0)
294 | assert len(utr_lens) == max_utr_num
295 |
296 |
297 | input_sents = tokenizer.convert_tokens_to_ids(tokens)
298 | input_mask = [1] * len(input_sents)
299 | assert len(input_sents) <= max_seq_length
300 | while len(input_sents) < max_seq_length:
301 | input_sents.append(0)
302 | input_mask.append(0)
303 | segment_ids.append(0)
304 | speaker_ids.append(0)
305 | assert len(input_sents) == max_seq_length
306 | assert len(input_mask) == max_seq_length
307 | assert len(segment_ids) == max_seq_length
308 | assert len(speaker_ids) == max_seq_length
309 |
310 | assert len(cls_positions) <= max_utr_num
311 | while len(cls_positions) < max_utr_num:
312 | cls_positions.append(0)
313 | assert len(cls_positions) == max_utr_num
314 |
315 | label_id = example.label
316 | assert len(label_id) <= max_utr_num
317 | while len(label_id) < max_utr_num:
318 | label_id.append(0)
319 | assert len(label_id) == max_utr_num
320 |
321 | features.append(
322 | InputFeatures(
323 | input_sents=input_sents,
324 | input_mask=input_mask,
325 | segment_ids=segment_ids,
326 | speaker_ids=speaker_ids,
327 | cls_positions=cls_positions,
328 | rsp_position=rsp_position,
329 | reply_mask_utr2word_flatten=reply_mask_utr2word_flatten,
330 | utr_lens=utr_lens,
331 | label_id=label_id))
332 |
333 | return features
334 |
335 |
336 | def write_instance_to_example_files(instances, output_files):
337 | writers = []
338 |
339 | for output_file in output_files:
340 | writers.append(tf.python_io.TFRecordWriter(output_file))
341 |
342 | writer_index = 0
343 | total_written = 0
344 | for (inst_index, instance) in enumerate(instances):
345 | features = collections.OrderedDict()
346 | features["input_sents"] = create_int_feature(instance.input_sents)
347 | features["input_mask"] = create_int_feature(instance.input_mask)
348 | features["segment_ids"] = create_int_feature(instance.segment_ids)
349 | features["speaker_ids"] = create_int_feature(instance.speaker_ids)
350 | features["cls_positions"] = create_int_feature(instance.cls_positions)
351 | features["rsp_position"] = create_int_feature(instance.rsp_position)
352 | features["reply_mask_utr2word_flatten"] = create_int_feature(instance.reply_mask_utr2word_flatten)
353 | features["utr_lens"] = create_int_feature(instance.utr_lens)
354 | features["label_ids"] = create_int_feature(instance.label_id)
355 |
356 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
357 |
358 | writers[writer_index].write(tf_example.SerializeToString())
359 | writer_index = (writer_index + 1) % len(writers)
360 |
361 | total_written += 1
362 |
363 | print("write_{}_instance_to_example_files".format(total_written))
364 |
365 | for feature_name in features.keys():
366 | feature = features[feature_name]
367 | values = []
368 | if feature.int64_list.value:
369 | values = feature.int64_list.value
370 | elif feature.float_list.value:
371 | values = feature.float_list.value
372 | tf.logging.info(
373 | "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
374 |
375 | for writer in writers:
376 | writer.close()
377 |
378 |
379 | def create_int_feature(values):
380 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
381 | return feature
382 |
383 | def create_float_feature(values):
384 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
385 | return feature
386 |
387 |
388 |
389 | if __name__ == "__main__":
390 |
391 | FLAGS = tf.flags.FLAGS
392 | print_configuration_op(FLAGS)
393 |
394 | tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
395 |
396 | filenames = [FLAGS.train_file, FLAGS.valid_file, FLAGS.test_file]
397 | filetypes = ["train", "valid", "test"]
398 | for (filename, filetype) in zip(filenames, filetypes):
399 | dataset = load_dataset(filename)
400 | examples = create_examples(dataset, filetype)
401 | features = convert_examples_to_features(examples, FLAGS.max_seq_length, FLAGS.max_utr_num, tokenizer)
402 | new_filename = filename[:-5] + "_si_gift.tfrecord"
403 | write_instance_to_example_files(features, [new_filename])
404 | print('Convert {} to {} done'.format(filename, new_filename))
405 |
406 | print("Sub-process(es) done.")
407 |
--------------------------------------------------------------------------------
/data/emnlp2016/README.md:
--------------------------------------------------------------------------------
1 | The files of 5.json, 10.json and 15.json are placed in this folder.
2 |
3 | Run the following command to derive the processed datasets.
4 | ```
5 | python data_preprocess.py
6 | ```
7 |
--------------------------------------------------------------------------------
/data/emnlp2016/data_preprocess.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 |
4 |
5 | def get_relation(ctx_spk_index, ctx_adr_index, rsp_spk_index, rsp_adr_index):
6 |
7 | ctx_relation = []
8 | ctx_spk_2_utr = {}
9 | for utr_id, (utr_spk, utr_adr) in enumerate(zip(ctx_spk_index, ctx_adr_index)):
10 | if utr_spk in ctx_spk_2_utr:
11 | ctx_spk_2_utr[utr_spk].append(utr_id)
12 | else:
13 | ctx_spk_2_utr[utr_spk] = [utr_id]
14 |
15 | if utr_adr == -1:
16 | continue
17 |
18 | if utr_adr in ctx_spk_2_utr:
19 | src = ctx_spk_2_utr[utr_adr][-1]
20 | tgt = utr_id
21 | ctx_relation.append([tgt, src])
22 |
23 | if rsp_adr_index in ctx_spk_2_utr:
24 | rsp_idx = ctx_spk_2_utr[rsp_adr_index][-1]
25 | else:
26 | rsp_idx = -1
27 |
28 | return ctx_relation, rsp_idx
29 |
30 |
31 | # main
32 | for dialogue_len in [5, 10, 15]:
33 | print("Processing the dataset of conversation length: {} ...".format(dialogue_len))
34 |
35 | with open("{}.json".format(dialogue_len), "r") as fin:
36 | data = json.load(fin)
37 |
38 | for split, dialogues in data.items():
39 | with open("{}_{}.json".format(dialogue_len, split), "w") as fout:
40 |
41 | for dialogue in tqdm(dialogues, total=len(dialogues)):
42 | assert len(dialogue) == dialogue_len
43 |
44 | user_index = {'-': -1}
45 |
46 | # context
47 | ctx = []
48 | ctx_spk = []
49 | ctx_adr = []
50 | for utterance in dialogue[:-1]:
51 | assert len(utterance) == 3
52 | utr_spk = utterance[0]
53 | utr = utterance[1]
54 | utr_adr = utterance[2]
55 |
56 | if utr_spk not in user_index:
57 | user_index[utr_spk] = len(user_index)
58 | if utr_adr not in user_index:
59 | user_index[utr_adr] = len(user_index)
60 |
61 | ctx.append(utr)
62 | ctx_spk.append(user_index[utr_spk])
63 | ctx_adr.append(user_index[utr_adr])
64 |
65 | # response
66 | response = dialogue[-1]
67 | rsp_spk = response[0]
68 | rsp = response[1]
69 | rsp_adr = response[2]
70 |
71 | if rsp_spk not in user_index:
72 | user_index[rsp_spk] = len(user_index)
73 | assert rsp_adr in user_index
74 | assert rsp_adr != '-'
75 | # if rsp_adr not in user_index:
76 | # user_index[rsp_adr] = len(user_index)
77 |
78 | rsp_spk = user_index[rsp_spk]
79 | rsp_adr = user_index[rsp_adr]
80 |
81 | ctx_relation, rsp_idx = get_relation(ctx_spk, ctx_adr, rsp_spk, rsp_adr)
82 |
83 | d = {}
84 | d['context'] = ctx
85 | d['relation_at'] = ctx_relation
86 | d['ctx_spk'] = ctx_spk
87 | d['ctx_adr'] = ctx_adr
88 |
89 | d['answer'] = rsp
90 | d['ans_idx'] = rsp_idx
91 | d['ans_spk'] = rsp_spk
92 | d['ans_adr'] = rsp_adr
93 |
94 | json_str = json.dumps(d) # indent=2
95 | fout.write(json_str +'\n')
96 |
--------------------------------------------------------------------------------
/data/ijcai2019/README.md:
--------------------------------------------------------------------------------
1 | The files of train.json, dev.json and test.json are placed in this folder.
2 |
--------------------------------------------------------------------------------
/image/result_addressee_recognition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JasonForJoy/MPC-BERT/d0171e941facf2066d9c23e2c8f936e41b94a008/image/result_addressee_recognition.png
--------------------------------------------------------------------------------
/image/result_addressee_recognition_gift.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JasonForJoy/MPC-BERT/d0171e941facf2066d9c23e2c8f936e41b94a008/image/result_addressee_recognition_gift.png
--------------------------------------------------------------------------------
/image/result_response_selection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JasonForJoy/MPC-BERT/d0171e941facf2066d9c23e2c8f936e41b94a008/image/result_response_selection.png
--------------------------------------------------------------------------------
/image/result_response_selection_gift.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JasonForJoy/MPC-BERT/d0171e941facf2066d9c23e2c8f936e41b94a008/image/result_response_selection_gift.png
--------------------------------------------------------------------------------
/image/result_speaker_identification.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JasonForJoy/MPC-BERT/d0171e941facf2066d9c23e2c8f936e41b94a008/image/result_speaker_identification.png
--------------------------------------------------------------------------------
/image/result_speaker_identification_gift.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JasonForJoy/MPC-BERT/d0171e941facf2066d9c23e2c8f936e41b94a008/image/result_speaker_identification_gift.png
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import operator
2 | import math
3 |
4 |
5 | def is_valid_query(v):
6 | num_pos = 0
7 | num_neg = 0
8 | for aid, label, score in v:
9 | if label > 0:
10 | num_pos += 1
11 | else:
12 | num_neg += 1
13 | if num_pos > 0 and num_neg > 0:
14 | return True
15 | else:
16 | return False
17 |
18 |
19 | def get_num_valid_query(results):
20 | num_query = 0
21 | for k, v in results.items():
22 | if not is_valid_query(v):
23 | continue
24 | num_query += 1
25 | return num_query
26 |
27 |
28 | def top_1_precision(results):
29 | num_query = 0
30 | top_1_correct = 0.0
31 | for k, v in results.items():
32 | if not is_valid_query(v):
33 | continue
34 | num_query += 1
35 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
36 | aid, label, score = sorted_v[0]
37 | if label > 0:
38 | top_1_correct += 1
39 |
40 | if num_query > 0:
41 | return top_1_correct / num_query
42 | else:
43 | return 0.0
44 |
45 |
46 | def mean_reciprocal_rank(results):
47 | num_query = 0
48 | mrr = 0.0
49 | for k, v in results.items():
50 | if not is_valid_query(v):
51 | continue
52 |
53 | num_query += 1
54 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
55 | for i, rec in enumerate(sorted_v):
56 | aid, label, score = rec
57 | if label > 0:
58 | mrr += 1.0 / (i + 1)
59 | break
60 |
61 | if num_query == 0:
62 | return 0.0
63 | else:
64 | mrr = mrr / num_query
65 | return mrr
66 |
67 |
68 | def mean_average_precision(results):
69 | num_query = 0
70 | mvp = 0.0
71 | for k, v in results.items():
72 | if not is_valid_query(v):
73 | continue
74 |
75 | num_query += 1
76 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
77 | num_relevant_doc = 0.0
78 | avp = 0.0
79 | for i, rec in enumerate(sorted_v):
80 | aid, label, score = rec
81 | if label == 1:
82 | num_relevant_doc += 1
83 | precision = num_relevant_doc / (i + 1)
84 | avp += precision
85 | avp = avp / num_relevant_doc
86 | mvp += avp
87 |
88 | if num_query == 0:
89 | return 0.0
90 | else:
91 | mvp = mvp / num_query
92 | return mvp
93 |
94 |
95 | def classification_metrics(results):
96 | total_num = 0
97 | total_correct = 0
98 | true_positive = 0
99 | positive_correct = 0
100 | predicted_positive = 0
101 |
102 | loss = 0.0;
103 | for k, v in results.items():
104 | for rec in v:
105 | total_num += 1
106 | aid, label, score = rec
107 |
108 | if score > 0.5:
109 | predicted_positive += 1
110 |
111 | if label > 0:
112 | true_positive += 1
113 | loss += -math.log(score + 1e-12)
114 | else:
115 | loss += -math.log(1.0 - score + 1e-12);
116 |
117 | if score > 0.5 and label > 0:
118 | total_correct += 1
119 | positive_correct += 1
120 |
121 | if score < 0.5 and label < 0.5:
122 | total_correct += 1
123 |
124 | accuracy = float(total_correct) / total_num
125 | precision = float(positive_correct) / (predicted_positive + 1e-12)
126 | recall = float(positive_correct) / true_positive
127 | F1 = 2.0 * precision * recall / (1e-12 + precision + recall)
128 | return accuracy, precision, recall, F1, loss / total_num;
129 |
130 |
131 | def top_k_precision(results, k=1):
132 | num_query = 0
133 | top_1_correct = 0.0
134 | for key, v in results.items():
135 | if not is_valid_query(v):
136 | continue
137 | num_query += 1
138 | sorted_v = sorted(v, key=operator.itemgetter(2), reverse=True)
139 | if k == 1:
140 | aid, label, score = sorted_v[0]
141 | if label > 0:
142 | top_1_correct += 1
143 | elif k == 2:
144 | aid1, label1, score1 = sorted_v[0]
145 | aid2, label2, score2 = sorted_v[1]
146 | if label1 > 0 or label2 > 0:
147 | top_1_correct += 1
148 | elif k == 5:
149 | for vv in sorted_v[0:5]:
150 | label = vv[1]
151 | if label > 0:
152 | top_1_correct += 1
153 | break
154 | else:
155 | raise BaseException
156 |
157 | if num_query > 0:
158 | return top_1_correct/num_query
159 | else:
160 | return 0.0
--------------------------------------------------------------------------------
/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Functions and classes related to optimization (weight updates)."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import re
22 | import tensorflow as tf
23 |
24 |
25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
26 | """Creates an optimizer training op."""
27 | global_step = tf.train.get_or_create_global_step()
28 |
29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
30 |
31 | # Implements linear decay of the learning rate.
32 | learning_rate = tf.train.polynomial_decay(
33 | learning_rate,
34 | global_step,
35 | num_train_steps,
36 | end_learning_rate=0.0,
37 | power=1.0,
38 | cycle=False)
39 |
40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
41 | # learning rate will be `global_step/num_warmup_steps * init_lr`.
42 | if num_warmup_steps:
43 | global_steps_int = tf.cast(global_step, tf.int32)
44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)
45 |
46 | global_steps_float = tf.cast(global_steps_int, tf.float32)
47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)
48 |
49 | warmup_percent_done = global_steps_float / warmup_steps_float
50 | warmup_learning_rate = init_lr * warmup_percent_done
51 |
52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
53 | learning_rate = (
54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)
55 |
56 | # It is recommended that you use this optimizer for fine tuning, since this
57 | # is how the model was trained (note that the Adam m/v variables are NOT
58 | # loaded from init_checkpoint.)
59 | optimizer = AdamWeightDecayOptimizer(
60 | learning_rate=learning_rate,
61 | weight_decay_rate=0.01,
62 | beta_1=0.9,
63 | beta_2=0.999,
64 | epsilon=1e-6,
65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
66 |
67 | if use_tpu:
68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
69 |
70 | tvars = tf.trainable_variables()
71 | grads = tf.gradients(loss, tvars)
72 |
73 | # This is how the model was pre-trained.
74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
75 |
76 | train_op = optimizer.apply_gradients(
77 | zip(grads, tvars), global_step=global_step)
78 |
79 | # Normally the global step update is done inside of `apply_gradients`.
80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
81 | # a different optimizer, you should probably take this line out.
82 | new_global_step = global_step + 1
83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)])
84 | return train_op
85 |
86 |
87 | class AdamWeightDecayOptimizer(tf.train.Optimizer):
88 | """A basic Adam optimizer that includes "correct" L2 weight decay."""
89 |
90 | def __init__(self,
91 | learning_rate,
92 | weight_decay_rate=0.0,
93 | beta_1=0.9,
94 | beta_2=0.999,
95 | epsilon=1e-6,
96 | exclude_from_weight_decay=None,
97 | name="AdamWeightDecayOptimizer"):
98 | """Constructs a AdamWeightDecayOptimizer."""
99 | super(AdamWeightDecayOptimizer, self).__init__(False, name)
100 |
101 | self.learning_rate = learning_rate
102 | self.weight_decay_rate = weight_decay_rate
103 | self.beta_1 = beta_1
104 | self.beta_2 = beta_2
105 | self.epsilon = epsilon
106 | self.exclude_from_weight_decay = exclude_from_weight_decay
107 |
108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
109 | """See base class."""
110 | assignments = []
111 | for (grad, param) in grads_and_vars:
112 | if grad is None or param is None:
113 | continue
114 |
115 | param_name = self._get_variable_name(param.name)
116 |
117 | m = tf.get_variable(
118 | name=param_name + "/adam_m",
119 | shape=param.shape.as_list(),
120 | dtype=tf.float32,
121 | trainable=False,
122 | initializer=tf.zeros_initializer())
123 | v = tf.get_variable(
124 | name=param_name + "/adam_v",
125 | shape=param.shape.as_list(),
126 | dtype=tf.float32,
127 | trainable=False,
128 | initializer=tf.zeros_initializer())
129 |
130 | # Standard Adam update.
131 | next_m = (
132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad))
133 | next_v = (
134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
135 | tf.square(grad)))
136 |
137 | update = next_m / (tf.sqrt(next_v) + self.epsilon)
138 |
139 | # Just adding the square of the weights to the loss function is *not*
140 | # the correct way of using L2 regularization/weight decay with Adam,
141 | # since that will interact with the m and v parameters in strange ways.
142 | #
143 | # Instead we want ot decay the weights in a manner that doesn't interact
144 | # with the m/v parameters. This is equivalent to adding the square
145 | # of the weights to the loss with plain (non-momentum) SGD.
146 | if self._do_use_weight_decay(param_name):
147 | update += self.weight_decay_rate * param
148 |
149 | update_with_lr = self.learning_rate * update
150 |
151 | next_param = param - update_with_lr
152 |
153 | assignments.extend(
154 | [param.assign(next_param),
155 | m.assign(next_m),
156 | v.assign(next_v)])
157 | return tf.group(*assignments, name=name)
158 |
159 | def _do_use_weight_decay(self, param_name):
160 | """Whether to use L2 weight decay for `param_name`."""
161 | if not self.weight_decay_rate:
162 | return False
163 | if self.exclude_from_weight_decay:
164 | for r in self.exclude_from_weight_decay:
165 | if re.search(r, param_name) is not None:
166 | return False
167 | return True
168 |
169 | def _get_variable_name(self, param_name):
170 | """Get the variable name from the tensor name."""
171 | m = re.match("^(.*):\\d+$", param_name)
172 | if m is not None:
173 | param_name = m.group(1)
174 | return param_name
175 |
--------------------------------------------------------------------------------
/run_finetuning_rs.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """MPC-BERT finetuning runner on the downstream task of response selection."""
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import operator
10 | from time import time
11 | from collections import defaultdict
12 | import tensorflow as tf
13 | import optimization
14 | import tokenization
15 | import modeling_speaker as modeling
16 | import metrics
17 |
18 | flags = tf.flags
19 | FLAGS = flags.FLAGS
20 |
21 | flags.DEFINE_string("train_dir", 'train.tfrecord',
22 | "The input train data dir. Should contain the .tsv files (or other data files) for the task.")
23 |
24 | flags.DEFINE_string("valid_dir", 'valid.tfrecord',
25 | "The input valid data dir. Should contain the .tsv files (or other data files) for the task.")
26 |
27 | flags.DEFINE_string("output_dir", 'output',
28 | "The output directory where the model checkpoints will be written.")
29 |
30 | flags.DEFINE_string("task_name", 'ResponseSelection',
31 | "The name of the task to train.")
32 |
33 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json',
34 | "The config json file corresponding to the pre-trained BERT model. "
35 | "This specifies the model architecture.")
36 |
37 | flags.DEFINE_string("vocab_file", 'uncased_L-12_H-768_A-12/vocab.txt',
38 | "The vocabulary file that the BERT model was trained on.")
39 |
40 | flags.DEFINE_string("init_checkpoint", 'uncased_L-12_H-768_A-12/bert_model.ckpt',
41 | "Initial checkpoint (usually from a pre-trained BERT model).")
42 |
43 | flags.DEFINE_bool("do_lower_case", True,
44 | "Whether to lower case the input text. Should be True for uncased "
45 | "models and False for cased models.")
46 |
47 | flags.DEFINE_integer("max_seq_length", 320,
48 | "The maximum total input sequence length after WordPiece tokenization. "
49 | "Sequences longer than this will be truncated, and sequences shorter "
50 | "than this will be padded.")
51 |
52 | flags.DEFINE_integer("max_utr_num", 7,
53 | "Maximum utterance number.")
54 |
55 | flags.DEFINE_bool("do_train", True,
56 | "Whether to run training.")
57 |
58 | flags.DEFINE_float("warmup_proportion", 0.1,
59 | "Proportion of training to perform linear learning rate warmup for. "
60 | "E.g., 0.1 = 10% of training.")
61 |
62 | flags.DEFINE_integer("train_batch_size", 12,
63 | "Total batch size for training.")
64 |
65 | flags.DEFINE_float("learning_rate", 2e-5,
66 | "The initial learning rate for Adam.")
67 |
68 | flags.DEFINE_integer("num_train_epochs", 5,
69 | "Total number of training epochs to perform.")
70 |
71 |
72 |
73 | def print_configuration_op(FLAGS):
74 | print('My Configurations:')
75 | for name, value in FLAGS.__flags.items():
76 | value=value.value
77 | if type(value) == float:
78 | print(' %s:\t %f'%(name, value))
79 | elif type(value) == int:
80 | print(' %s:\t %d'%(name, value))
81 | elif type(value) == str:
82 | print(' %s:\t %s'%(name, value))
83 | elif type(value) == bool:
84 | print(' %s:\t %s'%(name, value))
85 | else:
86 | print('%s:\t %s' % (name, value))
87 | print('End of configuration')
88 |
89 |
90 | def count_data_size(file_name):
91 | sample_nums = 0
92 | for record in tf.python_io.tf_record_iterator(file_name):
93 | sample_nums += 1
94 | return sample_nums
95 |
96 |
97 | def parse_exmp(serial_exmp):
98 | input_data = tf.parse_single_example(serial_exmp,
99 | features={
100 | "ctx_id":
101 | tf.FixedLenFeature([], tf.int64),
102 | "rsp_id":
103 | tf.FixedLenFeature([], tf.int64),
104 | "input_sents":
105 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
106 | "input_mask":
107 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
108 | "segment_ids":
109 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
110 | "speaker_ids":
111 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
112 | "label_ids":
113 | tf.FixedLenFeature([], tf.float32),
114 | }
115 | )
116 | # So cast all int64 to int32.
117 | for name in list(input_data.keys()):
118 | t = input_data[name]
119 | if t.dtype == tf.int64:
120 | t = tf.to_int32(t)
121 | input_data[name] = t
122 |
123 | ctx_id = input_data["ctx_id"]
124 | rsp_id = input_data['rsp_id']
125 | input_sents = input_data["input_sents"]
126 | input_mask = input_data["input_mask"]
127 | segment_ids= input_data["segment_ids"]
128 | speaker_ids= input_data["speaker_ids"]
129 | labels = input_data['label_ids']
130 | return ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, labels
131 |
132 |
133 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, labels, ctx_id, rsp_id,
134 | num_labels, use_one_hot_embeddings):
135 | """Creates a classification model."""
136 | model = modeling.BertModel(
137 | config=bert_config,
138 | is_training=is_training,
139 | input_ids=input_ids,
140 | input_mask=input_mask,
141 | token_type_ids=segment_ids,
142 | speaker_ids=speaker_ids,
143 | use_one_hot_embeddings=use_one_hot_embeddings)
144 |
145 | target_loss_weight = [1.0, 1.0]
146 | target_loss_weight = tf.convert_to_tensor(target_loss_weight)
147 |
148 | flagx = tf.cast(tf.greater(labels, 0), dtype=tf.float32)
149 | flagy = tf.cast(tf.equal(labels, 0), dtype=tf.float32)
150 |
151 | all_target_loss = target_loss_weight[1] * flagx + target_loss_weight[0] * flagy
152 |
153 | output_layer = model.get_pooled_output()
154 | hidden_size = output_layer.shape[-1].value
155 |
156 | output_weights = tf.get_variable(
157 | "output_weights", [num_labels, hidden_size],
158 | initializer=tf.truncated_normal_initializer(stddev=0.02))
159 | output_bias = tf.get_variable(
160 | "output_bias", [num_labels], initializer=tf.zeros_initializer())
161 |
162 | with tf.variable_scope("loss"):
163 | output_layer = tf.layers.dropout(output_layer, rate=0.1, training=is_training)
164 |
165 | logits = tf.matmul(output_layer, output_weights, transpose_b=True)
166 | logits = tf.nn.bias_add(logits, output_bias)
167 |
168 | probabilities = tf.sigmoid(logits, name="prob")
169 | logits = tf.squeeze(logits,[1])
170 | losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
171 | losses = tf.multiply(losses, all_target_loss)
172 |
173 | mean_loss = tf.reduce_mean(losses, name="mean_loss") + sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
174 |
175 | with tf.name_scope("accuracy"):
176 | correct_prediction = tf.equal(tf.sign(probabilities - 0.5), tf.sign(labels - 0.5))
177 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy")
178 |
179 | return mean_loss, logits, probabilities, accuracy
180 |
181 |
182 | def run_epoch(epoch, op_name, sess, training, logits, accuracy, mean_loss, train_opt):
183 |
184 | step = 0
185 | t0 = time()
186 |
187 | try:
188 | while True:
189 | step += 1
190 | batch_logits, batch_loss, _, batch_accuracy = sess.run([logits, mean_loss, train_opt, accuracy], feed_dict={training:True})
191 |
192 | if step % 1000 == 0:
193 | tf.logging.info("Epoch: %i, Step: %d, Time (min): %.2f, Loss: %.4f, Accuracy: %.2f" %
194 | (epoch, step, (time() - t0) / 60.0, batch_loss, 100 * batch_accuracy))
195 |
196 | except tf.errors.OutOfRangeError:
197 | tf.logging.info("Epoch: %i, Step: %d, Time (min): %.2f, Loss: %.4f, Accuracy: %.2f" %
198 | (epoch, step, (time() - t0) / 60.0, batch_loss, 100 * batch_accuracy))
199 | pass
200 |
201 |
202 | best_score = 0.0
203 | def run_test(epoch, op_name, sess, training, accuracy, prob, pair_ids, saver, dir_path):
204 |
205 | step = 0
206 | t0 = time()
207 | num_test = 0
208 | num_correct = 0.0
209 | mrr = 0
210 | results = defaultdict(list)
211 |
212 | try:
213 | while True:
214 | step += 1
215 | batch_accuracy, predicted_prob, pair_ = sess.run([accuracy, prob, pair_ids], feed_dict={training:False})
216 | question_id, answer_id, label = pair_
217 |
218 | num_test += len(predicted_prob)
219 | num_correct += len(predicted_prob) * batch_accuracy
220 | for i, prob_score in enumerate(predicted_prob):
221 | results[question_id[i]].append((answer_id[i], label[i], prob_score[0]))
222 |
223 | if step % 1000 == 0:
224 | tf.logging.info("Epoch: %i, Step: %d, Time (min): %.2f" % (epoch, step, (time() - t0)/60.0 ))
225 |
226 | except tf.errors.OutOfRangeError:
227 | print('num_test_samples: {}, test_accuracy: {}'.format(num_test, num_correct / num_test))
228 | accu, precision, recall, f1, loss = metrics.classification_metrics(results)
229 | print('Accuracy: {}, Precision: {}, Recall: {}, F1: {}, Loss: {}'.format(accu, precision, recall, f1, loss))
230 |
231 | mvp = metrics.mean_average_precision(results)
232 | mrr = metrics.mean_reciprocal_rank(results)
233 | top_1_precision = metrics.top_1_precision(results)
234 | total_valid_query = metrics.get_num_valid_query(results)
235 | print('MAP (mean average precision: {}, MRR (mean reciprocal rank): {}, Top-1 precision: {}, Num_query: {}'.format(
236 | mvp, mrr, top_1_precision, total_valid_query))
237 |
238 | out_path = os.path.join(dir_path, "output_epoch_{}.txt".format(epoch))
239 | print("Saving evaluation to {}".format(out_path))
240 | with open(out_path, 'w') as f:
241 | f.write("query_id\tdocument_id\tscore\trank\trelevance\n")
242 | for us_id, v in results.items():
243 | v.sort(key=operator.itemgetter(2), reverse=True)
244 | for i, rec in enumerate(v):
245 | r_id, label, prob_score = rec
246 | rank = i + 1
247 | f.write('{}\t{}\t{}\t{}\t{}\n'.format(us_id, r_id, prob_score, rank, label))
248 |
249 | global best_score
250 | if op_name == 'valid' and mrr > best_score:
251 | best_score = mrr
252 | dir_path = os.path.join(dir_path, "epoch_{}".format(epoch))
253 | saver.save(sess, dir_path)
254 | tf.logging.info(">> Save model!")
255 |
256 | return mrr
257 |
258 |
259 |
260 | def main(_):
261 | tf.logging.set_verbosity(tf.logging.INFO)
262 | print_configuration_op(FLAGS)
263 |
264 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
265 |
266 | root_path = FLAGS.output_dir
267 | if not os.path.exists(root_path):
268 | os.makedirs(root_path)
269 | timestamp = str(int(time()))
270 | root_path = os.path.join(root_path, timestamp)
271 | tf.logging.info('root_path: {}'.format(root_path))
272 | if not os.path.exists(root_path):
273 | os.makedirs(root_path)
274 |
275 | train_data_size = count_data_size(FLAGS.train_dir)
276 | tf.logging.info('train data size: {}'.format(train_data_size))
277 | valid_data_size = count_data_size(FLAGS.valid_dir)
278 | tf.logging.info('valid data size: {}'.format(valid_data_size))
279 |
280 | num_train_steps = train_data_size // FLAGS.train_batch_size * FLAGS.num_train_epochs
281 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
282 |
283 | filenames = tf.placeholder(tf.string, shape=[None])
284 | shuffle_size = tf.placeholder(tf.int64)
285 | dataset = tf.data.TFRecordDataset(filenames)
286 | dataset = dataset.map(parse_exmp) # Parse the record into tensors.
287 | dataset = dataset.repeat(1)
288 | # buffer_size 100
289 | dataset = dataset.shuffle(shuffle_size)
290 | dataset = dataset.batch(FLAGS.train_batch_size)
291 | iterator = dataset.make_initializable_iterator()
292 | ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, labels = iterator.get_next()
293 | pair_ids = [ctx_id, rsp_id, labels]
294 |
295 |
296 | training = tf.placeholder(tf.bool)
297 | mean_loss, logits, probabilities, accuracy = create_model(bert_config = bert_config,
298 | is_training = training,
299 | input_ids = input_sents,
300 | input_mask = input_mask,
301 | segment_ids = segment_ids,
302 | speaker_ids = speaker_ids,
303 | labels = labels,
304 | ctx_id = ctx_id,
305 | rsp_id = rsp_id,
306 | num_labels = 1,
307 | use_one_hot_embeddings = False)
308 |
309 | # init model with pre-training
310 | tvars = tf.trainable_variables()
311 | if FLAGS.init_checkpoint:
312 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,FLAGS.init_checkpoint)
313 | tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map)
314 |
315 | tf.logging.info("**** Trainable Variables ****")
316 | for var in tvars:
317 | init_string = ""
318 | if var.name in initialized_variable_names:
319 | init_string = ", *INIT_FROM_CKPT*"
320 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string)
321 |
322 | train_opt = optimization.create_optimizer(mean_loss, FLAGS.learning_rate, num_train_steps, num_warmup_steps, False)
323 |
324 | config = tf.ConfigProto(allow_soft_placement=True)
325 | config.gpu_options.allow_growth = True
326 | saver = tf.train.Saver()
327 |
328 | if FLAGS.do_train:
329 | with tf.Session(config=config) as sess:
330 | sess.run(tf.global_variables_initializer())
331 |
332 | for epoch in range(FLAGS.num_train_epochs):
333 | tf.logging.info('Train begin epoch {}'.format(epoch))
334 | sess.run(iterator.initializer,
335 | feed_dict={filenames: [FLAGS.train_dir], shuffle_size: 1024})
336 | run_epoch(epoch, "train", sess, training, logits, accuracy, mean_loss, train_opt)
337 |
338 | tf.logging.info('Valid begin')
339 | sess.run(iterator.initializer,
340 | feed_dict={filenames: [FLAGS.valid_dir], shuffle_size: 1})
341 | run_test(epoch, "valid", sess, training, accuracy, probabilities, pair_ids, saver, root_path)
342 |
343 |
344 | if __name__ == "__main__":
345 | tf.app.run()
346 |
--------------------------------------------------------------------------------
/run_finetuning_si.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """MPC-BERT finetuning runner on the downstream task of speaker identification."""
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import operator
10 | from time import time
11 | from collections import defaultdict
12 | import tensorflow as tf
13 | import optimization
14 | import tokenization
15 | import modeling_speaker as modeling
16 |
17 | flags = tf.flags
18 | FLAGS = flags.FLAGS
19 |
20 | flags.DEFINE_string("train_dir", 'train.tfrecord',
21 | "The input train data dir. Should contain the .tsv files (or other data files) for the task.")
22 |
23 | flags.DEFINE_string("valid_dir", 'valid.tfrecord',
24 | "The input valid data dir. Should contain the .tsv files (or other data files) for the task.")
25 |
26 | flags.DEFINE_string("output_dir", 'output',
27 | "The output directory where the model checkpoints will be written.")
28 |
29 | flags.DEFINE_string("task_name", 'SpeakerIdentification',
30 | "The name of the task to train.")
31 |
32 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json',
33 | "The config json file corresponding to the pre-trained BERT model. "
34 | "This specifies the model architecture.")
35 |
36 | flags.DEFINE_string("vocab_file", 'uncased_L-12_H-768_A-12/vocab.txt',
37 | "The vocabulary file that the BERT model was trained on.")
38 |
39 | flags.DEFINE_string("init_checkpoint", 'uncased_L-12_H-768_A-12/bert_model.ckpt',
40 | "Initial checkpoint (usually from a pre-trained BERT model).")
41 |
42 | flags.DEFINE_bool("do_lower_case", True,
43 | "Whether to lower case the input text. Should be True for uncased "
44 | "models and False for cased models.")
45 |
46 | flags.DEFINE_integer("max_seq_length", 320,
47 | "The maximum total input sequence length after WordPiece tokenization. "
48 | "Sequences longer than this will be truncated, and sequences shorter "
49 | "than this will be padded.")
50 |
51 | flags.DEFINE_integer("max_utr_num", 7,
52 | "Maximum utterance number.")
53 |
54 | flags.DEFINE_bool("do_train", True,
55 | "Whether to run training.")
56 |
57 | flags.DEFINE_float("warmup_proportion", 0.1,
58 | "Proportion of training to perform linear learning rate warmup for. "
59 | "E.g., 0.1 = 10% of training.")
60 |
61 | flags.DEFINE_integer("train_batch_size", 12,
62 | "Total batch size for training.")
63 |
64 | flags.DEFINE_float("learning_rate", 2e-5,
65 | "The initial learning rate for Adam.")
66 |
67 | flags.DEFINE_integer("num_train_epochs", 5,
68 | "Total number of training epochs to perform.")
69 |
70 |
71 |
72 | def print_configuration_op(FLAGS):
73 | print('My Configurations:')
74 | for name, value in FLAGS.__flags.items():
75 | value=value.value
76 | if type(value) == float:
77 | print(' %s:\t %f'%(name, value))
78 | elif type(value) == int:
79 | print(' %s:\t %d'%(name, value))
80 | elif type(value) == str:
81 | print(' %s:\t %s'%(name, value))
82 | elif type(value) == bool:
83 | print(' %s:\t %s'%(name, value))
84 | else:
85 | print('%s:\t %s' % (name, value))
86 | print('End of configuration')
87 |
88 |
89 | def count_data_size(file_name):
90 | sample_nums = 0
91 | for record in tf.python_io.tf_record_iterator(file_name):
92 | sample_nums += 1
93 | return sample_nums
94 |
95 |
96 | def parse_exmp(serial_exmp):
97 | input_data = tf.parse_single_example(serial_exmp,
98 | features={
99 | "input_sents":
100 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
101 | "input_mask":
102 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
103 | "segment_ids":
104 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
105 | "speaker_ids":
106 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
107 | "cls_positions":
108 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64),
109 | "rsp_position":
110 | tf.FixedLenFeature([1], tf.int64),
111 | "label_ids":
112 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64),
113 | }
114 | )
115 | # So cast all int64 to int32.
116 | for name in list(input_data.keys()):
117 | t = input_data[name]
118 | if t.dtype == tf.int64:
119 | t = tf.to_int32(t)
120 | input_data[name] = t
121 |
122 | input_sents = input_data["input_sents"]
123 | input_mask = input_data["input_mask"]
124 | segment_ids= input_data["segment_ids"]
125 | speaker_ids= input_data["speaker_ids"]
126 | cls_positions= input_data["cls_positions"]
127 | rsp_position= input_data["rsp_position"]
128 | labels = input_data['label_ids']
129 | return input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels
130 |
131 |
132 | def gather_indexes(sequence_tensor, positions):
133 | """Gathers the vectors at the specific positions over a minibatch."""
134 | # sequence_tensor = [batch_size, seq_length, width]
135 | # positions = [batch_size, max_utr_num]
136 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
137 | batch_size = sequence_shape[0]
138 | seq_length = sequence_shape[1]
139 | width = sequence_shape[2]
140 |
141 | flat_offsets = tf.reshape(
142 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) # [batch_size, 1]
143 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) # [batch_size*max_utr_num, ]
144 | flat_sequence_tensor = tf.reshape(sequence_tensor,
145 | [batch_size * seq_length, width])
146 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) # [batch_size*max_utr_num, width]
147 | return output_tensor
148 |
149 |
150 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels,
151 | num_labels, use_one_hot_embeddings):
152 | """Creates a classification model."""
153 | model = modeling.BertModel(
154 | config=bert_config,
155 | is_training=is_training,
156 | input_ids=input_ids,
157 | input_mask=input_mask,
158 | token_type_ids=segment_ids,
159 | speaker_ids=speaker_ids,
160 | use_one_hot_embeddings=use_one_hot_embeddings)
161 |
162 | input_tensor = gather_indexes(model.get_sequence_output(), cls_positions) # [batch_size*max_utr_num, dim]
163 |
164 | input_shape = modeling.get_shape_list(input_tensor, expected_rank=2)
165 | width = input_shape[-1]
166 | positions_shape = modeling.get_shape_list(cls_positions, expected_rank=2)
167 | max_utr_num = positions_shape[-1]
168 |
169 | with tf.variable_scope("cls/speaker_restore"):
170 | # We apply one more non-linear transformation before the output layer.
171 | with tf.variable_scope("transform"):
172 | input_tensor = tf.layers.dense(
173 | input_tensor,
174 | units=bert_config.hidden_size,
175 | activation=modeling.get_activation(bert_config.hidden_act),
176 | kernel_initializer=modeling.create_initializer(bert_config.initializer_range))
177 | input_tensor = modeling.layer_norm(input_tensor) # [batch_size*max_utr_num, dim]
178 |
179 | input_tensor = tf.reshape(input_tensor, [-1, max_utr_num, width]) # [batch_size, max_utr_num, dim]
180 |
181 | rsp_tensor = gather_indexes(input_tensor, rsp_position) # [batch_size*1, dim]
182 | rsp_tensor = tf.reshape(rsp_tensor, [-1, 1, width]) # [batch_size, 1, dim]
183 |
184 | output_weights = tf.get_variable(
185 | "output_weights",
186 | shape=[width, width],
187 | initializer=modeling.create_initializer(bert_config.initializer_range))
188 | logits = tf.matmul(tf.einsum('aij,jk->aik', rsp_tensor, output_weights),
189 | input_tensor, transpose_b=True) # [batch_size, 1, max_utr_num]
190 | logits = tf.squeeze(logits, [1]) # [batch_size, max_utr_num]
191 |
192 | mask = tf.sequence_mask(tf.reshape(rsp_position, [-1, ]), max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num]
193 | logits = logits * mask + -1e9 * (1-mask)
194 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size, max_utr_num]
195 |
196 | # loss
197 | one_hot_labels = tf.cast(labels, "float") # [batch_size, max_utr_num]
198 | per_example_loss = - tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch_size, ]
199 | mean_loss = tf.reduce_mean(per_example_loss, name="mean_loss")
200 |
201 | # accuracy
202 | predictions = tf.argmax(log_probs, axis=-1, output_type=tf.int32) # [batch_size, ]
203 | predictions_one_hot = tf.one_hot(predictions, depth=max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num]
204 | correct_prediction = tf.reduce_sum(predictions_one_hot * one_hot_labels, -1) # [batch_size, ]
205 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy")
206 |
207 | return mean_loss, logits, log_probs, accuracy
208 |
209 |
210 | def run_epoch(epoch, op_name, sess, training, logits, accuracy, mean_loss, train_opt):
211 |
212 | step = 0
213 | t0 = time()
214 |
215 | try:
216 | while True:
217 | step += 1
218 | batch_logits, batch_loss, _, batch_accuracy = sess.run([logits, mean_loss, train_opt, accuracy], feed_dict={training: True})
219 |
220 | if step % 1000 == 0:
221 | tf.logging.info("Epoch: %i, Step: %d, Time (min): %.2f, Loss: %.4f, Accuracy: %.2f" %
222 | (epoch, step, (time() - t0) / 60.0, batch_loss, 100 * batch_accuracy))
223 |
224 | except tf.errors.OutOfRangeError:
225 | tf.logging.info("Epoch: %i, Step: %d, Time (min): %.2f, Loss: %.4f, Accuracy: %.2f" %
226 | (epoch, step, (time() - t0) / 60.0, batch_loss, 100 * batch_accuracy))
227 | pass
228 |
229 |
230 | best_score = 0.0
231 | def run_test(epoch, op_name, sess, training, prob, accuracy, saver, dir_path):
232 |
233 | step = 0
234 | t0 = time()
235 | num_test = 0
236 | num_correct = 0.0
237 | test_accuracy = 0
238 |
239 | try:
240 | while True:
241 | step += 1
242 | batch_accuracy, predicted_prob = sess.run([accuracy, prob], feed_dict={training: False})
243 |
244 | num_test += len(predicted_prob)
245 | num_correct += len(predicted_prob) * batch_accuracy
246 |
247 | if step % 100 == 0:
248 | tf.logging.info("Epoch: %i, Step: %d, Time (min): %.2f" % (epoch, step, (time() - t0)/60.0 ))
249 |
250 | except tf.errors.OutOfRangeError:
251 | test_accuracy = num_correct / num_test
252 | print('num_test_samples: {}, test_accuracy: {}'.format(num_test, test_accuracy))
253 |
254 | global best_score
255 | if op_name == 'valid' and test_accuracy > best_score:
256 | best_score = test_accuracy
257 | dir_path = os.path.join(dir_path, "epoch_{}".format(epoch))
258 | saver.save(sess, dir_path)
259 | tf.logging.info(">> Save model!")
260 |
261 | return test_accuracy
262 |
263 |
264 |
265 | def main(_):
266 | tf.logging.set_verbosity(tf.logging.INFO)
267 | print_configuration_op(FLAGS)
268 |
269 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
270 |
271 | root_path = FLAGS.output_dir
272 | if not os.path.exists(root_path):
273 | os.makedirs(root_path)
274 | timestamp = str(int(time()))
275 | root_path = os.path.join(root_path, timestamp)
276 | tf.logging.info('root_path: {}'.format(root_path))
277 | if not os.path.exists(root_path):
278 | os.makedirs(root_path)
279 |
280 | train_data_size = count_data_size(FLAGS.train_dir)
281 | tf.logging.info('train data size: {}'.format(train_data_size))
282 | valid_data_size = count_data_size(FLAGS.valid_dir)
283 | tf.logging.info('valid data size: {}'.format(valid_data_size))
284 |
285 | num_train_steps = train_data_size // FLAGS.train_batch_size * FLAGS.num_train_epochs
286 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
287 |
288 | filenames = tf.placeholder(tf.string, shape=[None])
289 | shuffle_size = tf.placeholder(tf.int64)
290 | dataset = tf.data.TFRecordDataset(filenames)
291 | dataset = dataset.map(parse_exmp) # Parse the record into tensors.
292 | dataset = dataset.repeat(1)
293 | # buffer_size 100
294 | dataset = dataset.shuffle(shuffle_size)
295 | dataset = dataset.batch(FLAGS.train_batch_size)
296 | iterator = dataset.make_initializable_iterator()
297 | input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels = iterator.get_next()
298 |
299 |
300 | training = tf.placeholder(tf.bool)
301 | mean_loss, logits, log_probs, accuracy = create_model(bert_config = bert_config,
302 | is_training = training,
303 | input_ids = input_sents,
304 | input_mask = input_mask,
305 | segment_ids = segment_ids,
306 | speaker_ids = speaker_ids,
307 | cls_positions = cls_positions,
308 | rsp_position = rsp_position,
309 | labels = labels,
310 | num_labels = 1,
311 | use_one_hot_embeddings = False)
312 |
313 | # init model with pre-training
314 | tvars = tf.trainable_variables()
315 | if FLAGS.init_checkpoint:
316 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, FLAGS.init_checkpoint)
317 | tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map)
318 |
319 | tf.logging.info("**** Trainable Variables ****")
320 | for var in tvars:
321 | init_string = ""
322 | if var.name in initialized_variable_names:
323 | init_string = ", *INIT_FROM_CKPT*"
324 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string)
325 |
326 | train_opt = optimization.create_optimizer(mean_loss, FLAGS.learning_rate, num_train_steps, num_warmup_steps, False)
327 |
328 | config = tf.ConfigProto(allow_soft_placement=True)
329 | config.gpu_options.allow_growth = True
330 | saver = tf.train.Saver()
331 |
332 | if FLAGS.do_train:
333 | with tf.Session(config=config) as sess:
334 | sess.run(tf.global_variables_initializer())
335 |
336 | for epoch in range(FLAGS.num_train_epochs):
337 | tf.logging.info('Train begin epoch {}'.format(epoch))
338 | sess.run(iterator.initializer,
339 | feed_dict={filenames: [FLAGS.train_dir], shuffle_size: 1024})
340 | run_epoch(epoch, "train", sess, training, logits, accuracy, mean_loss, train_opt)
341 |
342 | tf.logging.info('Valid begin')
343 | sess.run(iterator.initializer,
344 | feed_dict={filenames: [FLAGS.valid_dir], shuffle_size: 1})
345 | run_test(epoch, "valid", sess, training, log_probs, accuracy, saver, root_path)
346 |
347 |
348 | if __name__ == "__main__":
349 | tf.app.run()
350 |
--------------------------------------------------------------------------------
/run_testing_ar.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """MPC-BERT testing runner on the downstream task of addressee recognition."""
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import operator
10 | from time import time
11 | from collections import defaultdict
12 | import tensorflow as tf
13 | import optimization
14 | import tokenization
15 | import modeling_speaker as modeling
16 |
17 | flags = tf.flags
18 | FLAGS = flags.FLAGS
19 |
20 | flags.DEFINE_string("task_name", 'Testing',
21 | "The name of the task.")
22 |
23 | flags.DEFINE_string("test_dir", 'test.tfrecord',
24 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.")
25 |
26 | flags.DEFINE_string("restore_model_dir", 'output/',
27 | "The output directory where the model checkpoints have been written.")
28 |
29 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json',
30 | "The config json file corresponding to the pre-trained BERT model. "
31 | "This specifies the model architecture.")
32 |
33 | flags.DEFINE_bool("do_eval", True,
34 | "Whether to run eval on the dev set.")
35 |
36 | flags.DEFINE_integer("eval_batch_size", 32,
37 | "Total batch size for predict.")
38 |
39 | flags.DEFINE_integer("max_seq_length", 320,
40 | "The maximum total input sequence length after WordPiece tokenization. "
41 | "Sequences longer than this will be truncated, and sequences shorter "
42 | "than this will be padded.")
43 |
44 | flags.DEFINE_integer("max_utr_num", 7,
45 | "Maximum utterance number.")
46 |
47 |
48 | def print_configuration_op(FLAGS):
49 | print('My Configurations:')
50 | for name, value in FLAGS.__flags.items():
51 | value=value.value
52 | if type(value) == float:
53 | print(' %s:\t %f'%(name, value))
54 | elif type(value) == int:
55 | print(' %s:\t %d'%(name, value))
56 | elif type(value) == str:
57 | print(' %s:\t %s'%(name, value))
58 | elif type(value) == bool:
59 | print(' %s:\t %s'%(name, value))
60 | else:
61 | print('%s:\t %s' % (name, value))
62 | print('End of configuration')
63 |
64 |
65 | def count_data_size(file_name):
66 | sample_nums = 0
67 | for record in tf.python_io.tf_record_iterator(file_name):
68 | sample_nums += 1
69 | return sample_nums
70 |
71 |
72 | def parse_exmp(serial_exmp):
73 | input_data = tf.parse_single_example(serial_exmp,
74 | features={
75 | "input_sents":
76 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
77 | "input_mask":
78 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
79 | "segment_ids":
80 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
81 | "speaker_ids":
82 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
83 | "cls_positions":
84 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64),
85 | "label_ids":
86 | tf.FixedLenFeature([FLAGS.max_utr_num*FLAGS.max_utr_num], tf.int64),
87 | "label_weights":
88 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.float32),
89 | }
90 | )
91 | # So cast all int64 to int32.
92 | for name in list(input_data.keys()):
93 | t = input_data[name]
94 | if t.dtype == tf.int64:
95 | t = tf.to_int32(t)
96 | input_data[name] = t
97 |
98 | input_sents = input_data["input_sents"]
99 | input_mask = input_data["input_mask"]
100 | segment_ids= input_data["segment_ids"]
101 | speaker_ids= input_data["speaker_ids"]
102 | cls_positions= input_data["cls_positions"]
103 | labels = input_data['label_ids']
104 | label_weights = input_data['label_weights']
105 | return input_sents, input_mask, segment_ids, speaker_ids, cls_positions, labels, label_weights
106 |
107 |
108 | def gather_indexes(sequence_tensor, positions):
109 | """Gathers the vectors at the specific positions over a minibatch."""
110 | # sequence_tensor = [batch_size, seq_length, width]
111 | # positions = [batch_size, max_utr_num]
112 |
113 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
114 | batch_size = sequence_shape[0]
115 | seq_length = sequence_shape[1]
116 | width = sequence_shape[2]
117 |
118 | flat_offsets = tf.reshape(
119 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) # [batch_size, 1]
120 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) # [batch_size*max_utr_num, ]
121 | flat_sequence_tensor = tf.reshape(sequence_tensor,
122 | [batch_size * seq_length, width])
123 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) # [batch_size*max_utr_num, width]
124 | return output_tensor
125 |
126 |
127 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, cls_positions, labels, label_weights,
128 | use_one_hot_embeddings):
129 | """Creates a classification model."""
130 | model = modeling.BertModel(
131 | config=bert_config,
132 | is_training=is_training,
133 | input_ids=input_ids,
134 | input_mask=input_mask,
135 | token_type_ids=segment_ids,
136 | speaker_ids=speaker_ids,
137 | use_one_hot_embeddings=use_one_hot_embeddings)
138 |
139 | positions_shape = modeling.get_shape_list(cls_positions, expected_rank=2)
140 | max_utr_num = positions_shape[-1]
141 |
142 | input_tensor = gather_indexes(model.get_sequence_output(), cls_positions) # [batch_size*max_utr_num, dim]
143 | input_shape = modeling.get_shape_list(input_tensor, expected_rank=2)
144 | width = input_shape[-1]
145 |
146 | with tf.variable_scope("cls/addressee_recognize"):
147 | # We apply one more non-linear transformation before the output layer.
148 | with tf.variable_scope("transform"):
149 | input_tensor = tf.layers.dense(
150 | input_tensor,
151 | units=bert_config.hidden_size,
152 | activation=modeling.get_activation(bert_config.hidden_act),
153 | kernel_initializer=modeling.create_initializer(bert_config.initializer_range))
154 | input_tensor = modeling.layer_norm(input_tensor) # [batch_size*max_utr_num, dim]
155 |
156 | input_tensor = tf.reshape(input_tensor, [-1, max_utr_num, width]) # [batch_size, max_utr_num, dim]
157 | output_weights = tf.get_variable(
158 | "output_weights",
159 | shape=[width, width],
160 | initializer=modeling.create_initializer(bert_config.initializer_range))
161 | logits = tf.matmul(tf.einsum('aij,jk->aik', input_tensor, output_weights),
162 | input_tensor, transpose_b=True) # [batch_size, max_utr_num, max_utr_num]
163 |
164 | # mask = [[0. 0. 0. 0. 0.]
165 | # [1. 0. 0. 0. 0.]
166 | # [1. 1. 0. 0. 0.]
167 | # [1. 1. 1. 0. 0.]
168 | # [1. 1. 1. 1. 0.]]
169 | mask = tf.matrix_band_part(tf.ones((max_utr_num, max_utr_num)), -1, 0) - tf.matrix_band_part(tf.ones((max_utr_num, max_utr_num)), 0, 0)
170 | logits = logits * mask + -1e9 * (1-mask) # [batch_size, max_utr_num, max_utr_num]
171 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size, max_utr_num, max_utr_num]
172 | one_hot_labels = tf.reshape(labels, [-1, max_utr_num, max_utr_num])
173 | one_hot_labels = tf.cast(one_hot_labels, "float")
174 |
175 | # loss
176 | per_example_loss = - tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch_size, max_utr_num]
177 | numerator = tf.reduce_sum(label_weights * per_example_loss) # [1, ]
178 | denominator = tf.reduce_sum(label_weights) + 1e-5 # [1, ]
179 | mean_loss = numerator / denominator
180 |
181 | # accuracy
182 | predictions = tf.argmax(log_probs, axis=-1, output_type=tf.int32) # [batch_size, max_utr_num]
183 | predictions_one_hot = tf.one_hot(predictions, depth=max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num, max_utr_num]
184 | correct_prediction_utr = tf.reduce_sum(predictions_one_hot * one_hot_labels, -1) # [batch_size, max_utr_num]
185 | numerator_utr = tf.reduce_sum(label_weights * correct_prediction_utr) # [1, ]
186 | accuracy_utr = numerator_utr / denominator
187 |
188 | correct_prediction_sess = tf.equal(tf.reduce_sum(label_weights * correct_prediction_utr, -1),
189 | tf.reduce_sum(label_weights, -1)) # [batch_size, ]
190 | accuracy_sess = tf.reduce_mean(tf.cast(correct_prediction_sess, "float"), name="accuracy")
191 |
192 | return mean_loss, logits, log_probs, accuracy_utr, accuracy_sess, tf.reduce_sum(label_weights)
193 |
194 |
195 | def run_test(sess, training, prob, accuracy_utr, accuracy_sess, num_utr):
196 |
197 | step = 0
198 | t0 = time()
199 | num_test_utr = 0
200 | num_correct_utr = 0.0
201 | test_accuracy_utr = 0
202 | num_test_sess = 0
203 | num_correct_sess = 0.0
204 | test_accuracy_sess = 0
205 |
206 | try:
207 | while True:
208 | step += 1
209 | batch_accuracy_utr, batch_accuracy_sess, batch_predicted_prob, batch_num_utr = sess.run(
210 | [accuracy_utr, accuracy_sess, prob, num_utr], feed_dict={training: False})
211 |
212 | num_test_utr += int(batch_num_utr)
213 | num_correct_utr += batch_num_utr * batch_accuracy_utr
214 |
215 | num_test_sess += len(batch_predicted_prob)
216 | num_correct_sess += len(batch_predicted_prob) * batch_accuracy_sess
217 |
218 | if step % 100 == 0:
219 | tf.logging.info("Step %d, Time (min): %.2f" % (step, (time() - t0) / 60.0))
220 |
221 | except tf.errors.OutOfRangeError:
222 | test_accuracy_utr = num_correct_utr / num_test_utr
223 | print('num_test_utterance: {}, test_accuracy_utr: {}'.format(num_test_utr, test_accuracy_utr))
224 | test_accuracy_sess = num_correct_sess / num_test_sess
225 | print('num_test_session: {}, test_accuracy_sess: {}'.format(num_test_sess, test_accuracy_sess))
226 |
227 | return test_accuracy_sess
228 |
229 |
230 | def main(_):
231 | tf.logging.set_verbosity(tf.logging.INFO)
232 | print_configuration_op(FLAGS)
233 |
234 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
235 |
236 | test_data_size = count_data_size(FLAGS.test_dir)
237 | tf.logging.info('test data size: {}'.format(test_data_size))
238 |
239 | filenames = tf.placeholder(tf.string, shape=[None])
240 | shuffle_size = tf.placeholder(tf.int64)
241 | dataset = tf.data.TFRecordDataset(filenames)
242 | dataset = dataset.map(parse_exmp) # Parse the record into tensors.
243 | dataset = dataset.repeat(1)
244 | # dataset = dataset.shuffle(shuffle_size)
245 | dataset = dataset.batch(FLAGS.eval_batch_size)
246 | iterator = dataset.make_initializable_iterator()
247 | input_sents, input_mask, segment_ids, speaker_ids, cls_positions, labels, label_weights = iterator.get_next()
248 |
249 | training = tf.placeholder(tf.bool)
250 | mean_loss, logits, log_probs, accuracy_utr, accuracy_sess, num_utr = create_model(bert_config = bert_config,
251 | is_training = training,
252 | input_ids = input_sents,
253 | input_mask = input_mask,
254 | segment_ids = segment_ids,
255 | speaker_ids = speaker_ids,
256 | cls_positions = cls_positions,
257 | labels = labels,
258 | label_weights = label_weights,
259 | use_one_hot_embeddings = False)
260 |
261 | config = tf.ConfigProto(allow_soft_placement=True)
262 | config.gpu_options.allow_growth = True
263 |
264 | if FLAGS.do_eval:
265 | with tf.Session(config=config) as sess:
266 | tf.logging.info("*** Restore model ***")
267 |
268 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir)
269 | variables = tf.trainable_variables()
270 | saver = tf.train.Saver(variables)
271 | saver.restore(sess, ckpt.model_checkpoint_path)
272 |
273 | tf.logging.info('Test begin')
274 | sess.run(iterator.initializer,
275 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1})
276 | run_test(sess, training, log_probs, accuracy_utr, accuracy_sess, num_utr)
277 |
278 |
279 | if __name__ == "__main__":
280 | tf.app.run()
281 |
--------------------------------------------------------------------------------
/run_testing_ar_gift.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """MPC-BERT testing runner on the downstream task of speaker identification."""
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import operator
10 | from time import time
11 | from collections import defaultdict
12 | import tensorflow as tf
13 | import optimization
14 | import tokenization
15 | import modeling_speaker_gift as modeling
16 |
17 | flags = tf.flags
18 | FLAGS = flags.FLAGS
19 |
20 | flags.DEFINE_string("task_name", 'Testing',
21 | "The name of the task.")
22 |
23 | flags.DEFINE_string("test_dir", 'test.tfrecord',
24 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.")
25 |
26 | flags.DEFINE_string("restore_model_dir", 'output/',
27 | "The output directory where the model checkpoints have been written.")
28 |
29 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json',
30 | "The config json file corresponding to the pre-trained BERT model. "
31 | "This specifies the model architecture.")
32 |
33 | flags.DEFINE_bool("do_eval", True,
34 | "Whether to run eval on the dev set.")
35 |
36 | flags.DEFINE_integer("eval_batch_size", 32,
37 | "Total batch size for predict.")
38 |
39 | flags.DEFINE_integer("max_seq_length", 320,
40 | "The maximum total input sequence length after WordPiece tokenization. "
41 | "Sequences longer than this will be truncated, and sequences shorter "
42 | "than this will be padded.")
43 |
44 | flags.DEFINE_integer("max_utr_num", 7,
45 | "Maximum utterance number.")
46 |
47 |
48 | def print_configuration_op(FLAGS):
49 | print('My Configurations:')
50 | for name, value in FLAGS.__flags.items():
51 | value=value.value
52 | if type(value) == float:
53 | print(' %s:\t %f'%(name, value))
54 | elif type(value) == int:
55 | print(' %s:\t %d'%(name, value))
56 | elif type(value) == str:
57 | print(' %s:\t %s'%(name, value))
58 | elif type(value) == bool:
59 | print(' %s:\t %s'%(name, value))
60 | else:
61 | print('%s:\t %s' % (name, value))
62 | print('End of configuration')
63 |
64 |
65 | def count_data_size(file_name):
66 | sample_nums = 0
67 | for record in tf.python_io.tf_record_iterator(file_name):
68 | sample_nums += 1
69 | return sample_nums
70 |
71 |
72 | def parse_exmp(serial_exmp):
73 | input_data = tf.parse_single_example(serial_exmp,
74 | features={
75 | "input_sents":
76 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
77 | "input_mask":
78 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
79 | "segment_ids":
80 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
81 | "speaker_ids":
82 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
83 | "cls_positions":
84 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64),
85 | "rsp_position":
86 | tf.FixedLenFeature([1], tf.int64),
87 | "label_ids":
88 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64),
89 | "reply_mask_utr2word_flatten":
90 | tf.FixedLenFeature([FLAGS.max_utr_num * FLAGS.max_seq_length], tf.int64),
91 | "utr_lens":
92 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64),
93 | }
94 | )
95 | # So cast all int64 to int32.
96 | for name in list(input_data.keys()):
97 | t = input_data[name]
98 | if t.dtype == tf.int64:
99 | t = tf.to_int32(t)
100 | input_data[name] = t
101 |
102 | input_sents = input_data["input_sents"]
103 | input_mask = input_data["input_mask"]
104 | segment_ids= input_data["segment_ids"]
105 | speaker_ids= input_data["speaker_ids"]
106 | cls_positions= input_data["cls_positions"]
107 | rsp_position= input_data["rsp_position"]
108 | reply_mask_utr2word = tf.reshape(input_data['reply_mask_utr2word_flatten'], [FLAGS.max_utr_num, FLAGS.max_seq_length])
109 | utr_lens = input_data['utr_lens']
110 | labels = input_data['label_ids']
111 |
112 | reply_mask_word2word = []
113 | for i in range(FLAGS.max_utr_num):
114 | reply_mask_utr2word_i = reply_mask_utr2word[i] # [max_seq_length, ]
115 | utr_len_i = utr_lens[i] # [1, ]
116 | reply_mask_utr2word_i_tiled = tf.tile(tf.expand_dims(reply_mask_utr2word_i, 0), [utr_len_i, 1]) # [utr_len, max_seq_length]
117 | reply_mask_word2word.append(reply_mask_utr2word_i_tiled)
118 |
119 | reply_mask_pad = tf.zeros(shape=[FLAGS.max_seq_length - tf.reduce_sum(utr_lens), FLAGS.max_seq_length], dtype=tf.int32)
120 | reply_mask_word2word.append(reply_mask_pad)
121 | reply_mask = tf.concat(reply_mask_word2word, 0) # [max_seq_length, max_seq_length]
122 | print("* Loading reply mask ...")
123 |
124 | return input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, reply_mask
125 |
126 |
127 | def gather_indexes(sequence_tensor, positions):
128 | """Gathers the vectors at the specific positions over a minibatch."""
129 | # sequence_tensor = [batch_size, seq_length, width]
130 | # positions = [batch_size, max_utr_num]
131 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
132 | batch_size = sequence_shape[0]
133 | seq_length = sequence_shape[1]
134 | width = sequence_shape[2]
135 |
136 | flat_offsets = tf.reshape(
137 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) # [batch_size, 1]
138 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) # [batch_size*max_utr_num, ]
139 | flat_sequence_tensor = tf.reshape(sequence_tensor,
140 | [batch_size * seq_length, width])
141 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) # [batch_size*max_utr_num, width]
142 | return output_tensor
143 |
144 |
145 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, reply_mask,
146 | num_labels, use_one_hot_embeddings):
147 | """Creates a classification model."""
148 | model = modeling.BertModel(
149 | config=bert_config,
150 | is_training=is_training,
151 | input_ids=input_ids,
152 | input_mask=input_mask,
153 | token_type_ids=segment_ids,
154 | speaker_ids=speaker_ids,
155 | reply_mask=reply_mask,
156 | use_one_hot_embeddings=use_one_hot_embeddings)
157 |
158 | input_tensor = gather_indexes(model.get_sequence_output(), cls_positions) # [batch_size*max_utr_num, dim]
159 |
160 | input_shape = modeling.get_shape_list(input_tensor, expected_rank=2)
161 | width = input_shape[-1]
162 | positions_shape = modeling.get_shape_list(cls_positions, expected_rank=2)
163 | max_utr_num = positions_shape[-1]
164 |
165 | with tf.variable_scope("cls/addressee_recognize"):
166 | # We apply one more non-linear transformation before the output layer.
167 | with tf.variable_scope("transform"):
168 | input_tensor = tf.layers.dense(
169 | input_tensor,
170 | units=bert_config.hidden_size,
171 | activation=modeling.get_activation(bert_config.hidden_act),
172 | kernel_initializer=modeling.create_initializer(bert_config.initializer_range))
173 | input_tensor = modeling.layer_norm(input_tensor) # [batch_size*max_utr_num, dim]
174 |
175 | input_tensor = tf.reshape(input_tensor, [-1, max_utr_num, width]) # [batch_size, max_utr_num, dim]
176 |
177 | rsp_tensor = gather_indexes(input_tensor, rsp_position) # [batch_size*1, dim]
178 | rsp_tensor = tf.reshape(rsp_tensor, [-1, 1, width]) # [batch_size, 1, dim]
179 |
180 | output_weights = tf.get_variable(
181 | "output_weights",
182 | shape=[width, width],
183 | initializer=modeling.create_initializer(bert_config.initializer_range))
184 | logits = tf.matmul(tf.einsum('aij,jk->aik', rsp_tensor, output_weights),
185 | input_tensor, transpose_b=True) # [batch_size, 1, max_utr_num]
186 | logits = tf.squeeze(logits, [1]) # [batch_size, max_utr_num]
187 |
188 | mask = tf.sequence_mask(tf.reshape(rsp_position, [-1, ]), max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num]
189 | logits = logits * mask + -1e9 * (1-mask)
190 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size, max_utr_num]
191 |
192 | # loss
193 | one_hot_labels = tf.cast(labels, "float") # [batch_size, max_utr_num]
194 | per_example_loss = - tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch_size, ]
195 | mean_loss = tf.reduce_mean(per_example_loss, name="mean_loss")
196 |
197 | # accuracy
198 | predictions = tf.argmax(log_probs, axis=-1, output_type=tf.int32) # [batch_size, ]
199 | predictions_one_hot = tf.one_hot(predictions, depth=max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num]
200 | correct_prediction = tf.reduce_sum(predictions_one_hot * one_hot_labels, -1) # [batch_size, ]
201 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy")
202 |
203 | return mean_loss, logits, log_probs, accuracy
204 |
205 |
206 | def run_test(sess, training, prob, accuracy):
207 |
208 | step = 0
209 | t0 = time()
210 | num_test = 0
211 | num_correct = 0.0
212 | test_accuracy = 0
213 |
214 | try:
215 | while True:
216 | step += 1
217 | batch_accuracy, predicted_prob = sess.run([accuracy, prob], feed_dict={training: False})
218 |
219 | num_test += len(predicted_prob)
220 | num_correct += len(predicted_prob) * batch_accuracy
221 |
222 | if step % 100 == 0:
223 | tf.logging.info("Step %d, Time (min): %.2f" % (step, (time() - t0) / 60.0))
224 |
225 | except tf.errors.OutOfRangeError:
226 | test_accuracy = num_correct / num_test
227 | print('num_test_samples: {}, test_accuracy: {}'.format(num_test, test_accuracy))
228 |
229 | return test_accuracy
230 |
231 |
232 | def main(_):
233 | tf.logging.set_verbosity(tf.logging.INFO)
234 | print_configuration_op(FLAGS)
235 |
236 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
237 |
238 | test_data_size = count_data_size(FLAGS.test_dir)
239 | tf.logging.info('test data size: {}'.format(test_data_size))
240 |
241 | filenames = tf.placeholder(tf.string, shape=[None])
242 | shuffle_size = tf.placeholder(tf.int64)
243 | dataset = tf.data.TFRecordDataset(filenames)
244 | dataset = dataset.map(parse_exmp) # Parse the record into tensors.
245 | dataset = dataset.repeat(1)
246 | # dataset = dataset.shuffle(shuffle_size)
247 | dataset = dataset.batch(FLAGS.eval_batch_size)
248 | iterator = dataset.make_initializable_iterator()
249 | input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, reply_mask = iterator.get_next()
250 |
251 | training = tf.placeholder(tf.bool)
252 | mean_loss, logits, log_probs, accuracy = create_model(bert_config = bert_config,
253 | is_training = training,
254 | input_ids = input_sents,
255 | input_mask = input_mask,
256 | segment_ids = segment_ids,
257 | speaker_ids = speaker_ids,
258 | cls_positions = cls_positions,
259 | rsp_position = rsp_position,
260 | reply_mask = reply_mask,
261 | labels = labels,
262 | num_labels = 1,
263 | use_one_hot_embeddings = False)
264 |
265 | config = tf.ConfigProto(allow_soft_placement=True)
266 | config.gpu_options.allow_growth = True
267 |
268 | if FLAGS.do_eval:
269 | with tf.Session(config=config) as sess:
270 | tf.logging.info("*** Restore model ***")
271 |
272 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir)
273 | variables = tf.trainable_variables()
274 | saver = tf.train.Saver(variables)
275 | saver.restore(sess, ckpt.model_checkpoint_path)
276 |
277 | tf.logging.info('Test begin')
278 | sess.run(iterator.initializer,
279 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1})
280 | run_test(sess, training, log_probs, accuracy)
281 |
282 |
283 | if __name__ == "__main__":
284 | tf.app.run()
285 |
--------------------------------------------------------------------------------
/run_testing_rs.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """MPC-BERT testing runner on the downstream task of response selection."""
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import operator
10 | from time import time
11 | from collections import defaultdict
12 | import tensorflow as tf
13 | import optimization
14 | import tokenization
15 | import modeling_speaker as modeling
16 | import metrics
17 |
18 | flags = tf.flags
19 | FLAGS = flags.FLAGS
20 |
21 | flags.DEFINE_string("task_name", 'Testing',
22 | "The name of the task.")
23 |
24 | flags.DEFINE_string("test_dir", 'test.tfrecord',
25 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.")
26 |
27 | flags.DEFINE_string("restore_model_dir", 'output/',
28 | "The output directory where the model checkpoints have been written.")
29 |
30 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json',
31 | "The config json file corresponding to the pre-trained BERT model. "
32 | "This specifies the model architecture.")
33 |
34 | flags.DEFINE_bool("do_eval", True,
35 | "Whether to run eval on the dev set.")
36 |
37 | flags.DEFINE_integer("eval_batch_size", 32,
38 | "Total batch size for predict.")
39 |
40 | flags.DEFINE_integer("max_seq_length", 320,
41 | "The maximum total input sequence length after WordPiece tokenization. "
42 | "Sequences longer than this will be truncated, and sequences shorter "
43 | "than this will be padded.")
44 |
45 |
46 | def print_configuration_op(FLAGS):
47 | print('My Configurations:')
48 | for name, value in FLAGS.__flags.items():
49 | value=value.value
50 | if type(value) == float:
51 | print(' %s:\t %f'%(name, value))
52 | elif type(value) == int:
53 | print(' %s:\t %d'%(name, value))
54 | elif type(value) == str:
55 | print(' %s:\t %s'%(name, value))
56 | elif type(value) == bool:
57 | print(' %s:\t %s'%(name, value))
58 | else:
59 | print('%s:\t %s' % (name, value))
60 | print('End of configuration')
61 |
62 |
63 | def count_data_size(file_name):
64 | sample_nums = 0
65 | for record in tf.python_io.tf_record_iterator(file_name):
66 | sample_nums += 1
67 | return sample_nums
68 |
69 |
70 | def parse_exmp(serial_exmp):
71 | input_data = tf.parse_single_example(serial_exmp,
72 | features={
73 | "ctx_id":
74 | tf.FixedLenFeature([], tf.int64),
75 | "rsp_id":
76 | tf.FixedLenFeature([], tf.int64),
77 | "input_sents":
78 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
79 | "input_mask":
80 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
81 | "segment_ids":
82 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
83 | "speaker_ids":
84 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
85 | "label_ids":
86 | tf.FixedLenFeature([], tf.float32),
87 | }
88 | )
89 | # So cast all int64 to int32.
90 | for name in list(input_data.keys()):
91 | t = input_data[name]
92 | if t.dtype == tf.int64:
93 | t = tf.to_int32(t)
94 | input_data[name] = t
95 |
96 | ctx_id = input_data["ctx_id"]
97 | rsp_id = input_data['rsp_id']
98 | input_sents = input_data["input_sents"]
99 | input_mask = input_data["input_mask"]
100 | segment_ids= input_data["segment_ids"]
101 | speaker_ids= input_data["speaker_ids"]
102 | labels = input_data['label_ids']
103 | return ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, labels
104 |
105 |
106 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, labels, ctx_id, rsp_id,
107 | num_labels, use_one_hot_embeddings):
108 | """Creates a classification model."""
109 | model = modeling.BertModel(
110 | config=bert_config,
111 | is_training=is_training,
112 | input_ids=input_ids,
113 | input_mask=input_mask,
114 | token_type_ids=segment_ids,
115 | speaker_ids=speaker_ids,
116 | use_one_hot_embeddings=use_one_hot_embeddings)
117 |
118 | target_loss_weight = [1.0, 1.0]
119 | target_loss_weight = tf.convert_to_tensor(target_loss_weight)
120 |
121 | flagx = tf.cast(tf.greater(labels, 0), dtype=tf.float32)
122 | flagy = tf.cast(tf.equal(labels, 0), dtype=tf.float32)
123 |
124 | all_target_loss = target_loss_weight[1] * flagx + target_loss_weight[0] * flagy
125 |
126 | output_layer = model.get_pooled_output()
127 |
128 | hidden_size = output_layer.shape[-1].value
129 |
130 | output_weights = tf.get_variable(
131 | "output_weights", [num_labels, hidden_size],
132 | initializer=tf.truncated_normal_initializer(stddev=0.02))
133 |
134 | output_bias = tf.get_variable(
135 | "output_bias", [num_labels], initializer=tf.zeros_initializer())
136 |
137 | with tf.variable_scope("loss"):
138 | # if is_training:
139 | # output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
140 | output_layer = tf.layers.dropout(output_layer, rate=0.1, training=is_training)
141 |
142 | logits = tf.matmul(output_layer, output_weights, transpose_b=True)
143 | logits = tf.nn.bias_add(logits, output_bias)
144 |
145 | probabilities = tf.sigmoid(logits, name="prob")
146 | logits = tf.squeeze(logits,[1])
147 | losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
148 | losses = tf.multiply(losses, all_target_loss)
149 |
150 | mean_loss = tf.reduce_mean(losses, name="mean_loss") + sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
151 |
152 | with tf.name_scope("accuracy"):
153 | correct_prediction = tf.equal(tf.sign(probabilities - 0.5), tf.sign(labels - 0.5))
154 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy")
155 |
156 | return mean_loss, logits, probabilities, accuracy
157 |
158 |
159 | best_score = 0.0
160 | def run_test(dir_path, op_name, sess, training, accuracy, prob, pair_ids):
161 |
162 | step = 0
163 | t0 = time()
164 | num_test = 0
165 | num_correct = 0.0
166 | mrr = 0
167 | results = defaultdict(list)
168 |
169 | try:
170 | while True:
171 | step += 1
172 | batch_accuracy, predicted_prob, batch_pair_ids = sess.run([accuracy, prob, pair_ids], feed_dict={training: False})
173 | question_id, answer_id, label = batch_pair_ids
174 |
175 | num_test += len(predicted_prob)
176 | num_correct += len(predicted_prob) * batch_accuracy
177 | for i, prob_score in enumerate(predicted_prob):
178 | results[question_id[i]].append((answer_id[i], label[i], prob_score[0]))
179 |
180 | if step % 100 == 0:
181 | tf.logging.info("n_update %d , %s: Mins Used: %.2f" %
182 | (step, op_name, (time() - t0) / 60.0))
183 |
184 | except tf.errors.OutOfRangeError:
185 | print('num_test_samples: {} test_accuracy: {}'.format(num_test, num_correct / num_test))
186 | accu, precision, recall, f1, loss = metrics.classification_metrics(results)
187 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss))
188 |
189 | mvp = metrics.mean_average_precision(results)
190 | mrr = metrics.mean_reciprocal_rank(results)
191 | top_1_precision = metrics.top_1_precision(results)
192 | total_valid_query = metrics.get_num_valid_query(results)
193 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'.format(
194 | mvp, mrr, top_1_precision, total_valid_query))
195 |
196 | out_path = os.path.join(dir_path, "output_test.txt")
197 | print("Saving evaluation to {}".format(out_path))
198 | with open(out_path, 'w') as f:
199 | f.write("query_id\tdocument_id\tscore\trank\trelevance\n")
200 | for us_id, v in results.items():
201 | v.sort(key=operator.itemgetter(2), reverse=True)
202 | for i, rec in enumerate(v):
203 | r_id, label, prob_score = rec
204 | rank = i+1
205 | f.write('{}\t{}\t{}\t{}\t{}\n'.format(us_id, r_id, prob_score, rank, label))
206 | return mrr
207 |
208 |
209 | def main(_):
210 | tf.logging.set_verbosity(tf.logging.INFO)
211 | print_configuration_op(FLAGS)
212 |
213 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
214 |
215 | test_data_size = count_data_size(FLAGS.test_dir)
216 | tf.logging.info('test data size: {}'.format(test_data_size))
217 |
218 | filenames = tf.placeholder(tf.string, shape=[None])
219 | shuffle_size = tf.placeholder(tf.int64)
220 | dataset = tf.data.TFRecordDataset(filenames)
221 | dataset = dataset.map(parse_exmp) # Parse the record into tensors.
222 | dataset = dataset.repeat(1)
223 | # dataset = dataset.shuffle(shuffle_size)
224 | dataset = dataset.batch(FLAGS.eval_batch_size)
225 | iterator = dataset.make_initializable_iterator()
226 | ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, labels = iterator.get_next()
227 | pair_ids = [ctx_id, rsp_id, labels]
228 |
229 | training = tf.placeholder(tf.bool)
230 | mean_loss, logits, probabilities, accuracy = create_model(bert_config = bert_config,
231 | is_training = training,
232 | input_ids = input_sents,
233 | input_mask = input_mask,
234 | segment_ids = segment_ids,
235 | speaker_ids = speaker_ids,
236 | labels = labels,
237 | ctx_id = ctx_id,
238 | rsp_id = rsp_id,
239 | num_labels = 1,
240 | use_one_hot_embeddings = False)
241 |
242 |
243 | config = tf.ConfigProto(allow_soft_placement=True)
244 | config.gpu_options.allow_growth = True
245 |
246 | if FLAGS.do_eval:
247 | with tf.Session(config=config) as sess:
248 | tf.logging.info("*** Restore model ***")
249 |
250 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir)
251 | variables = tf.trainable_variables()
252 | saver = tf.train.Saver(variables)
253 | saver.restore(sess, ckpt.model_checkpoint_path)
254 |
255 | tf.logging.info('Test begin')
256 | sess.run(iterator.initializer,
257 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1})
258 | run_test(FLAGS.restore_model_dir, "test", sess, training, accuracy, probabilities, pair_ids)
259 |
260 |
261 | if __name__ == "__main__":
262 | tf.app.run()
263 |
264 |
--------------------------------------------------------------------------------
/run_testing_rs_gift.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """MPC-BERT testing runner on the downstream task of response selection."""
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import operator
10 | from time import time
11 | from collections import defaultdict
12 | import tensorflow as tf
13 | import optimization
14 | import tokenization
15 | import modeling_speaker_gift as modeling
16 | import metrics
17 |
18 | flags = tf.flags
19 | FLAGS = flags.FLAGS
20 |
21 | flags.DEFINE_string("task_name", 'Testing',
22 | "The name of the task.")
23 |
24 | flags.DEFINE_string("test_dir", 'test.tfrecord',
25 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.")
26 |
27 | flags.DEFINE_string("restore_model_dir", 'output/',
28 | "The output directory where the model checkpoints have been written.")
29 |
30 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json',
31 | "The config json file corresponding to the pre-trained BERT model. "
32 | "This specifies the model architecture.")
33 |
34 | flags.DEFINE_bool("do_eval", True,
35 | "Whether to run eval on the dev set.")
36 |
37 | flags.DEFINE_integer("eval_batch_size", 32,
38 | "Total batch size for predict.")
39 |
40 | flags.DEFINE_integer("max_seq_length", 320,
41 | "The maximum total input sequence length after WordPiece tokenization. "
42 | "Sequences longer than this will be truncated, and sequences shorter "
43 | "than this will be padded.")
44 |
45 | flags.DEFINE_integer("max_utr_num", 7,
46 | "Maximum utterance number.")
47 |
48 |
49 | def print_configuration_op(FLAGS):
50 | print('My Configurations:')
51 | for name, value in FLAGS.__flags.items():
52 | value=value.value
53 | if type(value) == float:
54 | print(' %s:\t %f'%(name, value))
55 | elif type(value) == int:
56 | print(' %s:\t %d'%(name, value))
57 | elif type(value) == str:
58 | print(' %s:\t %s'%(name, value))
59 | elif type(value) == bool:
60 | print(' %s:\t %s'%(name, value))
61 | else:
62 | print('%s:\t %s' % (name, value))
63 | print('End of configuration')
64 |
65 |
66 | def count_data_size(file_name):
67 | sample_nums = 0
68 | for record in tf.python_io.tf_record_iterator(file_name):
69 | sample_nums += 1
70 | return sample_nums
71 |
72 |
73 | def parse_exmp(serial_exmp):
74 | input_data = tf.parse_single_example(serial_exmp,
75 | features={
76 | "ctx_id":
77 | tf.FixedLenFeature([], tf.int64),
78 | "rsp_id":
79 | tf.FixedLenFeature([], tf.int64),
80 | "input_sents":
81 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
82 | "input_mask":
83 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
84 | "segment_ids":
85 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
86 | "speaker_ids":
87 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
88 | "reply_mask_utr2word_flatten":
89 | tf.FixedLenFeature([FLAGS.max_utr_num * FLAGS.max_seq_length], tf.int64),
90 | "utr_lens":
91 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64),
92 | "label_ids":
93 | tf.FixedLenFeature([], tf.float32),
94 | }
95 | )
96 | # So cast all int64 to int32.
97 | for name in list(input_data.keys()):
98 | t = input_data[name]
99 | if t.dtype == tf.int64:
100 | t = tf.to_int32(t)
101 | input_data[name] = t
102 |
103 | ctx_id = input_data["ctx_id"]
104 | rsp_id = input_data['rsp_id']
105 | input_sents = input_data["input_sents"]
106 | input_mask = input_data["input_mask"]
107 | segment_ids= input_data["segment_ids"]
108 | speaker_ids= input_data["speaker_ids"]
109 | reply_mask_utr2word = tf.reshape(input_data['reply_mask_utr2word_flatten'], [FLAGS.max_utr_num, FLAGS.max_seq_length])
110 | utr_lens = input_data['utr_lens']
111 | labels = input_data['label_ids']
112 |
113 | reply_mask_word2word = []
114 | for i in range(FLAGS.max_utr_num):
115 | reply_mask_utr2word_i = reply_mask_utr2word[i] # [max_seq_length, ]
116 | utr_len_i = utr_lens[i] # [1, ]
117 | reply_mask_utr2word_i_tiled = tf.tile(tf.expand_dims(reply_mask_utr2word_i, 0), [utr_len_i, 1]) # [utr_len, max_seq_length]
118 | reply_mask_word2word.append(reply_mask_utr2word_i_tiled)
119 |
120 | reply_mask_pad = tf.zeros(shape=[FLAGS.max_seq_length - tf.reduce_sum(utr_lens), FLAGS.max_seq_length], dtype=tf.int32)
121 | reply_mask_word2word.append(reply_mask_pad)
122 | reply_mask = tf.concat(reply_mask_word2word, 0) # [max_seq_length, max_seq_length]
123 | print("* Loading reply mask ...")
124 |
125 | return ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, reply_mask, labels
126 |
127 |
128 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, reply_mask, labels, ctx_id, rsp_id,
129 | num_labels, use_one_hot_embeddings):
130 | """Creates a classification model."""
131 | model = modeling.BertModel(
132 | config=bert_config,
133 | is_training=is_training,
134 | input_ids=input_ids,
135 | input_mask=input_mask,
136 | token_type_ids=segment_ids,
137 | speaker_ids=speaker_ids,
138 | reply_mask=reply_mask,
139 | use_one_hot_embeddings=use_one_hot_embeddings)
140 |
141 | target_loss_weight = [1.0, 1.0]
142 | target_loss_weight = tf.convert_to_tensor(target_loss_weight)
143 |
144 | flagx = tf.cast(tf.greater(labels, 0), dtype=tf.float32)
145 | flagy = tf.cast(tf.equal(labels, 0), dtype=tf.float32)
146 |
147 | all_target_loss = target_loss_weight[1] * flagx + target_loss_weight[0] * flagy
148 |
149 | output_layer = model.get_pooled_output()
150 |
151 | hidden_size = output_layer.shape[-1].value
152 |
153 | output_weights = tf.get_variable(
154 | "output_weights", [num_labels, hidden_size],
155 | initializer=tf.truncated_normal_initializer(stddev=0.02))
156 |
157 | output_bias = tf.get_variable(
158 | "output_bias", [num_labels], initializer=tf.zeros_initializer())
159 |
160 | with tf.variable_scope("loss"):
161 | # if is_training:
162 | # output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
163 | output_layer = tf.layers.dropout(output_layer, rate=0.1, training=is_training)
164 |
165 | logits = tf.matmul(output_layer, output_weights, transpose_b=True)
166 | logits = tf.nn.bias_add(logits, output_bias)
167 |
168 | probabilities = tf.sigmoid(logits, name="prob")
169 | logits = tf.squeeze(logits,[1])
170 | losses = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
171 | losses = tf.multiply(losses, all_target_loss)
172 |
173 | mean_loss = tf.reduce_mean(losses, name="mean_loss") + sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
174 |
175 | with tf.name_scope("accuracy"):
176 | correct_prediction = tf.equal(tf.sign(probabilities - 0.5), tf.sign(labels - 0.5))
177 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy")
178 |
179 | return mean_loss, logits, probabilities, accuracy
180 |
181 |
182 | best_score = 0.0
183 | def run_test(dir_path, op_name, sess, training, accuracy, prob, pair_ids):
184 |
185 | step = 0
186 | t0 = time()
187 | num_test = 0
188 | num_correct = 0.0
189 | mrr = 0
190 | results = defaultdict(list)
191 |
192 | try:
193 | while True:
194 | step += 1
195 | batch_accuracy, predicted_prob, batch_pair_ids = sess.run([accuracy, prob, pair_ids], feed_dict={training: False})
196 | question_id, answer_id, label = batch_pair_ids
197 |
198 | num_test += len(predicted_prob)
199 | num_correct += len(predicted_prob) * batch_accuracy
200 | for i, prob_score in enumerate(predicted_prob):
201 | results[question_id[i]].append((answer_id[i], label[i], prob_score[0]))
202 |
203 | if step % 100 == 0:
204 | tf.logging.info("n_update %d , %s: Mins Used: %.2f" %
205 | (step, op_name, (time() - t0) / 60.0))
206 |
207 | except tf.errors.OutOfRangeError:
208 | print('num_test_samples: {} test_accuracy: {}'.format(num_test, num_correct / num_test))
209 | accu, precision, recall, f1, loss = metrics.classification_metrics(results)
210 | print('Accuracy: {}, Precision: {} Recall: {} F1: {} Loss: {}'.format(accu, precision, recall, f1, loss))
211 |
212 | mvp = metrics.mean_average_precision(results)
213 | mrr = metrics.mean_reciprocal_rank(results)
214 | top_1_precision = metrics.top_1_precision(results)
215 | total_valid_query = metrics.get_num_valid_query(results)
216 | print('MAP (mean average precision: {}\tMRR (mean reciprocal rank): {}\tTop-1 precision: {}\tNum_query: {}'.format(
217 | mvp, mrr, top_1_precision, total_valid_query))
218 |
219 | out_path = os.path.join(dir_path, "output_test.txt")
220 | print("Saving evaluation to {}".format(out_path))
221 | with open(out_path, 'w') as f:
222 | f.write("query_id\tdocument_id\tscore\trank\trelevance\n")
223 | for us_id, v in results.items():
224 | v.sort(key=operator.itemgetter(2), reverse=True)
225 | for i, rec in enumerate(v):
226 | r_id, label, prob_score = rec
227 | rank = i+1
228 | f.write('{}\t{}\t{}\t{}\t{}\n'.format(us_id, r_id, prob_score, rank, label))
229 | return mrr
230 |
231 |
232 | def main(_):
233 | tf.logging.set_verbosity(tf.logging.INFO)
234 | print_configuration_op(FLAGS)
235 |
236 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
237 |
238 | test_data_size = count_data_size(FLAGS.test_dir)
239 | tf.logging.info('test data size: {}'.format(test_data_size))
240 |
241 | filenames = tf.placeholder(tf.string, shape=[None])
242 | shuffle_size = tf.placeholder(tf.int64)
243 | dataset = tf.data.TFRecordDataset(filenames)
244 | dataset = dataset.map(parse_exmp) # Parse the record into tensors.
245 | dataset = dataset.repeat(1)
246 | # dataset = dataset.shuffle(shuffle_size)
247 | dataset = dataset.batch(FLAGS.eval_batch_size)
248 | iterator = dataset.make_initializable_iterator()
249 | ctx_id, rsp_id, input_sents, input_mask, segment_ids, speaker_ids, reply_mask, labels = iterator.get_next()
250 | pair_ids = [ctx_id, rsp_id, labels]
251 |
252 | training = tf.placeholder(tf.bool)
253 | mean_loss, logits, probabilities, accuracy = create_model(bert_config = bert_config,
254 | is_training = training,
255 | input_ids = input_sents,
256 | input_mask = input_mask,
257 | segment_ids = segment_ids,
258 | speaker_ids = speaker_ids,
259 | reply_mask = reply_mask,
260 | labels = labels,
261 | ctx_id = ctx_id,
262 | rsp_id = rsp_id,
263 | num_labels = 1,
264 | use_one_hot_embeddings = False)
265 |
266 |
267 | config = tf.ConfigProto(allow_soft_placement=True)
268 | config.gpu_options.allow_growth = True
269 |
270 | if FLAGS.do_eval:
271 | with tf.Session(config=config) as sess:
272 | tf.logging.info("*** Restore model ***")
273 |
274 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir)
275 | variables = tf.trainable_variables()
276 | saver = tf.train.Saver(variables)
277 | saver.restore(sess, ckpt.model_checkpoint_path)
278 |
279 | tf.logging.info('Test begin')
280 | sess.run(iterator.initializer,
281 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1})
282 | run_test(FLAGS.restore_model_dir, "test", sess, training, accuracy, probabilities, pair_ids)
283 |
284 |
285 | if __name__ == "__main__":
286 | tf.app.run()
287 |
288 |
--------------------------------------------------------------------------------
/run_testing_si.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """MPC-BERT testing runner on the downstream task of speaker identification."""
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import operator
10 | from time import time
11 | from collections import defaultdict
12 | import tensorflow as tf
13 | import optimization
14 | import tokenization
15 | import modeling_speaker as modeling
16 |
17 | flags = tf.flags
18 | FLAGS = flags.FLAGS
19 |
20 | flags.DEFINE_string("task_name", 'Testing',
21 | "The name of the task.")
22 |
23 | flags.DEFINE_string("test_dir", 'test.tfrecord',
24 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.")
25 |
26 | flags.DEFINE_string("restore_model_dir", 'output/',
27 | "The output directory where the model checkpoints have been written.")
28 |
29 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json',
30 | "The config json file corresponding to the pre-trained BERT model. "
31 | "This specifies the model architecture.")
32 |
33 | flags.DEFINE_bool("do_eval", True,
34 | "Whether to run eval on the dev set.")
35 |
36 | flags.DEFINE_integer("eval_batch_size", 32,
37 | "Total batch size for predict.")
38 |
39 | flags.DEFINE_integer("max_seq_length", 320,
40 | "The maximum total input sequence length after WordPiece tokenization. "
41 | "Sequences longer than this will be truncated, and sequences shorter "
42 | "than this will be padded.")
43 |
44 | flags.DEFINE_integer("max_utr_num", 7,
45 | "Maximum utterance number.")
46 |
47 |
48 | def print_configuration_op(FLAGS):
49 | print('My Configurations:')
50 | for name, value in FLAGS.__flags.items():
51 | value=value.value
52 | if type(value) == float:
53 | print(' %s:\t %f'%(name, value))
54 | elif type(value) == int:
55 | print(' %s:\t %d'%(name, value))
56 | elif type(value) == str:
57 | print(' %s:\t %s'%(name, value))
58 | elif type(value) == bool:
59 | print(' %s:\t %s'%(name, value))
60 | else:
61 | print('%s:\t %s' % (name, value))
62 | print('End of configuration')
63 |
64 |
65 | def count_data_size(file_name):
66 | sample_nums = 0
67 | for record in tf.python_io.tf_record_iterator(file_name):
68 | sample_nums += 1
69 | return sample_nums
70 |
71 |
72 | def parse_exmp(serial_exmp):
73 | input_data = tf.parse_single_example(serial_exmp,
74 | features={
75 | "input_sents":
76 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
77 | "input_mask":
78 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
79 | "segment_ids":
80 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
81 | "speaker_ids":
82 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
83 | "cls_positions":
84 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64),
85 | "rsp_position":
86 | tf.FixedLenFeature([1], tf.int64),
87 | "label_ids":
88 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64),
89 | }
90 | )
91 | # So cast all int64 to int32.
92 | for name in list(input_data.keys()):
93 | t = input_data[name]
94 | if t.dtype == tf.int64:
95 | t = tf.to_int32(t)
96 | input_data[name] = t
97 |
98 | input_sents = input_data["input_sents"]
99 | input_mask = input_data["input_mask"]
100 | segment_ids= input_data["segment_ids"]
101 | speaker_ids= input_data["speaker_ids"]
102 | cls_positions= input_data["cls_positions"]
103 | rsp_position= input_data["rsp_position"]
104 | labels = input_data['label_ids']
105 | return input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels
106 |
107 |
108 | def gather_indexes(sequence_tensor, positions):
109 | """Gathers the vectors at the specific positions over a minibatch."""
110 | # sequence_tensor = [batch_size, seq_length, width]
111 | # positions = [batch_size, max_utr_num]
112 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
113 | batch_size = sequence_shape[0]
114 | seq_length = sequence_shape[1]
115 | width = sequence_shape[2]
116 |
117 | flat_offsets = tf.reshape(
118 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) # [batch_size, 1]
119 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) # [batch_size*max_utr_num, ]
120 | flat_sequence_tensor = tf.reshape(sequence_tensor,
121 | [batch_size * seq_length, width])
122 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) # [batch_size*max_utr_num, width]
123 | return output_tensor
124 |
125 |
126 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels,
127 | num_labels, use_one_hot_embeddings):
128 | """Creates a classification model."""
129 | model = modeling.BertModel(
130 | config=bert_config,
131 | is_training=is_training,
132 | input_ids=input_ids,
133 | input_mask=input_mask,
134 | token_type_ids=segment_ids,
135 | speaker_ids=speaker_ids,
136 | use_one_hot_embeddings=use_one_hot_embeddings)
137 |
138 | input_tensor = gather_indexes(model.get_sequence_output(), cls_positions) # [batch_size*max_utr_num, dim]
139 |
140 | input_shape = modeling.get_shape_list(input_tensor, expected_rank=2)
141 | width = input_shape[-1]
142 | positions_shape = modeling.get_shape_list(cls_positions, expected_rank=2)
143 | max_utr_num = positions_shape[-1]
144 |
145 | with tf.variable_scope("cls/speaker_restore"):
146 | # We apply one more non-linear transformation before the output layer.
147 | with tf.variable_scope("transform"):
148 | input_tensor = tf.layers.dense(
149 | input_tensor,
150 | units=bert_config.hidden_size,
151 | activation=modeling.get_activation(bert_config.hidden_act),
152 | kernel_initializer=modeling.create_initializer(bert_config.initializer_range))
153 | input_tensor = modeling.layer_norm(input_tensor) # [batch_size*max_utr_num, dim]
154 |
155 | input_tensor = tf.reshape(input_tensor, [-1, max_utr_num, width]) # [batch_size, max_utr_num, dim]
156 |
157 | rsp_tensor = gather_indexes(input_tensor, rsp_position) # [batch_size*1, dim]
158 | rsp_tensor = tf.reshape(rsp_tensor, [-1, 1, width]) # [batch_size, 1, dim]
159 |
160 | output_weights = tf.get_variable(
161 | "output_weights",
162 | shape=[width, width],
163 | initializer=modeling.create_initializer(bert_config.initializer_range))
164 | logits = tf.matmul(tf.einsum('aij,jk->aik', rsp_tensor, output_weights),
165 | input_tensor, transpose_b=True) # [batch_size, 1, max_utr_num]
166 | logits = tf.squeeze(logits, [1]) # [batch_size, max_utr_num]
167 |
168 | mask = tf.sequence_mask(tf.reshape(rsp_position, [-1, ]), max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num]
169 | logits = logits * mask + -1e9 * (1-mask)
170 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size, max_utr_num]
171 |
172 | # loss
173 | one_hot_labels = tf.cast(labels, "float") # [batch_size, max_utr_num]
174 | per_example_loss = - tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch_size, ]
175 | mean_loss = tf.reduce_mean(per_example_loss, name="mean_loss")
176 |
177 | # accuracy
178 | predictions = tf.argmax(log_probs, axis=-1, output_type=tf.int32) # [batch_size, ]
179 | predictions_one_hot = tf.one_hot(predictions, depth=max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num]
180 | correct_prediction = tf.reduce_sum(predictions_one_hot * one_hot_labels, -1) # [batch_size, ]
181 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy")
182 |
183 | return mean_loss, logits, log_probs, accuracy
184 |
185 |
186 | def run_test(sess, training, prob, accuracy):
187 |
188 | step = 0
189 | t0 = time()
190 | num_test = 0
191 | num_correct = 0.0
192 | test_accuracy = 0
193 |
194 | try:
195 | while True:
196 | step += 1
197 | batch_accuracy, predicted_prob = sess.run([accuracy, prob], feed_dict={training: False})
198 |
199 | num_test += len(predicted_prob)
200 | num_correct += len(predicted_prob) * batch_accuracy
201 |
202 | if step % 100 == 0:
203 | tf.logging.info("Step %d, Time (min): %.2f" % (step, (time() - t0) / 60.0))
204 |
205 | except tf.errors.OutOfRangeError:
206 | test_accuracy = num_correct / num_test
207 | print('num_test_samples: {}, test_accuracy: {}'.format(num_test, test_accuracy))
208 |
209 | return test_accuracy
210 |
211 |
212 | def main(_):
213 | tf.logging.set_verbosity(tf.logging.INFO)
214 | print_configuration_op(FLAGS)
215 |
216 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
217 |
218 | test_data_size = count_data_size(FLAGS.test_dir)
219 | tf.logging.info('test data size: {}'.format(test_data_size))
220 |
221 | filenames = tf.placeholder(tf.string, shape=[None])
222 | shuffle_size = tf.placeholder(tf.int64)
223 | dataset = tf.data.TFRecordDataset(filenames)
224 | dataset = dataset.map(parse_exmp) # Parse the record into tensors.
225 | dataset = dataset.repeat(1)
226 | # dataset = dataset.shuffle(shuffle_size)
227 | dataset = dataset.batch(FLAGS.eval_batch_size)
228 | iterator = dataset.make_initializable_iterator()
229 | input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels = iterator.get_next()
230 |
231 | training = tf.placeholder(tf.bool)
232 | mean_loss, logits, log_probs, accuracy = create_model(bert_config = bert_config,
233 | is_training = training,
234 | input_ids = input_sents,
235 | input_mask = input_mask,
236 | segment_ids = segment_ids,
237 | speaker_ids = speaker_ids,
238 | cls_positions = cls_positions,
239 | rsp_position = rsp_position,
240 | labels = labels,
241 | num_labels = 1,
242 | use_one_hot_embeddings = False)
243 |
244 | config = tf.ConfigProto(allow_soft_placement=True)
245 | config.gpu_options.allow_growth = True
246 |
247 | if FLAGS.do_eval:
248 | with tf.Session(config=config) as sess:
249 | tf.logging.info("*** Restore model ***")
250 |
251 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir)
252 | variables = tf.trainable_variables()
253 | saver = tf.train.Saver(variables)
254 | saver.restore(sess, ckpt.model_checkpoint_path)
255 |
256 | tf.logging.info('Test begin')
257 | sess.run(iterator.initializer,
258 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1})
259 | run_test(sess, training, log_probs, accuracy)
260 |
261 |
262 | if __name__ == "__main__":
263 | tf.app.run()
264 |
--------------------------------------------------------------------------------
/run_testing_si_gift.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | """MPC-BERT testing runner on the downstream task of speaker identification."""
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import operator
10 | from time import time
11 | from collections import defaultdict
12 | import tensorflow as tf
13 | import optimization
14 | import tokenization
15 | import modeling_speaker_gift as modeling
16 |
17 | flags = tf.flags
18 | FLAGS = flags.FLAGS
19 |
20 | flags.DEFINE_string("task_name", 'Testing',
21 | "The name of the task.")
22 |
23 | flags.DEFINE_string("test_dir", 'test.tfrecord',
24 | "The input test data dir. Should contain the .tsv files (or other data files) for the task.")
25 |
26 | flags.DEFINE_string("restore_model_dir", 'output/',
27 | "The output directory where the model checkpoints have been written.")
28 |
29 | flags.DEFINE_string("bert_config_file", 'uncased_L-12_H-768_A-12/bert_config.json',
30 | "The config json file corresponding to the pre-trained BERT model. "
31 | "This specifies the model architecture.")
32 |
33 | flags.DEFINE_bool("do_eval", True,
34 | "Whether to run eval on the dev set.")
35 |
36 | flags.DEFINE_integer("eval_batch_size", 32,
37 | "Total batch size for predict.")
38 |
39 | flags.DEFINE_integer("max_seq_length", 320,
40 | "The maximum total input sequence length after WordPiece tokenization. "
41 | "Sequences longer than this will be truncated, and sequences shorter "
42 | "than this will be padded.")
43 |
44 | flags.DEFINE_integer("max_utr_num", 7,
45 | "Maximum utterance number.")
46 |
47 |
48 | def print_configuration_op(FLAGS):
49 | print('My Configurations:')
50 | for name, value in FLAGS.__flags.items():
51 | value=value.value
52 | if type(value) == float:
53 | print(' %s:\t %f'%(name, value))
54 | elif type(value) == int:
55 | print(' %s:\t %d'%(name, value))
56 | elif type(value) == str:
57 | print(' %s:\t %s'%(name, value))
58 | elif type(value) == bool:
59 | print(' %s:\t %s'%(name, value))
60 | else:
61 | print('%s:\t %s' % (name, value))
62 | print('End of configuration')
63 |
64 |
65 | def count_data_size(file_name):
66 | sample_nums = 0
67 | for record in tf.python_io.tf_record_iterator(file_name):
68 | sample_nums += 1
69 | return sample_nums
70 |
71 |
72 | def parse_exmp(serial_exmp):
73 | input_data = tf.parse_single_example(serial_exmp,
74 | features={
75 | "input_sents":
76 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
77 | "input_mask":
78 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
79 | "segment_ids":
80 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
81 | "speaker_ids":
82 | tf.FixedLenFeature([FLAGS.max_seq_length], tf.int64),
83 | "cls_positions":
84 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64),
85 | "rsp_position":
86 | tf.FixedLenFeature([1], tf.int64),
87 | "label_ids":
88 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64),
89 | "reply_mask_utr2word_flatten":
90 | tf.FixedLenFeature([FLAGS.max_utr_num * FLAGS.max_seq_length], tf.int64),
91 | "utr_lens":
92 | tf.FixedLenFeature([FLAGS.max_utr_num], tf.int64),
93 | }
94 | )
95 | # So cast all int64 to int32.
96 | for name in list(input_data.keys()):
97 | t = input_data[name]
98 | if t.dtype == tf.int64:
99 | t = tf.to_int32(t)
100 | input_data[name] = t
101 |
102 | input_sents = input_data["input_sents"]
103 | input_mask = input_data["input_mask"]
104 | segment_ids= input_data["segment_ids"]
105 | speaker_ids= input_data["speaker_ids"]
106 | cls_positions= input_data["cls_positions"]
107 | rsp_position= input_data["rsp_position"]
108 | reply_mask_utr2word = tf.reshape(input_data['reply_mask_utr2word_flatten'], [FLAGS.max_utr_num, FLAGS.max_seq_length])
109 | utr_lens = input_data['utr_lens']
110 | labels = input_data['label_ids']
111 |
112 | reply_mask_word2word = []
113 | for i in range(FLAGS.max_utr_num):
114 | reply_mask_utr2word_i = reply_mask_utr2word[i] # [max_seq_length, ]
115 | utr_len_i = utr_lens[i] # [1, ]
116 | reply_mask_utr2word_i_tiled = tf.tile(tf.expand_dims(reply_mask_utr2word_i, 0), [utr_len_i, 1]) # [utr_len, max_seq_length]
117 | reply_mask_word2word.append(reply_mask_utr2word_i_tiled)
118 |
119 | reply_mask_pad = tf.zeros(shape=[FLAGS.max_seq_length - tf.reduce_sum(utr_lens), FLAGS.max_seq_length], dtype=tf.int32)
120 | reply_mask_word2word.append(reply_mask_pad)
121 | reply_mask = tf.concat(reply_mask_word2word, 0) # [max_seq_length, max_seq_length]
122 | print("* Loading reply mask ...")
123 |
124 | return input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, reply_mask
125 |
126 |
127 | def gather_indexes(sequence_tensor, positions):
128 | """Gathers the vectors at the specific positions over a minibatch."""
129 | # sequence_tensor = [batch_size, seq_length, width]
130 | # positions = [batch_size, max_utr_num]
131 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
132 | batch_size = sequence_shape[0]
133 | seq_length = sequence_shape[1]
134 | width = sequence_shape[2]
135 |
136 | flat_offsets = tf.reshape(
137 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) # [batch_size, 1]
138 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) # [batch_size*max_utr_num, ]
139 | flat_sequence_tensor = tf.reshape(sequence_tensor,
140 | [batch_size * seq_length, width])
141 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) # [batch_size*max_utr_num, width]
142 | return output_tensor
143 |
144 |
145 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, reply_mask,
146 | num_labels, use_one_hot_embeddings):
147 | """Creates a classification model."""
148 | model = modeling.BertModel(
149 | config=bert_config,
150 | is_training=is_training,
151 | input_ids=input_ids,
152 | input_mask=input_mask,
153 | token_type_ids=segment_ids,
154 | speaker_ids=speaker_ids,
155 | reply_mask=reply_mask,
156 | use_one_hot_embeddings=use_one_hot_embeddings)
157 |
158 | input_tensor = gather_indexes(model.get_sequence_output(), cls_positions) # [batch_size*max_utr_num, dim]
159 |
160 | input_shape = modeling.get_shape_list(input_tensor, expected_rank=2)
161 | width = input_shape[-1]
162 | positions_shape = modeling.get_shape_list(cls_positions, expected_rank=2)
163 | max_utr_num = positions_shape[-1]
164 |
165 | with tf.variable_scope("cls/speaker_restore"):
166 | # We apply one more non-linear transformation before the output layer.
167 | with tf.variable_scope("transform"):
168 | input_tensor = tf.layers.dense(
169 | input_tensor,
170 | units=bert_config.hidden_size,
171 | activation=modeling.get_activation(bert_config.hidden_act),
172 | kernel_initializer=modeling.create_initializer(bert_config.initializer_range))
173 | input_tensor = modeling.layer_norm(input_tensor) # [batch_size*max_utr_num, dim]
174 |
175 | input_tensor = tf.reshape(input_tensor, [-1, max_utr_num, width]) # [batch_size, max_utr_num, dim]
176 |
177 | rsp_tensor = gather_indexes(input_tensor, rsp_position) # [batch_size*1, dim]
178 | rsp_tensor = tf.reshape(rsp_tensor, [-1, 1, width]) # [batch_size, 1, dim]
179 |
180 | output_weights = tf.get_variable(
181 | "output_weights",
182 | shape=[width, width],
183 | initializer=modeling.create_initializer(bert_config.initializer_range))
184 | logits = tf.matmul(tf.einsum('aij,jk->aik', rsp_tensor, output_weights),
185 | input_tensor, transpose_b=True) # [batch_size, 1, max_utr_num]
186 | logits = tf.squeeze(logits, [1]) # [batch_size, max_utr_num]
187 |
188 | mask = tf.sequence_mask(tf.reshape(rsp_position, [-1, ]), max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num]
189 | logits = logits * mask + -1e9 * (1-mask)
190 | log_probs = tf.nn.log_softmax(logits, axis=-1) # [batch_size, max_utr_num]
191 |
192 | # loss
193 | one_hot_labels = tf.cast(labels, "float") # [batch_size, max_utr_num]
194 | per_example_loss = - tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # [batch_size, ]
195 | mean_loss = tf.reduce_mean(per_example_loss, name="mean_loss")
196 |
197 | # accuracy
198 | predictions = tf.argmax(log_probs, axis=-1, output_type=tf.int32) # [batch_size, ]
199 | predictions_one_hot = tf.one_hot(predictions, depth=max_utr_num, dtype=tf.float32) # [batch_size, max_utr_num]
200 | correct_prediction = tf.reduce_sum(predictions_one_hot * one_hot_labels, -1) # [batch_size, ]
201 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"), name="accuracy")
202 |
203 | return mean_loss, logits, log_probs, accuracy
204 |
205 |
206 | def run_test(sess, training, prob, accuracy):
207 |
208 | step = 0
209 | t0 = time()
210 | num_test = 0
211 | num_correct = 0.0
212 | test_accuracy = 0
213 |
214 | try:
215 | while True:
216 | step += 1
217 | batch_accuracy, predicted_prob = sess.run([accuracy, prob], feed_dict={training: False})
218 |
219 | num_test += len(predicted_prob)
220 | num_correct += len(predicted_prob) * batch_accuracy
221 |
222 | if step % 100 == 0:
223 | tf.logging.info("Step %d, Time (min): %.2f" % (step, (time() - t0) / 60.0))
224 |
225 | except tf.errors.OutOfRangeError:
226 | test_accuracy = num_correct / num_test
227 | print('num_test_samples: {}, test_accuracy: {}'.format(num_test, test_accuracy))
228 |
229 | return test_accuracy
230 |
231 |
232 | def main(_):
233 | tf.logging.set_verbosity(tf.logging.INFO)
234 | print_configuration_op(FLAGS)
235 |
236 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
237 |
238 | test_data_size = count_data_size(FLAGS.test_dir)
239 | tf.logging.info('test data size: {}'.format(test_data_size))
240 |
241 | filenames = tf.placeholder(tf.string, shape=[None])
242 | shuffle_size = tf.placeholder(tf.int64)
243 | dataset = tf.data.TFRecordDataset(filenames)
244 | dataset = dataset.map(parse_exmp) # Parse the record into tensors.
245 | dataset = dataset.repeat(1)
246 | # dataset = dataset.shuffle(shuffle_size)
247 | dataset = dataset.batch(FLAGS.eval_batch_size)
248 | iterator = dataset.make_initializable_iterator()
249 | input_sents, input_mask, segment_ids, speaker_ids, cls_positions, rsp_position, labels, reply_mask = iterator.get_next()
250 |
251 | training = tf.placeholder(tf.bool)
252 | mean_loss, logits, log_probs, accuracy = create_model(bert_config = bert_config,
253 | is_training = training,
254 | input_ids = input_sents,
255 | input_mask = input_mask,
256 | segment_ids = segment_ids,
257 | speaker_ids = speaker_ids,
258 | cls_positions = cls_positions,
259 | rsp_position = rsp_position,
260 | reply_mask = reply_mask,
261 | labels = labels,
262 | num_labels = 1,
263 | use_one_hot_embeddings = False)
264 |
265 | config = tf.ConfigProto(allow_soft_placement=True)
266 | config.gpu_options.allow_growth = True
267 |
268 | if FLAGS.do_eval:
269 | with tf.Session(config=config) as sess:
270 | tf.logging.info("*** Restore model ***")
271 |
272 | ckpt = tf.train.get_checkpoint_state(FLAGS.restore_model_dir)
273 | variables = tf.trainable_variables()
274 | saver = tf.train.Saver(variables)
275 | saver.restore(sess, ckpt.model_checkpoint_path)
276 |
277 | tf.logging.info('Test begin')
278 | sess.run(iterator.initializer,
279 | feed_dict={filenames: [FLAGS.test_dir], shuffle_size: 1})
280 | run_test(sess, training, log_probs, accuracy)
281 |
282 |
283 | if __name__ == "__main__":
284 | tf.app.run()
285 |
--------------------------------------------------------------------------------
/scripts/run_finetuning.sh:
--------------------------------------------------------------------------------
1 |
2 | CUDA_VISIBLE_DEVICES=0 python -u ../run_finetuning_ar.py \
3 | --task_name fine_tuning \
4 | --train_dir ../data/ijcai2019/train_ar.tfrecord \
5 | --valid_dir ../data/ijcai2019/dev_ar.tfrecord \
6 | --output_dir ../output/ijcai2019 \
7 | --do_lower_case True \
8 | --vocab_file ../uncased_L-12_H-768_A-12_MPCBERT/vocab.txt \
9 | --bert_config_file ../uncased_L-12_H-768_A-12_MPCBERT/bert_config.json \
10 | --init_checkpoint ../uncased_L-12_H-768_A-12_MPCBERT/bert_model.ckpt \
11 | --max_seq_length 230 \
12 | --max_utr_num 7 \
13 | --do_train True \
14 | --train_batch_size 16 \
15 | --learning_rate 2e-5 \
16 | --num_train_epochs 10 \
17 | --warmup_proportion 0.1 > log_finetuning_MPCBERT_ar.txt 2>&1 &
18 |
--------------------------------------------------------------------------------
/scripts/run_finetuning_gift.sh:
--------------------------------------------------------------------------------
1 |
2 | CUDA_VISIBLE_DEVICES=0 python -u ../run_finetuning_ar_gift.py \
3 | --task_name fine_tuning \
4 | --train_dir ../data/ijcai2019/train_ar_gift.tfrecord \
5 | --valid_dir ../data/ijcai2019/dev_ar_gift.tfrecord \
6 | --output_dir ../output/ijcai2019 \
7 | --do_lower_case True \
8 | --vocab_file ../uncased_L-12_H-768_A-12_MPCBERT/vocab.txt \
9 | --bert_config_file ../uncased_L-12_H-768_A-12_MPCBERT/bert_config.json \
10 | --init_checkpoint ../uncased_L-12_H-768_A-12_MPCBERT/bert_model.ckpt \
11 | --max_seq_length 230 \
12 | --max_utr_num 7 \
13 | --do_train True \
14 | --train_batch_size 16 \
15 | --learning_rate 2e-5 \
16 | --num_train_epochs 10 \
17 | --warmup_proportion 0.1 > log_finetuning_MPCBERT_ar_GIFT.txt 2>&1 &
18 |
--------------------------------------------------------------------------------
/scripts/run_pretraining.sh:
--------------------------------------------------------------------------------
1 |
2 | CUDA_VISIBLE_DEVICES=0 python -u ../run_pretraining.py \
3 | --task_name MPC-BERT-pretraining \
4 | --input_file ../data/pretraining_data.tfrecord \
5 | --output_dir ../uncased_L-12_H-768_A-12_MPCBERT \
6 | --vocab_file ../uncased_L-12_H-768_A-12/vocab.txt \
7 | --bert_config_file ../uncased_L-12_H-768_A-12/bert_config.json \
8 | --init_checkpoint ../uncased_L-12_H-768_A-12/bert_model.ckpt \
9 | --max_seq_length 230 \
10 | --max_utr_length 30 \
11 | --max_utr_num 7 \
12 | --max_predictions_per_seq 25 \
13 | --max_predictions_per_seq_ar 4 \
14 | --max_predictions_per_seq_sr 2 \
15 | --max_predictions_per_seq_cd 2 \
16 | --train_batch_size 4 \
17 | --learning_rate 5e-5 \
18 | --mid_save_step 20000 \
19 | --num_train_epochs 1 \
20 | --warmup_proportion 0.1 > log_pretraining_MPCBERT.txt 2>&1 &
21 |
--------------------------------------------------------------------------------
/scripts/run_testing.sh:
--------------------------------------------------------------------------------
1 |
2 | CUDA_VISIBLE_DEVICES=0 python -u ../run_testing_ar.py \
3 | --test_dir ../data/ijcai2019/test_ar.tfrecord \
4 | --vocab_file ../uncased_L-12_H-768_A-12/vocab.txt \
5 | --bert_config_file ../uncased_L-12_H-768_A-12/bert_config.json \
6 | --max_seq_length 230 \
7 | --max_utr_num 7 \
8 | --eval_batch_size 256 \
9 | --restore_model_dir ../output/ijcai2019/PATH_TO_TEST_MODEL > log_testing_MPCBERT_ar.txt 2>&1 &
10 |
--------------------------------------------------------------------------------
/scripts/run_testing_gift.sh:
--------------------------------------------------------------------------------
1 |
2 | CUDA_VISIBLE_DEVICES=0 python -u ../run_testing_ar_gift.py \
3 | --test_dir ../data/ijcai2019/test_ar_gift.tfrecord \
4 | --vocab_file ../uncased_L-12_H-768_A-12/vocab.txt \
5 | --bert_config_file ../uncased_L-12_H-768_A-12/bert_config.json \
6 | --max_seq_length 230 \
7 | --max_utr_num 7 \
8 | --eval_batch_size 256 \
9 | --restore_model_dir ../output/ijcai2019/PATH_TO_TEST_MODEL > log_testing_MPCBERT_ar_GIFT.txt 2>&1 &
10 |
--------------------------------------------------------------------------------
/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Tokenization classes."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import re
23 | import unicodedata
24 | import six
25 | import tensorflow as tf
26 |
27 |
28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
29 | """Checks whether the casing config is consistent with the checkpoint name."""
30 |
31 | # The casing has to be passed in by the user and there is no explicit check
32 | # as to whether it matches the checkpoint. The casing information probably
33 | # should have been stored in the bert_config.json file, but it's not, so
34 | # we have to heuristically detect it to validate.
35 |
36 | if not init_checkpoint:
37 | return
38 |
39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
40 | if m is None:
41 | return
42 |
43 | model_name = m.group(1)
44 |
45 | lower_models = [
46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
48 | ]
49 |
50 | cased_models = [
51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
52 | "multi_cased_L-12_H-768_A-12"
53 | ]
54 |
55 | is_bad_config = False
56 | if model_name in lower_models and not do_lower_case:
57 | is_bad_config = True
58 | actual_flag = "False"
59 | case_name = "lowercased"
60 | opposite_flag = "True"
61 |
62 | if model_name in cased_models and do_lower_case:
63 | is_bad_config = True
64 | actual_flag = "True"
65 | case_name = "cased"
66 | opposite_flag = "False"
67 |
68 | if is_bad_config:
69 | raise ValueError(
70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
71 | "However, `%s` seems to be a %s model, so you "
72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches "
73 | "how the model was pre-training. If this error is wrong, please "
74 | "just comment out this check." % (actual_flag, init_checkpoint,
75 | model_name, case_name, opposite_flag))
76 |
77 |
78 | def convert_to_unicode(text):
79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
80 | if six.PY3:
81 | if isinstance(text, str):
82 | return text
83 | elif isinstance(text, bytes):
84 | return text.decode("utf-8", "ignore")
85 | else:
86 | raise ValueError("Unsupported string type: %s" % (type(text)))
87 | elif six.PY2:
88 | if isinstance(text, str):
89 | return text.decode("utf-8", "ignore")
90 | elif isinstance(text, unicode):
91 | return text
92 | else:
93 | raise ValueError("Unsupported string type: %s" % (type(text)))
94 | else:
95 | raise ValueError("Not running on Python2 or Python 3?")
96 |
97 |
98 | def printable_text(text):
99 | """Returns text encoded in a way suitable for print or `tf.logging`."""
100 |
101 | # These functions want `str` for both Python2 and Python3, but in one case
102 | # it's a Unicode string and in the other it's a byte string.
103 | if six.PY3:
104 | if isinstance(text, str):
105 | return text
106 | elif isinstance(text, bytes):
107 | return text.decode("utf-8", "ignore")
108 | else:
109 | raise ValueError("Unsupported string type: %s" % (type(text)))
110 | elif six.PY2:
111 | if isinstance(text, str):
112 | return text
113 | elif isinstance(text, unicode):
114 | return text.encode("utf-8")
115 | else:
116 | raise ValueError("Unsupported string type: %s" % (type(text)))
117 | else:
118 | raise ValueError("Not running on Python2 or Python 3?")
119 |
120 |
121 | def load_vocab(vocab_file):
122 | """Loads a vocabulary file into a dictionary."""
123 | vocab = collections.OrderedDict()
124 | index = 0
125 | with tf.gfile.GFile(vocab_file, "r") as reader:
126 | while True:
127 | token = convert_to_unicode(reader.readline())
128 | if not token:
129 | break
130 | token = token.strip()
131 | vocab[token] = index
132 | index += 1
133 | return vocab
134 |
135 |
136 | def convert_by_vocab(vocab, items):
137 | """Converts a sequence of [tokens|ids] using the vocab."""
138 | output = []
139 | for item in items:
140 | output.append(vocab[item])
141 | return output
142 |
143 |
144 | def convert_tokens_to_ids(vocab, tokens):
145 | return convert_by_vocab(vocab, tokens)
146 |
147 |
148 | def convert_ids_to_tokens(inv_vocab, ids):
149 | return convert_by_vocab(inv_vocab, ids)
150 |
151 |
152 | def whitespace_tokenize(text):
153 | """Runs basic whitespace cleaning and splitting on a piece of text."""
154 | text = text.strip()
155 | if not text:
156 | return []
157 | tokens = text.split()
158 | return tokens
159 |
160 |
161 | class FullTokenizer(object):
162 | """Runs end-to-end tokenziation."""
163 |
164 | def __init__(self, vocab_file, do_lower_case=True):
165 | self.vocab = load_vocab(vocab_file)
166 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
169 |
170 | def tokenize(self, text):
171 | split_tokens = []
172 | for token in self.basic_tokenizer.tokenize(text):
173 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
174 | split_tokens.append(sub_token)
175 |
176 | return split_tokens
177 |
178 | def convert_tokens_to_ids(self, tokens):
179 | return convert_by_vocab(self.vocab, tokens)
180 |
181 | def convert_ids_to_tokens(self, ids):
182 | return convert_by_vocab(self.inv_vocab, ids)
183 |
184 |
185 | class BasicTokenizer(object):
186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
187 |
188 | def __init__(self, do_lower_case=True):
189 | """Constructs a BasicTokenizer.
190 |
191 | Args:
192 | do_lower_case: Whether to lower case the input.
193 | """
194 | self.do_lower_case = do_lower_case
195 |
196 | def tokenize(self, text):
197 | """Tokenizes a piece of text."""
198 | text = convert_to_unicode(text)
199 | text = self._clean_text(text)
200 |
201 | # This was added on November 1st, 2018 for the multilingual and Chinese
202 | # models. This is also applied to the English models now, but it doesn't
203 | # matter since the English models were not trained on any Chinese data
204 | # and generally don't have any Chinese data in them (there are Chinese
205 | # characters in the vocabulary because Wikipedia does have some Chinese
206 | # words in the English Wikipedia.).
207 | text = self._tokenize_chinese_chars(text)
208 |
209 | orig_tokens = whitespace_tokenize(text)
210 | split_tokens = []
211 | for token in orig_tokens:
212 | if self.do_lower_case:
213 | token = token.lower()
214 | token = self._run_strip_accents(token)
215 | split_tokens.extend(self._run_split_on_punc(token))
216 |
217 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
218 | return output_tokens
219 |
220 | def _run_strip_accents(self, text):
221 | """Strips accents from a piece of text."""
222 | text = unicodedata.normalize("NFD", text)
223 | output = []
224 | for char in text:
225 | cat = unicodedata.category(char)
226 | if cat == "Mn":
227 | continue
228 | output.append(char)
229 | return "".join(output)
230 |
231 | def _run_split_on_punc(self, text):
232 | """Splits punctuation on a piece of text."""
233 | chars = list(text)
234 | i = 0
235 | start_new_word = True
236 | output = []
237 | while i < len(chars):
238 | char = chars[i]
239 | if _is_punctuation(char):
240 | output.append([char])
241 | start_new_word = True
242 | else:
243 | if start_new_word:
244 | output.append([])
245 | start_new_word = False
246 | output[-1].append(char)
247 | i += 1
248 |
249 | return ["".join(x) for x in output]
250 |
251 | def _tokenize_chinese_chars(self, text):
252 | """Adds whitespace around any CJK character."""
253 | output = []
254 | for char in text:
255 | cp = ord(char)
256 | if self._is_chinese_char(cp):
257 | output.append(" ")
258 | output.append(char)
259 | output.append(" ")
260 | else:
261 | output.append(char)
262 | return "".join(output)
263 |
264 | def _is_chinese_char(self, cp):
265 | """Checks whether CP is the codepoint of a CJK character."""
266 | # This defines a "chinese character" as anything in the CJK Unicode block:
267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
268 | #
269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
270 | # despite its name. The modern Korean Hangul alphabet is a different block,
271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write
272 | # space-separated words, so they are not treated specially and handled
273 | # like the all of the other languages.
274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
275 | (cp >= 0x3400 and cp <= 0x4DBF) or #
276 | (cp >= 0x20000 and cp <= 0x2A6DF) or #
277 | (cp >= 0x2A700 and cp <= 0x2B73F) or #
278 | (cp >= 0x2B740 and cp <= 0x2B81F) or #
279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
280 | (cp >= 0xF900 and cp <= 0xFAFF) or #
281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): #
282 | return True
283 |
284 | return False
285 |
286 | def _clean_text(self, text):
287 | """Performs invalid character removal and whitespace cleanup on text."""
288 | output = []
289 | for char in text:
290 | cp = ord(char)
291 | if cp == 0 or cp == 0xfffd or _is_control(char):
292 | continue
293 | if _is_whitespace(char):
294 | output.append(" ")
295 | else:
296 | output.append(char)
297 | return "".join(output)
298 |
299 |
300 | class WordpieceTokenizer(object):
301 | """Runs WordPiece tokenziation."""
302 |
303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
304 | self.vocab = vocab
305 | self.unk_token = unk_token
306 | self.max_input_chars_per_word = max_input_chars_per_word
307 |
308 | def tokenize(self, text):
309 | """Tokenizes a piece of text into its word pieces.
310 |
311 | This uses a greedy longest-match-first algorithm to perform tokenization
312 | using the given vocabulary.
313 |
314 | For example:
315 | input = "unaffable"
316 | output = ["un", "##aff", "##able"]
317 |
318 | Args:
319 | text: A single token or whitespace separated tokens. This should have
320 | already been passed through `BasicTokenizer.
321 |
322 | Returns:
323 | A list of wordpiece tokens.
324 | """
325 |
326 | text = convert_to_unicode(text)
327 |
328 | output_tokens = []
329 | for token in whitespace_tokenize(text):
330 | chars = list(token)
331 | if len(chars) > self.max_input_chars_per_word:
332 | output_tokens.append(self.unk_token)
333 | continue
334 |
335 | is_bad = False
336 | start = 0
337 | sub_tokens = []
338 | while start < len(chars):
339 | end = len(chars)
340 | cur_substr = None
341 | while start < end:
342 | substr = "".join(chars[start:end])
343 | if start > 0:
344 | substr = "##" + substr
345 | if substr in self.vocab:
346 | cur_substr = substr
347 | break
348 | end -= 1
349 | if cur_substr is None:
350 | is_bad = True
351 | break
352 | sub_tokens.append(cur_substr)
353 | start = end
354 |
355 | if is_bad:
356 | output_tokens.append(self.unk_token)
357 | else:
358 | output_tokens.extend(sub_tokens)
359 | return output_tokens
360 |
361 |
362 | def _is_whitespace(char):
363 | """Checks whether `chars` is a whitespace character."""
364 | # \t, \n, and \r are technically contorl characters but we treat them
365 | # as whitespace since they are generally considered as such.
366 | if char == " " or char == "\t" or char == "\n" or char == "\r":
367 | return True
368 | cat = unicodedata.category(char)
369 | if cat == "Zs":
370 | return True
371 | return False
372 |
373 |
374 | def _is_control(char):
375 | """Checks whether `chars` is a control character."""
376 | # These are technically control characters but we count them as whitespace
377 | # characters.
378 | if char == "\t" or char == "\n" or char == "\r":
379 | return False
380 | cat = unicodedata.category(char)
381 | if cat in ("Cc", "Cf"):
382 | return True
383 | return False
384 |
385 |
386 | def _is_punctuation(char):
387 | """Checks whether `chars` is a punctuation character."""
388 | cp = ord(char)
389 | # We treat all non-letter/number ASCII as punctuation.
390 | # Characters such as "^", "$", and "`" are not in the Unicode
391 | # Punctuation class but we treat them as punctuation anyways, for
392 | # consistency.
393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
395 | return True
396 | cat = unicodedata.category(char)
397 | if cat.startswith("P"):
398 | return True
399 | return False
400 |
--------------------------------------------------------------------------------
/uncased_L-12_H-768_A-12/README.md:
--------------------------------------------------------------------------------
1 | ====== Download the BERT base model ======
2 |
3 | link: https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
4 | Move to path: ./uncased_L-12_H-768_A-12
--------------------------------------------------------------------------------