├── .gitignore ├── LICENSE ├── README.md ├── babi_runner.py ├── bechmarks └── README.md ├── config.py ├── data └── README.md ├── demo ├── __init__.py ├── qa.py └── web │ ├── __init__.py │ ├── static │ ├── script.js │ └── style.css │ ├── templates │ └── index.html │ └── webapp.py ├── memn2n ├── __init__.py ├── memory.py └── nn.py ├── requirements.txt ├── train_test.py ├── trained_model ├── README.md └── memn2n_model.pklz └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.pyc 3 | data/tasks_1-20_v1-2 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | Copyright (c) 2016, Vinh Khuc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, 7 | are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name Facebook nor the names of its contributors may be used to 17 | endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 21 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 22 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 24 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 25 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 26 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 27 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 29 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## End-To-End Memory Networks for Question Answering 2 | This is an implementation of MemN2N model in Python for the [bAbI question-answering tasks](http://fb.ai/babi) 3 | as shown in the Section 4 of the paper "[End-To-End Memory Networks](http://arxiv.org/abs/1503.08895)". It is based on 4 | Facebook's [Matlab code](https://github.com/facebook/MemNN/tree/master/MemN2N-babi-matlab). 5 | 6 | ![Web-based Demo](http://i.imgur.com/mKtZ7kB.gif) 7 | 8 | ## Requirements 9 | * Python 2.7 10 | * Numpy, Flask (only for web-based demo) can be installed via pip: 11 | ``` 12 | $ sudo pip install -r requirements.txt 13 | ``` 14 | * [bAbI dataset](http://fb.ai/babi) should be downloaded to `data/tasks_1-20_v1-2`: 15 | ``` 16 | $ wget -qO- http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz | tar xvz -C data 17 | ``` 18 | 19 | ## Usage 20 | * To run on a single task, use `babi_runner.py` with `-t` followed by task's id. For example, 21 | ``` 22 | python babi_runner.py -t 1 23 | ``` 24 | The output will look like: 25 | ``` 26 | Using data from data/tasks_1-20_v1-2/en 27 | Train and test for task 1 ... 28 | 1 | train error: 0.876116 | val error: 0.75 29 | |=================================== | 71% 0.5s 30 | ``` 31 | * To run on 20 tasks: 32 | ``` 33 | python babi_runner.py -a 34 | ``` 35 | * To train using all training data from 20 tasks, use the joint mode: 36 | ``` 37 | python babi_runner.py -j 38 | ``` 39 | 40 | ## Question Answering Demo 41 | * In order to run the Web-based demo using the pretrained model `memn2n_model.pklz` in `trained_model/`, run: 42 | ``` 43 | python -m demo.qa 44 | ``` 45 | 46 | * Alternatively, you can try the console-based demo: 47 | ``` 48 | python -m demo.qa -console 49 | ``` 50 | 51 | * The pretrained model `memn2n_model.pklz` can be created by running: 52 | ``` 53 | python -m demo.qa -train 54 | ``` 55 | 56 | * To show all options, run `python -m demo.qa -h` 57 | 58 | ## Benchmarks 59 | See the results [here](https://github.com/vinhkhuc/MemN2N-babi-python/tree/master/bechmarks). 60 | 61 | ### Author 62 | Vinh Khuc 63 | 64 | ### Future Plans 65 | * Port to TensorFlow/Keras 66 | * Support Python 3 67 | 68 | ### References 69 | * Sainbayar Sukhbaatar, Arthur Szlam, Jason Weston, Rob Fergus, 70 | "[End-To-End Memory Networks](http://arxiv.org/abs/1503.08895)", 71 | *arXiv:1503.08895 [cs.NE]*. -------------------------------------------------------------------------------- /babi_runner.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | import sys 5 | 6 | import argparse 7 | import numpy as np 8 | 9 | from config import BabiConfig, BabiConfigJoint 10 | from train_test import train, train_linear_start, test 11 | from util import parse_babi_task, build_model 12 | 13 | seed_val = 42 14 | random.seed(seed_val) 15 | np.random.seed(seed_val) # for reproducing 16 | 17 | 18 | def run_task(data_dir, task_id): 19 | """ 20 | Train and test for each task 21 | """ 22 | print("Train and test for task %d ..." % task_id) 23 | 24 | # Parse data 25 | train_files = glob.glob('%s/qa%d_*_train.txt' % (data_dir, task_id)) 26 | test_files = glob.glob('%s/qa%d_*_test.txt' % (data_dir, task_id)) 27 | 28 | dictionary = {"nil": 0} 29 | train_story, train_questions, train_qstory = parse_babi_task(train_files, dictionary, False) 30 | test_story, test_questions, test_qstory = parse_babi_task(test_files, dictionary, False) 31 | 32 | general_config = BabiConfig(train_story, train_questions, dictionary) 33 | 34 | memory, model, loss = build_model(general_config) 35 | 36 | if general_config.linear_start: 37 | train_linear_start(train_story, train_questions, train_qstory, memory, model, loss, general_config) 38 | else: 39 | train(train_story, train_questions, train_qstory, memory, model, loss, general_config) 40 | 41 | test(test_story, test_questions, test_qstory, memory, model, loss, general_config) 42 | 43 | 44 | def run_all_tasks(data_dir): 45 | """ 46 | Train and test for all tasks 47 | """ 48 | print("Training and testing for all tasks ...") 49 | for t in range(20): 50 | run_task(data_dir, task_id=t + 1) 51 | 52 | 53 | def run_joint_tasks(data_dir): 54 | """ 55 | Train and test for all tasks but the trained model is built using training data from all tasks. 56 | """ 57 | print("Jointly train and test for all tasks ...") 58 | tasks = range(20) 59 | 60 | # Parse training data 61 | train_data_path = [] 62 | for t in tasks: 63 | train_data_path += glob.glob('%s/qa%d_*_train.txt' % (data_dir, t + 1)) 64 | 65 | dictionary = {"nil": 0} 66 | train_story, train_questions, train_qstory = parse_babi_task(train_data_path, dictionary, False) 67 | 68 | # Parse test data for each task so that the dictionary covers all words before training 69 | for t in tasks: 70 | test_data_path = glob.glob('%s/qa%d_*_test.txt' % (data_dir, t + 1)) 71 | parse_babi_task(test_data_path, dictionary, False) # ignore output for now 72 | 73 | general_config = BabiConfigJoint(train_story, train_questions, dictionary) 74 | memory, model, loss = build_model(general_config) 75 | 76 | if general_config.linear_start: 77 | train_linear_start(train_story, train_questions, train_qstory, memory, model, loss, general_config) 78 | else: 79 | train(train_story, train_questions, train_qstory, memory, model, loss, general_config) 80 | 81 | # Test on each task 82 | for t in tasks: 83 | print("Testing for task %d ..." % (t + 1)) 84 | test_data_path = glob.glob('%s/qa%d_*_test.txt' % (data_dir, t + 1)) 85 | dc = len(dictionary) 86 | test_story, test_questions, test_qstory = parse_babi_task(test_data_path, dictionary, False) 87 | assert dc == len(dictionary) # make sure that the dictionary already covers all words 88 | 89 | test(test_story, test_questions, test_qstory, memory, model, loss, general_config) 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("-d", "--data-dir", default="data/tasks_1-20_v1-2/en", 94 | help="path to dataset directory (default: %(default)s)") 95 | group = parser.add_mutually_exclusive_group() 96 | group.add_argument("-t", "--task", default="1", type=int, 97 | help="train and test for a single task (default: %(default)s)") 98 | group.add_argument("-a", "--all-tasks", action="store_true", 99 | help="train and test for all tasks (one by one) (default: %(default)s)") 100 | group.add_argument("-j", "--joint-tasks", action="store_true", 101 | help="train and test for all tasks (all together) (default: %(default)s)") 102 | args = parser.parse_args() 103 | 104 | # Check if data is available 105 | data_dir = args.data_dir 106 | if not os.path.exists(data_dir): 107 | print("The data directory '%s' does not exist. Please download it first." % data_dir) 108 | sys.exit(1) 109 | 110 | print("Using data from %s" % args.data_dir) 111 | if args.all_tasks: 112 | run_all_tasks(data_dir) 113 | elif args.joint_tasks: 114 | run_joint_tasks(data_dir) 115 | else: 116 | run_task(data_dir, task_id=args.task) 117 | -------------------------------------------------------------------------------- /bechmarks/README.md: -------------------------------------------------------------------------------- 1 | This page contains benchmark results to compare this Python implementation and the 2 | [original Matlab code](https://github.com/facebook/MemNN/tree/master/MemN2N-babi-matlab) 3 | on the [bAbI tasks](http://fb.ai/babi). 4 | 5 | These results are test error rates (%) with default configuration: 3 hops, position encoding (PE), 6 | linear start training (LS), random noise (RN) and adjacent weight tying. 7 | 8 | 9 | | Task | PE LS RN (matlab) | PE LS RN JOINT (matlab) | PE LS RN (this repo) | PE LS RN JOINT (this repo) | 10 | |:--------:|:---------------------:|:---------------------------:|:-------------------------:|:------------------------------:| 11 | | 1 | 0.1 | 0.0 | 0.5 | 0.1 | 12 | | 2 | 8.2 | 13.1 | 8.9 | 16.6 | 13 | | 3 | 41.8 | 23.4 | 41.4 | 26.3 | 14 | | 4 | 4.4 | 5.9 | 7.3 | 11.3 | 15 | | 5 | 13.7 | 12.9 | 12.2 | 14.4 | 16 | | 6 | 7.9 | 3.7 | 7.7 | 2.8 | 17 | | 7 | 20.3 | 22.9 | 19.8 | 16.0 | 18 | | 8 | 11.7 | 9.1 | 12.9 | 10.1 | 19 | | 9 | 13.6 | 2.7 | 13.9 | 2.3 | 20 | | 10 | 9.7 | 7.2 | 18.9 | 6.5 | 21 | | 11 | 0.7 | 0.8 | 0.5 | 1.2 | 22 | | 12 | 0.0 | 0.1 | 0.2 | 0.2 | 23 | | 13 | 0.3 | 0.1 | 0.9 | 0.5 | 24 | | 14 | 0.9 | 4.2 | 9.0 | 5.5 | 25 | | 15 | 0.0 | 0.0 | 0.0 | 0.3 | 26 | | 16 | 0.4 | 1.2 | 0.6 | 2.1 | 27 | | 17 | 51.0 | 43.6 | 49.1 | 42.6 | 28 | | 18 | 10.5 | 10.5 | 10.4 | 9.0 | 29 | | 19 | 81.4 | 86.7 | 90.8 | 90.2 | 30 | | 20 | 0.0 | 0.0 | 0.0 | 0.2 | 31 | | **Mean** | **13.8** | **12.4** | **15.2** | **12.9** | 32 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class BabiConfig(object): 4 | """ 5 | Configuration for bAbI 6 | """ 7 | def __init__(self, train_story, train_questions, dictionary): 8 | self.dictionary = dictionary 9 | self.batch_size = 32 10 | self.nhops = 3 11 | self.nepochs = 100 12 | self.lrate_decay_step = 25 # reduce learning rate by half every 25 epochs 13 | 14 | # Use 10% of training data for validation 15 | nb_questions = train_questions.shape[1] 16 | nb_train_questions = int(nb_questions * 0.9) 17 | 18 | self.train_range = np.array(range(nb_train_questions)) 19 | self.val_range = np.array(range(nb_train_questions, nb_questions)) 20 | self.enable_time = True # add time embeddings 21 | self.use_bow = False # use Bag-of-Words instead of Position-Encoding 22 | self.linear_start = True 23 | self.share_type = 1 # 1: adjacent, 2: layer-wise weight tying 24 | self.randomize_time = 0.1 # amount of noise injected into time index 25 | self.add_proj = False # add linear layer between internal states 26 | self.add_nonlin = False # add non-linearity to internal states 27 | 28 | if self.linear_start: 29 | self.ls_nepochs = 20 30 | self.ls_lrate_decay_step = 21 31 | self.ls_init_lrate = 0.01 / 2 32 | 33 | # Training configuration 34 | self.train_config = { 35 | "init_lrate" : 0.01, 36 | "max_grad_norm": 40, 37 | "in_dim" : 20, 38 | "out_dim" : 20, 39 | "sz" : min(50, train_story.shape[1]), # number of sentences 40 | "voc_sz" : len(self.dictionary), 41 | "bsz" : self.batch_size, 42 | "max_words" : len(train_story), 43 | "weight" : None 44 | } 45 | 46 | if self.linear_start: 47 | self.train_config["init_lrate"] = 0.01 / 2 48 | 49 | if self.enable_time: 50 | self.train_config.update({ 51 | "voc_sz" : self.train_config["voc_sz"] + self.train_config["sz"], 52 | "max_words": self.train_config["max_words"] + 1 # Add 1 for time words 53 | }) 54 | 55 | 56 | class BabiConfigJoint(object): 57 | """ 58 | Joint configuration for bAbI 59 | """ 60 | def __init__(self, train_story, train_questions, dictionary): 61 | 62 | # TODO: Inherit from BabiConfig 63 | self.dictionary = dictionary 64 | self.batch_size = 32 65 | self.nhops = 3 66 | self.nepochs = 60 67 | 68 | self.lrate_decay_step = 15 # reduce learning rate by half every 25 epochs # XXX: 69 | 70 | # Use 10% of training data for validation # XXX 71 | nb_questions = train_questions.shape[1] 72 | nb_train_questions = int(nb_questions * 0.9) 73 | 74 | # Randomly split to training and validation sets 75 | rp = np.random.permutation(nb_questions) 76 | self.train_range = rp[:nb_train_questions] 77 | self.val_range = rp[nb_train_questions:] 78 | 79 | self.enable_time = True # add time embeddings 80 | self.use_bow = False # use Bag-of-Words instead of Position-Encoding 81 | self.linear_start = True 82 | self.share_type = 1 # 1: adjacent, 2: layer-wise weight tying 83 | self.randomize_time = 0.1 # amount of noise injected into time index 84 | self.add_proj = False # add linear layer between internal states 85 | self.add_nonlin = False # add non-linearity to internal states 86 | 87 | if self.linear_start: 88 | self.ls_nepochs = 30 # XXX: 89 | self.ls_lrate_decay_step = 31 # XXX: 90 | self.ls_init_lrate = 0.01 / 2 91 | 92 | # Training configuration 93 | self.train_config = { 94 | "init_lrate" : 0.01, 95 | "max_grad_norm": 40, 96 | "in_dim" : 50, # XXX: 97 | "out_dim" : 50, # XXX: 98 | "sz" : min(50, train_story.shape[1]), 99 | "voc_sz" : len(self.dictionary), 100 | "bsz" : self.batch_size, 101 | "max_words" : len(train_story), 102 | "weight" : None 103 | } 104 | 105 | if self.linear_start: 106 | self.train_config["init_lrate"] = 0.01 / 2 107 | 108 | if self.enable_time: 109 | self.train_config.update({ 110 | "voc_sz" : self.train_config["voc_sz"] + self.train_config["sz"], 111 | "max_words": self.train_config["max_words"] + 1 # Add 1 for time words 112 | }) 113 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | This folder should contain the bAbI dataset v1.2 from [fb.ai/babi](http://fb.ai/babi). 2 | -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /demo/qa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Demo of using Memory Network for question answering 3 | """ 4 | import glob 5 | import os 6 | import gzip 7 | import sys 8 | import pickle 9 | 10 | import argparse 11 | import numpy as np 12 | 13 | from config import BabiConfigJoint 14 | from train_test import train, train_linear_start 15 | from util import parse_babi_task, build_model 16 | 17 | 18 | class MemN2N(object): 19 | """ 20 | MemN2N class 21 | """ 22 | def __init__(self, data_dir, model_file): 23 | self.data_dir = data_dir 24 | self.model_file = model_file 25 | self.reversed_dict = None 26 | self.memory = None 27 | self.model = None 28 | self.loss = None 29 | self.general_config = None 30 | 31 | def save_model(self): 32 | with gzip.open(self.model_file, "wb") as f: 33 | print("Saving model to file %s ..." % self.model_file) 34 | pickle.dump((self.reversed_dict, self.memory, self.model, self.loss, self.general_config), f) 35 | 36 | def load_model(self): 37 | # Check if model was loaded 38 | if self.reversed_dict is None or self.memory is None or \ 39 | self.model is None or self.loss is None or self.general_config is None: 40 | print("Loading model from file %s ..." % self.model_file) 41 | with gzip.open(self.model_file, "rb") as f: 42 | self.reversed_dict, self.memory, self.model, self.loss, self.general_config = pickle.load(f) 43 | 44 | def train(self): 45 | """ 46 | Train MemN2N model using training data for tasks. 47 | """ 48 | np.random.seed(42) # for reproducing 49 | assert self.data_dir is not None, "data_dir is not specified." 50 | print("Reading data from %s ..." % self.data_dir) 51 | 52 | # Parse training data 53 | train_data_path = glob.glob('%s/qa*_*_train.txt' % self.data_dir) 54 | dictionary = {"nil": 0} 55 | train_story, train_questions, train_qstory = parse_babi_task(train_data_path, dictionary, False) 56 | 57 | # Parse test data just to expand the dictionary so that it covers all words in the test data too 58 | test_data_path = glob.glob('%s/qa*_*_test.txt' % self.data_dir) 59 | parse_babi_task(test_data_path, dictionary, False) 60 | 61 | # Get reversed dictionary mapping index to word 62 | self.reversed_dict = dict((ix, w) for w, ix in dictionary.items()) 63 | 64 | # Construct model 65 | self.general_config = BabiConfigJoint(train_story, train_questions, dictionary) 66 | self.memory, self.model, self.loss = build_model(self.general_config) 67 | 68 | # Train model 69 | if self.general_config.linear_start: 70 | train_linear_start(train_story, train_questions, train_qstory, 71 | self.memory, self.model, self.loss, self.general_config) 72 | else: 73 | train(train_story, train_questions, train_qstory, 74 | self.memory, self.model, self.loss, self.general_config) 75 | 76 | # Save model 77 | self.save_model() 78 | 79 | def get_story_texts(self, test_story, test_questions, test_qstory, 80 | question_idx, story_idx, last_sentence_idx): 81 | """ 82 | Get text of question, its corresponding fact statements. 83 | """ 84 | train_config = self.general_config.train_config 85 | enable_time = self.general_config.enable_time 86 | max_words = train_config["max_words"] \ 87 | if not enable_time else train_config["max_words"] - 1 88 | 89 | story = [[self.reversed_dict[test_story[word_pos, sent_idx, story_idx]] 90 | for word_pos in range(max_words)] 91 | for sent_idx in range(last_sentence_idx + 1)] 92 | 93 | question = [self.reversed_dict[test_qstory[word_pos, question_idx]] 94 | for word_pos in range(max_words)] 95 | 96 | story_txt = [" ".join([w for w in sent if w != "nil"]) for sent in story] 97 | question_txt = " ".join([w for w in question if w != "nil"]) 98 | correct_answer = self.reversed_dict[test_questions[2, question_idx]] 99 | 100 | return story_txt, question_txt, correct_answer 101 | 102 | def predict_answer(self, test_story, test_questions, test_qstory, 103 | question_idx, story_idx, last_sentence_idx, 104 | user_question=''): 105 | # Get configuration 106 | nhops = self.general_config.nhops 107 | train_config = self.general_config.train_config 108 | batch_size = self.general_config.batch_size 109 | dictionary = self.general_config.dictionary 110 | enable_time = self.general_config.enable_time 111 | 112 | max_words = train_config["max_words"] \ 113 | if not enable_time else train_config["max_words"] - 1 114 | 115 | input_data = np.zeros((max_words, batch_size), np.float32) 116 | input_data[:] = dictionary["nil"] 117 | self.memory[0].data[:] = dictionary["nil"] 118 | 119 | # Check if user provides questions and it's different from suggested question 120 | _, suggested_question, _ = self.get_story_texts(test_story, test_questions, test_qstory, 121 | question_idx, story_idx, last_sentence_idx) 122 | user_question_provided = user_question != '' and user_question != suggested_question 123 | encoded_user_question = None 124 | if user_question_provided: 125 | # print("User question = '%s'" % user_question) 126 | user_question = user_question.strip() 127 | if user_question[-1] == '?': 128 | user_question = user_question[:-1] 129 | qwords = user_question.rstrip().lower().split() # skip '?' 130 | 131 | # Encoding 132 | encoded_user_question = np.zeros(max_words) 133 | encoded_user_question[:] = dictionary["nil"] 134 | for ix, w in enumerate(qwords): 135 | if w in dictionary: 136 | encoded_user_question[ix] = dictionary[w] 137 | else: 138 | print("WARNING - The word '%s' is not in dictionary." % w) 139 | 140 | # Input data and data for the 1st memory cell 141 | # Here we duplicate input_data to fill the whole batch 142 | for b in range(batch_size): 143 | d = test_story[:, :(1 + last_sentence_idx), story_idx] 144 | 145 | offset = max(0, d.shape[1] - train_config["sz"]) 146 | d = d[:, offset:] 147 | 148 | self.memory[0].data[:d.shape[0], :d.shape[1], b] = d 149 | 150 | if enable_time: 151 | self.memory[0].data[-1, :d.shape[1], b] = \ 152 | np.arange(d.shape[1])[::-1] + len(dictionary) # time words 153 | 154 | if user_question_provided: 155 | input_data[:test_qstory.shape[0], b] = encoded_user_question 156 | else: 157 | input_data[:test_qstory.shape[0], b] = test_qstory[:, question_idx] 158 | 159 | # Data for the rest memory cells 160 | for i in range(1, nhops): 161 | self.memory[i].data = self.memory[0].data 162 | 163 | # Run model to predict answer 164 | out = self.model.fprop(input_data) 165 | memory_probs = np.array([self.memory[i].probs[:(last_sentence_idx + 1), 0] for i in range(nhops)]) 166 | 167 | # Get answer for the 1st question since all are the same 168 | pred_answer_idx = out[:, 0].argmax() 169 | pred_prob = out[pred_answer_idx, 0] 170 | 171 | return pred_answer_idx, pred_prob, memory_probs 172 | 173 | 174 | def train_model(data_dir, model_file): 175 | memn2n = MemN2N(data_dir, model_file) 176 | memn2n.train() 177 | 178 | 179 | def run_console_demo(data_dir, model_file): 180 | """ 181 | Console-based demo 182 | """ 183 | memn2n = MemN2N(data_dir, model_file) 184 | 185 | # Try to load model 186 | memn2n.load_model() 187 | 188 | # Read test data 189 | print("Reading test data from %s ..." % memn2n.data_dir) 190 | test_data_path = glob.glob('%s/qa*_*_test.txt' % memn2n.data_dir) 191 | test_story, test_questions, test_qstory = \ 192 | parse_babi_task(test_data_path, memn2n.general_config.dictionary, False) 193 | 194 | while True: 195 | # Pick a random question 196 | question_idx = np.random.randint(test_questions.shape[1]) 197 | story_idx = test_questions[0, question_idx] 198 | last_sentence_idx = test_questions[1, question_idx] 199 | 200 | # Get story and question 201 | story_txt, question_txt, correct_answer = memn2n.get_story_texts(test_story, test_questions, test_qstory, 202 | question_idx, story_idx, last_sentence_idx) 203 | print("* Story:") 204 | print("\n\t".join(story_txt)) 205 | print("\n* Suggested question:\n\t%s?" % question_txt) 206 | 207 | while True: 208 | user_question = raw_input("Your question (press Enter to use the suggested question):\n\t") 209 | 210 | pred_answer_idx, pred_prob, memory_probs = \ 211 | memn2n.predict_answer(test_story, test_questions, test_qstory, 212 | question_idx, story_idx, last_sentence_idx, 213 | user_question) 214 | 215 | pred_answer = memn2n.reversed_dict[pred_answer_idx] 216 | 217 | print("* Answer: '%s', confidence score = %.2f%%" % (pred_answer, 100. * pred_prob)) 218 | if user_question == '': 219 | if pred_answer == correct_answer: 220 | print(" Correct!") 221 | else: 222 | print(" Wrong. The correct answer is '%s'" % correct_answer) 223 | 224 | print("\n* Explanation:") 225 | print("\t".join(["Memory %d" % (i + 1) for i in range(len(memory_probs))]) + "\tText") 226 | for sent_idx, sent_txt in enumerate(story_txt): 227 | prob_output = "\t".join(["%.3f" % mem_prob for mem_prob in memory_probs[:, sent_idx]]) 228 | print("%s\t%s" % (prob_output, sent_txt)) 229 | 230 | asking_another_question = raw_input("\nDo you want to ask another question? [y/N] ") 231 | if asking_another_question == '' or asking_another_question.lower() == 'n': break 232 | 233 | will_continue = raw_input("Do you want to continue? [Y/n] ") 234 | if will_continue != '' and will_continue.lower() != 'y': break 235 | print("=" * 70) 236 | 237 | 238 | def run_web_demo(data_dir, model_file): 239 | from demo.web import webapp 240 | webapp.init(data_dir, model_file) 241 | webapp.run() 242 | 243 | if __name__ == "__main__": 244 | parser = argparse.ArgumentParser() 245 | parser.add_argument("-d", "--data-dir", default="data/tasks_1-20_v1-2/en", 246 | help="path to dataset directory (default: %(default)s)") 247 | parser.add_argument("-m", "--model-file", default="trained_model/memn2n_model.pklz", 248 | help="model file (default: %(default)s)") 249 | group = parser.add_mutually_exclusive_group() 250 | group.add_argument("-train", "--train", action="store_true", 251 | help="train model (default: %(default)s)") 252 | group.add_argument("-console", "--console-demo", action="store_true", 253 | help="run console-based demo (default: %(default)s)") 254 | group.add_argument("-web", "--web-demo", action="store_true", default=True, 255 | help="run web-based demo (default: %(default)s)") 256 | args = parser.parse_args() 257 | 258 | if not os.path.exists(args.data_dir): 259 | print("The data directory '%s' does not exist. Please download it first." % args.data_dir) 260 | sys.exit(1) 261 | 262 | if args.train: 263 | train_model(args.data_dir, args.model_file) 264 | elif args.console_demo: 265 | run_console_demo(args.data_dir, args.model_file) 266 | else: 267 | run_web_demo(args.data_dir, args.model_file) 268 | -------------------------------------------------------------------------------- /demo/web/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /demo/web/static/script.js: -------------------------------------------------------------------------------- 1 | $(function() { 2 | var $story = $('#story'), 3 | $question = $('#question'), 4 | $answer = $('#answer'), 5 | $getAnswer = $('#get_answer'), 6 | $getStory = $('#get_story'), 7 | $explainTable = $('#explanation'); 8 | 9 | getStory(); 10 | 11 | // Activate tooltip 12 | $('.qa-container').find('.glyphicon-info-sign').tooltip(); 13 | 14 | $getAnswer.on('click', function(e) { 15 | e.preventDefault(); 16 | getAnswer(); 17 | }); 18 | 19 | $getStory.on('click', function(e) { 20 | e.preventDefault(); 21 | getStory(); 22 | }); 23 | 24 | function getStory() { 25 | $.get('/get/story', function(json) { 26 | $story.val(json["story"]); 27 | $question.val(json["question"]); 28 | $question.data('question_idx', json["question_idx"]); 29 | $question.data('suggested_question', json["question"]); // Save suggested question 30 | $answer.val(''); 31 | $answer.data('correct_answer', json["correct_answer"]); 32 | //$explainTable.find('tbody').empty(); 33 | }); 34 | } 35 | 36 | function getAnswer() { 37 | var questionIdx = $question.data('question_idx'), 38 | correctAnswer = $answer.data('correct_answer'), 39 | suggestedQuestion = $question.data('suggested_question'), 40 | question = $question.val(); 41 | 42 | var userQuestion = suggestedQuestion !== question? question : ''; 43 | var url = '/get/answer?question_idx=' + questionIdx + 44 | '&user_question=' + encodeURIComponent(userQuestion); 45 | 46 | $.get(url, function(json) { 47 | var predAnswer = json["pred_answer"], 48 | predProb = json["pred_prob"], 49 | memProbs = json["memory_probs"]; 50 | 51 | var outputMessage = "Answer = '" + predAnswer + "'" + 52 | "\nConfidence score = " + (predProb * 100).toFixed(2) + "%"; 53 | 54 | // Show answer's feedback only if suggested question was used 55 | if (userQuestion === '') { 56 | if (predAnswer === correctAnswer) 57 | outputMessage += "\nCorrect!"; 58 | else 59 | outputMessage += "\nWrong. The correct answer is '" + correctAnswer + "'"; 60 | } 61 | $answer.val(outputMessage); 62 | 63 | // Explain answer 64 | var explanationHtml = []; 65 | var sentenceList = $story.val().split('\n'); 66 | var maxLatestSents = memProbs.length; 67 | var numSents = sentenceList.length; 68 | 69 | for (var i = Math.max(0, numSents - maxLatestSents); i < numSents; i++) { 70 | var rowHtml = []; 71 | rowHtml.push(''); 72 | rowHtml.push('' + sentenceList[i] + ''); 73 | for (var j = 0; j < 3; j++) { 74 | var val = memProbs[i][j].toFixed(2); 75 | if (val > 0) { 76 | rowHtml.push('' + val + ''); 78 | } else { 79 | rowHtml.push('' + val + ''); 80 | } 81 | } 82 | rowHtml.push(''); 83 | explanationHtml.push(rowHtml.join('\n')); 84 | } 85 | $explainTable.find('tbody').html(explanationHtml); 86 | }); 87 | } 88 | }); 89 | -------------------------------------------------------------------------------- /demo/web/static/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | padding-top: 10px; 3 | font-family: 'Roboto', sans-serif; 4 | } 5 | 6 | div.shadow { 7 | border: 1px solid #ccc; 8 | border-radius: 10px; 9 | box-shadow: 0 0 24px #bbb; 10 | background: #fff; 11 | color: #666; 12 | } 13 | 14 | .qa-container, .qa-explain-container { 15 | width: 550px; 16 | /*max-height: 650px;*/ 17 | padding: 20px 50px 40px; 18 | margin-top: 10px; 19 | margin-left: 30px; 20 | } 21 | 22 | .qa-explain-container table th { 23 | font-size: 16px; 24 | font-weight: bold; 25 | } 26 | 27 | .container > h3 { 28 | text-align: center; 29 | } 30 | 31 | .qa-form { 32 | max-width: 600px; 33 | margin: 0 auto; 34 | } 35 | 36 | .qa-form .btn { 37 | width: 200px; 38 | } 39 | 40 | .qa-form .form-control[readonly] { 41 | background-color: #fff; 42 | } 43 | 44 | /* Font */ 45 | .app-header, .qa-form label, .qa-explain-container label, .qa-explain-container th { 46 | font-family: 'Shadows Into Light Two', sans-serif; 47 | } 48 | /*.qa-form input, .qa-form textarea, .qa-form button {*/ 49 | /*font-family: 'Roboto', sans-serif;*/ 50 | /*}*/ 51 | -------------------------------------------------------------------------------- /demo/web/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | MemN2N for bAbI Tasks 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 |

