├── .gitignore ├── LICENSE ├── README.md └── qa.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *,cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Stephen Merity 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Keras question answering for bAbi](http://i.imgur.com/RuINg4m.jpg)](https://www.flickr.com/photos/isherwoodchris/6917253693/in/photolist) 2 | 3 | # Keras question answering for bAbi 4 | 5 | **Note: This code has been merged as an example into Keras -- see [babi_rnn.py](https://github.com/fchollet/keras/blob/master/examples/babi_rnn.py)** 6 | 7 | This repository contains Keras code to train two recurrent neural networks based upon a story and a question. 8 | The resulting merged vector is then queried to answer a range of bAbI tasks. 9 | 10 | An example from the first task, QA1, is below: 11 | 12 | 1 Mary moved to the bathroom. 13 | 2 John went to the hallway. 14 | 3 Where is Mary? bathroom 1 15 | 4 Daniel went back to the hallway. 16 | 5 Sandra moved to the garden. 17 | 6 Where is Daniel? hallway 4 18 | 19 | The results are comparable (or superior) to those for the LSTM baseline provided in Weston et al.'s [Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks](http://arxiv.org/abs/1502.05698) given only 1000 samples and without any hyperparamater tuning. 20 | 21 | Task Number | FB LSTM Baseline | Keras QA 22 | --- | --- | --- 23 | QA1 - Single Supporting Fact | 50 | 52.1 24 | QA2 - Two Supporting Facts | 20 | 37.0 25 | QA3 - Three Supporting Facts | 20 | 20.5 26 | QA4 - Two Arg. Relations | 61 | 62.9 27 | QA5 - Three Arg. Relations | 70 | 61.9 28 | QA6 - Yes/No Questions | 48 | 50.7 29 | QA7 - Counting | 49 | 78.9 30 | QA8 - Lists/Sets | 45 | 77.2 31 | QA9 - Simple Negation | 64 | 64.0 32 | QA10 - Indefinite Knowledge | 44 | 47.7 33 | QA11 - Basic Coreference | 72 | 74.9 34 | QA12 - Conjunction | 74 | 76.4 35 | QA13 - Compound Coreference | 94 | 94.4 36 | QA14 - Time Reasoning | 27 | 34.8 37 | QA15 - Basic Deduction | 21 | 32.4 38 | QA16 - Basic Induction | 23 | 50.6 39 | QA17 - Positional Reasoning | 51 | 49.1 40 | QA18 - Size Reasoning | 52 | 90.8 41 | QA19 - Path Finding | 8 | 9.0 42 | QA20 - Agent's Motivations | 91 | 90.7 43 | 44 | For the resources related to the bAbI project, refer to the [Facebook AI Research bAbI project page](https://research.facebook.com/researchers/1543934539189348). 45 | 46 | # License 47 | 48 | MIT License, as per `LICENSE` 49 | -------------------------------------------------------------------------------- /qa.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | import re 4 | import tarfile 5 | 6 | import numpy as np 7 | np.random.seed(1337) # for reproducibility 8 | 9 | from keras.datasets.data_utils import get_file 10 | from keras.layers.embeddings import Embedding 11 | from keras.layers.core import Dense, Merge 12 | from keras.layers import recurrent 13 | from keras.models import Sequential 14 | from keras.preprocessing.sequence import pad_sequences 15 | 16 | ''' 17 | Trains two recurrent neural networks based upon a story and a question. 18 | The resulting merged vector is then queried to answer a range of bAbI tasks. 19 | 20 | The results are comparable to those for an LSTM model provided in Weston et al.: 21 | "Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks" 22 | http://arxiv.org/abs/1502.05698 23 | 24 | For the resources related to the bAbI project, refer to: 25 | https://research.facebook.com/researchers/1543934539189348 26 | 27 | Notes: 28 | 29 | - With default word, sentence, and query vector sizes, the GRU model achieves: 30 | - 52.1% test accuracy on QA1 in 20 epochs (2 seconds per epoch on CPU) 31 | - 37.0% test accuracy on QA2 in 20 epochs (16 seconds per epoch on CPU) 32 | In comparison, the Facebook paper achieves 50% and 20% for the LSTM baseline. 33 | 34 | - The task does not traditionally parse the question separately. This likely 35 | improves accuracy and is a good example of merging two RNNs. 36 | 37 | - The word vector embeddings are not shared between the story and question RNNs. 38 | 39 | - See how the accuracy changes given 10,000 training samples (en-10k) instead 40 | of only 1000. 1000 was used in order to be comparable to the original paper. 41 | 42 | - Experiment with GRU, LSTM, and JZS1-3 as they give subtly different results. 43 | 44 | - The length and noise (i.e. 'useless' story components) impact the ability for 45 | LSTMs / GRUs to provide the correct answer. Given only the supporting facts, 46 | these RNNs can achieve 100% accuracy on many tasks. Memory networks and neural 47 | networks that use attentional processes can efficiently search through this 48 | noise to find the relevant statements, improving performance substantially. 49 | This becomes especially obvious on QA2 and QA3, both far longer than QA1. 50 | ''' 51 | 52 | 53 | def tokenize(sent): 54 | '''Return the tokens of a sentence including punctuation. 55 | 56 | >>> tokenize('Bob dropped the apple. Where is the apple?') 57 | ['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?'] 58 | ''' 59 | return [x.strip() for x in re.split('(\W+)?', sent) if x.strip()] 60 | 61 | 62 | def parse_stories(lines, only_supporting=False): 63 | '''Parse stories provided in the bAbi tasks format 64 | 65 | If only_supporting is true, only the sentences that support the answer are kept. 66 | ''' 67 | data = [] 68 | story = [] 69 | for line in lines: 70 | line = line.strip() 71 | nid, line = line.split(' ', 1) 72 | nid = int(nid) 73 | if nid == 1: 74 | story = [] 75 | if '\t' in line: 76 | q, a, supporting = line.split('\t') 77 | q = tokenize(q) 78 | substory = None 79 | if only_supporting: 80 | # Only select the related substory 81 | supporting = map(int, supporting.split()) 82 | substory = [story[i - 1] for i in supporting] 83 | else: 84 | # Provide all the substories 85 | substory = [x for x in story if x] 86 | data.append((substory, q, a)) 87 | story.append('') 88 | else: 89 | sent = tokenize(line) 90 | story.append(sent) 91 | return data 92 | 93 | 94 | def get_stories(f, only_supporting=False, max_length=None): 95 | '''Given a file name, read the file, retrieve the stories, and then convert the sentences into a single story. 96 | 97 | If max_length is supplied, any stories longer than max_length tokens will be discarded. 98 | ''' 99 | data = parse_stories(f.readlines(), only_supporting=only_supporting) 100 | flatten = lambda data: reduce(lambda x, y: x + y, data) 101 | data = [(flatten(story), q, answer) for story, q, answer in data if not max_length or len(flatten(story)) < max_length] 102 | return data 103 | 104 | 105 | def vectorize_stories(data): 106 | X = [] 107 | Xq = [] 108 | Y = [] 109 | for story, query, answer in data: 110 | x = [word_idx[w] for w in story] 111 | xq = [word_idx[w] for w in query] 112 | y = np.zeros(vocab_size) 113 | y[word_idx[answer]] = 1 114 | X.append(x) 115 | Xq.append(xq) 116 | Y.append(y) 117 | return pad_sequences(X, maxlen=story_maxlen), pad_sequences(Xq, maxlen=query_maxlen), np.array(Y) 118 | 119 | RNN = recurrent.GRU 120 | EMBED_HIDDEN_SIZE = 50 121 | SENT_HIDDEN_SIZE = 100 122 | QUERY_HIDDEN_SIZE = 100 123 | BATCH_SIZE = 32 124 | EPOCHS = 20 125 | print('RNN / Embed / Sent / Query = {}, {}, {}, {}'.format(RNN, EMBED_HIDDEN_SIZE, SENT_HIDDEN_SIZE, QUERY_HIDDEN_SIZE)) 126 | 127 | path = get_file('babi-tasks-v1-2.tar.gz', origin='http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz') 128 | tar = tarfile.open(path) 129 | # Default QA1 with 1000 samples 130 | # challenge = 'tasks_1-20_v1-2/en/qa1_single-supporting-fact_{}.txt' 131 | # QA1 with 10,000 samples 132 | # challenge = 'tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_{}.txt' 133 | # QA2 with 1000 samples 134 | challenge = 'tasks_1-20_v1-2/en-10k/qa2_two-supporting-facts_{}.txt' 135 | # QA2 with 10,000 samples 136 | # challenge = 'tasks_1-20_v1-2/en-10k/qa2_two-supporting-facts_{}.txt' 137 | train = get_stories(tar.extractfile(challenge.format('train'))) 138 | test = get_stories(tar.extractfile(challenge.format('test'))) 139 | 140 | vocab = sorted(reduce(lambda x, y: x | y, (set(story + q + [answer]) for story, q, answer in train + test))) 141 | # Reserve 0 for masking via pad_sequences 142 | vocab_size = len(vocab) + 1 143 | word_idx = dict((c, i + 1) for i, c in enumerate(vocab)) 144 | story_maxlen = max(map(len, (x for x, _, _ in train + test))) 145 | query_maxlen = max(map(len, (x for _, x, _ in train + test))) 146 | 147 | X, Xq, Y = vectorize_stories(train) 148 | tX, tXq, tY = vectorize_stories(test) 149 | 150 | print('vocab = {}'.format(vocab)) 151 | print('X.shape = {}'.format(X.shape)) 152 | print('Xq.shape = {}'.format(Xq.shape)) 153 | print('Y.shape = {}'.format(Y.shape)) 154 | print('story_maxlen, query_maxlen = {}, {}'.format(story_maxlen, query_maxlen)) 155 | 156 | print('Build model...') 157 | 158 | sentrnn = Sequential() 159 | sentrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE, mask_zero=True)) 160 | sentrnn.add(RNN(EMBED_HIDDEN_SIZE, SENT_HIDDEN_SIZE, return_sequences=False)) 161 | 162 | qrnn = Sequential() 163 | qrnn.add(Embedding(vocab_size, EMBED_HIDDEN_SIZE)) 164 | qrnn.add(RNN(EMBED_HIDDEN_SIZE, QUERY_HIDDEN_SIZE, return_sequences=False)) 165 | 166 | model = Sequential() 167 | model.add(Merge([sentrnn, qrnn], mode='concat')) 168 | model.add(Dense(SENT_HIDDEN_SIZE + QUERY_HIDDEN_SIZE, vocab_size, activation='softmax')) 169 | 170 | model.compile(optimizer='adam', loss='categorical_crossentropy', class_mode='categorical') 171 | 172 | print('Training') 173 | model.fit([X, Xq], Y, batch_size=BATCH_SIZE, nb_epoch=EPOCHS, validation_split=0.05, show_accuracy=True) 174 | loss, acc = model.evaluate([tX, tXq], tY, batch_size=BATCH_SIZE, show_accuracy=True) 175 | print('Test loss / test accuracy = {:.4f} / {:.4f}'.format(loss, acc)) 176 | --------------------------------------------------------------------------------