├── .gitignore ├── LICENSE ├── README.md ├── TextCorrector.ipynb ├── correct_text.py ├── data_reader.py ├── dtc_lambda.py ├── preprocessors ├── __init__.py └── preprocess_movie_dialogs.py ├── seq2seq.py ├── text_corrector_data_readers.py └── text_corrector_models.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | .ipynb_checkpoints/ 3 | *.pyc 4 | *.swp 5 | .DS_Store 6 | *.zip 7 | upload_lambda_s3.sh 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Text Corrector 2 | 3 | Deep Text Corrector uses [TensorFlow](https://www.tensorflow.org/) to train sequence-to-sequence models that are capable of automatically correcting small grammatical errors in conversational written English (e.g. SMS messages). 4 | It does this by taking English text samples that are known to be mostly grammatically correct and randomly introducing a handful of small grammatical errors (e.g. removing articles) to each sentence to produce input-output pairs (where the output is the original sample), which are then used to train a sequence-to-sequence model. 5 | 6 | See [this blog post](http://atpaino.com/2017/01/03/deep-text-correcter.html) for a more thorough write-up of this work. 7 | 8 | ## Motivation 9 | While context-sensitive spell-check systems are able to automatically correct a large number of input errors in instant messaging, email, and SMS messages, they are unable to correct even simple grammatical errors. 10 | For example, the message "I'm going to store" would be unaffected by typical autocorrection systems, when the user most likely intendend to write "I'm going to _the_ store". 11 | These kinds of simple grammatical mistakes are common in so-called "learner English", and constructing systems capable of detecting and correcting these mistakes has been the subect of multiple [CoNLL shared tasks](http://www.aclweb.org/anthology/W14-1701.pdf). 12 | 13 | The goal of this project is to train sequence-to-sequence models that are capable of automatically correcting such errors. 14 | Specifically, the models are trained to provide a function mapping a potentially errant input sequence to a sequence with all (small) grammatical errors corrected. 15 | Given these models, it would be possible to construct tools to help correct these simple errors in written communications, such as emails, instant messaging, etc. 16 | 17 | ## Correcting Grammatical Errors with Deep Learning 18 | The basic idea behind this project is that we can generate large training datasets for the task of grammar correction by starting with grammatically correct samples and introducing small errors to produce input-output pairs, which can then be used to train a sequence-to-sequence models. 19 | The details of how we construct these datasets, train models using them, and produce predictions for this task are described below. 20 | 21 | ### Datasets 22 | To create a dataset for Deep Text Corrector models, we start with a large collection of mostly grammatically correct samples of conversational written English. 23 | The primary dataset considered in this project is the [Cornell Movie-Dialogs Corpus](http://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html), which contains over 300k lines from movie scripts. 24 | This was the largest collection of conversational written English I could find that was mostly grammatically correct. 25 | 26 | Given a sample of text like this, the next step is to generate input-output pairs to be used during training. 27 | This is done by: 28 | 1. Drawing a sample sentence from the dataset. 29 | 2. Setting the input sequence to this sentence after randomly applying certain perturbations. 30 | 3. Setting the output sequence to the unperturbed sentence. 31 | 32 | where the perturbations applied in step (2) are intended to introduce small grammatical errors which we would like the model to learn to correct. 33 | Thus far, these perturbations are limited to the: 34 | - subtraction of articles (a, an, the) 35 | - subtraction of the second part of a verb contraction (e.g. "'ve", "'ll", "'s", "'m") 36 | - replacement of a few common homophones with one of their counterparts (e.g. replacing "their" with "there", "then" with "than") 37 | 38 | The rates with which these perturbations are introduced are loosely based on figures taken from the [CoNLL 2014 Shared Task on Grammatical Error Correction](http://www.aclweb.org/anthology/W14-1701.pdf). 39 | In this project, each perturbation is applied in 25% of cases where it could potentially be applied. 40 | 41 | ### Training 42 | To artificially increase the dataset when training a sequence model, we perform the sampling strategy described above multiple times to arrive at 2-3x the number of input-output pairs. 43 | Given this augmented dataset, training proceeds in a very similar manner to [TensorFlow's sequence-to-sequence tutorial](https://www.tensorflow.org/tutorials/seq2seq/). 44 | That is, we train a sequence-to-sequence model using LSTM encoders and decoders with an attention mechanism as described in [Bahdanau et al., 2014](http://arxiv.org/abs/1409.0473) using stochastic gradient descent. 45 | 46 | ### Decoding 47 | 48 | Instead of using the most probable decoding according to the seq2seq model, this project takes advantage of the unique structure of the problem to impose the prior that all tokens in a decoded sequence should either exist in the input sequence or belong to a set of "corrective" tokens. 49 | The "corrective" token set is constructed during training and contains all tokens seen in the target, but not the source, for at least one sample in the training set. 50 | The intuition here is that the errors seen during training involve the misuse of a relatively small vocabulary of common words (e.g. "the", "an", "their") and that the model should only be allowed to perform corrections in this domain. 51 | 52 | This prior is carried out through a modification to the seq2seq model's decoding loop in addition to a post-processing step that resolves out-of-vocabulary (OOV) tokens: 53 | 54 | **Biased Decoding** 55 | 56 | To restrict the decoding such that it only ever chooses tokens from the input sequence or corrective token set, this project applies a binary mask to the model's logits prior to extracting the prediction to be fed into the next time step. 57 | This mask is constructed such that `mask[i] == 1.0 if (i in input or corrective_tokens) else 0.0`. 58 | Since this mask is applited to the result of a softmax transormation (which guarantees all outputs are non-negative), we can be sure that only input or corrective tokens are ever selected. 59 | 60 | Note that this logic is not used during training, as this would only serve to eliminate potentially useful signal from the model. 61 | 62 | **Handling OOV Tokens** 63 | 64 | Since the decoding bias described above is applied within the truncated vocabulary used by the model, we will still see the unknown token in its output for any OOV tokens. 65 | The more generic problem of resolving these OOV tokens is non-trivial (e.g. see [Addressing the Rare Word Problem in NMT](https://arxiv.org/pdf/1410.8206v4.pdf)), but in this project we can again take advantage of its unique structure to create a fairly straightforward OOV token resolution scheme. 66 | That is, if we assume the sequence of OOV tokens in the input is equal to the sequence of OOV tokens in the output sequence, then we can trivially assign the appropriate token to each "unknown" token encountered int he decoding. 67 | Empirically, and intuitively, this appears to be an appropriate assumption, as the relatively simple class of errors these models are being trained to address should never include mistakes that warrant the insertion or removal of a rare token. 68 | 69 | ## Experiments and Results 70 | 71 | Below are some anecdotal and aggregate results from experiments using the Deep Text Corrector model with the [Cornell Movie-Dialogs Corpus](http://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html). 72 | The dataset consists of 304,713 lines from movie scripts, of which 243,768 lines were used to train the model and 30,474 lines each were used for the validation and testing sets. 73 | The sets were selected such that no lines from the same movie were present in both the training and testing sets. 74 | 75 | The model being evaluated below is a sequence-to-sequence model, with attention, where the encoder and decoder were both 2-layer, 512 hidden unit LSTMs. 76 | The model was trained with a vocabulary of the 2k most common words seen in the training set. 77 | 78 | ### Aggregate Performance 79 | Below are reported the BLEU scores and accuracy numbers over the test dataset for both a trained model and a baseline, where the baseline is the identity function (which assumes no errors exist in the input). 80 | 81 | You'll notice that the model outperforms this baseline for all bucket sizes in terms of accuracy, and outperforms all but one in terms of BLEU score. 82 | This tells us that applying the Deep Text Corrector model to a potentially errant writing sample would, on average, result in a more grammatically correct writing sample. 83 | Anyone who tends to make errors similar to those the model has been trained on could therefore benefit from passing their messages through this model. 84 | 85 | ``` 86 | Bucket 0: (10, 10) 87 | Baseline BLEU = 0.8341 88 | Model BLEU = 0.8516 89 | Baseline Accuracy: 0.9083 90 | Model Accuracy: 0.9384 91 | Bucket 1: (15, 15) 92 | Baseline BLEU = 0.8850 93 | Model BLEU = 0.8860 94 | Baseline Accuracy: 0.8156 95 | Model Accuracy: 0.8491 96 | Bucket 2: (20, 20) 97 | Baseline BLEU = 0.8876 98 | Model BLEU = 0.8880 99 | Baseline Accuracy: 0.7291 100 | Model Accuracy: 0.7817 101 | Bucket 3: (40, 40) 102 | Baseline BLEU = 0.9099 103 | Model BLEU = 0.9045 104 | Baseline Accuracy: 0.6073 105 | Model Accuracy: 0.6425 106 | ``` 107 | 108 | ### Examples 109 | Decoding a sentence with a missing article: 110 | 111 | ``` 112 | In [31]: decode("Kvothe went to market") 113 | Out[31]: 'Kvothe went to the market' 114 | ``` 115 | 116 | Decoding a sentence with then/than confusion: 117 | 118 | ``` 119 | In [30]: decode("the Cardinals did better then the Cubs in the offseason") 120 | Out[30]: 'the Cardinals did better than the Cubs in the offseason' 121 | ``` 122 | 123 | 124 | ## Implementation Details 125 | This project reuses and slightly extends TensorFlow's [`Seq2SeqModel`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/rnn/translate/seq2seq_model.py), which itself implements a sequence-to-sequence model with an attention mechanism as described in https://arxiv.org/pdf/1412.7449v3.pdf. 126 | The primary contributions of this project are: 127 | 128 | - `data_reader.py`: an abstract class that defines the interface for classes which are capable of reading a source dataset and producing input-output pairs, where the input is a grammatically incorrect variant of a source sentence and the output is the original sentence. 129 | - `text_corrector_data_readers.py`: contains a few implementations of `DataReader`, one over the [Penn Treebank dataset](https://www.google.com/url?q=http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz&usg=AFQjCNG0IP5OHusdIAdJIrrem-HMck9AzA) and one over the [Cornell Movie-Dialogs Corpus](http://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html). 130 | - `text_corrector_models.py`: contains a version of `Seq2SeqModel` modified such that it implements the logic described in [Biased Decoding](#biased-decoding) 131 | - `correct_text.py`: a collection of helper functions that together allow for the training of a model and the usage of it to decode errant input sequences (at test time). The `decode` method defined here implements the [OOV token resolution logic](#handling-oov-tokens). This also defines a main method, and can be invoked from the command line. It was largely derived from TensorFlow's [`translate.py`](https://www.tensorflow.org/tutorials/seq2seq/). 132 | - `TextCorrector.ipynb`: an IPython notebook which ties together all of the above pieces to allow for the training and evaluation of the model in an interactive fashion. 133 | 134 | ### Example Usage 135 | Note: this project requires TensorFlow version >= 0.11. See [this page](https://www.tensorflow.org/get_started/os_setup) for setup instructions. 136 | 137 | **Preprocess Movie Dialog Data** 138 | ``` 139 | python preprocessors/preprocess_movie_dialogs.py --raw_data movie_lines.txt \ 140 | --out_file preprocessed_movie_lines.txt 141 | ``` 142 | This preprocessed file can then be split up however you like to create training, validation, and testing sets. 143 | 144 | **Training:** 145 | ``` 146 | python correct_text.py --train_path /movie_dialog_train.txt \ 147 | --val_path /movie_dialog_val.txt \ 148 | --config DefaultMovieDialogConfig \ 149 | --data_reader_type MovieDialogReader \ 150 | --model_path /movie_dialog_model 151 | ``` 152 | 153 | **Testing:** 154 | ``` 155 | python correct_text.py --test_path /movie_dialog_test.txt \ 156 | --config DefaultMovieDialogConfig \ 157 | --data_reader_type MovieDialogReader \ 158 | --model_path /movie_dialog_model \ 159 | --decode 160 | ``` 161 | 162 | -------------------------------------------------------------------------------- /TextCorrector.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "from __future__ import print_function\n", 12 | "\n", 13 | "import os\n", 14 | "import time\n", 15 | "import numpy as np\n", 16 | "import tensorflow as tf\n", 17 | "import pandas as pd\n", 18 | "from collections import defaultdict\n", 19 | "\n", 20 | "from sklearn.metrics import roc_auc_score, accuracy_score\n", 21 | "import nltk\n", 22 | "\n", 23 | "from correct_text import train, decode, decode_sentence, evaluate_accuracy, create_model,\\\n", 24 | " get_corrective_tokens, DefaultPTBConfig, DefaultMovieDialogConfig\n", 25 | "from text_correcter_data_readers import PTBDataReader, MovieDialogReader\n", 26 | "\n", 27 | "%matplotlib inline" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": { 34 | "collapsed": false 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "root_data_path = \"/Users/atpaino/data/textcorrecter/dialog_corpus\"\n", 39 | "train_path = os.path.join(root_data_path, \"movie_lines.txt\")\n", 40 | "val_path = os.path.join(root_data_path, \"cleaned_dialog_val.txt\")\n", 41 | "test_path = os.path.join(root_data_path, \"cleaned_dialog_test.txt\")\n", 42 | "model_path = os.path.join(root_data_path, \"dialog_correcter_model_testnltk\")\n", 43 | "config = DefaultMovieDialogConfig()" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "## Train" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 3, 56 | "metadata": { 57 | "collapsed": false 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "data_reader = MovieDialogReader(config, train_path)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 4, 67 | "metadata": { 68 | "collapsed": false 69 | }, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "Reading data; train = /Users/atpaino/data/textcorrecter/dialog_corpus/movie_lines.txt, test = /Users/atpaino/data/textcorrecter/dialog_corpus/cleaned_dialog_val.txt\n" 76 | ] 77 | }, 78 | { 79 | "ename": "KeyboardInterrupt", 80 | "evalue": "", 81 | "output_type": "error", 82 | "traceback": [ 83 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 84 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 85 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_reader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 86 | "\u001b[0;32m/Users/atpaino/github/deep-text-correcter/correct_text.pyc\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(data_reader, train_path, test_path, model_path)\u001b[0m\n\u001b[1;32m 138\u001b[0m \"Reading data; train = {}, test = {}\".format(train_path, test_path))\n\u001b[1;32m 139\u001b[0m \u001b[0mconfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata_reader\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 140\u001b[0;31m \u001b[0mtrain_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata_reader\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 141\u001b[0m \u001b[0mtest_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata_reader\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 87 | "\u001b[0;32m/Users/atpaino/github/deep-text-correcter/data_reader.pyc\u001b[0m in \u001b[0;36mbuild_dataset\u001b[0;34m(self, path)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;31m# dropouts.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset_copies\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0msource\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_samples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 128\u001b[0m for bucket_id, (source_size, target_size) in enumerate(\n\u001b[1;32m 129\u001b[0m self.config.buckets):\n", 88 | "\u001b[0;32m/Users/atpaino/github/deep-text-correcter/data_reader.pyc\u001b[0m in \u001b[0;36mread_samples\u001b[0;34m(self, path)\u001b[0m\n\u001b[1;32m 113\u001b[0m \"\"\"\n\u001b[1;32m 114\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0msource_words\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_words\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_samples_by_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 115\u001b[0;31m \u001b[0msource\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert_token_to_id\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mword\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mword\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msource_words\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 116\u001b[0m \u001b[0mtarget\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert_token_to_id\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mword\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mword\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtarget_words\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mEOS_ID\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 89 | "\u001b[0;32m/Users/atpaino/github/deep-text-correcter/data_reader.pyc\u001b[0m in \u001b[0;36mconvert_token_to_id\u001b[0;34m(self, token)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;32mreturn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 78\u001b[0m \"\"\"\n\u001b[0;32m---> 79\u001b[0;31m \u001b[0mtoken_with_id\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoken\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtoken\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoken_to_id\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munknown_token\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoken_to_id\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtoken_with_id\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 90 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "train(data_reader, train_path, val_path, model_path)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "## Decode sentences" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 3, 108 | "metadata": { 109 | "collapsed": true 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "data_reader = MovieDialogReader(config, train_path, dropout_prob=0.25, replacement_prob=0.25, dataset_copies=1)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "metadata": { 120 | "collapsed": true 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "corrective_tokens = get_corrective_tokens(data_reader, train_path)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 6, 130 | "metadata": { 131 | "collapsed": true 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "import pickle\n", 136 | "with open(os.path.join(root_data_path, \"corrective_tokens.pickle\"), \"w\") as f:\n", 137 | " pickle.dump(corrective_tokens, f)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 8, 143 | "metadata": { 144 | "collapsed": false 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "import pickle\n", 149 | "with open(os.path.join(root_data_path, \"token_to_id.pickle\"), \"w\") as f:\n", 150 | " pickle.dump(data_reader.token_to_id, f)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 5, 156 | "metadata": { 157 | "collapsed": false 158 | }, 159 | "outputs": [ 160 | { 161 | "name": "stdout", 162 | "output_type": "stream", 163 | "text": [ 164 | "Reading model parameters from /Users/atpaino/data/textcorrecter/dialog_corpus/dialog_correcter_model/translate.ckpt-41900\n" 165 | ] 166 | } 167 | ], 168 | "source": [ 169 | "sess = tf.InteractiveSession()\n", 170 | "model = create_model(sess, True, model_path, config=config)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 7, 176 | "metadata": { 177 | "collapsed": false, 178 | "scrolled": false 179 | }, 180 | "outputs": [ 181 | { 182 | "name": "stdout", 183 | "output_type": "stream", 184 | "text": [ 185 | "Input: you must have girlfriend\n", 186 | "Output: you must have a girlfriend\n", 187 | "\n" 188 | ] 189 | } 190 | ], 191 | "source": [ 192 | "# Test a sample from the test dataset.\n", 193 | "decoded = decode_sentence(sess, model, data_reader, \"you must have girlfriend\", corrective_tokens=corrective_tokens)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 7, 199 | "metadata": { 200 | "collapsed": false 201 | }, 202 | "outputs": [ 203 | { 204 | "ename": "NameError", 205 | "evalue": "name 'decoded' is not defined", 206 | "output_type": "error", 207 | "traceback": [ 208 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 209 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 210 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdecoded\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 211 | "\u001b[0;31mNameError\u001b[0m: name 'decoded' is not defined" 212 | ] 213 | } 214 | ], 215 | "source": [ 216 | "decoded" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 6, 222 | "metadata": { 223 | "collapsed": false 224 | }, 225 | "outputs": [ 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "Input: did n't you say that they 're going to develop this revolutionary new thing ...\n", 231 | "Output: did n't you say that they 're going to develop this revolutionary new thing ...\n", 232 | "\n" 233 | ] 234 | } 235 | ], 236 | "source": [ 237 | "decoded = decode_sentence(sess, model, data_reader,\n", 238 | " \"did n't you say that they 're going to develop this revolutionary new thing ...\",\n", 239 | " corrective_tokens=corrective_tokens)" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 9, 245 | "metadata": { 246 | "collapsed": false 247 | }, 248 | "outputs": [ 249 | { 250 | "data": { 251 | "text/plain": [ 252 | "['kvothe', 'went', 'to', 'the', 'market']" 253 | ] 254 | }, 255 | "execution_count": 9, 256 | "metadata": {}, 257 | "output_type": "execute_result" 258 | } 259 | ], 260 | "source": [ 261 | "decode_sentence(sess, model, data_reader, \"kvothe went to market\", corrective_tokens=corrective_tokens, verbose=False)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 10, 267 | "metadata": { 268 | "collapsed": false 269 | }, 270 | "outputs": [ 271 | { 272 | "data": { 273 | "text/plain": [ 274 | "['blablahblah', 'and', 'bladdddd', 'went', 'to', 'the', 'market']" 275 | ] 276 | }, 277 | "execution_count": 10, 278 | "metadata": {}, 279 | "output_type": "execute_result" 280 | } 281 | ], 282 | "source": [ 283 | "decode_sentence(sess, model, data_reader, \"blablahblah and bladdddd went to market\", corrective_tokens=corrective_tokens,\n", 284 | " verbose=False)" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 11, 290 | "metadata": { 291 | "collapsed": false 292 | }, 293 | "outputs": [ 294 | { 295 | "data": { 296 | "text/plain": [ 297 | "['do', 'you', 'have', 'a', 'book']" 298 | ] 299 | }, 300 | "execution_count": 11, 301 | "metadata": {}, 302 | "output_type": "execute_result" 303 | } 304 | ], 305 | "source": [ 306 | "decode_sentence(sess, model, data_reader, \"do you have book\", corrective_tokens=corrective_tokens, verbose=False)" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 10, 312 | "metadata": { 313 | "collapsed": false 314 | }, 315 | "outputs": [ 316 | { 317 | "data": { 318 | "text/plain": [ 319 | "['the', 'cardinals', 'did', 'better', 'than', 'the', 'cubs']" 320 | ] 321 | }, 322 | "execution_count": 10, 323 | "metadata": {}, 324 | "output_type": "execute_result" 325 | } 326 | ], 327 | "source": [ 328 | "decode_sentence(sess, model, data_reader, \"the cardinals did better then the cubs\", corrective_tokens=corrective_tokens, verbose=False)" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 6, 334 | "metadata": { 335 | "collapsed": false 336 | }, 337 | "outputs": [ 338 | { 339 | "name": "stdout", 340 | "output_type": "stream", 341 | "text": [ 342 | "Bucket 0: (10, 10)\n", 343 | "\tBaseline BLEU = 0.8354\n", 344 | "\tModel BLEU = 0.8492\n", 345 | "\tBaseline Accuracy: 0.9090\n", 346 | "\tModel Accuracy: 0.9354\n", 347 | "Bucket 1: (15, 15)\n", 348 | "\tBaseline BLEU = 0.8826\n", 349 | "\tModel BLEU = 0.8595\n", 350 | "\tBaseline Accuracy: 0.8055\n", 351 | "\tModel Accuracy: 0.8149\n", 352 | "Bucket 2: (20, 20)\n", 353 | "\tBaseline BLEU = 0.8880\n", 354 | "\tModel BLEU = 0.8216\n", 355 | "\tBaseline Accuracy: 0.7301\n", 356 | "\tModel Accuracy: 0.6689\n", 357 | "Bucket 3: (40, 40)\n", 358 | "\tBaseline BLEU = 0.9097\n", 359 | "\tModel BLEU = 0.6357\n", 360 | "\tBaseline Accuracy: 0.5981\n", 361 | "\tModel Accuracy: 0.2283\n" 362 | ] 363 | } 364 | ], 365 | "source": [ 366 | "# 4 layers, 40k steps\n", 367 | "errors = evaluate_accuracy(sess, model, data_reader, corrective_tokens, test_path)#, max_samples=1000)" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 9, 373 | "metadata": { 374 | "collapsed": false 375 | }, 376 | "outputs": [ 377 | { 378 | "name": "stdout", 379 | "output_type": "stream", 380 | "text": [ 381 | "Bucket 0: (10, 10)\n", 382 | "\tBaseline BLEU = 0.8368\n", 383 | "\tModel BLEU = 0.8425\n", 384 | "\tBaseline Accuracy: 0.9110\n", 385 | "\tModel Accuracy: 0.9303\n", 386 | "Bucket 1: (15, 15)\n", 387 | "\tBaseline BLEU = 0.8818\n", 388 | "\tModel BLEU = 0.8459\n", 389 | "\tBaseline Accuracy: 0.8063\n", 390 | "\tModel Accuracy: 0.8014\n", 391 | "Bucket 2: (20, 20)\n", 392 | "\tBaseline BLEU = 0.8891\n", 393 | "\tModel BLEU = 0.7986\n", 394 | "\tBaseline Accuracy: 0.7309\n", 395 | "\tModel Accuracy: 0.6281\n", 396 | "Bucket 3: (40, 40)\n", 397 | "\tBaseline BLEU = 0.9099\n", 398 | "\tModel BLEU = 0.5997\n", 399 | "\tBaseline Accuracy: 0.6007\n", 400 | "\tModel Accuracy: 0.1607\n" 401 | ] 402 | } 403 | ], 404 | "source": [ 405 | "# 4 layers, 30k steps\n", 406 | "errors = evaluate_accuracy(sess, model, data_reader, corrective_tokens, test_path)#, max_samples=1000)" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": 13, 412 | "metadata": { 413 | "collapsed": false 414 | }, 415 | "outputs": [ 416 | { 417 | "name": "stdout", 418 | "output_type": "stream", 419 | "text": [ 420 | "Bucket 0: (10, 10)\n", 421 | "\tBaseline BLEU = 0.8330\n", 422 | "\tModel BLEU = 0.8335\n", 423 | "\tBaseline Accuracy: 0.9067\n", 424 | "\tModel Accuracy: 0.9218\n", 425 | "Bucket 1: (15, 15)\n", 426 | "\tBaseline BLEU = 0.8772\n", 427 | "\tModel BLEU = 0.8100\n", 428 | "\tBaseline Accuracy: 0.7980\n", 429 | "\tModel Accuracy: 0.7437\n", 430 | "Bucket 2: (20, 20)\n", 431 | "\tBaseline BLEU = 0.8898\n", 432 | "\tModel BLEU = 0.7636\n", 433 | "\tBaseline Accuracy: 0.7366\n", 434 | "\tModel Accuracy: 0.5370\n", 435 | "Bucket 3: (40, 40)\n", 436 | "\tBaseline BLEU = 0.9098\n", 437 | "\tModel BLEU = 0.5387\n", 438 | "\tBaseline Accuracy: 0.6041\n", 439 | "\tModel Accuracy: 0.1117\n" 440 | ] 441 | } 442 | ], 443 | "source": [ 444 | "# 4 layers, 20k steps\n", 445 | "errors = evaluate_accuracy(sess, model, data_reader, corrective_tokens, test_path)#, max_samples=1000)" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 16, 451 | "metadata": { 452 | "collapsed": false 453 | }, 454 | "outputs": [ 455 | { 456 | "name": "stdout", 457 | "output_type": "stream", 458 | "text": [ 459 | "Bucket 0: (10, 10)\n", 460 | "\tBaseline BLEU = 0.8341\n", 461 | "\tModel BLEU = 0.8516\n", 462 | "\tBaseline Accuracy: 0.9083\n", 463 | "\tModel Accuracy: 0.9384\n", 464 | "Bucket 1: (15, 15)\n", 465 | "\tBaseline BLEU = 0.8850\n", 466 | "\tModel BLEU = 0.8860\n", 467 | "\tBaseline Accuracy: 0.8156\n", 468 | "\tModel Accuracy: 0.8491\n", 469 | "Bucket 2: (20, 20)\n", 470 | "\tBaseline BLEU = 0.8876\n", 471 | "\tModel BLEU = 0.8880\n", 472 | "\tBaseline Accuracy: 0.7291\n", 473 | "\tModel Accuracy: 0.7817\n", 474 | "Bucket 3: (40, 40)\n", 475 | "\tBaseline BLEU = 0.9099\n", 476 | "\tModel BLEU = 0.9045\n", 477 | "\tBaseline Accuracy: 0.6073\n", 478 | "\tModel Accuracy: 0.6425\n" 479 | ] 480 | } 481 | ], 482 | "source": [ 483 | "errors = evaluate_accuracy(sess, model, data_reader, corrective_tokens, test_path)#, max_samples=1000)" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": 15, 489 | "metadata": { 490 | "collapsed": false, 491 | "scrolled": false 492 | }, 493 | "outputs": [ 494 | { 495 | "name": "stdout", 496 | "output_type": "stream", 497 | "text": [ 498 | "Decoding: you beg for mercy in a second .\n", 499 | "Target: you 'll beg for mercy in a second .\n", 500 | "\n", 501 | "Decoding: i 'm dying for a shower . you could use the one too . and we 'd better check that bandage .\n", 502 | "Target: i 'm dying for a shower . you could use one too . and we 'd better check that bandage .\n", 503 | "\n", 504 | "Decoding: whatever ... they 've become hotshot computer guys so they get a job to build el computer grande ... skynet ... for the government . right ?\n", 505 | "Target: whatever ... they become the hotshot computer guys so they get the job to build el computer grande ... skynet ... for the government . right ?\n", 506 | "\n", 507 | "Decoding: did n't you say that they 're going to develop this revolutionary a new thing ...\n", 508 | "Target: did n't you say that they 're going to develop this revolutionary new thing ...\n", 509 | "\n", 510 | "Decoding: bag some z ?\n", 511 | "Target: bag some z 's ?\n", 512 | "\n", 513 | "Decoding: sleep . it 'll be a light soon .\n", 514 | "Target: sleep . it 'll be light soon .\n", 515 | "\n", 516 | "Decoding: well , at least i know what to name him . i do n't suppose you 'd know who father is ? so i do n't tell him to get lost when i meet him .\n", 517 | "Target: well , at least i know what to name him . i do n't suppose you 'd know who the father is ? so i do n't tell him to get lost when i meet him .\n", 518 | "\n", 519 | "Decoding: we got ta get you to doctor .\n", 520 | "Target: we got ta get you to a doctor .\n", 521 | "\n", 522 | "Decoding: hunter killers . patrol machines . a build in automated factories . most of us were rounded up , put in camps ... for orderly disposal .\n", 523 | "Target: hunter killers . patrol machines . build in automated factories . most of us were rounded up , put in camps ... for orderly disposal .\n", 524 | "\n", 525 | "Decoding: but outside , it 's a living human tissue . flesh , skin , hair ... blood . grown for the cyborgs .\n", 526 | "Target: but outside , it 's living human tissue . flesh , skin , hair ... blood . grown for the cyborgs .\n", 527 | "\n", 528 | "Decoding: you heard enough . decide . are you going to release me ?\n", 529 | "Target: you 've heard enough . decide . are you going to release me ?\n", 530 | "\n", 531 | "Decoding: okay . okay . but this ... cyborg ... if it metal ...\n", 532 | "Target: okay . okay . but this ... cyborg ... if it 's metal ...\n", 533 | "\n", 534 | "Decoding: you go naked . something about the field generated by living organism . nothing dead will go .\n", 535 | "Target: you go naked . something about the field generated by a living organism . nothing dead will go .\n", 536 | "\n", 537 | "Decoding: ca n't . nobody goes home . nobody else comes through . it just him and me .\n", 538 | "Target: ca n't . nobody goes home . nobody else comes through . it 's just him and me .\n", 539 | "\n", 540 | "Decoding: i see . and this ... computer , thinks it can win by killing the mother of its enemy , kill- ing him , in effect , before he is even conceived ? sort of retroactive abortion ?\n", 541 | "Target: i see . and this ... computer , thinks it can win by killing the mother of its enemy , kill- ing him , in effect , before he is even conceived ? a sort of retroactive abortion ?\n", 542 | "\n", 543 | "Decoding: skynet . a computer defense system built for sac-norad by cyber dynamics . modified series 4800 .\n", 544 | "Target: skynet . a computer defense system built for sac-norad by cyber dynamics . a modified series 4800 .\n", 545 | "\n", 546 | "Decoding: a year 2027 ?\n", 547 | "Target: the year 2027 ?\n", 548 | "\n", 549 | "Decoding: with one thirty a second under perry , from '21 to '27 --\n", 550 | "Target: with the one thirty second under perry , from '21 to '27 --\n", 551 | "\n", 552 | "Decoding: why do n't you just stretch out here and get some sleep . it take your mom 's a good hour to get here from redlands .\n", 553 | "Target: why do n't you just stretch out here and get some sleep . it 'll take your mom a good hour to get here from redlands .\n", 554 | "\n", 555 | "Decoding: lieutenant , are you sure it them ? maybe i should see the ... bodies .\n", 556 | "Target: lieutenant , are you sure it 's them ? maybe i should see the ... bodies .\n", 557 | "\n", 558 | "Decoding: i already did . no answer at the door and the apartment manager 's out . i keeping them there .\n", 559 | "Target: i already did . no answer at the door and the apartment manager 's out . i 'm keeping them there .\n", 560 | "\n", 561 | "Decoding: that stuff two hours cold .\n", 562 | "Target: that stuff 's two hours cold .\n", 563 | "\n", 564 | "Decoding: you got ta be kidding me . the new guys 'll be short-stroking it over this one . one-day pattern killer .\n", 565 | "Target: you got ta be kidding me . the new guys 'll be short-stroking it over this one . a one-day pattern killer .\n", 566 | "\n", 567 | "Decoding: give me a short version .\n", 568 | "Target: give me the short version .\n", 569 | "\n", 570 | "Decoding: because it 's fair . give me the next quarter . if you still feel this way , vote your shares ...\n", 571 | "Target: because it 's fair . give me next quarter . if you still feel this way , vote your shares ...\n", 572 | "\n", 573 | "Decoding: it 's probably will . in fact , i 'd go so far as to say it 's almost certainly will , in time . why should i settle for that ?\n", 574 | "Target: it probably will . in fact , i 'd go so far as to say it almost certainly will , in time . why should i settle for that ?\n", 575 | "\n", 576 | "Decoding: stock will turn .\n", 577 | "Target: the stock will turn .\n", 578 | "\n", 579 | "Decoding: you want to know what it is ? what 's it all about ? john . chapter nine . verse twenty-five .\n", 580 | "Target: you want to know what it is ? what it 's all about ? john . chapter nine . verse twenty-five .\n", 581 | "\n", 582 | "Decoding: i only mention it because i took a test this afternoon , down on montgomery street .\n", 583 | "Target: i only mention it because i took the test this afternoon , down on montgomery street .\n", 584 | "\n", 585 | "Decoding: christine ! mister van orton is valued customer ...\n", 586 | "Target: christine ! mister van orton is a valued customer ...\n", 587 | "\n", 588 | "Decoding: a single ?\n", 589 | "Target: single ?\n", 590 | "\n", 591 | "Decoding: there 's another gig starting in saudi arabia . i just a walk-on this time though . bit-part .\n", 592 | "Target: there 's another gig starting in saudi arabia . i 'm just a walk-on this time though . bit-part .\n", 593 | "\n", 594 | "Decoding: no ! you take another step , i shoot ! they 're trying to kill me ...\n", 595 | "Target: no ! you take another step , i 'll shoot ! they 're trying to kill me ...\n", 596 | "\n", 597 | "Decoding: listen very carefully , i 'm telling the truth ... this is a game . this was all the game .\n", 598 | "Target: listen very carefully , i 'm telling the truth ... this is the game . this was all the game .\n", 599 | "\n", 600 | "Decoding: that 's gun . that 's ... that 's not automatic . the guard had an automatic ...\n", 601 | "Target: that gun . that ... that 's not automatic . the guard had an automatic ...\n", 602 | "\n", 603 | "Decoding: take a picture out .\n", 604 | "Target: take the picture out .\n", 605 | "\n", 606 | "Decoding: yeah . first communion . are n't i little angel ?\n", 607 | "Target: yeah . first communion . are n't i a little angel ?\n", 608 | "\n", 609 | "Decoding: let me go get some clothes on . we talk , okay ? be right back .\n", 610 | "Target: let me go get some clothes on . we 'll talk , okay ? be right back .\n", 611 | "\n", 612 | "Decoding: i 'm tired . i 'm sorry , i should go . i 've been enough of nuisance .\n", 613 | "Target: i 'm tired . i 'm sorry , i should go . i 've been enough of a nuisance .\n", 614 | "\n", 615 | "Decoding: they said five hundred . i said six . they said man in the gray flannel suit . i think i said , you mean the attractive guy in the gray flannel suit ?\n", 616 | "Target: they said five hundred . i said six . they said the man in the gray flannel suit . i think i said , you mean the attractive guy in the gray flannel suit ?\n", 617 | "\n", 618 | "Decoding: i have a confession to make . someone gave me six-hundred dollars to spill a drinks on you , as a practical joke .\n", 619 | "Target: i have a confession to make . someone gave me six-hundred dollars to spill drinks on you , as a practical joke .\n", 620 | "\n", 621 | "Decoding: maitre d ' called you christine .\n", 622 | "Target: the maitre d ' called you christine .\n", 623 | "\n", 624 | "Decoding: i know owner of campton place . i could talk to him in the morning .\n", 625 | "Target: i know the owner of campton place . i could talk to him in the morning .\n", 626 | "\n", 627 | "Decoding: fresh shirt ...\n", 628 | "Target: a fresh shirt ...\n", 629 | "\n", 630 | "Decoding: investment banking . moving money from a place to place .\n", 631 | "Target: investment banking . moving money from place to place .\n", 632 | "\n", 633 | "Decoding: what 's the c .r .s . ?\n", 634 | "Target: what 's c .r .s . ?\n", 635 | "\n", 636 | "Decoding: this is a c .r .s .\n", 637 | "Target: this is c .r .s .\n", 638 | "\n", 639 | "Decoding: their ladder here .\n", 640 | "Target: there 's a ladder here .\n", 641 | "\n", 642 | "Decoding: this is n't attempt to be gallant . if i do n't lift you , how are you going to get there ?\n", 643 | "Target: this is n't an attempt to be gallant . if i do n't lift you , how are you going to get there ?\n", 644 | "\n", 645 | "Decoding: are you suggesting we wait till someone 's finds us ?\n", 646 | "Target: are you suggesting we wait till someone finds us ?\n", 647 | "\n", 648 | "Decoding: `` ... wait for help . '' wait for help . i 'm not opening that specifically warns me not to .\n", 649 | "Target: `` ... wait for help . '' wait for help . i 'm not opening a door that specifically warns me not to .\n", 650 | "\n", 651 | "Decoding: read what it says : `` warning , do < u > not < /u > attempt to open . if elevator stops , use the emergency ... ``\n", 652 | "Target: read what it says : `` warning , do < u > not < /u > attempt to open . if elevator stops , use emergency ... ``\n", 653 | "\n", 654 | "Decoding: long story . i found this key in the mouth of wooden harlequin .\n", 655 | "Target: long story . i found this key in the mouth of a wooden harlequin .\n", 656 | "\n", 657 | "Decoding: how do you know that way ?\n", 658 | "Target: how do you know that 's the way ?\n", 659 | "\n", 660 | "Decoding: it 's run by company ... they play elaborate pranks . things like this . i 'm really only now finding out myself .\n", 661 | "Target: it 's run by a company ... they play elaborate pranks . things like this . i 'm really only now finding out myself .\n", 662 | "\n", 663 | "Decoding: you got to be kidding .\n", 664 | "Target: you 've got to be kidding .\n", 665 | "\n", 666 | "Decoding: i do n't think he breathing .\n", 667 | "Target: i do n't think he 's breathing .\n", 668 | "\n", 669 | "Decoding: a bad month . you did exact the same thing to me last week .\n", 670 | "Target: a bad month . you did the exact same thing to me last week .\n", 671 | "\n", 672 | "Decoding: yeah , yeah . she 's called a cab . said something about catching plane .\n", 673 | "Target: yeah , yeah . she called a cab . said something about catching a plane .\n", 674 | "\n", 675 | "Decoding: oh , god yes please . thanks , man . i take you up on that .\n", 676 | "Target: oh , god yes please . thanks , man . i 'll take you up on that .\n", 677 | "\n", 678 | "Decoding: this ... ? oh , this is just ... this is bill .\n", 679 | "Target: this ... ? oh , this is just ... this is the bill .\n", 680 | "\n", 681 | "Decoding: baby , they were all over the house with metal detectors . they switched your gun with look-alike , rigged barrel , loaded with blanks . pop-gun .\n", 682 | "Target: baby , they were all over the house with metal detectors . they switched your gun with a look-alike , rigged barrel , loaded with blanks . pop-gun .\n", 683 | "\n", 684 | "Decoding: you dodged bullet .\n", 685 | "Target: you dodged a bullet .\n", 686 | "\n", 687 | "Decoding: c .r .s . who do you think ? jesus h . , thank your lucky charms . to think what i 've almost got you into .\n", 688 | "Target: c .r .s . who do you think ? jesus h . , thank your lucky charms . to think what i almost got you into .\n", 689 | "\n", 690 | "Decoding: it 's profound life experience .\n", 691 | "Target: it 's a profound life experience .\n", 692 | "\n", 693 | "Decoding: you 've heard of it . you 've seen other people having it . they 're entertainment service , but more than that .\n", 694 | "Target: you 've heard of it . you 've seen other people having it . they 're an entertainment service , but more than that .\n", 695 | "\n", 696 | "Decoding: they make your life fun . there 's only guarantee is you will not be bored .\n", 697 | "Target: they make your life fun . their only guarantee is you will not be bored .\n", 698 | "\n", 699 | "Decoding: not after i done with it . actually , i 've been here . in grad-school i bought crystal-meth from the maitre d ' .\n", 700 | "Target: not after i 'm done with it . actually , i 've been here . in grad-school i bought crystal-meth from the maitre d ' .\n", 701 | "\n", 702 | "Decoding: that 's why it 's a classic . come on , man ... how 'bout hug ... ?\n", 703 | "Target: that 's why it 's a classic . come on , man ... how 'bout a hug ... ?\n", 704 | "\n", 705 | "Decoding: how much is it ? a few thousand , at least . a rolex like that ... lucky for you 've missed it .\n", 706 | "Target: how much is it ? a few thousand , at least . a rolex like that ... lucky for you they missed it .\n", 707 | "\n", 708 | "Decoding: i told you , they hired me over the phone . i 've never met anyone .\n", 709 | "Target: i told you , they hired me over the phone . i never met anyone .\n", 710 | "\n", 711 | "Decoding: i do n't want money . i 'm pulling back curtain . i 'm here to meet the wizard .\n", 712 | "Target: i do n't want money . i 'm pulling back the curtain . i 'm here to meet the wizard .\n", 713 | "\n", 714 | "Decoding: tell them the cops are after you ... tell them you got to talk to someone , i 'm threatening to blow the whistle .\n", 715 | "Target: tell them the cops are after you ... tell them you 've got to talk to someone , i 'm threatening to blow the whistle .\n", 716 | "\n", 717 | "Decoding: they own the whole building . they just move from the floor to floor .\n", 718 | "Target: they own the whole building . they just move from floor to floor .\n", 719 | "\n", 720 | "Decoding: look , it was just a job . nothing personal , ya know ? i play my part , improvise little . that 's what i 'm good at .\n", 721 | "Target: look , it was just a job . nothing personal , ya know ? i play my part , improvise a little . that 's what i 'm good at .\n", 722 | "\n", 723 | "Decoding: that 's right -- you 're left-brain the word fetishist .\n", 724 | "Target: that 's right -- you 're a left-brain word fetishist .\n", 725 | "\n", 726 | "Decoding: one guarantee . payment 's entirely at your brother discretion and , as a gift , dependent on your satisfaction .\n", 727 | "Target: one guarantee . payment 's entirely at your brother 's discretion and , as a gift , dependent on your satisfaction .\n", 728 | "\n", 729 | "Decoding: your brother was a client with our branch . we do a sort of informal scoring . his numbers were outstanding . sure you 're not hungry at all ... ? tung hoy , best in chinatown ...\n", 730 | "Target: your brother was a client with our london branch . we do a sort of informal scoring . his numbers were outstanding . sure you 're not hungry at all ... ? tung hoy , best in chinatown ...\n", 731 | "\n", 732 | "Decoding: key ?\n", 733 | "Target: the key ?\n", 734 | "\n", 735 | "Decoding: nobody 's worried about your father .\n", 736 | "Target: nobody worried about your father .\n", 737 | "\n", 738 | "Decoding: there 's been a break in . lock this door and stay here . do n't move muscle .\n", 739 | "Target: there 's been a break in . lock this door and stay here . do n't move a muscle .\n", 740 | "\n", 741 | "Decoding: i do n't know what you 're talking about . what happened ?\n", 742 | "Target: i do n't know what you 're talking about . what 's happened ?\n", 743 | "\n", 744 | "Decoding: did alarm go off ? the house ... they ... you did n't see ... ?\n", 745 | "Target: did the alarm go off ? the house ... they ... you did n't see ... ?\n", 746 | "\n", 747 | "Decoding: then then .\n", 748 | "Target: goodnight then .\n", 749 | "\n", 750 | "Decoding: okay . i think he into some sort of new personal improvement cult .\n", 751 | "Target: okay . i think he 's into some sort of new personal improvement cult .\n", 752 | "\n", 753 | "Decoding: dinner in the oven .\n", 754 | "Target: dinner 's in the oven .\n", 755 | "\n", 756 | "Decoding: there was incident a few days ago ... a nervous breakdown , they said . the police took him . they left this address , in case anyone ...\n", 757 | "Target: there was an incident a few days ago ... a nervous breakdown , they said . the police took him . they left this address , in case anyone ...\n", 758 | "\n", 759 | "Decoding: what 's trouble ?\n", 760 | "Target: what 's the trouble ?\n", 761 | "\n", 762 | "Decoding: mister ... seymour butts .\n", 763 | "Target: a mister ... seymour butts .\n", 764 | "\n", 765 | "Decoding: what 's the gentleman , maria ?\n", 766 | "Target: what gentleman , maria ?\n", 767 | "\n", 768 | "Decoding: i would n't mention following , except he was very insistent . it 's obviously some sort of prank ...\n", 769 | "Target: i would n't mention the following , except he was very insistent . it 's obviously some sort of prank ...\n", 770 | "\n", 771 | "Decoding: i send your regrets . honestly , why must i even bother ?\n", 772 | "Target: i 'll send your regrets . honestly , why must i even bother ?\n", 773 | "\n", 774 | "Decoding: the hinchberger 's wedding .\n", 775 | "Target: the hinchberger wedding .\n", 776 | "\n", 777 | "Decoding: invitations : museum gala .\n", 778 | "Target: invitations : the museum gala .\n", 779 | "\n", 780 | "Decoding: nice touch . does a game use real bullets ... ?\n", 781 | "Target: nice touch . does the game use real bullets ... ?\n", 782 | "\n", 783 | "Decoding: it 's what they do . it 's like ... being toyed with by a bunch of ... depraved children\n", 784 | "Target: it 's what they do . it 's like ... being toyed with by a bunch of ... depraved children .\n", 785 | "\n", 786 | "Decoding: find out about a company called the c .r .s . consumer recreation services .\n", 787 | "Target: find out about a company called c .r .s . consumer recreation services .\n", 788 | "\n", 789 | "Decoding: someone 's playing hardball . it 's complicated . can i ask favor ?\n", 790 | "Target: someone 's playing hardball . it 's complicated . can i ask a favor ?\n", 791 | "\n", 792 | "Decoding: how 's the concerned should i be ?\n", 793 | "Target: how concerned should i be ?\n", 794 | "\n", 795 | "Decoding: that you 've a involved conrad ... is unforgivable . i am now your enemy .\n", 796 | "Target: that you 've involved conrad ... is unforgivable . i am now your enemy .\n", 797 | "\n", 798 | "Decoding: what happened ...\n", 799 | "Target: what 's happened ...\n", 800 | "\n", 801 | "Decoding: modelling small-group dynamics in formation of narrative hallucinations . you brought us here to scare us . insomnia , that was just a decoy issue . you 're disgusting .\n", 802 | "Target: modelling small-group dynamics in the formation of narrative hallucinations . you brought us here to scare us . insomnia , that was just a decoy issue . you 're disgusting .\n", 803 | "\n", 804 | "Decoding: come on . these are the typically sentimental gestures of depraved industrialist .\n", 805 | "Target: come on . these are the typically sentimental gestures of a depraved industrialist .\n", 806 | "\n", 807 | "Decoding: the children . children hugh crain built the house for . the children he never had .\n", 808 | "Target: the children . the children hugh crain built the house for . the children he never had .\n", 809 | "\n", 810 | "Decoding: obsessive worrier . join club . and you ? i 'd guess ...\n", 811 | "Target: obsessive worrier . join the club . and you ? i 'd guess ...\n", 812 | "\n", 813 | "Decoding: so why did you need the addam family mansion for a scientific test ?\n", 814 | "Target: so why did you need the addam 's family mansion for a scientific test ?\n", 815 | "\n", 816 | "Decoding: -- how much is this car 's worth ?\n", 817 | "Target: -- how much is this car worth ?\n", 818 | "\n", 819 | "Decoding: you do n't really believe it haunted ... do you believe in ghosts ?\n", 820 | "Target: you do n't really believe it 's haunted ... do you believe in ghosts ?\n", 821 | "\n", 822 | "Decoding: so could you ! is this some fucked up the idea of art , putting someone else 's name to a painting ?\n", 823 | "Target: so could you ! is this some fucked up idea of art , putting someone else 's name to a painting ?\n", 824 | "\n", 825 | "Decoding: and why did n't marrow tell < u > us < /u > ? does n't he a trust women ? that fuck .\n", 826 | "Target: and why did n't marrow tell < u > us < /u > ? does n't he trust women ? that fuck .\n", 827 | "\n", 828 | "Decoding: nah , you 're going crazy with doubt , all of your mistakes are coming back up the pipes , and it 's worse than nightmare . --\n", 829 | "Target: nah , you 're going crazy with doubt , all of your mistakes are coming back up the pipes , and it 's worse than a nightmare . --\n", 830 | "\n", 831 | "Decoding: not the way you 've constructed your group , it just not ethical !\n", 832 | "Target: not the way you 've constructed your group , it 's just not ethical !\n", 833 | "\n", 834 | "Decoding: children want me . they 're calling me . they need me .\n", 835 | "Target: the children want me . they 're calling me . they need me .\n", 836 | "\n", 837 | "Decoding: i looked at theo . she had look on her face .\n", 838 | "Target: i looked at theo . she had a look on her face .\n", 839 | "\n", 840 | "Decoding: i was n't thinking about my mother bathroom .\n", 841 | "Target: i was n't thinking about my mother 's bathroom .\n", 842 | "\n", 843 | "Decoding: so ... smell ... is ... smell is sense that triggers the most powerful memories . and memory can trigger a smell .\n", 844 | "Target: so ... smell ... is ... smell is the sense that triggers the most powerful memories . and a memory can trigger a smell .\n", 845 | "\n", 846 | "Decoding: in the bathroom in my mother 's room , toilet was next to old wooden table . it smelled like that wood .\n", 847 | "Target: in the bathroom in my mother 's room , the toilet was next to an old wooden table . it smelled like that wood .\n", 848 | "\n", 849 | "Decoding: cold sensation . who felt it first ?\n", 850 | "Target: the cold sensation . who felt it first ?\n", 851 | "\n", 852 | "Decoding: i really ... honored to be part of this study , jim .\n", 853 | "Target: i 'm really ... honored to be part of this study , jim .\n", 854 | "\n", 855 | "Decoding: nell . good enough . and i jim .\n", 856 | "Target: nell . good enough . and i 'm jim .\n", 857 | "\n", 858 | "Decoding: that ? that 's a hill house .\n", 859 | "Target: that ? that 's hill house .\n", 860 | "\n", 861 | "Decoding: here 's how they 're organized . groups of five , very different personalities : scored all over the kiersey temperament sorter just like you asked for . and they all score high on insomnia charts .\n", 862 | "Target: here 's how they 're organized . groups of five , very different personalities : scored all over the kiersey temperament sorter just like you asked for . and they all score high on the insomnia charts .\n", 863 | "\n", 864 | "Decoding: you hear the vibrations in the wire . there 's magnetic pulse in the wires , you feel it . i could test it .\n", 865 | "Target: you hear the vibrations in the wire . there 's a magnetic pulse in the wires , you feel it . i could test it .\n", 866 | "\n", 867 | "Decoding: but experiment was a failure .\n", 868 | "Target: but the experiment was a failure .\n", 869 | "\n", 870 | "Decoding: he wandering around house , and nell heard him . she thought it was ghosts . let 's go look for him again .\n", 871 | "Target: he 's wandering around the house , and nell heard him . she thought it was ghosts . let 's go look for him again .\n", 872 | "\n", 873 | "Decoding: i 'll take her with me to university tomorrow . i ca n't believe i read the test wrong . i did n't see anything that looked like she was suicidal .\n", 874 | "Target: i 'll take her with me to the university tomorrow . i ca n't believe i read the test wrong . i did n't see anything that looked like she was suicidal .\n", 875 | "\n", 876 | "Decoding: no , but nell been here longer than i have .\n", 877 | "Target: no , but nell 's been here longer than i have .\n", 878 | "\n", 879 | "Decoding: rene crain . up there . rope . ship 's hawser . hard to tie . do n't know how she 's got it .\n", 880 | "Target: rene crain . up there . rope . ship 's hawser . hard to tie . do n't know how she got it .\n", 881 | "\n", 882 | "Decoding: mrs . dudley be waiting for you .\n", 883 | "Target: mrs . dudley 'll be waiting for you .\n", 884 | "\n", 885 | "Decoding: that 's a good question . what is it about fences ? sometimes a locked chain makes people on both sides of fence just a little more comfortable . why would that be ?\n", 886 | "Target: that 's a good question . what is it about fences ? sometimes a locked chain makes people on both sides of the fence just a little more comfortable . why would that be ?\n", 887 | "\n", 888 | "Decoding: well , i 've never lived with a beauty . you must love working here .\n", 889 | "Target: well , i 've never lived with beauty . you must love working here .\n", 890 | "\n", 891 | "Decoding: nell , it makes sense . it 's all makes sense . you and i , we were scaring each other , working each other up .\n", 892 | "Target: nell , it makes sense . it all makes sense . you and i , we were scaring each other , working each other up .\n", 893 | "\n" 894 | ] 895 | } 896 | ], 897 | "source": [ 898 | "for decoding, target in errors:\n", 899 | " print(\"Decoding: \" + \" \".join(decoding))\n", 900 | " print(\"Target: \" + \" \".join(target) + \"\\n\")" 901 | ] 902 | } 903 | ], 904 | "metadata": { 905 | "kernelspec": { 906 | "display_name": "Python 2", 907 | "language": "python", 908 | "name": "python2" 909 | }, 910 | "language_info": { 911 | "codemirror_mode": { 912 | "name": "ipython", 913 | "version": 2 914 | }, 915 | "file_extension": ".py", 916 | "mimetype": "text/x-python", 917 | "name": "python", 918 | "nbconvert_exporter": "python", 919 | "pygments_lexer": "ipython2", 920 | "version": "2.7.11" 921 | } 922 | }, 923 | "nbformat": 4, 924 | "nbformat_minor": 0 925 | } 926 | -------------------------------------------------------------------------------- /correct_text.py: -------------------------------------------------------------------------------- 1 | """Program used to create, train, and evaluate "text correcting" models. 2 | 3 | Defines utilities that allow for: 4 | 1. Creating a TextCorrectorModel 5 | 2. Training a TextCorrectorModel using a given DataReader (i.e. a data source) 6 | 3. Decoding predictions from a trained TextCorrectorModel 7 | 8 | The program is best run from the command line using the flags defined below or 9 | through an IPython notebook. 10 | 11 | Note: this has been mostly copied from Tensorflow's translate.py demo 12 | """ 13 | 14 | from __future__ import absolute_import 15 | from __future__ import division 16 | from __future__ import print_function 17 | 18 | import math 19 | import os 20 | import sys 21 | import time 22 | from collections import defaultdict 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | from data_reader import EOS_ID 28 | from text_corrector_data_readers import MovieDialogReader, PTBDataReader 29 | 30 | from text_corrector_models import TextCorrectorModel 31 | 32 | tf.app.flags.DEFINE_string("config", "TestConfig", "Name of config to use.") 33 | tf.app.flags.DEFINE_string("data_reader_type", "MovieDialogReader", 34 | "Type of data reader to use.") 35 | tf.app.flags.DEFINE_string("train_path", "train", "Training data path.") 36 | tf.app.flags.DEFINE_string("val_path", "val", "Validation data path.") 37 | tf.app.flags.DEFINE_string("test_path", "test", "Testing data path.") 38 | tf.app.flags.DEFINE_string("model_path", "model", "Path where the model is " 39 | "saved.") 40 | tf.app.flags.DEFINE_boolean("decode", False, "Whether we should decode data " 41 | "at test_path. The default is to " 42 | "train a model and save it at " 43 | "model_path.") 44 | 45 | FLAGS = tf.app.flags.FLAGS 46 | 47 | 48 | class TestConfig(): 49 | # We use a number of buckets and pad to the closest one for efficiency. 50 | buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] 51 | 52 | steps_per_checkpoint = 20 53 | max_steps = 100 54 | 55 | max_vocabulary_size = 10000 56 | 57 | size = 128 58 | num_layers = 1 59 | max_gradient_norm = 5.0 60 | batch_size = 64 61 | learning_rate = 0.5 62 | learning_rate_decay_factor = 0.99 63 | 64 | use_lstm = False 65 | use_rms_prop = False 66 | 67 | 68 | class DefaultPTBConfig(): 69 | buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] 70 | 71 | steps_per_checkpoint = 100 72 | max_steps = 20000 73 | 74 | max_vocabulary_size = 10000 75 | 76 | size = 512 77 | num_layers = 2 78 | max_gradient_norm = 5.0 79 | batch_size = 64 80 | learning_rate = 0.5 81 | learning_rate_decay_factor = 0.99 82 | 83 | use_lstm = False 84 | use_rms_prop = False 85 | 86 | 87 | class DefaultMovieDialogConfig(): 88 | buckets = [(10, 10), (15, 15), (20, 20), (40, 40)] 89 | 90 | steps_per_checkpoint = 100 91 | max_steps = 20000 92 | 93 | # The OOV resolution scheme used in decode() allows us to use a much smaller 94 | # vocabulary. 95 | max_vocabulary_size = 2000 96 | 97 | size = 512 98 | num_layers = 4 99 | max_gradient_norm = 5.0 100 | batch_size = 64 101 | learning_rate = 0.5 102 | learning_rate_decay_factor = 0.99 103 | 104 | use_lstm = True 105 | use_rms_prop = False 106 | 107 | projection_bias = 0.0 108 | 109 | 110 | def create_model(session, forward_only, model_path, config=TestConfig()): 111 | """Create translation model and initialize or load parameters in session.""" 112 | model = TextCorrectorModel( 113 | config.max_vocabulary_size, 114 | config.max_vocabulary_size, 115 | config.buckets, 116 | config.size, 117 | config.num_layers, 118 | config.max_gradient_norm, 119 | config.batch_size, 120 | config.learning_rate, 121 | config.learning_rate_decay_factor, 122 | use_lstm=config.use_lstm, 123 | forward_only=forward_only, 124 | config=config) 125 | ckpt = tf.train.get_checkpoint_state(model_path) 126 | if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path): 127 | print("Reading model parameters from %s" % ckpt.model_checkpoint_path) 128 | model.saver.restore(session, ckpt.model_checkpoint_path) 129 | else: 130 | print("Created model with fresh parameters.") 131 | session.run(tf.initialize_all_variables()) 132 | return model 133 | 134 | 135 | def train(data_reader, train_path, test_path, model_path): 136 | """""" 137 | print( 138 | "Reading data; train = {}, test = {}".format(train_path, test_path)) 139 | config = data_reader.config 140 | train_data = data_reader.build_dataset(train_path) 141 | test_data = data_reader.build_dataset(test_path) 142 | 143 | with tf.Session() as sess: 144 | # Create model. 145 | print( 146 | "Creating %d layers of %d units." % ( 147 | config.num_layers, config.size)) 148 | model = create_model(sess, False, model_path, config=config) 149 | 150 | # Read data into buckets and compute their sizes. 151 | train_bucket_sizes = [len(train_data[b]) for b in 152 | range(len(config.buckets))] 153 | print("Training bucket sizes: {}".format(train_bucket_sizes)) 154 | train_total_size = float(sum(train_bucket_sizes)) 155 | print("Total train size: {}".format(train_total_size)) 156 | 157 | # A bucket scale is a list of increasing numbers from 0 to 1 that 158 | # we'll use to select a bucket. Length of [scale[i], scale[i+1]] is 159 | # proportional to the size if i-th training bucket, as used later. 160 | train_buckets_scale = [ 161 | sum(train_bucket_sizes[:i + 1]) / train_total_size 162 | for i in range(len(train_bucket_sizes))] 163 | 164 | # This is the training loop. 165 | step_time, loss = 0.0, 0.0 166 | current_step = 0 167 | previous_losses = [] 168 | while current_step < config.max_steps: 169 | # Choose a bucket according to data distribution. We pick a random 170 | # number in [0, 1] and use the corresponding interval in 171 | # train_buckets_scale. 172 | random_number_01 = np.random.random_sample() 173 | bucket_id = min([i for i in range(len(train_buckets_scale)) 174 | if train_buckets_scale[i] > random_number_01]) 175 | 176 | # Get a batch and make a step. 177 | start_time = time.time() 178 | encoder_inputs, decoder_inputs, target_weights = model.get_batch( 179 | train_data, bucket_id) 180 | _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, 181 | target_weights, bucket_id, False) 182 | step_time += (time.time() - start_time) / config \ 183 | .steps_per_checkpoint 184 | loss += step_loss / config.steps_per_checkpoint 185 | current_step += 1 186 | 187 | # Once in a while, we save checkpoint, print statistics, and run 188 | # evals. 189 | if current_step % config.steps_per_checkpoint == 0: 190 | # Print statistics for the previous epoch. 191 | perplexity = math.exp(float(loss)) if loss < 300 else float( 192 | "inf") 193 | print("global step %d learning rate %.4f step-time %.2f " 194 | "perplexity %.2f" % ( 195 | model.global_step.eval(), model.learning_rate.eval(), 196 | step_time, perplexity)) 197 | # Decrease learning rate if no improvement was seen over last 198 | # 3 times. 199 | if len(previous_losses) > 2 and loss > max( 200 | previous_losses[-3:]): 201 | sess.run(model.learning_rate_decay_op) 202 | previous_losses.append(loss) 203 | # Save checkpoint and zero timer and loss. 204 | checkpoint_path = os.path.join(model_path, "translate.ckpt") 205 | model.saver.save(sess, checkpoint_path, 206 | global_step=model.global_step) 207 | step_time, loss = 0.0, 0.0 208 | # Run evals on development set and print their perplexity. 209 | for bucket_id in range(len(config.buckets)): 210 | if len(test_data[bucket_id]) == 0: 211 | print(" eval: empty bucket %d" % (bucket_id)) 212 | continue 213 | encoder_inputs, decoder_inputs, target_weights = \ 214 | model.get_batch(test_data, bucket_id) 215 | _, eval_loss, _ = model.step(sess, encoder_inputs, 216 | decoder_inputs, 217 | target_weights, bucket_id, 218 | True) 219 | eval_ppx = math.exp( 220 | float(eval_loss)) if eval_loss < 300 else float( 221 | "inf") 222 | print(" eval: bucket %d perplexity %.2f" % ( 223 | bucket_id, eval_ppx)) 224 | sys.stdout.flush() 225 | 226 | 227 | def get_corrective_tokens(data_reader, train_path): 228 | # TODO: this should be part of the model, learned during training 229 | corrective_tokens = set() 230 | for source_tokens, target_tokens in data_reader.read_samples_by_string( 231 | train_path): 232 | corrective_tokens.update(set(target_tokens) - set(source_tokens)) 233 | return corrective_tokens 234 | 235 | 236 | def decode(sess, model, data_reader, data_to_decode, corrective_tokens=set(), 237 | verbose=True): 238 | """ 239 | 240 | :param sess: 241 | :param model: 242 | :param data_reader: 243 | :param data_to_decode: an iterable of token lists representing the input 244 | data we want to decode 245 | :param corrective_tokens 246 | :param verbose: 247 | :return: 248 | """ 249 | model.batch_size = 1 250 | 251 | corrective_tokens_mask = np.zeros(model.target_vocab_size) 252 | corrective_tokens_mask[EOS_ID] = 1.0 253 | for token in corrective_tokens: 254 | corrective_tokens_mask[data_reader.convert_token_to_id(token)] = 1.0 255 | 256 | for tokens in data_to_decode: 257 | token_ids = [data_reader.convert_token_to_id(token) for token in tokens] 258 | 259 | # Which bucket does it belong to? 260 | matching_buckets = [b for b in range(len(model.buckets)) 261 | if model.buckets[b][0] > len(token_ids)] 262 | if not matching_buckets: 263 | # The input string has more tokens than the largest bucket, so we 264 | # have to skip it. 265 | continue 266 | 267 | bucket_id = min(matching_buckets) 268 | 269 | # Get a 1-element batch to feed the sentence to the model. 270 | encoder_inputs, decoder_inputs, target_weights = model.get_batch( 271 | {bucket_id: [(token_ids, [])]}, bucket_id) 272 | 273 | # Get output logits for the sentence. 274 | _, _, output_logits = model.step( 275 | sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, 276 | True, corrective_tokens=corrective_tokens_mask) 277 | 278 | oov_input_tokens = [token for token in tokens if 279 | data_reader.is_unknown_token(token)] 280 | 281 | outputs = [] 282 | next_oov_token_idx = 0 283 | 284 | for logit in output_logits: 285 | 286 | max_likelihood_token_id = int(np.argmax(logit, axis=1)) 287 | # First check to see if this logit most likely points to the EOS 288 | # identifier. 289 | if max_likelihood_token_id == EOS_ID: 290 | break 291 | 292 | token = data_reader.convert_id_to_token(max_likelihood_token_id) 293 | if data_reader.is_unknown_token(token): 294 | # Replace the "unknown" token with the most probable OOV 295 | # token from the input. 296 | if next_oov_token_idx < len(oov_input_tokens): 297 | # If we still have OOV input tokens available, 298 | # pick the next available one. 299 | token = oov_input_tokens[next_oov_token_idx] 300 | # Advance to the next OOV input token. 301 | next_oov_token_idx += 1 302 | else: 303 | # If we've already used all OOV input tokens, 304 | # then we just leave the token as "UNK" 305 | pass 306 | 307 | outputs.append(token) 308 | 309 | if verbose: 310 | decoded_sentence = " ".join(outputs) 311 | 312 | print("Input: {}".format(" ".join(tokens))) 313 | print("Output: {}\n".format(decoded_sentence)) 314 | 315 | yield outputs 316 | 317 | 318 | def decode_sentence(sess, model, data_reader, sentence, corrective_tokens=set(), 319 | verbose=True): 320 | """Used with InteractiveSession in an IPython notebook.""" 321 | return next(decode(sess, model, data_reader, [sentence.split()], 322 | corrective_tokens=corrective_tokens, verbose=verbose)) 323 | 324 | 325 | def evaluate_accuracy(sess, model, data_reader, corrective_tokens, test_path, 326 | max_samples=None): 327 | """Evaluates the accuracy and BLEU score of the given model.""" 328 | 329 | import nltk # Loading here to avoid having to bundle it in lambda. 330 | 331 | # Build a collection of "baseline" and model-based hypotheses, where the 332 | # baseline is just the (potentially errant) source sequence. 333 | baseline_hypotheses = defaultdict(list) # The model's input 334 | model_hypotheses = defaultdict(list) # The actual model's predictions 335 | targets = defaultdict(list) # Groundtruth 336 | 337 | errors = [] 338 | 339 | n_samples_by_bucket = defaultdict(int) 340 | n_correct_model_by_bucket = defaultdict(int) 341 | n_correct_baseline_by_bucket = defaultdict(int) 342 | n_samples = 0 343 | 344 | # Evaluate the model against all samples in the test data set. 345 | for source, target in data_reader.read_samples_by_string(test_path): 346 | 347 | matching_buckets = [i for i, bucket in enumerate(model.buckets) if 348 | len(source) < bucket[0]] 349 | if not matching_buckets: 350 | continue 351 | 352 | bucket_id = matching_buckets[0] 353 | 354 | decoding = next( 355 | decode(sess, model, data_reader, [source], 356 | corrective_tokens=corrective_tokens, verbose=False)) 357 | model_hypotheses[bucket_id].append(decoding) 358 | if decoding == target: 359 | n_correct_model_by_bucket[bucket_id] += 1 360 | else: 361 | errors.append((decoding, target)) 362 | 363 | baseline_hypotheses[bucket_id].append(source) 364 | if source == target: 365 | n_correct_baseline_by_bucket[bucket_id] += 1 366 | 367 | # nltk.corpus_bleu expects a list of one or more reference 368 | # tranlsations per sample, so we wrap the target list in another list 369 | # here. 370 | targets[bucket_id].append([target]) 371 | 372 | n_samples_by_bucket[bucket_id] += 1 373 | n_samples += 1 374 | 375 | if max_samples is not None and n_samples > max_samples: 376 | break 377 | 378 | # Measure the corpus BLEU score and accuracy for the model and baseline 379 | # across all buckets. 380 | for bucket_id in targets.keys(): 381 | baseline_bleu_score = nltk.translate.bleu_score.corpus_bleu( 382 | targets[bucket_id], baseline_hypotheses[bucket_id]) 383 | model_bleu_score = nltk.translate.bleu_score.corpus_bleu( 384 | targets[bucket_id], model_hypotheses[bucket_id]) 385 | print("Bucket {}: {}".format(bucket_id, model.buckets[bucket_id])) 386 | print("\tBaseline BLEU = {:.4f}\n\tModel BLEU = {:.4f}".format( 387 | baseline_bleu_score, model_bleu_score)) 388 | print("\tBaseline Accuracy: {:.4f}".format( 389 | 1.0 * n_correct_baseline_by_bucket[bucket_id] / 390 | n_samples_by_bucket[bucket_id])) 391 | print("\tModel Accuracy: {:.4f}".format( 392 | 1.0 * n_correct_model_by_bucket[bucket_id] / 393 | n_samples_by_bucket[bucket_id])) 394 | 395 | return errors 396 | 397 | 398 | def main(_): 399 | # Determine which config we should use. 400 | if FLAGS.config == "TestConfig": 401 | config = TestConfig() 402 | elif FLAGS.config == "DefaultMovieDialogConfig": 403 | config = DefaultMovieDialogConfig() 404 | elif FLAGS.config == "DefaultPTBConfig": 405 | config = DefaultPTBConfig() 406 | else: 407 | raise ValueError("config argument not recognized; must be one of: " 408 | "TestConfig, DefaultPTBConfig, " 409 | "DefaultMovieDialogConfig") 410 | 411 | # Determine which kind of DataReader we want to use. 412 | if FLAGS.data_reader_type == "MovieDialogReader": 413 | data_reader = MovieDialogReader(config, FLAGS.train_path) 414 | elif FLAGS.data_reader_type == "PTBDataReader": 415 | data_reader = PTBDataReader(config, FLAGS.train_path) 416 | else: 417 | raise ValueError("data_reader_type argument not recognized; must be " 418 | "one of: MovieDialogReader, PTBDataReader") 419 | 420 | if FLAGS.decode: 421 | # Decode test sentences. 422 | with tf.Session() as session: 423 | model = create_model(session, True, FLAGS.model_path, config=config) 424 | print("Loaded model. Beginning decoding.") 425 | decodings = decode(session, model=model, data_reader=data_reader, 426 | data_to_decode=data_reader.read_tokens( 427 | FLAGS.test_path), verbose=False) 428 | # Write the decoded tokens to stdout. 429 | for tokens in decodings: 430 | print(" ".join(tokens)) 431 | sys.stdout.flush() 432 | else: 433 | print("Training model.") 434 | train(data_reader, FLAGS.train_path, FLAGS.val_path, FLAGS.model_path) 435 | 436 | 437 | if __name__ == "__main__": 438 | tf.app.run() 439 | -------------------------------------------------------------------------------- /data_reader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from collections import Counter 6 | 7 | # Define constants associated with the usual special-case tokens. 8 | PAD_ID = 0 9 | GO_ID = 1 10 | EOS_ID = 2 11 | 12 | PAD_TOKEN = "PAD" 13 | EOS_TOKEN = "EOS" 14 | GO_TOKEN = "GO" 15 | 16 | 17 | class DataReader(object): 18 | 19 | def __init__(self, config, train_path=None, token_to_id=None, 20 | special_tokens=(), dataset_copies=1): 21 | self.config = config 22 | self.dataset_copies = dataset_copies 23 | 24 | # Construct vocabulary. 25 | max_vocabulary_size = self.config.max_vocabulary_size 26 | 27 | if train_path is None: 28 | self.token_to_id = token_to_id 29 | else: 30 | token_counts = Counter() 31 | 32 | for tokens in self.read_tokens(train_path): 33 | token_counts.update(tokens) 34 | 35 | self.token_counts = token_counts 36 | 37 | # Get to max_vocab_size words 38 | count_pairs = sorted(token_counts.items(), 39 | key=lambda x: (-x[1], x[0])) 40 | vocabulary, _ = list(zip(*count_pairs)) 41 | vocabulary = list(vocabulary) 42 | # Insert the special tokens at the beginning. 43 | vocabulary[0:0] = special_tokens 44 | full_token_and_id = zip(vocabulary, range(len(vocabulary))) 45 | self.full_token_to_id = dict(full_token_and_id) 46 | self.token_to_id = dict(full_token_and_id[:max_vocabulary_size]) 47 | 48 | self.id_to_token = {v: k for k, v in self.token_to_id.items()} 49 | 50 | def read_tokens(self, path): 51 | """ 52 | Reads the given file line by line and yields the list of tokens present 53 | in each line. 54 | 55 | :param path: 56 | :return: 57 | """ 58 | raise NotImplementedError("Must implement read_tokens") 59 | 60 | def read_samples_by_string(self, path): 61 | """ 62 | Reads the given file line by line and yields the word-form of each 63 | derived sample. 64 | 65 | :param path: 66 | :return: 67 | """ 68 | raise NotImplementedError("Must implement read_word_samples") 69 | 70 | def unknown_token(self): 71 | raise NotImplementedError("Must implement read_word_samples") 72 | 73 | def convert_token_to_id(self, token): 74 | """ 75 | 76 | :param token: 77 | :return: 78 | """ 79 | token_with_id = token if token in self.token_to_id else \ 80 | self.unknown_token() 81 | return self.token_to_id[token_with_id] 82 | 83 | def convert_id_to_token(self, token_id): 84 | return self.id_to_token[token_id] 85 | 86 | def is_unknown_token(self, token): 87 | """ 88 | True if the given token is out of the vocabulary used or if it is the 89 | actual unknown token. 90 | 91 | :param token: 92 | :return: 93 | """ 94 | return token not in self.token_to_id or token == self.unknown_token() 95 | 96 | def sentence_to_token_ids(self, sentence): 97 | """ 98 | Converts a whitespace-delimited sentence into a list of word ids. 99 | """ 100 | return [self.convert_token_to_id(word) for word in sentence.split()] 101 | 102 | def token_ids_to_tokens(self, word_ids): 103 | """ 104 | Converts a list of word ids to a list of their corresponding words. 105 | """ 106 | return [self.convert_id_to_token(word) for word in word_ids] 107 | 108 | def read_samples(self, path): 109 | """ 110 | 111 | :param path: 112 | :return: 113 | """ 114 | for source_words, target_words in self.read_samples_by_string(path): 115 | source = [self.convert_token_to_id(word) for word in source_words] 116 | target = [self.convert_token_to_id(word) for word in target_words] 117 | target.append(EOS_ID) 118 | 119 | yield source, target 120 | 121 | def build_dataset(self, path): 122 | dataset = [[] for _ in self.config.buckets] 123 | 124 | # Make multiple copies of the dataset so that we synthesize different 125 | # dropouts. 126 | for _ in range(self.dataset_copies): 127 | for source, target in self.read_samples(path): 128 | for bucket_id, (source_size, target_size) in enumerate( 129 | self.config.buckets): 130 | if len(source) < source_size and len( 131 | target) < target_size: 132 | dataset[bucket_id].append([source, target]) 133 | break 134 | 135 | return dataset 136 | 137 | -------------------------------------------------------------------------------- /dtc_lambda.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import json 4 | import os 5 | import pickle 6 | 7 | import boto3 8 | import tensorflow as tf 9 | 10 | from correct_text import create_model, DefaultMovieDialogConfig, decode_sentence 11 | from text_corrector_data_readers import MovieDialogReader 12 | 13 | 14 | def safe_mkdir(path): 15 | try: 16 | os.mkdir(path) 17 | except OSError: 18 | pass 19 | 20 | 21 | def download(client, filename, local_path=None, s3_path=None): 22 | if s3_path is None: 23 | s3_path = MODEL_PARAMS_DIR + "/" + filename 24 | if local_path is None: 25 | local_path = os.path.join(MODEL_PATH, filename) 26 | 27 | print("Downloading " + filename) 28 | client.download_file(BUCKET_NAME, s3_path, local_path) 29 | 30 | 31 | # Define resources on S3. 32 | BUCKET_NAME = "deeptextcorrecter" 33 | ROOT_DATA_PATH = "/tmp/" 34 | MODEL_PARAMS_DIR = "model_params" 35 | MODEL_PATH = os.path.join(ROOT_DATA_PATH, MODEL_PARAMS_DIR) 36 | 37 | # Create tmp dirs for storing data locally. 38 | safe_mkdir(ROOT_DATA_PATH) 39 | safe_mkdir(MODEL_PATH) 40 | 41 | # Download files from S3 to local disk. 42 | s3_client = boto3.client('s3') 43 | 44 | model_ckpt = "41900" 45 | tf_meta_filename = "translate.ckpt-{}.meta".format(model_ckpt) 46 | download(s3_client, tf_meta_filename) 47 | 48 | tf_params_filename = "translate.ckpt-{}".format(model_ckpt) 49 | download(s3_client, tf_params_filename) 50 | 51 | tf_ckpt_filename = "checkpoint" 52 | download(s3_client, tf_ckpt_filename) 53 | 54 | corrective_tokens_filename = "corrective_tokens.pickle" 55 | corrective_tokens_path = os.path.join(ROOT_DATA_PATH, 56 | corrective_tokens_filename) 57 | download(s3_client, corrective_tokens_filename, 58 | local_path=corrective_tokens_path) 59 | 60 | token_to_id_filename = "token_to_id.pickle" 61 | token_to_id_path = os.path.join(ROOT_DATA_PATH, token_to_id_filename) 62 | download(s3_client, token_to_id_filename, local_path=token_to_id_path) 63 | 64 | # Load model. 65 | config = DefaultMovieDialogConfig() 66 | sess = tf.Session() 67 | print("Loading model") 68 | model = create_model(sess, True, MODEL_PATH, config=config) 69 | print("Loaded model") 70 | 71 | with open(corrective_tokens_path) as f: 72 | corrective_tokens = pickle.load(f) 73 | with open(token_to_id_path) as f: 74 | token_to_id = pickle.load(f) 75 | data_reader = MovieDialogReader(config, token_to_id=token_to_id) 76 | print("Done initializing.") 77 | 78 | 79 | def process_event(event, context): 80 | print("Received event: " + json.dumps(event, indent=2)) 81 | 82 | outputs = decode_sentence(sess, model, data_reader, event["text"], 83 | corrective_tokens=corrective_tokens, 84 | verbose=False) 85 | return {"input": event["text"], "output": " ".join(outputs)} 86 | -------------------------------------------------------------------------------- /preprocessors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atpaino/deep-text-corrector/ebf44ee81f61b21d378bbe6b98d715ae42521f6d/preprocessors/__init__.py -------------------------------------------------------------------------------- /preprocessors/preprocess_movie_dialogs.py: -------------------------------------------------------------------------------- 1 | """Preprocesses Cornell Movie Dialog data.""" 2 | import nltk 3 | import tensorflow as tf 4 | 5 | tf.app.flags.DEFINE_string("raw_data", "", "Raw data path") 6 | tf.app.flags.DEFINE_string("out_file", "", "File to write preprocessed data " 7 | "to.") 8 | 9 | FLAGS = tf.app.flags.FLAGS 10 | 11 | 12 | def main(_): 13 | with open(FLAGS.raw_data, "r") as raw_data, \ 14 | open(FLAGS.out_file, "w") as out: 15 | for line in raw_data: 16 | parts = line.split(" +++$+++ ") 17 | dialog_line = parts[-1] 18 | s = dialog_line.strip().lower().decode("utf-8", "ignore") 19 | preprocessed_line = " ".join(nltk.word_tokenize(s)) 20 | out.write(preprocessed_line + "\n") 21 | 22 | if __name__ == "__main__": 23 | tf.app.run() 24 | -------------------------------------------------------------------------------- /seq2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Library for creating sequence-to-sequence models in TensorFlow. 17 | 18 | Sequence-to-sequence recurrent neural networks can learn complex functions 19 | that map input sequences to output sequences. These models yield very good 20 | results on a number of tasks, such as speech recognition, parsing, machine 21 | translation, or even constructing automated replies to emails. 22 | 23 | Before using this module, it is recommended to read the TensorFlow tutorial 24 | on sequence-to-sequence models. It explains the basic concepts of this module 25 | and shows an end-to-end example of how to build a translation model. 26 | https://www.tensorflow.org/versions/master/tutorials/seq2seq/index.html 27 | 28 | Here is an overview of functions available in this module. They all use 29 | a very similar interface, so after reading the above tutorial and using 30 | one of them, others should be easy to substitute. 31 | 32 | * Full sequence-to-sequence models. 33 | - basic_rnn_seq2seq: The most basic RNN-RNN model. 34 | - tied_rnn_seq2seq: The basic model with tied encoder and decoder weights. 35 | - embedding_rnn_seq2seq: The basic model with input embedding. 36 | - embedding_tied_rnn_seq2seq: The tied model with input embedding. 37 | - embedding_attention_seq2seq: Advanced model with input embedding and 38 | the neural attention mechanism; recommended for complex tasks. 39 | 40 | * Multi-task sequence-to-sequence models. 41 | - one2many_rnn_seq2seq: The embedding model with multiple decoders. 42 | 43 | * Decoders (when you write your own encoder, you can use these to decode; 44 | e.g., if you want to write a model that generates captions for images). 45 | - rnn_decoder: The basic decoder based on a pure RNN. 46 | - attention_decoder: A decoder that uses the attention mechanism. 47 | 48 | * Losses. 49 | - sequence_loss: Loss for a sequence model returning average log-perplexity. 50 | - sequence_loss_by_example: As above, but not averaging over all examples. 51 | 52 | * model_with_buckets: A convenience function to create models with bucketing 53 | (see the tutorial above for an explanation of why and how to use it). 54 | """ 55 | 56 | from __future__ import absolute_import 57 | from __future__ import division 58 | from __future__ import print_function 59 | 60 | # We disable pylint because we need python3 compatibility. 61 | from six.moves import xrange # pylint: disable=redefined-builtin 62 | from six.moves import zip # pylint: disable=redefined-builtin 63 | 64 | from tensorflow.python import shape 65 | from tensorflow.python.framework import dtypes 66 | from tensorflow.python.framework import ops 67 | from tensorflow.python.ops import array_ops 68 | from tensorflow.python.ops import control_flow_ops 69 | from tensorflow.python.ops import embedding_ops 70 | from tensorflow.python.ops import math_ops 71 | from tensorflow.python.ops import nn_ops 72 | from tensorflow.python.ops import rnn 73 | from tensorflow.python.ops import rnn_cell 74 | from tensorflow.python.ops import variable_scope 75 | from tensorflow.python.util import nest 76 | 77 | # TODO(ebrevdo): Remove once _linear is fully deprecated. 78 | linear = rnn_cell._linear # pylint: disable=protected-access 79 | 80 | 81 | def _extract_argmax_and_embed(embedding, output_projection=None, 82 | update_embedding=True): 83 | """Get a loop_function that extracts the previous symbol and embeds it. 84 | 85 | Args: 86 | embedding: embedding tensor for symbols. 87 | output_projection: None or a pair (W, B). If provided, each fed previous 88 | output will first be multiplied by W and added B. 89 | update_embedding: Boolean; if False, the gradients will not propagate 90 | through the embeddings. 91 | 92 | Returns: 93 | A loop function. 94 | """ 95 | def loop_function(prev, _): 96 | # decoder outputs thus far. 97 | if output_projection is not None: 98 | prev = nn_ops.xw_plus_b( 99 | prev, output_projection[0], output_projection[1]) 100 | prev_symbol = math_ops.argmax(prev, 1) 101 | # Note that gradients will not propagate through the second parameter of 102 | # embedding_lookup. 103 | emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol) 104 | if not update_embedding: 105 | emb_prev = array_ops.stop_gradient(emb_prev) 106 | return emb_prev, prev_symbol 107 | return loop_function 108 | 109 | 110 | def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, 111 | scope=None): 112 | """RNN decoder for the sequence-to-sequence model. 113 | 114 | Args: 115 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 116 | initial_state: 2D Tensor with shape [batch_size x cell.state_size]. 117 | cell: rnn_cell.RNNCell defining the cell function and size. 118 | loop_function: If not None, this function will be applied to the i-th output 119 | in order to generate the i+1-st input, and decoder_inputs will be ignored, 120 | except for the first element ("GO" symbol). This can be used for decoding, 121 | but also for training to emulate http://arxiv.org/abs/1506.03099. 122 | Signature -- loop_function(prev, i) = next 123 | * prev is a 2D Tensor of shape [batch_size x output_size], 124 | * i is an integer, the step number (when advanced control is needed), 125 | * next is a 2D Tensor of shape [batch_size x input_size]. 126 | scope: VariableScope for the created subgraph; defaults to "rnn_decoder". 127 | 128 | Returns: 129 | A tuple of the form (outputs, state), where: 130 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 131 | shape [batch_size x output_size] containing generated outputs. 132 | state: The state of each cell at the final time-step. 133 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 134 | (Note that in some cases, like basic RNN cell or GRU cell, outputs and 135 | states can be the same. They are different for LSTM cells though.) 136 | """ 137 | with variable_scope.variable_scope(scope or "rnn_decoder"): 138 | state = initial_state 139 | outputs = [] 140 | prev = None 141 | for i, inp in enumerate(decoder_inputs): 142 | if loop_function is not None and prev is not None: 143 | with variable_scope.variable_scope("loop_function", reuse=True): 144 | inp = loop_function(prev, i) 145 | if i > 0: 146 | variable_scope.get_variable_scope().reuse_variables() 147 | output, state = cell(inp, state) 148 | outputs.append(output) 149 | if loop_function is not None: 150 | prev = output 151 | return outputs, state 152 | 153 | 154 | def basic_rnn_seq2seq( 155 | encoder_inputs, decoder_inputs, cell, dtype=dtypes.float32, scope=None): 156 | """Basic RNN sequence-to-sequence model. 157 | 158 | This model first runs an RNN to encode encoder_inputs into a state vector, 159 | then runs decoder, initialized with the last encoder state, on decoder_inputs. 160 | Encoder and decoder use the same RNN cell type, but don't share parameters. 161 | 162 | Args: 163 | encoder_inputs: A list of 2D Tensors [batch_size x input_size]. 164 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 165 | cell: rnn_cell.RNNCell defining the cell function and size. 166 | dtype: The dtype of the initial state of the RNN cell (default: tf.float32). 167 | scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq". 168 | 169 | Returns: 170 | A tuple of the form (outputs, state), where: 171 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 172 | shape [batch_size x output_size] containing the generated outputs. 173 | state: The state of each decoder cell in the final time-step. 174 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 175 | """ 176 | with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"): 177 | _, enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtype) 178 | return rnn_decoder(decoder_inputs, enc_state, cell) 179 | 180 | 181 | def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, 182 | loop_function=None, dtype=dtypes.float32, scope=None): 183 | """RNN sequence-to-sequence model with tied encoder and decoder parameters. 184 | 185 | This model first runs an RNN to encode encoder_inputs into a state vector, and 186 | then runs decoder, initialized with the last encoder state, on decoder_inputs. 187 | Encoder and decoder use the same RNN cell and share parameters. 188 | 189 | Args: 190 | encoder_inputs: A list of 2D Tensors [batch_size x input_size]. 191 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 192 | cell: rnn_cell.RNNCell defining the cell function and size. 193 | loop_function: If not None, this function will be applied to i-th output 194 | in order to generate i+1-th input, and decoder_inputs will be ignored, 195 | except for the first element ("GO" symbol), see rnn_decoder for details. 196 | dtype: The dtype of the initial state of the rnn cell (default: tf.float32). 197 | scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq". 198 | 199 | Returns: 200 | A tuple of the form (outputs, state), where: 201 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 202 | shape [batch_size x output_size] containing the generated outputs. 203 | state: The state of each decoder cell in each time-step. This is a list 204 | with length len(decoder_inputs) -- one item for each time-step. 205 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 206 | """ 207 | with variable_scope.variable_scope("combined_tied_rnn_seq2seq"): 208 | scope = scope or "tied_rnn_seq2seq" 209 | _, enc_state = rnn.rnn( 210 | cell, encoder_inputs, dtype=dtype, scope=scope) 211 | variable_scope.get_variable_scope().reuse_variables() 212 | return rnn_decoder(decoder_inputs, enc_state, cell, 213 | loop_function=loop_function, scope=scope) 214 | 215 | 216 | def embedding_rnn_decoder(decoder_inputs, 217 | initial_state, 218 | cell, 219 | num_symbols, 220 | embedding_size, 221 | output_projection=None, 222 | feed_previous=False, 223 | update_embedding_for_previous=True, 224 | scope=None): 225 | """RNN decoder with embedding and a pure-decoding option. 226 | 227 | Args: 228 | decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 229 | initial_state: 2D Tensor [batch_size x cell.state_size]. 230 | cell: rnn_cell.RNNCell defining the cell function. 231 | num_symbols: Integer, how many symbols come into the embedding. 232 | embedding_size: Integer, the length of the embedding vector for each symbol. 233 | output_projection: None or a pair (W, B) of output projection weights and 234 | biases; W has shape [output_size x num_symbols] and B has 235 | shape [num_symbols]; if provided and feed_previous=True, each fed 236 | previous output will first be multiplied by W and added B. 237 | feed_previous: Boolean; if True, only the first of decoder_inputs will be 238 | used (the "GO" symbol), and all other decoder inputs will be generated by: 239 | next = embedding_lookup(embedding, argmax(previous_output)), 240 | In effect, this implements a greedy decoder. It can also be used 241 | during training to emulate http://arxiv.org/abs/1506.03099. 242 | If False, decoder_inputs are used as given (the standard decoder case). 243 | update_embedding_for_previous: Boolean; if False and feed_previous=True, 244 | only the embedding for the first symbol of decoder_inputs (the "GO" 245 | symbol) will be updated by back propagation. Embeddings for the symbols 246 | generated from the decoder itself remain unchanged. This parameter has 247 | no effect if feed_previous=False. 248 | scope: VariableScope for the created subgraph; defaults to 249 | "embedding_rnn_decoder". 250 | 251 | Returns: 252 | A tuple of the form (outputs, state), where: 253 | outputs: A list of the same length as decoder_inputs of 2D Tensors. The 254 | output is of shape [batch_size x cell.output_size] when 255 | output_projection is not None (and represents the dense representation 256 | of predicted tokens). It is of shape [batch_size x num_decoder_symbols] 257 | when output_projection is None. 258 | state: The state of each decoder cell in each time-step. This is a list 259 | with length len(decoder_inputs) -- one item for each time-step. 260 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 261 | 262 | Raises: 263 | ValueError: When output_projection has the wrong shape. 264 | """ 265 | with variable_scope.variable_scope(scope or "embedding_rnn_decoder") as scope: 266 | if output_projection is not None: 267 | dtype = scope.dtype 268 | proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) 269 | proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 270 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 271 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 272 | 273 | embedding = variable_scope.get_variable("embedding", 274 | [num_symbols, embedding_size]) 275 | loop_function = _extract_argmax_and_embed( 276 | embedding, output_projection, 277 | update_embedding_for_previous) if feed_previous else None 278 | emb_inp = ( 279 | embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs) 280 | return rnn_decoder(emb_inp, initial_state, cell, 281 | loop_function=loop_function) 282 | 283 | 284 | def embedding_rnn_seq2seq(encoder_inputs, 285 | decoder_inputs, 286 | cell, 287 | num_encoder_symbols, 288 | num_decoder_symbols, 289 | embedding_size, 290 | output_projection=None, 291 | feed_previous=False, 292 | dtype=None, 293 | scope=None): 294 | """Embedding RNN sequence-to-sequence model. 295 | 296 | This model first embeds encoder_inputs by a newly created embedding (of shape 297 | [num_encoder_symbols x input_size]). Then it runs an RNN to encode 298 | embedded encoder_inputs into a state vector. Next, it embeds decoder_inputs 299 | by another newly created embedding (of shape [num_decoder_symbols x 300 | input_size]). Then it runs RNN decoder, initialized with the last 301 | encoder state, on embedded decoder_inputs. 302 | 303 | Args: 304 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 305 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 306 | cell: rnn_cell.RNNCell defining the cell function and size. 307 | num_encoder_symbols: Integer; number of symbols on the encoder side. 308 | num_decoder_symbols: Integer; number of symbols on the decoder side. 309 | embedding_size: Integer, the length of the embedding vector for each symbol. 310 | output_projection: None or a pair (W, B) of output projection weights and 311 | biases; W has shape [output_size x num_decoder_symbols] and B has 312 | shape [num_decoder_symbols]; if provided and feed_previous=True, each 313 | fed previous output will first be multiplied by W and added B. 314 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 315 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 316 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 317 | If False, decoder_inputs are used as given (the standard decoder case). 318 | dtype: The dtype of the initial state for both the encoder and encoder 319 | rnn cells (default: tf.float32). 320 | scope: VariableScope for the created subgraph; defaults to 321 | "embedding_rnn_seq2seq" 322 | 323 | Returns: 324 | A tuple of the form (outputs, state), where: 325 | outputs: A list of the same length as decoder_inputs of 2D Tensors. The 326 | output is of shape [batch_size x cell.output_size] when 327 | output_projection is not None (and represents the dense representation 328 | of predicted tokens). It is of shape [batch_size x num_decoder_symbols] 329 | when output_projection is None. 330 | state: The state of each decoder cell in each time-step. This is a list 331 | with length len(decoder_inputs) -- one item for each time-step. 332 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 333 | """ 334 | with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq") as scope: 335 | if dtype is not None: 336 | scope.set_dtype(dtype) 337 | else: 338 | dtype = scope.dtype 339 | 340 | # Encoder. 341 | encoder_cell = rnn_cell.EmbeddingWrapper( 342 | cell, embedding_classes=num_encoder_symbols, 343 | embedding_size=embedding_size) 344 | _, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype) 345 | 346 | # Decoder. 347 | if output_projection is None: 348 | cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 349 | 350 | if isinstance(feed_previous, bool): 351 | return embedding_rnn_decoder( 352 | decoder_inputs, 353 | encoder_state, 354 | cell, 355 | num_decoder_symbols, 356 | embedding_size, 357 | output_projection=output_projection, 358 | feed_previous=feed_previous) 359 | 360 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 361 | def decoder(feed_previous_bool): 362 | reuse = None if feed_previous_bool else True 363 | with variable_scope.variable_scope( 364 | variable_scope.get_variable_scope(), reuse=reuse) as scope: 365 | outputs, state = embedding_rnn_decoder( 366 | decoder_inputs, encoder_state, cell, num_decoder_symbols, 367 | embedding_size, output_projection=output_projection, 368 | feed_previous=feed_previous_bool, 369 | update_embedding_for_previous=False) 370 | state_list = [state] 371 | if nest.is_sequence(state): 372 | state_list = nest.flatten(state) 373 | return outputs + state_list 374 | 375 | outputs_and_state = control_flow_ops.cond(feed_previous, 376 | lambda: decoder(True), 377 | lambda: decoder(False)) 378 | outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 379 | state_list = outputs_and_state[outputs_len:] 380 | state = state_list[0] 381 | if nest.is_sequence(encoder_state): 382 | state = nest.pack_sequence_as(structure=encoder_state, 383 | flat_sequence=state_list) 384 | return outputs_and_state[:outputs_len], state 385 | 386 | 387 | def embedding_tied_rnn_seq2seq(encoder_inputs, 388 | decoder_inputs, 389 | cell, 390 | num_symbols, 391 | embedding_size, 392 | num_decoder_symbols=None, 393 | output_projection=None, 394 | feed_previous=False, 395 | dtype=None, 396 | scope=None): 397 | """Embedding RNN sequence-to-sequence model with tied (shared) parameters. 398 | 399 | This model first embeds encoder_inputs by a newly created embedding (of shape 400 | [num_symbols x input_size]). Then it runs an RNN to encode embedded 401 | encoder_inputs into a state vector. Next, it embeds decoder_inputs using 402 | the same embedding. Then it runs RNN decoder, initialized with the last 403 | encoder state, on embedded decoder_inputs. The decoder output is over symbols 404 | from 0 to num_decoder_symbols - 1 if num_decoder_symbols is none; otherwise it 405 | is over 0 to num_symbols - 1. 406 | 407 | Args: 408 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 409 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 410 | cell: rnn_cell.RNNCell defining the cell function and size. 411 | num_symbols: Integer; number of symbols for both encoder and decoder. 412 | embedding_size: Integer, the length of the embedding vector for each symbol. 413 | num_decoder_symbols: Integer; number of output symbols for decoder. If 414 | provided, the decoder output is over symbols 0 to num_decoder_symbols - 1. 415 | Otherwise, decoder output is over symbols 0 to num_symbols - 1. Note that 416 | this assumes that the vocabulary is set up such that the first 417 | num_decoder_symbols of num_symbols are part of decoding. 418 | output_projection: None or a pair (W, B) of output projection weights and 419 | biases; W has shape [output_size x num_symbols] and B has 420 | shape [num_symbols]; if provided and feed_previous=True, each 421 | fed previous output will first be multiplied by W and added B. 422 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 423 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 424 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 425 | If False, decoder_inputs are used as given (the standard decoder case). 426 | dtype: The dtype to use for the initial RNN states (default: tf.float32). 427 | scope: VariableScope for the created subgraph; defaults to 428 | "embedding_tied_rnn_seq2seq". 429 | 430 | Returns: 431 | A tuple of the form (outputs, state), where: 432 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 433 | shape [batch_size x output_symbols] containing the generated 434 | outputs where output_symbols = num_decoder_symbols if 435 | num_decoder_symbols is not None otherwise output_symbols = num_symbols. 436 | state: The state of each decoder cell at the final time-step. 437 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 438 | 439 | Raises: 440 | ValueError: When output_projection has the wrong shape. 441 | """ 442 | with variable_scope.variable_scope( 443 | scope or "embedding_tied_rnn_seq2seq", dtype=dtype) as scope: 444 | dtype = scope.dtype 445 | 446 | if output_projection is not None: 447 | proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) 448 | proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 449 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 450 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 451 | 452 | embedding = variable_scope.get_variable( 453 | "embedding", [num_symbols, embedding_size], dtype=dtype) 454 | 455 | emb_encoder_inputs = [embedding_ops.embedding_lookup(embedding, x) 456 | for x in encoder_inputs] 457 | emb_decoder_inputs = [embedding_ops.embedding_lookup(embedding, x) 458 | for x in decoder_inputs] 459 | 460 | output_symbols = num_symbols 461 | if num_decoder_symbols is not None: 462 | output_symbols = num_decoder_symbols 463 | if output_projection is None: 464 | cell = rnn_cell.OutputProjectionWrapper(cell, output_symbols) 465 | 466 | if isinstance(feed_previous, bool): 467 | loop_function = _extract_argmax_and_embed( 468 | embedding, output_projection, True) if feed_previous else None 469 | return tied_rnn_seq2seq(emb_encoder_inputs, emb_decoder_inputs, cell, 470 | loop_function=loop_function, dtype=dtype) 471 | 472 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 473 | def decoder(feed_previous_bool): 474 | loop_function = _extract_argmax_and_embed( 475 | embedding, output_projection, False) if feed_previous_bool else None 476 | reuse = None if feed_previous_bool else True 477 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 478 | reuse=reuse): 479 | outputs, state = tied_rnn_seq2seq( 480 | emb_encoder_inputs, emb_decoder_inputs, cell, 481 | loop_function=loop_function, dtype=dtype) 482 | state_list = [state] 483 | if nest.is_sequence(state): 484 | state_list = nest.flatten(state) 485 | return outputs + state_list 486 | 487 | outputs_and_state = control_flow_ops.cond(feed_previous, 488 | lambda: decoder(True), 489 | lambda: decoder(False)) 490 | outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 491 | state_list = outputs_and_state[outputs_len:] 492 | state = state_list[0] 493 | # Calculate zero-state to know it's structure. 494 | static_batch_size = encoder_inputs[0].get_shape()[0] 495 | for inp in encoder_inputs[1:]: 496 | static_batch_size.merge_with(inp.get_shape()[0]) 497 | batch_size = static_batch_size.value 498 | if batch_size is None: 499 | batch_size = array_ops.shape(encoder_inputs[0])[0] 500 | zero_state = cell.zero_state(batch_size, dtype) 501 | if nest.is_sequence(zero_state): 502 | state = nest.pack_sequence_as(structure=zero_state, 503 | flat_sequence=state_list) 504 | return outputs_and_state[:outputs_len], state 505 | 506 | 507 | def attention_decoder(decoder_inputs, 508 | initial_state, 509 | attention_states, 510 | cell, 511 | output_size=None, 512 | num_heads=1, 513 | loop_function=None, 514 | dtype=None, 515 | scope=None, 516 | initial_state_attention=False): 517 | """RNN decoder with attention for the sequence-to-sequence model. 518 | 519 | In this context "attention" means that, during decoding, the RNN can look up 520 | information in the additional tensor attention_states, and it does this by 521 | focusing on a few entries from the tensor. This model has proven to yield 522 | especially good results in a number of sequence-to-sequence tasks. This 523 | implementation is based on http://arxiv.org/abs/1412.7449 (see below for 524 | details). It is recommended for complex sequence-to-sequence tasks. 525 | 526 | Args: 527 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 528 | initial_state: 2D Tensor [batch_size x cell.state_size]. 529 | attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 530 | cell: rnn_cell.RNNCell defining the cell function and size. 531 | output_size: Size of the output vectors; if None, we use cell.output_size. 532 | num_heads: Number of attention heads that read from attention_states. 533 | loop_function: If not None, this function will be applied to i-th output 534 | in order to generate i+1-th input, and decoder_inputs will be ignored, 535 | except for the first element ("GO" symbol). This can be used for decoding, 536 | but also for training to emulate http://arxiv.org/abs/1506.03099. 537 | Signature -- loop_function(prev, i) = next 538 | * prev is a 2D Tensor of shape [batch_size x output_size], 539 | * i is an integer, the step number (when advanced control is needed), 540 | * next is a 2D Tensor of shape [batch_size x input_size]. 541 | dtype: The dtype to use for the RNN initial state (default: tf.float32). 542 | scope: VariableScope for the created subgraph; default: "attention_decoder". 543 | initial_state_attention: If False (default), initial attentions are zero. 544 | If True, initialize the attentions from the initial state and attention 545 | states -- useful when we wish to resume decoding from a previously 546 | stored decoder state and attention states. 547 | 548 | Returns: 549 | A tuple of the form (outputs, state), where: 550 | outputs: A list of the same length as decoder_inputs of 2D Tensors of 551 | shape [batch_size x output_size]. These represent the generated outputs. 552 | Output i is computed from input i (which is either the i-th element 553 | of decoder_inputs or loop_function(output {i-1}, i)) as follows. 554 | First, we run the cell on a combination of the input and previous 555 | attention masks: 556 | cell_output, new_state = cell(linear(input, prev_attn), prev_state). 557 | Then, we calculate new attention masks: 558 | new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) 559 | and then we calculate the output: 560 | output = linear(cell_output, new_attn). 561 | state: The state of each decoder cell the final time-step. 562 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 563 | 564 | Raises: 565 | ValueError: when num_heads is not positive, there are no inputs, shapes 566 | of attention_states are not set, or input size cannot be inferred 567 | from the input. 568 | """ 569 | if not decoder_inputs: 570 | raise ValueError("Must provide at least 1 input to attention decoder.") 571 | if num_heads < 1: 572 | raise ValueError("With less than 1 heads, use a non-attention decoder.") 573 | if attention_states.get_shape()[2].value is None: 574 | raise ValueError("Shape[2] of attention_states must be known: %s" 575 | % attention_states.get_shape()) 576 | if output_size is None: 577 | output_size = cell.output_size 578 | 579 | with variable_scope.variable_scope( 580 | scope or "attention_decoder", dtype=dtype) as scope: 581 | dtype = scope.dtype 582 | 583 | batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping. 584 | attn_length = attention_states.get_shape()[1].value 585 | if attn_length is None: 586 | attn_length = shape(attention_states)[1] 587 | attn_size = attention_states.get_shape()[2].value 588 | 589 | # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. 590 | hidden = array_ops.reshape( 591 | attention_states, [-1, attn_length, 1, attn_size]) 592 | hidden_features = [] 593 | v = [] 594 | attention_vec_size = attn_size # Size of query vectors for attention. 595 | for a in xrange(num_heads): 596 | k = variable_scope.get_variable("AttnW_%d" % a, 597 | [1, 1, attn_size, attention_vec_size]) 598 | hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) 599 | v.append( 600 | variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size])) 601 | 602 | state = initial_state 603 | 604 | def attention(query): 605 | """Put attention masks on hidden using hidden_features and query.""" 606 | ds = [] # Results of attention reads will be stored here. 607 | if nest.is_sequence(query): # If the query is a tuple, flatten it. 608 | query_list = nest.flatten(query) 609 | for q in query_list: # Check that ndims == 2 if specified. 610 | ndims = q.get_shape().ndims 611 | if ndims: 612 | assert ndims == 2 613 | query = array_ops.concat(1, query_list) 614 | for a in xrange(num_heads): 615 | with variable_scope.variable_scope("Attention_%d" % a): 616 | y = linear(query, attention_vec_size, True) 617 | y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) 618 | # Attention mask is a softmax of v^T * tanh(...). 619 | s = math_ops.reduce_sum( 620 | v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3]) 621 | a = nn_ops.softmax(s) 622 | # Now calculate the attention-weighted vector d. 623 | d = math_ops.reduce_sum( 624 | array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, 625 | [1, 2]) 626 | ds.append(array_ops.reshape(d, [-1, attn_size])) 627 | return ds 628 | 629 | outputs = [] 630 | prev = None 631 | batch_attn_size = array_ops.pack([batch_size, attn_size]) 632 | attns = [array_ops.zeros(batch_attn_size, dtype=dtype) 633 | for _ in xrange(num_heads)] 634 | for a in attns: # Ensure the second shape of attention vectors is set. 635 | a.set_shape([None, attn_size]) 636 | if initial_state_attention: 637 | attns = attention(initial_state) 638 | 639 | for i, inp in enumerate(decoder_inputs): 640 | if i > 0: 641 | variable_scope.get_variable_scope().reuse_variables() 642 | # If loop_function is set, we use it instead of decoder_inputs. 643 | if loop_function is not None and prev is not None: 644 | with variable_scope.variable_scope("loop_function", reuse=True): 645 | # inp is the embedding to be used as input next 646 | # inp_index is the index of the token to be used as input 647 | # next 648 | inp, inp_symbol = loop_function(prev, i) 649 | # Merge input and previous attentions into one vector of the right size. 650 | input_size = inp.get_shape().with_rank(2)[1] 651 | if input_size.value is None: 652 | raise ValueError("Could not infer input size from input: %s" % inp.name) 653 | x = linear([inp] + attns, input_size, True) 654 | # Run the RNN. 655 | cell_output, state = cell(x, state) 656 | # Run the attention mechanism. 657 | if i == 0 and initial_state_attention: 658 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 659 | reuse=True): 660 | attns = attention(state) 661 | else: 662 | attns = attention(state) 663 | 664 | with variable_scope.variable_scope("AttnOutputProjection"): 665 | output = linear([cell_output] + attns, output_size, True) 666 | if loop_function is not None: 667 | prev = output 668 | outputs.append(output) 669 | 670 | return outputs, state 671 | 672 | 673 | def embedding_attention_decoder(decoder_inputs, 674 | initial_state, 675 | attention_states, 676 | cell, 677 | num_symbols, 678 | embedding_size, 679 | num_heads=1, 680 | output_size=None, 681 | output_projection=None, 682 | feed_previous=False, 683 | update_embedding_for_previous=True, 684 | dtype=None, 685 | scope=None, 686 | initial_state_attention=False, 687 | loop_fn_factory=_extract_argmax_and_embed): 688 | """RNN decoder with embedding and attention and a pure-decoding option. 689 | 690 | Args: 691 | decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 692 | initial_state: 2D Tensor [batch_size x cell.state_size]. 693 | attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 694 | cell: rnn_cell.RNNCell defining the cell function. 695 | num_symbols: Integer, how many symbols come into the embedding. 696 | embedding_size: Integer, the length of the embedding vector for each symbol. 697 | num_heads: Number of attention heads that read from attention_states. 698 | output_size: Size of the output vectors; if None, use output_size. 699 | output_projection: None or a pair (W, B) of output projection weights and 700 | biases; W has shape [output_size x num_symbols] and B has shape 701 | [num_symbols]; if provided and feed_previous=True, each fed previous 702 | output will first be multiplied by W and added B. 703 | feed_previous: Boolean; if True, only the first of decoder_inputs will be 704 | used (the "GO" symbol), and all other decoder inputs will be generated by: 705 | next = embedding_lookup(embedding, argmax(previous_output)), 706 | In effect, this implements a greedy decoder. It can also be used 707 | during training to emulate http://arxiv.org/abs/1506.03099. 708 | If False, decoder_inputs are used as given (the standard decoder case). 709 | update_embedding_for_previous: Boolean; if False and feed_previous=True, 710 | only the embedding for the first symbol of decoder_inputs (the "GO" 711 | symbol) will be updated by back propagation. Embeddings for the symbols 712 | generated from the decoder itself remain unchanged. This parameter has 713 | no effect if feed_previous=False. 714 | dtype: The dtype to use for the RNN initial states (default: tf.float32). 715 | scope: VariableScope for the created subgraph; defaults to 716 | "embedding_attention_decoder". 717 | initial_state_attention: If False (default), initial attentions are zero. 718 | If True, initialize the attentions from the initial state and attention 719 | states -- useful when we wish to resume decoding from a previously 720 | stored decoder state and attention states. 721 | 722 | Returns: 723 | A tuple of the form (outputs, state), where: 724 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 725 | shape [batch_size x output_size] containing the generated outputs. 726 | state: The state of each decoder cell at the final time-step. 727 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 728 | 729 | Raises: 730 | ValueError: When output_projection has the wrong shape. 731 | """ 732 | if output_size is None: 733 | output_size = cell.output_size 734 | if output_projection is not None: 735 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 736 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 737 | 738 | with variable_scope.variable_scope( 739 | scope or "embedding_attention_decoder", dtype=dtype) as scope: 740 | 741 | embedding = variable_scope.get_variable("embedding", 742 | [num_symbols, embedding_size]) 743 | loop_function = loop_fn_factory( 744 | embedding, output_projection, 745 | update_embedding_for_previous) if feed_previous else None 746 | emb_inp = [ 747 | embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs] 748 | return attention_decoder( 749 | emb_inp, 750 | initial_state, 751 | attention_states, 752 | cell, 753 | output_size=output_size, 754 | num_heads=num_heads, 755 | loop_function=loop_function, 756 | initial_state_attention=initial_state_attention) 757 | 758 | 759 | def embedding_attention_seq2seq(encoder_inputs, 760 | decoder_inputs, 761 | cell, 762 | num_encoder_symbols, 763 | num_decoder_symbols, 764 | embedding_size, 765 | num_heads=1, 766 | output_projection=None, 767 | feed_previous=False, 768 | dtype=None, 769 | scope=None, 770 | initial_state_attention=False, 771 | loop_fn_factory=_extract_argmax_and_embed): 772 | """Embedding sequence-to-sequence model with attention. 773 | 774 | This model first embeds encoder_inputs by a newly created embedding (of shape 775 | [num_encoder_symbols x input_size]). Then it runs an RNN to encode 776 | embedded encoder_inputs into a state vector. It keeps the outputs of this 777 | RNN at every step to use for attention later. Next, it embeds decoder_inputs 778 | by another newly created embedding (of shape [num_decoder_symbols x 779 | input_size]). Then it runs attention decoder, initialized with the last 780 | encoder state, on embedded decoder_inputs and attending to encoder outputs. 781 | 782 | Args: 783 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 784 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 785 | cell: rnn_cell.RNNCell defining the cell function and size. 786 | num_encoder_symbols: Integer; number of symbols on the encoder side. 787 | num_decoder_symbols: Integer; number of symbols on the decoder side. 788 | embedding_size: Integer, the length of the embedding vector for each symbol. 789 | num_heads: Number of attention heads that read from attention_states. 790 | output_projection: None or a pair (W, B) of output projection weights and 791 | biases; W has shape [output_size x num_decoder_symbols] and B has 792 | shape [num_decoder_symbols]; if provided and feed_previous=True, each 793 | fed previous output will first be multiplied by W and added B. 794 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 795 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 796 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 797 | If False, decoder_inputs are used as given (the standard decoder case). 798 | dtype: The dtype of the initial RNN state (default: tf.float32). 799 | scope: VariableScope for the created subgraph; defaults to 800 | "embedding_attention_seq2seq". 801 | initial_state_attention: If False (default), initial attentions are zero. 802 | If True, initialize the attentions from the initial state and attention 803 | states. 804 | 805 | Returns: 806 | A tuple of the form (outputs, state), where: 807 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 808 | shape [batch_size x num_decoder_symbols] containing the generated 809 | outputs. 810 | state: The state of each decoder cell at the final time-step. 811 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 812 | """ 813 | with variable_scope.variable_scope( 814 | scope or "embedding_attention_seq2seq", dtype=dtype) as scope: 815 | dtype = scope.dtype 816 | # Encoder. 817 | encoder_cell = rnn_cell.EmbeddingWrapper( 818 | cell, embedding_classes=num_encoder_symbols, 819 | embedding_size=embedding_size) 820 | encoder_outputs, encoder_state = rnn.rnn( 821 | encoder_cell, encoder_inputs, dtype=dtype) 822 | 823 | # First calculate a concatenation of encoder outputs to put attention on. 824 | top_states = [array_ops.reshape(e, [-1, 1, cell.output_size]) 825 | for e in encoder_outputs] 826 | attention_states = array_ops.concat(1, top_states) 827 | 828 | # Decoder. 829 | output_size = None 830 | if output_projection is None: 831 | cell = rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 832 | output_size = num_decoder_symbols 833 | 834 | if isinstance(feed_previous, bool): 835 | return embedding_attention_decoder( 836 | decoder_inputs, 837 | encoder_state, 838 | attention_states, 839 | cell, 840 | num_decoder_symbols, 841 | embedding_size, 842 | num_heads=num_heads, 843 | output_size=output_size, 844 | output_projection=output_projection, 845 | feed_previous=feed_previous, 846 | initial_state_attention=initial_state_attention, 847 | loop_fn_factory=loop_fn_factory) 848 | 849 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 850 | def decoder(feed_previous_bool): 851 | reuse = None if feed_previous_bool else True 852 | with variable_scope.variable_scope( 853 | variable_scope.get_variable_scope(), reuse=reuse) as scope: 854 | outputs, state = embedding_attention_decoder( 855 | decoder_inputs, 856 | encoder_state, 857 | attention_states, 858 | cell, 859 | num_decoder_symbols, 860 | embedding_size, 861 | num_heads=num_heads, 862 | output_size=output_size, 863 | output_projection=output_projection, 864 | feed_previous=feed_previous_bool, 865 | update_embedding_for_previous=False, 866 | initial_state_attention=initial_state_attention, 867 | loop_fn_factory=loop_fn_factory) 868 | state_list = [state] 869 | if nest.is_sequence(state): 870 | state_list = nest.flatten(state) 871 | return outputs + state_list 872 | 873 | outputs_and_state = control_flow_ops.cond(feed_previous, 874 | lambda: decoder(True), 875 | lambda: decoder(False)) 876 | outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 877 | state_list = outputs_and_state[outputs_len:] 878 | state = state_list[0] 879 | if nest.is_sequence(encoder_state): 880 | state = nest.pack_sequence_as(structure=encoder_state, 881 | flat_sequence=state_list) 882 | return outputs_and_state[:outputs_len], state 883 | 884 | 885 | def one2many_rnn_seq2seq(encoder_inputs, 886 | decoder_inputs_dict, 887 | cell, 888 | num_encoder_symbols, 889 | num_decoder_symbols_dict, 890 | embedding_size, 891 | feed_previous=False, 892 | dtype=None, 893 | scope=None): 894 | """One-to-many RNN sequence-to-sequence model (multi-task). 895 | 896 | This is a multi-task sequence-to-sequence model with one encoder and multiple 897 | decoders. Reference to multi-task sequence-to-sequence learning can be found 898 | here: http://arxiv.org/abs/1511.06114 899 | 900 | Args: 901 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 902 | decoder_inputs_dict: A dictionany mapping decoder name (string) to 903 | the corresponding decoder_inputs; each decoder_inputs is a list of 1D 904 | Tensors of shape [batch_size]; num_decoders is defined as 905 | len(decoder_inputs_dict). 906 | cell: rnn_cell.RNNCell defining the cell function and size. 907 | num_encoder_symbols: Integer; number of symbols on the encoder side. 908 | num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an 909 | integer specifying number of symbols for the corresponding decoder; 910 | len(num_decoder_symbols_dict) must be equal to num_decoders. 911 | embedding_size: Integer, the length of the embedding vector for each symbol. 912 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of 913 | decoder_inputs will be used (the "GO" symbol), and all other decoder 914 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 915 | If False, decoder_inputs are used as given (the standard decoder case). 916 | dtype: The dtype of the initial state for both the encoder and encoder 917 | rnn cells (default: tf.float32). 918 | scope: VariableScope for the created subgraph; defaults to 919 | "one2many_rnn_seq2seq" 920 | 921 | Returns: 922 | A tuple of the form (outputs_dict, state_dict), where: 923 | outputs_dict: A mapping from decoder name (string) to a list of the same 924 | length as decoder_inputs_dict[name]; each element in the list is a 2D 925 | Tensors with shape [batch_size x num_decoder_symbol_list[name]] 926 | containing the generated outputs. 927 | state_dict: A mapping from decoder name (string) to the final state of the 928 | corresponding decoder RNN; it is a 2D Tensor of shape 929 | [batch_size x cell.state_size]. 930 | """ 931 | outputs_dict = {} 932 | state_dict = {} 933 | 934 | with variable_scope.variable_scope( 935 | scope or "one2many_rnn_seq2seq", dtype=dtype) as scope: 936 | dtype = scope.dtype 937 | 938 | # Encoder. 939 | encoder_cell = rnn_cell.EmbeddingWrapper( 940 | cell, embedding_classes=num_encoder_symbols, 941 | embedding_size=embedding_size) 942 | _, encoder_state = rnn.rnn(encoder_cell, encoder_inputs, dtype=dtype) 943 | 944 | # Decoder. 945 | for name, decoder_inputs in decoder_inputs_dict.items(): 946 | num_decoder_symbols = num_decoder_symbols_dict[name] 947 | 948 | with variable_scope.variable_scope("one2many_decoder_" + str( 949 | name)) as scope: 950 | decoder_cell = rnn_cell.OutputProjectionWrapper(cell, 951 | num_decoder_symbols) 952 | if isinstance(feed_previous, bool): 953 | outputs, state = embedding_rnn_decoder( 954 | decoder_inputs, encoder_state, decoder_cell, num_decoder_symbols, 955 | embedding_size, feed_previous=feed_previous) 956 | else: 957 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 958 | def filled_embedding_rnn_decoder(feed_previous): 959 | """The current decoder with a fixed feed_previous parameter.""" 960 | # pylint: disable=cell-var-from-loop 961 | reuse = None if feed_previous else True 962 | vs = variable_scope.get_variable_scope() 963 | with variable_scope.variable_scope(vs, reuse=reuse): 964 | outputs, state = embedding_rnn_decoder( 965 | decoder_inputs, encoder_state, decoder_cell, 966 | num_decoder_symbols, embedding_size, 967 | feed_previous=feed_previous) 968 | # pylint: enable=cell-var-from-loop 969 | state_list = [state] 970 | if nest.is_sequence(state): 971 | state_list = nest.flatten(state) 972 | return outputs + state_list 973 | 974 | outputs_and_state = control_flow_ops.cond( 975 | feed_previous, 976 | lambda: filled_embedding_rnn_decoder(True), 977 | lambda: filled_embedding_rnn_decoder(False)) 978 | # Outputs length is the same as for decoder inputs. 979 | outputs_len = len(decoder_inputs) 980 | outputs = outputs_and_state[:outputs_len] 981 | state_list = outputs_and_state[outputs_len:] 982 | state = state_list[0] 983 | if nest.is_sequence(encoder_state): 984 | state = nest.pack_sequence_as(structure=encoder_state, 985 | flat_sequence=state_list) 986 | outputs_dict[name] = outputs 987 | state_dict[name] = state 988 | 989 | return outputs_dict, state_dict 990 | 991 | 992 | def sequence_loss_by_example(logits, targets, weights, 993 | average_across_timesteps=True, 994 | softmax_loss_function=None, name=None): 995 | """Weighted cross-entropy loss for a sequence of logits (per example). 996 | 997 | Args: 998 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 999 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 1000 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 1001 | average_across_timesteps: If set, divide the returned cost by the total 1002 | label weight. 1003 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 1004 | to be used instead of the standard softmax (the default if this is None). 1005 | name: Optional name for this operation, default: "sequence_loss_by_example". 1006 | 1007 | Returns: 1008 | 1D batch-sized float Tensor: The log-perplexity for each sequence. 1009 | 1010 | Raises: 1011 | ValueError: If len(logits) is different from len(targets) or len(weights). 1012 | """ 1013 | if len(targets) != len(logits) or len(weights) != len(logits): 1014 | raise ValueError("Lengths of logits, weights, and targets must be the same " 1015 | "%d, %d, %d." % (len(logits), len(weights), len(targets))) 1016 | with ops.name_scope(name, "sequence_loss_by_example", 1017 | logits + targets + weights): 1018 | log_perp_list = [] 1019 | for logit, target, weight in zip(logits, targets, weights): 1020 | if softmax_loss_function is None: 1021 | # TODO(irving,ebrevdo): This reshape is needed because 1022 | # sequence_loss_by_example is called with scalars sometimes, which 1023 | # violates our general scalar strictness policy. 1024 | target = array_ops.reshape(target, [-1]) 1025 | crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( 1026 | logit, target) 1027 | else: 1028 | crossent = softmax_loss_function(logit, target) 1029 | log_perp_list.append(crossent * weight) 1030 | log_perps = math_ops.add_n(log_perp_list) 1031 | if average_across_timesteps: 1032 | total_size = math_ops.add_n(weights) 1033 | total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. 1034 | log_perps /= total_size 1035 | return log_perps 1036 | 1037 | 1038 | def sequence_loss(logits, targets, weights, 1039 | average_across_timesteps=True, average_across_batch=True, 1040 | softmax_loss_function=None, name=None): 1041 | """Weighted cross-entropy loss for a sequence of logits, batch-collapsed. 1042 | 1043 | Args: 1044 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 1045 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 1046 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 1047 | average_across_timesteps: If set, divide the returned cost by the total 1048 | label weight. 1049 | average_across_batch: If set, divide the returned cost by the batch size. 1050 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 1051 | to be used instead of the standard softmax (the default if this is None). 1052 | name: Optional name for this operation, defaults to "sequence_loss". 1053 | 1054 | Returns: 1055 | A scalar float Tensor: The average log-perplexity per symbol (weighted). 1056 | 1057 | Raises: 1058 | ValueError: If len(logits) is different from len(targets) or len(weights). 1059 | """ 1060 | with ops.name_scope(name, "sequence_loss", logits + targets + weights): 1061 | cost = math_ops.reduce_sum(sequence_loss_by_example( 1062 | logits, targets, weights, 1063 | average_across_timesteps=average_across_timesteps, 1064 | softmax_loss_function=softmax_loss_function)) 1065 | if average_across_batch: 1066 | batch_size = array_ops.shape(targets[0])[0] 1067 | return cost / math_ops.cast(batch_size, cost.dtype) 1068 | else: 1069 | return cost 1070 | 1071 | 1072 | def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, 1073 | buckets, seq2seq, softmax_loss_function=None, 1074 | per_example_loss=False, name=None): 1075 | """Create a sequence-to-sequence model with support for bucketing. 1076 | 1077 | The seq2seq argument is a function that defines a sequence-to-sequence model, 1078 | e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24)) 1079 | 1080 | Args: 1081 | encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input. 1082 | decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input. 1083 | targets: A list of 1D batch-sized int32 Tensors (desired output sequence). 1084 | weights: List of 1D batch-sized float-Tensors to weight the targets. 1085 | buckets: A list of pairs of (input size, output size) for each bucket. 1086 | seq2seq: A sequence-to-sequence model function; it takes 2 input that 1087 | agree with encoder_inputs and decoder_inputs, and returns a pair 1088 | consisting of outputs and states (as, e.g., basic_rnn_seq2seq). 1089 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 1090 | to be used instead of the standard softmax (the default if this is None). 1091 | per_example_loss: Boolean. If set, the returned loss will be a batch-sized 1092 | tensor of losses for each sequence in the batch. If unset, it will be 1093 | a scalar with the averaged loss from all examples. 1094 | name: Optional name for this operation, defaults to "model_with_buckets". 1095 | 1096 | Returns: 1097 | A tuple of the form (outputs, losses), where: 1098 | outputs: The outputs for each bucket. Its j'th element consists of a list 1099 | of 2D Tensors. The shape of output tensors can be either 1100 | [batch_size x output_size] or [batch_size x num_decoder_symbols] 1101 | depending on the seq2seq model used. 1102 | losses: List of scalar Tensors, representing losses for each bucket, or, 1103 | if per_example_loss is set, a list of 1D batch-sized float Tensors. 1104 | 1105 | Raises: 1106 | ValueError: If length of encoder_inputsut, targets, or weights is smaller 1107 | than the largest (last) bucket. 1108 | """ 1109 | if len(encoder_inputs) < buckets[-1][0]: 1110 | raise ValueError("Length of encoder_inputs (%d) must be at least that of la" 1111 | "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) 1112 | if len(targets) < buckets[-1][1]: 1113 | raise ValueError("Length of targets (%d) must be at least that of last" 1114 | "bucket (%d)." % (len(targets), buckets[-1][1])) 1115 | if len(weights) < buckets[-1][1]: 1116 | raise ValueError("Length of weights (%d) must be at least that of last" 1117 | "bucket (%d)." % (len(weights), buckets[-1][1])) 1118 | 1119 | all_inputs = encoder_inputs + decoder_inputs + targets + weights 1120 | losses = [] 1121 | outputs = [] 1122 | with ops.name_scope(name, "model_with_buckets", all_inputs): 1123 | for j, bucket in enumerate(buckets): 1124 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 1125 | reuse=True if j > 0 else None): 1126 | bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]], 1127 | decoder_inputs[:bucket[1]]) 1128 | outputs.append(bucket_outputs) 1129 | if per_example_loss: 1130 | losses.append(sequence_loss_by_example( 1131 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 1132 | softmax_loss_function=softmax_loss_function)) 1133 | else: 1134 | losses.append(sequence_loss( 1135 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 1136 | softmax_loss_function=softmax_loss_function)) 1137 | 1138 | return outputs, losses 1139 | -------------------------------------------------------------------------------- /text_corrector_data_readers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | 7 | from data_reader import DataReader, PAD_TOKEN, EOS_TOKEN, GO_TOKEN 8 | 9 | 10 | class PTBDataReader(DataReader): 11 | """ 12 | DataReader used to read in the Penn Treebank dataset. 13 | """ 14 | 15 | UNKNOWN_TOKEN = "" # already defined in the source data 16 | 17 | DROPOUT_WORDS = {"a", "an", "the"} 18 | DROPOUT_PROB = 0.25 19 | 20 | REPLACEMENTS = {"there": "their", "their": "there"} 21 | REPLACEMENT_PROB = 0.25 22 | 23 | def __init__(self, config, train_path): 24 | super(PTBDataReader, self).__init__( 25 | config, train_path, special_tokens=[PAD_TOKEN, GO_TOKEN, EOS_TOKEN]) 26 | 27 | self.UNKNOWN_ID = self.token_to_id[PTBDataReader.UNKNOWN_TOKEN] 28 | 29 | def read_samples_by_string(self, path): 30 | 31 | for line in self.read_tokens(path): 32 | source = [] 33 | target = [] 34 | 35 | for token in line: 36 | target.append(token) 37 | 38 | # Randomly dropout some words from the input. 39 | dropout_word = (token in PTBDataReader.DROPOUT_WORDS and 40 | random.random() < PTBDataReader.DROPOUT_PROB) 41 | replace_word = (token in PTBDataReader.REPLACEMENTS and 42 | random.random() < 43 | PTBDataReader.REPLACEMENT_PROB) 44 | 45 | if replace_word: 46 | source.append(PTBDataReader.REPLACEMENTS[token]) 47 | elif not dropout_word: 48 | source.append(token) 49 | 50 | yield source, target 51 | 52 | def unknown_token(self): 53 | return PTBDataReader.UNKNOWN_TOKEN 54 | 55 | def read_tokens(self, path): 56 | with open(path, "r") as f: 57 | for line in f: 58 | yield line.rstrip().lstrip().split() 59 | 60 | 61 | class MovieDialogReader(DataReader): 62 | """ 63 | DataReader used to read and tokenize data from the Cornell open movie 64 | dialog dataset. 65 | """ 66 | 67 | UNKNOWN_TOKEN = "UNK" 68 | 69 | DROPOUT_TOKENS = {"a", "an", "the", "'ll", "'s", "'m", "'ve"} # Add "to" 70 | 71 | REPLACEMENTS = {"there": "their", "their": "there", "then": "than", 72 | "than": "then"} 73 | # Add: "be":"to" 74 | 75 | def __init__(self, config, train_path=None, token_to_id=None, 76 | dropout_prob=0.25, replacement_prob=0.25, dataset_copies=2): 77 | super(MovieDialogReader, self).__init__( 78 | config, train_path=train_path, token_to_id=token_to_id, 79 | special_tokens=[ 80 | PAD_TOKEN, GO_TOKEN, EOS_TOKEN, 81 | MovieDialogReader.UNKNOWN_TOKEN], 82 | dataset_copies=dataset_copies) 83 | 84 | self.dropout_prob = dropout_prob 85 | self.replacement_prob = replacement_prob 86 | 87 | self.UNKNOWN_ID = self.token_to_id[MovieDialogReader.UNKNOWN_TOKEN] 88 | 89 | def read_samples_by_string(self, path): 90 | for tokens in self.read_tokens(path): 91 | source = [] 92 | target = [] 93 | 94 | for token in tokens: 95 | target.append(token) 96 | 97 | # Randomly dropout some words from the input. 98 | dropout_token = (token in MovieDialogReader.DROPOUT_TOKENS and 99 | random.random() < self.dropout_prob) 100 | replace_token = (token in MovieDialogReader.REPLACEMENTS and 101 | random.random() < self.replacement_prob) 102 | 103 | if replace_token: 104 | source.append(MovieDialogReader.REPLACEMENTS[token]) 105 | elif not dropout_token: 106 | source.append(token) 107 | 108 | yield source, target 109 | 110 | def unknown_token(self): 111 | return MovieDialogReader.UNKNOWN_TOKEN 112 | 113 | def read_tokens(self, path): 114 | with open(path, "r") as f: 115 | for line in f: 116 | yield line.lower().strip().split() 117 | 118 | -------------------------------------------------------------------------------- /text_corrector_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorflow.python.ops import array_ops 10 | from tensorflow.python.ops import embedding_ops 11 | from tensorflow.python.ops import math_ops 12 | from tensorflow.python.ops import nn_ops 13 | 14 | import seq2seq 15 | from data_reader import PAD_ID, GO_ID 16 | 17 | 18 | class TextCorrectorModel(object): 19 | """Sequence-to-sequence model used to correct grammatical errors in text. 20 | 21 | NOTE: mostly copied from TensorFlow's seq2seq_model.py; only modifications 22 | are: 23 | - the introduction of RMSProp as an optional optimization algorithm 24 | - the introduction of a "projection bias" that biases decoding towards 25 | selecting tokens that appeared in the input 26 | """ 27 | 28 | def __init__(self, source_vocab_size, target_vocab_size, buckets, size, 29 | num_layers, max_gradient_norm, batch_size, learning_rate, 30 | learning_rate_decay_factor, use_lstm=False, 31 | num_samples=512, forward_only=False, config=None, 32 | corrective_tokens_mask=None): 33 | """Create the model. 34 | 35 | Args: 36 | source_vocab_size: size of the source vocabulary. 37 | target_vocab_size: size of the target vocabulary. 38 | buckets: a list of pairs (I, O), where I specifies maximum input 39 | length that will be processed in that bucket, and O specifies 40 | maximum output length. Training instances that have longer than I 41 | or outputs longer than O will be pushed to the next bucket and 42 | padded accordingly. We assume that the list is sorted, e.g., [(2, 43 | 4), (8, 16)]. 44 | size: number of units in each layer of the model. 45 | num_layers: number of layers in the model. 46 | max_gradient_norm: gradients will be clipped to maximally this norm. 47 | batch_size: the size of the batches used during training; 48 | the model construction is independent of batch_size, so it can be 49 | changed after initialization if this is convenient, e.g., 50 | for decoding. 51 | learning_rate: learning rate to start with. 52 | learning_rate_decay_factor: decay learning rate by this much when 53 | needed. 54 | use_lstm: if true, we use LSTM cells instead of GRU cells. 55 | num_samples: number of samples for sampled softmax. 56 | forward_only: if set, we do not construct the backward pass in the 57 | model. 58 | """ 59 | self.source_vocab_size = source_vocab_size 60 | self.target_vocab_size = target_vocab_size 61 | self.buckets = buckets 62 | self.batch_size = batch_size 63 | self.learning_rate = tf.Variable(float(learning_rate), trainable=False) 64 | self.learning_rate_decay_op = self.learning_rate.assign( 65 | self.learning_rate * learning_rate_decay_factor) 66 | self.global_step = tf.Variable(0, trainable=False) 67 | self.config = config 68 | 69 | # Feeds for inputs. 70 | self.encoder_inputs = [] 71 | self.decoder_inputs = [] 72 | self.target_weights = [] 73 | for i in range(buckets[-1][0]): # Last bucket is the biggest one. 74 | self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None], 75 | name="encoder{0}".format( 76 | i))) 77 | for i in range(buckets[-1][1] + 1): 78 | self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None], 79 | name="decoder{0}".format( 80 | i))) 81 | self.target_weights.append(tf.placeholder(tf.float32, shape=[None], 82 | name="weight{0}".format( 83 | i))) 84 | 85 | # One hot encoding of corrective tokens. 86 | corrective_tokens_tensor = tf.constant(corrective_tokens_mask if 87 | corrective_tokens_mask else 88 | np.zeros(self.target_vocab_size), 89 | shape=[self.target_vocab_size], 90 | dtype=tf.float32) 91 | batched_corrective_tokens = tf.pack( 92 | [corrective_tokens_tensor] * self.batch_size) 93 | self.batch_corrective_tokens_mask = batch_corrective_tokens_mask = \ 94 | tf.placeholder( 95 | tf.float32, 96 | shape=[None, None], 97 | name="corrective_tokens") 98 | 99 | # Our targets are decoder inputs shifted by one. 100 | targets = [self.decoder_inputs[i + 1] 101 | for i in range(len(self.decoder_inputs) - 1)] 102 | # If we use sampled softmax, we need an output projection. 103 | output_projection = None 104 | softmax_loss_function = None 105 | # Sampled softmax only makes sense if we sample less than vocabulary 106 | # size. 107 | if num_samples > 0 and num_samples < self.target_vocab_size: 108 | w = tf.get_variable("proj_w", [size, self.target_vocab_size]) 109 | w_t = tf.transpose(w) 110 | b = tf.get_variable("proj_b", [self.target_vocab_size]) 111 | 112 | output_projection = (w, b) 113 | 114 | def sampled_loss(inputs, labels): 115 | labels = tf.reshape(labels, [-1, 1]) 116 | return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, 117 | num_samples, 118 | self.target_vocab_size) 119 | softmax_loss_function = sampled_loss 120 | 121 | # Create the internal multi-layer cell for our RNN. 122 | single_cell = tf.nn.rnn_cell.GRUCell(size) 123 | if use_lstm: 124 | single_cell = tf.nn.rnn_cell.BasicLSTMCell(size) 125 | cell = single_cell 126 | if num_layers > 1: 127 | cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers) 128 | 129 | # The seq2seq function: we use embedding for the input and attention. 130 | def seq2seq_f(encoder_inputs, decoder_inputs, do_decode): 131 | """ 132 | 133 | :param encoder_inputs: list of length equal to the input bucket 134 | length of 1-D tensors (of length equal to the batch size) whose 135 | elements consist of the token index of each sample in the batch 136 | at a given index in the input. 137 | :param decoder_inputs: 138 | :param do_decode: 139 | :return: 140 | """ 141 | 142 | if do_decode: 143 | # Modify bias here to bias the model towards selecting words 144 | # present in the input sentence. 145 | input_bias = self.build_input_bias(encoder_inputs, 146 | batch_corrective_tokens_mask) 147 | 148 | # Redefined seq2seq to allow for the injection of a special 149 | # decoding function that 150 | return seq2seq.embedding_attention_seq2seq( 151 | encoder_inputs, decoder_inputs, cell, 152 | num_encoder_symbols=source_vocab_size, 153 | num_decoder_symbols=target_vocab_size, 154 | embedding_size=size, 155 | output_projection=output_projection, 156 | feed_previous=do_decode, 157 | loop_fn_factory= 158 | apply_input_bias_and_extract_argmax_fn_factory(input_bias)) 159 | else: 160 | return seq2seq.embedding_attention_seq2seq( 161 | encoder_inputs, decoder_inputs, cell, 162 | num_encoder_symbols=source_vocab_size, 163 | num_decoder_symbols=target_vocab_size, 164 | embedding_size=size, 165 | output_projection=output_projection, 166 | feed_previous=do_decode) 167 | 168 | # Training outputs and losses. 169 | if forward_only: 170 | self.outputs, self.losses = tf.nn.seq2seq.model_with_buckets( 171 | self.encoder_inputs, self.decoder_inputs, targets, 172 | self.target_weights, buckets, 173 | lambda x, y: seq2seq_f(x, y, True), 174 | softmax_loss_function=softmax_loss_function) 175 | 176 | if output_projection is not None: 177 | for b in range(len(buckets)): 178 | # We need to apply the same input bias used during model 179 | # evaluation when decoding. 180 | input_bias = self.build_input_bias( 181 | self.encoder_inputs[:buckets[b][0]], 182 | batch_corrective_tokens_mask) 183 | self.outputs[b] = [ 184 | project_and_apply_input_bias(output, output_projection, 185 | input_bias) 186 | for output in self.outputs[b]] 187 | else: 188 | self.outputs, self.losses = tf.nn.seq2seq.model_with_buckets( 189 | self.encoder_inputs, self.decoder_inputs, targets, 190 | self.target_weights, buckets, 191 | lambda x, y: seq2seq_f(x, y, False), 192 | softmax_loss_function=softmax_loss_function) 193 | 194 | # Gradients and SGD update operation for training the model. 195 | params = tf.trainable_variables() 196 | if not forward_only: 197 | self.gradient_norms = [] 198 | self.updates = [] 199 | opt = tf.train.RMSPropOptimizer(0.001) if self.config.use_rms_prop \ 200 | else tf.train.GradientDescentOptimizer(self.learning_rate) 201 | # opt = tf.train.AdamOptimizer() 202 | 203 | for b in range(len(buckets)): 204 | gradients = tf.gradients(self.losses[b], params) 205 | clipped_gradients, norm = tf.clip_by_global_norm( 206 | gradients, max_gradient_norm) 207 | self.gradient_norms.append(norm) 208 | self.updates.append(opt.apply_gradients( 209 | zip(clipped_gradients, params), 210 | global_step=self.global_step)) 211 | 212 | self.saver = tf.train.Saver(tf.all_variables()) 213 | 214 | def build_input_bias(self, encoder_inputs, batch_corrective_tokens_mask): 215 | packed_one_hot_inputs = tf.one_hot(indices=tf.pack( 216 | encoder_inputs, axis=1), depth=self.target_vocab_size) 217 | return tf.maximum(batch_corrective_tokens_mask, 218 | tf.reduce_max(packed_one_hot_inputs, 219 | reduction_indices=1)) 220 | 221 | def step(self, session, encoder_inputs, decoder_inputs, target_weights, 222 | bucket_id, forward_only, corrective_tokens=None): 223 | """Run a step of the model feeding the given inputs. 224 | 225 | Args: 226 | session: tensorflow session to use. 227 | encoder_inputs: list of numpy int vectors to feed as encoder inputs. 228 | decoder_inputs: list of numpy int vectors to feed as decoder inputs. 229 | target_weights: list of numpy float vectors to feed as target weights. 230 | bucket_id: which bucket of the model to use. 231 | forward_only: whether to do the backward step or only forward. 232 | 233 | Returns: 234 | A triple consisting of gradient norm (or None if we did not do 235 | backward), average perplexity, and the outputs. 236 | 237 | Raises: 238 | ValueError: if length of encoder_inputs, decoder_inputs, or 239 | target_weights disagrees with bucket size for the specified 240 | bucket_id. 241 | """ 242 | # Check if the sizes match. 243 | encoder_size, decoder_size = self.buckets[bucket_id] 244 | if len(encoder_inputs) != encoder_size: 245 | raise ValueError( 246 | "Encoder length must be equal to the one in bucket," 247 | " %d != %d." % (len(encoder_inputs), encoder_size)) 248 | if len(decoder_inputs) != decoder_size: 249 | raise ValueError( 250 | "Decoder length must be equal to the one in bucket," 251 | " %d != %d." % (len(decoder_inputs), decoder_size)) 252 | if len(target_weights) != decoder_size: 253 | raise ValueError( 254 | "Weights length must be equal to the one in bucket," 255 | " %d != %d." % (len(target_weights), decoder_size)) 256 | 257 | # Input feed: encoder inputs, decoder inputs, target_weights, 258 | # as provided. 259 | input_feed = {} 260 | for l in range(encoder_size): 261 | input_feed[self.encoder_inputs[l].name] = encoder_inputs[l] 262 | for l in range(decoder_size): 263 | input_feed[self.decoder_inputs[l].name] = decoder_inputs[l] 264 | input_feed[self.target_weights[l].name] = target_weights[l] 265 | 266 | # TODO: learn corrective tokens during training 267 | corrective_tokens_vector = (corrective_tokens 268 | if corrective_tokens is not None else 269 | np.zeros(self.target_vocab_size)) 270 | batch_corrective_tokens = np.repeat([corrective_tokens_vector], 271 | self.batch_size, axis=0) 272 | input_feed[self.batch_corrective_tokens_mask.name] = ( 273 | batch_corrective_tokens) 274 | 275 | # Since our targets are decoder inputs shifted by one, we need one more. 276 | last_target = self.decoder_inputs[decoder_size].name 277 | input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32) 278 | 279 | # Output feed: depends on whether we do a backward step or not. 280 | if not forward_only: 281 | output_feed = [self.updates[bucket_id], # Update Op that does SGD. 282 | self.gradient_norms[bucket_id], # Gradient norm. 283 | self.losses[bucket_id]] # Loss for this batch. 284 | else: 285 | output_feed = [self.losses[bucket_id]] # Loss for this batch. 286 | for l in range(decoder_size): # Output logits. 287 | output_feed.append(self.outputs[bucket_id][l]) 288 | 289 | outputs = session.run(output_feed, input_feed) 290 | if not forward_only: 291 | # Gradient norm, loss, no outputs. 292 | return outputs[1], outputs[2], None 293 | else: 294 | # No gradient norm, loss, outputs. 295 | return None, outputs[0], outputs[1:] 296 | 297 | def get_batch(self, data, bucket_id): 298 | """Get a random batch of data from the specified bucket, prepare for 299 | step. 300 | 301 | To feed data in step(..) it must be a list of batch-major vectors, while 302 | data here contains single length-major cases. So the main logic of this 303 | function is to re-index data cases to be in the proper format for 304 | feeding. 305 | 306 | Args: 307 | data: a tuple of size len(self.buckets) in which each element contains 308 | lists of pairs of input and output data that we use to create a 309 | batch. 310 | bucket_id: integer, which bucket to get the batch for. 311 | 312 | Returns: 313 | The triple (encoder_inputs, decoder_inputs, target_weights) for 314 | the constructed batch that has the proper format to call step(...) 315 | later. 316 | """ 317 | encoder_size, decoder_size = self.buckets[bucket_id] 318 | encoder_inputs, decoder_inputs = [], [] 319 | 320 | # Get a random batch of encoder and decoder inputs from data, 321 | # pad them if needed, reverse encoder inputs and add GO to decoder. 322 | for _ in range(self.batch_size): 323 | encoder_input, decoder_input = random.choice(data[bucket_id]) 324 | 325 | # Encoder inputs are padded and then reversed. 326 | encoder_pad = [PAD_ID] * ( 327 | encoder_size - len(encoder_input)) 328 | encoder_inputs.append(list(reversed(encoder_input + encoder_pad))) 329 | 330 | # Decoder inputs get an extra "GO" symbol, and are padded then. 331 | decoder_pad_size = decoder_size - len(decoder_input) - 1 332 | decoder_inputs.append([GO_ID] + decoder_input + 333 | [PAD_ID] * decoder_pad_size) 334 | 335 | # Now we create batch-major vectors from the data selected above. 336 | batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], [] 337 | 338 | # Batch encoder inputs are just re-indexed encoder_inputs. 339 | for length_idx in range(encoder_size): 340 | batch_encoder_inputs.append( 341 | np.array([encoder_inputs[batch_idx][length_idx] 342 | for batch_idx in range(self.batch_size)], 343 | dtype=np.int32)) 344 | 345 | # Batch decoder inputs are re-indexed decoder_inputs, we create weights. 346 | for length_idx in range(decoder_size): 347 | batch_decoder_inputs.append( 348 | np.array([decoder_inputs[batch_idx][length_idx] 349 | for batch_idx in range(self.batch_size)], 350 | dtype=np.int32)) 351 | 352 | # Create target_weights to be 0 for targets that are padding. 353 | batch_weight = np.ones(self.batch_size, dtype=np.float32) 354 | for batch_idx in range(self.batch_size): 355 | # We set weight to 0 if the corresponding target is a PAD 356 | # symbol. The corresponding target is decoder_input shifted by 1 357 | # forward. 358 | if length_idx < decoder_size - 1: 359 | target = decoder_inputs[batch_idx][length_idx + 1] 360 | if length_idx == decoder_size - 1 or target == PAD_ID: 361 | batch_weight[batch_idx] = 0.0 362 | batch_weights.append(batch_weight) 363 | return batch_encoder_inputs, batch_decoder_inputs, batch_weights 364 | 365 | 366 | def project_and_apply_input_bias(logits, output_projection, input_bias): 367 | if output_projection is not None: 368 | logits = nn_ops.xw_plus_b( 369 | logits, output_projection[0], output_projection[1]) 370 | 371 | # Apply softmax to ensure all tokens have a positive value. 372 | probs = tf.nn.softmax(logits) 373 | 374 | # Apply input bias, which is a mask of shape [batch, vocab len] 375 | # where each token from the input in addition to all "corrective" 376 | # tokens are set to 1.0. 377 | return tf.mul(probs, input_bias) 378 | 379 | 380 | def apply_input_bias_and_extract_argmax_fn_factory(input_bias): 381 | """ 382 | 383 | :param encoder_inputs: list of length equal to the input bucket 384 | length of 1-D tensors (of length equal to the batch size) whose 385 | elements consist of the token index of each sample in the batch 386 | at a given index in the input. 387 | :return: 388 | """ 389 | 390 | def fn_factory(embedding, output_projection=None, update_embedding=True): 391 | """Get a loop_function that extracts the previous symbol and embeds it. 392 | 393 | Args: 394 | embedding: embedding tensor for symbols. 395 | output_projection: None or a pair (W, B). If provided, each fed previous 396 | output will first be multiplied by W and added B. 397 | update_embedding: Boolean; if False, the gradients will not propagate 398 | through the embeddings. 399 | 400 | Returns: 401 | A loop function. 402 | """ 403 | def loop_function(prev, _): 404 | prev = project_and_apply_input_bias(prev, output_projection, 405 | input_bias) 406 | 407 | prev_symbol = math_ops.argmax(prev, 1) 408 | # Note that gradients will not propagate through the second 409 | # parameter of embedding_lookup. 410 | emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol) 411 | if not update_embedding: 412 | emb_prev = array_ops.stop_gradient(emb_prev) 413 | return emb_prev, prev_symbol 414 | return loop_function 415 | 416 | return fn_factory 417 | 418 | --------------------------------------------------------------------------------