├── data ├── run_train_word2vec.sh ├── train_word2vec.py └── prepare.py ├── scripts └── esim │ ├── run.sh │ ├── data_iterator.py │ └── main.py ├── README.md └── LICENSE /data/run_train_word2vec.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | python train_word2vec.py ./ubuntu_data/train.txt ./embedding_w2v_d300.txt 3 | -------------------------------------------------------------------------------- /scripts/esim/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_DIR=../../data/ubuntu_data_concat 4 | 5 | CUDA_VISIBLE_DEVICES=0 python -u main.py \ 6 | --train_file=$DATA_DIR/train.txt \ 7 | --valid_file=$DATA_DIR/valid.txt \ 8 | --test_file=$DATA_DIR/test.txt \ 9 | --vocab_file=$DATA_DIR/vocab.txt \ 10 | --output_dir=result \ 11 | --embedding_file=../../data/embedding_w2v_d300.txt \ 12 | --maxlen_1=400 \ 13 | --maxlen_2=150 \ 14 | --hidden_size=300 \ 15 | --train_batch_size=16 \ 16 | --valid_batch_size=16 \ 17 | --test_batch_size=16 \ 18 | --fix_embedding=True \ 19 | --patience=1 \ 20 | > log.txt 2>&1 & 21 | 22 | -------------------------------------------------------------------------------- /data/train_word2vec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (C) 2019 Alibaba Group Holding Limited 4 | # Copyright (C) 2017 Pan Yang (panyangnlp@gmail.com) 5 | 6 | from __future__ import print_function 7 | 8 | import logging 9 | import os 10 | import sys 11 | import multiprocessing 12 | 13 | from gensim.models import Word2Vec 14 | from gensim.models.word2vec import LineSentence 15 | 16 | if __name__ == '__main__': 17 | program = os.path.basename(sys.argv[0]) 18 | logger = logging.getLogger(program) 19 | 20 | logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s') 21 | logging.root.setLevel(level=logging.INFO) 22 | logger.info("running %s" % ' '.join(sys.argv)) 23 | 24 | # check and process input arguments 25 | if len(sys.argv) < 3: 26 | print("Using: python train_word2vec.py [input_text] [output_word_vector]") 27 | sys.exit(1) 28 | input_file, output_file = sys.argv[1:3] 29 | sentences = [] 30 | for line in open(input_file): 31 | texts = line.decode("utf-8").replace("\n", "").split("\t")[1:] 32 | for uter in texts: 33 | sentences.append(uter.split()) 34 | 35 | model = Word2Vec(sentences, size=300, window=5, min_count=5, sg=1, 36 | workers=multiprocessing.cpu_count()) 37 | 38 | model.wv.save_word2vec_format(output_file, binary=False) 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ESIM for Multi-turn Response Selection Task 2 | 3 | ## Introduction 4 | If you use this code as part of any published research, please acknowledge one of the following papers. 5 | 6 | ``` 7 | @inproceedings{chen2019sequential, 8 | title={Sequential Matching Model for End-to-end Multi-turn Response Selection}, 9 | author={Chen, Qian and Wang, Wen}, 10 | booktitle={ICASSP 2019-2019 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 11 | pages={7350--7354}, 12 | year={2019}, 13 | organization={IEEE} 14 | } 15 | ``` 16 | 17 | ``` 18 | @article{DBLP:journals/corr/abs-1901-02609, 19 | author = {Chen, Qian and Wang, Wen}, 20 | title = {Sequential Attention-based Network for Noetic End-to-End Response Selection}, 21 | journal = {CoRR}, 22 | volume = {abs/1901.02609}, 23 | year = {2019}, 24 | url = {http://arxiv.org/abs/1901.02609}, 25 | } 26 | ``` 27 | 28 | ## Requirement 29 | 1. gensim 30 | ```bash 31 | pip install gensim 32 | ``` 33 | 34 | 2. Tensorflow 1.9-1.12 + Python2.7 35 | 36 | ## Steps 37 | 1. Download the [Ubuntu dataset](https://www.dropbox.com/s/2fdn26rj6h9bpvl/ubuntu_data.zip?dl=0 38 | ) released by (Xu et al, 2017) 39 | 40 | 2. Unzip the dataset and put data directory into `data/` 41 | 42 | 3. Preprocess dataset, including concatenatate context and build vocabulary 43 | ```bash 44 | cd data 45 | python prepare.py 46 | ``` 47 | 48 | 4. Train word2vec 49 | ```bash 50 | bash run_train_word2vec.sh 51 | ``` 52 | 53 | 5. Train and test ESIM, the log information is in `log.txt` file. You could find an example log file in `log_example.txt`. 54 | ```bash 55 | cd scripts/esim 56 | bash run.sh 57 | ``` -------------------------------------------------------------------------------- /data/prepare.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Generate dictionary file to plain format, one line one token""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import sys 23 | import os 24 | import numpy 25 | 26 | from collections import OrderedDict 27 | 28 | 29 | def build_dictionary(org_path, dst_path, is_lowercase=False): 30 | token_to_freqs = OrderedDict() 31 | count = 0 32 | with open(org_path, 'r') as f: 33 | for line in f: 34 | if is_lowercase: 35 | line = line.lower() 36 | arr = line.strip().split('\t') 37 | assert len(arr) == 3 38 | 39 | for text in arr[1:]: 40 | tokens = text.split(' ') 41 | for w in tokens: 42 | if w in token_to_freqs: 43 | token_to_freqs[w] += 1 44 | else: 45 | token_to_freqs[w] = 1 46 | if count % 10000 == 0: 47 | print(count) 48 | count += 1 49 | 50 | tokens = token_to_freqs.keys() 51 | freqs = token_to_freqs.values() 52 | 53 | sorted_idx = numpy.argsort(freqs) 54 | sorted_tokens = [tokens[i] for i in sorted_idx[::-1]] 55 | 56 | token_to_idx = OrderedDict() 57 | token_to_idx['_PAD_'] = 0 # default, padding 58 | token_to_idx['_UNK_'] = 1 # out-of-vocabulary 59 | token_to_idx['_BOS_'] = 2 # begin of sentence token 60 | token_to_idx['_EOS_'] = 3 # end of sentence token 61 | 62 | for i, t in enumerate(sorted_tokens): 63 | token_to_idx[t] = i + 4 64 | 65 | with open(dst_path, 'w') as f: 66 | for t in token_to_idx.keys(): 67 | f.write(t + '\n') 68 | 69 | print('Dict size', len(token_to_idx)) 70 | 71 | 72 | def concat_context(org_file, dst_file): 73 | with open(org_file, 'r') as fi: 74 | with open(dst_file, 'w') as fo: 75 | for idx, line in enumerate(fi): 76 | arr = line.strip().split('\t') 77 | label = arr[0] 78 | context = ' __eou__ __eot__ '.join( 79 | arr[1:-1]) + ' __eou__ __eot__ ' 80 | response = arr[-1] 81 | fo.write('\t'.join([label, context, response]) + '\n') 82 | 83 | 84 | def make_dirs(dirs): 85 | for d in dirs: 86 | if not os.path.exists(d): 87 | os.makedirs(d) 88 | 89 | 90 | if __name__ == '__main__': 91 | base_dir = os.path.dirname(os.path.realpath(__file__)) 92 | org_dir = os.path.join(base_dir, 'ubuntu_data/') 93 | dst_dir = os.path.join(base_dir, 'ubuntu_data_concat/') 94 | make_dirs([dst_dir]) 95 | 96 | print("***** Concatenate Context ***** ") 97 | concat_context(os.path.join(org_dir, 'test.txt'), 98 | os.path.join(dst_dir, 'test.txt')) 99 | concat_context(os.path.join(org_dir, 'valid.txt'), 100 | os.path.join(dst_dir, 'valid.txt')) 101 | concat_context(os.path.join(org_dir, 'train.txt'), 102 | os.path.join(dst_dir, 'train.txt')) 103 | 104 | print("***** Obtain Dictionary ***** ") 105 | build_dictionary(os.path.join(dst_dir, 'train.txt'), 106 | os.path.join(dst_dir, 'vocab.txt'), 107 | is_lowercase=False) 108 | -------------------------------------------------------------------------------- /scripts/esim/data_iterator.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Text iterator 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy 24 | import random 25 | import math 26 | 27 | 28 | class TextIterator: 29 | """Create text iterator for sequence pair classification problem. 30 | Data file is assumed to contain one sample per line. The format is 31 | label\tsequence1\tsequence2. 32 | Args: 33 | input_file: path of the input text file. 34 | token_to_idx: a dictionary, which convert token to index 35 | batch_size: mini-batch size 36 | vocab_size: limit on the size of the vocabulary, if token index is 37 | larger than vocab_size, return UNK (index 1) 38 | shuffle: Boolean; if true, we will first sort a buffer of samples by 39 | sequence length, and then shuffle it by batch-level. 40 | factor: buffer size is factor * batch-size 41 | 42 | """ 43 | 44 | def __init__(self, input_file, token_to_idx, 45 | batch_size=128, vocab_size=-1, shuffle=True, factor=20): 46 | self.input_file = open(input_file, 'r') 47 | self.token_to_idx = token_to_idx 48 | self.batch_size = batch_size 49 | self.vocab_size = vocab_size 50 | self.shuffle = shuffle 51 | self.end_of_data = False 52 | self.instance_buffer = [] 53 | # buffer for shuffle 54 | self.max_buffer_size = batch_size * factor 55 | 56 | def __iter__(self): 57 | return self 58 | 59 | def next(self): 60 | if self.end_of_data: 61 | self.end_of_data = False 62 | self.input_file.seek(0) 63 | raise StopIteration 64 | 65 | instance = [] 66 | 67 | if len(self.instance_buffer) == 0: 68 | for _ in range(self.max_buffer_size): 69 | line = self.input_file.readline() 70 | if line == "": 71 | break 72 | arr = line.strip().split('\t') 73 | assert len(arr) == 3 74 | self.instance_buffer.append( 75 | [arr[0], arr[1].split(' '), arr[2].split(' ')]) 76 | 77 | if self.shuffle: 78 | # sort by length of sum of target buffer and target_buffer 79 | length_list = [] 80 | for ins in self.instance_buffer: 81 | current_length = ins[1] + ins[2] 82 | length_list.append(current_length) 83 | 84 | length_array = numpy.array(length_list) 85 | length_idx = length_array.argsort() 86 | # shuffle mini-batch 87 | tindex = [] 88 | small_index = range( 89 | int(math.ceil(len(length_idx) * 1. / self.batch_size))) 90 | random.shuffle(small_index) 91 | for i in small_index: 92 | if (i + 1) * self.batch_size > len(length_idx): 93 | tindex.extend(length_idx[i * self.batch_size:]) 94 | else: 95 | tindex.extend( 96 | length_idx[i * self.batch_size:(i + 1) * self.batch_size]) 97 | 98 | _buf = [self.instance_buffer[i] for i in tindex] 99 | self.instance_buffer = _buf 100 | 101 | if len(self.instance_buffer) == 0: 102 | self.end_of_data = False 103 | self.input_file.seek(0) 104 | raise StopIteration 105 | 106 | try: 107 | # actual work here 108 | while True: 109 | # read from source file and map to word index 110 | try: 111 | current_instance = self.instance_buffer.pop(0) 112 | except IndexError: 113 | break 114 | 115 | label = current_instance[0] 116 | sent1 = current_instance[1] 117 | sent2 = current_instance[2] 118 | 119 | sent1.insert(0, '_BOS_') 120 | sent1.append('_EOS_') 121 | sent1 = [self.token_to_idx[w] 122 | if w in self.token_to_idx else 1 for w in sent1] 123 | if self.vocab_size > 0: 124 | sent1 = [w if w < self.vocab_size else 1 for w in sent1] 125 | 126 | sent2.insert(0, '_BOS_') 127 | sent2.append('_EOS_') 128 | sent2 = [self.token_to_idx[w] if w in self.token_to_idx else 1 129 | for w in sent2] 130 | if self.vocab_size > 0: 131 | sent2 = [w if w < self.vocab_size else 1 for w in sent2] 132 | 133 | instance.append([label, sent1, sent2]) 134 | 135 | if len(instance) >= self.batch_size: 136 | break 137 | except IOError: 138 | self.end_of_data = True 139 | 140 | if len(instance) <= 0: 141 | self.end_of_data = False 142 | self.input_file.seek(0) 143 | raise StopIteration 144 | 145 | return instance 146 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /scripts/esim/main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ESIM""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import sys 24 | import time 25 | import collections 26 | import numpy 27 | import tensorflow as tf 28 | 29 | from data_iterator import TextIterator 30 | 31 | 32 | flags = tf.flags 33 | FLAGS = flags.FLAGS 34 | 35 | # Required parameters 36 | flags.DEFINE_string("train_file", None, "The train file path.") 37 | flags.DEFINE_string("valid_file", None, "The validation file path.") 38 | flags.DEFINE_string("test_file", None, "The test file path.") 39 | flags.DEFINE_string("vocab_file", None, "The vocabulary file.") 40 | flags.DEFINE_string("output_dir", None, 41 | "The output directory where the model checkpoints will be written.") 42 | 43 | # Other parameters 44 | flags.DEFINE_string("embedding_file", None, "The pre-trained embedding file path.") 45 | 46 | flags.DEFINE_bool("fix_embedding", True, "Whether to fix embedding during training.") 47 | flags.DEFINE_bool("use_cudnn", True, "Whether to cudnn version BiLSTM.") 48 | 49 | flags.DEFINE_integer("disp_freq", 10, "The display frequence of log information.") 50 | flags.DEFINE_integer("maxlen_1", 400, "The maximum total first input sequence length.") 51 | flags.DEFINE_integer("maxlen_2", 150, "The maximum total second input sequence length.") 52 | flags.DEFINE_integer("hidden_size", 300, "The hidden size of hidden states for BiLSTM and MLP.") 53 | flags.DEFINE_integer("dim_word", 300, "The dim of word embedding.") 54 | flags.DEFINE_integer("patience", 1, "The early stopping patience.") 55 | flags.DEFINE_integer("vocab_size", 100000, "The vocab size.") 56 | flags.DEFINE_integer("train_batch_size", 16, "Total batch size for training.") 57 | flags.DEFINE_integer("valid_batch_size", 16, "Total batch size for validation.") 58 | flags.DEFINE_integer("test_batch_size", 16, "Total batch size for test.") 59 | flags.DEFINE_integer("max_train_epochs", 50, "Max number of training epochs to perform.") 60 | flags.DEFINE_integer("num_labels", 2, "Number of labels.") 61 | 62 | flags.DEFINE_float("learning_rate", 2e-4, "The initial learning rate for Adam.") 63 | flags.DEFINE_float("clip_c", 10., "Gradient clipping threshold.") 64 | 65 | 66 | def prepare_data(instance): 67 | """ padding the data with minibatch 68 | Args: 69 | instance: [list, list, list] for [labels, seqs_x, seqs_y] 70 | 71 | Return: 72 | x: int64 numpy.array of shape [seq_length_x, batch_size]. 73 | x_mask: float32 numpy.array of shape [seq_length_x, batch_size]. 74 | y: int64 numpy.array of shape [seq_length_y, batch_size]. 75 | y_mask: float32 numpy.array of shape [seq_length_y, batch_size]. 76 | l: int64 numpy.array of shape [batch_size, ]. 77 | """ 78 | 79 | seqs_x = [] 80 | seqs_y = [] 81 | labels = [] 82 | 83 | for ins in instance: 84 | labels.append(ins[0]) 85 | seqs_x.append(ins[1]) 86 | seqs_y.append(ins[2]) 87 | 88 | lengths_x = [len(s) for s in seqs_x] 89 | lengths_y = [len(s) for s in seqs_y] 90 | 91 | maxlen_1 = FLAGS.maxlen_1 92 | maxlen_2 = FLAGS.maxlen_2 93 | 94 | new_seqs_x = [] 95 | new_seqs_y = [] 96 | new_lengths_x = [] 97 | new_lengths_y = [] 98 | new_labels = [] 99 | for l_x, s_x, l_y, s_y, l in zip(lengths_x, seqs_x, lengths_y, seqs_y, labels): 100 | if l_x > maxlen_1: 101 | new_seqs_x.append(s_x[-maxlen_1:]) 102 | new_lengths_x.append(maxlen_1) 103 | else: 104 | new_seqs_x.append(s_x) 105 | new_lengths_x.append(l_x) 106 | if l_y > maxlen_2: 107 | new_seqs_y.append(s_y[:maxlen_2]) 108 | new_lengths_y.append(maxlen_2) 109 | else: 110 | new_seqs_y.append(s_y) 111 | new_lengths_y.append(l_y) 112 | 113 | new_labels.append(l) 114 | 115 | lengths_x = new_lengths_x 116 | seqs_x = new_seqs_x 117 | lengths_y = new_lengths_y 118 | seqs_y = new_seqs_y 119 | labels = new_labels 120 | 121 | if len(lengths_x) < 1 or len(lengths_y) < 1: 122 | return None 123 | 124 | n_samples = len(seqs_x) 125 | maxlen_1 = numpy.max(lengths_x) 126 | maxlen_2 = numpy.max(lengths_y) 127 | 128 | x = numpy.zeros((maxlen_1, n_samples)).astype("int64") 129 | y = numpy.zeros((maxlen_2, n_samples)).astype("int64") 130 | x_mask = numpy.zeros((maxlen_1, n_samples)).astype("float32") 131 | y_mask = numpy.zeros((maxlen_2, n_samples)).astype("float32") 132 | l = numpy.zeros((n_samples,)).astype("int64") 133 | 134 | for idx, (s_x, s_y, ll) in enumerate(zip(seqs_x, seqs_y, labels)): 135 | x[:lengths_x[idx], idx] = s_x 136 | x_mask[:lengths_x[idx], idx] = 1. 137 | y[:lengths_y[idx], idx] = s_y 138 | y_mask[:lengths_y[idx], idx] = 1. 139 | l[idx] = ll 140 | 141 | return (x, x_mask, y, y_mask, l) 142 | 143 | 144 | def bilstm_layer_cudnn(input_data, num_layers, rnn_size, keep_prob=1.): 145 | """Multi-layer BiLSTM cudnn version, faster 146 | Args: 147 | input_data: float32 Tensor of shape [seq_length, batch_size, dim]. 148 | num_layers: int64 scalar, number of layers. 149 | rnn_size: int64 scalar, hidden size for undirectional LSTM. 150 | keep_prob: float32 scalar, keep probability of dropout between BiLSTM layers 151 | 152 | Return: 153 | output: float32 Tensor of shape [seq_length, batch_size, dim * 2] 154 | 155 | """ 156 | with tf.variable_scope("bilstm", reuse=tf.AUTO_REUSE): 157 | lstm = tf.contrib.cudnn_rnn.CudnnLSTM( 158 | num_layers=num_layers, 159 | num_units=rnn_size, 160 | input_mode="linear_input", 161 | direction="bidirectional", 162 | dropout=1 - keep_prob) 163 | 164 | # to do, how to include input_mask 165 | outputs, output_states = lstm(inputs=input_data) 166 | 167 | return outputs 168 | 169 | 170 | def bilstm_layer(input_data, num_layers, rnn_size, keep_prob=1.): 171 | """Multi-layer BiLSTM 172 | Args: 173 | input_data: float32 Tensor of shape [seq_length, batch_size, dim]. 174 | num_layers: int64 scalar, number of layers. 175 | rnn_size: int64 scalar, hidden size for undirectional LSTM. 176 | keep_prob: float32 scalar, keep probability of dropout between BiLSTM layers 177 | 178 | Return: 179 | output: float32 Tensor of shape [seq_length, batch_size, dim * 2] 180 | 181 | """ 182 | input_data = tf.transpose(input_data, [1, 0, 2]) 183 | 184 | output = input_data 185 | for layer in range(num_layers): 186 | with tf.variable_scope('bilstm_{}'.format(layer), reuse=tf.AUTO_REUSE): 187 | 188 | cell_fw = tf.contrib.rnn.LSTMCell( 189 | rnn_size, initializer=tf.truncated_normal_initializer(stddev=0.02)) 190 | cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw, input_keep_prob=keep_prob) 191 | 192 | cell_bw = tf.contrib.rnn.LSTMCell( 193 | rnn_size, initializer=tf.truncated_normal_initializer(stddev=0.02)) 194 | cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw, input_keep_prob=keep_prob) 195 | 196 | outputs, states = tf.nn.bidirectional_dynamic_rnn(cell_fw, 197 | cell_bw, 198 | output, 199 | dtype=tf.float32) 200 | 201 | # Concat the forward and backward outputs 202 | output = tf.concat(outputs, 2) 203 | 204 | output = tf.transpose(output, [1, 0, 2]) 205 | 206 | return output 207 | 208 | 209 | def load_word_embedding(token_to_idx): 210 | """ Load pre-trained word embedding 211 | Args: 212 | token_to_idx: dictionary of token to idx 213 | 214 | Return: 215 | embedding: float32 Tensor of shape [vocab_size, dim_word] 216 | 217 | """ 218 | 219 | embedding_np = 0.02 * \ 220 | numpy.random.randn(FLAGS.vocab_size, FLAGS.dim_word).astype("float32") 221 | 222 | if FLAGS.embedding_file: 223 | with open(FLAGS.embedding_file, "r") as f: 224 | for line in f: 225 | tokens = line.strip().split(" ") 226 | token = tokens[0] 227 | vector = map(float, tokens[1:]) 228 | if token in token_to_idx and token_to_idx[token] < FLAGS.vocab_size: 229 | embedding_np[token_to_idx[token], :] = vector 230 | 231 | embedding = tf.get_variable("embedding", 232 | shape=[FLAGS.vocab_size, FLAGS.dim_word], 233 | initializer=tf.constant_initializer( 234 | numpy.array(embedding_np)), 235 | trainable=not FLAGS.fix_embedding) 236 | return embedding 237 | 238 | 239 | def local_inference(x1, x1_mask, x2, x2_mask): 240 | """Local inference collected over sequences 241 | Args: 242 | x1: float32 Tensor of shape [seq_length1, batch_size, dim]. 243 | x1_mask: float32 Tensor of shape [seq_length1, batch_size]. 244 | x2: float32 Tensor of shape [seq_length2, batch_size, dim]. 245 | x2_mask: float32 Tensor of shape [seq_length2, batch_size]. 246 | 247 | Return: 248 | x1_dual: float32 Tensor of shape [seq_length1, batch_size, dim] 249 | x2_dual: float32 Tensor of shape [seq_length2, batch_size, dim] 250 | 251 | """ 252 | 253 | # x1: [batch_size, seq_length1, dim]. 254 | # x1_mask: [batch_size, seq_length1]. 255 | # x2: [batch_size, seq_length2, dim]. 256 | # x2_mask: [batch_size, seq_length2]. 257 | x1 = tf.transpose(x1, [1, 0, 2]) 258 | x1_mask = tf.transpose(x1_mask, [1, 0]) 259 | x2 = tf.transpose(x2, [1, 0, 2]) 260 | x2_mask = tf.transpose(x2_mask, [1, 0]) 261 | 262 | # attention_weight: [batch_size, seq_length1, seq_length2] 263 | attention_weight = tf.matmul(x1, tf.transpose(x2, [0, 2, 1])) 264 | 265 | # calculate normalized attention weight x1 and x2 266 | # attention_weight_2: [batch_size, seq_length1, seq_length2] 267 | attention_weight_2 = tf.exp( 268 | attention_weight - tf.reduce_max(attention_weight, axis=2, keepdims=True)) 269 | attention_weight_2 = attention_weight_2 * tf.expand_dims(x2_mask, 1) 270 | # alpha: [batch_size, seq_length1, seq_length2] 271 | alpha = attention_weight_2 / (tf.reduce_sum(attention_weight_2, -1, keepdims=True) + 1e-8) 272 | # x1_dual: [batch_size, seq_length1, dim] 273 | x1_dual = tf.reduce_sum(tf.expand_dims(x2, 1) * tf.expand_dims(alpha, -1), 2) 274 | # x1_dual: [seq_length1, batch_size, dim] 275 | x1_dual = tf.transpose(x1_dual, [1, 0, 2]) 276 | 277 | # attention_weight_1: [batch_size, seq_length2, seq_length1] 278 | attention_weight_1 = attention_weight - tf.reduce_max(attention_weight, axis=1, keepdims=True) 279 | attention_weight_1 = tf.exp(tf.transpose(attention_weight_1, [0, 2, 1])) 280 | attention_weight_1 = attention_weight_1 * tf.expand_dims(x1_mask, 1) 281 | 282 | # beta: [batch_size, seq_length2, seq_length1] 283 | beta = attention_weight_1 / \ 284 | (tf.reduce_sum(attention_weight_1, -1, keepdims=True) + 1e-8) 285 | # x2_dual: [batch_size, seq_length2, dim] 286 | x2_dual = tf.reduce_sum(tf.expand_dims(x1, 1) * tf.expand_dims(beta, -1), 2) 287 | # x2_dual: [seq_length2, batch_size, dim] 288 | x2_dual = tf.transpose(x2_dual, [1, 0, 2]) 289 | 290 | return x1_dual, x2_dual 291 | 292 | 293 | def create_model(embedding): 294 | """ Create the computational graph 295 | Args: 296 | embedding: float32 Tensor of shape [vocab_size, dim_word] 297 | 298 | Return: 299 | probability: float32 Tensor of shape [batch_size,] 300 | cost: float32 Tensor of shape [batch_size,] 301 | """ 302 | 303 | # x1: int64 Tensor of shape [seq_length, batch_size]. 304 | # x1_mask: float32 Tensor of shape [seq_length, batch_size]. 305 | # x2: int64 Tensor of shape [seq_length, batch_size]. 306 | # x2_mask: float32 Tensor of shape [seq_length, batch_size]. 307 | # y: int64 Tensor of shape [batch_size,]. 308 | # keep_rate: float32 Scalar 309 | x1 = tf.placeholder(tf.int64, shape=[None, None], name="x1") 310 | x1_mask = tf.placeholder(tf.float32, shape=[None, None], name="x1_mask") 311 | x2 = tf.placeholder(tf.int64, shape=[None, None], name="x2") 312 | x2_mask = tf.placeholder(tf.float32, shape=[None, None], name="x2_mask") 313 | y = tf.placeholder(tf.int64, shape=[None], name="y") 314 | keep_rate = tf.placeholder(tf.float32, [], name="keep_rate") 315 | 316 | # embedding: [length, batch, dim] 317 | emb1 = tf.nn.embedding_lookup(embedding, x1) 318 | emb2 = tf.nn.embedding_lookup(embedding, x2) 319 | 320 | emb1 = tf.nn.dropout(emb1, keep_rate) 321 | emb2 = tf.nn.dropout(emb2, keep_rate) 322 | 323 | emb1 = emb1 * tf.expand_dims(x1_mask, -1) 324 | emb2 = emb2 * tf.expand_dims(x2_mask, -1) 325 | 326 | # encode the sentence pair 327 | with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE): 328 | if FLAGS.use_cudnn: 329 | x1_enc = bilstm_layer_cudnn(emb1, 1, FLAGS.hidden_size) 330 | x2_enc = bilstm_layer_cudnn(emb2, 1, FLAGS.hidden_size) 331 | else: 332 | x1_enc = bilstm_layer(emb1, 1, FLAGS.hidden_size) 333 | x2_enc = bilstm_layer(emb2, 1, FLAGS.hidden_size) 334 | 335 | x1_enc = x1_enc * tf.expand_dims(x1_mask, -1) 336 | x2_enc = x2_enc * tf.expand_dims(x2_mask, -1) 337 | 338 | # local inference modeling based on attention mechanism 339 | x1_dual, x2_dual = local_inference(x1_enc, x1_mask, x2_enc, x2_mask) 340 | 341 | x1_match = tf.concat([x1_enc, x1_dual, x1_enc * x1_dual, x1_enc - x1_dual], 2) 342 | x2_match = tf.concat([x2_enc, x2_dual, x2_enc * x2_dual, x2_enc - x2_dual], 2) 343 | 344 | # mapping high dimension feature to low dimension 345 | with tf.variable_scope("projection", reuse=tf.AUTO_REUSE): 346 | x1_match_mapping = tf.layers.dense(x1_match, FLAGS.hidden_size, 347 | activation=tf.nn.relu, 348 | name="fnn", 349 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)) 350 | x2_match_mapping = tf.layers.dense(x2_match, FLAGS.hidden_size, 351 | activation=tf.nn.relu, 352 | name="fnn", 353 | kernel_initializer=tf.truncated_normal_initializer( 354 | stddev=0.02), 355 | reuse=True) 356 | 357 | x1_match_mapping = tf.nn.dropout(x1_match_mapping, keep_rate) 358 | x2_match_mapping = tf.nn.dropout(x2_match_mapping, keep_rate) 359 | 360 | # inference composition 361 | with tf.variable_scope("composition", reuse=tf.AUTO_REUSE): 362 | if FLAGS.use_cudnn: 363 | x1_cmp = bilstm_layer_cudnn(x1_match_mapping, 1, FLAGS.hidden_size) 364 | x2_cmp = bilstm_layer_cudnn(x2_match_mapping, 1, FLAGS.hidden_size) 365 | else: 366 | x1_cmp = bilstm_layer(x1_match_mapping, 1, FLAGS.hidden_size) 367 | x2_cmp = bilstm_layer(x2_match_mapping, 1, FLAGS.hidden_size) 368 | 369 | logit_x1_sum = tf.reduce_sum(x1_cmp * tf.expand_dims(x1_mask, -1), 0) / \ 370 | tf.expand_dims(tf.reduce_sum(x1_mask, 0), 1) 371 | logit_x1_max = tf.reduce_max(x1_cmp * tf.expand_dims(x1_mask, -1), 0) 372 | logit_x2_sum = tf.reduce_sum(x2_cmp * tf.expand_dims(x2_mask, -1), 0) / \ 373 | tf.expand_dims(tf.reduce_sum(x2_mask, 0), 1) 374 | logit_x2_max = tf.reduce_max(x2_cmp * tf.expand_dims(x2_mask, -1), 0) 375 | 376 | logit = tf.concat([logit_x1_sum, logit_x1_max, logit_x2_sum, logit_x2_max], 1) 377 | 378 | # final classifier 379 | with tf.variable_scope("classifier", reuse=tf.AUTO_REUSE): 380 | logit = tf.nn.dropout(logit, keep_rate) 381 | logit = tf.layers.dense(logit, FLAGS.hidden_size, 382 | activation=tf.nn.tanh, 383 | name="fnn1", 384 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)) 385 | 386 | logit = tf.nn.dropout(logit, keep_rate) 387 | assert FLAGS.num_labels == 2 388 | logit = tf.layers.dense(logit, FLAGS.num_labels, 389 | activation=None, 390 | name="fnn2", 391 | kernel_initializer=tf.truncated_normal_initializer(stddev=0.02)) 392 | 393 | cost = tf.nn.sparse_softmax_cross_entropy_with_logits( 394 | labels=y, logits=logit) 395 | probability = tf.nn.softmax(logit) 396 | 397 | return probability, cost 398 | 399 | 400 | def predict_accuracy(sess, cost_op, probability_op, iterator): 401 | """ Caculate accuracy and loss for dataset 402 | Args: 403 | sess: tf.Session 404 | cost_op: cost operation 405 | probability_op: probability operation 406 | iterator: iterator of dataset 407 | 408 | Return: 409 | accuracy: float32 scalar 410 | loss: float32 scalar 411 | """ 412 | n_done = 0 413 | total_correct = 0 414 | total_cost = 0 415 | for instance in iterator: 416 | n_done += len(instance) 417 | (batch_x1, batch_x1_mask, batch_x2, batch_x2_mask, batch_y) = prepare_data(instance) 418 | 419 | cost, probability = sess.run([cost_op, probability_op], 420 | feed_dict={"x1:0": batch_x1, "x1_mask:0": batch_x1_mask, 421 | "x2:0": batch_x2, "x2_mask:0": batch_x2_mask, 422 | "y:0": batch_y, "keep_rate:0": 1.0}) 423 | 424 | total_correct += (numpy.argmax(probability, axis=1) == batch_y).sum() 425 | total_cost += cost.sum() 426 | 427 | accuracy = 1.0 * total_correct / n_done 428 | loss = 1.0 * total_cost / n_done 429 | 430 | return accuracy, loss 431 | 432 | 433 | def average_precision(sort_data): 434 | """ calculate average precision (AP) 435 | If our returned result is 1, 0, 0, 1, 1, 1 436 | The precision is 1/1, 0, 0, 2/4, 3/5, 4/6 437 | AP = (1 + 2/4 + 3/5 + 4/6)/4 = 0.69 438 | 439 | Args: 440 | sort_data: List of tuple, (score, gold_label); score is in [0, 1], glod_label is in {0, 1} 441 | 442 | Return: 443 | average precision 444 | """ 445 | count_gold = 0 446 | sum_precision = 0 447 | 448 | for i, data in enumerate(sort_data): 449 | if data[1] == 1: 450 | count_gold += 1 451 | sum_precision += 1. * count_gold / (i + 1) 452 | 453 | ap = 1. * sum_precision / count_gold 454 | 455 | return ap 456 | 457 | 458 | def reciprocal_rank(sort_data): 459 | """ calculate reciprocal rank 460 | If our returned result is 0, 0, 0, 1, 1, 1 461 | The rank is 4 462 | The reciprocal rank is 1/4 463 | Args: 464 | sort_data: List of tuple, (score, gold_label); score is in [0, 1], glod_label is in {0, 1} 465 | 466 | Return: 467 | reciprocal rank 468 | 469 | """ 470 | 471 | sort_label = [x[1] for x in sort_data] 472 | assert 1 in sort_label 473 | reciprocal_rank = 1. / (1 + sort_label.index(1)) 474 | 475 | return reciprocal_rank 476 | 477 | 478 | def precision_at_position_1(sort_data): 479 | """ calculate precision at position 1 480 | Precision= (Relevant_Items_Recommended in top-k) / (k_Items_Recommended) 481 | 482 | Args: 483 | sort_data: List of tuple, (score, gold_label); score is in [0, 1], glod_label is in {0, 1} 484 | 485 | Return: 486 | precision_at_position_1 487 | 488 | """ 489 | 490 | if sort_data[0][1] == 1: 491 | return 1 492 | else: 493 | return 0 494 | 495 | 496 | def recall_at_position_k(sort_data, k): 497 | """ calculate precision at position 1 498 | Recall= (Relevant_Items_Recommended in top-k) / (Relevant_Items) 499 | 500 | Args: 501 | sort_data: List of tuple, (score, gold_label); score is in [0, 1], glod_label is in {0, 1} 502 | 503 | Return: 504 | recall_at_position_k 505 | 506 | """ 507 | 508 | sort_label = [s_d[1] for s_d in sort_data] 509 | gold_label_count = sort_label.count(1) 510 | 511 | select_label = sort_label[:k] 512 | recall_at_position_k = 1. * select_label.count(1) / gold_label_count 513 | 514 | return recall_at_position_k 515 | 516 | 517 | def evaluation_one_session(data): 518 | """ evaluate for one session 519 | 520 | """ 521 | 522 | sort_data = sorted(data, key=lambda x: x[0], reverse=True) 523 | ap = average_precision(sort_data) 524 | rr = reciprocal_rank(sort_data) 525 | precision1 = precision_at_position_1(sort_data) 526 | recall1 = recall_at_position_k(sort_data, 1) 527 | recall2 = recall_at_position_k(sort_data, 2) 528 | recall5 = recall_at_position_k(sort_data, 5) 529 | 530 | return ap, rr, precision1, recall1, recall2, recall5 531 | 532 | 533 | def predict_metrics(sess, cost_op, probability_op, iterator): 534 | """ Caculate MAP, MRR, Precision@1, Recall@1, Recall@2, Recall@5 535 | Args: 536 | sess: tf.Session 537 | cost_op: cost operation 538 | probability_op: probability operation 539 | iterator: iterator of dataset 540 | 541 | Return: 542 | metrics: float32 list, [MAP, MRR, Precision@1, Recall@1, Recall@2, Recall@5] 543 | scores: float32 list, probability for positive label for all instances 544 | """ 545 | 546 | n_done = 0 547 | scores = [] 548 | labels = [] 549 | for instance in iterator: 550 | n_done += len(instance) 551 | (batch_x1, batch_x1_mask, batch_x2, batch_x2_mask, batch_y) = prepare_data(instance) 552 | cost, probability = sess.run([cost_op, probability_op], 553 | feed_dict={"x1:0": batch_x1, "x1_mask:0": batch_x1_mask, 554 | "x2:0": batch_x2, "x2_mask:0": batch_x2_mask, 555 | "y:0": batch_y, "keep_rate:0": 1.0}) 556 | 557 | labels.extend(batch_y.tolist()) 558 | # probability for positive label 559 | scores.extend(probability[:, 1].tolist()) 560 | 561 | assert len(labels) == n_done 562 | assert len(scores) == n_done 563 | 564 | tf.logging.info("seen samples %s", n_done) 565 | 566 | sum_map = 0 567 | sum_mrr = 0 568 | sum_p1 = 0 569 | sum_r1 = 0 570 | sum_r2 = 0 571 | sum_r5 = 0 572 | total_num = 0 573 | 574 | for i, (s, l) in enumerate(zip(scores, labels)): 575 | if i % 10 == 0: 576 | data = [] 577 | data.append((float(s), int(l))) 578 | 579 | if i % 10 == 9: 580 | total_num += 1 581 | ap, rr, precision1, recall1, recall2, recall5 = evaluation_one_session( 582 | data) 583 | sum_map += ap 584 | sum_mrr += rr 585 | sum_p1 += precision1 586 | sum_r1 += recall1 587 | sum_r2 += recall2 588 | sum_r5 += recall5 589 | 590 | metrics = [1. * sum_map / total_num, 1. * sum_mrr / total_num, 1. * sum_p1 / total_num, 591 | 1. * sum_r1 / total_num, 1. * sum_r2 / total_num, 1. * sum_r5 / total_num] 592 | 593 | return metrics, scores 594 | 595 | 596 | def load_vocab(vocab_file): 597 | """Loads a vocabulary file into a dictionary.""" 598 | vocab = collections.OrderedDict() 599 | index = 0 600 | with open(vocab_file, "r") as f: 601 | while True: 602 | token = f.readline() 603 | if not token: 604 | break 605 | token = token.strip() 606 | vocab[token] = index 607 | index += 1 608 | return vocab 609 | 610 | 611 | def main(_): 612 | """Main procedure for training and test 613 | 614 | """ 615 | 616 | ud_start_whole = time.time() 617 | 618 | tf.logging.set_verbosity(tf.logging.INFO) 619 | 620 | # Load vocabulary 621 | tf.logging.info("***** Loading Vocabulary *****") 622 | token_to_idx = load_vocab(FLAGS.vocab_file) 623 | 624 | tf.gfile.MakeDirs(FLAGS.output_dir) 625 | 626 | # Load text iterator 627 | tf.logging.info("***** Loading Text Iterator *****") 628 | train = TextIterator(FLAGS.train_file, token_to_idx, 629 | batch_size=FLAGS.train_batch_size, 630 | vocab_size=FLAGS.vocab_size, 631 | shuffle=True) 632 | valid = TextIterator(FLAGS.valid_file, token_to_idx, 633 | batch_size=FLAGS.valid_batch_size, 634 | vocab_size=FLAGS.vocab_size, 635 | shuffle=False) 636 | test = TextIterator(FLAGS.test_file, token_to_idx, 637 | batch_size=FLAGS.test_batch_size, 638 | vocab_size=FLAGS.vocab_size, 639 | shuffle=False) 640 | # Text iterator of training set for evaluation 641 | train_eval = TextIterator(FLAGS.train_file, token_to_idx, 642 | vocab_size=FLAGS.vocab_size, batch_size=FLAGS.train_batch_size, shuffle=False) 643 | 644 | # Initialize the word embedding 645 | tf.logging.info("***** Initialize Word Embedding *****") 646 | embedding = load_word_embedding(token_to_idx) 647 | 648 | # Build graph 649 | tf.logging.info("***** Build Computation Graph *****") 650 | probability_op, cost_op = create_model(embedding) 651 | loss_op = tf.reduce_mean(cost_op) 652 | 653 | lr = tf.Variable(0.0, name="learning_rate", trainable=False) 654 | 655 | optimizer = tf.train.AdamOptimizer(learning_rate=lr) 656 | 657 | tf.logging.info("***** Trainable Variables *****") 658 | 659 | tvars = tf.trainable_variables() 660 | for var in tvars: 661 | tf.logging.info(" name = %s, shape = %s", var.name, var.shape) 662 | 663 | if FLAGS.clip_c > 0.: 664 | grads, _ = tf.clip_by_global_norm( 665 | tf.gradients(cost_op, tvars), FLAGS.clip_c) 666 | 667 | train_op = optimizer.apply_gradients(zip(grads, tvars)) 668 | init = tf.global_variables_initializer() 669 | saver = tf.train.Saver(max_to_keep=5) 670 | 671 | # training process 672 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 673 | sess.run(init) 674 | 675 | uidx = 0 676 | bad_counter = 0 677 | history_errs = [] 678 | 679 | current_lr = FLAGS.learning_rate 680 | sess.run(tf.assign(lr, current_lr)) 681 | 682 | for eidx in range(FLAGS.max_train_epochs): 683 | tf.logging.info("***** Training at Epoch %s *****", eidx) 684 | n_samples = 0 685 | for instance in train: 686 | n_samples += len(instance) 687 | uidx += 1 688 | 689 | (batch_x1, batch_x1_mask, batch_x2, batch_x2_mask, batch_y) = prepare_data( 690 | instance) 691 | 692 | if batch_x1 is None: 693 | tf.logging.info("Minibatch with zero sample") 694 | uidx -= 1 695 | continue 696 | 697 | ud_start = time.time() 698 | _, loss = sess.run([train_op, loss_op], 699 | feed_dict={ 700 | "x1:0": batch_x1, "x1_mask:0": batch_x1_mask, 701 | "x2:0": batch_x2, "x2_mask:0": batch_x2_mask, 702 | "y:0": batch_y, "keep_rate:0": 0.5}) 703 | ud = time.time() - ud_start 704 | 705 | if numpy.mod(uidx, FLAGS.disp_freq) == 0: 706 | tf.logging.info( 707 | "epoch %s update %s loss %s samples/sec %s", eidx, uidx, loss, 1. * batch_x1.shape[1] / ud) 708 | 709 | tf.logging.info("***** Evaluation at Epoch %s *****", eidx) 710 | tf.logging.info("seen samples %s each epoch", n_samples) 711 | tf.logging.info("current learning rate: %s", current_lr) 712 | 713 | # validate model on validation set and early stop if necessary 714 | valid_metrics, valid_scores = predict_metrics( 715 | sess, cost_op, probability_op, valid) 716 | 717 | # select best model based on recall@1 of validation set 718 | valid_err = 1.0 - valid_metrics[3] 719 | history_errs.append(valid_err) 720 | 721 | tf.logging.info( 722 | "valid set: MAP %s MRR %s Precision@1 %s Recall@1 %s Recall@2 %s Recall@5 %s", *valid_metrics) 723 | 724 | test_metrics, test_scores = predict_metrics( 725 | sess, cost_op, probability_op, test) 726 | 727 | tf.logging.info( 728 | "test set: MAP %s MRR %s Precision@1 %s Recall@1 %s Recall@2 %s Recall@5 %s", *test_metrics) 729 | 730 | if eidx == 0 or valid_err <= numpy.array(history_errs).min(): 731 | best_epoch_num = eidx 732 | tf.logging.info( 733 | "saving current best model at epoch %s based on metrics on valid set", best_epoch_num) 734 | saver.save(sess, os.path.join( 735 | FLAGS.output_dir, "model_epoch_{}.ckpt".format(best_epoch_num))) 736 | 737 | if valid_err > numpy.array(history_errs).min(): 738 | bad_counter += 1 739 | tf.logging.info("bad_counter: %s", bad_counter) 740 | 741 | current_lr = current_lr * 0.5 742 | sess.run(tf.assign(lr, current_lr)) 743 | tf.logging.info( 744 | "half the current learning rate to %s", current_lr) 745 | 746 | if bad_counter > FLAGS.patience: 747 | tf.logging.info("***** Early Stop *****") 748 | estop = True 749 | break 750 | 751 | # evaluation process 752 | tf.logging.info("***** Final Result ***** ") 753 | tf.logging.info( 754 | "restore best model at epoch %s ", best_epoch_num) 755 | saver.restore(sess, os.path.join( 756 | FLAGS.output_dir, "model_epoch_{}.ckpt".format(best_epoch_num))) 757 | 758 | valid_metrics, valid_scores = predict_metrics( 759 | sess, cost_op, probability_op, valid) 760 | tf.logging.info( 761 | "valid set: MAP %s MRR %s Precision@1 %s Recall@1 %s Recall@2 %s Recall@5 %s", *valid_metrics) 762 | 763 | test_metrics, test_scores = predict_metrics( 764 | sess, cost_op, probability_op, test) 765 | tf.logging.info( 766 | "test set: MAP %s MRR %s Precision@1 %s Recall@1 %s Recall@2 %s Recall@5 %s", *test_metrics) 767 | 768 | train_acc, train_cost = predict_accuracy( 769 | sess, cost_op, probability_op, train_eval) 770 | tf.logging.info("train set: ACC %s Cost %s", train_acc, train_cost) 771 | 772 | ud_whole = (time.time() - ud_start_whole) / 3600 773 | 774 | tf.logging.info("training epochs: %s", eidx + 1) 775 | tf.logging.info("training duration: %s hours", ud_whole) 776 | 777 | 778 | if __name__ == "__main__": 779 | flags.mark_flag_as_required("train_file") 780 | flags.mark_flag_as_required("valid_file") 781 | flags.mark_flag_as_required("test_file") 782 | flags.mark_flag_as_required("vocab_file") 783 | flags.mark_flag_as_required("output_dir") 784 | 785 | tf.app.run() 786 | --------------------------------------------------------------------------------