├── README.md ├── dataset.py ├── main.ipynb ├── main.py └── pointer.py /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Pointer-Networks 2 | 3 | Tensorflow implementation of Pointer Networks 4 | 5 | ## Reference 6 | - Oriol Vinyals, Meire Fortunato, Navdeep Jaitly, "Pointer Networks" [arXiv:1506.03134](http://arxiv.org/abs/1506.03134) 7 | - Oriol Vinyals, Samy Bengio, Manjunath Kudlur, "Order Matters: Sequence to sequence for sets" [arXiv:1511.06391](https://arxiv.org/abs/1511.06391) 8 | 9 | ## Task 10 | - We implemented '4.4. Sorting Experiment' of "Order Matters: Sequence to sequence for sets" 11 | - N unordered random floating point numbers between 0 and 1 => sorted order 12 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import numpy as np 4 | 5 | 6 | class DataGenerator(object): 7 | 8 | def __init__(self): 9 | """Construct a DataGenerator.""" 10 | pass 11 | 12 | def next_batch(self, batch_size, N, train_mode=True): 13 | """Return the next `batch_size` examples from this data set.""" 14 | 15 | # A sequence of random numbers from [0, 1] 16 | reader_input_batch = [] 17 | 18 | # Sorted sequence that we feed to encoder 19 | # In inference we feed an unordered sequence again 20 | decoder_input_batch = [] 21 | 22 | # Ordered sequence where one hot vector encodes position in the input array 23 | writer_outputs_batch = [] 24 | for _ in range(N): 25 | reader_input_batch.append(np.zeros([batch_size, 1])) 26 | for _ in range(N + 1): 27 | decoder_input_batch.append(np.zeros([batch_size, 1])) 28 | writer_outputs_batch.append(np.zeros([batch_size, N + 1])) 29 | 30 | for b in range(batch_size): 31 | shuffle = np.random.permutation(N) 32 | sequence = np.sort(np.random.random(N)) 33 | shuffled_sequence = sequence[shuffle] 34 | 35 | for i in range(N): 36 | reader_input_batch[i][b] = shuffled_sequence[i] 37 | if train_mode: 38 | decoder_input_batch[i + 1][b] = sequence[i] 39 | else: 40 | decoder_input_batch[i + 1][b] = shuffled_sequence[i] 41 | writer_outputs_batch[shuffle[i]][b, i + 1] = 1.0 42 | 43 | # Points to the stop symbol 44 | writer_outputs_batch[N][b, 0] = 1.0 45 | 46 | return reader_input_batch, decoder_input_batch, writer_outputs_batch 47 | -------------------------------------------------------------------------------- /main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "\"\"\"Implementation of Pointer networks: http://arxiv.org/pdf/1506.03134v1.pdf.\n", 12 | "\"\"\"" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": { 19 | "collapsed": false 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "from __future__ import absolute_import, division, print_function\n", 24 | "\n", 25 | "import random\n", 26 | "\n", 27 | "import numpy as np\n", 28 | "import tensorflow as tf\n", 29 | "\n", 30 | "from dataset import DataGenerator\n", 31 | "from pointer import pointer_decoder" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": { 38 | "collapsed": false 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "flags = tf.app.flags\n", 43 | "FLAGS = flags.FLAGS\n", 44 | "flags.DEFINE_integer('batch_size', 32, 'Batch size. ')\n", 45 | "flags.DEFINE_integer('max_steps', 10, 'Number of numbers to sort. ')\n", 46 | "flags.DEFINE_integer('rnn_size', 32, 'RNN size. ')" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": { 53 | "collapsed": false 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "class PointerNetwork(object):\n", 58 | " \n", 59 | " def __init__(self, max_len, input_size, size, num_layers, max_gradient_norm, batch_size, learning_rate, learning_rate_decay_factor):\n", 60 | " \"\"\"Create the network. A simplified network that handles only sorting.\n", 61 | " \n", 62 | " Args:\n", 63 | " max_len: maximum length of the model.\n", 64 | " input_size: size of the inputs data.\n", 65 | " size: number of units in each layer of the model.\n", 66 | " num_layers: number of layers in the model.\n", 67 | " max_gradient_norm: gradients will be clipped to maximally this norm.\n", 68 | " batch_size: the size of the batches used during training;\n", 69 | " the model construction is independent of batch_size, so it can be\n", 70 | " changed after initialization if this is convenient, e.g., for decoding.\n", 71 | " learning_rate: learning rate to start with.\n", 72 | " learning_rate_decay_factor: decay learning rate by this much when needed.\n", 73 | " \"\"\"\n", 74 | " self.batch_size = batch_size\n", 75 | " self.learning_rate = tf.Variable(float(learning_rate), trainable=False)\n", 76 | " self.learning_rate_decay_op = self.learning_rate.assign(\n", 77 | " self.learning_rate * learning_rate_decay_factor)\n", 78 | " self.global_step = tf.Variable(0, trainable=False)\n", 79 | "\n", 80 | " \n", 81 | " cell = tf.nn.rnn_cell.GRUCell(size)\n", 82 | " if num_layers > 1:\n", 83 | " cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers)\n", 84 | " \n", 85 | " self.encoder_inputs = []\n", 86 | " self.decoder_inputs = []\n", 87 | " self.decoder_targets = []\n", 88 | " self.target_weights = []\n", 89 | " for i in range(max_len):\n", 90 | " self.encoder_inputs.append(tf.placeholder(\n", 91 | " tf.float32, [batch_size, input_size], name=\"EncoderInput%d\" % i))\n", 92 | "\n", 93 | " for i in range(max_len + 1):\n", 94 | " self.decoder_inputs.append(tf.placeholder(\n", 95 | " tf.float32, [batch_size, input_size], name=\"DecoderInput%d\" % i))\n", 96 | " self.decoder_targets.append(tf.placeholder(\n", 97 | " tf.float32, [batch_size, max_len + 1], name=\"DecoderTarget%d\" % i)) # one hot\n", 98 | " self.target_weights.append(tf.placeholder(\n", 99 | " tf.float32, [batch_size, 1], name=\"TargetWeight%d\" % i))\n", 100 | "\n", 101 | " \n", 102 | " # Encoder\n", 103 | " \n", 104 | " # Need for attention\n", 105 | " encoder_outputs, final_state = tf.nn.rnn(cell, self.encoder_inputs, dtype = tf.float32)\n", 106 | " \n", 107 | " # Need a dummy output to point on it. End of decoding.\n", 108 | " encoder_outputs = [tf.zeros([FLAGS.batch_size, FLAGS.rnn_size])] + encoder_outputs\n", 109 | "\n", 110 | " # First calculate a concatenation of encoder outputs to put attention on.\n", 111 | " top_states = [tf.reshape(e, [-1, 1, cell.output_size])\n", 112 | " for e in encoder_outputs]\n", 113 | " attention_states = tf.concat(1, top_states)\n", 114 | "\n", 115 | " with tf.variable_scope(\"decoder\"):\n", 116 | " outputs, states, _ = pointer_decoder(\n", 117 | " self.decoder_inputs, final_state, attention_states, cell)\n", 118 | "\n", 119 | " with tf.variable_scope(\"decoder\", reuse=True):\n", 120 | " predictions, _, inps = pointer_decoder(\n", 121 | " self.decoder_inputs, final_state, attention_states, cell, feed_prev=True)\n", 122 | " \n", 123 | " self.predictions = predictions\n", 124 | "\n", 125 | " self.outputs = outputs\n", 126 | " self.inps = inps\n", 127 | " # move code below to a separate function as in TF examples\n", 128 | " \n", 129 | " \n", 130 | " def create_feed_dict(self, encoder_input_data, decoder_input_data, decoder_target_data):\n", 131 | " feed_dict = {}\n", 132 | " for placeholder, data in zip(self.encoder_inputs, encoder_input_data):\n", 133 | " feed_dict[placeholder] = data\n", 134 | "\n", 135 | " for placeholder, data in zip(self.decoder_inputs, decoder_input_data):\n", 136 | " feed_dict[placeholder] = data\n", 137 | "\n", 138 | " for placeholder, data in zip(self.decoder_targets, decoder_target_data):\n", 139 | " feed_dict[placeholder] = data\n", 140 | "\n", 141 | " for placeholder in self.target_weights:\n", 142 | " feed_dict[placeholder] = np.ones([self.batch_size, 1])\n", 143 | "\n", 144 | " return feed_dict\n", 145 | "\n", 146 | " def step(self):\n", 147 | "\n", 148 | " loss = 0.0\n", 149 | " for output, target, weight in zip(self.outputs, self.decoder_targets, self.target_weights):\n", 150 | " loss += tf.nn.softmax_cross_entropy_with_logits(output, target) * weight\n", 151 | "\n", 152 | " loss = tf.reduce_mean(loss)\n", 153 | "\n", 154 | " test_loss = 0.0\n", 155 | " for output, target, weight in zip(self.predictions, self.decoder_targets, self.target_weights):\n", 156 | " test_loss += tf.nn.softmax_cross_entropy_with_logits(output, target) * weight\n", 157 | "\n", 158 | " test_loss = tf.reduce_mean(test_loss)\n", 159 | "\n", 160 | " optimizer = tf.train.AdamOptimizer()\n", 161 | " train_op = optimizer.minimize(loss)\n", 162 | " \n", 163 | " train_loss_value = 0.0\n", 164 | " test_loss_value = 0.0\n", 165 | " \n", 166 | " correct_order = 0\n", 167 | " all_order = 0\n", 168 | "\n", 169 | " with tf.Session() as sess:\n", 170 | " merged = tf.merge_all_summaries()\n", 171 | " writer = tf.train.SummaryWriter(\"/tmp/pointer_logs\", sess.graph)\n", 172 | " init = tf.initialize_all_variables()\n", 173 | " sess.run(init)\n", 174 | " for i in range(10000):\n", 175 | " encoder_input_data, decoder_input_data, targets_data = dataset.next_batch(\n", 176 | " FLAGS.batch_size, FLAGS.max_steps)\n", 177 | "\n", 178 | " # Train\n", 179 | " feed_dict = self.create_feed_dict(\n", 180 | " encoder_input_data, decoder_input_data, targets_data)\n", 181 | " d_x, l = sess.run([loss, train_op], feed_dict=feed_dict)\n", 182 | " train_loss_value = 0.9 * train_loss_value + 0.1 * d_x\n", 183 | " \n", 184 | " if i % 100 == 0:\n", 185 | " print('Step: %d' % i)\n", 186 | " print(\"Train: \", train_loss_value)\n", 187 | "\n", 188 | " encoder_input_data, decoder_input_data, targets_data = dataset.next_batch(\n", 189 | " FLAGS.batch_size, FLAGS.max_steps, train_mode=False)\n", 190 | " # Test\n", 191 | " feed_dict = self.create_feed_dict(\n", 192 | " encoder_input_data, decoder_input_data, targets_data)\n", 193 | " inps_ = sess.run(self.inps, feed_dict=feed_dict)\n", 194 | "\n", 195 | " predictions = sess.run(self.predictions, feed_dict=feed_dict)\n", 196 | " \n", 197 | " test_loss_value = 0.9 * test_loss_value + 0.1 * sess.run(test_loss, feed_dict=feed_dict)\n", 198 | "\n", 199 | " if i % 100 == 0:\n", 200 | " print(\"Test: \", test_loss_value)\n", 201 | "\n", 202 | " predictions_order = np.concatenate([np.expand_dims(prediction , 0) for prediction in predictions])\n", 203 | " predictions_order = np.argmax(predictions_order, 2).transpose(1, 0)[:,0:FLAGS.max_steps]\n", 204 | " \n", 205 | " input_order = np.concatenate([np.expand_dims(encoder_input_data_ , 0) for encoder_input_data_ in encoder_input_data])\n", 206 | " input_order = np.argsort(input_order, 0).squeeze().transpose(1, 0)+1\n", 207 | " \n", 208 | " correct_order += np.sum(np.all(predictions_order == input_order,\n", 209 | " axis=1))\n", 210 | " all_order += FLAGS.batch_size\n", 211 | "\n", 212 | " if i % 100 == 0:\n", 213 | " print('Correct order / All order: %f' % (correct_order / all_order))\n", 214 | " correct_order = 0\n", 215 | " all_order = 0\n", 216 | " \n", 217 | " # print(encoder_input_data, decoder_input_data, targets_data)\n", 218 | " # print(inps_)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": { 225 | "collapsed": false 226 | }, 227 | "outputs": [], 228 | "source": [ 229 | "pointer_network = PointerNetwork(FLAGS.max_steps, 1, FLAGS.rnn_size, 1, 5, FLAGS.batch_size, 1e-2, 0.95)\n", 230 | "dataset = DataGenerator()\n", 231 | "pointer_network.step()" 232 | ] 233 | } 234 | ], 235 | "metadata": { 236 | "kernelspec": { 237 | "display_name": "Python 2", 238 | "language": "python", 239 | "name": "python2" 240 | }, 241 | "language_info": { 242 | "codemirror_mode": { 243 | "name": "ipython", 244 | "version": 2 245 | }, 246 | "file_extension": ".py", 247 | "mimetype": "text/x-python", 248 | "name": "python", 249 | "nbconvert_exporter": "python", 250 | "pygments_lexer": "ipython2", 251 | "version": "2.7.12" 252 | } 253 | }, 254 | "nbformat": 4, 255 | "nbformat_minor": 0 256 | } 257 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Implementation of Pointer networks: http://arxiv.org/pdf/1506.03134v1.pdf. 2 | """ 3 | 4 | 5 | from __future__ import absolute_import, division, print_function 6 | 7 | import random 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | from dataset import DataGenerator 13 | from pointer import pointer_decoder 14 | 15 | 16 | flags = tf.app.flags 17 | FLAGS = flags.FLAGS 18 | flags.DEFINE_integer('batch_size', 32, 'Batch size. ') 19 | flags.DEFINE_integer('max_steps', 10, 'Number of numbers to sort. ') 20 | flags.DEFINE_integer('rnn_size', 32, 'RNN size. ') 21 | 22 | 23 | class PointerNetwork(object): 24 | 25 | def __init__(self, max_len, input_size, size, num_layers, max_gradient_norm, batch_size, learning_rate, learning_rate_decay_factor): 26 | """Create the network. A simplified network that handles only sorting. 27 | 28 | Args: 29 | max_len: maximum length of the model. 30 | input_size: size of the inputs data. 31 | size: number of units in each layer of the model. 32 | num_layers: number of layers in the model. 33 | max_gradient_norm: gradients will be clipped to maximally this norm. 34 | batch_size: the size of the batches used during training; 35 | the model construction is independent of batch_size, so it can be 36 | changed after initialization if this is convenient, e.g., for decoding. 37 | learning_rate: learning rate to start with. 38 | learning_rate_decay_factor: decay learning rate by this much when needed. 39 | """ 40 | self.batch_size = batch_size 41 | self.learning_rate = tf.Variable(float(learning_rate), trainable=False) 42 | self.learning_rate_decay_op = self.learning_rate.assign( 43 | self.learning_rate * learning_rate_decay_factor) 44 | self.global_step = tf.Variable(0, trainable=False) 45 | 46 | 47 | cell = tf.nn.rnn_cell.GRUCell(size) 48 | if num_layers > 1: 49 | cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers) 50 | 51 | self.encoder_inputs = [] 52 | self.decoder_inputs = [] 53 | self.decoder_targets = [] 54 | self.target_weights = [] 55 | for i in range(max_len): 56 | self.encoder_inputs.append(tf.placeholder( 57 | tf.float32, [batch_size, input_size], name="EncoderInput%d" % i)) 58 | 59 | for i in range(max_len + 1): 60 | self.decoder_inputs.append(tf.placeholder( 61 | tf.float32, [batch_size, input_size], name="DecoderInput%d" % i)) 62 | self.decoder_targets.append(tf.placeholder( 63 | tf.float32, [batch_size, max_len + 1], name="DecoderTarget%d" % i)) # one hot 64 | self.target_weights.append(tf.placeholder( 65 | tf.float32, [batch_size, 1], name="TargetWeight%d" % i)) 66 | 67 | 68 | # Encoder 69 | 70 | # Need for attention 71 | encoder_outputs, final_state = tf.nn.rnn(cell, self.encoder_inputs, dtype = tf.float32) 72 | 73 | # Need a dummy output to point on it. End of decoding. 74 | encoder_outputs = [tf.zeros([FLAGS.batch_size, FLAGS.rnn_size])] + encoder_outputs 75 | 76 | # First calculate a concatenation of encoder outputs to put attention on. 77 | top_states = [tf.reshape(e, [-1, 1, cell.output_size]) 78 | for e in encoder_outputs] 79 | attention_states = tf.concat(1, top_states) 80 | 81 | with tf.variable_scope("decoder"): 82 | outputs, states, _ = pointer_decoder( 83 | self.decoder_inputs, final_state, attention_states, cell) 84 | 85 | with tf.variable_scope("decoder", reuse=True): 86 | predictions, _, inps = pointer_decoder( 87 | self.decoder_inputs, final_state, attention_states, cell, feed_prev=True) 88 | 89 | self.predictions = predictions 90 | 91 | self.outputs = outputs 92 | self.inps = inps 93 | # move code below to a separate function as in TF examples 94 | 95 | 96 | def create_feed_dict(self, encoder_input_data, decoder_input_data, decoder_target_data): 97 | feed_dict = {} 98 | for placeholder, data in zip(self.encoder_inputs, encoder_input_data): 99 | feed_dict[placeholder] = data 100 | 101 | for placeholder, data in zip(self.decoder_inputs, decoder_input_data): 102 | feed_dict[placeholder] = data 103 | 104 | for placeholder, data in zip(self.decoder_targets, decoder_target_data): 105 | feed_dict[placeholder] = data 106 | 107 | for placeholder in self.target_weights: 108 | feed_dict[placeholder] = np.ones([self.batch_size, 1]) 109 | 110 | return feed_dict 111 | 112 | def step(self): 113 | 114 | loss = 0.0 115 | for output, target, weight in zip(self.outputs, self.decoder_targets, self.target_weights): 116 | loss += tf.nn.softmax_cross_entropy_with_logits(output, target) * weight 117 | 118 | loss = tf.reduce_mean(loss) 119 | 120 | test_loss = 0.0 121 | for output, target, weight in zip(self.predictions, self.decoder_targets, self.target_weights): 122 | test_loss += tf.nn.softmax_cross_entropy_with_logits(output, target) * weight 123 | 124 | test_loss = tf.reduce_mean(test_loss) 125 | 126 | optimizer = tf.train.AdamOptimizer() 127 | train_op = optimizer.minimize(loss) 128 | 129 | train_loss_value = 0.0 130 | test_loss_value = 0.0 131 | 132 | correct_order = 0 133 | all_order = 0 134 | 135 | with tf.Session() as sess: 136 | merged = tf.merge_all_summaries() 137 | writer = tf.train.SummaryWriter("/tmp/pointer_logs", sess.graph) 138 | init = tf.initialize_all_variables() 139 | sess.run(init) 140 | for i in range(10000): 141 | encoder_input_data, decoder_input_data, targets_data = dataset.next_batch( 142 | FLAGS.batch_size, FLAGS.max_steps) 143 | 144 | # Train 145 | feed_dict = self.create_feed_dict( 146 | encoder_input_data, decoder_input_data, targets_data) 147 | d_x, l = sess.run([loss, train_op], feed_dict=feed_dict) 148 | train_loss_value = 0.9 * train_loss_value + 0.1 * d_x 149 | 150 | if i % 100 == 0: 151 | print('Step: %d' % i) 152 | print("Train: ", train_loss_value) 153 | 154 | encoder_input_data, decoder_input_data, targets_data = dataset.next_batch( 155 | FLAGS.batch_size, FLAGS.max_steps, train_mode=False) 156 | # Test 157 | feed_dict = self.create_feed_dict( 158 | encoder_input_data, decoder_input_data, targets_data) 159 | inps_ = sess.run(self.inps, feed_dict=feed_dict) 160 | 161 | predictions = sess.run(self.predictions, feed_dict=feed_dict) 162 | 163 | test_loss_value = 0.9 * test_loss_value + 0.1 * sess.run(test_loss, feed_dict=feed_dict) 164 | 165 | if i % 100 == 0: 166 | print("Test: ", test_loss_value) 167 | 168 | predictions_order = np.concatenate([np.expand_dims(prediction , 0) for prediction in predictions]) 169 | predictions_order = np.argmax(predictions_order, 2).transpose(1, 0)[:,0:FLAGS.max_steps] 170 | 171 | input_order = np.concatenate([np.expand_dims(encoder_input_data_ , 0) for encoder_input_data_ in encoder_input_data]) 172 | input_order = np.argsort(input_order, 0).squeeze().transpose(1, 0)+1 173 | 174 | correct_order += np.sum(np.all(predictions_order == input_order, 175 | axis=1)) 176 | all_order += FLAGS.batch_size 177 | 178 | if i % 100 == 0: 179 | print('Correct order / All order: %f' % (correct_order / all_order)) 180 | correct_order = 0 181 | all_order = 0 182 | 183 | # print(encoder_input_data, decoder_input_data, targets_data) 184 | # print(inps_) 185 | 186 | 187 | 188 | if __name__ == "__main__": 189 | # TODO: replace other with params 190 | pointer_network = PointerNetwork(FLAGS.max_steps, 1, FLAGS.rnn_size, 191 | 1, 5, FLAGS.batch_size, 1e-2, 0.95) 192 | dataset = DataGenerator() 193 | pointer_network.step() 194 | 195 | -------------------------------------------------------------------------------- /pointer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A pointer-network helper. 17 | Based on attenton_decoder implementation from TensorFlow 18 | https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | from six.moves import xrange # pylint: disable=redefined-builtin 26 | 27 | import tensorflow as tf 28 | 29 | from tensorflow.python.framework import dtypes 30 | from tensorflow.python.framework import ops 31 | from tensorflow.python.ops import array_ops 32 | from tensorflow.python.ops import control_flow_ops 33 | from tensorflow.python.ops import embedding_ops 34 | from tensorflow.python.ops import math_ops 35 | from tensorflow.python.ops import nn_ops 36 | from tensorflow.python.ops import rnn 37 | from tensorflow.python.ops import rnn_cell 38 | from tensorflow.python.ops import sparse_ops 39 | from tensorflow.python.ops import variable_scope as vs 40 | 41 | 42 | def pointer_decoder(decoder_inputs, initial_state, attention_states, cell, 43 | feed_prev=True, dtype=dtypes.float32, scope=None): 44 | """RNN decoder with pointer net for the sequence-to-sequence model. 45 | Args: 46 | decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. 47 | initial_state: 2D Tensor [batch_size x cell.state_size]. 48 | attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 49 | cell: rnn_cell.RNNCell defining the cell function and size. 50 | dtype: The dtype to use for the RNN initial state (default: tf.float32). 51 | scope: VariableScope for the created subgraph; default: "pointer_decoder". 52 | Returns: 53 | outputs: A list of the same length as decoder_inputs of 2D Tensors of shape 54 | [batch_size x output_size]. These represent the generated outputs. 55 | Output i is computed from input i (which is either i-th decoder_inputs. 56 | First, we run the cell 57 | on a combination of the input and previous attention masks: 58 | cell_output, new_state = cell(linear(input, prev_attn), prev_state). 59 | Then, we calculate new attention masks: 60 | new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) 61 | and then we calculate the output: 62 | output = linear(cell_output, new_attn). 63 | states: The state of each decoder cell in each time-step. This is a list 64 | with length len(decoder_inputs) -- one item for each time-step. 65 | Each item is a 2D Tensor of shape [batch_size x cell.state_size]. 66 | """ 67 | if not decoder_inputs: 68 | raise ValueError("Must provide at least 1 input to attention decoder.") 69 | if not attention_states.get_shape()[1:2].is_fully_defined(): 70 | raise ValueError("Shape[1] and [2] of attention_states must be known: %s" 71 | % attention_states.get_shape()) 72 | 73 | with vs.variable_scope(scope or "point_decoder"): 74 | batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping. 75 | input_size = decoder_inputs[0].get_shape()[1].value 76 | attn_length = attention_states.get_shape()[1].value 77 | attn_size = attention_states.get_shape()[2].value 78 | 79 | # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. 80 | hidden = array_ops.reshape( 81 | attention_states, [-1, attn_length, 1, attn_size]) 82 | 83 | attention_vec_size = attn_size # Size of query vectors for attention. 84 | k = vs.get_variable("AttnW", [1, 1, attn_size, attention_vec_size]) 85 | hidden_features = nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME") 86 | v = vs.get_variable("AttnV", [attention_vec_size]) 87 | 88 | states = [initial_state] 89 | 90 | def attention(query): 91 | """Point on hidden using hidden_features and query.""" 92 | with vs.variable_scope("Attention"): 93 | y = rnn_cell._linear(query, attention_vec_size, True) 94 | y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) 95 | # Attention mask is a softmax of v^T * tanh(...). 96 | s = math_ops.reduce_sum( 97 | v * math_ops.tanh(hidden_features + y), [2, 3]) 98 | return s 99 | 100 | outputs = [] 101 | prev = None 102 | batch_attn_size = array_ops.pack([batch_size, attn_size]) 103 | attns = array_ops.zeros(batch_attn_size, dtype=dtype) 104 | 105 | attns.set_shape([None, attn_size]) 106 | inps = [] 107 | for i in xrange(len(decoder_inputs)): 108 | if i > 0: 109 | vs.get_variable_scope().reuse_variables() 110 | inp = decoder_inputs[i] 111 | 112 | if feed_prev and i > 0: 113 | inp = tf.pack(decoder_inputs) 114 | inp = tf.transpose(inp, perm=[1, 0, 2]) 115 | inp = tf.reshape(inp, [-1, attn_length, input_size]) 116 | inp = tf.reduce_sum(inp * tf.reshape(tf.nn.softmax(output), [-1, attn_length, 1]), 1) 117 | inp = tf.stop_gradient(inp) 118 | inps.append(inp) 119 | 120 | # Use the same inputs in inference, order internaly 121 | 122 | # Merge input and previous attentions into one vector of the right size. 123 | x = rnn_cell._linear([inp, attns], cell.output_size, True) 124 | # Run the RNN. 125 | cell_output, new_state = cell(x, states[-1]) 126 | states.append(new_state) 127 | # Run the attention mechanism. 128 | output = attention(new_state) 129 | 130 | outputs.append(output) 131 | 132 | return outputs, states, inps 133 | --------------------------------------------------------------------------------