├── .gitignore ├── README.md ├── corpus ├── dev.txt.in ├── dev.txt.out ├── test.txt.in ├── test.txt.out ├── train.txt.in └── train.txt.out ├── models └── config.json ├── requirements.txt ├── seq2seq.py ├── seq2seq_model.py ├── utils ├── data_utils.py └── distributions.py └── vrae.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gerating Sentences from a Continuous Space 2 | 3 | Tensorflow implementation of [Generating Sentences from a Continuous Space](https://arxiv.org/abs/1511.06349). 4 | 5 | ## Prerequisites 6 | 1. Python packages: 7 | - Python 3.4 or higher 8 | - Tensorflow r0.12 9 | - Numpy 10 | 11 | ## Setting up the environment: 12 | 1. Clone this repository: 13 | ```shell= 14 | git clone https://github.com/Chung-I/Variational-Recurrent-Autoencoder-Tensorflow.git 15 | ``` 16 | 2. Set up conda environment: 17 | ```=bash 18 | conda create -n vrae python=3.6 19 | conda activate vrae 20 | ``` 21 | 3. Install python package requirements: 22 | ```=bash 23 | pip install -r requirements.txt 24 | ``` 25 | ## Usage 26 | 27 | 28 | Training: 29 | ```shell= 30 | python vrae.py --model_dir models --do train --new True 31 | ``` 32 | 33 | Reconstruct: 34 | ```shell= 35 | python vrae.py --model_dir models --do reconstruct --new False --input input.txt --output output.txt 36 | ``` 37 | 38 | Sample (this script read only the first line of `input.txt`, generate `num_pts` samples, and write them into `output.txt`): 39 | ```shell= 40 | python vrae.py --model_dir models --do sample --new False --input input.txt --output output.txt 41 | ``` 42 | 43 | Interpolate (this script requires that `input.txt` consists of only two sentences; it generate `num_pts` interpolations between them, and write those interpolated sentences into `output.txt`):: 44 | ```shell= 45 | python vrae.py --model_dir models --do interpolate --new False --input input.txt --output output.txt 46 | ``` 47 | 48 | `model_dir`: The location of the config file `config.json` and the checkpoint file. 49 | 50 | `do`: Accept 4 values: `train`, `encode_decode`, `sample`, or `interpolate`. 51 | 52 | `new`: create models with fresh parameters if set to `True`; else read model parameters from checkpoints in `model_dir`. 53 | 54 | ## config.json 55 | 56 | Hyperparameters are not passed from command prompt like that in [tensorflow/models/rnn/translate/translate.py](https://github.com/tensorflow/tensorflow/blob/r0.12/tensorflow/models/rnn/translate/translate.py). Instead, [vrae.py](https://github.com/Chung-I/Variational-Recurrent-Autoencoder-Tensorflow/blob/master/vrae.py) reads hyperparameters from [config.json](https://github.com/Chung-I/Variational-Recurrent-Autoencoder-Tensorflow/blob/master/models/config.json) in `model_dir`. 57 | 58 | Below are hyperparameters in [config.json](https://github.com/Chung-I/Variational-Recurrent-Autoencoder-Tensorflow/blob/master/models/config.json): 59 | 60 | - `model`: 61 | - `size`: embedding size, and encoder/decoder state size. 62 | - `latent_dim`: latent space size. 63 | - `in_vocab_size`: source vocabulary size. 64 | - `out_vocab_size`: target vocabulary size. 65 | - `data_dir`: path to the corpus. 66 | - `num_layers`: number of layers for encoder and decoder. 67 | - `use_lstm`: use lstm for encoder and decoder or not. Use `BasicLSTMCell` if set to `True`; else `GRUCell` is used. 68 | - `buckets`: A list of pairs of [input size, output size] for each bucket. 69 | - `bidirectional`: `bidirectional_rnn` is used if set to `True`. 70 | - `probablistic`: variance is set to zero if set to `False`. 71 | - `orthogonal_initializer`: `orthogonal_initializer` is used if set to `True`; else `uniform_unit_scaling_initializer` is used. 72 | - `iaf`: [inverse autoregressive flow](https://github.com/openai/iaf) is used if set to `True`. 73 | - `activation`: activation for encoder-to-latent layer and latent-to-decoder layer. 74 | - `elu`: exponential linear unit. 75 | - `prelu`: parametric linear unit. (default) 76 | - `None`: linear. 77 | - `train`: 78 | - `batch_size` 79 | - `beam_size`: beam size for decoding. __Warning__: beam search is still under implementation. `NotImplementedError` would be raised if `beam_size` is set to be greater than 1. 80 | - `learning_rate`: learning rate parameter passed into `AdamOptimizer`. 81 | - `steps_per_checkpoint`: save checkpoint every `steps_per_checkpoint` steps. 82 | - `anneal`: do [KL cost annealing](https://aclweb.org/anthology/K/K16/K16-1002.pdf#page=4) if set to `True`. 83 | - `kl_rate_rise_factor`: KL term weight is increasd by this much every `steps_per_checkpoint` steps. 84 | - `max_train_data_size`: Limit on the size of training data (0: no limit). 85 | - `feed_previous`: If `True`, only the first of decoder_inputs will be 86 | used (the "GO" symbol), and all other decoder inputs will be generated by: `next = embedding_lookup(embedding, argmax(previous_output))`. In effect, this implements a greedy decoder. It can also be used during training to emulate http://arxiv.org/abs/1506.03099. If `False`, `decoder_inputs` are used as given (the standard decoder case). 87 | - `kl_min`: the [minimum information constraint](https://arxiv.org/pdf/1606.04934v1.pdf#page=7). Should be a non-negative float (where 0 is no constraint). 88 | - `max_gradient_norm`: gradients will be clipped to maximally this norm. 89 | - `word_dropout_keep_prob`: probability of randomly replacing some fraction of the conditioned-on word tokens with the generic unknown word token `UNK`. when equal to 0, the decoder sees no input. 90 | 91 | - reconstruct: 92 | - `feed_previous` 93 | - `word_dropout_keep_prob` 94 | - sample: 95 | - `feed_previous` 96 | - `word_dropout_keep_prob` 97 | - `num_pts`: sample `num_pts` points. 98 | - interpolate: 99 | - `feed_previous` 100 | - `word_dropout_keep_prob` 101 | - `num_pts`: sample `num_pts` points. 102 | 103 | ## Data 104 | 105 | Penn TreeBank corpus is included in the repo. We also provide a [Chinese poem corpus](https://drive.google.com/file/d/178u6rYoupyT9crrIxXwHBMQ9v-7VhyqL/view?usp=sharing), [its preprocessed version](https://drive.google.com/file/d/1jfUuuVDf0dg9KZtof7Q-gotVmo-Pd3cF/view?usp=sharing) (set `{"model":{"data_dir": ""}}` in `/config.json` to it), and [its pretrained model](https://drive.google.com/file/d/1jfUuuVDf0dg9KZtof7Q-gotVmo-Pd3cF/view?usp=sharing) (set `model_dir` to it), all of which can be found [here](https://drive.google.com/drive/folders/1d7185c4qL6laphyEf5GZRV0I-2aSZcRK?usp=sharing). 106 | -------------------------------------------------------------------------------- /models/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "size": 256, 4 | "latent_dim": 16, 5 | "en_vocab_size": 20000, 6 | "fr_vocab_size": 20000, 7 | "data_dir": "corpus", 8 | "num_layers": 1, 9 | "use_lstm": false, 10 | "buckets": [[18,19]], 11 | "bidirectional": false, 12 | "probabilistic": true, 13 | "orthogonal_initializer": true, 14 | "iaf": true, 15 | "activation": "prelu" 16 | }, 17 | "train": { 18 | "batch_size": 256, 19 | "beam_size": 1, 20 | "learning_rate": 0.001, 21 | "kl_rate_rise_factor": 0.001, 22 | "kl_rate_rise_time": 50000, 23 | "max_train_data_size": 0, 24 | "steps_per_checkpoint": 2000, 25 | "feed_previous": true, 26 | "kl_min": 4, 27 | "max_gradient_norm": 5.0, 28 | "word_dropout_keep_prob": 0.0, 29 | "anneal": false 30 | }, 31 | "reconstruct": { 32 | "feed_previous": true, 33 | "word_dropout_keep_prob": 0.0 34 | }, 35 | "sample": { 36 | "feed_previous": true, 37 | "word_dropout_keep_prob": 0.0, 38 | "num_pts": 10 39 | }, 40 | "interpolate": { 41 | "feed_previous": true, 42 | "word_dropout_keep_prob": 0.0, 43 | "num_pts": 10 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==0.12.1 2 | -------------------------------------------------------------------------------- /seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Library for creating sequence-to-sequence models in TensorFlow. 17 | 18 | Sequence-to-sequence recurrent neural networks can learn complex functions 19 | that map input sequences to output sequences. These models yield very good 20 | results on a number of tasks, such as speech recognition, parsing, machine 21 | translation, or even constructing automated replies to emails. 22 | 23 | Before using this module, it is recommended to read the TensorFlow tutorial 24 | on sequence-to-sequence models. It explains the basic concepts of this module 25 | and shows an end-to-end example of how to build a translation model. 26 | https://www.tensorflow.org/versions/master/tutorials/seq2seq/index.html 27 | 28 | Here is an overview of functions available in this module. They all use 29 | a very similar interface, so after reading the above tutorial and using 30 | one of them, others should be easy to substitute. 31 | 32 | * Full sequence-to-sequence models. 33 | - basic_rnn_seq2seq: The most basic RNN-RNN model. 34 | - tied_rnn_seq2seq: The basic model with tied encoder and decoder weights. 35 | - embedding_rnn_seq2seq: The basic model with input embedding. 36 | - embedding_tied_rnn_seq2seq: The tied model with input embedding. 37 | - embedding_attention_seq2seq: Advanced model with input embedding and 38 | the neural attention mechanism; recommended for complex tasks. 39 | 40 | * Multi-task sequence-to-sequence models. 41 | - one2many_rnn_seq2seq: The embedding model with multiple decoders. 42 | 43 | * Decoders (when you write your own encoder, you can use these to decode; 44 | e.g., if you want to write a model that generates captions for images). 45 | - rnn_decoder: The basic decoder based on a pure RNN. 46 | - attention_decoder: A decoder that uses the attention mechanism. 47 | 48 | * Losses. 49 | - sequence_loss: Loss for a sequence model returning average log-perplexity. 50 | - sequence_loss_by_example: As above, but not averaging over all examples. 51 | 52 | * model_with_buckets: A convenience function to create models with bucketing 53 | (see the tutorial above for an explanation of why and how to use it). 54 | """ 55 | 56 | from __future__ import absolute_import 57 | from __future__ import division 58 | from __future__ import print_function 59 | 60 | # We disable pylint because we need python3 compatibility. 61 | from six.moves import xrange # pylint: disable=redefined-builtin 62 | from six.moves import zip # pylint: disable=redefined-builtin 63 | 64 | from tensorflow.python import shape 65 | from tensorflow.python.framework import dtypes 66 | from tensorflow.python.framework import ops 67 | from tensorflow.python.ops import array_ops 68 | from tensorflow.python.ops import control_flow_ops 69 | from tensorflow.python.ops import embedding_ops 70 | from tensorflow.python.ops import math_ops 71 | from tensorflow.python.ops import nn_ops 72 | from tensorflow.python.ops import rnn 73 | from tensorflow.python.ops import rnn_cell 74 | from tensorflow.python.ops import variable_scope 75 | from tensorflow.python.util import nest 76 | import tensorflow as tf 77 | import numpy as np 78 | from utils.distributions import DiagonalGaussian 79 | 80 | # TODO(ebrevdo): Remove once _linear is fully deprecated. 81 | linear = rnn_cell._linear # pylint: disable=protected-access 82 | 83 | 84 | 85 | def prelu(_x): 86 | with tf.variable_scope("prelu"): 87 | alphas = tf.get_variable('alpha', _x.get_shape()[-1], 88 | initializer=tf.constant_initializer(0.0), 89 | dtype=tf.float32) 90 | pos = tf.nn.relu(_x) 91 | neg = alphas * (_x - abs(_x)) * 0.5 92 | return pos + neg 93 | 94 | 95 | def _extract_argmax_and_embed(embedding, output_projection=None, 96 | update_embedding=True): 97 | """Get a loop_function that extracts the previous symbol and embeds it. 98 | 99 | Args: 100 | embedding: embedding tensor for symbols. 101 | output_projection: None or a pair (W, B). If provided, each fed previous 102 | output will first be multiplied by W and added B. 103 | update_embedding: Boolean; if False, the gradients will not propagate 104 | through the embeddings. 105 | 106 | Returns: 107 | A loop function. 108 | """ 109 | def loop_function(prev, _): 110 | if output_projection is not None: 111 | prev = nn_ops.xw_plus_b( 112 | prev, output_projection[0], output_projection[1]) 113 | prev_symbol = math_ops.argmax(prev, 1) 114 | # Note that gradients will not propagate through the second parameter of 115 | # embedding_lookup. 116 | emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol) 117 | if not update_embedding: 118 | emb_prev = array_ops.stop_gradient(emb_prev) 119 | return emb_prev 120 | return loop_function 121 | 122 | 123 | def rnn_decoder(decoder_inputs, initial_state, cell, word_dropout_keep_prob=1, replace_inp=None, 124 | loop_function=None, scope=None): 125 | """RNN decoder for the sequence-to-sequence model. 126 | 127 | Args: 128 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 129 | initial_state: 2D Tensor with shape [batch_size x cell.state_size]. 130 | cell: rnn_cell.RNNCell defining the cell function and size. 131 | loop_function: If not None, this function will be applied to the i-th output 132 | in order to generate the i+1-st input, and decoder_inputs will be ignored, 133 | except for the first element ("GO" symbol). This can be used for decoding, 134 | but also for training to emulate http://arxiv.org/abs/1506.03099. 135 | Signature -- loop_function(prev, i) = next 136 | * prev is a 2D Tensor of shape [batch_size x output_size], 137 | * i is an integer, the step number (when advanced control is needed), 138 | * next is a 2D Tensor of shape [batch_size x input_size]. 139 | scope: VariableScope for the created subgraph; defaults to "rnn_decoder". 140 | 141 | Returns: 142 | A tuple of the form (outputs, state), where: 143 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 144 | shape [batch_size x output_size] containing generated outputs. 145 | state: The state of each cell at the final time-step. 146 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 147 | (Note that in some cases, like basic RNN cell or GRU cell, outputs and 148 | states can be the same. They are different for LSTM cells though.) 149 | """ 150 | with variable_scope.variable_scope(scope or "rnn_decoder"): 151 | state = initial_state 152 | outputs = [] 153 | prev = None 154 | seq_len = len(decoder_inputs) 155 | keep = tf.select(tf.random_uniform([seq_len]) < word_dropout_keep_prob, 156 | tf.fill([seq_len], True), tf.fill([seq_len], False)) 157 | for i, inp in enumerate(decoder_inputs): 158 | if loop_function is not None and prev is not None: 159 | with variable_scope.variable_scope("loop_function", reuse=True): 160 | if word_dropout_keep_prob < 1: 161 | inp = tf.cond(keep[i], lambda: loop_function(prev, i), lambda: replace_inp) 162 | else: 163 | inp = loop_function(prev, i) 164 | if i > 0: 165 | variable_scope.get_variable_scope().reuse_variables() 166 | output, state = cell(inp, state) 167 | outputs.append(output) 168 | if loop_function is not None: 169 | prev = output 170 | return outputs, state 171 | 172 | 173 | def beam_rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, 174 | scope=None,output_projection=None, beam_size=1): 175 | """RNN decoder for the sequence-to-sequence model. 176 | 177 | Args: 178 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 179 | initial_state: 2D Tensor with shape [batch_size x cell.state_size]. 180 | cell: rnn_cell.RNNCell defining the cell function and size. 181 | loop_function: If not None, this function will be applied to the i-th output 182 | in order to generate the i+1-st input, and decoder_inputs will be ignored, 183 | except for the first element ("GO" symbol). This can be used for decoding, 184 | but also for training to emulate http://arxiv.org/abs/1506.03099. 185 | Signature -- loop_function(prev, i) = next 186 | * prev is a 2D Tensor of shape [batch_size x output_size], 187 | * i is an integer, the step number (when advanced control is needed), 188 | * next is a 2D Tensor of shape [batch_size x input_size]. 189 | scope: VariableScope for the created subgraph; defaults to "rnn_decoder". 190 | 191 | Returns: 192 | A tuple of the form (outputs, state), where: 193 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 194 | shape [batch_size x output_size] containing generated outputs. 195 | state: The state of each cell at the final time-step. 196 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 197 | (Note that in some cases, like basic RNN cell or GRU cell, outputs and 198 | states can be the same. They are different for LSTM cells though.) 199 | """ 200 | with variable_scope.variable_scope(scope or "rnn_decoder"): 201 | state = initial_state 202 | outputs = [] 203 | prev = None 204 | log_beam_probs, beam_path, beam_symbols = [],[],[] 205 | state_size = int(initial_state.get_shape().with_rank(2)[1]) 206 | 207 | for i, inp in enumerate(decoder_inputs): 208 | if loop_function is not None and prev is not None: 209 | with variable_scope.variable_scope("loop_function", reuse=True): 210 | inp = loop_function(prev, i,log_beam_probs, beam_path, beam_symbols) 211 | if i > 0: 212 | variable_scope.get_variable_scope().reuse_variables() 213 | 214 | input_size = inp.get_shape().with_rank(2)[1] 215 | x = inp 216 | output, state = cell(x, state) 217 | 218 | if loop_function is not None: 219 | prev = output 220 | if i ==0: 221 | states =[] 222 | for kk in range(beam_size): 223 | states.append(state) 224 | state = tf.reshape(tf.concat(0, states), [-1, state_size]) 225 | 226 | outputs.append(tf.argmax(nn_ops.xw_plus_b( 227 | output, output_projection[0], output_projection[1]), dimension=1)) 228 | return outputs, state, tf.reshape(tf.concat(0, beam_path),[-1,beam_size]), tf.reshape(tf.concat(0, beam_symbols),[-1,beam_size]) 229 | 230 | 231 | def embedding_rnn_decoder(decoder_inputs, 232 | initial_state, 233 | cell, 234 | embedding, 235 | num_symbols, 236 | embedding_size, 237 | word_dropout_keep_prob=1, 238 | replace_input=None, 239 | output_projection=None, 240 | feed_previous=False, 241 | update_embedding_for_previous=True, 242 | weight_initializer=None, 243 | beam_size=1, 244 | scope=None): 245 | """RNN decoder with embedding and a pure-decoding option. 246 | 247 | Args: 248 | decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 249 | initial_state: 2D Tensor [batch_size x cell.state_size]. 250 | cell: rnn_cell.RNNCell defining the cell function. 251 | num_symbols: Integer, how many symbols come into the embedding. 252 | embedding_size: Integer, the length of the embedding vector for each symbol. 253 | output_projection: None or a pair (W, B) of output projection weights and 254 | biases; W has shape [output_size x num_symbols] and B has 255 | shape [num_symbols]; if provided and feed_previous=True, each fed 256 | previous output will first be multiplied by W and added B. 257 | feed_previous: Boolean; if True, only the first of decoder_inputs will be 258 | used (the "GO" symbol), and all other decoder inputs will be generated by: 259 | next = embedding_lookup(embedding, argmax(previous_output)), 260 | In effect, this implements a greedy decoder. It can also be used 261 | during training to emulate http://arxiv.org/abs/1506.03099. 262 | If False, decoder_inputs are used as given (the standard decoder case). 263 | update_embedding_for_previous: Boolean; if False and feed_previous=True, 264 | only the embedding for the first symbol of decoder_inputs (the "GO" 265 | symbol) will be updated by back propagation. Embeddings for the symbols 266 | generated from the decoder itself remain unchanged. This parameter has 267 | no effect if feed_previous=False. 268 | scope: VariableScope for the created subgraph; defaults to 269 | "embedding_rnn_decoder". 270 | 271 | Returns: 272 | A tuple of the form (outputs, state), where: 273 | outputs: A list of the same length as decoder_inputs of 2D Tensors. The 274 | output is of shape [batch_size x cell.output_size] when 275 | output_projection is not None (and represents the dense representation 276 | of predicted tokens). It is of shape [batch_size x num_decoder_symbols] 277 | when output_projection is None. 278 | state: The state of each decoder cell in each time-step. This is a list 279 | with length len(decoder_inputs) -- one item for each time-step. 280 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 281 | 282 | Raises: 283 | ValueError: When output_projection has the wrong shape. 284 | """ 285 | with variable_scope.variable_scope(scope or "embedding_rnn_decoder") as scope: 286 | if output_projection is not None: 287 | dtype = scope.dtype 288 | proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) 289 | proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 290 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 291 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 292 | 293 | if not embedding: 294 | embedding = variable_scope.get_variable("embedding", [num_symbols, embedding_size], 295 | initializer=weight_initializer()) 296 | 297 | if beam_size > 1: 298 | loop_function = _extract_beam_search( 299 | embedding, beam_size,num_symbols,embedding_size, output_projection, 300 | update_embedding_for_previous) 301 | else: 302 | loop_function = _extract_argmax_and_embed( 303 | embedding, output_projection, 304 | update_embedding_for_previous) if feed_previous else None 305 | 306 | emb_inp = [ 307 | embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs] 308 | if beam_size > 1: 309 | return beam_rnn_decoder(emb_inp, initial_state, cell,loop_function=loop_function, 310 | output_projection=output_projection, beam_size=beam_size) 311 | 312 | return rnn_decoder(emb_inp, initial_state, cell, word_dropout_keep_prob, replace_input, 313 | loop_function=loop_function) 314 | 315 | 316 | def embedding_attention_encoder(encoder_inputs, 317 | cell, 318 | num_encoder_symbols, 319 | embedding_size, 320 | dtype=None, 321 | scope=None): 322 | """Embedding sequence-to-sequence model with attention. 323 | 324 | This model first embeds encoder_inputs by a newly created embedding (of shape 325 | [num_encoder_symbols x input_size]). Then it runs an RNN to encode 326 | embedded encoder_inputs into a state vector. It keeps the outputs of this 327 | RNN at every step to use for attention later. Next, it embeds decoder_inputs 328 | by another newly created embedding (of shape [num_decoder_symbols x 329 | input_size]). Then it runs attention decoder, initialized with the last 330 | encoder state, on embedded decoder_inputs and attending to encoder outputs. 331 | 332 | Warning: when output_projection is None, the size of the attention vectors 333 | and variables will be made proportional to num_decoder_symbols, can be large. 334 | 335 | Args: 336 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 337 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 338 | cell: rnn_cell.RNNCell defining the cell function and size. 339 | num_encoder_symbols: Integer; number of symbols on the encoder side. 340 | num_decoder_symbols: Integer; number of symbols on the decoder side. 341 | embedding_size: Integer, the length of the embedding vector for each symbol. 342 | num_heads: Number of attention heads that read from attention_states. 343 | output_projection: None or a pair (W, B) of output projection weights and 344 | biases; W has shape [output_size x num_decoder_symbols] and B has 345 | shape [num_decoder_symbols]; if provided and feed_previous=True, each 346 | fed previous output will first be multiplied by W and added B. 347 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 348 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 349 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 350 | If False, decoder_inputs are used as given (the standard decoder case). 351 | dtype: The dtype of the initial RNN state (default: tf.float32). 352 | scope: VariableScope for the created subgraph; defaults to 353 | "embedding_attention_seq2seq". 354 | initial_state_attention: If False (default), initial attentions are zero. 355 | If True, initialize the attentions from the initial state and attention 356 | states. 357 | 358 | Returns: 359 | A tuple of the form (outputs, state), where: 360 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 361 | shape [batch_size x num_decoder_symbols] containing the generated 362 | outputs. 363 | state: The state of each decoder cell at the final time-step. 364 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 365 | """ 366 | with variable_scope.variable_scope( 367 | scope or "embedding_attention_encoder", dtype=dtype) as scope: 368 | dtype = scope.dtype 369 | # Encoder. 370 | encoder_cell = rnn_cell.EmbeddingWrapper( 371 | cell, embedding_classes=num_encoder_symbols, 372 | embedding_size=embedding_size) 373 | encoder_outputs, encoder_state = rnn.rnn( 374 | encoder_cell, encoder_inputs, dtype=dtype) 375 | 376 | # First calculate a concatenation of encoder outputs to put attention on. 377 | top_states = [array_ops.reshape(e, [-1, 1, cell.output_size]) 378 | for e in encoder_outputs] 379 | attention_states = array_ops.concat(1, top_states) 380 | 381 | return encoder_state, attention_states 382 | 383 | 384 | def embedding_encoder(encoder_inputs, 385 | cell, 386 | embedding, 387 | num_symbols, 388 | embedding_size, 389 | bidirectional=False, 390 | dtype=None, 391 | weight_initializer=None, 392 | scope=None): 393 | 394 | with variable_scope.variable_scope( 395 | scope or "embedding_encoder", dtype=dtype) as scope: 396 | dtype = scope.dtype 397 | # Encoder. 398 | if not embedding: 399 | embedding = variable_scope.get_variable("embedding", [num_symbols, embedding_size], 400 | initializer=weight_initializer()) 401 | emb_inp = [embedding_ops.embedding_lookup(embedding, i) for i in encoder_inputs] 402 | if bidirectional: 403 | _, output_state_fw, output_state_bw = rnn.bidirectional_rnn(cell, cell, emb_inp, 404 | dtype=dtype) 405 | encoder_state = tf.concat(1, [output_state_fw, output_state_bw]) 406 | else: 407 | _, encoder_state = rnn.rnn( 408 | cell, emb_inp, dtype=dtype) 409 | 410 | return encoder_state 411 | 412 | 413 | def sequence_loss_by_example(logits, targets, weights, 414 | average_across_timesteps=True, 415 | softmax_loss_function=None, name=None): 416 | """Weighted cross-entropy loss for a sequence of logits (per example). 417 | 418 | Args: 419 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 420 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 421 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 422 | average_across_timesteps: If set, divide the returned cost by the total 423 | label weight. 424 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 425 | to be used instead of the standard softmax (the default if this is None). 426 | name: Optional name for this operation, default: "sequence_loss_by_example". 427 | 428 | Returns: 429 | 1D batch-sized float Tensor: The log-perplexity for each sequence. 430 | 431 | Raises: 432 | ValueError: If len(logits) is different from len(targets) or len(weights). 433 | """ 434 | if len(targets) != len(logits) or len(weights) != len(logits): 435 | raise ValueError("Lengths of logits, weights, and targets must be the same " 436 | "%d, %d, %d." % (len(logits), len(weights), len(targets))) 437 | with ops.name_scope(name, "sequence_loss_by_example", 438 | logits + targets + weights): 439 | log_perp_list = [] 440 | for logit, target, weight in zip(logits, targets, weights): 441 | if softmax_loss_function is None: 442 | # TODO(irving,ebrevdo): This reshape is needed because 443 | # sequence_loss_by_example is called with scalars sometimes, which 444 | # violates our general scalar strictness policy. 445 | target = array_ops.reshape(target, [-1]) 446 | crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( 447 | logit, target) 448 | else: 449 | crossent = softmax_loss_function(logit, target) 450 | log_perp_list.append(crossent * weight) 451 | log_perps = math_ops.add_n(log_perp_list) 452 | if average_across_timesteps: 453 | total_size = math_ops.add_n(weights) 454 | total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. 455 | log_perps /= total_size 456 | return log_perps 457 | 458 | 459 | def sequence_loss(logits, targets, weights, 460 | average_across_timesteps=True, average_across_batch=True, 461 | softmax_loss_function=None, name=None): 462 | """Weighted cross-entropy loss for a sequence of logits, batch-collapsed. 463 | 464 | Args: 465 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 466 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 467 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 468 | average_across_timesteps: If set, divide the returned cost by the total 469 | label weight. 470 | average_across_batch: If set, divide the returned cost by the batch size. 471 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 472 | to be used instead of the standard softmax (the default if this is None). 473 | name: Optional name for this operation, defaults to "sequence_loss". 474 | 475 | Returns: 476 | A scalar float Tensor: The average log-perplexity per symbol (weighted). 477 | 478 | Raises: 479 | ValueError: If len(logits) is different from len(targets) or len(weights). 480 | """ 481 | with ops.name_scope(name, "sequence_loss", logits + targets + weights): 482 | cost = math_ops.reduce_sum(sequence_loss_by_example( 483 | logits, targets, weights, 484 | average_across_timesteps=average_across_timesteps, 485 | softmax_loss_function=softmax_loss_function)) 486 | if average_across_batch: 487 | batch_size = array_ops.shape(targets[0])[0] 488 | return cost / math_ops.cast(batch_size, cost.dtype) 489 | else: 490 | return cost 491 | 492 | 493 | def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, 494 | buckets, seq2seq, softmax_loss_function=None, 495 | per_example_loss=False, name=None): 496 | """Create a sequence-to-sequence model with support for bucketing. 497 | 498 | The seq2seq argument is a function that defines a sequence-to-sequence model, 499 | e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24)) 500 | 501 | Args: 502 | encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input. 503 | decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input. 504 | targets: A list of 1D batch-sized int32 Tensors (desired output sequence). 505 | weights: List of 1D batch-sized float-Tensors to weight the targets. 506 | buckets: A list of pairs of (input size, output size) for each bucket. 507 | seq2seq: A sequence-to-sequence model function; it takes 2 input that 508 | agree with encoder_inputs and decoder_inputs, and returns a pair 509 | consisting of outputs and states (as, e.g., basic_rnn_seq2seq). 510 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 511 | to be used instead of the standard softmax (the default if this is None). 512 | per_example_loss: Boolean. If set, the returned loss will be a batch-sized 513 | tensor of losses for each sequence in the batch. If unset, it will be 514 | a scalar with the averaged loss from all examples. 515 | name: Optional name for this operation, defaults to "model_with_buckets". 516 | 517 | Returns: 518 | A tuple of the form (outputs, losses), where: 519 | outputs: The outputs for each bucket. Its j'th element consists of a list 520 | of 2D Tensors. The shape of output tensors can be either 521 | [batch_size x output_size] or [batch_size x num_decoder_symbols] 522 | depending on the seq2seq model used. 523 | losses: List of scalar Tensors, representing losses for each bucket, or, 524 | if per_example_loss is set, a list of 1D batch-sized float Tensors. 525 | 526 | Raises: 527 | ValueError: If length of encoder_inputsut, targets, or weights is smaller 528 | than the largest (last) bucket. 529 | """ 530 | if len(encoder_inputs) < buckets[-1][0]: 531 | raise ValueError("Length of encoder_inputs (%d) must be at least that of la" 532 | "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) 533 | if len(targets) < buckets[-1][1]: 534 | raise ValueError("Length of targets (%d) must be at least that of last" 535 | "bucket (%d)." % (len(targets), buckets[-1][1])) 536 | if len(weights) < buckets[-1][1]: 537 | raise ValueError("Length of weights (%d) must be at least that of last" 538 | "bucket (%d)." % (len(weights), buckets[-1][1])) 539 | 540 | all_inputs = encoder_inputs + decoder_inputs + targets + weights 541 | losses = [] 542 | outputs = [] 543 | with ops.name_scope(name, "model_with_buckets", all_inputs): 544 | for j, bucket in enumerate(buckets): 545 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 546 | reuse=True if j > 0 else None): 547 | bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]], 548 | decoder_inputs[:bucket[1]]) 549 | outputs.append(bucket_outputs) 550 | if per_example_loss: 551 | losses.append(sequence_loss_by_example( 552 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 553 | softmax_loss_function=softmax_loss_function)) 554 | else: 555 | losses.append(sequence_loss( 556 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 557 | softmax_loss_function=softmax_loss_function)) 558 | 559 | return outputs, losses 560 | 561 | 562 | def autoencoder_with_buckets(encoder_inputs, decoder_inputs, targets, weights, 563 | buckets, encoder, decoder, softmax_loss_function=None, 564 | per_example_loss=False, name=None): 565 | """Create a sequence-to-sequence model with support for bucketing. 566 | 567 | The seq2seq argument is a function that defines a sequence-to-sequence model, 568 | e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24)) 569 | 570 | Args: 571 | encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input. 572 | decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input. 573 | targets: A list of 1D batch-sized int32 Tensors (desired output sequence). 574 | weights: List of 1D batch-sized float-Tensors to weight the targets. 575 | buckets: A list of pairs of (input size, output size) for each bucket. 576 | seq2seq: A sequence-to-sequence model function; it takes 2 input that 577 | agree with encoder_inputs and decoder_inputs, and returns a pair 578 | consisting of outputs and states (as, e.g., basic_rnn_seq2seq). 579 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 580 | to be used instead of the standard softmax (the default if this is None). 581 | per_example_loss: Boolean. If set, the returned loss will be a batch-sized 582 | tensor of losses for each sequence in the batch. If unset, it will be 583 | a scalar with the averaged loss from all examples. 584 | name: Optional name for this operation, defaults to "model_with_buckets". 585 | 586 | Returns: 587 | A tuple of the form (outputs, losses), where: 588 | outputs: The outputs for each bucket. Its j'th element consists of a list 589 | of 2D Tensors. The shape of output tensors can be either 590 | [batch_size x output_size] or [batch_size x num_decoder_symbols] 591 | depending on the seq2seq model used. 592 | losses: List of scalar Tensors, representing losses for each bucket, or, 593 | if per_example_loss is set, a list of 1D batch-sized float Tensors. 594 | 595 | Raises: 596 | ValueError: If length of encoder_inputsut, targets, or weights is smaller 597 | than the largest (last) bucket. 598 | """ 599 | if len(encoder_inputs) < buckets[-1][0]: 600 | raise ValueError("Length of encoder_inputs (%d) must be at least that of la" 601 | "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) 602 | if len(targets) < buckets[-1][1]: 603 | raise ValueError("Length of targets (%d) must be at least that of last" 604 | "bucket (%d)." % (len(targets), buckets[-1][1])) 605 | if len(weights) < buckets[-1][1]: 606 | raise ValueError("Length of weights (%d) must be at least that of last" 607 | "bucket (%d)." % (len(weights), buckets[-1][1])) 608 | 609 | all_inputs = encoder_inputs + decoder_inputs + targets + weights 610 | losses = [] 611 | outputs = [] 612 | with ops.name_scope(name, "model_with_buckets", all_inputs): 613 | for j, bucket in enumerate(buckets): 614 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 615 | reuse=True if j > 0 else None): 616 | encoder_state = encoder(encoder_inputs[:bucket[0]]) 617 | bucket_outputs, _ = decoder(encoder_state, decoder_inputs[:bucket[1]]) 618 | outputs.append(bucket_outputs) 619 | if per_example_loss: 620 | losses.append(sequence_loss_by_example( 621 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 622 | softmax_loss_function=softmax_loss_function)) 623 | else: 624 | losses.append(sequence_loss( 625 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 626 | softmax_loss_function=softmax_loss_function)) 627 | 628 | return outputs, losses 629 | 630 | 631 | def sample(means, 632 | logvars, 633 | latent_dim, 634 | iaf=True, 635 | kl_min=None, 636 | anneal=False, 637 | kl_rate=None, 638 | dtype=None): 639 | """Perform sampling and calculate KL divergence. 640 | 641 | Args: 642 | means: tensor of shape (batch_size, latent_dim) 643 | logvars: tensor of shape (batch_size, latent_dim) 644 | latent_dim: dimension of latent space. 645 | iaf: perform linear IAF or not. 646 | kl_min: lower bound for KL divergence. 647 | anneal: perform KL cost annealing or not. 648 | kl_rate: KL divergence is multiplied by kl_rate if anneal is set to True. 649 | Returns: 650 | latent_vector: latent variable after sampling. A vector of shape (batch_size, latent_dim). 651 | kl_obj: objective to be minimized for the KL term. 652 | kl_cost: real KL divergence. 653 | """ 654 | if iaf: 655 | with tf.variable_scope('iaf'): 656 | prior = DiagonalGaussian(tf.zeros_like(means, dtype=dtype), 657 | tf.zeros_like(logvars, dtype=dtype)) 658 | posterior = DiagonalGaussian(means, logvars) 659 | z = posterior.sample 660 | 661 | logqs = posterior.logps(z) 662 | L = tf.get_variable("inverse_cholesky", [latent_dim, latent_dim], dtype=dtype, initializer=tf.zeros_initializer) 663 | diag_one = tf.ones([latent_dim], dtype=dtype) 664 | L = tf.matrix_set_diag(L, diag_one) 665 | mask = np.tril(np.ones([latent_dim,latent_dim])) 666 | L = L * mask 667 | latent_vector = tf.matmul(z, L) 668 | logps = prior.logps(latent_vector) 669 | kl_cost = logqs - logps 670 | else: 671 | noise = tf.random_normal(tf.shape(mean)) 672 | sample = mean + tf.exp(0.5 * logvar) * noise 673 | kl_cost = -0.5 * (logvars - tf.square(means) - 674 | tf.exp(logvars) + 1.0) 675 | kl_ave = tf.reduce_mean(kl_cost, [0]) #mean of kl_cost over batches 676 | kl_obj = kl_cost = tf.reduce_sum(kl_ave) 677 | if kl_min: 678 | kl_obj = tf.reduce_sum(tf.maximum(kl_ave, kl_min)) 679 | if anneal: 680 | kl_obj = kl_obj * kl_rate 681 | 682 | return latent_vector, kl_obj, kl_cost #both kl_obj and kl_cost are scalar 683 | 684 | 685 | def encoder_to_latent(encoder_state, 686 | embedding_size, 687 | latent_dim, 688 | num_layers, 689 | activation=tf.nn.relu, 690 | use_lstm=False, 691 | enc_state_bidirectional=False, 692 | dtype=None): 693 | concat_state_size = num_layers * embedding_size 694 | if enc_state_bidirectional: 695 | concat_state_size *= 2 696 | if use_lstm: 697 | concat_state_size *= 2 698 | if num_layers > 1: 699 | encoder_state = list(map(lambda state_tuple: tf.concat(1, state_tuple), encoder_state)) 700 | else: 701 | encoder_state = tf.concat(1, encoder_state) 702 | if num_layers > 1: 703 | encoder_state = tf.concat(1, encoder_state) 704 | with tf.variable_scope('encoder_to_latent'): 705 | w = tf.get_variable("w",[concat_state_size, 2 * latent_dim], 706 | dtype=dtype) 707 | b = tf.get_variable("b", [2 * latent_dim], dtype=dtype) 708 | mean_logvar = prelu(tf.matmul(encoder_state, w) + b) 709 | mean, logvar = tf.split(1, 2, mean_logvar) 710 | 711 | return mean, logvar 712 | 713 | 714 | def latent_to_decoder(latent_vector, 715 | embedding_size, 716 | latent_dim, 717 | num_layers, 718 | activation=tf.nn.relu, 719 | use_lstm=False, 720 | dtype=None): 721 | 722 | concat_state_size = num_layers * embedding_size 723 | if use_lstm: 724 | concat_state_size *= 2 725 | with tf.variable_scope('latent_to_decoder'): 726 | w = tf.get_variable("w",[latent_dim, concat_state_size], 727 | dtype=dtype) 728 | b = tf.get_variable("b", [concat_state_size], dtype=dtype) 729 | decoder_initial_state = prelu(tf.matmul(latent_vector, w) + b) 730 | if num_layers > 1: 731 | decoder_initial_state = tuple(tf.split(1, num_layers, decoder_initial_state)) 732 | if use_lstm: 733 | decoder_initial_state = [tuple(tf.split(1, 2, single_layer_state)) for single_layer_state in decoder_initial_state] 734 | elif use_lstm: 735 | decoder_initial_state = tuple(tf.split(1, 2, decoder_initial_state)) 736 | 737 | return decoder_initial_state 738 | 739 | 740 | def variational_autoencoder_with_buckets(encoder_inputs, decoder_inputs, targets, weights, 741 | buckets, encoder, decoder, enc_latent, latent_dec, sample, kl_f, 742 | probabilistic=False, 743 | softmax_loss_function=None, 744 | per_example_loss=False, name=None): 745 | """Create a sequence-to-sequence model with support for bucketing. 746 | 747 | The seq2seq argument is a function that defines a sequence-to-sequence model, 748 | e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24)) 749 | 750 | Args: 751 | encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input. 752 | decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input. 753 | targets: A list of 1D batch-sized int32 Tensors (desired output sequence). 754 | weights: List of 1D batch-sized float-Tensors to weight the targets. 755 | buckets: A list of pairs of (input size, output size) for each bucket. 756 | seq2seq: A sequence-to-sequence model function; it takes 2 input that 757 | agree with encoder_inputs and decoder_inputs, and returns a pair 758 | consisting of outputs and states (as, e.g., basic_rnn_seq2seq). 759 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 760 | to be used instead of the standard softmax (the default if this is None). 761 | per_example_loss: Boolean. If set, the returned loss will be a batch-sized 762 | tensor of losses for each sequence in the batch. If unset, it will be 763 | a scalar with the averaged loss from all examples. 764 | name: Optional name for this operation, defaults to "model_with_buckets". 765 | 766 | Returns: 767 | A tuple of the form (outputs, losses), where: 768 | outputs: The outputs for each bucket. Its j'th element consists of a list 769 | of 2D Tensors. The shape of output tensors can be either 770 | [batch_size x output_size] or [batch_size x num_decoder_symbols] 771 | depending on the seq2seq model used. 772 | losses: List of scalar Tensors, representing losses for each bucket, or, 773 | if per_example_loss is set, a list of 1D batch-sized float Tensors. 774 | 775 | Raises: 776 | ValueError: If length of encoder_inputsut, targets, or weights is smaller 777 | than the largest (last) bucket. 778 | """ 779 | if len(encoder_inputs) < buckets[-1][0]: 780 | raise ValueError("Length of encoder_inputs (%d) must be at least that of la" 781 | "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) 782 | if len(targets) < buckets[-1][1]: 783 | raise ValueError("Length of targets (%d) must be at least that of last" 784 | "bucket (%d)." % (len(targets), buckets[-1][1])) 785 | if len(weights) < buckets[-1][1]: 786 | raise ValueError("Length of weights (%d) must be at least that of last" 787 | "bucket (%d)." % (len(weights), buckets[-1][1])) 788 | 789 | all_inputs = encoder_inputs + decoder_inputs + targets + weights 790 | losses = [] 791 | outputs = [] 792 | KL_divergences = [] 793 | with ops.name_scope(name, "variational_autoencoder_with_buckets", all_inputs): 794 | for j, bucket in enumerate(buckets): 795 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 796 | reuse=True if j > 0 else None): 797 | encoder_last_state = encoder(encoder_inputs[:bucket[0]]) 798 | mean, logvar = enc_latent(encoder_last_state) 799 | if probabilistic: 800 | latent_vector = sample(mean, logvar) 801 | else: 802 | latent_vector = mean 803 | decoder_initial_state = latent_dec(latent_vector) 804 | bucket_outputs, _ = decoder(decoder_initial_state, decoder_inputs[:bucket[1]]) 805 | outputs.append(bucket_outputs) 806 | total_size = math_ops.add_n(weights[:bucket[1]]) 807 | total_size += 1e-12 808 | KL_divergences.append(tf.reduce_mean(kl_f(mean, logvar) / total_size)) 809 | if per_example_loss: 810 | losses.append(sequence_loss_by_example( 811 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 812 | softmax_loss_function=softmax_loss_function)) 813 | else: 814 | losses.append(sequence_loss( 815 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 816 | softmax_loss_function=softmax_loss_function)) 817 | 818 | return outputs, losses, KL_divergences 819 | 820 | 821 | def variational_encoder_with_buckets(encoder_inputs, buckets, encoder, 822 | enc_latent, softmax_loss_function=None, 823 | per_example_loss=False, name=None): 824 | """Create a sequence-to-sequence model with support for bucketing. 825 | """ 826 | if len(encoder_inputs) < buckets[-1][0]: 827 | raise ValueError("Length of encoder_inputs (%d) must be at least that of la" 828 | "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) 829 | 830 | all_inputs = encoder_inputs 831 | means = [] 832 | logvars = [] 833 | with ops.name_scope(name, "variational_encoder_with_buckets", all_inputs): 834 | for j, bucket in enumerate(buckets): 835 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 836 | reuse=True if j > 0 else None): 837 | encoder_last_state = encoder(encoder_inputs[:bucket[0]]) 838 | mean, logvar = enc_latent(encoder_last_state) 839 | means.append(mean) 840 | logvars.append(logvar) 841 | 842 | return means, logvars 843 | 844 | 845 | def variational_decoder_with_buckets(means, logvars, decoder_inputs, 846 | targets, weights, 847 | buckets, decoder, latent_dec, sample, 848 | softmax_loss_function=None, 849 | per_example_loss=False, name=None): 850 | """Create a sequence-to-sequence model with support for bucketing. 851 | """ 852 | if len(targets) < buckets[-1][1]: 853 | raise ValueError("Length of targets (%d) must be at least that of last" 854 | "bucket (%d)." % (len(targets), buckets[-1][1])) 855 | if len(weights) < buckets[-1][1]: 856 | raise ValueError("Length of weights (%d) must be at least that of last" 857 | "bucket (%d)." % (len(weights), buckets[-1][1])) 858 | 859 | all_inputs = decoder_inputs + targets + weights 860 | losses = [] 861 | outputs = [] 862 | KL_objs = [] 863 | KL_costs = [] 864 | with ops.name_scope(name, "variational_decoder_with_buckets", all_inputs): 865 | for j, bucket in enumerate(buckets): 866 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 867 | reuse=True if j > 0 else None): 868 | 869 | latent_vector, kl_obj, kl_cost = sample(means[j], logvars[j]) 870 | decoder_initial_state = latent_dec(latent_vector) 871 | 872 | bucket_outputs, _ = decoder(decoder_initial_state, decoder_inputs[:bucket[1]]) 873 | outputs.append(bucket_outputs) 874 | total_size = math_ops.add_n(weights[:bucket[1]]) 875 | total_size += 1e-12 876 | KL_objs.append(tf.reduce_mean(kl_obj / total_size)) 877 | KL_costs.append(tf.reduce_mean(kl_cost / total_size)) 878 | if per_example_loss: 879 | losses.append(sequence_loss_by_example( 880 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 881 | softmax_loss_function=softmax_loss_function)) 882 | else: 883 | losses.append(sequence_loss( 884 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 885 | softmax_loss_function=softmax_loss_function)) 886 | 887 | return outputs, losses, KL_objs, KL_costs 888 | 889 | 890 | def variational_beam_decoder_with_buckets(means, logvars, decoder_inputs, 891 | targets, weights, 892 | buckets, decoder, latent_dec, kl_f, sample, iaf=False, 893 | softmax_loss_function=None, 894 | per_example_loss=False, name=None): 895 | """Create a sequence-to-sequence model with support for bucketing. 896 | """ 897 | if len(targets) < buckets[-1][1]: 898 | raise ValueError("Length of targets (%d) must be at least that of last" 899 | "bucket (%d)." % (len(targets), buckets[-1][1])) 900 | if len(weights) < buckets[-1][1]: 901 | raise ValueError("Length of weights (%d) must be at least that of last" 902 | "bucket (%d)." % (len(weights), buckets[-1][1])) 903 | 904 | all_inputs = decoder_inputs + targets + weights 905 | losses = [] 906 | outputs = [] 907 | beam_paths = [] 908 | beam_path = [] 909 | KL_divergences = [] 910 | with ops.name_scope(name, "variational_decoder_with_buckets", all_inputs): 911 | for j, bucket in enumerate(buckets): 912 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 913 | reuse=True if j > 0 else None): 914 | latent_vector, kl_cost = sample(means[j], logvars[j]) 915 | decoder_initial_state = latent_dec(latent_vector) 916 | 917 | bucket_outputs, _, beam_path, beam_symbol = decoder(decoder_initial_state, decoder_inputs[:bucket[1]]) 918 | outputs.append(bucket_outputs) 919 | beam_paths.append(beam_path) 920 | beam_symbols.append(beam_symbol) 921 | total_size = math_ops.add_n(weights[:bucket[1]]) 922 | total_size += 1e-12 923 | KL_divergences.append(tf.reduce_mean(kl_cost / total_size)) 924 | if per_example_loss: 925 | losses.append(sequence_loss_by_example( 926 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 927 | softmax_loss_function=softmax_loss_function)) 928 | else: 929 | losses.append(sequence_loss( 930 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 931 | softmax_loss_function=softmax_loss_function)) 932 | 933 | return outputs, losses, KL_objs, KL_costs 934 | -------------------------------------------------------------------------------- /seq2seq_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Sequence-to-sequence model with an attention mechanism.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import random 23 | 24 | import numpy as np 25 | from six.moves import xrange # pylint: disable=redefined-builtin 26 | import tensorflow as tf 27 | 28 | import utils.data_utils as data_utils 29 | import seq2seq 30 | from tensorflow.python.ops import variable_scope 31 | 32 | class Seq2SeqModel(object): 33 | """Sequence-to-sequence model with attention and for multiple buckets. 34 | 35 | This class implements a multi-layer recurrent neural network as encoder, 36 | and an attention-based decoder. This is the same as the model described in 37 | this paper: http://arxiv.org/abs/1412.7449 - please look there for details, 38 | or into the seq2seq library for complete model implementation. 39 | This class also allows to use GRU cells in addition to LSTM cells, and 40 | sampled softmax to handle large output vocabulary size. A single-layer 41 | version of this model, but with bi-directional encoder, was presented in 42 | http://arxiv.org/abs/1409.0473 43 | and sampled softmax is described in Section 3 of the following paper. 44 | http://arxiv.org/abs/1412.2007 45 | """ 46 | 47 | def __init__(self, 48 | source_vocab_size, 49 | target_vocab_size, 50 | buckets, 51 | size, 52 | num_layers, 53 | latent_dim, 54 | max_gradient_norm, 55 | batch_size, 56 | learning_rate, 57 | kl_min=2, 58 | word_dropout_keep_prob=1.0, 59 | anneal=False, 60 | kl_rate_rise_factor=None, 61 | use_lstm=False, 62 | num_samples=512, 63 | optimizer=None, 64 | activation=tf.nn.relu, 65 | forward_only=False, 66 | feed_previous=True, 67 | bidirectional=False, 68 | weight_initializer=None, 69 | bias_initializer=None, 70 | iaf=False, 71 | dtype=tf.float32): 72 | """Create the model. 73 | 74 | Args: 75 | source_vocab_size: size of the source vocabulary. 76 | target_vocab_size: size of the target vocabulary. 77 | buckets: a list of pairs (I, O), where I specifies maximum input length 78 | that will be processed in that bucket, and O specifies maximum output 79 | length. Training instances that have inputs longer than I or outputs 80 | longer than O will be pushed to the next bucket and padded accordingly. 81 | We assume that the list is sorted, e.g., [(2, 4), (8, 16)]. 82 | size: number of units in each layer of the model. 83 | num_layers: number of layers in the model. 84 | max_gradient_norm: gradients will be clipped to maximally this norm. 85 | batch_size: the size of the batches used during training; 86 | the model construction is independent of batch_size, so it can be 87 | changed after initialization if this is convenient, e.g., for decoding. 88 | learning_rate: learning rate to start with. 89 | use_lstm: if true, we use LSTM cells instead of GRU cells. 90 | num_samples: number of samples for sampled softmax. 91 | forward_only: if set, we do not construct the backward pass in the model. 92 | dtype: the data type to use to store internal variables. 93 | """ 94 | self.source_vocab_size = source_vocab_size 95 | self.target_vocab_size = target_vocab_size 96 | self.latent_dim = latent_dim 97 | self.buckets = buckets 98 | self.batch_size = batch_size 99 | self.word_dropout_keep_prob = word_dropout_keep_prob 100 | self.kl_min = kl_min 101 | feed_previous = feed_previous or forward_only 102 | 103 | self.learning_rate = tf.Variable( 104 | float(learning_rate), trainable=False, dtype=dtype) 105 | 106 | self.enc_embedding = tf.get_variable("enc_embedding", [source_vocab_size, size], dtype=dtype, initializer=weight_initializer()) 107 | 108 | self.dec_embedding = tf.get_variable("dec_embedding", [target_vocab_size, size], dtype=dtype, initializer=weight_initializer()) 109 | 110 | self.kl_rate = tf.Variable( 111 | 0.0, trainable=False, dtype=dtype) 112 | self.new_kl_rate = tf.placeholder(tf.float32, shape=[], name="new_kl_rate") 113 | self.kl_rate_update = tf.assign(self.kl_rate, self.new_kl_rate) 114 | 115 | self.replace_input = tf.placeholder(tf.int32, shape=[None], name="replace_input") 116 | replace_input = tf.nn.embedding_lookup(self.dec_embedding, self.replace_input) 117 | 118 | self.global_step = tf.Variable(0, trainable=False) 119 | 120 | # If we use sampled softmax, we need an output projection. 121 | output_projection = None 122 | softmax_loss_function = None 123 | # Sampled softmax only makes sense if we sample less than vocabulary size. 124 | if num_samples > 0 and num_samples < self.target_vocab_size: 125 | w_t = tf.get_variable("proj_w", [self.target_vocab_size, size], dtype=dtype, initializer=weight_initializer()) 126 | w = tf.transpose(w_t) 127 | b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype, initializer=bias_initializer) 128 | output_projection = (w, b) 129 | 130 | def sampled_loss(inputs, labels): 131 | labels = tf.reshape(labels, [-1, 1]) 132 | # We need to compute the sampled_softmax_loss using 32bit floats to 133 | # avoid numerical instabilities. 134 | local_w_t = tf.cast(w_t, tf.float32) 135 | local_b = tf.cast(b, tf.float32) 136 | local_inputs = tf.cast(inputs, tf.float32) 137 | return tf.cast( 138 | tf.nn.sampled_softmax_loss(local_w_t, local_b, local_inputs, labels, 139 | num_samples, self.target_vocab_size), 140 | dtype) 141 | softmax_loss_function = sampled_loss 142 | # Create the internal multi-layer cell for our RNN. 143 | single_cell = tf.nn.rnn_cell.GRUCell(size) 144 | if use_lstm: 145 | single_cell = tf.nn.rnn_cell.BasicLSTMCell(size) 146 | 147 | cell = single_cell 148 | 149 | def encoder_f(encoder_inputs): 150 | return seq2seq.embedding_encoder( 151 | encoder_inputs, 152 | cell, 153 | self.enc_embedding, 154 | num_symbols=source_vocab_size, 155 | embedding_size=size, 156 | bidirectional=bidirectional, 157 | weight_initializer=weight_initializer, 158 | dtype=dtype) 159 | 160 | def decoder_f(encoder_state, decoder_inputs): 161 | return seq2seq.embedding_rnn_decoder( 162 | decoder_inputs, 163 | encoder_state, 164 | cell, 165 | embedding=self.dec_embedding, 166 | word_dropout_keep_prob=word_dropout_keep_prob, 167 | replace_input=replace_input, 168 | num_symbols=target_vocab_size, 169 | embedding_size=size, 170 | output_projection=output_projection, 171 | feed_previous=feed_previous, 172 | weight_initializer=weight_initializer) 173 | 174 | def enc_latent_f(encoder_state): 175 | return seq2seq.encoder_to_latent( 176 | encoder_state, 177 | embedding_size=size, 178 | latent_dim=latent_dim, 179 | num_layers=num_layers, 180 | activation=activation, 181 | use_lstm=use_lstm, 182 | enc_state_bidirectional=bidirectional, 183 | dtype=dtype) 184 | 185 | def latent_dec_f(latent_vector): 186 | return seq2seq.latent_to_decoder(latent_vector, 187 | embedding_size=size, 188 | latent_dim=latent_dim, 189 | num_layers=num_layers, 190 | activation=activation, 191 | use_lstm=use_lstm, 192 | dtype=dtype) 193 | 194 | 195 | def sample_f(mean, logvar): 196 | return seq2seq.sample( 197 | mean, 198 | logvar, 199 | latent_dim, 200 | iaf, 201 | kl_min, 202 | anneal, 203 | self.kl_rate, 204 | dtype) 205 | 206 | # The seq2seq function: we use embedding for the input and attention. 207 | def seq2seq_f(encoder_inputs, decoder_inputs, do_decode): 208 | return tf.nn.seq2seq.embedding_attention_seq2seq_f( 209 | encoder_inputs, 210 | decoder_inputs, 211 | cell, 212 | num_encoder_symbols=source_vocab_size, 213 | num_decoder_symbols=target_vocab_size, 214 | embedding_size=size, 215 | output_projection=output_projection, 216 | feed_previous=do_decode, 217 | dtype=dtype) 218 | 219 | 220 | # Feeds for inputs. 221 | self.encoder_inputs = [] 222 | self.decoder_inputs = [] 223 | self.target_weights = [] 224 | for i in xrange(buckets[-1][0]): # Last bucket is the biggest one. 225 | self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None], 226 | name="encoder{0}".format(i))) 227 | for i in xrange(buckets[-1][1] + 1): 228 | self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None], 229 | name="decoder{0}".format(i))) 230 | self.target_weights.append(tf.placeholder(dtype, shape=[None], 231 | name="weight{0}".format(i))) 232 | 233 | # Our targets are decoder inputs shifted by one. 234 | targets = [self.decoder_inputs[i + 1] 235 | for i in xrange(len(self.decoder_inputs) - 1)] 236 | 237 | 238 | self.means, self.logvars = seq2seq.variational_encoder_with_buckets( 239 | self.encoder_inputs, buckets, encoder_f, enc_latent_f, 240 | softmax_loss_function=softmax_loss_function) 241 | self.outputs, self.losses, self.KL_objs, self.KL_costs = seq2seq.variational_decoder_with_buckets( 242 | self.means, self.logvars, self.decoder_inputs, targets, 243 | self.target_weights, buckets, decoder_f, latent_dec_f, 244 | sample_f, softmax_loss_function=softmax_loss_function) 245 | 246 | # If we use output projection, we need to project outputs for decoding. 247 | if output_projection is not None: 248 | for b in xrange(len(buckets)): 249 | self.outputs[b] = [ 250 | tf.matmul(output, output_projection[0]) + output_projection[1] 251 | for output in self.outputs[b] 252 | ] 253 | # Gradients and SGD update operation for training the model. 254 | params = tf.trainable_variables() 255 | if not forward_only: 256 | self.gradient_norms = [] 257 | self.updates = [] 258 | for b in xrange(len(buckets)): 259 | total_loss = self.losses[b] + self.KL_objs[b] 260 | gradients = tf.gradients(total_loss, params) 261 | clipped_gradients, norm = tf.clip_by_global_norm(gradients, 262 | max_gradient_norm) 263 | self.gradient_norms.append(norm) 264 | self.updates.append(optimizer.apply_gradients( 265 | zip(clipped_gradients, params), global_step=self.global_step)) 266 | 267 | self.saver = tf.train.Saver(tf.global_variables()) 268 | 269 | 270 | def step(self, session, encoder_inputs, decoder_inputs, target_weights, 271 | bucket_id, forward_only, prob, beam_size=1): 272 | """Run a step of the model feeding the given inputs. 273 | 274 | Args: 275 | session: tensorflow session to use. 276 | encoder_inputs: list of numpy int vectors to feed as encoder inputs. 277 | decoder_inputs: list of numpy int vectors to feed as decoder inputs. 278 | target_weights: list of numpy float vectors to feed as target weights. 279 | bucket_id: which bucket of the model to use. 280 | forward_only: whether to do the backward step or only forward. 281 | 282 | Returns: 283 | A triple consisting of gradient norm (or None if we did not do backward), 284 | average perplexity, and the outputs. 285 | 286 | Raises: 287 | ValueError: if length of encoder_inputs, decoder_inputs, or 288 | target_weights disagrees with bucket size for the specified bucket_id. 289 | """ 290 | # Check if the sizes match. 291 | encoder_size, decoder_size = self.buckets[bucket_id] 292 | if len(encoder_inputs) != encoder_size: 293 | raise ValueError("Encoder length must be equal to the one in bucket," 294 | " %d != %d." % (len(encoder_inputs), encoder_size)) 295 | if len(decoder_inputs) != decoder_size: 296 | raise ValueError("Decoder length must be equal to the one in bucket," 297 | " %d != %d." % (len(decoder_inputs), decoder_size)) 298 | if len(target_weights) != decoder_size: 299 | raise ValueError("Weights length must be equal to the one in bucket," 300 | " %d != %d." % (len(target_weights), decoder_size)) 301 | 302 | # Input feed: encoder inputs, decoder inputs, target_weights, as provided. 303 | input_feed = {} 304 | for l in xrange(encoder_size): 305 | input_feed[self.encoder_inputs[l].name] = encoder_inputs[l] 306 | for l in xrange(decoder_size): 307 | input_feed[self.decoder_inputs[l].name] = decoder_inputs[l] 308 | input_feed[self.target_weights[l].name] = target_weights[l] 309 | if self.word_dropout_keep_prob < 1: 310 | input_feed[self.replace_input.name] = np.full((self.batch_size), data_utils.UNK_ID, dtype=np.int32) 311 | 312 | # Since our targets are decoder inputs shifted by one, we need one more. 313 | last_target = self.decoder_inputs[decoder_size].name 314 | input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32) 315 | if not prob: 316 | input_feed[self.logvars[bucket_id]] = np.full((self.batch_size, self.latent_dim), -800.0, dtype=np.float32) 317 | 318 | # Output feed: depends on whether we do a backward step or not. 319 | if not forward_only: 320 | output_feed = [self.updates[bucket_id], # Update Op that does SGD. 321 | self.gradient_norms[bucket_id], # Gradient norm. 322 | self.losses[bucket_id], 323 | self.KL_costs[bucket_id]] # Loss for this batch. 324 | else: 325 | output_feed = [self.losses[bucket_id], self.KL_costs[bucket_id]] # Loss for this batch. 326 | for l in xrange(decoder_size): # Output logits. 327 | output_feed.append(self.outputs[bucket_id][l]) 328 | 329 | outputs = session.run(output_feed, input_feed) 330 | if not forward_only: 331 | return outputs[1], outputs[2], outputs[3], None # Gradient norm, loss, KL divergence, no outputs. 332 | else: 333 | return None, outputs[0], outputs[1], outputs[2:] # no gradient norm, loss, KL divergence, outputs. 334 | 335 | 336 | def encode_to_latent(self, session, encoder_inputs, bucket_id): 337 | 338 | # Check if the sizes match. 339 | encoder_size, _ = self.buckets[bucket_id] 340 | if len(encoder_inputs) != encoder_size: 341 | raise ValueError("Encoder length must be equal to the one in bucket," 342 | " %d != %d." % (len(encoder_inputs), encoder_size)) 343 | 344 | input_feed = {} 345 | for l in xrange(encoder_size): 346 | input_feed[self.encoder_inputs[l].name] = encoder_inputs[l] 347 | 348 | 349 | output_feed = [self.means[bucket_id], self.logvars[bucket_id]] 350 | means, logvars = session.run(output_feed, input_feed) 351 | 352 | return means, logvars 353 | 354 | 355 | def decode_from_latent(self, session, means, logvars, bucket_id, decoder_inputs, target_weights): 356 | 357 | _, decoder_size = self.buckets[bucket_id] 358 | # Input feed: means. 359 | input_feed = {self.means[bucket_id]: means} 360 | input_feed[self.logvars[bucket_id]] = logvars 361 | 362 | for l in xrange(decoder_size): 363 | input_feed[self.decoder_inputs[l].name] = decoder_inputs[l] 364 | input_feed[self.target_weights[l].name] = target_weights[l] 365 | if self.word_dropout_keep_prob < 1: 366 | input_feed[self.replace_input.name] = np.full((self.batch_size), data_utils.UNK_ID, dtype=np.int32) 367 | 368 | last_target = self.decoder_inputs[decoder_size].name 369 | input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32) 370 | output_feed = [] 371 | for l in xrange(decoder_size): # Output logits. 372 | output_feed.append(self.outputs[bucket_id][l]) 373 | 374 | outputs = session.run(output_feed, input_feed) 375 | 376 | return outputs 377 | 378 | def get_batch(self, data, bucket_id): 379 | """Get a random batch of data from the specified bucket, prepare for step. 380 | 381 | To feed data in step(..) it must be a list of batch-major vectors, while 382 | data here contains single length-major cases. So the main logic of this 383 | function is to re-index data cases to be in the proper format for feeding. 384 | 385 | Args: 386 | data: a tuple of size len(self.buckets) in which each element contains 387 | lists of pairs of input and output data that we use to create a batch. 388 | bucket_id: integer, which bucket to get the batch for. 389 | 390 | Returns: 391 | The triple (encoder_inputs, decoder_inputs, target_weights) for 392 | the constructed batch that has the proper format to call step(...) later. 393 | """ 394 | encoder_size, decoder_size = self.buckets[bucket_id] 395 | encoder_inputs, decoder_inputs = [], [] 396 | 397 | # Get a random batch of encoder and decoder inputs from data, 398 | # pad them if needed, reverse encoder inputs and add GO to decoder. 399 | for _ in xrange(self.batch_size): 400 | encoder_input, decoder_input = random.choice(data[bucket_id]) 401 | 402 | # Encoder inputs are padded and then reversed. 403 | encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input)) 404 | encoder_inputs.append(list(reversed(encoder_input + encoder_pad))) 405 | 406 | # Decoder inputs get an extra "GO" symbol, and are padded then. 407 | decoder_pad_size = decoder_size - len(decoder_input) - 1 408 | decoder_inputs.append([data_utils.GO_ID] + decoder_input + 409 | [data_utils.PAD_ID] * decoder_pad_size) 410 | 411 | # Now we create batch-major vectors from the data selected above. 412 | batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], [] 413 | 414 | # Batch encoder inputs are just re-indexed encoder_inputs. 415 | for length_idx in xrange(encoder_size): 416 | batch_encoder_inputs.append( 417 | np.array([encoder_inputs[batch_idx][length_idx] 418 | for batch_idx in xrange(self.batch_size)], dtype=np.int32)) 419 | 420 | # Batch decoder inputs are re-indexed decoder_inputs, we create weights. 421 | for length_idx in xrange(decoder_size): 422 | batch_decoder_inputs.append( 423 | np.array([decoder_inputs[batch_idx][length_idx] 424 | for batch_idx in xrange(self.batch_size)], dtype=np.int32)) 425 | 426 | # Create target_weights to be 0 for targets that are padding. 427 | batch_weight = np.ones(self.batch_size, dtype=np.float32) 428 | for batch_idx in xrange(self.batch_size): 429 | # We set weight to 0 if the corresponding target is a PAD symbol. 430 | # The corresponding target is decoder_input shifted by 1 forward. 431 | if length_idx < decoder_size - 1: 432 | target = decoder_inputs[batch_idx][length_idx + 1] 433 | if length_idx == decoder_size - 1 or target == data_utils.PAD_ID: 434 | batch_weight[batch_idx] = 0.0 435 | batch_weights.append(batch_weight) 436 | return batch_encoder_inputs, batch_decoder_inputs, batch_weights 437 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilities for downloading data from WMT, tokenizing, vocabularies.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import gzip 22 | import os 23 | import re 24 | import tarfile 25 | 26 | from six.moves import urllib 27 | 28 | from tensorflow.python.platform import gfile 29 | import tensorflow as tf 30 | 31 | # Special vocabulary symbols - we always put them at the start. 32 | _PAD = "_PAD" 33 | _GO = "_GO" 34 | _EOS = "_EOS" 35 | _UNK = "_UNK" 36 | _START_VOCAB = [_PAD, _GO, _EOS, _UNK] 37 | 38 | PAD_ID = 0 39 | GO_ID = 1 40 | EOS_ID = 2 41 | UNK_ID = 3 42 | 43 | # Regular expressions used to tokenize. 44 | _WORD_SPLIT = re.compile("([.,!?\"':;)(])") 45 | _DIGIT_RE = re.compile(r"\d") 46 | 47 | # URLs for WMT data. 48 | _WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar" 49 | _WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz" 50 | 51 | 52 | def maybe_download(directory, filename, url): 53 | """Download filename from url unless it's already in directory.""" 54 | if not os.path.exists(directory): 55 | print("Creating directory %s" % directory) 56 | os.mkdir(directory) 57 | filepath = os.path.join(directory, filename) 58 | if not os.path.exists(filepath): 59 | print("Downloading %s to %s" % (url, filepath)) 60 | filepath, _ = urllib.request.urlretrieve(url, filepath) 61 | statinfo = os.stat(filepath) 62 | print("Succesfully downloaded", filename, statinfo.st_size, "bytes") 63 | return filepath 64 | 65 | 66 | def gunzip_file(gz_path, new_path): 67 | """Unzips from gz_path into new_path.""" 68 | print("Unpacking %s to %s" % (gz_path, new_path)) 69 | with gzip.open(gz_path, "r") as gz_file: 70 | with open(new_path, "w") as new_file: 71 | for line in gz_file: 72 | new_file.write(line) 73 | 74 | 75 | def basic_tokenizer(sentence): 76 | """Very basic tokenizer: split the sentence into a list of tokens.""" 77 | words = [] 78 | for space_separated_fragment in sentence.strip().split(): 79 | words.extend(_WORD_SPLIT.split(space_separated_fragment)) 80 | return [w for w in words if w] 81 | 82 | 83 | def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size, embedding_path, 84 | tokenizer=None, normalize_digits=True): 85 | """Create vocabulary file (if it does not exist yet) from data file. 86 | 87 | Data file is assumed to contain one sentence per line. Each sentence is 88 | tokenized and digits are normalized (if normalize_digits is set). 89 | Vocabulary contains the most-frequent tokens up to max_vocabulary_size. 90 | We write it to vocabulary_path in a one-token-per-line format, so that later 91 | token in the first line gets id=0, second line gets id=1, and so on. 92 | 93 | Args: 94 | vocabulary_path: path where the vocabulary will be created. 95 | data_path: data file that will be used to create vocabulary. 96 | max_vocabulary_size: limit on the size of the created vocabulary. 97 | tokenizer: a function to use to tokenize each data sentence; 98 | if None, basic_tokenizer will be used. 99 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 100 | """ 101 | if not gfile.Exists(vocabulary_path) or not gfile.Exists(embedding_path): 102 | print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path)) 103 | print("Creating embedding file %s from data %s" % (embedding_path, data_path)) 104 | vocab = {} 105 | with gfile.GFile(data_path, mode="r") as f: 106 | counter = 0 107 | for line in f: 108 | counter += 1 109 | if counter % 100000 == 0: 110 | print(" processing line %d" % counter) 111 | tokens = tokenizer(line) if tokenizer else basic_tokenizer(line) 112 | for w in tokens: 113 | word = _DIGIT_RE.sub("0", w) if normalize_digits else w 114 | if word in vocab: 115 | vocab[word] += 1 116 | else: 117 | vocab[word] = 1 118 | vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) 119 | if len(vocab_list) > max_vocabulary_size: 120 | vocab_list = vocab_list[:max_vocabulary_size] 121 | with gfile.GFile(vocabulary_path, mode="wb") as vocab_file: 122 | with gfile.GFile(embedding_path, mode="wb") as embedding_file: 123 | for w in vocab_list: 124 | vocab_file.write(w + "\n") 125 | embedding_file.write(w + "\n") 126 | 127 | 128 | def initialize_vocabulary(vocabulary_path): 129 | """Initialize vocabulary from file. 130 | 131 | We assume the vocabulary is stored one-item-per-line, so a file: 132 | dog 133 | cat 134 | will result in a vocabulary {"dog": 0, "cat": 1}, and this function will 135 | also return the reversed-vocabulary ["dog", "cat"]. 136 | 137 | Args: 138 | vocabulary_path: path to the file containing the vocabulary. 139 | 140 | Returns: 141 | a pair: the vocabulary (a dictionary mapping string to integers), and 142 | the reversed vocabulary (a list, which reverses the vocabulary mapping). 143 | 144 | Raises: 145 | ValueError: if the provided vocabulary_path does not exist. 146 | """ 147 | if gfile.Exists(vocabulary_path): 148 | rev_vocab = [] 149 | with gfile.GFile(vocabulary_path, mode="r") as f: 150 | rev_vocab.extend(f.readlines()) 151 | rev_vocab = [line.strip() for line in rev_vocab] 152 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) 153 | return vocab, rev_vocab 154 | else: 155 | raise ValueError("Vocabulary file %s not found.", vocabulary_path) 156 | 157 | 158 | def sentence_to_token_ids(sentence, vocabulary, 159 | tokenizer=None, normalize_digits=True): 160 | """Convert a string to list of integers representing token-ids. 161 | 162 | For example, a sentence "I have a dog" may become tokenized into 163 | ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2, 164 | "a": 4, "dog": 7"} this function will return [1, 2, 4, 7]. 165 | 166 | Args: 167 | sentence: the sentence in bytes format to convert to token-ids. 168 | vocabulary: a dictionary mapping tokens to integers. 169 | tokenizer: a function to use to tokenize each sentence; 170 | if None, basic_tokenizer will be used. 171 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 172 | 173 | Returns: 174 | a list of integers, the token-ids for the sentence. 175 | """ 176 | 177 | if tokenizer: 178 | words = tokenizer(sentence) 179 | else: 180 | words = basic_tokenizer(sentence) 181 | if not normalize_digits: 182 | return [vocabulary.get(w, UNK_ID) for w in words] 183 | # Normalize digits by 0 before looking words up in the vocabulary. 184 | return [vocabulary.get(_DIGIT_RE.sub("0", w), UNK_ID) for w in words] 185 | 186 | 187 | def data_to_token_ids(data_path, target_path, vocabulary_path, 188 | tokenizer=None, normalize_digits=True): 189 | """Tokenize data file and turn into token-ids using given vocabulary file. 190 | 191 | This function loads data line-by-line from data_path, calls the above 192 | sentence_to_token_ids, and saves the result to target_path. See comment 193 | for sentence_to_token_ids on the details of token-ids format. 194 | 195 | Args: 196 | data_path: path to the data file in one-sentence-per-line format. 197 | target_path: path where the file with token-ids will be created. 198 | vocabulary_path: path to the vocabulary file. 199 | tokenizer: a function to use to tokenize each sentence; 200 | if None, basic_tokenizer will be used. 201 | normalize_digits: Boolean; if true, all digits are replaced by 0s. 202 | """ 203 | if not gfile.Exists(target_path): 204 | print("Tokenizing data in %s" % data_path) 205 | vocab, _ = initialize_vocabulary(vocabulary_path) 206 | with gfile.GFile(data_path, mode="r") as data_file: 207 | with gfile.GFile(target_path, mode="w") as tokens_file: 208 | counter = 0 209 | for line in data_file: 210 | counter += 1 211 | if counter % 100000 == 0: 212 | print(" tokenizing line %d" % counter) 213 | token_ids = sentence_to_token_ids(line, vocab, tokenizer, 214 | normalize_digits) 215 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n") 216 | 217 | 218 | def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size, 219 | load_embeddings=False, tokenizer=None): 220 | """Get WMT data into data_dir, create vocabularies and tokenize data. 221 | 222 | Args: 223 | data_dir: directory in which the data sets will be stored. 224 | en_vocabulary_size: size of the English vocabulary to create and use. 225 | fr_vocabulary_size: size of the French vocabulary to create and use. 226 | tokenizer: a function to use to tokenize each data sentence; 227 | if None, basic_tokenizer will be used. 228 | 229 | Returns: 230 | A tuple of 6 elements: 231 | (1) path to the token-ids for English training data-set, 232 | (2) path to the token-ids for French training data-set, 233 | (3) path to the token-ids for English development data-set, 234 | (4) path to the token-ids for French development data-set, 235 | (5) path to the English vocabulary file, 236 | (6) path to the French vocabulary file. 237 | """ 238 | # Get wmt data to the specified directory. 239 | train_path = os.path.join(data_dir, "train.txt") 240 | dev_path = os.path.join(data_dir, "dev.txt") 241 | 242 | # Create vocabularies of the appropriate sizes. 243 | fr_vocab_path = os.path.join(data_dir, "vocab%d.out" % fr_vocabulary_size) 244 | en_vocab_path = os.path.join(data_dir, "vocab%d.in" % en_vocabulary_size) 245 | create_vocabulary(fr_vocab_path, train_path + ".out", fr_vocabulary_size, 246 | os.path.join(data_dir, "dec_embedding{0}.tsv".format(fr_vocabulary_size)), 247 | tokenizer) 248 | create_vocabulary(en_vocab_path, train_path + ".in", en_vocabulary_size, 249 | os.path.join(data_dir, "enc_embedding{0}.tsv".format(en_vocabulary_size)), 250 | tokenizer) 251 | #if load_embeddings: 252 | # embed_utils.save_embeddings(fr_vocab_path, "embed5000.txt") 253 | # embed_utils.save_embeddings(en_vocab_path, "embed5000.txt") 254 | 255 | 256 | # Create token ids for the training data. 257 | fr_train_ids_path = train_path + (".ids%d.out" % fr_vocabulary_size) 258 | en_train_ids_path = train_path + (".ids%d.in" % en_vocabulary_size) 259 | data_to_token_ids(train_path + ".out", fr_train_ids_path, fr_vocab_path, tokenizer) 260 | data_to_token_ids(train_path + ".in", en_train_ids_path, en_vocab_path, tokenizer) 261 | 262 | # Create token ids for the development data. 263 | fr_dev_ids_path = dev_path + (".ids%d.out" % fr_vocabulary_size) 264 | en_dev_ids_path = dev_path + (".ids%d.in" % en_vocabulary_size) 265 | data_to_token_ids(dev_path + ".out", fr_dev_ids_path, fr_vocab_path, tokenizer) 266 | data_to_token_ids(dev_path + ".in", en_dev_ids_path, en_vocab_path, tokenizer) 267 | 268 | return (en_train_ids_path, fr_train_ids_path, 269 | en_dev_ids_path, fr_dev_ids_path, 270 | en_vocab_path, fr_vocab_path) 271 | -------------------------------------------------------------------------------- /utils/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def gaussian_diag_logps(mean, logvar, sample=None): 6 | if sample is None: 7 | noise = tf.random_normal(tf.shape(mean)) 8 | sample = mean + tf.exp(0.5 * logvar) * noise 9 | 10 | return -0.5 * (np.log(2 * np.pi) + logvar + tf.square(sample - mean) / tf.exp(logvar)) 11 | 12 | 13 | class DiagonalGaussian(object): 14 | 15 | def __init__(self, mean, logvar, sample=None): 16 | self.mean = mean 17 | self.logvar = logvar 18 | 19 | if sample is None: 20 | noise = tf.random_normal(tf.shape(mean)) 21 | sample = mean + tf.exp(0.5 * logvar) * noise 22 | self.sample = sample 23 | 24 | def logps(self, sample): 25 | return gaussian_diag_logps(self.mean, self.logvar, sample) 26 | 27 | 28 | def discretized_logistic(mean, logscale, binsize=1 / 256.0, sample=None): 29 | scale = tf.exp(logscale) 30 | sample = (tf.floor(sample / binsize) * binsize - mean) / scale 31 | logp = tf.log(tf.sigmoid(sample + binsize / scale) - tf.sigmoid(sample) + 1e-7) 32 | return tf.reduce_sum(logp, [1, 2, 3]) 33 | 34 | 35 | def logsumexp(x): 36 | x_max = tf.reduce_max(x, [1], keep_dims=True) 37 | return tf.reshape(x_max, [-1]) + tf.log(tf.reduce_sum(tf.exp(x - x_max), [1])) 38 | 39 | 40 | def repeat(x, n): 41 | if n == 1: 42 | return x 43 | 44 | shape = map(int, x.get_shape().as_list()) 45 | shape[0] *= n 46 | idx = tf.range(tf.shape(x)[0]) 47 | idx = tf.reshape(idx, [-1, 1]) 48 | idx = tf.tile(idx, [1, n]) 49 | idx = tf.reshape(idx, [-1]) 50 | x = tf.gather(x, idx) 51 | x.set_shape(shape) 52 | return x 53 | 54 | 55 | def compute_lowerbound(log_pxz, sum_kl_costs, k=1): 56 | if k == 1: 57 | return sum_kl_costs - log_pxz 58 | 59 | # log 1/k \sum p(x | z) * p(z) / q(z | x) = -log(k) + logsumexp(log p(x|z) + log p(z) - log q(z|x)) 60 | log_pxz = tf.reshape(log_pxz, [-1, k]) 61 | sum_kl_costs = tf.reshape(sum_kl_costs, [-1, k]) 62 | return - (- tf.log(float(k)) + logsumexp(log_pxz - sum_kl_costs)) 63 | -------------------------------------------------------------------------------- /vrae.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Binary for training translation models and decoding from them. 17 | 18 | Running this program without --decode will download the WMT corpus into 19 | the directory specified as --data_dir and tokenize it in a very basic way, 20 | and then start training a model saving checkpoints to --train_dir. 21 | 22 | Running with --decode starts an interactive loop so you can see how 23 | the current checkpoint translates English sentences into French. 24 | 25 | See the following papers for more information on neural translation models. 26 | * http://arxiv.org/abs/1409.3215 27 | * http://arxiv.org/abs/1409.0473 28 | * http://arxiv.org/abs/1412.2007 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import math 35 | import os 36 | import sys 37 | import time 38 | import logging 39 | import json 40 | 41 | import numpy as np 42 | from six.moves import xrange # pylint: disable=redefined-builtin 43 | import tensorflow as tf 44 | 45 | import utils.data_utils as data_utils 46 | import seq2seq_model 47 | from tensorflow.python.platform import gfile 48 | 49 | tf.app.flags.DEFINE_string("model_dir", "models", "directory of the model.") 50 | tf.app.flags.DEFINE_boolean("new", True, "whether this is a new model or not.") 51 | tf.app.flags.DEFINE_string("do", "train", "what to do. accepts train, interpolate, sample, and decode.") 52 | tf.app.flags.DEFINE_string("input", None, "input filename for reconstruct sample, and interpolate.") 53 | tf.app.flags.DEFINE_string("output", None, "output filename for reconstruct sample, and interpolate.") 54 | 55 | FLAGS = tf.app.flags.FLAGS 56 | 57 | def prelu(x): 58 | with tf.variable_scope("prelu") as scope: 59 | alphas = tf.get_variable("alphas", [], initializer=tf.constant_initializer(0.0), dtype=tf.float32) 60 | return tf.nn.relu(x) - tf.mul(alphas, tf.nn.relu(-x)) 61 | 62 | 63 | # We use a number of buckets and pad to the closest one for efficiency. 64 | # See seq2seq_model.Seq2SeqModel for details of how they work. 65 | 66 | 67 | def read_data(source_path, target_path, config, max_size=None): 68 | """Read data from source and target files and put into buckets. 69 | 70 | Args: 71 | source_path: path to the files with token-ids for the source language. 72 | target_path: path to the file with token-ids for the target language; 73 | it must be aligned with the source file: n-th line contains the desired 74 | output for n-th line from the source_path. 75 | max_size: maximum number of lines to read, all other will be ignored; 76 | if 0 or None, data files will be read completely (no limit). 77 | 78 | Returns: 79 | data_set: a list of length len(config.buckets); data_set[n] contains a list of 80 | (source, target) pairs read from the provided data files that fit 81 | into the n-th bucket, i.e., such that len(source) < config.buckets[n][0] and 82 | len(target) < config.buckets[n][1]; source and target are lists of token-ids. 83 | """ 84 | data_set = [[] for _ in config.buckets] 85 | with tf.gfile.GFile(source_path, mode="r") as source_file: 86 | with tf.gfile.GFile(target_path, mode="r") as target_file: 87 | source, target = source_file.readline(), target_file.readline() 88 | counter = 0 89 | while source and target and (not max_size or counter < max_size): 90 | counter += 1 91 | if counter % 100000 == 0: 92 | print(" reading data line %d" % counter) 93 | sys.stdout.flush() 94 | source_ids = [int(x) for x in source.split()] 95 | target_ids = [int(x) for x in target.split()] 96 | target_ids.append(data_utils.EOS_ID) 97 | for bucket_id, (source_size, target_size) in enumerate(config.buckets): 98 | if len(source_ids) < source_size and len(target_ids) < target_size: 99 | data_set[bucket_id].append([source_ids, target_ids]) 100 | break 101 | source, target = source_file.readline(), target_file.readline() 102 | return data_set 103 | 104 | 105 | def create_model(session, config, forward_only): 106 | """Create translation model and initialize or load parameters in session.""" 107 | dtype = tf.float32 108 | optimizer = None 109 | if not forward_only: 110 | optimizer = tf.train.AdamOptimizer(config.learning_rate) 111 | if config.activation == "elu": 112 | activation = tf.nn.elu 113 | elif config.activation == "prelu": 114 | activation = prelu 115 | else: 116 | activation = tf.identity 117 | 118 | weight_initializer = tf.orthogonal_initializer if config.orthogonal_initializer else tf.uniform_unit_scaling_initializer 119 | bias_initializer = tf.zeros_initializer 120 | 121 | model = seq2seq_model.Seq2SeqModel( 122 | config.en_vocab_size, 123 | config.fr_vocab_size, 124 | config.buckets, 125 | config.size, 126 | config.num_layers, 127 | config.latent_dim, 128 | config.max_gradient_norm, 129 | config.batch_size, 130 | config.learning_rate, 131 | config.kl_min, 132 | config.word_dropout_keep_prob, 133 | config.anneal, 134 | config.use_lstm, 135 | optimizer=optimizer, 136 | activation=activation, 137 | forward_only=forward_only, 138 | feed_previous=config.feed_previous, 139 | bidirectional=config.bidirectional, 140 | weight_initializer=weight_initializer, 141 | bias_initializer=bias_initializer, 142 | iaf=config.iaf, 143 | dtype=dtype) 144 | ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir) 145 | if not FLAGS.new and ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 146 | print("Reading model parameters from %s" % ckpt.model_checkpoint_path) 147 | model.saver.restore(session, ckpt.model_checkpoint_path) 148 | else: 149 | print("Created model with fresh parameters.") 150 | session.run(tf.global_variables_initializer()) 151 | return model 152 | 153 | 154 | def train(config): 155 | """Train a en->fr translation model using WMT data.""" 156 | # Prepare WMT data. 157 | print("Preparing WMT data in %s" % config.data_dir) 158 | en_train, fr_train, en_dev, fr_dev, _, _ = data_utils.prepare_wmt_data( 159 | config.data_dir, config.en_vocab_size, config.fr_vocab_size, config.load_embeddings) 160 | 161 | with tf.Session() as sess: 162 | if not os.path.exists(FLAGS.model_dir): 163 | os.makedirs(FLAGS.model_dir) 164 | 165 | # Create model. 166 | print("Creating %d layers of %d units." % (config.num_layers, config.size)) 167 | model = create_model(sess, config, False) 168 | 169 | if not config.probabilistic: 170 | self.kl_rate_update(0.0) 171 | 172 | train_writer = tf.summary.FileWriter(os.path.join(FLAGS.model_dir,"train"), graph=sess.graph) 173 | dev_writer = tf.summary.FileWriter(os.path.join(FLAGS.model_dir, "test"), graph=sess.graph) 174 | 175 | # Read data into buckets and compute their sizes. 176 | print ("Reading development and training data (limit: %d)." 177 | % config.max_train_data_size) 178 | 179 | dev_set = read_data(en_dev, fr_dev, config) 180 | train_set = read_data(en_train, fr_train, config, config.max_train_data_size) 181 | train_bucket_sizes = [len(train_set[b]) for b in xrange(len(config.buckets))] 182 | train_total_size = float(sum(train_bucket_sizes)) 183 | 184 | # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use 185 | # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to 186 | # the size if i-th training bucket, as used later. 187 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size 188 | for i in xrange(len(train_bucket_sizes))] 189 | 190 | # This is the training loop. 191 | step_time, loss = 0.0, 0.0 192 | KL_loss = 0.0 193 | current_step = model.global_step.eval() 194 | step_loss_summaries = [] 195 | step_KL_loss_summaries = [] 196 | overall_start_time = time.time() 197 | while True: 198 | # Choose a bucket according to data distribution. We pick a random number 199 | # in [0, 1] and use the corresponding interval in train_buckets_scale. 200 | random_number_01 = np.random.random_sample() 201 | bucket_id = min([i for i in xrange(len(train_buckets_scale)) 202 | if train_buckets_scale[i] > random_number_01]) 203 | 204 | # Get a batch and make a step. 205 | start_time = time.time() 206 | encoder_inputs, decoder_inputs, target_weights = model.get_batch( 207 | train_set, bucket_id) 208 | _, step_loss, step_KL_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, 209 | target_weights, bucket_id, False, config.probabilistic) 210 | 211 | if config.anneal and model.global_step.eval() > config.kl_rate_rise_time and model.kl_rate < 1: 212 | new_kl_rate = model.kl_rate.eval() + config.kl_rate_rise_factor 213 | sess.run(model.kl_rate_update, feed_dict={'new_kl_rate': new_kl_rate}) 214 | 215 | step_time += (time.time() - start_time) / config.steps_per_checkpoint 216 | step_loss_summaries.append(tf.Summary(value=[tf.Summary.Value(tag="step loss", simple_value=float(step_loss))])) 217 | step_KL_loss_summaries.append(tf.Summary(value=[tf.Summary.Value(tag="KL step loss", simple_value=float(step_KL_loss))])) 218 | loss += step_loss / config.steps_per_checkpoint 219 | KL_loss += step_KL_loss / config.steps_per_checkpoint 220 | current_step = model.global_step.eval() 221 | 222 | # Once in a while, we save checkpoint, print statistics, and run evals. 223 | if current_step % config.steps_per_checkpoint == 0: 224 | # Print statistics for the previous epoch. 225 | perplexity = math.exp(float(loss)) if loss < 300 else float("inf") 226 | print ("global step %d learning rate %.4f step-time %.2f perplexity " 227 | "%.2f" % (model.global_step.eval(), model.learning_rate.eval(), 228 | step_time, perplexity)) 229 | 230 | print ("global step %d learning rate %.4f step-time %.2f KL divergence " 231 | "%.2f" % (model.global_step.eval(), model.learning_rate.eval(), 232 | step_time, KL_loss)) 233 | wall_time = time.time() - overall_start_time 234 | print("time passed: {0}".format(wall_time)) 235 | 236 | # Add perplexity, KL divergence to summary and stats. 237 | perp_summary = tf.Summary(value=[tf.Summary.Value(tag="train perplexity", simple_value=perplexity)]) 238 | train_writer.add_summary(perp_summary, current_step) 239 | KL_loss_summary = tf.Summary(value=[tf.Summary.Value(tag="KL divergence", simple_value=KL_loss)]) 240 | train_writer.add_summary(KL_loss_summary, current_step) 241 | for i, summary in enumerate(step_loss_summaries): 242 | train_writer.add_summary(summary, current_step - 200 + i) 243 | step_loss_summaries = [] 244 | for i, summary in enumerate(step_KL_loss_summaries): 245 | train_writer.add_summary(summary, current_step - 200 + i) 246 | step_KL_loss_summaries = [] 247 | 248 | # Save checkpoint and zero timer and loss. 249 | checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name + ".ckpt") 250 | model.saver.save(sess, checkpoint_path, global_step=model.global_step) 251 | step_time, loss, KL_loss = 0.0, 0.0, 0.0 252 | 253 | # Run evals on development set and print their perplexity. 254 | eval_losses = [] 255 | eval_KL_losses = [] 256 | eval_bucket_num = 0 257 | for bucket_id in xrange(len(config.buckets)): 258 | if len(dev_set[bucket_id]) == 0: 259 | print(" eval: empty bucket %d" % (bucket_id)) 260 | continue 261 | eval_bucket_num += 1 262 | encoder_inputs, decoder_inputs, target_weights = model.get_batch( 263 | dev_set, bucket_id) 264 | _, eval_loss, eval_KL_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, 265 | target_weights, bucket_id, True, config.probabilistic) 266 | eval_losses.append(float(eval_loss)) 267 | eval_KL_losses.append(float(eval_KL_loss)) 268 | eval_ppx = math.exp(float(eval_loss)) if eval_loss < 300 else float( 269 | "inf") 270 | print(" eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx)) 271 | 272 | eval_perp_summary = tf.Summary(value=[tf.Summary.Value(tag="eval perplexity for bucket {0}".format(bucket_id), simple_value=eval_ppx)]) 273 | dev_writer.add_summary(eval_perp_summary, current_step) 274 | 275 | mean_eval_loss = sum(eval_losses) / float(eval_bucket_num) 276 | mean_eval_KL_loss = sum(eval_KL_losses) / float(eval_bucket_num) 277 | mean_eval_ppx = math.exp(float(mean_eval_loss)) 278 | print(" eval: mean perplexity {0}".format(mean_eval_ppx)) 279 | 280 | eval_loss_summary = tf.Summary(value=[tf.Summary.Value(tag="mean eval loss", simple_value=float(mean_eval_ppx))]) 281 | dev_writer.add_summary(eval_loss_summary, current_step) 282 | eval_KL_loss_summary = tf.Summary(value=[tf.Summary.Value(tag="mean eval loss", simple_value=float(mean_eval_KL_loss))]) 283 | dev_writer.add_summary(eval_KL_loss_summary, current_step) 284 | 285 | 286 | def reconstruct(sess, model, config): 287 | model.batch_size = 1 # We decode one sentence at a time. 288 | model.probabilistic = config.probabilistic 289 | beam_size = config.beam_size 290 | 291 | # Load vocabularies. 292 | en_vocab_path = os.path.join(config.data_dir, 293 | "vocab%d.in" % config.en_vocab_size) 294 | fr_vocab_path = os.path.join(config.data_dir, 295 | "vocab%d.out" % config.fr_vocab_size) 296 | en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path) 297 | _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path) 298 | 299 | # Decode from standard input. 300 | outputs = [] 301 | with gfile.GFile(FLAGS.input, "r") as fs: 302 | sentences = fs.readlines() 303 | for i, sentence in enumerate(sentences): 304 | # Get token-ids for the input sentence. 305 | token_ids = data_utils.sentence_to_token_ids(sentence, en_vocab) 306 | # Which bucket does it belong to? 307 | bucket_id = len(config.buckets) - 1 308 | for i, bucket in enumerate(config.buckets): 309 | if bucket[0] >= len(token_ids): 310 | bucket_id = i 311 | break 312 | else: 313 | logging.warning("Sentence truncated: %s", sentence) 314 | 315 | encoder_inputs, decoder_inputs, target_weights = model.get_batch( 316 | {bucket_id: [(token_ids, [])]}, bucket_id) 317 | 318 | if beam_size > 1: 319 | path, symbol, output_logits = model.step(sess, encoder_inputs, decoder_inputs, 320 | target_weights, bucket_id, True, config.probabilistic, beam_size) 321 | 322 | k = output_logits[0] 323 | paths = [] 324 | for kk in range(beam_size): 325 | paths.append([]) 326 | curr = range(beam_size) 327 | num_steps = len(path) 328 | for i in range(num_steps-1, -1, -1): 329 | for kk in range(beam_size): 330 | paths[kk].append(symbol[i][curr[kk]]) 331 | curr[kk] = path[i][curr[kk]] 332 | recos = set() 333 | for kk in range(beam_size): 334 | output = [int(logit) for logit in paths[kk][::-1]] 335 | 336 | if EOS_ID in output: 337 | output = output[:output.index(EOS_ID)] 338 | output = " ".join([rev_fr_vocab[word] for word in output]) + "\n" 339 | outputs.append(output) 340 | 341 | else: 342 | # Get output logits for the sentence. 343 | _, _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, 344 | target_weights, bucket_id, True, config.probabilistic) 345 | # This is a greedy decoder - outputs are just argmaxes of output_logits. 346 | output = [int(np.argmax(logit, axis=1)) for logit in output_logits] 347 | # If there is an EOS symbol in outputs, cut them at that point. 348 | if data_utils.EOS_ID in output: 349 | output = output[:output.index(data_utils.EOS_ID)] 350 | output = " ".join([rev_fr_vocab[word] for word in output]) + "\n" 351 | outputs.append(output) 352 | with gfile.GFile(FLAGS.output, "w") as enc_dec_f: 353 | for output in outputs: 354 | enc_dec_f.write(output) 355 | 356 | 357 | def encode(sess, model, config, sentences): 358 | # Load vocabularies. 359 | en_vocab_path = os.path.join(config.data_dir, 360 | "vocab%d.in" % config.en_vocab_size) 361 | fr_vocab_path = os.path.join(config.data_dir, 362 | "vocab%d.out" % config.fr_vocab_size) 363 | en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path) 364 | _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path) 365 | 366 | means = [] 367 | logvars = [] 368 | for i, sentence in enumerate(sentences): 369 | # Get token-ids for the input sentence. 370 | token_ids = data_utils.sentence_to_token_ids(sentence, en_vocab) 371 | # Which bucket does it belong to? 372 | bucket_id = len(config.buckets) - 1 373 | for i, bucket in enumerate(config.buckets): 374 | if bucket[0] >= len(token_ids): 375 | bucket_id = i 376 | break 377 | else: 378 | logging.warning("Sentence truncated: %s", sentence) 379 | 380 | # Get a 1-element batch to feed the sentence to the model. 381 | encoder_inputs, _, _ = model.get_batch( 382 | {bucket_id: [(token_ids, [])]}, bucket_id) 383 | # Get output logits for the sentence. 384 | mean, logvar = model.encode_to_latent(sess, encoder_inputs, bucket_id) 385 | means.append(mean) 386 | logvars.append(logvar) 387 | 388 | return means, logvars 389 | 390 | 391 | def decode(sess, model, config, means, logvars, bucket_id): 392 | fr_vocab_path = os.path.join(config.data_dir, 393 | "vocab%d.out" % config.fr_vocab_size) 394 | _, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path) 395 | 396 | _, decoder_inputs, target_weights = model.get_batch( 397 | {bucket_id: [([], [])]}, bucket_id) 398 | outputs = [] 399 | for mean, logvar in zip(means, logvars): 400 | mean = mean.reshape(1,-1) 401 | logvar = logvar.reshape(1,-1) 402 | output_logits = model.decode_from_latent(sess, mean, logvar, bucket_id, decoder_inputs, target_weights) 403 | output = [int(np.argmax(logit, axis=1)) for logit in output_logits] 404 | # If there is an EOS symbol in outputs, cut them at that point. 405 | if data_utils.EOS_ID in output: 406 | output = output[:output.index(data_utils.EOS_ID)] 407 | output = " ".join([rev_fr_vocab[word] for word in output]) + "\n" 408 | outputs.append(output) 409 | 410 | return outputs 411 | # Print out French sentence corresponding to outputs. 412 | 413 | def n_sample(sess, model, config): 414 | bucket_id = len(config.buckets) - 1 415 | with gfile.GFile(FLAGS.input, "r") as fs: 416 | sentences = fs.readlines() 417 | mean, logvar = encode(sess, model, config, sentences) 418 | mean = mean[0][0] 419 | logvar = logvar[0][0] 420 | means = [mean] * config.num_pts 421 | neg_inf_logvar = np.full(logvar.shape, -800.0, dtype=np.float32) 422 | logvars = [neg_inf_logvar] + [logvar] * (config.num_pts - 1) 423 | outputs = decode(sess, model, config, means, logvars, bucket_id) 424 | with gfile.GFile(FLAGS.output, "w") as sample_f: 425 | for output in outputs: 426 | sample_f.write(output) 427 | 428 | 429 | def interpolate(sess, model, config, means, logvars, num_pts): 430 | if len(means) != 2: 431 | raise ValueError("there should be two sentences when interpolating." 432 | "number of setences: %d." % len(means)) 433 | if num_pts < 3: 434 | raise ValueError("there should be more than two points when interpolating." 435 | "number of points: %d." % num_pts) 436 | pts = [] 437 | for s, e in zip(means[0][0].tolist(),means[1][0].tolist()): 438 | pts.append(np.linspace(s, e, num_pts)) 439 | 440 | pts = np.array(pts) 441 | pts = pts.T 442 | pts = [np.array(pt) for pt in pts.tolist()] 443 | bucket_id = len(config.buckets) - 1 444 | logvars = [np.full(pt.shape, -800.0, dtype=np.float32) for pt in pts] 445 | outputs = decode(sess, model, config, pts, logvars, bucket_id) 446 | 447 | return outputs 448 | 449 | def encode_interpolate(sess, model, config): 450 | with gfile.GFile(FLAGS.input, "r") as fs: 451 | sentences = fs.readlines() 452 | model.batch_size = 1 453 | model.probabilistic = config.probabilistic 454 | means, logvars = encode(sess, model, config, sentences) 455 | outputs = interpolate(sess, model, config, means, logvars, config.num_pts) 456 | with gfile.GFile(FLAGS.output, "w") as interp_f: 457 | for output in outputs: 458 | interp_f.write(output) 459 | 460 | class Struct(object): 461 | def __init__(self, **entries): 462 | self.__dict__.update(entries) 463 | if not self.__dict__.get("kl_min"): 464 | self.__dict__.update({ "kl_min": None }) 465 | if not self.__dict__.get("max_gradient_norm"): 466 | self.__dict__.update({ "max_gradient_norm": 5.0 }) 467 | if not self.__dict__.get("load_embeddings"): 468 | self.__dict__.update({ "load_embeddings": False }) 469 | if not self.__dict__.get("batch_size"): 470 | self.__dict__.update({ "batch_size": 1 }) 471 | if not self.__dict__.get("learning_rate"): 472 | self.__dict__.update({ "learning_rate": 0.001 }) 473 | if not self.__dict__.get("anneal"): 474 | self.__dict__.update({ "anneal": False }) 475 | if not self.__dict__.get("beam_size"): 476 | self.__dict__.update({ "beam_size": 1 }) 477 | if self.__dict__.get("beam_size") > 1: 478 | raise NotImplementedError("Beam search is still under implementation.") 479 | def update(self, **entries): 480 | self.__dict__.update(entries) 481 | 482 | 483 | def main(_): 484 | 485 | with open(os.path.join(FLAGS.model_dir, "config.json")) as config_file: 486 | configs = json.load(config_file) 487 | 488 | FLAGS.model_name = os.path.basename(os.path.normpath(FLAGS.model_dir)) 489 | behavior = ["train", "interpolate", "reconstruct", "sample"] 490 | if FLAGS.do not in behavior: 491 | raise ValueError("argument \"do\" is not one of the following: train, interpolate, decode or sample.") 492 | 493 | if FLAGS.do != "train": 494 | FLAGS.new = False 495 | 496 | config = Struct(**configs["model"]) 497 | config.update(**configs[FLAGS.do]) 498 | interp_config = Struct(**configs["model"]) 499 | interp_config.update(**configs["interpolate"]) 500 | enc_dec_config = Struct(**configs["model"]) 501 | enc_dec_config.update(**configs["reconstruct"]) 502 | sample_config = Struct(**configs["model"]) 503 | sample_config.update(**configs["sample"]) 504 | 505 | if FLAGS.do == "reconstruct": 506 | with tf.Session() as sess: 507 | model = create_model(sess, enc_dec_config, True) 508 | reconstruct(sess, model, enc_dec_config) 509 | elif FLAGS.do == "interpolate": 510 | with tf.Session() as sess: 511 | model = create_model(sess, interp_config, True) 512 | encode_interpolate(sess, model, interp_config) 513 | elif FLAGS.do == "sample": 514 | with tf.Session() as sess: 515 | model = create_model(sess, sample_config, True) 516 | n_sample(sess, model, config) 517 | elif FLAGS.do == "train": 518 | train(config) 519 | 520 | if __name__ == "__main__": 521 | tf.app.run() 522 | --------------------------------------------------------------------------------