End-To-End Memory Network for bAbI Tasks

21 |
22 |
23 |
24 |
25 |
26 |
27 |

28 | 29 |
30 |
31 |

32 | 33 | 35 |

36 | 38 |
39 |
40 |

41 | 42 |
43 |
44 |
45 | 46 | 47 |
48 |
49 |
50 |
51 |
52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 |
TextMem 1Mem 2Mem 3
64 |
65 |
66 |
67 |
68 | 69 | 70 | -------------------------------------------------------------------------------- /demo/web/webapp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Web-based demo 3 | """ 4 | import glob 5 | import flask 6 | import numpy as np 7 | 8 | from demo.qa import MemN2N 9 | from util import parse_babi_task 10 | 11 | app = flask.Flask(__name__) 12 | memn2n = None 13 | test_story, test_questions, test_qstory = None, None, None 14 | 15 | 16 | def init(data_dir, model_file): 17 | """ Initialize web app """ 18 | global memn2n, test_story, test_questions, test_qstory 19 | 20 | # Try to load model 21 | memn2n = MemN2N(data_dir, model_file) 22 | memn2n.load_model() 23 | 24 | # Read test data 25 | print("Reading test data from %s ..." % memn2n.data_dir) 26 | test_data_path = glob.glob('%s/qa*_*_test.txt' % memn2n.data_dir) 27 | test_story, test_questions, test_qstory = \ 28 | parse_babi_task(test_data_path, memn2n.general_config.dictionary, False) 29 | 30 | 31 | def run(): 32 | app.run() 33 | 34 | 35 | @app.route('/') 36 | def index(): 37 | return flask.render_template("index.html") 38 | 39 | 40 | @app.route('/get/story', methods=['GET']) 41 | def get_story(): 42 | question_idx = np.random.randint(test_questions.shape[1]) 43 | story_idx = test_questions[0, question_idx] 44 | last_sentence_idx = test_questions[1, question_idx] 45 | 46 | story_txt, question_txt, correct_answer = memn2n.get_story_texts(test_story, test_questions, test_qstory, 47 | question_idx, story_idx, last_sentence_idx) 48 | # Format text 49 | story_txt = "\n".join(story_txt) 50 | question_txt += "?" 51 | 52 | return flask.jsonify({ 53 | "question_idx": question_idx, 54 | "story": story_txt, 55 | "question": question_txt, 56 | "correct_answer": correct_answer 57 | }) 58 | 59 | 60 | @app.route('/get/answer', methods=['GET']) 61 | def get_answer(): 62 | question_idx = int(flask.request.args.get('question_idx')) 63 | user_question = flask.request.args.get('user_question', '') 64 | 65 | story_idx = test_questions[0, question_idx] 66 | last_sentence_idx = test_questions[1, question_idx] 67 | 68 | pred_answer_idx, pred_prob, memory_probs = memn2n.predict_answer(test_story, test_questions, test_qstory, 69 | question_idx, story_idx, last_sentence_idx, 70 | user_question) 71 | pred_answer = memn2n.reversed_dict[pred_answer_idx] 72 | 73 | return flask.jsonify({ 74 | "pred_answer" : pred_answer, 75 | "pred_prob" : pred_prob, 76 | "memory_probs": memory_probs.T.tolist() 77 | }) 78 | -------------------------------------------------------------------------------- /memn2n/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /memn2n/memory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from memn2n.nn import ElemMult, Identity, Sequential, LookupTable, Module 4 | from memn2n.nn import Sum, Parallel, Softmax, MatVecProd 5 | 6 | 7 | class Memory(Module): 8 | """ 9 | Memory: 10 | Query module = Parallel(LookupTable + Identity) + MatVecProd with transpose + Softmax 11 | Output module = Parallel(LookupTable + Identity) + MatVecProd 12 | """ 13 | def __init__(self, train_config): 14 | super(Memory, self).__init__() 15 | 16 | self.sz = train_config["sz"] 17 | self.voc_sz = train_config["voc_sz"] 18 | self.in_dim = train_config["in_dim"] 19 | self.out_dim = train_config["out_dim"] 20 | 21 | # TODO: Mark self.nil_word and self.data as None since they will be overriden eventually 22 | # In build.model.py, memory[i].nil_word = dictionary['nil']" 23 | self.nil_word = train_config["voc_sz"] 24 | self.config = train_config 25 | self.data = np.zeros((self.sz, train_config["bsz"]), np.float32) 26 | 27 | self.emb_query = None 28 | self.emb_out = None 29 | self.mod_query = None 30 | self.mod_out = None 31 | self.probs = None 32 | 33 | self.init_query_module() 34 | self.init_output_module() 35 | 36 | def init_query_module(self): 37 | self.emb_query = LookupTable(self.voc_sz, self.in_dim) 38 | p = Parallel() 39 | p.add(self.emb_query) 40 | p.add(Identity()) 41 | 42 | self.mod_query = Sequential() 43 | self.mod_query.add(p) 44 | self.mod_query.add(MatVecProd(True)) 45 | self.mod_query.add(Softmax()) 46 | 47 | def init_output_module(self): 48 | self.emb_out = LookupTable(self.voc_sz, self.out_dim) 49 | p = Parallel() 50 | p.add(self.emb_out) 51 | p.add(Identity()) 52 | 53 | self.mod_out = Sequential() 54 | self.mod_out.add(p) 55 | self.mod_out.add(MatVecProd(False)) 56 | 57 | def reset(self): 58 | self.data[:] = self.nil_word 59 | 60 | def put(self, data_row): 61 | self.data[1:, :] = self.data[:-1, :] # shift rows down 62 | self.data[0, :] = data_row # add the new data row on top 63 | 64 | def fprop(self, input_data): 65 | self.probs = self.mod_query.fprop([self.data, input_data]) 66 | self.output = self.mod_out.fprop([self.data, self.probs]) 67 | return self.output 68 | 69 | def bprop(self, input_data, grad_output): 70 | g1 = self.mod_out.bprop([self.data, self.probs], grad_output) 71 | g2 = self.mod_query.bprop([self.data, input_data], g1[1]) 72 | self.grad_input = g2[1] 73 | return self.grad_input 74 | 75 | def update(self, params): 76 | self.mod_out.update(params) 77 | self.mod_query.update(params) 78 | self.emb_out.weight.D[:, self.nil_word] = 0 79 | 80 | def share(self, m): 81 | pass 82 | 83 | 84 | class MemoryBoW(Memory): 85 | """ 86 | MemoryBoW: 87 | Query module = Parallel((LookupTable + Sum(1)) + Identity) + MatVecProd with transpose + Softmax 88 | Output module = Parallel((LookupTable + Sum(1)) + Identity) + MatVecProd 89 | """ 90 | def __init__(self, config): 91 | super(MemoryBoW, self).__init__(config) 92 | self.data = np.zeros((config["max_words"], self.sz, config["bsz"]), np.float32) 93 | 94 | def init_query_module(self): 95 | self.emb_query = LookupTable(self.voc_sz, self.in_dim) 96 | s = Sequential() 97 | s.add(self.emb_query) 98 | s.add(Sum(dim=1)) 99 | 100 | p = Parallel() 101 | p.add(s) 102 | p.add(Identity()) 103 | 104 | self.mod_query = Sequential() 105 | self.mod_query.add(p) 106 | self.mod_query.add(MatVecProd(True)) 107 | self.mod_query.add(Softmax()) 108 | 109 | def init_output_module(self): 110 | self.emb_out = LookupTable(self.voc_sz, self.out_dim) 111 | s = Sequential() 112 | s.add(self.emb_out) 113 | s.add(Sum(dim=1)) 114 | 115 | p = Parallel() 116 | p.add(s) 117 | p.add(Identity()) 118 | 119 | self.mod_out = Sequential() 120 | self.mod_out.add(p) 121 | self.mod_out.add(MatVecProd(False)) 122 | 123 | 124 | class MemoryL(Memory): 125 | """ 126 | MemoryL: 127 | Query module = Parallel((LookupTable + ElemMult + Sum(1)) + Identity) + MatVecProd with transpose + Softmax 128 | Output module = Parallel((LookupTable + ElemMult + Sum(1)) + Identity) + MatVecProd 129 | """ 130 | def __init__(self, train_config): 131 | super(MemoryL, self).__init__(train_config) 132 | self.data = np.zeros((train_config["max_words"], self.sz, train_config["bsz"]), np.float32) 133 | 134 | def init_query_module(self): 135 | self.emb_query = LookupTable(self.voc_sz, self.in_dim) 136 | s = Sequential() 137 | s.add(self.emb_query) 138 | s.add(ElemMult(self.config["weight"])) 139 | s.add(Sum(dim=1)) 140 | 141 | p = Parallel() 142 | p.add(s) 143 | p.add(Identity()) 144 | 145 | self.mod_query = Sequential() 146 | self.mod_query.add(p) 147 | self.mod_query.add(MatVecProd(True)) 148 | self.mod_query.add(Softmax()) 149 | 150 | def init_output_module(self): 151 | self.emb_out = LookupTable(self.voc_sz, self.out_dim) 152 | s = Sequential() 153 | s.add(self.emb_out) 154 | s.add(ElemMult(self.config["weight"])) 155 | s.add(Sum(dim=1)) 156 | 157 | p = Parallel() 158 | p.add(s) 159 | p.add(Identity()) 160 | 161 | self.mod_out = Sequential() 162 | self.mod_out.add(p) 163 | self.mod_out.add(MatVecProd(False)) 164 | 165 | -------------------------------------------------------------------------------- /memn2n/nn.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import numpy as np 4 | 5 | # Ignore division by zero which will happen CrossEntropyLoss.fprop() 6 | # when Softmax is not included at the end layer. 7 | np.seterr(divide='ignore') 8 | 9 | 10 | class Module(object): 11 | """ 12 | Abstract Module class for neural net 13 | """ 14 | __metaclass__ = ABCMeta 15 | 16 | def __init__(self): 17 | self.output = None 18 | self.grad_input = None 19 | 20 | @abstractmethod 21 | def fprop(self, input_data): 22 | self.output = input_data 23 | return self.output 24 | 25 | @abstractmethod 26 | def bprop(self, input_data, grad_output): 27 | self.grad_input = grad_output 28 | return self.grad_input 29 | 30 | @abstractmethod 31 | def update(self, params): 32 | pass 33 | 34 | @abstractmethod 35 | def share(self, m): 36 | pass 37 | 38 | 39 | class Container(Module): 40 | """ 41 | Container 42 | """ 43 | def __init__(self): 44 | super(Container, self).__init__() 45 | self.modules = [] 46 | 47 | def add(self, m): 48 | self.modules.append(m) 49 | 50 | def update(self, params): 51 | for module in self.modules: 52 | module.update(params) 53 | 54 | def share(self, m): 55 | for c_module, m_module in zip(self.modules, m.modules): 56 | c_module.share(m_module) 57 | 58 | 59 | class Sequential(Container): 60 | 61 | def fprop(self, input_data): 62 | temp = input_data 63 | for module in self.modules: 64 | temp = module.fprop(temp) 65 | 66 | self.output = temp 67 | return self.output 68 | 69 | def bprop(self, input_data, grad_output): 70 | for i in range(len(self.modules) - 1, 0, -1): 71 | grad_input = self.modules[i].bprop(self.modules[i - 1].output, grad_output) 72 | grad_output = grad_input 73 | grad_input = self.modules[0].bprop(input_data, grad_output) 74 | 75 | self.grad_input = grad_input 76 | return self.grad_input 77 | 78 | 79 | class Parallel(Container): 80 | """ 81 | Computes forward and backward propagations for all modules at once. 82 | """ 83 | def fprop(self, input_data): 84 | self.output = [module.fprop(input_elem) 85 | for module, input_elem in zip(self.modules, input_data)] 86 | return self.output 87 | 88 | def bprop(self, input_data, grad_output): 89 | self.grad_input = [module.bprop(input_elem, grad_output_elem) 90 | for module, input_elem, grad_output_elem 91 | in zip(self.modules, input_data, grad_output)] 92 | return self.grad_input 93 | 94 | 95 | class AddTable(Module): 96 | """ 97 | Module for sum operator which sums up all elements in input data 98 | """ 99 | def __init__(self): 100 | super(AddTable, self).__init__() 101 | 102 | def fprop(self, input_data): 103 | self.output = input_data[0] 104 | for elem in input_data[1:]: 105 | # Expand to the same ndim as self.output 106 | # TODO: Code improvement 107 | if elem.ndim == self.output.ndim - 1: 108 | elem = np.expand_dims(elem, axis=elem.ndim + 1) 109 | self.output += elem 110 | return self.output 111 | 112 | def bprop(self, input_data, grad_output): 113 | self.grad_input = [grad_output for _ in range(len(input_data))] 114 | return self.grad_input 115 | 116 | def share(self, m): 117 | pass 118 | 119 | def update(self, params): 120 | pass 121 | 122 | 123 | class ConstMult(Module): 124 | """ 125 | Module for multiplying with a constant 126 | """ 127 | def __init__(self, c): 128 | super(ConstMult, self).__init__() 129 | self.c = c 130 | 131 | def fprop(self, input_data): 132 | self.output = self.c * input_data 133 | return self.output 134 | 135 | def bprop(self, input_data, grad_output): 136 | self.grad_input = self.c * grad_output 137 | return self.grad_input 138 | 139 | def share(self, m): 140 | pass 141 | 142 | def update(self, params): 143 | pass 144 | 145 | 146 | class Duplicate(Module): 147 | """ 148 | Duplicate module which essentially makes a clone for input data 149 | """ 150 | def __init__(self): 151 | super(Duplicate, self).__init__() 152 | 153 | def fprop(self, input_data): 154 | self.output = [input_data, input_data] 155 | return self.output 156 | 157 | def bprop(self, input_data, grad_output): 158 | self.grad_input = grad_output[0] + grad_output[1] 159 | return self.grad_input 160 | 161 | def share(self, m): 162 | pass 163 | 164 | def update(self, params): 165 | pass 166 | 167 | 168 | class ElemMult(Module): 169 | """ 170 | Module for element-wise product 171 | """ 172 | def __init__(self, weight): 173 | super(ElemMult, self).__init__() 174 | self.weight = weight 175 | 176 | def fprop(self, input_data): 177 | # TODO: Rewrite these checkings!!! 178 | if input_data.ndim == 2: 179 | self.output = input_data * self.weight 180 | elif input_data.ndim == 3: 181 | self.output = input_data * self.weight[:, :, None] # broadcasting 182 | elif input_data.ndim == 4: 183 | self.output = input_data * self.weight[:, :, None, None] # broadcasting 184 | else: 185 | raise Exception("input_data has large dimension = %d" % input_data.ndim) 186 | return self.output 187 | 188 | def bprop(self, input_data, grad_output): 189 | # TODO: Same as above. 190 | if input_data.ndim == 2: 191 | self.grad_input = grad_output * self.weight 192 | elif input_data.ndim == 3: 193 | self.grad_input = grad_output * self.weight[:, :, None] # broadcasting 194 | elif input_data.ndim == 4: 195 | self.grad_input = grad_output * self.weight[:, :, None, None] # broadcasting 196 | else: 197 | raise Exception("input_data has large dimension = %d" % input_data.ndim) 198 | return self.grad_input 199 | 200 | def share(self, m): 201 | pass 202 | 203 | def update(self, params): 204 | pass 205 | 206 | 207 | class Identity(Module): 208 | """ 209 | Identical forward and backward propagations 210 | """ 211 | def __init__(self): 212 | super(Identity, self).__init__() 213 | 214 | def fprop(self, input_data): 215 | self.output = input_data 216 | return self.output 217 | 218 | def bprop(self, input_data, grad_output): 219 | self.grad_input = grad_output 220 | return self.grad_input 221 | 222 | def share(self, m): 223 | pass 224 | 225 | def update(self, params): 226 | pass 227 | 228 | 229 | class Linear(Module): 230 | """ 231 | Linear Layer 232 | """ 233 | def __init__(self, in_dim, out_dim): 234 | super(Linear, self).__init__() 235 | self.in_dim = in_dim 236 | self.out_dim = out_dim 237 | self.weight = Weight((out_dim, in_dim)) 238 | self.bias = Weight((out_dim, 1)) 239 | 240 | def fprop(self, input_data): 241 | high_dimension_input = input_data.ndim > 2 242 | 243 | # Reshape input 244 | if high_dimension_input: 245 | input_data = input_data.reshape(input_data.shape[0], -1) 246 | 247 | self.output = np.dot(self.weight.D, input_data) + self.bias.D 248 | 249 | # Reshape output 250 | if high_dimension_input: 251 | self.output = self.output.reshape(self.output.shape[0], -1) 252 | 253 | return self.out_dim 254 | 255 | def bprop(self, input_data, grad_output): 256 | orig_input_data_shape = input_data.shape 257 | high_dimension_input = input_data.ndim > 2 258 | 259 | # Reshape input and grad_output 260 | if high_dimension_input: 261 | input_data = input_data.reshape(input_data.shape[0], -1) 262 | grad_output = grad_output.reshape(grad_output.shape[0], -1) 263 | 264 | self.weight.grad = self.weight.grad + np.dot(grad_output, input_data.T) 265 | self.bias.grad = self.bias.grad + grad_output.sum(axis=1) 266 | self.grad_input = np.dot(self.weight.D.T, grad_output) 267 | 268 | if high_dimension_input: 269 | self.grad_input = self.grad_input.reshape(orig_input_data_shape) 270 | 271 | return self.grad_input 272 | 273 | def update(self, params): 274 | self.weight.update(params) 275 | self.bias.update(params) 276 | 277 | def share(self, m): 278 | self.weight = m.weight 279 | self.bias = m.bias 280 | 281 | 282 | class LinearNB(Module): 283 | """ 284 | Linear layer with no bias 285 | """ 286 | def __init__(self, in_dim, out_dim, do_transpose=False): 287 | super(LinearNB, self).__init__() 288 | self.in_dim = in_dim 289 | self.out_dim = out_dim 290 | self.do_transpose = do_transpose 291 | 292 | if do_transpose: 293 | self.weight = Weight((in_dim, out_dim)) 294 | else: 295 | self.weight = Weight((out_dim, in_dim)) 296 | 297 | def fprop(self, input_data): 298 | high_dimension_input = input_data.ndim > 2 299 | 300 | if high_dimension_input: 301 | input_data = input_data.reshape(input_data.shape[0], -1) 302 | 303 | if self.do_transpose: 304 | self.output = np.dot(self.weight.D.T, input_data) 305 | else: 306 | self.output = np.dot(self.weight.D, input_data) 307 | 308 | if high_dimension_input: 309 | self.output = self.output.reshape(self.output.shape[0], -1) 310 | 311 | return self.output 312 | 313 | def bprop(self, input_data, grad_output): 314 | orig_input_data_shape = input_data.shape 315 | high_dimension_input = input_data.ndim > 2 316 | 317 | # Reshape input and grad_output 318 | if high_dimension_input: 319 | input_data = input_data.reshape(input_data.shape[0], -1) 320 | grad_output = grad_output.reshape(grad_output.shape[0], -1) 321 | 322 | if self.do_transpose: 323 | self.weight.grad = self.weight.grad + np.dot(input_data, grad_output.T) 324 | self.grad_input = np.dot(self.weight.D, grad_output) 325 | else: 326 | self.weight.grad = self.weight.grad + np.dot(grad_output, input_data.T) 327 | self.grad_input = np.dot(self.weight.D.T, grad_output) 328 | 329 | if high_dimension_input: 330 | self.grad_input = self.grad_input.reshape(orig_input_data_shape) 331 | 332 | return self.grad_input 333 | 334 | def update(self, params): 335 | self.weight.update(params) 336 | 337 | def share(self, m): 338 | self.weight = m.weight 339 | 340 | 341 | class LookupTable(Module): 342 | """ 343 | Lookup table 344 | """ 345 | def __init__(self, voc_sz, out_dim): 346 | """ 347 | Constructor 348 | 349 | Args: 350 | voc_sz (int): vocabulary size 351 | out_dim (int): output dimension 352 | """ 353 | super(LookupTable, self).__init__() 354 | self.sz = voc_sz 355 | self.out_dim = out_dim 356 | self.weight = Weight((out_dim, voc_sz)) 357 | 358 | def fprop(self, input_data): 359 | self.output = self.weight.D[:, input_data.T.astype(np.int).flatten()] 360 | # Matlab's reshape uses Fortran order (i.e. column first) 361 | self.output = np.squeeze(self.output.reshape((self.out_dim,) + input_data.shape, order='F')) 362 | return self.output 363 | 364 | def bprop(self, input_data, grad_output): 365 | # Make sure input_data has one dim lower than grad_output (see the index below) 366 | if input_data.ndim == grad_output.ndim: 367 | input_data = np.squeeze(input_data) # TODO: Seems clumsy! 368 | 369 | input_data = input_data.astype(int) 370 | c = np.unique(input_data.flatten()) 371 | for i in np.nditer(c): 372 | self.weight.grad[:, i] += np.sum(grad_output[:, input_data == i], axis=1) 373 | 374 | self.grad_input = [] 375 | return self.grad_input 376 | 377 | def update(self, params): 378 | self.weight.update(params) 379 | 380 | def share(self, m): 381 | self.weight = m.weight 382 | 383 | 384 | class MatVecProd(Module): 385 | """ 386 | Product of matrix and vector in batch, where 387 | matrix's shape is [:, :, batch] and vectors is [:, batch] 388 | Result is a vector of size [:, batch] 389 | """ 390 | def __init__(self, do_transpose): 391 | super(MatVecProd, self).__init__() 392 | self.do_transpose = do_transpose 393 | 394 | def fprop(self, input_data): 395 | M = input_data[0] 396 | V = input_data[1] 397 | 398 | # Expand M to 3-dimension and V to 2-dimension 399 | if M.ndim == 2: 400 | M = np.expand_dims(M, axis=2) 401 | if V.ndim == 1: 402 | V = np.expand_dims(V, axis=1) 403 | 404 | batch_size = M.shape[2] 405 | 406 | if self.do_transpose: 407 | self.output = np.zeros((M.shape[1], batch_size), np.float32) 408 | for i in range(batch_size): 409 | self.output[:, i] = np.dot(M[:, :, i].T, V[:, i]) 410 | else: 411 | self.output = np.zeros((M.shape[0], batch_size), np.float32) 412 | for i in range(batch_size): 413 | self.output[:, i] = np.dot(M[:, :, i], V[:, i]) 414 | 415 | return self.output 416 | 417 | def bprop(self, input_data, grad_output): 418 | M = input_data[0] 419 | V = input_data[1] 420 | 421 | # Expand M to 3-dimension and V to 2-dimension 422 | if M.ndim == 2: 423 | M = np.expand_dims(M, axis=2) 424 | if V.ndim == 1: 425 | V = np.expand_dims(V, axis=1) 426 | 427 | batch_size = M.shape[2] 428 | 429 | grad_M = np.zeros_like(M, np.float32) 430 | grad_V = np.zeros_like(V, np.float32) 431 | 432 | for i in range(batch_size): 433 | if self.do_transpose: 434 | grad_M[:, :, i] = np.dot(V[:, [i]], grad_output[:, [i]].T) 435 | grad_V[:, i] = np.dot(M[:, :, i], grad_output[:, i]) 436 | else: 437 | grad_M[:, :, i] = np.dot(grad_output[:, [i]], V[:, [i]].T) 438 | grad_V[:, i] = np.dot(M[:, :, i].T, grad_output[:, i]) 439 | 440 | self.grad_input = (grad_M, grad_V) 441 | return self.grad_input 442 | 443 | def update(self, params): 444 | pass 445 | 446 | def share(self, m): 447 | pass 448 | 449 | 450 | class ReLU(Module): 451 | """ ReLU module """ 452 | 453 | def fprop(self, input_data): 454 | self.output = np.multiply(input_data, input_data > 0) 455 | return self.output 456 | 457 | def bprop(self, input_data, grad_output): 458 | self.grad_input = np.multiply(grad_output, input_data > 0) 459 | return self.grad_input 460 | 461 | def update(self, params): 462 | pass 463 | 464 | def share(self, m): 465 | pass 466 | 467 | 468 | class SelectTable(Module): 469 | """ SelectTable which slices input data in a specific dimension """ 470 | 471 | def __init__(self, index): 472 | super(SelectTable, self).__init__() 473 | self.index = index 474 | 475 | def fprop(self, input_data): 476 | self.output = input_data[self.index] 477 | return self.output 478 | 479 | def bprop(self, input_data, grad_output): 480 | self.grad_input = [grad_output if i == self.index 481 | else np.zeros_like(input_elem, np.float32) 482 | for i, input_elem in enumerate(input_data)] 483 | return self.grad_input 484 | 485 | def update(self, params): 486 | pass 487 | 488 | def share(self, m): 489 | pass 490 | 491 | 492 | class Sigmoid(Module): 493 | 494 | def fprop(self, input_data): 495 | self.output = 1. / (1 + np.exp(-input_data)) 496 | return self.output 497 | 498 | def bprop(self, input_data, grad_output): 499 | return grad_output * self.output * (1. - self.output) 500 | 501 | def update(self, params): 502 | pass 503 | 504 | def share(self, m): 505 | pass 506 | 507 | 508 | class Softmax(Module): 509 | 510 | def __init__(self, skip_bprop=False): 511 | super(Softmax, self).__init__() 512 | self.skip_bprop = skip_bprop # for the output module 513 | 514 | def fprop(self, input_data): 515 | input_data -= np.max(input_data, axis=0) 516 | input_data += 1.0 517 | 518 | a = np.exp(input_data) 519 | sum_a = a.sum(axis=0) 520 | 521 | self.output = a / sum_a[None, :] # divide by row 522 | return self.output 523 | 524 | def bprop(self, input_data, grad_output): 525 | if not self.skip_bprop: 526 | z = grad_output - np.sum(self.output * grad_output, axis=0) 527 | self.grad_input = self.output * z 528 | else: 529 | self.grad_input = grad_output 530 | 531 | return self.grad_input 532 | 533 | def update(self, params): 534 | pass 535 | 536 | def share(self, m): 537 | pass 538 | 539 | 540 | class Sum(Module): 541 | """ 542 | Sum module which sums up input data at specified dimension 543 | """ 544 | def __init__(self, dim): 545 | super(Sum, self).__init__() 546 | self.dim = dim 547 | 548 | def fprop(self, input_data): 549 | self.output = np.squeeze(np.sum(input_data, axis=self.dim)) 550 | return self.output 551 | 552 | def bprop(self, input_data, grad_output): 553 | # TODO: Seems clumsy! 554 | sz = np.array(input_data.shape) 555 | sz[self.dim] = 1 556 | grad_output = grad_output.reshape(sz) 557 | sz[:] = 1 558 | sz[self.dim] = input_data.shape[self.dim] 559 | self.grad_input = np.tile(grad_output, sz) 560 | return self.grad_input 561 | 562 | def update(self, params): 563 | pass 564 | 565 | def share(self, m): 566 | pass 567 | 568 | 569 | class Weight(object): 570 | def __init__(self, sz): 571 | """ 572 | Initialize weight 573 | Args: 574 | sz (tuple): shape 575 | """ 576 | self.sz = sz 577 | self.D = 0.1 * np.random.standard_normal(sz) 578 | self.grad = np.zeros(sz, np.float32) 579 | 580 | def update(self, params): 581 | """ 582 | Update weights 583 | """ 584 | max_grad_norm = params.get('max_grad_norm') 585 | if max_grad_norm and max_grad_norm > 0: 586 | grad_norm = np.linalg.norm(self.grad, 2) 587 | if grad_norm > max_grad_norm: 588 | self.grad = self.grad * max_grad_norm / grad_norm 589 | 590 | self.D -= params['lrate'] * self.grad 591 | self.grad[:] = 0 592 | 593 | def clone(self): 594 | m = Weight(self.sz) 595 | m.D = np.copy(self.D) 596 | m.grad = np.copy(self.grad) 597 | return m 598 | 599 | 600 | class Loss(object): 601 | """ Abstract Loss class """ 602 | __metaclass__ = ABCMeta 603 | 604 | @abstractmethod 605 | def fprop(self, input_data, target_data): 606 | """ Abstract function for forward propagation """ 607 | pass 608 | 609 | @abstractmethod 610 | def bprop(self, input_data, target_data): 611 | """ Abstract function for back-propagation """ 612 | pass 613 | 614 | 615 | class CrossEntropyLoss(Loss): 616 | 617 | def __init__(self): 618 | self.do_softmax_bprop = False 619 | self.eps = 1e-7 620 | self.size_average = True 621 | 622 | def fprop(self, input_data, target_data): 623 | tmp = [(t, i) for i, t in enumerate(target_data)] 624 | z = zip(*tmp) # unzipping trick ! 625 | cost = np.sum(-np.log(input_data[z])) 626 | if self.size_average: 627 | cost /= input_data.shape[1] 628 | 629 | return cost 630 | 631 | def bprop(self, input_data, target_data): 632 | tmp = [(t, i) for i, t in enumerate(target_data)] 633 | z = zip(*tmp) 634 | 635 | if self.do_softmax_bprop: 636 | grad_input = input_data 637 | grad_input[z] -= 1 638 | else: 639 | grad_input = np.zeros_like(input_data, np.float32) 640 | grad_input[z] = -1. / (input_data[z] + self.eps) 641 | 642 | if self.size_average: 643 | grad_input /= input_data.shape[1] 644 | 645 | return grad_input 646 | 647 | def get_error(self, input_data, target_data): 648 | y = input_data.argmax(axis=0) 649 | return np.sum(y != target_data) 650 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | Flask 3 | -------------------------------------------------------------------------------- /train_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import math 4 | import numpy as np 5 | 6 | from memn2n.nn import Softmax 7 | from util import Progress 8 | 9 | 10 | def train(train_story, train_questions, train_qstory, memory, model, loss, general_config): 11 | 12 | train_config = general_config.train_config 13 | dictionary = general_config.dictionary 14 | nepochs = general_config.nepochs 15 | nhops = general_config.nhops 16 | batch_size = general_config.batch_size 17 | enable_time = general_config.enable_time 18 | randomize_time = general_config.randomize_time 19 | lrate_decay_step = general_config.lrate_decay_step 20 | 21 | train_range = general_config.train_range # indices of training questions 22 | val_range = general_config.val_range # indices of validation questions 23 | train_len = len(train_range) 24 | val_len = len(val_range) 25 | 26 | params = { 27 | "lrate": train_config["init_lrate"], 28 | "max_grad_norm": train_config["max_grad_norm"] 29 | } 30 | 31 | for ep in range(nepochs): 32 | # Decrease learning rate after every decay step 33 | if (ep + 1) % lrate_decay_step == 0: 34 | params["lrate"] *= 0.5 35 | 36 | total_err = 0. 37 | total_cost = 0. 38 | total_num = 0 39 | for _ in Progress(range(int(math.floor(train_len / batch_size)))): 40 | # Question batch 41 | batch = train_range[np.random.randint(train_len, size=batch_size)] 42 | 43 | input_data = np.zeros((train_story.shape[0], batch_size), np.float32) # words of training questions 44 | target_data = train_questions[2, batch] # indices of training answers 45 | 46 | memory[0].data[:] = dictionary["nil"] 47 | 48 | # Compose batch of training data 49 | for b in range(batch_size): 50 | # NOTE: +1 since train_questions[1, :] is the index of the sentence right before the training question. 51 | # d is a batch of [word indices in sentence, sentence indices from batch] for this story 52 | d = train_story[:, :(1 + train_questions[1, batch[b]]), train_questions[0, batch[b]]] 53 | 54 | # Pick a fixed number of latest sentences (before the question) from the story 55 | offset = max(0, d.shape[1] - train_config["sz"]) 56 | d = d[:, offset:] 57 | 58 | # Training data for the 1st memory cell 59 | memory[0].data[:d.shape[0], :d.shape[1], b] = d 60 | 61 | if enable_time: 62 | # Inject noise into time index (i.e. word index) 63 | if randomize_time > 0: 64 | # Random number of blank (must be < total sentences until the training question?) 65 | nblank = np.random.randint(int(math.ceil(d.shape[1] * randomize_time))) 66 | rt = np.random.permutation(d.shape[1] + nblank) 67 | 68 | rt[rt >= train_config["sz"]] = train_config["sz"] - 1 # put the cap 69 | 70 | # Add random time (must be > dictionary's length) into the time word (decreasing order) 71 | memory[0].data[-1, :d.shape[1], b] = np.sort(rt[:d.shape[1]])[::-1] + len(dictionary) 72 | 73 | else: 74 | memory[0].data[-1, :d.shape[1], b] = \ 75 | np.arange(d.shape[1])[::-1] + len(dictionary) 76 | 77 | input_data[:, b] = train_qstory[:, batch[b]] 78 | 79 | for i in range(1, nhops): 80 | memory[i].data = memory[0].data 81 | 82 | out = model.fprop(input_data) 83 | total_cost += loss.fprop(out, target_data) 84 | total_err += loss.get_error(out, target_data) 85 | total_num += batch_size 86 | 87 | grad = loss.bprop(out, target_data) 88 | model.bprop(input_data, grad) 89 | model.update(params) 90 | 91 | for i in range(nhops): 92 | memory[i].emb_query.weight.D[:, 0] = 0 93 | 94 | # Validation 95 | total_val_err = 0. 96 | total_val_cost = 0. 97 | total_val_num = 0 98 | 99 | for k in range(int(math.floor(val_len / batch_size))): 100 | batch = val_range[np.arange(k * batch_size, (k + 1) * batch_size)] 101 | input_data = np.zeros((train_story.shape[0], batch_size), np.float32) 102 | target_data = train_questions[2, batch] 103 | 104 | memory[0].data[:] = dictionary["nil"] 105 | 106 | for b in range(batch_size): 107 | d = train_story[:, :(1 + train_questions[1, batch[b]]), train_questions[0, batch[b]]] 108 | 109 | offset = max(0, d.shape[1] - train_config["sz"]) 110 | d = d[:, offset:] 111 | 112 | # Data for the 1st memory cell 113 | memory[0].data[:d.shape[0], :d.shape[1], b] = d 114 | 115 | if enable_time: 116 | memory[0].data[-1, :d.shape[1], b] = np.arange(d.shape[1])[::-1] + len(dictionary) 117 | 118 | input_data[:, b] = train_qstory[:, batch[b]] 119 | 120 | for i in range(1, nhops): 121 | memory[i].data = memory[0].data 122 | 123 | out = model.fprop(input_data) 124 | total_val_cost += loss.fprop(out, target_data) 125 | total_val_err += loss.get_error(out, target_data) 126 | total_val_num += batch_size 127 | 128 | train_error = total_err / total_num 129 | val_error = total_val_err / total_val_num 130 | 131 | print("%d | train error: %g | val error: %g" % (ep + 1, train_error, val_error)) 132 | 133 | 134 | def train_linear_start(train_story, train_questions, train_qstory, memory, model, loss, general_config): 135 | 136 | train_config = general_config.train_config 137 | 138 | # Remove softmax from memory 139 | for i in range(general_config.nhops): 140 | memory[i].mod_query.modules.pop() 141 | 142 | # Save settings 143 | nepochs2 = general_config.nepochs 144 | lrate_decay_step2 = general_config.lrate_decay_step 145 | init_lrate2 = train_config["init_lrate"] 146 | 147 | # Add new settings 148 | general_config.nepochs = general_config.ls_nepochs 149 | general_config.lrate_decay_step = general_config.ls_lrate_decay_step 150 | train_config["init_lrate"] = general_config.ls_init_lrate 151 | 152 | # Train with new settings 153 | train(train_story, train_questions, train_qstory, memory, model, loss, general_config) 154 | 155 | # Add softmax back 156 | for i in range(general_config.nhops): 157 | memory[i].mod_query.add(Softmax()) 158 | 159 | # Restore old settings 160 | general_config.nepochs = nepochs2 161 | general_config.lrate_decay_step = lrate_decay_step2 162 | train_config["init_lrate"] = init_lrate2 163 | 164 | # Train with old settings 165 | train(train_story, train_questions, train_qstory, memory, model, loss, general_config) 166 | 167 | 168 | def test(test_story, test_questions, test_qstory, memory, model, loss, general_config): 169 | total_test_err = 0. 170 | total_test_num = 0 171 | 172 | nhops = general_config.nhops 173 | train_config = general_config.train_config 174 | batch_size = general_config.batch_size 175 | dictionary = general_config.dictionary 176 | enable_time = general_config.enable_time 177 | 178 | max_words = train_config["max_words"] \ 179 | if not enable_time else train_config["max_words"] - 1 180 | 181 | for k in range(int(math.floor(test_questions.shape[1] / batch_size))): 182 | batch = np.arange(k * batch_size, (k + 1) * batch_size) 183 | 184 | input_data = np.zeros((max_words, batch_size), np.float32) 185 | target_data = test_questions[2, batch] 186 | 187 | input_data[:] = dictionary["nil"] 188 | memory[0].data[:] = dictionary["nil"] 189 | 190 | for b in range(batch_size): 191 | d = test_story[:, :(1 + test_questions[1, batch[b]]), test_questions[0, batch[b]]] 192 | 193 | offset = max(0, d.shape[1] - train_config["sz"]) 194 | d = d[:, offset:] 195 | 196 | memory[0].data[:d.shape[0], :d.shape[1], b] = d 197 | 198 | if enable_time: 199 | memory[0].data[-1, :d.shape[1], b] = np.arange(d.shape[1])[::-1] + len(dictionary) # time words 200 | 201 | input_data[:test_qstory.shape[0], b] = test_qstory[:, batch[b]] 202 | 203 | for i in range(1, nhops): 204 | memory[i].data = memory[0].data 205 | 206 | out = model.fprop(input_data) 207 | # cost = loss.fprop(out, target_data) 208 | total_test_err += loss.get_error(out, target_data) 209 | total_test_num += batch_size 210 | 211 | test_error = total_test_err / total_test_num 212 | print("Test error: %f" % test_error) 213 | -------------------------------------------------------------------------------- /trained_model/README.md: -------------------------------------------------------------------------------- 1 | The model `memn2n_model.pklz` was trained jointly using data from all 20 tasks (1k training questions for 2 | each task). -------------------------------------------------------------------------------- /trained_model/memn2n_model.pklz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinhkhuc/MemN2N-babi-python/84f635138782b02db0f96013bb9189ccf9ca1e06/trained_model/memn2n_model.pklz -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import sys 4 | import time 5 | 6 | import numpy as np 7 | 8 | from memn2n.memory import MemoryL, MemoryBoW 9 | from memn2n.nn import AddTable, CrossEntropyLoss, Duplicate, ElemMult, LinearNB 10 | from memn2n.nn import Identity, ReLU, Sequential, LookupTable, Sum, Parallel, Softmax 11 | 12 | 13 | def parse_babi_task(data_files, dictionary, include_question): 14 | """ Parse bAbI data. 15 | 16 | Args: 17 | data_files (list): a list of data file's paths. 18 | dictionary (dict): word's dictionary 19 | include_question (bool): whether count question toward input sentence. 20 | 21 | Returns: 22 | A tuple of (story, questions, qstory): 23 | story (3-D array) 24 | [position of word in sentence, sentence index, story index] = index of word in dictionary 25 | questions (2-D array) 26 | [0-9, question index], in which the first component is encoded as follows: 27 | 0 - story index 28 | 1 - index of the last sentence before the question 29 | 2 - index of the answer word in dictionary 30 | 3 to 13 - indices of supporting sentence 31 | 14 - line index 32 | qstory (2-D array) question's indices within a story 33 | [index of word in question, question index] = index of word in dictionary 34 | """ 35 | # Try to reserve spaces beforehand (large matrices for both 1k and 10k data sets) 36 | # maximum number of words in sentence = 20 37 | story = np.zeros((20, 500, len(data_files) * 3500), np.int16) 38 | questions = np.zeros((14, len(data_files) * 10000), np.int16) 39 | qstory = np.zeros((20, len(data_files) * 10000), np.int16) 40 | 41 | # NOTE: question's indices are not reset when going through a new story 42 | story_idx, question_idx, sentence_idx, max_words, max_sentences = -1, -1, -1, 0, 0 43 | 44 | # Mapping line number (within a story) to sentence's index (to support the flag include_question) 45 | mapping = None 46 | 47 | for fp in data_files: 48 | with open(fp) as f: 49 | for line_idx, line in enumerate(f): 50 | line = line.rstrip().lower() 51 | words = line.split() 52 | 53 | # Story begins 54 | if words[0] == '1': 55 | story_idx += 1 56 | sentence_idx = -1 57 | mapping = [] 58 | 59 | # FIXME: This condition makes the code more fragile! 60 | if '?' not in line: 61 | is_question = False 62 | sentence_idx += 1 63 | else: 64 | is_question = True 65 | question_idx += 1 66 | questions[0, question_idx] = story_idx 67 | questions[1, question_idx] = sentence_idx 68 | if include_question: 69 | sentence_idx += 1 70 | 71 | mapping.append(sentence_idx) 72 | 73 | # Skip substory index 74 | for k in range(1, len(words)): 75 | w = words[k] 76 | 77 | if w.endswith('.') or w.endswith('?'): 78 | w = w[:-1] 79 | 80 | if w not in dictionary: 81 | dictionary[w] = len(dictionary) 82 | 83 | if max_words < k: 84 | max_words = k 85 | 86 | if not is_question: 87 | story[k - 1, sentence_idx, story_idx] = dictionary[w] 88 | else: 89 | qstory[k - 1, question_idx] = dictionary[w] 90 | if include_question: 91 | story[k - 1, sentence_idx, story_idx] = dictionary[w] 92 | 93 | # NOTE: Punctuation is already removed from w 94 | if words[k].endswith('?'): 95 | answer = words[k + 1] 96 | if answer not in dictionary: 97 | dictionary[answer] = len(dictionary) 98 | 99 | questions[2, question_idx] = dictionary[answer] 100 | 101 | # Indices of supporting sentences 102 | for h in range(k + 2, len(words)): 103 | questions[1 + h - k, question_idx] = mapping[int(words[h]) - 1] 104 | 105 | questions[-1, question_idx] = line_idx 106 | break 107 | 108 | if max_sentences < sentence_idx + 1: 109 | max_sentences = sentence_idx + 1 110 | 111 | story = story[:max_words, :max_sentences, :(story_idx + 1)] 112 | questions = questions[:, :(question_idx + 1)] 113 | qstory = qstory[:max_words, :(question_idx + 1)] 114 | 115 | return story, questions, qstory 116 | 117 | 118 | def build_model(general_config): 119 | """ 120 | Build model 121 | 122 | NOTE: (for default config) 123 | 1) Model's architecture (embedding B) 124 | LookupTable -> ElemMult -> Sum -> [ Duplicate -> { Parallel -> Memory -> Identity } -> AddTable ] -> LinearNB -> Softmax 125 | 126 | 2) Memory's architecture 127 | a) Query module (embedding A) 128 | Parallel -> { LookupTable + ElemMult + Sum } -> Identity -> MatVecProd -> Softmax 129 | 130 | b) Output module (embedding C) 131 | Parallel -> { LookupTable + ElemMult + Sum } -> Identity -> MatVecProd 132 | """ 133 | train_config = general_config.train_config 134 | dictionary = general_config.dictionary 135 | use_bow = general_config.use_bow 136 | nhops = general_config.nhops 137 | add_proj = general_config.add_proj 138 | share_type = general_config.share_type 139 | enable_time = general_config.enable_time 140 | add_nonlin = general_config.add_nonlin 141 | 142 | in_dim = train_config["in_dim"] 143 | out_dim = train_config["out_dim"] 144 | max_words = train_config["max_words"] 145 | voc_sz = train_config["voc_sz"] 146 | 147 | if not use_bow: 148 | train_config["weight"] = np.ones((in_dim, max_words), np.float32) 149 | for i in range(in_dim): 150 | for j in range(max_words): 151 | train_config["weight"][i][j] = (i + 1 - (in_dim + 1) / 2) * \ 152 | (j + 1 - (max_words + 1) / 2) 153 | train_config["weight"] = \ 154 | 1 + 4 * train_config["weight"] / (in_dim * max_words) 155 | 156 | memory = {} 157 | model = Sequential() 158 | model.add(LookupTable(voc_sz, in_dim)) 159 | if not use_bow: 160 | if enable_time: 161 | model.add(ElemMult(train_config["weight"][:, :-1])) 162 | else: 163 | model.add(ElemMult(train_config["weight"])) 164 | 165 | model.add(Sum(dim=1)) 166 | 167 | proj = {} 168 | for i in range(nhops): 169 | if use_bow: 170 | memory[i] = MemoryBoW(train_config) 171 | else: 172 | memory[i] = MemoryL(train_config) 173 | 174 | # Override nil_word which is initialized in "self.nil_word = train_config["voc_sz"]" 175 | memory[i].nil_word = dictionary['nil'] 176 | model.add(Duplicate()) 177 | p = Parallel() 178 | p.add(memory[i]) 179 | 180 | if add_proj: 181 | proj[i] = LinearNB(in_dim, in_dim) 182 | p.add(proj[i]) 183 | else: 184 | p.add(Identity()) 185 | 186 | model.add(p) 187 | model.add(AddTable()) 188 | if add_nonlin: 189 | model.add(ReLU()) 190 | 191 | model.add(LinearNB(out_dim, voc_sz, True)) 192 | model.add(Softmax()) 193 | 194 | # Share weights 195 | if share_type == 1: 196 | # Type 1: adjacent weight tying 197 | memory[0].emb_query.share(model.modules[0]) 198 | for i in range(1, nhops): 199 | memory[i].emb_query.share(memory[i - 1].emb_out) 200 | 201 | model.modules[-2].share(memory[len(memory) - 1].emb_out) 202 | 203 | elif share_type == 2: 204 | # Type 2: layer-wise weight tying 205 | for i in range(1, nhops): 206 | memory[i].emb_query.share(memory[0].emb_query) 207 | memory[i].emb_out.share(memory[0].emb_out) 208 | 209 | if add_proj: 210 | for i in range(1, nhops): 211 | proj[i].share(proj[0]) 212 | 213 | # Cost 214 | loss = CrossEntropyLoss() 215 | loss.size_average = False 216 | loss.do_softmax_bprop = True 217 | model.modules[-1].skip_bprop = True 218 | 219 | return memory, model, loss 220 | 221 | 222 | class Progress(object): 223 | """ 224 | Progress bar 225 | """ 226 | 227 | def __init__(self, iterable, bar_length=50): 228 | self.iterable = iterable 229 | self.bar_length = bar_length 230 | self.total_length = len(iterable) 231 | self.start_time = time.time() 232 | self.count = 0 233 | 234 | def __iter__(self): 235 | for obj in self.iterable: 236 | yield obj 237 | self.count += 1 238 | percent = self.count / self.total_length 239 | print_length = int(percent * self.bar_length) 240 | progress = "=" * print_length + " " * (self.bar_length - print_length) 241 | elapsed_time = time.time() - self.start_time 242 | print_msg = "\r|%s| %.0f%% %.1fs" % (progress, percent * 100, elapsed_time) 243 | sys.stdout.write(print_msg) 244 | if self.count == self.total_length: 245 | sys.stdout.write("\r" + " " * len(print_msg) + "\r") 246 | sys.stdout.flush() 247 | --------------------------------------------------------------------------------