├── .gitignore ├── LICENSE ├── README.md ├── attention_decoder.py ├── copynet ├── attention_decoder.py ├── dynamic_decoder.py ├── main.py ├── model.py └── output_projection.py ├── dynamic_decoder.py ├── image ├── demo.png └── evaluation.png ├── main.py ├── memnet ├── attention_decoder.py ├── dynamic_decoder.py ├── main.py ├── model.py └── output_projection.py ├── model.py └── output_projection.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.py[cod] 3 | *.so 4 | *.egg 5 | *.egg-info 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Commonsense Knowledge Aware Conversation Generation with Graph Attention 2 | 3 | ## Introduction 4 | 5 | Commonsense knowledge is vital to many natural language processing tasks. In this [paper](https://www.ijcai.org/proceedings/2018/0643.pdf), we present a novel open-domain conversation generation model to demonstrate how large-scale commonsense knowledge can facilitate language understanding and generation. Given a user post, the model retrieves relevant knowledge graphs from a knowledge base and then encodes the graphs with a static graph attention mechanism, which augments the semantic information of the post and thus supports better understanding of the post. Then, during word generation, the model attentively reads the retrieved knowledge graphs and the knowledge triples within each graph to facilitate better generation through a dynamic graph attention mechanism, as shown in Figure 1. 6 | 7 | ![image](https://raw.githubusercontent.com/tuxchow/ccm/master/image/demo.png) 8 | 9 | This project is a tensorflow implement of our work, [CCM](http://coai.cs.tsinghua.edu.cn/hml/media/files/2018_commonsense_ZhouHao_3_TYVQ7Iq.pdf). 10 | 11 | ## Dependencies 12 | 13 | * Python 2.7 14 | * Numpy 15 | * Tensorflow 1.3.0 16 | 17 | ## Quick Start 18 | 19 | * Dataset 20 | 21 | Commonsense Conversation Dataset contains one-turn post-response pairs with the corresponding commonsense knowledge graphs. Each pair is associated with some knowledge graphs retrieved from ConceptNet. We have applied some filtering rules to retain high-quality and useful knowledge graphs. 22 | 23 | Please [download](http://coai.cs.tsinghua.edu.cn/hml/dataset/#commonsense) the Commonsense Conversation Dataset to data directory. 24 | 25 | * Train 26 | 27 | ```python main.py ``` 28 | 29 | The model will achieve the expected performance after 20 epochs. 30 | 31 | * Test 32 | 33 | ```python main.py --is_train False ``` 34 | 35 | You can test the model using this command. The statistical result and the text result will be output to the 'test.res' file and the 'test.log' file respectively. 36 | 37 | 38 | ## Details 39 | 40 | ### Training 41 | 42 | You can change the model parameters using: 43 | 44 | --units xxx the hidden units 45 | --layers xxx the number of RNN layers 46 | --batch_size xxx batch size to use during training 47 | --per_checkpoint xxx steps to save and evaluate the model 48 | --train_dir xxx training directory 49 | 50 | ### Evaluation 51 | 52 | ![image](https://raw.githubusercontent.com/tuxchow/ccm/master/image/evaluation.png) 53 | 54 | ## Paper 55 | 56 | Hao Zhou, Tom Yang, Minlie Huang, Haizhou Zhao, Jingfang Xu, Xiaoyan Zhu. 57 | [Commonsense Knowledge Aware Conversation Generation with Graph Attention.](http://coai.cs.tsinghua.edu.cn/hml/media/files/2018_commonsense_ZhouHao_3_TYVQ7Iq.pdf) 58 | IJCAI-ECAI 2018, Stockholm, Sweden. 59 | 60 | **Please kindly cite our paper if this paper and the code are helpful.** 61 | 62 | 63 | ## Acknowlegments 64 | 65 | Thanks for the kind help of Prof. Minlie Huang and Prof. Xiaoyan Zhu. Thanks for the support of my teammates. 66 | 67 | ## License 68 | 69 | Apache License 2.0 70 | 71 | -------------------------------------------------------------------------------- /copynet/attention_decoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import tensorflow as tf 5 | 6 | from tensorflow.contrib.layers.python.layers import layers 7 | from tensorflow.python.ops import rnn_cell_impl 8 | from tensorflow.python.framework import dtypes 9 | from tensorflow.python.framework import function 10 | from tensorflow.python.framework import ops 11 | from tensorflow.python.ops import array_ops 12 | from tensorflow.python.ops import control_flow_ops 13 | from tensorflow.python.ops import gen_data_flow_ops 14 | from tensorflow.python.ops import tensor_array_ops 15 | from tensorflow.python.ops import math_ops 16 | from tensorflow.python.ops import nn_ops 17 | from tensorflow.python.ops import variable_scope 18 | from tensorflow.python.util import nest 19 | 20 | def attention_decoder_fn_train(encoder_state, 21 | attention_keys, 22 | attention_values, 23 | attention_score_fn, 24 | attention_construct_fn, 25 | output_alignments=False, 26 | max_length=None, 27 | name=None): 28 | """Attentional decoder function for `dynamic_rnn_decoder` during training. 29 | 30 | The `attention_decoder_fn_train` is a training function for an 31 | attention-based sequence-to-sequence model. It should be used when 32 | `dynamic_rnn_decoder` is in the training mode. 33 | 34 | The `attention_decoder_fn_train` is called with a set of the user arguments 35 | and returns the `decoder_fn`, which can be passed to the 36 | `dynamic_rnn_decoder`, such that 37 | 38 | ``` 39 | dynamic_fn_train = attention_decoder_fn_train(encoder_state) 40 | outputs_train, state_train = dynamic_rnn_decoder( 41 | decoder_fn=dynamic_fn_train, ...) 42 | ``` 43 | 44 | Further usage can be found in the `kernel_tests/seq2seq_test.py`. 45 | 46 | Args: 47 | encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`. 48 | attention_keys: to be compared with target states. 49 | attention_values: to be used to construct context vectors. 50 | attention_score_fn: to compute similarity between key and target states. 51 | attention_construct_fn: to build attention states. 52 | name: (default: `None`) NameScope for the decoder function; 53 | defaults to "simple_decoder_fn_train" 54 | 55 | Returns: 56 | A decoder function with the required interface of `dynamic_rnn_decoder` 57 | intended for training. 58 | """ 59 | with ops.name_scope(name, "attention_decoder_fn_train", [ 60 | encoder_state, attention_keys, attention_values, attention_score_fn, 61 | attention_construct_fn 62 | ]): 63 | pass 64 | 65 | def decoder_fn(time, cell_state, cell_input, cell_output, context_state): 66 | """Decoder function used in the `dynamic_rnn_decoder` for training. 67 | 68 | Args: 69 | time: positive integer constant reflecting the current timestep. 70 | cell_state: state of RNNCell. 71 | cell_input: input provided by `dynamic_rnn_decoder`. 72 | cell_output: output of RNNCell. 73 | context_state: context state provided by `dynamic_rnn_decoder`. 74 | 75 | Returns: 76 | A tuple (done, next state, next input, emit output, next context state) 77 | where: 78 | 79 | done: `None`, which is used by the `dynamic_rnn_decoder` to indicate 80 | that `sequence_lengths` in `dynamic_rnn_decoder` should be used. 81 | 82 | next state: `cell_state`, this decoder function does not modify the 83 | given state. 84 | 85 | next input: `cell_input`, this decoder function does not modify the 86 | given input. The input could be modified when applying e.g. attention. 87 | 88 | emit output: `cell_output`, this decoder function does not modify the 89 | given output. 90 | 91 | next context state: `context_state`, this decoder function does not 92 | modify the given context state. The context state could be modified when 93 | applying e.g. beam search. 94 | """ 95 | with ops.name_scope( 96 | name, "attention_decoder_fn_train", 97 | [time, cell_state, cell_input, cell_output, context_state]): 98 | if cell_state is None: # first call, return encoder_state 99 | cell_state = encoder_state 100 | 101 | # init attention 102 | attention = _init_attention(encoder_state) 103 | if output_alignments: 104 | context_state = tensor_array_ops.TensorArray(dtype=dtypes.float32, tensor_array_name="alignments_ta", size=max_length, dynamic_size=True, infer_shape=False) 105 | else: 106 | # construct attention 107 | #cell_output = tf.Print(cell_output, [context_state.stack()], summarize=1e8) 108 | attention = attention_construct_fn(cell_output, attention_keys, attention_values) 109 | if output_alignments: 110 | attention, alignments = attention 111 | context_state = context_state.write(time-1, alignments) 112 | 113 | cell_output = attention 114 | 115 | # combine cell_input and attention 116 | next_input = array_ops.concat([cell_input, attention], 1) 117 | 118 | return (None, cell_state, next_input, cell_output, context_state) 119 | 120 | return decoder_fn 121 | 122 | 123 | def attention_decoder_fn_inference(output_fn, 124 | encoder_state, 125 | attention_keys, 126 | attention_values, 127 | attention_score_fn, 128 | attention_construct_fn, 129 | embeddings, 130 | start_of_sequence_id, 131 | end_of_sequence_id, 132 | maximum_length, 133 | num_decoder_symbols, 134 | dtype=dtypes.int32, 135 | selector_fn=None, 136 | imem=None, 137 | name=None): 138 | """Attentional decoder function for `dynamic_rnn_decoder` during inference. 139 | 140 | The `attention_decoder_fn_inference` is a simple inference function for a 141 | sequence-to-sequence model. It should be used when `dynamic_rnn_decoder` is 142 | in the inference mode. 143 | 144 | The `attention_decoder_fn_inference` is called with user arguments 145 | and returns the `decoder_fn`, which can be passed to the 146 | `dynamic_rnn_decoder`, such that 147 | 148 | ``` 149 | dynamic_fn_inference = attention_decoder_fn_inference(...) 150 | outputs_inference, state_inference = dynamic_rnn_decoder( 151 | decoder_fn=dynamic_fn_inference, ...) 152 | ``` 153 | 154 | Further usage can be found in the `kernel_tests/seq2seq_test.py`. 155 | 156 | Args: 157 | output_fn: An output function to project your `cell_output` onto class 158 | logits. 159 | 160 | An example of an output function; 161 | 162 | ``` 163 | tf.variable_scope("decoder") as varscope 164 | output_fn = lambda x: layers.linear(x, num_decoder_symbols, 165 | scope=varscope) 166 | 167 | outputs_train, state_train = seq2seq.dynamic_rnn_decoder(...) 168 | logits_train = output_fn(outputs_train) 169 | 170 | varscope.reuse_variables() 171 | logits_inference, state_inference = seq2seq.dynamic_rnn_decoder( 172 | output_fn=output_fn, ...) 173 | ``` 174 | 175 | If `None` is supplied it will act as an identity function, which 176 | might be wanted when using the RNNCell `OutputProjectionWrapper`. 177 | 178 | encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`. 179 | attention_keys: to be compared with target states. 180 | attention_values: to be used to construct context vectors. 181 | attention_score_fn: to compute similarity between key and target states. 182 | attention_construct_fn: to build attention states. 183 | embeddings: The embeddings matrix used for the decoder sized 184 | `[num_decoder_symbols, embedding_size]`. 185 | start_of_sequence_id: The start of sequence ID in the decoder embeddings. 186 | end_of_sequence_id: The end of sequence ID in the decoder embeddings. 187 | maximum_length: The maximum allowed of time steps to decode. 188 | num_decoder_symbols: The number of classes to decode at each time step. 189 | dtype: (default: `dtypes.int32`) The default data type to use when 190 | handling integer objects. 191 | name: (default: `None`) NameScope for the decoder function; 192 | defaults to "attention_decoder_fn_inference" 193 | 194 | Returns: 195 | A decoder function with the required interface of `dynamic_rnn_decoder` 196 | intended for inference. 197 | """ 198 | with ops.name_scope(name, "attention_decoder_fn_inference", [ 199 | output_fn, encoder_state, attention_keys, attention_values, 200 | attention_score_fn, attention_construct_fn, embeddings, imem, 201 | start_of_sequence_id, end_of_sequence_id, maximum_length, 202 | num_decoder_symbols, dtype 203 | ]): 204 | start_of_sequence_id = ops.convert_to_tensor(start_of_sequence_id, dtype) 205 | end_of_sequence_id = ops.convert_to_tensor(end_of_sequence_id, dtype) 206 | maximum_length = ops.convert_to_tensor(maximum_length, dtype) 207 | num_decoder_symbols = ops.convert_to_tensor(num_decoder_symbols, dtype) 208 | encoder_info = nest.flatten(encoder_state)[0] 209 | batch_size = encoder_info.get_shape()[0].value 210 | if output_fn is None: 211 | output_fn = lambda x: x 212 | if batch_size is None: 213 | batch_size = array_ops.shape(encoder_info)[0] 214 | 215 | def decoder_fn(time, cell_state, cell_input, cell_output, context_state): 216 | """Decoder function used in the `dynamic_rnn_decoder` for inference. 217 | 218 | The main difference between this decoder function and the `decoder_fn` in 219 | `attention_decoder_fn_train` is how `next_cell_input` is calculated. In 220 | decoder function we calculate the next input by applying an argmax across 221 | the feature dimension of the output from the decoder. This is a 222 | greedy-search approach. (Bahdanau et al., 2014) & (Sutskever et al., 2014) 223 | use beam-search instead. 224 | 225 | Args: 226 | time: positive integer constant reflecting the current timestep. 227 | cell_state: state of RNNCell. 228 | cell_input: input provided by `dynamic_rnn_decoder`. 229 | cell_output: output of RNNCell. 230 | context_state: context state provided by `dynamic_rnn_decoder`. 231 | 232 | Returns: 233 | A tuple (done, next state, next input, emit output, next context state) 234 | where: 235 | 236 | done: A boolean vector to indicate which sentences has reached a 237 | `end_of_sequence_id`. This is used for early stopping by the 238 | `dynamic_rnn_decoder`. When `time>=maximum_length` a boolean vector with 239 | all elements as `true` is returned. 240 | 241 | next state: `cell_state`, this decoder function does not modify the 242 | given state. 243 | 244 | next input: The embedding from argmax of the `cell_output` is used as 245 | `next_input`. 246 | 247 | emit output: If `output_fn is None` the supplied `cell_output` is 248 | returned, else the `output_fn` is used to update the `cell_output` 249 | before calculating `next_input` and returning `cell_output`. 250 | 251 | next context state: `context_state`, this decoder function does not 252 | modify the given context state. The context state could be modified when 253 | applying e.g. beam search. 254 | 255 | Raises: 256 | ValueError: if cell_input is not None. 257 | 258 | """ 259 | with ops.name_scope( 260 | name, "attention_decoder_fn_inference", 261 | [time, cell_state, cell_input, cell_output, context_state]): 262 | if cell_input is not None: 263 | raise ValueError("Expected cell_input to be None, but saw: %s" % 264 | cell_input) 265 | if cell_output is None: 266 | # invariant that this is time == 0 267 | next_input_id = array_ops.ones( 268 | [batch_size,], dtype=dtype) * (start_of_sequence_id) 269 | done = array_ops.zeros([batch_size,], dtype=dtypes.bool) 270 | cell_state = encoder_state 271 | cell_output = array_ops.zeros( 272 | [num_decoder_symbols], dtype=dtypes.float32) 273 | cell_input = array_ops.gather(embeddings, next_input_id) 274 | 275 | # init attention 276 | attention = _init_attention(encoder_state) 277 | if imem is not None: 278 | context_state = tensor_array_ops.TensorArray(dtype=dtypes.int32, tensor_array_name="output_ids_ta", size=maximum_length, dynamic_size=True, infer_shape=False) 279 | else: 280 | # construct attention 281 | attention = attention_construct_fn(cell_output, attention_keys, 282 | attention_values) 283 | if type(attention) is tuple: 284 | attention, alignment = attention 285 | cell_output = attention 286 | alignment = tf.reshape(alignment, [batch_size, -1]) 287 | #cell_output = output_fn(cell_output) # logits 288 | #next_input_id = math_ops.cast( 289 | # math_ops.argmax(cell_output, 1), dtype=dtype) 290 | #done = math_ops.equal(next_input_id, end_of_sequence_id) 291 | #cell_input = array_ops.gather(embeddings, next_input_id) 292 | selector = selector_fn(cell_output) 293 | logit = output_fn(cell_output) 294 | word_prob = nn_ops.softmax(logit) * (1 - selector) 295 | entity_prob = alignment * selector 296 | mask = array_ops.reshape(math_ops.cast(math_ops.greater(tf.reduce_max(word_prob, 1), tf.reduce_max(entity_prob, 1)), dtype=dtypes.float32), [-1,1]) 297 | cell_input = mask * array_ops.gather(embeddings, math_ops.cast(math_ops.argmax(word_prob, 1), dtype=dtype)) + (1 - mask) * array_ops.gather_nd(imem, array_ops.concat([array_ops.reshape(math_ops.range(batch_size, dtype=dtype), [-1,1]), array_ops.reshape(math_ops.cast(math_ops.argmax(entity_prob, 1), dtype=dtype), [-1,1])], axis=1)) 298 | 299 | mask = array_ops.reshape(math_ops.cast(mask, dtype=dtype), [-1]) 300 | input_id = mask * math_ops.cast(math_ops.argmax(word_prob, 1), dtype=dtype) + (mask - 1) * math_ops.cast(math_ops.argmax(entity_prob, 1), dtype=dtype) 301 | context_state = context_state.write(time-1, input_id) 302 | done = array_ops.reshape(math_ops.equal(input_id, end_of_sequence_id), [-1]) 303 | #done = tf.Print(done, ['selector', selector, 'mask', mask], summarize=1e6) 304 | cell_output = logit 305 | 306 | else: 307 | cell_output = attention 308 | 309 | # argmax decoder 310 | cell_output = output_fn(cell_output) # logits 311 | next_input_id = math_ops.cast( 312 | math_ops.argmax(cell_output, 1), dtype=dtype) 313 | done = math_ops.equal(next_input_id, end_of_sequence_id) 314 | cell_input = array_ops.gather(embeddings, next_input_id) 315 | 316 | # combine cell_input and attention 317 | next_input = array_ops.concat([cell_input, attention], 1) 318 | 319 | # if time > maxlen, return all true vector 320 | done = control_flow_ops.cond( 321 | math_ops.greater(time, maximum_length), 322 | lambda: array_ops.ones([batch_size,], dtype=dtypes.bool), 323 | lambda: done) 324 | return (done, cell_state, next_input, cell_output, context_state) 325 | 326 | return decoder_fn 327 | 328 | def attention_decoder_fn_beam_inference(output_fn, 329 | encoder_state, 330 | attention_keys, 331 | attention_values, 332 | attention_score_fn, 333 | attention_construct_fn, 334 | embeddings, 335 | start_of_sequence_id, 336 | end_of_sequence_id, 337 | maximum_length, 338 | num_decoder_symbols, 339 | beam_size, 340 | remove_unk=False, 341 | d_rate=0.0, 342 | dtype=dtypes.int32, 343 | name=None): 344 | """Attentional decoder function for `dynamic_rnn_decoder` during inference. 345 | The `attention_decoder_fn_inference` is a simple inference function for a 346 | sequence-to-sequence model. It should be used when `dynamic_rnn_decoder` is 347 | in the inference mode. 348 | The `attention_decoder_fn_inference` is called with user arguments 349 | and returns the `decoder_fn`, which can be passed to the 350 | `dynamic_rnn_decoder`, such that 351 | ``` 352 | dynamic_fn_inference = attention_decoder_fn_inference(...) 353 | outputs_inference, state_inference = dynamic_rnn_decoder( 354 | decoder_fn=dynamic_fn_inference, ...) 355 | ``` 356 | Further usage can be found in the `kernel_tests/seq2seq_test.py`. 357 | Args: 358 | output_fn: An output function to project your `cell_output` onto class 359 | logits. 360 | An example of an output function; 361 | ``` 362 | tf.variable_scope("decoder") as varscope 363 | output_fn = lambda x: layers.linear(x, num_decoder_symbols, 364 | scope=varscope) 365 | outputs_train, state_train = seq2seq.dynamic_rnn_decoder(...) 366 | logits_train = output_fn(outputs_train) 367 | varscope.reuse_variables() 368 | logits_inference, state_inference = seq2seq.dynamic_rnn_decoder( 369 | output_fn=output_fn, ...) 370 | ``` 371 | If `None` is supplied it will act as an identity function, which 372 | might be wanted when using the RNNCell `OutputProjectionWrapper`. 373 | encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`. 374 | attention_keys: to be compared with target states. 375 | attention_values: to be used to construct context vectors. 376 | attention_score_fn: to compute similarity between key and target states. 377 | attention_construct_fn: to build attention states. 378 | embeddings: The embeddings matrix used for the decoder sized 379 | `[num_decoder_symbols, embedding_size]`. 380 | start_of_sequence_id: The start of sequence ID in the decoder embeddings. 381 | end_of_sequence_id: The end of sequence ID in the decoder embeddings. 382 | maximum_length: The maximum allowed of time steps to decode. 383 | num_decoder_symbols: The number of classes to decode at each time step. 384 | dtype: (default: `dtypes.int32`) The default data type to use when 385 | handling integer objects. 386 | name: (default: `None`) NameScope for the decoder function; 387 | defaults to "attention_decoder_fn_inference" 388 | Returns: 389 | A decoder function with the required interface of `dynamic_rnn_decoder` 390 | intended for inference. 391 | """ 392 | with ops.name_scope(name, "attention_decoder_fn_inference", [ 393 | output_fn, encoder_state, attention_keys, attention_values, 394 | attention_score_fn, attention_construct_fn, embeddings, 395 | start_of_sequence_id, end_of_sequence_id, maximum_length, 396 | num_decoder_symbols, dtype 397 | ]): 398 | state_size = int(encoder_state[0].get_shape().with_rank(2)[1]) 399 | state = [] 400 | for s in encoder_state: 401 | state.append(array_ops.reshape(array_ops.concat([array_ops.reshape(s, [-1, 1, state_size])]*beam_size, 1), [-1, state_size])) 402 | encoder_state = tuple(state) 403 | origin_batch = array_ops.shape(attention_values)[0] 404 | attn_length = array_ops.shape(attention_values)[1] 405 | attention_values = array_ops.reshape(array_ops.concat([array_ops.reshape(attention_values, [-1, 1, attn_length, state_size])]*beam_size, 1), [-1, attn_length, state_size]) 406 | attn_size = array_ops.shape(attention_keys)[2] 407 | attention_keys = array_ops.reshape(array_ops.concat([array_ops.reshape(attention_keys, [-1, 1, attn_length, attn_size])]*beam_size, 1), [-1, attn_length, attn_size]) 408 | start_of_sequence_id = ops.convert_to_tensor(start_of_sequence_id, dtype) 409 | end_of_sequence_id = ops.convert_to_tensor(end_of_sequence_id, dtype) 410 | maximum_length = ops.convert_to_tensor(maximum_length, dtype) 411 | num_decoder_symbols = ops.convert_to_tensor(num_decoder_symbols, dtype) 412 | encoder_info = nest.flatten(encoder_state)[0] 413 | batch_size = encoder_info.get_shape()[0].value 414 | if output_fn is None: 415 | output_fn = lambda x: x 416 | if batch_size is None: 417 | batch_size = array_ops.shape(encoder_info)[0] 418 | #beam_size = ops.convert_to_tensor(beam_size, dtype) 419 | 420 | def decoder_fn(time, cell_state, cell_input, cell_output, context_state): 421 | """Decoder function used in the `dynamic_rnn_decoder` for inference. 422 | The main difference between this decoder function and the `decoder_fn` in 423 | `attention_decoder_fn_train` is how `next_cell_input` is calculated. In 424 | decoder function we calculate the next input by applying an argmax across 425 | the feature dimension of the output from the decoder. This is a 426 | greedy-search approach. (Bahdanau et al., 2014) & (Sutskever et al., 2014) 427 | use beam-search instead. 428 | Args: 429 | time: positive integer constant reflecting the current timestep. 430 | cell_state: state of RNNCell. 431 | cell_input: input provided by `dynamic_rnn_decoder`. 432 | cell_output: output of RNNCell. 433 | context_state: context state provided by `dynamic_rnn_decoder`. 434 | Returns: 435 | A tuple (done, next state, next input, emit output, next context state) 436 | where: 437 | done: A boolean vector to indicate which sentences has reached a 438 | `end_of_sequence_id`. This is used for early stopping by the 439 | `dynamic_rnn_decoder`. When `time>=maximum_length` a boolean vector with 440 | all elements as `true` is returned. 441 | next state: `cell_state`, this decoder function does not modify the 442 | given state. 443 | next input: The embedding from argmax of the `cell_output` is used as 444 | `next_input`. 445 | emit output: If `output_fn is None` the supplied `cell_output` is 446 | returned, else the `output_fn` is used to update the `cell_output` 447 | before calculating `next_input` and returning `cell_output`. 448 | next context state: `context_state`, this decoder function does not 449 | modify the given context state. The context state could be modified when 450 | applying e.g. beam search. 451 | Raises: 452 | ValueError: if cell_input is not None. 453 | """ 454 | with ops.name_scope( 455 | name, "attention_decoder_fn_inference", 456 | [time, cell_state, cell_input, cell_output, context_state]): 457 | if cell_input is not None: 458 | raise ValueError("Expected cell_input to be None, but saw: %s" % 459 | cell_input) 460 | if cell_output is None: 461 | # invariant that this is time == 0 462 | next_input_id = array_ops.ones( 463 | [batch_size,], dtype=dtype) * (start_of_sequence_id) 464 | done = array_ops.zeros([batch_size,], dtype=dtypes.bool) 465 | cell_state = encoder_state 466 | cell_output = array_ops.zeros( 467 | [num_decoder_symbols], dtype=dtypes.float32) 468 | cell_input = array_ops.gather(embeddings, next_input_id) 469 | 470 | # init attention 471 | attention = _init_attention(encoder_state) 472 | # init context state 473 | log_beam_probs = tensor_array_ops.TensorArray(dtype=dtypes.float32, tensor_array_name="log_beam_probs", size=maximum_length, dynamic_size=True, infer_shape=False) 474 | beam_parents = tensor_array_ops.TensorArray(dtype=dtypes.int32, tensor_array_name="beam_parents", size=maximum_length, dynamic_size=True, infer_shape=False) 475 | beam_symbols = tensor_array_ops.TensorArray(dtype=dtypes.int32, tensor_array_name="beam_symbols", size=maximum_length, dynamic_size=True, infer_shape=False) 476 | result_probs = tensor_array_ops.TensorArray(dtype=dtypes.float32, tensor_array_name="result_probs", size=maximum_length, dynamic_size=True, infer_shape=False) 477 | result_parents = tensor_array_ops.TensorArray(dtype=dtypes.int32, tensor_array_name="result_parents", size=maximum_length, dynamic_size=True, infer_shape=False) 478 | result_symbols = tensor_array_ops.TensorArray(dtype=dtypes.int32, tensor_array_name="result_symbols", size=maximum_length, dynamic_size=True, infer_shape=False) 479 | context_state = (log_beam_probs, beam_parents, beam_symbols, result_probs, result_parents, result_symbols) 480 | else: 481 | # construct attention 482 | attention = attention_construct_fn(cell_output, attention_keys, 483 | attention_values) 484 | cell_output = attention 485 | 486 | # beam search decoder 487 | (log_beam_probs, beam_parents, beam_symbols, result_probs, result_parents, result_symbols) = context_state 488 | 489 | cell_output = output_fn(cell_output) # logits 490 | cell_output = nn_ops.softmax(cell_output) 491 | 492 | 493 | cell_output = array_ops.split(cell_output, [2, num_decoder_symbols-2], 1)[1] 494 | 495 | tmp_output = array_ops.gather(cell_output, math_ops.range(origin_batch)*beam_size) 496 | 497 | probs = control_flow_ops.cond( 498 | math_ops.equal(time, ops.convert_to_tensor(1, dtype)), 499 | lambda: math_ops.log(tmp_output+ops.convert_to_tensor(1e-20, dtypes.float32)), 500 | lambda: math_ops.log(cell_output+ops.convert_to_tensor(1e-20, dtypes.float32)) + array_ops.reshape(log_beam_probs.read(time-2), [-1, 1])) 501 | 502 | probs = array_ops.reshape(probs, [origin_batch, -1]) 503 | best_probs, indices = nn_ops.top_k(probs, beam_size * 2) 504 | #indices = array_ops.reshape(indices, [-1]) 505 | indices_flatten = array_ops.reshape(indices, [-1]) + array_ops.reshape(array_ops.concat([array_ops.reshape(math_ops.range(origin_batch)*((num_decoder_symbols-2)*beam_size), [-1, 1])]*(beam_size*2), 1), [origin_batch*beam_size*2]) 506 | best_probs_flatten = array_ops.reshape(best_probs, [-1]) 507 | 508 | symbols = indices_flatten % (num_decoder_symbols - 2) 509 | symbols = symbols + 2 510 | parents = indices_flatten // (num_decoder_symbols - 2) 511 | 512 | probs_wo_eos = best_probs + 1e5*math_ops.cast(math_ops.cast((indices%(num_decoder_symbols-2)+2)-end_of_sequence_id, dtypes.bool), dtypes.float32) 513 | 514 | best_probs_wo_eos, indices_wo_eos = nn_ops.top_k(probs_wo_eos, beam_size) 515 | 516 | indices_wo_eos = array_ops.reshape(indices_wo_eos, [-1]) + array_ops.reshape(array_ops.concat([array_ops.reshape(math_ops.range(origin_batch)*(beam_size*2), [-1, 1])]*beam_size, 1), [origin_batch*beam_size]) 517 | 518 | _probs = array_ops.gather(best_probs_flatten, indices_wo_eos) 519 | _symbols = array_ops.gather(symbols, indices_wo_eos) 520 | _parents = array_ops.gather(parents, indices_wo_eos) 521 | 522 | 523 | log_beam_probs = log_beam_probs.write(time-1, _probs) 524 | beam_symbols = beam_symbols.write(time-1, _symbols) 525 | beam_parents = beam_parents.write(time-1, _parents) 526 | result_probs = result_probs.write(time-1, best_probs_flatten) 527 | result_symbols = result_symbols.write(time-1, symbols) 528 | result_parents = result_parents.write(time-1, parents) 529 | 530 | 531 | next_input_id = array_ops.reshape(_symbols, [batch_size]) 532 | 533 | state_size = int(cell_state[0].get_shape().with_rank(2)[1]) 534 | attn_size = int(attention.get_shape().with_rank(2)[1]) 535 | state = [] 536 | for j in cell_state: 537 | state.append(array_ops.reshape(array_ops.gather(j, _parents), [-1, state_size])) 538 | cell_state = tuple(state) 539 | attention = array_ops.reshape(array_ops.gather(attention, _parents), [-1, attn_size]) 540 | 541 | done = math_ops.equal(next_input_id, end_of_sequence_id) 542 | cell_input = array_ops.gather(embeddings, next_input_id) 543 | 544 | # combine cell_input and attention 545 | next_input = array_ops.concat([cell_input, attention], 1) 546 | 547 | # if time > maxlen, return all true vector 548 | done = control_flow_ops.cond( 549 | math_ops.greater(time, maximum_length), 550 | lambda: array_ops.ones([batch_size,], dtype=dtypes.bool), 551 | lambda: array_ops.zeros([batch_size,], dtype=dtypes.bool)) 552 | return (done, cell_state, next_input, cell_output, (log_beam_probs, beam_parents, beam_symbols, result_probs, result_parents, result_symbols))#context_state) 553 | 554 | return decoder_fn 555 | 556 | ## Helper functions ## 557 | def prepare_attention(attention_states, 558 | attention_option, 559 | num_units, 560 | imem=None, 561 | output_alignments=False, 562 | reuse=False): 563 | """Prepare keys/values/functions for attention. 564 | Args: 565 | attention_states: hidden states to attend over. 566 | attention_option: how to compute attention, either "luong" or "bahdanau". 567 | num_units: hidden state dimension. 568 | reuse: whether to reuse variable scope. 569 | Returns: 570 | attention_keys: to be compared with target states. 571 | attention_values: to be used to construct context vectors. 572 | attention_score_fn: to compute similarity between key and target states. 573 | attention_construct_fn: to build attention states. 574 | """ 575 | 576 | # Prepare attention keys / values from attention_states 577 | with variable_scope.variable_scope("attention_keys", reuse=reuse) as scope: 578 | attention_keys = layers.linear( 579 | attention_states, num_units, biases_initializer=None, scope=scope) 580 | attention_values = attention_states 581 | 582 | if imem is not None: 583 | if type(imem) is tuple: 584 | with variable_scope.variable_scope("imem_graph", reuse=reuse) as scope: 585 | attention_keys2, attention_states2 = array_ops.split(layers.linear( 586 | imem[0], num_units*2, biases_initializer=None, scope=scope), [num_units, num_units], axis=2) 587 | with variable_scope.variable_scope("imem_triple", reuse=reuse) as scope: 588 | attention_keys3, attention_states3 = array_ops.split(layers.linear( 589 | imem[1], num_units*2, biases_initializer=None, scope=scope), [num_units, num_units], axis=3) 590 | attention_keys = (attention_keys, attention_keys2, attention_keys3) 591 | attention_values = (attention_states, attention_states2, attention_states3) 592 | else: 593 | with variable_scope.variable_scope("imem", reuse=reuse) as scope: 594 | attention_keys2, attention_states2 = array_ops.split(layers.linear( 595 | imem, num_units*2, biases_initializer=None, scope=scope), [num_units, num_units], axis=2) 596 | attention_keys = (attention_keys, attention_keys2) 597 | attention_values = (attention_states, attention_states2) 598 | 599 | 600 | 601 | # Attention score function 602 | if imem is None: 603 | attention_score_fn = _create_attention_score_fn("attention_score", num_units, 604 | attention_option, reuse) 605 | else: 606 | attention_score_fn = (_create_attention_score_fn("attention_score", num_units, 607 | attention_option, reuse), 608 | _create_attention_score_fn("imem_score", num_units, 609 | "luong", reuse, output_alignments=output_alignments)) 610 | 611 | # Attention construction function 612 | attention_construct_fn = _create_attention_construct_fn("attention_construct", 613 | num_units, 614 | attention_score_fn, 615 | reuse) 616 | 617 | return (attention_keys, attention_values, attention_score_fn, 618 | attention_construct_fn) 619 | 620 | 621 | def _init_attention(encoder_state): 622 | """Initialize attention. Handling both LSTM and GRU. 623 | Args: 624 | encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`. 625 | Returns: 626 | attn: initial zero attention vector. 627 | """ 628 | 629 | # Multi- vs single-layer 630 | # TODO(thangluong): is this the best way to check? 631 | if isinstance(encoder_state, tuple): 632 | top_state = encoder_state[-1] 633 | else: 634 | top_state = encoder_state 635 | 636 | # LSTM vs GRU 637 | if isinstance(top_state, rnn_cell_impl.LSTMStateTuple): 638 | attn = array_ops.zeros_like(top_state.h) 639 | else: 640 | attn = array_ops.zeros_like(top_state) 641 | 642 | return attn 643 | 644 | 645 | def _create_attention_construct_fn(name, num_units, attention_score_fn, reuse): 646 | """Function to compute attention vectors. 647 | Args: 648 | name: to label variables. 649 | num_units: hidden state dimension. 650 | attention_score_fn: to compute similarity between key and target states. 651 | reuse: whether to reuse variable scope. 652 | Returns: 653 | attention_construct_fn: to build attention states. 654 | """ 655 | with variable_scope.variable_scope(name, reuse=reuse) as scope: 656 | 657 | def construct_fn(attention_query, attention_keys, attention_values): 658 | alignments = None 659 | if type(attention_score_fn) is tuple: 660 | context0 = attention_score_fn[0](attention_query, attention_keys[0], 661 | attention_values[0]) 662 | if len(attention_keys) == 2: 663 | context1 = attention_score_fn[1](attention_query, attention_keys[1], 664 | attention_values[1]) 665 | elif len(attention_keys) == 3: 666 | context1 = attention_score_fn[1](attention_query, attention_keys[1:], 667 | attention_values[1:]) 668 | if type(context1) is tuple: 669 | if len(context1) == 2: 670 | context1, alignments = context1 671 | concat_input = array_ops.concat([attention_query, context0, context1], 1) 672 | elif len(context1) == 3: 673 | context1, context2, alignments = context1 674 | concat_input = array_ops.concat([attention_query, context0, context1, context2], 1) 675 | else: 676 | concat_input = array_ops.concat([attention_query, context0, context1], 1) 677 | else: 678 | context = attention_score_fn(attention_query, attention_keys, 679 | attention_values) 680 | concat_input = array_ops.concat([attention_query, context], 1) 681 | attention = layers.linear( 682 | concat_input, num_units, biases_initializer=None, scope=scope) 683 | if alignments is None: 684 | return attention 685 | else: 686 | return attention, alignments 687 | 688 | return construct_fn 689 | 690 | 691 | # keys: [batch_size, attention_length, attn_size] 692 | # query: [batch_size, 1, attn_size] 693 | # return weights [batch_size, attention_length] 694 | @function.Defun(func_name="attn_add_fun", noinline=True) 695 | def _attn_add_fun(v, keys, query): 696 | return math_ops.reduce_sum(v * math_ops.tanh(keys + query), [2]) 697 | 698 | 699 | @function.Defun(func_name="attn_mul_fun", noinline=True) 700 | def _attn_mul_fun(keys, query): 701 | return math_ops.reduce_sum(keys * query, [2]) 702 | 703 | 704 | def _create_attention_score_fn(name, 705 | num_units, 706 | attention_option, 707 | reuse, 708 | output_alignments=False, 709 | dtype=dtypes.float32): 710 | """Different ways to compute attention scores. 711 | Args: 712 | name: to label variables. 713 | num_units: hidden state dimension. 714 | attention_option: how to compute attention, either "luong" or "bahdanau". 715 | "bahdanau": additive (Bahdanau et al., ICLR'2015) 716 | "luong": multiplicative (Luong et al., EMNLP'2015) 717 | reuse: whether to reuse variable scope. 718 | dtype: (default: `dtypes.float32`) data type to use. 719 | Returns: 720 | attention_score_fn: to compute similarity between key and target states. 721 | """ 722 | with variable_scope.variable_scope(name, reuse=reuse): 723 | if attention_option == "bahdanau": 724 | query_w = variable_scope.get_variable( 725 | "attnW", [num_units, num_units], dtype=dtype) 726 | score_v = variable_scope.get_variable("attnV", [num_units], dtype=dtype) 727 | 728 | def attention_score_fn(query, keys, values): 729 | """Put attention masks on attention_values using attention_keys and query. 730 | Args: 731 | query: A Tensor of shape [batch_size, num_units]. 732 | keys: A Tensor of shape [batch_size, attention_length, num_units]. 733 | values: A Tensor of shape [batch_size, attention_length, num_units]. 734 | Returns: 735 | context_vector: A Tensor of shape [batch_size, num_units]. 736 | Raises: 737 | ValueError: if attention_option is neither "luong" or "bahdanau". 738 | """ 739 | triple_keys, triple_values = None, None 740 | 741 | if type(keys) is tuple: 742 | keys, triple_keys = keys 743 | values, triple_values = values 744 | 745 | if attention_option == "bahdanau": 746 | # transform query 747 | query = math_ops.matmul(query, query_w) 748 | 749 | # reshape query: [batch_size, 1, num_units] 750 | query = array_ops.reshape(query, [-1, 1, num_units]) 751 | 752 | 753 | # attn_fun 754 | scores = _attn_add_fun(score_v, keys, query) 755 | elif attention_option == "luong": 756 | # reshape query: [batch_size, 1, num_units] 757 | query = array_ops.reshape(query, [-1, 1, num_units]) 758 | 759 | # attn_fun 760 | scores = _attn_mul_fun(keys, query) 761 | else: 762 | raise ValueError("Unknown attention option %s!" % attention_option) 763 | 764 | # Compute alignment weights 765 | # scores: [batch_size, length] 766 | # alignments: [batch_size, length] 767 | # TODO(thangluong): not normalize over padding positions. 768 | alignments = nn_ops.softmax(scores) 769 | #alignments = tf.Print(alignments, [alignments], summarize=1000) 770 | 771 | 772 | # Now calculate the attention-weighted vector. 773 | new_alignments = array_ops.expand_dims(alignments, 2) 774 | context_vector = math_ops.reduce_sum(new_alignments * values, [1]) 775 | 776 | context_vector.set_shape([None, num_units]) 777 | 778 | if triple_values is not None: 779 | triple_scores = math_ops.reduce_sum(triple_keys * array_ops.reshape(query, [-1, 1, 1, num_units]), [3]) 780 | triple_alignments = nn_ops.softmax(triple_scores) 781 | context_triples = math_ops.reduce_sum(array_ops.expand_dims(triple_alignments, 3) * triple_values, [2]) 782 | context_graph_triples = math_ops.reduce_sum(new_alignments * context_triples, [1]) 783 | context_graph_triples.set_shape([None, num_units]) 784 | return context_vector, context_graph_triples, new_alignments * triple_alignments 785 | else: 786 | if output_alignments: 787 | return context_vector, alignments 788 | else: 789 | return context_vector 790 | 791 | return attention_score_fn 792 | -------------------------------------------------------------------------------- /copynet/dynamic_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Seq2seq layer operations for use in neural networks. 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | from tensorflow.contrib import layers 24 | from tensorflow.python.framework import ops 25 | from tensorflow.python.ops import array_ops 26 | from tensorflow.python.ops import control_flow_ops 27 | from tensorflow.python.ops import math_ops 28 | from tensorflow.python.ops import rnn 29 | from tensorflow.python.ops import tensor_array_ops 30 | from tensorflow.python.ops import variable_scope as vs 31 | 32 | __all__ = ["dynamic_rnn_decoder"] 33 | 34 | def dynamic_rnn_decoder(cell, decoder_fn, inputs=None, sequence_length=None, 35 | parallel_iterations=None, swap_memory=False, 36 | time_major=False, scope=None, name=None): 37 | """ Dynamic RNN decoder for a sequence-to-sequence model specified by 38 | RNNCell and decoder function. 39 | 40 | The `dynamic_rnn_decoder` is similar to the `tf.python.ops.rnn.dynamic_rnn` 41 | as the decoder does not make any assumptions of sequence length and batch 42 | size of the input. 43 | 44 | The `dynamic_rnn_decoder` has two modes: training or inference and expects 45 | the user to create seperate functions for each. 46 | 47 | Under both training and inference, both `cell` and `decoder_fn` are expected, 48 | where `cell` performs computation at every timestep using `raw_rnn`, and 49 | `decoder_fn` allows modeling of early stopping, output, state, and next 50 | input and context. 51 | 52 | When training the user is expected to supply `inputs`. At every time step a 53 | slice of the supplied input is fed to the `decoder_fn`, which modifies and 54 | returns the input for the next time step. 55 | 56 | `sequence_length` is needed at training time, i.e., when `inputs` is not 57 | None, for dynamic unrolling. At test time, when `inputs` is None, 58 | `sequence_length` is not needed. 59 | 60 | Under inference `inputs` is expected to be `None` and the input is inferred 61 | solely from the `decoder_fn`. 62 | 63 | Args: 64 | cell: An instance of RNNCell. 65 | decoder_fn: A function that takes time, cell state, cell input, 66 | cell output and context state. It returns a early stopping vector, 67 | cell state, next input, cell output and context state. 68 | Examples of decoder_fn can be found in the decoder_fn.py folder. 69 | inputs: The inputs for decoding (embedded format). 70 | 71 | If `time_major == False` (default), this must be a `Tensor` of shape: 72 | `[batch_size, max_time, ...]`. 73 | 74 | If `time_major == True`, this must be a `Tensor` of shape: 75 | `[max_time, batch_size, ...]`. 76 | 77 | The input to `cell` at each time step will be a `Tensor` with dimensions 78 | `[batch_size, ...]`. 79 | 80 | sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. 81 | if `inputs` is not None and `sequence_length` is None it is inferred 82 | from the `inputs` as the maximal possible sequence length. 83 | parallel_iterations: (Default: 32). The number of iterations to run in 84 | parallel. Those operations which do not have any temporal dependency 85 | and can be run in parallel, will be. This parameter trades off 86 | time for space. Values >> 1 use more memory but take less time, 87 | while smaller values use less memory but computations take longer. 88 | swap_memory: Transparently swap the tensors produced in forward inference 89 | but needed for back prop from GPU to CPU. This allows training RNNs 90 | which would typically not fit on a single GPU, with very minimal (or no) 91 | performance penalty. 92 | time_major: The shape format of the `inputs` and `outputs` Tensors. 93 | If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. 94 | If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. 95 | Using `time_major = True` is a bit more efficient because it avoids 96 | transposes at the beginning and end of the RNN calculation. However, 97 | most TensorFlow data is batch-major, so by default this function 98 | accepts input and emits output in batch-major form. 99 | scope: VariableScope for the `raw_rnn`; 100 | defaults to None. 101 | name: NameScope for the decoder; 102 | defaults to "dynamic_rnn_decoder" 103 | 104 | Returns: 105 | A tuple (outputs, final_state, final_context_state) where: 106 | 107 | outputs: the RNN output 'Tensor'. 108 | 109 | If time_major == False (default), this will be a `Tensor` shaped: 110 | `[batch_size, max_time, cell.output_size]`. 111 | 112 | If time_major == True, this will be a `Tensor` shaped: 113 | `[max_time, batch_size, cell.output_size]`. 114 | 115 | final_state: The final state and will be shaped 116 | `[batch_size, cell.state_size]`. 117 | 118 | final_context_state: The context state returned by the final call 119 | to decoder_fn. This is useful if the context state maintains internal 120 | data which is required after the graph is run. 121 | For example, one way to diversify the inference output is to use 122 | a stochastic decoder_fn, in which case one would want to store the 123 | decoded outputs, not just the RNN outputs. This can be done by 124 | maintaining a TensorArray in context_state and storing the decoded 125 | output of each iteration therein. 126 | 127 | Raises: 128 | ValueError: if inputs is not None and has less than three dimensions. 129 | """ 130 | with ops.name_scope(name, "dynamic_rnn_decoder", 131 | [cell, decoder_fn, inputs, sequence_length, 132 | parallel_iterations, swap_memory, time_major, scope]): 133 | if inputs is not None: 134 | # Convert to tensor 135 | inputs = ops.convert_to_tensor(inputs) 136 | 137 | # Test input dimensions 138 | if inputs.get_shape().ndims is not None and ( 139 | inputs.get_shape().ndims < 2): 140 | raise ValueError("Inputs must have at least two dimensions") 141 | # Setup of RNN (dimensions, sizes, length, initial state, dtype) 142 | if not time_major: 143 | # [batch, seq, features] -> [seq, batch, features] 144 | inputs = array_ops.transpose(inputs, perm=[1, 0, 2]) 145 | 146 | dtype = inputs.dtype 147 | # Get data input information 148 | input_depth = int(inputs.get_shape()[2]) 149 | batch_depth = inputs.get_shape()[1].value 150 | max_time = inputs.get_shape()[0].value 151 | if max_time is None: 152 | max_time = array_ops.shape(inputs)[0] 153 | # Setup decoder inputs as TensorArray 154 | inputs_ta = tensor_array_ops.TensorArray(dtype, size=max_time) 155 | inputs_ta = inputs_ta.unstack(inputs) 156 | 157 | def loop_fn(time, cell_output, cell_state, loop_state): 158 | if cell_state is None: # first call, before while loop (in raw_rnn) 159 | if cell_output is not None: 160 | raise ValueError("Expected cell_output to be None when cell_state " 161 | "is None, but saw: %s" % cell_output) 162 | if loop_state is not None: 163 | raise ValueError("Expected loop_state to be None when cell_state " 164 | "is None, but saw: %s" % loop_state) 165 | context_state = None 166 | else: # subsequent calls, inside while loop, after cell excution 167 | if isinstance(loop_state, tuple): 168 | (done, context_state) = loop_state 169 | else: 170 | done = loop_state 171 | context_state = None 172 | 173 | # call decoder function 174 | if inputs is not None: # training 175 | # get next_cell_input 176 | if cell_state is None: 177 | next_cell_input = inputs_ta.read(0) 178 | else: 179 | if batch_depth is not None: 180 | batch_size = batch_depth 181 | else: 182 | batch_size = array_ops.shape(done)[0] 183 | next_cell_input = control_flow_ops.cond( 184 | math_ops.equal(time, max_time), 185 | lambda: array_ops.zeros([batch_size, input_depth], dtype=dtype), 186 | lambda: inputs_ta.read(time)) 187 | (next_done, next_cell_state, next_cell_input, emit_output, 188 | next_context_state) = decoder_fn(time, cell_state, next_cell_input, 189 | cell_output, context_state) 190 | else: # inference 191 | # next_cell_input is obtained through decoder_fn 192 | (next_done, next_cell_state, next_cell_input, emit_output, 193 | next_context_state) = decoder_fn(time, cell_state, None, cell_output, 194 | context_state) 195 | 196 | # check if we are done 197 | if next_done is None: # training 198 | next_done = time >= sequence_length 199 | 200 | # build next_loop_state 201 | if next_context_state is None: 202 | next_loop_state = next_done 203 | else: 204 | next_loop_state = (next_done, next_context_state) 205 | 206 | return (next_done, next_cell_input, next_cell_state, 207 | emit_output, next_loop_state) 208 | 209 | # Run raw_rnn function 210 | outputs_ta, final_state, final_loop_state = rnn.raw_rnn( 211 | cell, loop_fn, parallel_iterations=parallel_iterations, 212 | swap_memory=swap_memory, scope=scope) 213 | outputs = outputs_ta.stack() 214 | 215 | # Get final context_state, if generated by user 216 | if isinstance(final_loop_state, tuple): 217 | final_context_state = final_loop_state[1] 218 | else: 219 | final_context_state = None 220 | 221 | if not time_major: 222 | # [seq, batch, features] -> [batch, seq, features] 223 | outputs = array_ops.transpose(outputs, perm=[1, 0, 2]) 224 | return outputs, final_state, final_context_state 225 | -------------------------------------------------------------------------------- /copynet/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.framework import constant_op 4 | import sys 5 | import json 6 | import math 7 | import os 8 | import time 9 | import random 10 | import sqlite3 11 | random.seed(time.time()) 12 | from model import Model, _START_VOCAB 13 | 14 | tf.app.flags.DEFINE_boolean("is_train", True, "Set to False to inference.") 15 | tf.app.flags.DEFINE_integer("symbols", 30000, "vocabulary size.") 16 | tf.app.flags.DEFINE_integer("num_entities", 21471, "entitiy vocabulary size.") 17 | tf.app.flags.DEFINE_integer("num_relations", 44, "relation size.") 18 | tf.app.flags.DEFINE_integer("embed_units", 300, "Size of word embedding.") 19 | tf.app.flags.DEFINE_integer("trans_units", 100, "Size of trans embedding.") 20 | tf.app.flags.DEFINE_integer("units", 512, "Size of each model layer.") 21 | tf.app.flags.DEFINE_integer("layers", 2, "Number of layers in the model.") 22 | tf.app.flags.DEFINE_boolean("copy_use", True, "use copy mechanism or not.") 23 | tf.app.flags.DEFINE_integer("batch_size", 100, "Batch size to use during training.") 24 | tf.app.flags.DEFINE_string("data_dir", "./data", "Data directory") 25 | tf.app.flags.DEFINE_string("train_dir", "./train", "Training directory.") 26 | tf.app.flags.DEFINE_integer("per_checkpoint", 1000, "How many steps to do per checkpoint.") 27 | tf.app.flags.DEFINE_integer("inference_version", 0, "The version for inferencing.") 28 | tf.app.flags.DEFINE_boolean("log_parameters", True, "Set to True to show the parameters") 29 | tf.app.flags.DEFINE_string("inference_path", "test", "Set filename of inference, default isscreen") 30 | 31 | FLAGS = tf.app.flags.FLAGS 32 | if FLAGS.train_dir[-1] == '/': FLAGS.train_dir = FLAGS.train_dir[:-1] 33 | csk_triples, csk_entities, kb_dict = [], [], [] 34 | 35 | def prepare_data(path, is_train=True): 36 | global csk_entities, csk_triples, kb_dict 37 | 38 | with open('%s/resource.txt' % path) as f: 39 | d = json.loads(f.readline()) 40 | 41 | csk_triples = d['csk_triples'] 42 | csk_entities = d['csk_entities'] 43 | raw_vocab = d['vocab_dict'] 44 | kb_dict = d['dict_csk'] 45 | 46 | data_train, data_dev, data_test = [], [], [] 47 | 48 | if is_train: 49 | with open('%s/trainset.txt' % path) as f: 50 | for idx, line in enumerate(f): 51 | #if idx == 100000: break 52 | if idx % 100000 == 0: print('read train file line %d' % idx) 53 | data_train.append(json.loads(line)) 54 | 55 | with open('%s/validset.txt' % path) as f: 56 | for line in f: 57 | data_dev.append(json.loads(line)) 58 | 59 | with open('%s/testset.txt' % path) as f: 60 | for line in f: 61 | data_test.append(json.loads(line)) 62 | 63 | return raw_vocab, data_train, data_dev, data_test 64 | 65 | def build_vocab(path, raw_vocab, trans='transE'): 66 | print("Creating word vocabulary...") 67 | vocab_list = _START_VOCAB + sorted(raw_vocab, key=raw_vocab.get, reverse=True) 68 | if len(vocab_list) > FLAGS.symbols: 69 | vocab_list = vocab_list[:FLAGS.symbols] 70 | 71 | print("Creating entity vocabulary...") 72 | entity_list = ['_NONE', '_PAD_H', '_PAD_R', '_PAD_T', '_NAF_H', '_NAF_R', '_NAF_T'] 73 | with open('%s/entity.txt' % path) as f: 74 | for i, line in enumerate(f): 75 | e = line.strip() 76 | entity_list.append(e) 77 | 78 | print("Creating relation vocabulary...") 79 | relation_list = [] 80 | with open('%s/relation.txt' % path) as f: 81 | for i, line in enumerate(f): 82 | r = line.strip() 83 | relation_list.append(r) 84 | 85 | print("Loading word vectors...") 86 | vectors = {} 87 | with open('%s/glove.840B.300d.txt' % path) as f: 88 | for i, line in enumerate(f): 89 | if i % 100000 == 0: 90 | print(" processing line %d" % i) 91 | s = line.strip() 92 | word = s[:s.find(' ')] 93 | vector = s[s.find(' ')+1:] 94 | vectors[word] = vector 95 | 96 | embed = [] 97 | for word in vocab_list: 98 | if word in vectors: 99 | vector = map(float, vectors[word].split()) 100 | else: 101 | vector = np.zeros((FLAGS.embed_units), dtype=np.float32) 102 | embed.append(vector) 103 | embed = np.array(embed, dtype=np.float32) 104 | 105 | print("Loading entity vectors...") 106 | entity_embed = [] 107 | with open('%s/entity_%s.txt' % (path, trans)) as f: 108 | for i, line in enumerate(f): 109 | s = line.strip().split('\t') 110 | entity_embed.append(map(float, s)) 111 | 112 | print("Loading relation vectors...") 113 | relation_embed = [] 114 | with open('%s/relation_%s.txt' % (path, trans)) as f: 115 | for i, line in enumerate(f): 116 | s = line.strip().split('\t') 117 | relation_embed.append(s) 118 | 119 | entity_relation_embed = np.array(entity_embed+relation_embed, dtype=np.float32) 120 | entity_embed = np.array(entity_embed, dtype=np.float32) 121 | relation_embed = np.array(relation_embed, dtype=np.float32) 122 | 123 | return vocab_list, embed, entity_list, entity_embed, relation_list, relation_embed, entity_relation_embed 124 | 125 | def gen_batched_data(data): 126 | global csk_entities, csk_triples, kb_dict 127 | encoder_len = max([len(item['post']) for item in data])+1 128 | decoder_len = max([len(item['response']) for item in data])+1 129 | triple_len = max([sum([len(tri) for tri in item['all_triples']]) for item in data ])+1 130 | max_length = 20 131 | posts, responses, posts_length, responses_length = [], [], [], [] 132 | entities, triples, matches, post_triples, response_triples = [], [], [], [], [] 133 | match_entities, all_entities = [], [] 134 | match_triples, all_triples = [], [] 135 | NAF = ['_NAF_H', '_NAF_R', '_NAF_T'] 136 | PAD = ['_PAD_H', '_PAD_R', '_PAD_T'] 137 | 138 | def padding(sent, l): 139 | return sent + ['_EOS'] + ['_PAD'] * (l-len(sent)-1) 140 | 141 | def padding_triple(triple, l): 142 | return [NAF] + triple + [PAD] * (l - len(triple) - 1) 143 | 144 | for item in data: 145 | posts.append(padding(item['post'], encoder_len)) 146 | responses.append(padding(item['response'], decoder_len)) 147 | posts_length.append(len(item['post'])+1) 148 | responses_length.append(len(item['response'])+1) 149 | all_triples.append(padding_triple([csk_triples[x].split(', ') for triple in item['all_triples'] for x in triple], triple_len)) 150 | match_index = [] 151 | for x in item['match_index']: 152 | _index = [-1] * triple_len 153 | if x[0] == -1 and x[1] == -1: 154 | match_index.append(-1) 155 | else: 156 | match_index.append(sum([len(m) for m in item['all_triples'][:(x[0]-1)]]) + 1 + x[1]) 157 | match_triples.append(match_index + [-1]*(decoder_len-len(match_index))) 158 | 159 | if not FLAGS.is_train: 160 | entity = ['_NONE'] 161 | entity += [csk_entities[x] for ent in item['all_entities'] for x in ent] 162 | entities.append(entity+['_NONE']*(triple_len-len(entity))) 163 | 164 | 165 | batched_data = {'posts': np.array(posts), 166 | 'responses': np.array(responses), 167 | 'posts_length': posts_length, 168 | 'responses_length': responses_length, 169 | 'triples': np.array(all_triples), 170 | 'entities': np.array(entities), 171 | 'match_triples': np.array(match_triples)} 172 | return batched_data 173 | 174 | def train(model, sess, data_train): 175 | batched_data = gen_batched_data(data_train) 176 | outputs = model.step_decoder(sess, batched_data, kb_use=True) 177 | return np.sum(outputs[0]) 178 | 179 | def generate_summary(model, sess, data_train): 180 | selected_data = [random.choice(data_train) for i in range(FLAGS.batch_size)] 181 | batched_data = gen_batched_data(selected_data) 182 | summary = model.step_decoder(sess, batched_data, kb_use=True, forward_only=True, summary=True)[-1] 183 | return summary 184 | 185 | 186 | def evaluate(model, sess, data_dev, summary_writer): 187 | loss = np.zeros((1, )) 188 | st, ed, times = 0, FLAGS.batch_size, 0 189 | while st < len(data_dev): 190 | selected_data = data_dev[st:ed] 191 | batched_data = gen_batched_data(selected_data) 192 | outputs = model.step_decoder(sess, batched_data, kb_use=True, forward_only=True) 193 | loss += np.sum(outputs[0]) 194 | st, ed = ed, ed+FLAGS.batch_size 195 | times += 1 196 | loss /= len(data_dev) 197 | summary = tf.Summary() 198 | summary.value.add(tag='decoder_loss/dev', simple_value=loss) 199 | summary.value.add(tag='perplexity/dev', simple_value=np.exp(loss)) 200 | summary_writer.add_summary(summary, model.global_step.eval()) 201 | print(' perplexity on dev set: %.2f' % np.exp(loss)) 202 | 203 | 204 | def get_steps(train_dir): 205 | a = os.walk(train_dir) 206 | for root, dirs, files in a: 207 | if root == train_dir: 208 | filenames = files 209 | 210 | steps, metafiles, datafiles, indexfiles = [], [], [], [] 211 | for filename in filenames: 212 | if 'meta' in filename: 213 | metafiles.append(filename) 214 | if 'data' in filename: 215 | datafiles.append(filename) 216 | if 'index' in filename: 217 | indexfiles.append(filename) 218 | 219 | metafiles.sort() 220 | datafiles.sort() 221 | indexfiles.sort(reverse=True) 222 | 223 | for f in indexfiles: 224 | steps.append(int(f[11:-6])) 225 | 226 | return steps 227 | 228 | def test(sess, saver, data_dev, setnum=5000): 229 | with open('%s/stopwords' % FLAGS.data_dir) as f: 230 | stopwords = json.loads(f.readline()) 231 | steps = get_steps(FLAGS.train_dir) 232 | low_step = 00000 233 | high_step = 800000 234 | with open('%s.res' % FLAGS.inference_path, 'w') as resfile, open('%s.log' % FLAGS.inference_path, 'w') as outfile: 235 | for step in [step for step in steps if step > low_step and step < high_step]: 236 | outfile.write('test for model-%d\n' % step) 237 | model_path = '%s/checkpoint-%08d' % (FLAGS.train_dir, step) 238 | print('restore from %s' % model_path) 239 | try: 240 | saver.restore(sess, model_path) 241 | except: 242 | continue 243 | st, ed = 0, FLAGS.batch_size 244 | results = [] 245 | loss = [] 246 | while st < len(data_dev): 247 | selected_data = data_dev[st:ed] 248 | batched_data = gen_batched_data(selected_data) 249 | responses, ppx_loss = sess.run(['decoder_1/generation:0', 'decoder/ppx_loss:0'], {'enc_inps:0': batched_data['posts'], 'enc_lens:0': batched_data['posts_length'], 'dec_inps:0': batched_data['responses'], 'dec_lens:0': batched_data['responses_length'], 'entities:0': batched_data['entities'], 'triples:0': batched_data['triples'], 'match_triples:0': batched_data['match_triples']}) 250 | loss += [x for x in ppx_loss] 251 | for response in responses: 252 | result = [] 253 | for token in response: 254 | if token != '_EOS': 255 | result.append(token) 256 | else: 257 | break 258 | results.append(result) 259 | st, ed = ed, ed+FLAGS.batch_size 260 | match_entity_sum = [.0] * 4 261 | cnt = 0 262 | for post, response, result, match_triples, triples, entities in zip([data['post'] for data in data_dev], [data['response'] for data in data_dev], results, [data['match_triples'] for data in data_dev], [data['all_triples'] for data in data_dev], [data['all_entities'] for data in data_dev]): 263 | setidx = cnt / setnum 264 | result_matched_entities = [] 265 | triples = [csk_triples[tri] for triple in triples for tri in triple] 266 | match_triples = [csk_triples[triple] for triple in match_triples] 267 | entities = [csk_entities[x] for entity in entities for x in entity] 268 | matches = [x for triple in match_triples for x in [triple.split(', ')[0], triple.split(', ')[2]] if x in response] 269 | 270 | for word in result: 271 | if word not in stopwords and word in entities: 272 | result_matched_entities.append(word) 273 | outfile.write('post: %s\nresponse: %s\nresult: %s\nmatch_entity: %s\n\n' % (' '.join(post), ' '.join(response), ' '.join(result), ' '.join(result_matched_entities))) 274 | match_entity_sum[setidx] += len(set(result_matched_entities)) 275 | cnt += 1 276 | match_entity_sum = [m / setnum for m in match_entity_sum] + [sum(match_entity_sum) / len(data_dev)] 277 | losses = [np.sum(loss[x:x+setnum]) / float(setnum) for x in range(0, setnum*4, setnum)] + [np.sum(loss) / float(setnum*4)] 278 | losses = [np.exp(x) for x in losses] 279 | def show(x): 280 | return ', '.join([str(v) for v in x]) 281 | outfile.write('model: %d\n\tperplexity: %s\n\tmatch_entity_rate: %s\n%s\n\n' % (step, show(losses), show(match_entity_sum), '='*50)) 282 | resfile.write('model: %d\n\tperplexity: %s\n\tmatch_entity_rate: %s\n\n' % (step, show(losses), show(match_entity_sum))) 283 | outfile.flush() 284 | resfile.flush() 285 | return results 286 | 287 | config = tf.ConfigProto() 288 | config.gpu_options.allow_growth = True 289 | with tf.Session(config=config) as sess: 290 | if FLAGS.is_train: 291 | raw_vocab, data_train, data_dev, data_test = prepare_data(FLAGS.data_dir) 292 | vocab, embed, entity_vocab, entity_embed, relation_vocab, relation_embed, entity_relation_embed = build_vocab(FLAGS.data_dir, raw_vocab) 293 | FLAGS.num_entities = len(entity_vocab) 294 | print(FLAGS.__flags) 295 | model = Model( 296 | FLAGS.symbols, 297 | FLAGS.embed_units, 298 | FLAGS.units, 299 | FLAGS.layers, 300 | embed, 301 | entity_relation_embed, 302 | num_entities=len(entity_vocab)+len(relation_vocab), 303 | num_trans_units=FLAGS.trans_units, 304 | output_alignments=FLAGS.copy_use) 305 | if tf.train.get_checkpoint_state(FLAGS.train_dir): 306 | print("Reading model parameters from %s" % FLAGS.train_dir) 307 | model.saver.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir)) 308 | else: 309 | print("Created model with fresh parameters.") 310 | tf.global_variables_initializer().run() 311 | op_in = model.symbol2index.insert(constant_op.constant(vocab), 312 | constant_op.constant(range(FLAGS.symbols), dtype=tf.int64)) 313 | sess.run(op_in) 314 | op_out = model.index2symbol.insert(constant_op.constant( 315 | range(FLAGS.symbols), dtype=tf.int64), constant_op.constant(vocab)) 316 | sess.run(op_out) 317 | op_in = model.entity2index.insert(constant_op.constant(entity_vocab+relation_vocab), 318 | constant_op.constant(range(len(entity_vocab)+len(relation_vocab)), dtype=tf.int64)) 319 | sess.run(op_in) 320 | op_out = model.index2entity.insert(constant_op.constant( 321 | range(len(entity_vocab)+len(relation_vocab)), dtype=tf.int64), constant_op.constant(entity_vocab+relation_vocab)) 322 | sess.run(op_out) 323 | 324 | if FLAGS.log_parameters: 325 | model.print_parameters() 326 | 327 | summary_writer = tf.summary.FileWriter('%s/log' % FLAGS.train_dir, sess.graph) 328 | loss_step, time_step = np.zeros((1, )), .0 329 | previous_losses = [1e18]*3 330 | train_len = len(data_train) 331 | while True: 332 | st, ed = 0, FLAGS.batch_size * FLAGS.per_checkpoint 333 | random.shuffle(data_train) 334 | while st < train_len: 335 | start_time = time.time() 336 | for batch in range(st, ed, FLAGS.batch_size): 337 | loss_step += train(model, sess, data_train[batch:batch+FLAGS.batch_size]) / (ed - st) 338 | 339 | show = lambda a: '[%s]' % (' '.join(['%.2f' % x for x in a])) 340 | print("global step %d learning rate %.4f step-time %.2f loss %f perplexity %s" 341 | % (model.global_step.eval(), model.lr, 342 | (time.time() - start_time) / (ed - st) / FLAGS.batch_size, loss_step, show(np.exp(loss_step)))) 343 | model.saver.save(sess, '%s/checkpoint' % FLAGS.train_dir, 344 | global_step=model.global_step) 345 | summary = tf.Summary() 346 | summary.value.add(tag='decoder_loss/train', simple_value=loss_step) 347 | summary.value.add(tag='perplexity/train', simple_value=np.exp(loss_step)) 348 | summary_writer.add_summary(summary, model.global_step.eval()) 349 | summary_model = generate_summary(model, sess, data_train) 350 | summary_writer.add_summary(summary_model, model.global_step.eval()) 351 | evaluate(model, sess, data_dev, summary_writer) 352 | previous_losses = previous_losses[1:]+[np.sum(loss_step)] 353 | loss_step, time_step = np.zeros((1, )), .0 354 | st, ed = ed, min(train_len, ed + FLAGS.batch_size * FLAGS.per_checkpoint) 355 | model.saver_epoch.save(sess, '%s/epoch/checkpoint' % FLAGS.train_dir, global_step=model.global_step) 356 | else: 357 | model = Model( 358 | FLAGS.symbols, 359 | FLAGS.embed_units, 360 | FLAGS.units, 361 | FLAGS.layers, 362 | embed=None, 363 | num_entities=FLAGS.num_entities+FLAGS.num_relations, 364 | num_trans_units=FLAGS.trans_units, 365 | output_alignments=FLAGS.copy_use) 366 | 367 | if FLAGS.inference_version == 0: 368 | model_path = tf.train.latest_checkpoint(FLAGS.train_dir) 369 | else: 370 | model_path = '%s/checkpoint-%08d' % (FLAGS.train_dir, FLAGS.inference_version) 371 | print('restore from %s' % model_path) 372 | model.saver.restore(sess, model_path) 373 | saver = model.saver 374 | 375 | raw_vocab, data_train, data_dev, data_test = prepare_data(FLAGS.data_dir, is_train=False) 376 | 377 | test(sess, saver, data_test, setnum=5000) 378 | 379 | -------------------------------------------------------------------------------- /copynet/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tensorflow.python.ops.nn import dynamic_rnn 5 | from tensorflow.contrib.rnn import GRUCell, LSTMCell, MultiRNNCell 6 | from tensorflow.contrib.seq2seq.python.ops.loss import sequence_loss 7 | from tensorflow.contrib.lookup.lookup_ops import MutableHashTable 8 | from tensorflow.contrib.layers.python.layers import layers 9 | from dynamic_decoder import dynamic_rnn_decoder 10 | from output_projection import output_projection_layer 11 | from attention_decoder import * 12 | from tensorflow.contrib.session_bundle import exporter 13 | 14 | PAD_ID = 0 15 | UNK_ID = 1 16 | GO_ID = 2 17 | EOS_ID = 3 18 | NONE_ID = 0 19 | _START_VOCAB = ['_PAD', '_UNK', '_GO', '_EOS'] 20 | 21 | class Model(object): 22 | def __init__(self, 23 | num_symbols, 24 | num_embed_units, 25 | num_units, 26 | num_layers, 27 | embed, 28 | entity_embed=None, 29 | num_entities=0, 30 | num_trans_units=100, 31 | learning_rate=0.0001, 32 | learning_rate_decay_factor=0.95, 33 | max_gradient_norm=5.0, 34 | num_samples=512, 35 | max_length=60, 36 | output_alignments=True, 37 | use_lstm=False): 38 | 39 | self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # batch*len 40 | self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # batch 41 | self.responses = tf.placeholder(tf.string, (None, None), 'dec_inps') # batch*len 42 | self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # batch 43 | self.entities = tf.placeholder(tf.string, (None, None), 'entities') # batch 44 | self.entity_masks = tf.placeholder(tf.string, (None, None), 'entity_masks') # batch 45 | self.triples = tf.placeholder(tf.string, (None, None, 3), 'triples') # batch 46 | self.posts_triple = tf.placeholder(tf.int32, (None, None, 1), 'enc_triples') # batch 47 | self.responses_triple = tf.placeholder(tf.string, (None, None, 3), 'dec_triples') # batch 48 | self.match_triples = tf.placeholder(tf.int32, (None, None), 'match_triples') # batch 49 | encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) 50 | triple_num = tf.shape(self.triples)[1] 51 | 52 | #use_triples = tf.reduce_sum(tf.cast(tf.greater_equal(self.match_triples, 0), tf.float32), axis=-1) 53 | one_hot_triples = tf.one_hot(self.match_triples, triple_num) 54 | use_triples = tf.reduce_sum(one_hot_triples, axis=[2]) 55 | 56 | self.symbol2index = MutableHashTable( 57 | key_dtype=tf.string, 58 | value_dtype=tf.int64, 59 | default_value=UNK_ID, 60 | shared_name="in_table", 61 | name="in_table", 62 | checkpoint=True) 63 | self.index2symbol = MutableHashTable( 64 | key_dtype=tf.int64, 65 | value_dtype=tf.string, 66 | default_value='_UNK', 67 | shared_name="out_table", 68 | name="out_table", 69 | checkpoint=True) 70 | self.entity2index = MutableHashTable( 71 | key_dtype=tf.string, 72 | value_dtype=tf.int64, 73 | default_value=NONE_ID, 74 | shared_name="entity_in_table", 75 | name="entity_in_table", 76 | checkpoint=True) 77 | self.index2entity = MutableHashTable( 78 | key_dtype=tf.int64, 79 | value_dtype=tf.string, 80 | default_value='_NONE', 81 | shared_name="entity_out_table", 82 | name="entity_out_table", 83 | checkpoint=True) 84 | # build the vocab table (string to index) 85 | 86 | 87 | self.posts_word_id = self.symbol2index.lookup(self.posts) # batch*len 88 | self.posts_entity_id = self.entity2index.lookup(self.posts) # batch*len 89 | #self.posts_word_id = tf.Print(self.posts_word_id, ['use_triples', use_triples, 'one_hot_triples', one_hot_triples], summarize=1e6) 90 | self.responses_target = self.symbol2index.lookup(self.responses) #batch*len 91 | 92 | batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(self.responses)[1] 93 | self.responses_word_id = tf.concat([tf.ones([batch_size, 1], dtype=tf.int64)*GO_ID, 94 | tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1) # batch*len 95 | self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1, 96 | decoder_len), reverse=True, axis=1), [-1, decoder_len]) 97 | 98 | # build the embedding table (index to vector) 99 | if embed is None: 100 | # initialize the embedding randomly 101 | self.embed = tf.get_variable('word_embed', [num_symbols, num_embed_units], tf.float32) 102 | else: 103 | # initialize the embedding by pre-trained word vectors 104 | self.embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed) 105 | if entity_embed is None: 106 | # initialize the embedding randomly 107 | self.entity_trans = tf.get_variable('entity_embed', [num_entities, num_trans_units], tf.float32, trainable=False) 108 | else: 109 | # initialize the embedding by pre-trained word vectors 110 | self.entity_trans = tf.get_variable('entity_embed', dtype=tf.float32, initializer=entity_embed, trainable=False) 111 | 112 | self.entity_trans_transformed = tf.layers.dense(self.entity_trans, num_trans_units, activation=tf.tanh, name='trans_transformation') 113 | padding_entity = tf.get_variable('entity_padding_embed', [7, num_trans_units], dtype=tf.float32, initializer=tf.zeros_initializer()) 114 | 115 | self.entity_embed = tf.concat([padding_entity, self.entity_trans_transformed], axis=0) 116 | 117 | triples_embedding = tf.reshape(tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.triples)), [encoder_batch_size, triple_num, 3 * num_trans_units]) 118 | entities_word_embedding = tf.reshape(tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entities)), [encoder_batch_size, -1, num_embed_units]) 119 | 120 | 121 | self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts_word_id) #batch*len*unit 122 | self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_word_id) #batch*len*unit 123 | 124 | encoder_cell = MultiRNNCell([GRUCell(num_units) for _ in range(num_layers)]) 125 | decoder_cell = MultiRNNCell([GRUCell(num_units) for _ in range(num_layers)]) 126 | 127 | # rnn encoder 128 | encoder_output, encoder_state = dynamic_rnn(encoder_cell, self.encoder_input, 129 | self.posts_length, dtype=tf.float32, scope="encoder") 130 | 131 | # get output projection function 132 | output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss = output_projection_layer(num_units, 133 | num_symbols, num_samples) 134 | 135 | 136 | 137 | with tf.variable_scope('decoder'): 138 | # get attention function 139 | attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \ 140 | = prepare_attention(encoder_output, 'bahdanau', num_units, imem=triples_embedding, output_alignments=output_alignments)#'luong', num_units) 141 | 142 | decoder_fn_train = attention_decoder_fn_train( 143 | encoder_state, attention_keys_init, attention_values_init, 144 | attention_score_fn_init, attention_construct_fn_init, output_alignments=output_alignments, max_length=tf.reduce_max(self.responses_length)) 145 | self.decoder_output, _, alignments_ta = dynamic_rnn_decoder(decoder_cell, decoder_fn_train, 146 | self.decoder_input, self.responses_length, scope="decoder_rnn") 147 | if output_alignments: 148 | self.alignments = tf.transpose(alignments_ta.stack(), perm=[1,0,2]) 149 | #self.alignments = tf.Print(self.alignments, [self.alignments], summarize=1e8) 150 | self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss(self.decoder_output, self.responses_target, self.decoder_mask, self.alignments, triples_embedding, use_triples, one_hot_triples) 151 | self.sentence_ppx = tf.identity(self.sentence_ppx, 'ppx_loss') 152 | #self.decoder_loss = tf.Print(self.decoder_loss, ['decoder_loss', self.decoder_loss], summarize=1e6) 153 | else: 154 | self.decoder_loss, self.sentence_ppx = sequence_loss(self.decoder_output, 155 | self.responses_target, self.decoder_mask) 156 | self.sentence_ppx = tf.identity(self.sentence_ppx, 'ppx_loss') 157 | 158 | with tf.variable_scope('decoder', reuse=True): 159 | # get attention function 160 | attention_keys, attention_values, attention_score_fn, attention_construct_fn \ 161 | = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=triples_embedding, output_alignments=output_alignments)#'luong', num_units) 162 | decoder_fn_inference = attention_decoder_fn_inference( 163 | output_fn, encoder_state, attention_keys, attention_values, 164 | attention_score_fn, attention_construct_fn, self.embed, GO_ID, 165 | EOS_ID, max_length, num_symbols, imem=entities_word_embedding, selector_fn=selector_fn) 166 | 167 | 168 | self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder(decoder_cell, 169 | decoder_fn_inference, scope="decoder_rnn") 170 | if output_alignments: 171 | output_len = tf.shape(self.decoder_distribution)[1] 172 | output_ids = tf.transpose(output_ids_ta.gather(tf.range(output_len))) 173 | word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols), tf.int64) 174 | entity_ids = tf.reshape(tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape(tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]), [-1]) 175 | entities = tf.reshape(tf.gather(tf.reshape(self.entities, [-1]), entity_ids), [-1, output_len]) 176 | words = self.index2symbol.lookup(word_ids) 177 | self.generation = tf.where(output_ids > 0, words, entities, name='generation') 178 | else: 179 | self.generation_index = tf.argmax(self.decoder_distribution, 2) 180 | 181 | self.generation = self.index2symbol.lookup(self.generation_index, name='generation') 182 | 183 | 184 | # initialize the training process 185 | self.learning_rate = tf.Variable(float(learning_rate), 186 | trainable=False, dtype=tf.float32) 187 | self.learning_rate_decay_op = self.learning_rate.assign( 188 | self.learning_rate * learning_rate_decay_factor) 189 | self.global_step = tf.Variable(0, trainable=False) 190 | 191 | self.params = tf.global_variables() 192 | 193 | # calculate the gradient of parameters 194 | #opt = tf.train.GradientDescentOptimizer(self.learning_rate) 195 | opt = tf.train.AdamOptimizer(learning_rate=learning_rate) 196 | self.lr = opt._lr 197 | 198 | gradients = tf.gradients(self.decoder_loss, self.params) 199 | clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, 200 | max_gradient_norm) 201 | self.update = opt.apply_gradients(zip(clipped_gradients, self.params), 202 | global_step=self.global_step) 203 | 204 | tf.summary.scalar('decoder_loss', self.decoder_loss) 205 | for each in tf.trainable_variables(): 206 | tf.summary.histogram(each.name, each) 207 | 208 | self.merged_summary_op = tf.summary.merge_all() 209 | 210 | self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, 211 | max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) 212 | 213 | self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1000, pad_step_number=True) 214 | 215 | 216 | def print_parameters(self): 217 | for item in self.params: 218 | print('%s: %s' % (item.name, item.get_shape())) 219 | 220 | def step_decoder(self, session, data, forward_only=False, summary=False, kb_use=True): 221 | input_feed = {self.posts: data['posts'], 222 | self.posts_length: data['posts_length'], 223 | self.responses: data['responses'], 224 | self.responses_length: data['responses_length'], 225 | self.triples: data['triples'], 226 | self.match_triples: data['match_triples']} 227 | 228 | if forward_only: 229 | output_feed = [self.sentence_ppx] 230 | else: 231 | output_feed = [self.sentence_ppx, self.gradient_norm, self.update] 232 | if summary: 233 | output_feed.append(self.merged_summary_op) 234 | return session.run(output_feed, input_feed) 235 | -------------------------------------------------------------------------------- /copynet/output_projection.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.layers.python.layers import layers 3 | from tensorflow.python.ops import variable_scope 4 | 5 | def output_projection_layer(num_units, num_symbols, num_samples=None, name="output_projection"): 6 | def output_fn(outputs): 7 | return layers.linear(outputs, num_symbols, scope=name) 8 | 9 | def selector_fn(outputs): 10 | return tf.sigmoid(layers.linear(outputs, 1, scope='selector')) 11 | 12 | def sequence_loss(outputs, targets, masks): 13 | with variable_scope.variable_scope('decoder_rnn'): 14 | batch_size = tf.shape(outputs)[0] 15 | logits = layers.linear(outputs, num_symbols, scope=name) 16 | logits = tf.reshape(logits, [-1, num_symbols]) 17 | local_labels = tf.reshape(targets, [-1]) 18 | local_masks = tf.reshape(masks, [-1]) 19 | 20 | local_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=local_labels, logits=logits) 21 | local_loss = local_loss * local_masks 22 | ppx_loss = tf.reduce_sum(tf.reshape(local_loss, [batch_size, -1]), axis=1) / tf.reduce_sum(masks, axis=1) 23 | 24 | loss = tf.reduce_sum(local_loss) 25 | total_size = tf.reduce_sum(local_masks) 26 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 27 | 28 | return loss / total_size, ppx_loss 29 | 30 | def sampled_sequence_loss(outputs, targets, masks): 31 | with variable_scope.variable_scope('decoder_rnn/%s' % name): 32 | weights = tf.transpose(tf.get_variable("weights", [num_units, num_symbols])) 33 | bias = tf.get_variable("biases", [num_symbols]) 34 | 35 | local_labels = tf.reshape(targets, [-1, 1]) 36 | local_outputs = tf.reshape(outputs, [-1, num_units]) 37 | local_masks = tf.reshape(masks, [-1]) 38 | 39 | local_loss = tf.nn.sampled_softmax_loss(weights, bias, local_labels, 40 | local_outputs, num_samples, num_symbols) 41 | local_loss = local_loss * local_masks 42 | 43 | loss = tf.reduce_sum(local_loss) 44 | total_size = tf.reduce_sum(local_masks) 45 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 46 | 47 | return loss / total_size 48 | 49 | def total_loss(outputs, targets, masks, alignments, triples_embedding, use_entities, entity_targets): 50 | local_masks = tf.reshape(masks, [-1]) 51 | 52 | logits = layers.linear(outputs, num_symbols, scope='decoder_rnn/%s' % name) 53 | one_hot_targets = tf.one_hot(targets, num_symbols) 54 | word_prob = tf.reduce_sum(tf.nn.softmax(logits) * one_hot_targets, axis=2) 55 | selector = tf.squeeze(tf.sigmoid(layers.linear(outputs, 1, scope='decoder_rnn/selector'))) 56 | 57 | triple_prob = tf.reduce_sum(alignments * entity_targets, axis=[2]) 58 | cast_selector = tf.cast(tf.reduce_sum(alignments, axis=2) > tf.reduce_sum(tf.nn.softmax(logits), axis=2), tf.float32) 59 | final_prob = word_prob * (1 - selector) + triple_prob * selector 60 | ppx_prob = word_prob * (1 - use_entities) + triple_prob * use_entities 61 | final_loss = tf.reshape( - tf.log(1e-12 + final_prob), [-1]) * local_masks 62 | ppx_loss = tf.reshape( - tf.log(1e-12 + ppx_prob), [-1]) * local_masks 63 | sentence_ppx = tf.reduce_sum( - tf.log(1e-12 + ppx_prob) * masks, axis=1) 64 | 65 | loss = tf.reduce_sum(final_loss) 66 | #loss = tf.Print(loss, ['use_entity', tf.reduce_min(use_entities), tf.reduce_max(use_entities), 'triple_prob',tf.reduce_min(triple_prob), 'word_prob', tf.reduce_min(word_prob), 'final_prob', tf.reduce_min(final_prob), 'final_loss', tf.reduce_min(final_loss)], summarize=1e6) 67 | total_size = tf.reduce_sum(local_masks) 68 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 69 | 70 | return loss / total_size, tf.reduce_sum(ppx_loss) / total_size, sentence_ppx / tf.reduce_sum(masks, axis=1) 71 | 72 | 73 | 74 | return output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss 75 | 76 | -------------------------------------------------------------------------------- /dynamic_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Seq2seq layer operations for use in neural networks. 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | from tensorflow.contrib import layers 24 | from tensorflow.python.framework import ops 25 | from tensorflow.python.ops import array_ops 26 | from tensorflow.python.ops import control_flow_ops 27 | from tensorflow.python.ops import math_ops 28 | from tensorflow.python.ops import rnn 29 | from tensorflow.python.ops import tensor_array_ops 30 | from tensorflow.python.ops import variable_scope as vs 31 | 32 | __all__ = ["dynamic_rnn_decoder"] 33 | 34 | def dynamic_rnn_decoder(cell, decoder_fn, inputs=None, sequence_length=None, 35 | parallel_iterations=None, swap_memory=False, 36 | time_major=False, scope=None, name=None): 37 | """ Dynamic RNN decoder for a sequence-to-sequence model specified by 38 | RNNCell and decoder function. 39 | 40 | The `dynamic_rnn_decoder` is similar to the `tf.python.ops.rnn.dynamic_rnn` 41 | as the decoder does not make any assumptions of sequence length and batch 42 | size of the input. 43 | 44 | The `dynamic_rnn_decoder` has two modes: training or inference and expects 45 | the user to create seperate functions for each. 46 | 47 | Under both training and inference, both `cell` and `decoder_fn` are expected, 48 | where `cell` performs computation at every timestep using `raw_rnn`, and 49 | `decoder_fn` allows modeling of early stopping, output, state, and next 50 | input and context. 51 | 52 | When training the user is expected to supply `inputs`. At every time step a 53 | slice of the supplied input is fed to the `decoder_fn`, which modifies and 54 | returns the input for the next time step. 55 | 56 | `sequence_length` is needed at training time, i.e., when `inputs` is not 57 | None, for dynamic unrolling. At test time, when `inputs` is None, 58 | `sequence_length` is not needed. 59 | 60 | Under inference `inputs` is expected to be `None` and the input is inferred 61 | solely from the `decoder_fn`. 62 | 63 | Args: 64 | cell: An instance of RNNCell. 65 | decoder_fn: A function that takes time, cell state, cell input, 66 | cell output and context state. It returns a early stopping vector, 67 | cell state, next input, cell output and context state. 68 | Examples of decoder_fn can be found in the decoder_fn.py folder. 69 | inputs: The inputs for decoding (embedded format). 70 | 71 | If `time_major == False` (default), this must be a `Tensor` of shape: 72 | `[batch_size, max_time, ...]`. 73 | 74 | If `time_major == True`, this must be a `Tensor` of shape: 75 | `[max_time, batch_size, ...]`. 76 | 77 | The input to `cell` at each time step will be a `Tensor` with dimensions 78 | `[batch_size, ...]`. 79 | 80 | sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. 81 | if `inputs` is not None and `sequence_length` is None it is inferred 82 | from the `inputs` as the maximal possible sequence length. 83 | parallel_iterations: (Default: 32). The number of iterations to run in 84 | parallel. Those operations which do not have any temporal dependency 85 | and can be run in parallel, will be. This parameter trades off 86 | time for space. Values >> 1 use more memory but take less time, 87 | while smaller values use less memory but computations take longer. 88 | swap_memory: Transparently swap the tensors produced in forward inference 89 | but needed for back prop from GPU to CPU. This allows training RNNs 90 | which would typically not fit on a single GPU, with very minimal (or no) 91 | performance penalty. 92 | time_major: The shape format of the `inputs` and `outputs` Tensors. 93 | If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. 94 | If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. 95 | Using `time_major = True` is a bit more efficient because it avoids 96 | transposes at the beginning and end of the RNN calculation. However, 97 | most TensorFlow data is batch-major, so by default this function 98 | accepts input and emits output in batch-major form. 99 | scope: VariableScope for the `raw_rnn`; 100 | defaults to None. 101 | name: NameScope for the decoder; 102 | defaults to "dynamic_rnn_decoder" 103 | 104 | Returns: 105 | A tuple (outputs, final_state, final_context_state) where: 106 | 107 | outputs: the RNN output 'Tensor'. 108 | 109 | If time_major == False (default), this will be a `Tensor` shaped: 110 | `[batch_size, max_time, cell.output_size]`. 111 | 112 | If time_major == True, this will be a `Tensor` shaped: 113 | `[max_time, batch_size, cell.output_size]`. 114 | 115 | final_state: The final state and will be shaped 116 | `[batch_size, cell.state_size]`. 117 | 118 | final_context_state: The context state returned by the final call 119 | to decoder_fn. This is useful if the context state maintains internal 120 | data which is required after the graph is run. 121 | For example, one way to diversify the inference output is to use 122 | a stochastic decoder_fn, in which case one would want to store the 123 | decoded outputs, not just the RNN outputs. This can be done by 124 | maintaining a TensorArray in context_state and storing the decoded 125 | output of each iteration therein. 126 | 127 | Raises: 128 | ValueError: if inputs is not None and has less than three dimensions. 129 | """ 130 | with ops.name_scope(name, "dynamic_rnn_decoder", 131 | [cell, decoder_fn, inputs, sequence_length, 132 | parallel_iterations, swap_memory, time_major, scope]): 133 | if inputs is not None: 134 | # Convert to tensor 135 | inputs = ops.convert_to_tensor(inputs) 136 | 137 | # Test input dimensions 138 | if inputs.get_shape().ndims is not None and ( 139 | inputs.get_shape().ndims < 2): 140 | raise ValueError("Inputs must have at least two dimensions") 141 | # Setup of RNN (dimensions, sizes, length, initial state, dtype) 142 | if not time_major: 143 | # [batch, seq, features] -> [seq, batch, features] 144 | inputs = array_ops.transpose(inputs, perm=[1, 0, 2]) 145 | 146 | dtype = inputs.dtype 147 | # Get data input information 148 | input_depth = int(inputs.get_shape()[2]) 149 | batch_depth = inputs.get_shape()[1].value 150 | max_time = inputs.get_shape()[0].value 151 | if max_time is None: 152 | max_time = array_ops.shape(inputs)[0] 153 | # Setup decoder inputs as TensorArray 154 | inputs_ta = tensor_array_ops.TensorArray(dtype, size=max_time) 155 | inputs_ta = inputs_ta.unstack(inputs) 156 | 157 | def loop_fn(time, cell_output, cell_state, loop_state): 158 | if cell_state is None: # first call, before while loop (in raw_rnn) 159 | if cell_output is not None: 160 | raise ValueError("Expected cell_output to be None when cell_state " 161 | "is None, but saw: %s" % cell_output) 162 | if loop_state is not None: 163 | raise ValueError("Expected loop_state to be None when cell_state " 164 | "is None, but saw: %s" % loop_state) 165 | context_state = None 166 | else: # subsequent calls, inside while loop, after cell excution 167 | if isinstance(loop_state, tuple): 168 | (done, context_state) = loop_state 169 | else: 170 | done = loop_state 171 | context_state = None 172 | 173 | # call decoder function 174 | if inputs is not None: # training 175 | # get next_cell_input 176 | if cell_state is None: 177 | next_cell_input = inputs_ta.read(0) 178 | else: 179 | if batch_depth is not None: 180 | batch_size = batch_depth 181 | else: 182 | batch_size = array_ops.shape(done)[0] 183 | next_cell_input = control_flow_ops.cond( 184 | math_ops.equal(time, max_time), 185 | lambda: array_ops.zeros([batch_size, input_depth], dtype=dtype), 186 | lambda: inputs_ta.read(time)) 187 | (next_done, next_cell_state, next_cell_input, emit_output, 188 | next_context_state) = decoder_fn(time, cell_state, next_cell_input, 189 | cell_output, context_state) 190 | else: # inference 191 | # next_cell_input is obtained through decoder_fn 192 | (next_done, next_cell_state, next_cell_input, emit_output, 193 | next_context_state) = decoder_fn(time, cell_state, None, cell_output, 194 | context_state) 195 | 196 | # check if we are done 197 | if next_done is None: # training 198 | next_done = time >= sequence_length 199 | 200 | # build next_loop_state 201 | if next_context_state is None: 202 | next_loop_state = next_done 203 | else: 204 | next_loop_state = (next_done, next_context_state) 205 | 206 | return (next_done, next_cell_input, next_cell_state, 207 | emit_output, next_loop_state) 208 | 209 | # Run raw_rnn function 210 | outputs_ta, final_state, final_loop_state = rnn.raw_rnn( 211 | cell, loop_fn, parallel_iterations=parallel_iterations, 212 | swap_memory=swap_memory, scope=scope) 213 | outputs = outputs_ta.stack() 214 | 215 | # Get final context_state, if generated by user 216 | if isinstance(final_loop_state, tuple): 217 | final_context_state = final_loop_state[1] 218 | else: 219 | final_context_state = None 220 | 221 | if not time_major: 222 | # [seq, batch, features] -> [batch, seq, features] 223 | outputs = array_ops.transpose(outputs, perm=[1, 0, 2]) 224 | return outputs, final_state, final_context_state 225 | -------------------------------------------------------------------------------- /image/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuxchow/ccm/a41378011a82c78c522ff067fe09bccef10d62f5/image/demo.png -------------------------------------------------------------------------------- /image/evaluation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tuxchow/ccm/a41378011a82c78c522ff067fe09bccef10d62f5/image/evaluation.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import json 4 | from tensorflow.python.framework import constant_op 5 | import sys 6 | import math 7 | import os 8 | import time 9 | import random 10 | random.seed(time.time()) 11 | from model import Model, _START_VOCAB 12 | 13 | tf.app.flags.DEFINE_boolean("is_train", True, "Set to False to inference.") 14 | tf.app.flags.DEFINE_integer("symbols", 30000, "vocabulary size.") 15 | tf.app.flags.DEFINE_integer("num_entities", 21471, "entitiy vocabulary size.") 16 | tf.app.flags.DEFINE_integer("num_relations", 44, "relation size.") 17 | tf.app.flags.DEFINE_integer("embed_units", 300, "Size of word embedding.") 18 | tf.app.flags.DEFINE_integer("trans_units", 100, "Size of trans embedding.") 19 | tf.app.flags.DEFINE_integer("units", 512, "Size of each model layer.") 20 | tf.app.flags.DEFINE_integer("layers", 2, "Number of layers in the model.") 21 | tf.app.flags.DEFINE_integer("batch_size", 100, "Batch size to use during training.") 22 | tf.app.flags.DEFINE_string("data_dir", "./data", "Data directory") 23 | tf.app.flags.DEFINE_string("train_dir", "./train", "Training directory.") 24 | tf.app.flags.DEFINE_integer("per_checkpoint", 1000, "How many steps to do per checkpoint.") 25 | tf.app.flags.DEFINE_integer("inference_version", 0, "The version for inferencing.") 26 | tf.app.flags.DEFINE_boolean("log_parameters", True, "Set to True to show the parameters") 27 | tf.app.flags.DEFINE_string("inference_path", "test", "Set filename of inference") 28 | 29 | FLAGS = tf.app.flags.FLAGS 30 | if FLAGS.train_dir[-1] == '/': FLAGS.train_dir = FLAGS.train_dir[:-1] 31 | csk_triples, csk_entities, kb_dict = [], [], [] 32 | 33 | def prepare_data(path, is_train=True): 34 | global csk_entities, csk_triples, kb_dict 35 | 36 | with open('%s/resource.txt' % path) as f: 37 | d = json.loads(f.readline()) 38 | 39 | csk_triples = d['csk_triples'] 40 | csk_entities = d['csk_entities'] 41 | raw_vocab = d['vocab_dict'] 42 | kb_dict = d['dict_csk'] 43 | 44 | data_train, data_dev, data_test = [], [], [] 45 | 46 | if is_train: 47 | with open('%s/trainset.txt' % path) as f: 48 | for idx, line in enumerate(f): 49 | #if idx == 100000: break 50 | if idx % 100000 == 0: print('read train file line %d' % idx) 51 | data_train.append(json.loads(line)) 52 | 53 | with open('%s/validset.txt' % path) as f: 54 | for line in f: 55 | data_dev.append(json.loads(line)) 56 | 57 | with open('%s/testset.txt' % path) as f: 58 | for line in f: 59 | data_test.append(json.loads(line)) 60 | 61 | return raw_vocab, data_train, data_dev, data_test 62 | 63 | def build_vocab(path, raw_vocab, trans='transE'): 64 | print("Creating word vocabulary...") 65 | vocab_list = _START_VOCAB + sorted(raw_vocab, key=raw_vocab.get, reverse=True) 66 | if len(vocab_list) > FLAGS.symbols: 67 | vocab_list = vocab_list[:FLAGS.symbols] 68 | 69 | print("Creating entity vocabulary...") 70 | entity_list = ['_NONE', '_PAD_H', '_PAD_R', '_PAD_T', '_NAF_H', '_NAF_R', '_NAF_T'] 71 | with open('%s/entity.txt' % path) as f: 72 | for i, line in enumerate(f): 73 | e = line.strip() 74 | entity_list.append(e) 75 | 76 | print("Creating relation vocabulary...") 77 | relation_list = [] 78 | with open('%s/relation.txt' % path) as f: 79 | for i, line in enumerate(f): 80 | r = line.strip() 81 | relation_list.append(r) 82 | 83 | print("Loading word vectors...") 84 | vectors = {} 85 | with open('%s/glove.840B.300d.txt' % path) as f: 86 | for i, line in enumerate(f): 87 | if i % 100000 == 0: 88 | print(" processing line %d" % i) 89 | s = line.strip() 90 | word = s[:s.find(' ')] 91 | vector = s[s.find(' ')+1:] 92 | vectors[word] = vector 93 | 94 | embed = [] 95 | for word in vocab_list: 96 | if word in vectors: 97 | vector = map(float, vectors[word].split()) 98 | else: 99 | vector = np.zeros((FLAGS.embed_units), dtype=np.float32) 100 | embed.append(vector) 101 | embed = np.array(embed, dtype=np.float32) 102 | 103 | print("Loading entity vectors...") 104 | entity_embed = [] 105 | with open('%s/entity_%s.txt' % (path, trans)) as f: 106 | for i, line in enumerate(f): 107 | s = line.strip().split('\t') 108 | entity_embed.append(map(float, s)) 109 | 110 | print("Loading relation vectors...") 111 | relation_embed = [] 112 | with open('%s/relation_%s.txt' % (path, trans)) as f: 113 | for i, line in enumerate(f): 114 | s = line.strip().split('\t') 115 | relation_embed.append(s) 116 | 117 | entity_relation_embed = np.array(entity_embed+relation_embed, dtype=np.float32) 118 | entity_embed = np.array(entity_embed, dtype=np.float32) 119 | relation_embed = np.array(relation_embed, dtype=np.float32) 120 | 121 | return vocab_list, embed, entity_list, entity_embed, relation_list, relation_embed, entity_relation_embed 122 | 123 | def gen_batched_data(data): 124 | global csk_entities, csk_triples, kb_dict 125 | encoder_len = max([len(item['post']) for item in data])+1 126 | decoder_len = max([len(item['response']) for item in data])+1 127 | triple_num = max([len(item['all_triples']) for item in data])+1 128 | triple_len = max([len(tri) for item in data for tri in item['all_triples']]) 129 | max_length = 20 130 | posts, responses, posts_length, responses_length = [], [], [], [] 131 | entities, triples, matches, post_triples, response_triples = [], [], [], [], [] 132 | match_entities, all_entities = [], [] 133 | match_triples, all_triples = [], [] 134 | NAF = ['_NAF_H', '_NAF_R', '_NAF_T'] 135 | 136 | def padding(sent, l): 137 | return sent + ['_EOS'] + ['_PAD'] * (l-len(sent)-1) 138 | 139 | def padding_triple(triple, num, l): 140 | newtriple = [] 141 | triple = [[NAF]] + triple 142 | for tri in triple: 143 | newtriple.append(tri + [['_PAD_H', '_PAD_R', '_PAD_T']] * (l-len(tri))) 144 | pad_triple = [['_PAD_H', '_PAD_R', '_PAD_T']] * l 145 | return newtriple + [pad_triple] * (num - len(newtriple)) 146 | 147 | for item in data: 148 | posts.append(padding(item['post'], encoder_len)) 149 | responses.append(padding(item['response'], decoder_len)) 150 | posts_length.append(len(item['post'])+1) 151 | responses_length.append(len(item['response'])+1) 152 | all_triples.append(padding_triple([[csk_triples[x].split(', ') for x in triple] for triple in item['all_triples']], triple_num, triple_len)) 153 | post_triples.append([[x] for x in item['post_triples']] + [[0]] * (encoder_len - len(item['post_triples']))) 154 | response_triples.append([NAF] + [NAF if x == -1 else csk_triples[x].split(', ') for x in item['response_triples']] + [NAF] * (decoder_len - 1 - len(item['response_triples']))) 155 | match_index = [] 156 | for idx, x in enumerate(item['match_index']): 157 | _index = [-1] * triple_num 158 | if x[0] == -1 and x[1] == -1: 159 | match_index.append(_index) 160 | else: 161 | _index[x[0]] = x[1] 162 | t = all_triples[-1][x[0]][x[1]] 163 | assert(t == response_triples[-1][idx+1]) 164 | match_index.append(_index) 165 | match_triples.append(match_index + [[-1]*triple_num]*(decoder_len-len(match_index))) 166 | 167 | if not FLAGS.is_train: 168 | entity = [['_NONE']*triple_len] 169 | for ent in item['all_entities']: 170 | entity.append([csk_entities[x] for x in ent] + ['_NONE'] * (triple_len-len(ent))) 171 | entities.append(entity+[['_NONE']*triple_len]*(triple_num-len(entity))) 172 | 173 | 174 | batched_data = {'posts': np.array(posts), 175 | 'responses': np.array(responses), 176 | 'posts_length': posts_length, 177 | 'responses_length': responses_length, 178 | 'triples': np.array(all_triples), 179 | 'entities': np.array(entities), 180 | 'posts_triple': np.array(post_triples), 181 | 'responses_triple': np.array(response_triples), 182 | 'match_triples': np.array(match_triples)} 183 | 184 | return batched_data 185 | 186 | def train(model, sess, data_train): 187 | batched_data = gen_batched_data(data_train) 188 | outputs = model.step_decoder(sess, batched_data) 189 | return np.sum(outputs[0]) 190 | 191 | def generate_summary(model, sess, data_train): 192 | selected_data = [random.choice(data_train) for i in range(FLAGS.batch_size)] 193 | batched_data = gen_batched_data(selected_data) 194 | summary = model.step_decoder(sess, batched_data, forward_only=True, summary=True)[-1] 195 | return summary 196 | 197 | 198 | def evaluate(model, sess, data_dev, summary_writer): 199 | loss = np.zeros((1, )) 200 | st, ed, times = 0, FLAGS.batch_size, 0 201 | while st < len(data_dev): 202 | selected_data = data_dev[st:ed] 203 | batched_data = gen_batched_data(selected_data) 204 | outputs = model.step_decoder(sess, batched_data, forward_only=True) 205 | loss += np.sum(outputs[0]) 206 | st, ed = ed, ed+FLAGS.batch_size 207 | times += 1 208 | loss /= len(data_dev) 209 | summary = tf.Summary() 210 | summary.value.add(tag='decoder_loss/dev', simple_value=loss) 211 | summary.value.add(tag='perplexity/dev', simple_value=np.exp(loss)) 212 | summary_writer.add_summary(summary, model.global_step.eval()) 213 | print(' perplexity on dev set: %.2f' % np.exp(loss)) 214 | 215 | def get_steps(train_dir): 216 | a = os.walk(train_dir) 217 | for root, dirs, files in a: 218 | if root == train_dir: 219 | filenames = files 220 | 221 | steps, metafiles, datafiles, indexfiles = [], [], [], [] 222 | for filename in filenames: 223 | if 'meta' in filename: 224 | metafiles.append(filename) 225 | if 'data' in filename: 226 | datafiles.append(filename) 227 | if 'index' in filename: 228 | indexfiles.append(filename) 229 | 230 | metafiles.sort() 231 | datafiles.sort() 232 | indexfiles.sort(reverse=True) 233 | 234 | for f in indexfiles: 235 | steps.append(int(f[11:-6])) 236 | 237 | return steps 238 | 239 | def test(sess, saver, data_dev, setnum=5000): 240 | with open('%s/stopwords' % FLAGS.data_dir) as f: 241 | stopwords = json.loads(f.readline()) 242 | steps = get_steps(FLAGS.train_dir) 243 | low_step = 00000 244 | high_step = 800000 245 | with open('%s.res' % FLAGS.inference_path, 'w') as resfile, open('%s.log' % FLAGS.inference_path, 'w') as outfile: 246 | for step in [step for step in steps if step > low_step and step < high_step]: 247 | outfile.write('test for model-%d\n' % step) 248 | model_path = '%s/checkpoint-%08d' % (FLAGS.train_dir, step) 249 | print('restore from %s' % model_path) 250 | try: 251 | saver.restore(sess, model_path) 252 | except: 253 | continue 254 | st, ed = 0, FLAGS.batch_size 255 | results = [] 256 | loss = [] 257 | while st < len(data_dev): 258 | selected_data = data_dev[st:ed] 259 | batched_data = gen_batched_data(selected_data) 260 | responses, ppx_loss = sess.run(['decoder_1/generation:0', 'decoder/ppx_loss:0'], {'enc_inps:0': batched_data['posts'], 'enc_lens:0': batched_data['posts_length'], 'dec_inps:0': batched_data['responses'], 'dec_lens:0': batched_data['responses_length'], 'entities:0': batched_data['entities'], 'triples:0': batched_data['triples'], 'match_triples:0': batched_data['match_triples'], 'enc_triples:0': batched_data['posts_triple'], 'dec_triples:0': batched_data['responses_triple']}) 261 | loss += [x for x in ppx_loss] 262 | for response in responses: 263 | result = [] 264 | for token in response: 265 | if token != '_EOS': 266 | result.append(token) 267 | else: 268 | break 269 | results.append(result) 270 | st, ed = ed, ed+FLAGS.batch_size 271 | match_entity_sum = [.0] * 4 272 | cnt = 0 273 | for post, response, result, match_triples, triples, entities in zip([data['post'] for data in data_dev], [data['response'] for data in data_dev], results, [data['match_triples'] for data in data_dev], [data['all_triples'] for data in data_dev], [data['all_entities'] for data in data_dev]): 274 | setidx = cnt / setnum 275 | result_matched_entities = [] 276 | triples = [csk_triples[tri] for triple in triples for tri in triple] 277 | match_triples = [csk_triples[triple] for triple in match_triples] 278 | entities = [csk_entities[x] for entity in entities for x in entity] 279 | matches = [x for triple in match_triples for x in [triple.split(', ')[0], triple.split(', ')[2]] if x in response] 280 | 281 | for word in result: 282 | if word not in stopwords and word in entities: 283 | result_matched_entities.append(word) 284 | outfile.write('post: %s\nresponse: %s\nresult: %s\nmatch_entity: %s\n\n' % (' '.join(post), ' '.join(response), ' '.join(result), ' '.join(result_matched_entities))) 285 | match_entity_sum[setidx] += len(set(result_matched_entities)) 286 | cnt += 1 287 | match_entity_sum = [m / setnum for m in match_entity_sum] + [sum(match_entity_sum) / len(data_dev)] 288 | losses = [np.sum(loss[x:x+setnum]) / float(setnum) for x in range(0, setnum*4, setnum)] + [np.sum(loss) / float(setnum*4)] 289 | losses = [np.exp(x) for x in losses] 290 | def show(x): 291 | return ', '.join([str(v) for v in x]) 292 | outfile.write('model: %d\n\tperplexity: %s\n\tmatch_entity_rate: %s\n%s\n\n' % (step, show(losses), show(match_entity_sum), '='*50)) 293 | resfile.write('model: %d\n\tperplexity: %s\n\tmatch_entity_rate: %s\n\n' % (step, show(losses), show(match_entity_sum))) 294 | outfile.flush() 295 | resfile.flush() 296 | return results 297 | 298 | config = tf.ConfigProto() 299 | config.gpu_options.allow_growth = True 300 | with tf.Session(config=config) as sess: 301 | if FLAGS.is_train: 302 | raw_vocab, data_train, data_dev, data_test = prepare_data(FLAGS.data_dir) 303 | vocab, embed, entity_vocab, entity_embed, relation_vocab, relation_embed, entity_relation_embed = build_vocab(FLAGS.data_dir, raw_vocab) 304 | FLAGS.num_entities = len(entity_vocab) 305 | print(FLAGS.__flags) 306 | model = Model( 307 | FLAGS.symbols, 308 | FLAGS.embed_units, 309 | FLAGS.units, 310 | FLAGS.layers, 311 | embed, 312 | entity_relation_embed, 313 | num_entities=len(entity_vocab)+len(relation_vocab), 314 | num_trans_units=FLAGS.trans_units) 315 | if tf.train.get_checkpoint_state(FLAGS.train_dir): 316 | print("Reading model parameters from %s" % FLAGS.train_dir) 317 | model.saver.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir)) 318 | else: 319 | print("Created model with fresh parameters.") 320 | tf.global_variables_initializer().run() 321 | op_in = model.symbol2index.insert(constant_op.constant(vocab), 322 | constant_op.constant(range(FLAGS.symbols), dtype=tf.int64)) 323 | sess.run(op_in) 324 | op_out = model.index2symbol.insert(constant_op.constant( 325 | range(FLAGS.symbols), dtype=tf.int64), constant_op.constant(vocab)) 326 | sess.run(op_out) 327 | op_in = model.entity2index.insert(constant_op.constant(entity_vocab+relation_vocab), 328 | constant_op.constant(range(len(entity_vocab)+len(relation_vocab)), dtype=tf.int64)) 329 | sess.run(op_in) 330 | op_out = model.index2entity.insert(constant_op.constant( 331 | range(len(entity_vocab)+len(relation_vocab)), dtype=tf.int64), constant_op.constant(entity_vocab+relation_vocab)) 332 | sess.run(op_out) 333 | 334 | if FLAGS.log_parameters: 335 | model.print_parameters() 336 | 337 | summary_writer = tf.summary.FileWriter('%s/log' % FLAGS.train_dir, sess.graph) 338 | loss_step, time_step = np.zeros((1, )), .0 339 | previous_losses = [1e18]*3 340 | train_len = len(data_train) 341 | while True: 342 | st, ed = 0, FLAGS.batch_size * FLAGS.per_checkpoint 343 | random.shuffle(data_train) 344 | while st < train_len: 345 | start_time = time.time() 346 | for batch in range(st, ed, FLAGS.batch_size): 347 | loss_step += train(model, sess, data_train[batch:batch+FLAGS.batch_size]) / (ed - st) 348 | 349 | show = lambda a: '[%s]' % (' '.join(['%.2f' % x for x in a])) 350 | print("global step %d learning rate %.4f step-time %.2f loss %f perplexity %s" 351 | % (model.global_step.eval(), model.lr, 352 | (time.time() - start_time) / ((ed - st) / FLAGS.batch_size), loss_step, show(np.exp(loss_step)))) 353 | model.saver.save(sess, '%s/checkpoint' % FLAGS.train_dir, 354 | global_step=model.global_step) 355 | summary = tf.Summary() 356 | summary.value.add(tag='decoder_loss/train', simple_value=loss_step) 357 | summary.value.add(tag='perplexity/train', simple_value=np.exp(loss_step)) 358 | summary_writer.add_summary(summary, model.global_step.eval()) 359 | summary_model = generate_summary(model, sess, data_train) 360 | summary_writer.add_summary(summary_model, model.global_step.eval()) 361 | evaluate(model, sess, data_dev, summary_writer) 362 | previous_losses = previous_losses[1:]+[np.sum(loss_step)] 363 | loss_step, time_step = np.zeros((1, )), .0 364 | st, ed = ed, min(train_len, ed + FLAGS.batch_size * FLAGS.per_checkpoint) 365 | model.saver_epoch.save(sess, '%s/epoch/checkpoint' % FLAGS.train_dir, global_step=model.global_step) 366 | else: 367 | model = Model( 368 | FLAGS.symbols, 369 | FLAGS.embed_units, 370 | FLAGS.units, 371 | FLAGS.layers, 372 | embed=None, 373 | num_entities=FLAGS.num_entities+FLAGS.num_relations, 374 | num_trans_units=FLAGS.trans_units) 375 | 376 | if FLAGS.inference_version == 0: 377 | model_path = tf.train.latest_checkpoint(FLAGS.train_dir) 378 | else: 379 | model_path = '%s/checkpoint-%08d' % (FLAGS.train_dir, FLAGS.inference_version) 380 | print('restore from %s' % model_path) 381 | model.saver.restore(sess, model_path) 382 | saver = model.saver 383 | 384 | raw_vocab, data_train, data_dev, data_test = prepare_data(FLAGS.data_dir, is_train=False) 385 | 386 | test(sess, saver, data_test, setnum=5000) 387 | 388 | -------------------------------------------------------------------------------- /memnet/dynamic_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Seq2seq layer operations for use in neural networks. 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | from tensorflow.contrib import layers 24 | from tensorflow.python.framework import ops 25 | from tensorflow.python.ops import array_ops 26 | from tensorflow.python.ops import control_flow_ops 27 | from tensorflow.python.ops import math_ops 28 | from tensorflow.python.ops import rnn 29 | from tensorflow.python.ops import tensor_array_ops 30 | from tensorflow.python.ops import variable_scope as vs 31 | 32 | __all__ = ["dynamic_rnn_decoder"] 33 | 34 | def dynamic_rnn_decoder(cell, decoder_fn, inputs=None, sequence_length=None, 35 | parallel_iterations=None, swap_memory=False, 36 | time_major=False, scope=None, name=None): 37 | """ Dynamic RNN decoder for a sequence-to-sequence model specified by 38 | RNNCell and decoder function. 39 | 40 | The `dynamic_rnn_decoder` is similar to the `tf.python.ops.rnn.dynamic_rnn` 41 | as the decoder does not make any assumptions of sequence length and batch 42 | size of the input. 43 | 44 | The `dynamic_rnn_decoder` has two modes: training or inference and expects 45 | the user to create seperate functions for each. 46 | 47 | Under both training and inference, both `cell` and `decoder_fn` are expected, 48 | where `cell` performs computation at every timestep using `raw_rnn`, and 49 | `decoder_fn` allows modeling of early stopping, output, state, and next 50 | input and context. 51 | 52 | When training the user is expected to supply `inputs`. At every time step a 53 | slice of the supplied input is fed to the `decoder_fn`, which modifies and 54 | returns the input for the next time step. 55 | 56 | `sequence_length` is needed at training time, i.e., when `inputs` is not 57 | None, for dynamic unrolling. At test time, when `inputs` is None, 58 | `sequence_length` is not needed. 59 | 60 | Under inference `inputs` is expected to be `None` and the input is inferred 61 | solely from the `decoder_fn`. 62 | 63 | Args: 64 | cell: An instance of RNNCell. 65 | decoder_fn: A function that takes time, cell state, cell input, 66 | cell output and context state. It returns a early stopping vector, 67 | cell state, next input, cell output and context state. 68 | Examples of decoder_fn can be found in the decoder_fn.py folder. 69 | inputs: The inputs for decoding (embedded format). 70 | 71 | If `time_major == False` (default), this must be a `Tensor` of shape: 72 | `[batch_size, max_time, ...]`. 73 | 74 | If `time_major == True`, this must be a `Tensor` of shape: 75 | `[max_time, batch_size, ...]`. 76 | 77 | The input to `cell` at each time step will be a `Tensor` with dimensions 78 | `[batch_size, ...]`. 79 | 80 | sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. 81 | if `inputs` is not None and `sequence_length` is None it is inferred 82 | from the `inputs` as the maximal possible sequence length. 83 | parallel_iterations: (Default: 32). The number of iterations to run in 84 | parallel. Those operations which do not have any temporal dependency 85 | and can be run in parallel, will be. This parameter trades off 86 | time for space. Values >> 1 use more memory but take less time, 87 | while smaller values use less memory but computations take longer. 88 | swap_memory: Transparently swap the tensors produced in forward inference 89 | but needed for back prop from GPU to CPU. This allows training RNNs 90 | which would typically not fit on a single GPU, with very minimal (or no) 91 | performance penalty. 92 | time_major: The shape format of the `inputs` and `outputs` Tensors. 93 | If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. 94 | If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. 95 | Using `time_major = True` is a bit more efficient because it avoids 96 | transposes at the beginning and end of the RNN calculation. However, 97 | most TensorFlow data is batch-major, so by default this function 98 | accepts input and emits output in batch-major form. 99 | scope: VariableScope for the `raw_rnn`; 100 | defaults to None. 101 | name: NameScope for the decoder; 102 | defaults to "dynamic_rnn_decoder" 103 | 104 | Returns: 105 | A tuple (outputs, final_state, final_context_state) where: 106 | 107 | outputs: the RNN output 'Tensor'. 108 | 109 | If time_major == False (default), this will be a `Tensor` shaped: 110 | `[batch_size, max_time, cell.output_size]`. 111 | 112 | If time_major == True, this will be a `Tensor` shaped: 113 | `[max_time, batch_size, cell.output_size]`. 114 | 115 | final_state: The final state and will be shaped 116 | `[batch_size, cell.state_size]`. 117 | 118 | final_context_state: The context state returned by the final call 119 | to decoder_fn. This is useful if the context state maintains internal 120 | data which is required after the graph is run. 121 | For example, one way to diversify the inference output is to use 122 | a stochastic decoder_fn, in which case one would want to store the 123 | decoded outputs, not just the RNN outputs. This can be done by 124 | maintaining a TensorArray in context_state and storing the decoded 125 | output of each iteration therein. 126 | 127 | Raises: 128 | ValueError: if inputs is not None and has less than three dimensions. 129 | """ 130 | with ops.name_scope(name, "dynamic_rnn_decoder", 131 | [cell, decoder_fn, inputs, sequence_length, 132 | parallel_iterations, swap_memory, time_major, scope]): 133 | if inputs is not None: 134 | # Convert to tensor 135 | inputs = ops.convert_to_tensor(inputs) 136 | 137 | # Test input dimensions 138 | if inputs.get_shape().ndims is not None and ( 139 | inputs.get_shape().ndims < 2): 140 | raise ValueError("Inputs must have at least two dimensions") 141 | # Setup of RNN (dimensions, sizes, length, initial state, dtype) 142 | if not time_major: 143 | # [batch, seq, features] -> [seq, batch, features] 144 | inputs = array_ops.transpose(inputs, perm=[1, 0, 2]) 145 | 146 | dtype = inputs.dtype 147 | # Get data input information 148 | input_depth = int(inputs.get_shape()[2]) 149 | batch_depth = inputs.get_shape()[1].value 150 | max_time = inputs.get_shape()[0].value 151 | if max_time is None: 152 | max_time = array_ops.shape(inputs)[0] 153 | # Setup decoder inputs as TensorArray 154 | inputs_ta = tensor_array_ops.TensorArray(dtype, size=max_time) 155 | inputs_ta = inputs_ta.unstack(inputs) 156 | 157 | def loop_fn(time, cell_output, cell_state, loop_state): 158 | if cell_state is None: # first call, before while loop (in raw_rnn) 159 | if cell_output is not None: 160 | raise ValueError("Expected cell_output to be None when cell_state " 161 | "is None, but saw: %s" % cell_output) 162 | if loop_state is not None: 163 | raise ValueError("Expected loop_state to be None when cell_state " 164 | "is None, but saw: %s" % loop_state) 165 | context_state = None 166 | else: # subsequent calls, inside while loop, after cell excution 167 | if isinstance(loop_state, tuple): 168 | (done, context_state) = loop_state 169 | else: 170 | done = loop_state 171 | context_state = None 172 | 173 | # call decoder function 174 | if inputs is not None: # training 175 | # get next_cell_input 176 | if cell_state is None: 177 | next_cell_input = inputs_ta.read(0) 178 | else: 179 | if batch_depth is not None: 180 | batch_size = batch_depth 181 | else: 182 | batch_size = array_ops.shape(done)[0] 183 | next_cell_input = control_flow_ops.cond( 184 | math_ops.equal(time, max_time), 185 | lambda: array_ops.zeros([batch_size, input_depth], dtype=dtype), 186 | lambda: inputs_ta.read(time)) 187 | (next_done, next_cell_state, next_cell_input, emit_output, 188 | next_context_state) = decoder_fn(time, cell_state, next_cell_input, 189 | cell_output, context_state) 190 | else: # inference 191 | # next_cell_input is obtained through decoder_fn 192 | (next_done, next_cell_state, next_cell_input, emit_output, 193 | next_context_state) = decoder_fn(time, cell_state, None, cell_output, 194 | context_state) 195 | 196 | # check if we are done 197 | if next_done is None: # training 198 | next_done = time >= sequence_length 199 | 200 | # build next_loop_state 201 | if next_context_state is None: 202 | next_loop_state = next_done 203 | else: 204 | next_loop_state = (next_done, next_context_state) 205 | 206 | return (next_done, next_cell_input, next_cell_state, 207 | emit_output, next_loop_state) 208 | 209 | # Run raw_rnn function 210 | outputs_ta, final_state, final_loop_state = rnn.raw_rnn( 211 | cell, loop_fn, parallel_iterations=parallel_iterations, 212 | swap_memory=swap_memory, scope=scope) 213 | outputs = outputs_ta.stack() 214 | 215 | # Get final context_state, if generated by user 216 | if isinstance(final_loop_state, tuple): 217 | final_context_state = final_loop_state[1] 218 | else: 219 | final_context_state = None 220 | 221 | if not time_major: 222 | # [seq, batch, features] -> [batch, seq, features] 223 | outputs = array_ops.transpose(outputs, perm=[1, 0, 2]) 224 | return outputs, final_state, final_context_state 225 | -------------------------------------------------------------------------------- /memnet/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.framework import constant_op 4 | import sys 5 | import json 6 | import math 7 | import os 8 | import time 9 | import random 10 | import sqlite3 11 | random.seed(time.time()) 12 | from model import Model, _START_VOCAB 13 | 14 | tf.app.flags.DEFINE_boolean("is_train", True, "Set to False to inference.") 15 | tf.app.flags.DEFINE_integer("symbols", 30000, "vocabulary size.") 16 | tf.app.flags.DEFINE_integer("num_entities", 21471, "entitiy vocabulary size.") 17 | tf.app.flags.DEFINE_integer("num_relations", 44, "relation size.") 18 | tf.app.flags.DEFINE_integer("embed_units", 300, "Size of word embedding.") 19 | tf.app.flags.DEFINE_integer("trans_units", 100, "Size of trans embedding.") 20 | tf.app.flags.DEFINE_integer("units", 512, "Size of each model layer.") 21 | tf.app.flags.DEFINE_integer("layers", 2, "Number of layers in the model.") 22 | tf.app.flags.DEFINE_boolean("copy_use", False, "use copy mechanism or not.") 23 | tf.app.flags.DEFINE_integer("batch_size", 100, "Batch size to use during training.") 24 | tf.app.flags.DEFINE_string("data_dir", "./data", "Data directory") 25 | tf.app.flags.DEFINE_string("train_dir", "./train", "Training directory.") 26 | tf.app.flags.DEFINE_integer("per_checkpoint", 1000, "How many steps to do per checkpoint.") 27 | tf.app.flags.DEFINE_integer("inference_version", 0, "The version for inferencing.") 28 | tf.app.flags.DEFINE_boolean("log_parameters", True, "Set to True to show the parameters") 29 | tf.app.flags.DEFINE_string("inference_path", "test", "Set filename of inference, default isscreen") 30 | 31 | FLAGS = tf.app.flags.FLAGS 32 | if FLAGS.train_dir[-1] == '/': FLAGS.train_dir = FLAGS.train_dir[:-1] 33 | csk_triples, csk_entities, kb_dict = [], [], [] 34 | 35 | def prepare_data(path, is_train=True): 36 | global csk_entities, csk_triples, kb_dict 37 | 38 | with open('%s/resource.txt' % path) as f: 39 | d = json.loads(f.readline()) 40 | 41 | csk_triples = d['csk_triples'] 42 | csk_entities = d['csk_entities'] 43 | raw_vocab = d['vocab_dict'] 44 | kb_dict = d['dict_csk'] 45 | 46 | data_train, data_dev, data_test = [], [], [] 47 | 48 | if is_train: 49 | with open('%s/trainset.txt' % path) as f: 50 | for idx, line in enumerate(f): 51 | #if idx == 100000: break 52 | if idx % 100000 == 0: print('read train file line %d' % idx) 53 | data_train.append(json.loads(line)) 54 | 55 | with open('%s/validset.txt' % path) as f: 56 | for line in f: 57 | data_dev.append(json.loads(line)) 58 | 59 | with open('%s/testset.txt' % path) as f: 60 | for line in f: 61 | data_test.append(json.loads(line)) 62 | 63 | return raw_vocab, data_train, data_dev, data_test 64 | 65 | def build_vocab(path, raw_vocab, trans='transE'): 66 | print("Creating word vocabulary...") 67 | vocab_list = _START_VOCAB + sorted(raw_vocab, key=raw_vocab.get, reverse=True) 68 | if len(vocab_list) > FLAGS.symbols: 69 | vocab_list = vocab_list[:FLAGS.symbols] 70 | 71 | print("Creating entity vocabulary...") 72 | entity_list = ['_NONE', '_PAD_H', '_PAD_R', '_PAD_T', '_NAF_H', '_NAF_R', '_NAF_T'] 73 | with open('%s/entity.txt' % path) as f: 74 | for i, line in enumerate(f): 75 | e = line.strip() 76 | entity_list.append(e) 77 | 78 | print("Creating relation vocabulary...") 79 | relation_list = [] 80 | with open('%s/relation.txt' % path) as f: 81 | for i, line in enumerate(f): 82 | r = line.strip() 83 | relation_list.append(r) 84 | 85 | print("Loading word vectors...") 86 | vectors = {} 87 | with open('%s/glove.840B.300d.txt' % path) as f: 88 | for i, line in enumerate(f): 89 | if i % 100000 == 0: 90 | print(" processing line %d" % i) 91 | s = line.strip() 92 | word = s[:s.find(' ')] 93 | vector = s[s.find(' ')+1:] 94 | vectors[word] = vector 95 | 96 | embed = [] 97 | for word in vocab_list: 98 | if word in vectors: 99 | vector = map(float, vectors[word].split()) 100 | else: 101 | vector = np.zeros((FLAGS.embed_units), dtype=np.float32) 102 | embed.append(vector) 103 | embed = np.array(embed, dtype=np.float32) 104 | 105 | print("Loading entity vectors...") 106 | entity_embed = [] 107 | with open('%s/entity_%s.txt' % (path, trans)) as f: 108 | for i, line in enumerate(f): 109 | s = line.strip().split('\t') 110 | entity_embed.append(map(float, s)) 111 | 112 | print("Loading relation vectors...") 113 | relation_embed = [] 114 | with open('%s/relation_%s.txt' % (path, trans)) as f: 115 | for i, line in enumerate(f): 116 | s = line.strip().split('\t') 117 | relation_embed.append(s) 118 | 119 | entity_relation_embed = np.array(entity_embed+relation_embed, dtype=np.float32) 120 | entity_embed = np.array(entity_embed, dtype=np.float32) 121 | relation_embed = np.array(relation_embed, dtype=np.float32) 122 | 123 | return vocab_list, embed, entity_list, entity_embed, relation_list, relation_embed, entity_relation_embed 124 | 125 | def gen_batched_data(data): 126 | global csk_entities, csk_triples, kb_dict 127 | encoder_len = max([len(item['post']) for item in data])+1 128 | decoder_len = max([len(item['response']) for item in data])+1 129 | triple_len = max([sum([len(tri) for tri in item['all_triples']]) for item in data ])+1 130 | max_length = 20 131 | posts, responses, posts_length, responses_length = [], [], [], [] 132 | entities, triples, matches, post_triples, response_triples = [], [], [], [], [] 133 | match_entities, all_entities = [], [] 134 | match_triples, all_triples = [], [] 135 | NAF = ['_NAF_H', '_NAF_R', '_NAF_T'] 136 | PAD = ['_PAD_H', '_PAD_R', '_PAD_T'] 137 | 138 | def padding(sent, l): 139 | return sent + ['_EOS'] + ['_PAD'] * (l-len(sent)-1) 140 | 141 | def padding_triple(triple, l): 142 | return [NAF] + triple + [PAD] * (l - len(triple) - 1) 143 | 144 | for item in data: 145 | posts.append(padding(item['post'], encoder_len)) 146 | responses.append(padding(item['response'], decoder_len)) 147 | posts_length.append(len(item['post'])+1) 148 | responses_length.append(len(item['response'])+1) 149 | all_triples.append(padding_triple([csk_triples[x].split(', ') for triple in item['all_triples'] for x in triple], triple_len)) 150 | match_index = [] 151 | for x in item['match_index']: 152 | _index = [-1] * triple_len 153 | if x[0] == -1 and x[1] == -1: 154 | match_index.append(-1) 155 | else: 156 | match_index.append(sum([len(m) for m in item['all_triples'][:(x[0]-1)]]) + 1 + x[1]) 157 | match_triples.append(match_index + [-1]*(decoder_len-len(match_index))) 158 | 159 | if not FLAGS.is_train: 160 | entity = ['_NONE'] 161 | entity += [csk_entities[x] for ent in item['all_entities'] for x in ent] 162 | entities.append(entity+['_NONE']*(triple_len-len(entity))) 163 | 164 | 165 | batched_data = {'posts': np.array(posts), 166 | 'responses': np.array(responses), 167 | 'posts_length': posts_length, 168 | 'responses_length': responses_length, 169 | 'triples': np.array(all_triples), 170 | 'entities': np.array(entities), 171 | 'match_triples': np.array(match_triples)} 172 | return batched_data 173 | 174 | def train(model, sess, data_train): 175 | batched_data = gen_batched_data(data_train) 176 | outputs = model.step_decoder(sess, batched_data, kb_use=True) 177 | return np.sum(outputs[0]) 178 | 179 | def generate_summary(model, sess, data_train): 180 | selected_data = [random.choice(data_train) for i in range(FLAGS.batch_size)] 181 | batched_data = gen_batched_data(selected_data) 182 | summary = model.step_decoder(sess, batched_data, kb_use=True, forward_only=True, summary=True)[-1] 183 | return summary 184 | 185 | 186 | def evaluate(model, sess, data_dev, summary_writer): 187 | loss = np.zeros((1, )) 188 | st, ed, times = 0, FLAGS.batch_size, 0 189 | while st < len(data_dev): 190 | selected_data = data_dev[st:ed] 191 | batched_data = gen_batched_data(selected_data) 192 | outputs = model.step_decoder(sess, batched_data, kb_use=True, forward_only=True) 193 | loss += np.sum(outputs[0]) 194 | st, ed = ed, ed+FLAGS.batch_size 195 | times += 1 196 | loss /= len(data_dev) 197 | summary = tf.Summary() 198 | summary.value.add(tag='decoder_loss/dev', simple_value=loss) 199 | summary.value.add(tag='perplexity/dev', simple_value=np.exp(loss)) 200 | summary_writer.add_summary(summary, model.global_step.eval()) 201 | print(' perplexity on dev set: %.2f' % np.exp(loss)) 202 | 203 | 204 | def get_steps(train_dir): 205 | a = os.walk(train_dir) 206 | for root, dirs, files in a: 207 | if root == train_dir: 208 | filenames = files 209 | 210 | steps, metafiles, datafiles, indexfiles = [], [], [], [] 211 | for filename in filenames: 212 | if 'meta' in filename: 213 | metafiles.append(filename) 214 | if 'data' in filename: 215 | datafiles.append(filename) 216 | if 'index' in filename: 217 | indexfiles.append(filename) 218 | 219 | metafiles.sort() 220 | datafiles.sort() 221 | indexfiles.sort(reverse=True) 222 | 223 | for f in indexfiles: 224 | steps.append(int(f[11:-6])) 225 | 226 | return steps 227 | 228 | def test(sess, saver, data_dev, setnum=5000): 229 | with open('%s/stopwords' % FLAGS.data_dir) as f: 230 | stopwords = json.loads(f.readline()) 231 | steps = get_steps(FLAGS.train_dir) 232 | low_step = 00000 233 | high_step = 800000 234 | with open('%s.res' % FLAGS.inference_path, 'w') as resfile, open('%s.log' % FLAGS.inference_path, 'w') as outfile: 235 | for step in [step for step in steps if step > low_step and step < high_step]: 236 | outfile.write('test for model-%d\n' % step) 237 | model_path = '%s/checkpoint-%08d' % (FLAGS.train_dir, step) 238 | print('restore from %s' % model_path) 239 | try: 240 | saver.restore(sess, model_path) 241 | except: 242 | continue 243 | st, ed = 0, FLAGS.batch_size 244 | results = [] 245 | loss = [] 246 | while st < len(data_dev): 247 | selected_data = data_dev[st:ed] 248 | batched_data = gen_batched_data(selected_data) 249 | responses, ppx_loss = sess.run(['decoder_1/generation:0', 'decoder/ppx_loss:0'], {'enc_inps:0': batched_data['posts'], 'enc_lens:0': batched_data['posts_length'], 'dec_inps:0': batched_data['responses'], 'dec_lens:0': batched_data['responses_length'], 'entities:0': batched_data['entities'], 'triples:0': batched_data['triples'], 'match_triples:0': batched_data['match_triples']}) 250 | loss += [x for x in ppx_loss] 251 | for response in responses: 252 | result = [] 253 | for token in response: 254 | if token != '_EOS': 255 | result.append(token) 256 | else: 257 | break 258 | results.append(result) 259 | st, ed = ed, ed+FLAGS.batch_size 260 | match_entity_sum = [.0] * 4 261 | cnt = 0 262 | for post, response, result, match_triples, triples, entities in zip([data['post'] for data in data_dev], [data['response'] for data in data_dev], results, [data['match_triples'] for data in data_dev], [data['all_triples'] for data in data_dev], [data['all_entities'] for data in data_dev]): 263 | setidx = cnt / setnum 264 | result_matched_entities = [] 265 | triples = [csk_triples[tri] for triple in triples for tri in triple] 266 | match_triples = [csk_triples[triple] for triple in match_triples] 267 | entities = [csk_entities[x] for entity in entities for x in entity] 268 | matches = [x for triple in match_triples for x in [triple.split(', ')[0], triple.split(', ')[2]] if x in response] 269 | 270 | for word in result: 271 | if word not in stopwords and word in entities: 272 | result_matched_entities.append(word) 273 | outfile.write('post: %s\nresponse: %s\nresult: %s\nmatch_entity: %s\n\n' % (' '.join(post), ' '.join(response), ' '.join(result), ' '.join(result_matched_entities))) 274 | match_entity_sum[setidx] += len(set(result_matched_entities)) 275 | cnt += 1 276 | match_entity_sum = [m / setnum for m in match_entity_sum] + [sum(match_entity_sum) / len(data_dev)] 277 | losses = [np.sum(loss[x:x+setnum]) / float(setnum) for x in range(0, setnum*4, setnum)] + [np.sum(loss) / float(setnum*4)] 278 | losses = [np.exp(x) for x in losses] 279 | def show(x): 280 | return ', '.join([str(v) for v in x]) 281 | outfile.write('model: %d\n\tperplexity: %s\n\tmatch_entity_rate: %s\n%s\n\n' % (step, show(losses), show(match_entity_sum), '='*50)) 282 | resfile.write('model: %d\n\tperplexity: %s\n\tmatch_entity_rate: %s\n\n' % (step, show(losses), show(match_entity_sum))) 283 | outfile.flush() 284 | resfile.flush() 285 | return results 286 | 287 | config = tf.ConfigProto() 288 | config.gpu_options.allow_growth = True 289 | with tf.Session(config=config) as sess: 290 | if FLAGS.is_train: 291 | raw_vocab, data_train, data_dev, data_test = prepare_data(FLAGS.data_dir) 292 | vocab, embed, entity_vocab, entity_embed, relation_vocab, relation_embed, entity_relation_embed = build_vocab(FLAGS.data_dir, raw_vocab) 293 | FLAGS.num_entities = len(entity_vocab) 294 | print(FLAGS.__flags) 295 | model = Model( 296 | FLAGS.symbols, 297 | FLAGS.embed_units, 298 | FLAGS.units, 299 | FLAGS.layers, 300 | embed, 301 | entity_relation_embed, 302 | num_entities=len(entity_vocab)+len(relation_vocab), 303 | num_trans_units=FLAGS.trans_units, 304 | output_alignments=FLAGS.copy_use) 305 | if tf.train.get_checkpoint_state(FLAGS.train_dir): 306 | print("Reading model parameters from %s" % FLAGS.train_dir) 307 | model.saver.restore(sess, tf.train.latest_checkpoint(FLAGS.train_dir)) 308 | else: 309 | print("Created model with fresh parameters.") 310 | tf.global_variables_initializer().run() 311 | op_in = model.symbol2index.insert(constant_op.constant(vocab), 312 | constant_op.constant(range(FLAGS.symbols), dtype=tf.int64)) 313 | sess.run(op_in) 314 | op_out = model.index2symbol.insert(constant_op.constant( 315 | range(FLAGS.symbols), dtype=tf.int64), constant_op.constant(vocab)) 316 | sess.run(op_out) 317 | op_in = model.entity2index.insert(constant_op.constant(entity_vocab+relation_vocab), 318 | constant_op.constant(range(len(entity_vocab)+len(relation_vocab)), dtype=tf.int64)) 319 | sess.run(op_in) 320 | op_out = model.index2entity.insert(constant_op.constant( 321 | range(len(entity_vocab)+len(relation_vocab)), dtype=tf.int64), constant_op.constant(entity_vocab+relation_vocab)) 322 | sess.run(op_out) 323 | 324 | if FLAGS.log_parameters: 325 | model.print_parameters() 326 | 327 | summary_writer = tf.summary.FileWriter('%s/log' % FLAGS.train_dir, sess.graph) 328 | loss_step, time_step = np.zeros((1, )), .0 329 | previous_losses = [1e18]*3 330 | train_len = len(data_train) 331 | while True: 332 | st, ed = 0, FLAGS.batch_size * FLAGS.per_checkpoint 333 | random.shuffle(data_train) 334 | while st < train_len: 335 | start_time = time.time() 336 | for batch in range(st, ed, FLAGS.batch_size): 337 | loss_step += train(model, sess, data_train[batch:batch+FLAGS.batch_size]) / (ed - st) 338 | 339 | show = lambda a: '[%s]' % (' '.join(['%.2f' % x for x in a])) 340 | print("global step %d learning rate %.4f step-time %.2f loss %f perplexity %s" 341 | % (model.global_step.eval(), model.lr, 342 | (time.time() - start_time) / (ed - st) / FLAGS.batch_size, loss_step, show(np.exp(loss_step)))) 343 | model.saver.save(sess, '%s/checkpoint' % FLAGS.train_dir, 344 | global_step=model.global_step) 345 | summary = tf.Summary() 346 | summary.value.add(tag='decoder_loss/train', simple_value=loss_step) 347 | summary.value.add(tag='perplexity/train', simple_value=np.exp(loss_step)) 348 | summary_writer.add_summary(summary, model.global_step.eval()) 349 | summary_model = generate_summary(model, sess, data_train) 350 | summary_writer.add_summary(summary_model, model.global_step.eval()) 351 | evaluate(model, sess, data_dev, summary_writer) 352 | previous_losses = previous_losses[1:]+[np.sum(loss_step)] 353 | loss_step, time_step = np.zeros((1, )), .0 354 | st, ed = ed, min(train_len, ed + FLAGS.batch_size * FLAGS.per_checkpoint) 355 | model.saver_epoch.save(sess, '%s/epoch/checkpoint' % FLAGS.train_dir, global_step=model.global_step) 356 | else: 357 | model = Model( 358 | FLAGS.symbols, 359 | FLAGS.embed_units, 360 | FLAGS.units, 361 | FLAGS.layers, 362 | embed=None, 363 | num_entities=FLAGS.num_entities+FLAGS.num_relations, 364 | num_trans_units=FLAGS.trans_units, 365 | output_alignments=FLAGS.copy_use) 366 | 367 | if FLAGS.inference_version == 0: 368 | model_path = tf.train.latest_checkpoint(FLAGS.train_dir) 369 | else: 370 | model_path = '%s/checkpoint-%08d' % (FLAGS.train_dir, FLAGS.inference_version) 371 | print('restore from %s' % model_path) 372 | model.saver.restore(sess, model_path) 373 | saver = model.saver 374 | 375 | raw_vocab, data_train, data_dev, data_test = prepare_data(FLAGS.data_dir, is_train=False) 376 | 377 | test(sess, saver, data_test, setnum=5000) 378 | 379 | -------------------------------------------------------------------------------- /memnet/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tensorflow.python.ops.nn import dynamic_rnn 5 | from tensorflow.contrib.rnn import GRUCell, LSTMCell, MultiRNNCell 6 | from tensorflow.contrib.seq2seq.python.ops.loss import sequence_loss 7 | from tensorflow.contrib.lookup.lookup_ops import MutableHashTable 8 | from tensorflow.contrib.layers.python.layers import layers 9 | from dynamic_decoder import dynamic_rnn_decoder 10 | from output_projection import output_projection_layer 11 | from attention_decoder import * 12 | from tensorflow.contrib.session_bundle import exporter 13 | 14 | PAD_ID = 0 15 | UNK_ID = 1 16 | GO_ID = 2 17 | EOS_ID = 3 18 | NONE_ID = 0 19 | _START_VOCAB = ['_PAD', '_UNK', '_GO', '_EOS'] 20 | 21 | class Model(object): 22 | def __init__(self, 23 | num_symbols, 24 | num_embed_units, 25 | num_units, 26 | num_layers, 27 | embed, 28 | entity_embed=None, 29 | num_entities=0, 30 | num_trans_units=100, 31 | learning_rate=0.0001, 32 | learning_rate_decay_factor=0.95, 33 | max_gradient_norm=5.0, 34 | num_samples=512, 35 | max_length=60, 36 | output_alignments=True, 37 | use_lstm=False): 38 | 39 | self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # batch*len 40 | self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # batch 41 | self.responses = tf.placeholder(tf.string, (None, None), 'dec_inps') # batch*len 42 | self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # batch 43 | self.entities = tf.placeholder(tf.string, (None, None), 'entities') # batch 44 | self.entity_masks = tf.placeholder(tf.string, (None, None), 'entity_masks') # batch 45 | self.triples = tf.placeholder(tf.string, (None, None, 3), 'triples') # batch 46 | self.posts_triple = tf.placeholder(tf.int32, (None, None, 1), 'enc_triples') # batch 47 | self.responses_triple = tf.placeholder(tf.string, (None, None, 3), 'dec_triples') # batch 48 | self.match_triples = tf.placeholder(tf.int32, (None, None), 'match_triples') # batch 49 | encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) 50 | triple_num = tf.shape(self.triples)[1] 51 | 52 | #use_triples = tf.reduce_sum(tf.cast(tf.greater_equal(self.match_triples, 0), tf.float32), axis=-1) 53 | one_hot_triples = tf.one_hot(self.match_triples, triple_num) 54 | use_triples = tf.reduce_sum(one_hot_triples, axis=[2]) 55 | 56 | self.symbol2index = MutableHashTable( 57 | key_dtype=tf.string, 58 | value_dtype=tf.int64, 59 | default_value=UNK_ID, 60 | shared_name="in_table", 61 | name="in_table", 62 | checkpoint=True) 63 | self.index2symbol = MutableHashTable( 64 | key_dtype=tf.int64, 65 | value_dtype=tf.string, 66 | default_value='_UNK', 67 | shared_name="out_table", 68 | name="out_table", 69 | checkpoint=True) 70 | self.entity2index = MutableHashTable( 71 | key_dtype=tf.string, 72 | value_dtype=tf.int64, 73 | default_value=NONE_ID, 74 | shared_name="entity_in_table", 75 | name="entity_in_table", 76 | checkpoint=True) 77 | self.index2entity = MutableHashTable( 78 | key_dtype=tf.int64, 79 | value_dtype=tf.string, 80 | default_value='_NONE', 81 | shared_name="entity_out_table", 82 | name="entity_out_table", 83 | checkpoint=True) 84 | # build the vocab table (string to index) 85 | 86 | 87 | self.posts_word_id = self.symbol2index.lookup(self.posts) # batch*len 88 | self.posts_entity_id = self.entity2index.lookup(self.posts) # batch*len 89 | #self.posts_word_id = tf.Print(self.posts_word_id, ['use_triples', use_triples, 'one_hot_triples', one_hot_triples], summarize=1e6) 90 | self.responses_target = self.symbol2index.lookup(self.responses) #batch*len 91 | 92 | batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(self.responses)[1] 93 | self.responses_word_id = tf.concat([tf.ones([batch_size, 1], dtype=tf.int64)*GO_ID, 94 | tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1) # batch*len 95 | self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1, 96 | decoder_len), reverse=True, axis=1), [-1, decoder_len]) 97 | 98 | # build the embedding table (index to vector) 99 | if embed is None: 100 | # initialize the embedding randomly 101 | self.embed = tf.get_variable('word_embed', [num_symbols, num_embed_units], tf.float32) 102 | else: 103 | # initialize the embedding by pre-trained word vectors 104 | self.embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed) 105 | if entity_embed is None: 106 | # initialize the embedding randomly 107 | self.entity_trans = tf.get_variable('entity_embed', [num_entities, num_trans_units], tf.float32, trainable=False) 108 | else: 109 | # initialize the embedding by pre-trained word vectors 110 | self.entity_trans = tf.get_variable('entity_embed', dtype=tf.float32, initializer=entity_embed, trainable=False) 111 | 112 | self.entity_trans_transformed = tf.layers.dense(self.entity_trans, num_trans_units, activation=tf.tanh, name='trans_transformation') 113 | padding_entity = tf.get_variable('entity_padding_embed', [7, num_trans_units], dtype=tf.float32, initializer=tf.zeros_initializer()) 114 | 115 | self.entity_embed = tf.concat([padding_entity, self.entity_trans_transformed], axis=0) 116 | 117 | triples_embedding = tf.reshape(tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.triples)), [encoder_batch_size, triple_num, 3 * num_trans_units]) 118 | entities_word_embedding = tf.reshape(tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entities)), [encoder_batch_size, -1, num_embed_units]) 119 | 120 | 121 | self.encoder_input = tf.nn.embedding_lookup(self.embed, self.posts_word_id) #batch*len*unit 122 | self.decoder_input = tf.nn.embedding_lookup(self.embed, self.responses_word_id) #batch*len*unit 123 | 124 | encoder_cell = MultiRNNCell([GRUCell(num_units) for _ in range(num_layers)]) 125 | decoder_cell = MultiRNNCell([GRUCell(num_units) for _ in range(num_layers)]) 126 | 127 | # rnn encoder 128 | encoder_output, encoder_state = dynamic_rnn(encoder_cell, self.encoder_input, 129 | self.posts_length, dtype=tf.float32, scope="encoder") 130 | 131 | # get output projection function 132 | output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss = output_projection_layer(num_units, 133 | num_symbols, num_samples) 134 | 135 | 136 | 137 | with tf.variable_scope('decoder'): 138 | # get attention function 139 | attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \ 140 | = prepare_attention(encoder_output, 'bahdanau', num_units, imem=triples_embedding, output_alignments=output_alignments)#'luong', num_units) 141 | 142 | decoder_fn_train = attention_decoder_fn_train( 143 | encoder_state, attention_keys_init, attention_values_init, 144 | attention_score_fn_init, attention_construct_fn_init, output_alignments=output_alignments, max_length=tf.reduce_max(self.responses_length)) 145 | self.decoder_output, _, alignments_ta = dynamic_rnn_decoder(decoder_cell, decoder_fn_train, 146 | self.decoder_input, self.responses_length, scope="decoder_rnn") 147 | if output_alignments: 148 | self.alignments = tf.transpose(alignments_ta.stack(), perm=[1,0,2]) 149 | #self.alignments = tf.Print(self.alignments, [self.alignments], summarize=1e8) 150 | self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss(self.decoder_output, self.responses_target, self.decoder_mask, self.alignments, triples_embedding, use_triples, one_hot_triples) 151 | self.sentence_ppx = tf.identity(self.sentence_ppx, 'ppx_loss') 152 | #self.decoder_loss = tf.Print(self.decoder_loss, ['decoder_loss', self.decoder_loss], summarize=1e6) 153 | else: 154 | self.decoder_loss, self.sentence_ppx = sequence_loss(self.decoder_output, 155 | self.responses_target, self.decoder_mask) 156 | self.sentence_ppx = tf.identity(self.sentence_ppx, 'ppx_loss') 157 | 158 | with tf.variable_scope('decoder', reuse=True): 159 | # get attention function 160 | attention_keys, attention_values, attention_score_fn, attention_construct_fn \ 161 | = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=triples_embedding, output_alignments=output_alignments)#'luong', num_units) 162 | decoder_fn_inference = attention_decoder_fn_inference( 163 | output_fn, encoder_state, attention_keys, attention_values, 164 | attention_score_fn, attention_construct_fn, self.embed, GO_ID, 165 | EOS_ID, max_length, num_symbols, imem=entities_word_embedding, selector_fn=selector_fn) 166 | 167 | 168 | self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder(decoder_cell, 169 | decoder_fn_inference, scope="decoder_rnn") 170 | if output_alignments: 171 | output_len = tf.shape(self.decoder_distribution)[1] 172 | output_ids = tf.transpose(output_ids_ta.gather(tf.range(output_len))) 173 | word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols), tf.int64) 174 | entity_ids = tf.reshape(tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape(tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]), [-1]) 175 | entities = tf.reshape(tf.gather(tf.reshape(self.entities, [-1]), entity_ids), [-1, output_len]) 176 | words = self.index2symbol.lookup(word_ids) 177 | self.generation = tf.where(output_ids > 0, words, entities, name='generation') 178 | else: 179 | self.generation_index = tf.argmax(self.decoder_distribution, 2) 180 | 181 | self.generation = self.index2symbol.lookup(self.generation_index, name='generation') 182 | 183 | 184 | # initialize the training process 185 | self.learning_rate = tf.Variable(float(learning_rate), 186 | trainable=False, dtype=tf.float32) 187 | self.learning_rate_decay_op = self.learning_rate.assign( 188 | self.learning_rate * learning_rate_decay_factor) 189 | self.global_step = tf.Variable(0, trainable=False) 190 | 191 | self.params = tf.global_variables() 192 | 193 | # calculate the gradient of parameters 194 | #opt = tf.train.GradientDescentOptimizer(self.learning_rate) 195 | opt = tf.train.AdamOptimizer(learning_rate=learning_rate) 196 | self.lr = opt._lr 197 | 198 | gradients = tf.gradients(self.decoder_loss, self.params) 199 | clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, 200 | max_gradient_norm) 201 | self.update = opt.apply_gradients(zip(clipped_gradients, self.params), 202 | global_step=self.global_step) 203 | 204 | tf.summary.scalar('decoder_loss', self.decoder_loss) 205 | for each in tf.trainable_variables(): 206 | tf.summary.histogram(each.name, each) 207 | 208 | self.merged_summary_op = tf.summary.merge_all() 209 | 210 | self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, 211 | max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) 212 | 213 | self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1000, pad_step_number=True) 214 | 215 | 216 | def print_parameters(self): 217 | for item in self.params: 218 | print('%s: %s' % (item.name, item.get_shape())) 219 | 220 | def step_decoder(self, session, data, forward_only=False, summary=False, kb_use=True): 221 | input_feed = {self.posts: data['posts'], 222 | self.posts_length: data['posts_length'], 223 | self.responses: data['responses'], 224 | self.responses_length: data['responses_length'], 225 | self.triples: data['triples'], 226 | self.match_triples: data['match_triples']} 227 | 228 | if forward_only: 229 | output_feed = [self.sentence_ppx] 230 | else: 231 | output_feed = [self.sentence_ppx, self.gradient_norm, self.update] 232 | if summary: 233 | output_feed.append(self.merged_summary_op) 234 | return session.run(output_feed, input_feed) 235 | -------------------------------------------------------------------------------- /memnet/output_projection.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.layers.python.layers import layers 3 | from tensorflow.python.ops import variable_scope 4 | 5 | def output_projection_layer(num_units, num_symbols, num_samples=None, name="output_projection"): 6 | def output_fn(outputs): 7 | return layers.linear(outputs, num_symbols, scope=name) 8 | 9 | def selector_fn(outputs): 10 | return tf.sigmoid(layers.linear(outputs, 1, scope='selector')) 11 | 12 | def sequence_loss(outputs, targets, masks): 13 | with variable_scope.variable_scope('decoder_rnn'): 14 | batch_size = tf.shape(outputs)[0] 15 | logits = layers.linear(outputs, num_symbols, scope=name) 16 | logits = tf.reshape(logits, [-1, num_symbols]) 17 | local_labels = tf.reshape(targets, [-1]) 18 | local_masks = tf.reshape(masks, [-1]) 19 | 20 | local_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=local_labels, logits=logits) 21 | local_loss = local_loss * local_masks 22 | ppx_loss = tf.reduce_sum(tf.reshape(local_loss, [batch_size, -1]), axis=1) / tf.reduce_sum(masks, axis=1) 23 | 24 | loss = tf.reduce_sum(local_loss) 25 | total_size = tf.reduce_sum(local_masks) 26 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 27 | 28 | return loss / total_size, ppx_loss 29 | 30 | def sampled_sequence_loss(outputs, targets, masks): 31 | with variable_scope.variable_scope('decoder_rnn/%s' % name): 32 | weights = tf.transpose(tf.get_variable("weights", [num_units, num_symbols])) 33 | bias = tf.get_variable("biases", [num_symbols]) 34 | 35 | local_labels = tf.reshape(targets, [-1, 1]) 36 | local_outputs = tf.reshape(outputs, [-1, num_units]) 37 | local_masks = tf.reshape(masks, [-1]) 38 | 39 | local_loss = tf.nn.sampled_softmax_loss(weights, bias, local_labels, 40 | local_outputs, num_samples, num_symbols) 41 | local_loss = local_loss * local_masks 42 | 43 | loss = tf.reduce_sum(local_loss) 44 | total_size = tf.reduce_sum(local_masks) 45 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 46 | 47 | return loss / total_size 48 | 49 | def total_loss(outputs, targets, masks, alignments, triples_embedding, use_entities, entity_targets): 50 | local_masks = tf.reshape(masks, [-1]) 51 | 52 | logits = layers.linear(outputs, num_symbols, scope='decoder_rnn/%s' % name) 53 | one_hot_targets = tf.one_hot(targets, num_symbols) 54 | word_prob = tf.reduce_sum(tf.nn.softmax(logits) * one_hot_targets, axis=2) 55 | selector = tf.squeeze(tf.sigmoid(layers.linear(outputs, 1, scope='decoder_rnn/selector'))) 56 | 57 | triple_prob = tf.reduce_sum(alignments * entity_targets, axis=[2]) 58 | cast_selector = tf.cast(tf.reduce_sum(alignments, axis=2) > tf.reduce_sum(tf.nn.softmax(logits), axis=2), tf.float32) 59 | final_prob = word_prob * (1 - selector) + triple_prob * selector 60 | ppx_prob = word_prob * (1 - use_entities) + triple_prob * use_entities 61 | final_loss = tf.reshape( - tf.log(1e-12 + final_prob), [-1]) * local_masks 62 | ppx_loss = tf.reshape( - tf.log(1e-12 + ppx_prob), [-1]) * local_masks 63 | sentence_ppx = tf.reduce_sum( - tf.log(1e-12 + ppx_prob) * masks, axis=1) 64 | 65 | loss = tf.reduce_sum(final_loss) 66 | #loss = tf.Print(loss, ['use_entity', tf.reduce_min(use_entities), tf.reduce_max(use_entities), 'triple_prob',tf.reduce_min(triple_prob), 'word_prob', tf.reduce_min(word_prob), 'final_prob', tf.reduce_min(final_prob), 'final_loss', tf.reduce_min(final_loss)], summarize=1e6) 67 | total_size = tf.reduce_sum(local_masks) 68 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 69 | 70 | return loss / total_size, tf.reduce_sum(ppx_loss) / total_size, sentence_ppx / tf.reduce_sum(masks, axis=1) 71 | 72 | 73 | 74 | return output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss 75 | 76 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tensorflow.python.ops.nn import dynamic_rnn 5 | from tensorflow.contrib.rnn import GRUCell, LSTMCell, MultiRNNCell 6 | from tensorflow.contrib.seq2seq.python.ops.loss import sequence_loss 7 | from tensorflow.contrib.lookup.lookup_ops import MutableHashTable 8 | from tensorflow.contrib.layers.python.layers import layers 9 | from dynamic_decoder import dynamic_rnn_decoder 10 | from output_projection import output_projection_layer 11 | from attention_decoder import * 12 | from tensorflow.contrib.session_bundle import exporter 13 | 14 | PAD_ID = 0 15 | UNK_ID = 1 16 | GO_ID = 2 17 | EOS_ID = 3 18 | NONE_ID = 0 19 | _START_VOCAB = ['_PAD', '_UNK', '_GO', '_EOS'] 20 | 21 | class Model(object): 22 | def __init__(self, 23 | num_symbols, 24 | num_embed_units, 25 | num_units, 26 | num_layers, 27 | embed, 28 | entity_embed=None, 29 | num_entities=0, 30 | num_trans_units=100, 31 | learning_rate=0.0001, 32 | learning_rate_decay_factor=0.95, 33 | max_gradient_norm=5.0, 34 | num_samples=500, 35 | max_length=60, 36 | mem_use=True, 37 | output_alignments=True, 38 | use_lstm=False): 39 | 40 | self.posts = tf.placeholder(tf.string, (None, None), 'enc_inps') # batch*len 41 | self.posts_length = tf.placeholder(tf.int32, (None), 'enc_lens') # batch 42 | self.responses = tf.placeholder(tf.string, (None, None), 'dec_inps') # batch*len 43 | self.responses_length = tf.placeholder(tf.int32, (None), 'dec_lens') # batch 44 | self.entities = tf.placeholder(tf.string, (None, None, None), 'entities') # batch 45 | self.entity_masks = tf.placeholder(tf.string, (None, None), 'entity_masks') # batch 46 | self.triples = tf.placeholder(tf.string, (None, None, None, 3), 'triples') # batch 47 | self.posts_triple = tf.placeholder(tf.int32, (None, None, 1), 'enc_triples') # batch 48 | self.responses_triple = tf.placeholder(tf.string, (None, None, 3), 'dec_triples') # batch 49 | self.match_triples = tf.placeholder(tf.int32, (None, None, None), 'match_triples') # batch 50 | 51 | encoder_batch_size, encoder_len = tf.unstack(tf.shape(self.posts)) 52 | triple_num = tf.shape(self.triples)[1] 53 | triple_len = tf.shape(self.triples)[2] 54 | one_hot_triples = tf.one_hot(self.match_triples, triple_len) 55 | use_triples = tf.reduce_sum(one_hot_triples, axis=[2, 3]) 56 | 57 | self.symbol2index = MutableHashTable( 58 | key_dtype=tf.string, 59 | value_dtype=tf.int64, 60 | default_value=UNK_ID, 61 | shared_name="in_table", 62 | name="in_table", 63 | checkpoint=True) 64 | self.index2symbol = MutableHashTable( 65 | key_dtype=tf.int64, 66 | value_dtype=tf.string, 67 | default_value='_UNK', 68 | shared_name="out_table", 69 | name="out_table", 70 | checkpoint=True) 71 | self.entity2index = MutableHashTable( 72 | key_dtype=tf.string, 73 | value_dtype=tf.int64, 74 | default_value=NONE_ID, 75 | shared_name="entity_in_table", 76 | name="entity_in_table", 77 | checkpoint=True) 78 | self.index2entity = MutableHashTable( 79 | key_dtype=tf.int64, 80 | value_dtype=tf.string, 81 | default_value='_NONE', 82 | shared_name="entity_out_table", 83 | name="entity_out_table", 84 | checkpoint=True) 85 | # build the vocab table (string to index) 86 | 87 | 88 | self.posts_word_id = self.symbol2index.lookup(self.posts) # batch*len 89 | self.posts_entity_id = self.entity2index.lookup(self.posts) # batch*len 90 | #self.posts_word_id = tf.Print(self.posts_word_id, ['use_triples', use_triples, 'one_hot_triples', one_hot_triples], summarize=1e6) 91 | self.responses_target = self.symbol2index.lookup(self.responses) #batch*len 92 | 93 | batch_size, decoder_len = tf.shape(self.responses)[0], tf.shape(self.responses)[1] 94 | self.responses_word_id = tf.concat([tf.ones([batch_size, 1], dtype=tf.int64)*GO_ID, 95 | tf.split(self.responses_target, [decoder_len-1, 1], 1)[0]], 1) # batch*len 96 | self.decoder_mask = tf.reshape(tf.cumsum(tf.one_hot(self.responses_length-1, 97 | decoder_len), reverse=True, axis=1), [-1, decoder_len]) 98 | 99 | # build the embedding table (index to vector) 100 | if embed is None: 101 | # initialize the embedding randomly 102 | self.embed = tf.get_variable('word_embed', [num_symbols, num_embed_units], tf.float32) 103 | else: 104 | # initialize the embedding by pre-trained word vectors 105 | self.embed = tf.get_variable('word_embed', dtype=tf.float32, initializer=embed) 106 | if entity_embed is None: 107 | # initialize the embedding randomly 108 | self.entity_trans = tf.get_variable('entity_embed', [num_entities, num_trans_units], tf.float32, trainable=False) 109 | else: 110 | # initialize the embedding by pre-trained trans vectors 111 | self.entity_trans = tf.get_variable('entity_embed', dtype=tf.float32, initializer=entity_embed, trainable=False) 112 | 113 | self.entity_trans_transformed = tf.layers.dense(self.entity_trans, num_trans_units, activation=tf.tanh, name='trans_transformation') 114 | padding_entity = tf.get_variable('entity_padding_embed', [7, num_trans_units], dtype=tf.float32, initializer=tf.zeros_initializer()) 115 | 116 | self.entity_embed = tf.concat([padding_entity, self.entity_trans_transformed], axis=0) 117 | 118 | triples_embedding = tf.reshape(tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.triples)), [encoder_batch_size, triple_num, -1, 3 * num_trans_units]) 119 | entities_word_embedding = tf.reshape(tf.nn.embedding_lookup(self.embed, self.symbol2index.lookup(self.entities)), [encoder_batch_size, -1, num_embed_units]) 120 | 121 | head, relation, tail = tf.split(triples_embedding, [num_trans_units] * 3, axis=3) 122 | 123 | with tf.variable_scope('graph_attention'): 124 | head_tail = tf.concat([head, tail], axis=3) 125 | head_tail_transformed = tf.layers.dense(head_tail, num_trans_units, activation=tf.tanh, name='head_tail_transform') 126 | relation_transformed = tf.layers.dense(relation, num_trans_units, name='relation_transform') 127 | e_weight = tf.reduce_sum(relation_transformed * head_tail_transformed, axis=3) 128 | alpha_weight = tf.nn.softmax(e_weight) 129 | graph_embed = tf.reduce_sum(tf.expand_dims(alpha_weight, 3) * head_tail, axis=2) 130 | 131 | 132 | graph_embed_input = tf.gather_nd(graph_embed, tf.concat([tf.tile(tf.reshape(tf.range(encoder_batch_size, dtype=tf.int32), [-1, 1, 1]), [1, encoder_len, 1]), self.posts_triple], axis=2)) 133 | 134 | triple_embed_input = tf.reshape(tf.nn.embedding_lookup(self.entity_embed, self.entity2index.lookup(self.responses_triple)), [batch_size, decoder_len, 3 * num_trans_units]) 135 | 136 | post_word_input = tf.nn.embedding_lookup(self.embed, self.posts_word_id) #batch*len*unit 137 | response_word_input = tf.nn.embedding_lookup(self.embed, self.responses_word_id) #batch*len*unit 138 | 139 | self.encoder_input = tf.concat([post_word_input, graph_embed_input], axis=2) 140 | self.decoder_input = tf.concat([response_word_input, triple_embed_input], axis=2) 141 | 142 | encoder_cell = MultiRNNCell([GRUCell(num_units) for _ in range(num_layers)]) 143 | decoder_cell = MultiRNNCell([GRUCell(num_units) for _ in range(num_layers)]) 144 | 145 | # rnn encoder 146 | encoder_output, encoder_state = dynamic_rnn(encoder_cell, self.encoder_input, 147 | self.posts_length, dtype=tf.float32, scope="encoder") 148 | 149 | # get output projection function 150 | output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss = output_projection_layer(num_units, 151 | num_symbols, num_samples) 152 | 153 | 154 | 155 | with tf.variable_scope('decoder'): 156 | # get attention function 157 | attention_keys_init, attention_values_init, attention_score_fn_init, attention_construct_fn_init \ 158 | = prepare_attention(encoder_output, 'bahdanau', num_units, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use)#'luong', num_units) 159 | 160 | decoder_fn_train = attention_decoder_fn_train( 161 | encoder_state, attention_keys_init, attention_values_init, 162 | attention_score_fn_init, attention_construct_fn_init, output_alignments=output_alignments and mem_use, max_length=tf.reduce_max(self.responses_length)) 163 | self.decoder_output, _, alignments_ta = dynamic_rnn_decoder(decoder_cell, decoder_fn_train, 164 | self.decoder_input, self.responses_length, scope="decoder_rnn") 165 | if output_alignments: 166 | self.alignments = tf.transpose(alignments_ta.stack(), perm=[1,0,2,3]) 167 | self.decoder_loss, self.ppx_loss, self.sentence_ppx = total_loss(self.decoder_output, self.responses_target, self.decoder_mask, self.alignments, triples_embedding, use_triples, one_hot_triples) 168 | self.sentence_ppx = tf.identity(self.sentence_ppx, name='ppx_loss') 169 | else: 170 | self.decoder_loss = sequence_loss(self.decoder_output, 171 | self.responses_target, self.decoder_mask) 172 | 173 | with tf.variable_scope('decoder', reuse=True): 174 | # get attention function 175 | attention_keys, attention_values, attention_score_fn, attention_construct_fn \ 176 | = prepare_attention(encoder_output, 'bahdanau', num_units, reuse=True, imem=(graph_embed, triples_embedding), output_alignments=output_alignments and mem_use)#'luong', num_units) 177 | decoder_fn_inference = attention_decoder_fn_inference( 178 | output_fn, encoder_state, attention_keys, attention_values, 179 | attention_score_fn, attention_construct_fn, self.embed, GO_ID, 180 | EOS_ID, max_length, num_symbols, imem=(entities_word_embedding, tf.reshape(triples_embedding, [encoder_batch_size, -1, 3*num_trans_units])), selector_fn=selector_fn) 181 | 182 | 183 | self.decoder_distribution, _, output_ids_ta = dynamic_rnn_decoder(decoder_cell, 184 | decoder_fn_inference, scope="decoder_rnn") 185 | 186 | output_len = tf.shape(self.decoder_distribution)[1] 187 | output_ids = tf.transpose(output_ids_ta.gather(tf.range(output_len))) 188 | word_ids = tf.cast(tf.clip_by_value(output_ids, 0, num_symbols), tf.int64) 189 | entity_ids = tf.reshape(tf.clip_by_value(-output_ids, 0, num_symbols) + tf.reshape(tf.range(encoder_batch_size) * tf.shape(entities_word_embedding)[1], [-1, 1]), [-1]) 190 | entities = tf.reshape(tf.gather(tf.reshape(self.entities, [-1]), entity_ids), [-1, output_len]) 191 | words = self.index2symbol.lookup(word_ids) 192 | self.generation = tf.where(output_ids > 0, words, entities) 193 | self.generation = tf.identity(self.generation, name='generation') 194 | 195 | 196 | # initialize the training process 197 | self.learning_rate = tf.Variable(float(learning_rate), 198 | trainable=False, dtype=tf.float32) 199 | self.learning_rate_decay_op = self.learning_rate.assign( 200 | self.learning_rate * learning_rate_decay_factor) 201 | self.global_step = tf.Variable(0, trainable=False) 202 | 203 | self.params = tf.global_variables() 204 | 205 | opt = tf.train.AdamOptimizer(learning_rate=learning_rate) 206 | self.lr = opt._lr 207 | 208 | gradients = tf.gradients(self.decoder_loss, self.params) 209 | clipped_gradients, self.gradient_norm = tf.clip_by_global_norm(gradients, 210 | max_gradient_norm) 211 | self.update = opt.apply_gradients(zip(clipped_gradients, self.params), 212 | global_step=self.global_step) 213 | 214 | tf.summary.scalar('decoder_loss', self.decoder_loss) 215 | for each in tf.trainable_variables(): 216 | tf.summary.histogram(each.name, each) 217 | 218 | self.merged_summary_op = tf.summary.merge_all() 219 | 220 | self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, 221 | max_to_keep=3, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) 222 | self.saver_epoch = tf.train.Saver(write_version=tf.train.SaverDef.V2, max_to_keep=1000, pad_step_number=True) 223 | 224 | 225 | def print_parameters(self): 226 | for item in self.params: 227 | print('%s: %s' % (item.name, item.get_shape())) 228 | 229 | def step_decoder(self, session, data, forward_only=False, summary=False): 230 | input_feed = {self.posts: data['posts'], 231 | self.posts_length: data['posts_length'], 232 | self.responses: data['responses'], 233 | self.responses_length: data['responses_length'], 234 | self.triples: data['triples'], 235 | self.posts_triple: data['posts_triple'], 236 | self.responses_triple: data['responses_triple'], 237 | self.match_triples: data['match_triples']} 238 | 239 | if forward_only: 240 | output_feed = [self.sentence_ppx] 241 | else: 242 | output_feed = [self.sentence_ppx, self.gradient_norm, self.update] 243 | if summary: 244 | output_feed.append(self.merged_summary_op) 245 | return session.run(output_feed, input_feed) 246 | -------------------------------------------------------------------------------- /output_projection.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.layers.python.layers import layers 3 | from tensorflow.python.ops import variable_scope 4 | 5 | def output_projection_layer(num_units, num_symbols, num_samples=None, name="output_projection"): 6 | def output_fn(outputs): 7 | return layers.linear(outputs, num_symbols, scope=name) 8 | 9 | def selector_fn(outputs): 10 | selector = tf.sigmoid(layers.linear(outputs, 1, scope='selector')) 11 | return selector 12 | 13 | def sequence_loss(outputs, targets, masks): 14 | with variable_scope.variable_scope('decoder_rnn'): 15 | logits = layers.linear(outputs, num_symbols, scope=name) 16 | logits = tf.reshape(logits, [-1, num_symbols]) 17 | local_labels = tf.reshape(targets, [-1]) 18 | local_masks = tf.reshape(masks, [-1]) 19 | 20 | local_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=local_labels, logits=logits) 21 | local_loss = local_loss * local_masks 22 | 23 | loss = tf.reduce_sum(local_loss) 24 | total_size = tf.reduce_sum(local_masks) 25 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 26 | 27 | return loss / total_size 28 | 29 | def sampled_sequence_loss(outputs, targets, masks): 30 | with variable_scope.variable_scope('decoder_rnn/%s' % name): 31 | weights = tf.transpose(tf.get_variable("weights", [num_units, num_symbols])) 32 | bias = tf.get_variable("biases", [num_symbols]) 33 | 34 | local_labels = tf.reshape(targets, [-1, 1]) 35 | local_outputs = tf.reshape(outputs, [-1, num_units]) 36 | local_masks = tf.reshape(masks, [-1]) 37 | 38 | local_loss = tf.nn.sampled_softmax_loss(weights, bias, local_labels, 39 | local_outputs, num_samples, num_symbols) 40 | local_loss = local_loss * local_masks 41 | 42 | loss = tf.reduce_sum(local_loss) 43 | total_size = tf.reduce_sum(local_masks) 44 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 45 | 46 | return loss / total_size 47 | 48 | def total_loss(outputs, targets, masks, alignments, triples_embedding, use_entities, entity_targets): 49 | batch_size = tf.shape(outputs)[0] 50 | local_masks = tf.reshape(masks, [-1]) 51 | 52 | logits = layers.linear(outputs, num_symbols, scope='decoder_rnn/%s' % name) 53 | one_hot_targets = tf.one_hot(targets, num_symbols) 54 | word_prob = tf.reduce_sum(tf.nn.softmax(logits) * one_hot_targets, axis=2) 55 | selector = tf.squeeze(tf.sigmoid(layers.linear(outputs, 1, scope='decoder_rnn/selector'))) 56 | 57 | triple_prob = tf.reduce_sum(alignments * entity_targets, axis=[2, 3]) 58 | ppx_prob = word_prob * (1 - use_entities) + triple_prob * use_entities 59 | final_prob = word_prob * (1 - selector) * (1 - use_entities) + triple_prob * selector * use_entities 60 | final_loss = tf.reduce_sum(tf.reshape( - tf.log(1e-12 + final_prob), [-1]) * local_masks) 61 | ppx_loss = tf.reduce_sum(tf.reshape( - tf.log(1e-12 + ppx_prob), [-1]) * local_masks) 62 | sentence_ppx = tf.reduce_sum(tf.reshape(tf.reshape( - tf.log(1e-12 + ppx_prob), [-1]) * local_masks, [batch_size, -1]), axis=1) 63 | selector_loss = tf.reduce_sum(tf.reshape( - tf.log(1e-12 + selector * use_entities + (1 - selector) * (1 - use_entities)), [-1]) * local_masks) 64 | 65 | loss = final_loss + selector_loss 66 | total_size = tf.reduce_sum(local_masks) 67 | total_size += 1e-12 # to avoid division by 0 for all-0 weights 68 | 69 | return loss / total_size, ppx_loss / total_size, sentence_ppx / tf.reduce_sum(masks, axis=1) 70 | 71 | 72 | 73 | return output_fn, selector_fn, sequence_loss, sampled_sequence_loss, total_loss 74 | 75 | --------------------------------------------------------------------------------