├── .gitignore
├── LICENSE
├── README.md
├── babi_runner.py
├── bechmarks
└── README.md
├── config.py
├── data
└── README.md
├── demo
├── __init__.py
├── qa.py
└── web
│ ├── __init__.py
│ ├── static
│ ├── script.js
│ └── style.css
│ ├── templates
│ └── index.html
│ └── webapp.py
├── memn2n
├── __init__.py
├── memory.py
└── nn.py
├── requirements.txt
├── train_test.py
├── trained_model
├── README.md
└── memn2n_model.pklz
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *.pyc
3 | data/tasks_1-20_v1-2
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD License
2 |
3 | Copyright (c) 2016, Vinh Khuc.
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without modification,
7 | are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name Facebook nor the names of its contributors may be used to
17 | endorse or promote products derived from this software without specific
18 | prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
21 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
22 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
24 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
25 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
26 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
27 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## End-To-End Memory Networks for Question Answering
2 | This is an implementation of MemN2N model in Python for the [bAbI question-answering tasks](http://fb.ai/babi)
3 | as shown in the Section 4 of the paper "[End-To-End Memory Networks](http://arxiv.org/abs/1503.08895)". It is based on
4 | Facebook's [Matlab code](https://github.com/facebook/MemNN/tree/master/MemN2N-babi-matlab).
5 |
6 | 
7 |
8 | ## Requirements
9 | * Python 2.7
10 | * Numpy, Flask (only for web-based demo) can be installed via pip:
11 | ```
12 | $ sudo pip install -r requirements.txt
13 | ```
14 | * [bAbI dataset](http://fb.ai/babi) should be downloaded to `data/tasks_1-20_v1-2`:
15 | ```
16 | $ wget -qO- http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz | tar xvz -C data
17 | ```
18 |
19 | ## Usage
20 | * To run on a single task, use `babi_runner.py` with `-t` followed by task's id. For example,
21 | ```
22 | python babi_runner.py -t 1
23 | ```
24 | The output will look like:
25 | ```
26 | Using data from data/tasks_1-20_v1-2/en
27 | Train and test for task 1 ...
28 | 1 | train error: 0.876116 | val error: 0.75
29 | |=================================== | 71% 0.5s
30 | ```
31 | * To run on 20 tasks:
32 | ```
33 | python babi_runner.py -a
34 | ```
35 | * To train using all training data from 20 tasks, use the joint mode:
36 | ```
37 | python babi_runner.py -j
38 | ```
39 |
40 | ## Question Answering Demo
41 | * In order to run the Web-based demo using the pretrained model `memn2n_model.pklz` in `trained_model/`, run:
42 | ```
43 | python -m demo.qa
44 | ```
45 |
46 | * Alternatively, you can try the console-based demo:
47 | ```
48 | python -m demo.qa -console
49 | ```
50 |
51 | * The pretrained model `memn2n_model.pklz` can be created by running:
52 | ```
53 | python -m demo.qa -train
54 | ```
55 |
56 | * To show all options, run `python -m demo.qa -h`
57 |
58 | ## Benchmarks
59 | See the results [here](https://github.com/vinhkhuc/MemN2N-babi-python/tree/master/bechmarks).
60 |
61 | ### Author
62 | Vinh Khuc
63 |
64 | ### Future Plans
65 | * Port to TensorFlow/Keras
66 | * Support Python 3
67 |
68 | ### References
69 | * Sainbayar Sukhbaatar, Arthur Szlam, Jason Weston, Rob Fergus,
70 | "[End-To-End Memory Networks](http://arxiv.org/abs/1503.08895)",
71 | *arXiv:1503.08895 [cs.NE]*.
--------------------------------------------------------------------------------
/babi_runner.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import random
4 | import sys
5 |
6 | import argparse
7 | import numpy as np
8 |
9 | from config import BabiConfig, BabiConfigJoint
10 | from train_test import train, train_linear_start, test
11 | from util import parse_babi_task, build_model
12 |
13 | seed_val = 42
14 | random.seed(seed_val)
15 | np.random.seed(seed_val) # for reproducing
16 |
17 |
18 | def run_task(data_dir, task_id):
19 | """
20 | Train and test for each task
21 | """
22 | print("Train and test for task %d ..." % task_id)
23 |
24 | # Parse data
25 | train_files = glob.glob('%s/qa%d_*_train.txt' % (data_dir, task_id))
26 | test_files = glob.glob('%s/qa%d_*_test.txt' % (data_dir, task_id))
27 |
28 | dictionary = {"nil": 0}
29 | train_story, train_questions, train_qstory = parse_babi_task(train_files, dictionary, False)
30 | test_story, test_questions, test_qstory = parse_babi_task(test_files, dictionary, False)
31 |
32 | general_config = BabiConfig(train_story, train_questions, dictionary)
33 |
34 | memory, model, loss = build_model(general_config)
35 |
36 | if general_config.linear_start:
37 | train_linear_start(train_story, train_questions, train_qstory, memory, model, loss, general_config)
38 | else:
39 | train(train_story, train_questions, train_qstory, memory, model, loss, general_config)
40 |
41 | test(test_story, test_questions, test_qstory, memory, model, loss, general_config)
42 |
43 |
44 | def run_all_tasks(data_dir):
45 | """
46 | Train and test for all tasks
47 | """
48 | print("Training and testing for all tasks ...")
49 | for t in range(20):
50 | run_task(data_dir, task_id=t + 1)
51 |
52 |
53 | def run_joint_tasks(data_dir):
54 | """
55 | Train and test for all tasks but the trained model is built using training data from all tasks.
56 | """
57 | print("Jointly train and test for all tasks ...")
58 | tasks = range(20)
59 |
60 | # Parse training data
61 | train_data_path = []
62 | for t in tasks:
63 | train_data_path += glob.glob('%s/qa%d_*_train.txt' % (data_dir, t + 1))
64 |
65 | dictionary = {"nil": 0}
66 | train_story, train_questions, train_qstory = parse_babi_task(train_data_path, dictionary, False)
67 |
68 | # Parse test data for each task so that the dictionary covers all words before training
69 | for t in tasks:
70 | test_data_path = glob.glob('%s/qa%d_*_test.txt' % (data_dir, t + 1))
71 | parse_babi_task(test_data_path, dictionary, False) # ignore output for now
72 |
73 | general_config = BabiConfigJoint(train_story, train_questions, dictionary)
74 | memory, model, loss = build_model(general_config)
75 |
76 | if general_config.linear_start:
77 | train_linear_start(train_story, train_questions, train_qstory, memory, model, loss, general_config)
78 | else:
79 | train(train_story, train_questions, train_qstory, memory, model, loss, general_config)
80 |
81 | # Test on each task
82 | for t in tasks:
83 | print("Testing for task %d ..." % (t + 1))
84 | test_data_path = glob.glob('%s/qa%d_*_test.txt' % (data_dir, t + 1))
85 | dc = len(dictionary)
86 | test_story, test_questions, test_qstory = parse_babi_task(test_data_path, dictionary, False)
87 | assert dc == len(dictionary) # make sure that the dictionary already covers all words
88 |
89 | test(test_story, test_questions, test_qstory, memory, model, loss, general_config)
90 |
91 | if __name__ == "__main__":
92 | parser = argparse.ArgumentParser()
93 | parser.add_argument("-d", "--data-dir", default="data/tasks_1-20_v1-2/en",
94 | help="path to dataset directory (default: %(default)s)")
95 | group = parser.add_mutually_exclusive_group()
96 | group.add_argument("-t", "--task", default="1", type=int,
97 | help="train and test for a single task (default: %(default)s)")
98 | group.add_argument("-a", "--all-tasks", action="store_true",
99 | help="train and test for all tasks (one by one) (default: %(default)s)")
100 | group.add_argument("-j", "--joint-tasks", action="store_true",
101 | help="train and test for all tasks (all together) (default: %(default)s)")
102 | args = parser.parse_args()
103 |
104 | # Check if data is available
105 | data_dir = args.data_dir
106 | if not os.path.exists(data_dir):
107 | print("The data directory '%s' does not exist. Please download it first." % data_dir)
108 | sys.exit(1)
109 |
110 | print("Using data from %s" % args.data_dir)
111 | if args.all_tasks:
112 | run_all_tasks(data_dir)
113 | elif args.joint_tasks:
114 | run_joint_tasks(data_dir)
115 | else:
116 | run_task(data_dir, task_id=args.task)
117 |
--------------------------------------------------------------------------------
/bechmarks/README.md:
--------------------------------------------------------------------------------
1 | This page contains benchmark results to compare this Python implementation and the
2 | [original Matlab code](https://github.com/facebook/MemNN/tree/master/MemN2N-babi-matlab)
3 | on the [bAbI tasks](http://fb.ai/babi).
4 |
5 | These results are test error rates (%) with default configuration: 3 hops, position encoding (PE),
6 | linear start training (LS), random noise (RN) and adjacent weight tying.
7 |
8 |
9 | | Task | PE LS RN (matlab) | PE LS RN JOINT (matlab) | PE LS RN (this repo) | PE LS RN JOINT (this repo) |
10 | |:--------:|:---------------------:|:---------------------------:|:-------------------------:|:------------------------------:|
11 | | 1 | 0.1 | 0.0 | 0.5 | 0.1 |
12 | | 2 | 8.2 | 13.1 | 8.9 | 16.6 |
13 | | 3 | 41.8 | 23.4 | 41.4 | 26.3 |
14 | | 4 | 4.4 | 5.9 | 7.3 | 11.3 |
15 | | 5 | 13.7 | 12.9 | 12.2 | 14.4 |
16 | | 6 | 7.9 | 3.7 | 7.7 | 2.8 |
17 | | 7 | 20.3 | 22.9 | 19.8 | 16.0 |
18 | | 8 | 11.7 | 9.1 | 12.9 | 10.1 |
19 | | 9 | 13.6 | 2.7 | 13.9 | 2.3 |
20 | | 10 | 9.7 | 7.2 | 18.9 | 6.5 |
21 | | 11 | 0.7 | 0.8 | 0.5 | 1.2 |
22 | | 12 | 0.0 | 0.1 | 0.2 | 0.2 |
23 | | 13 | 0.3 | 0.1 | 0.9 | 0.5 |
24 | | 14 | 0.9 | 4.2 | 9.0 | 5.5 |
25 | | 15 | 0.0 | 0.0 | 0.0 | 0.3 |
26 | | 16 | 0.4 | 1.2 | 0.6 | 2.1 |
27 | | 17 | 51.0 | 43.6 | 49.1 | 42.6 |
28 | | 18 | 10.5 | 10.5 | 10.4 | 9.0 |
29 | | 19 | 81.4 | 86.7 | 90.8 | 90.2 |
30 | | 20 | 0.0 | 0.0 | 0.0 | 0.2 |
31 | | **Mean** | **13.8** | **12.4** | **15.2** | **12.9** |
32 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class BabiConfig(object):
4 | """
5 | Configuration for bAbI
6 | """
7 | def __init__(self, train_story, train_questions, dictionary):
8 | self.dictionary = dictionary
9 | self.batch_size = 32
10 | self.nhops = 3
11 | self.nepochs = 100
12 | self.lrate_decay_step = 25 # reduce learning rate by half every 25 epochs
13 |
14 | # Use 10% of training data for validation
15 | nb_questions = train_questions.shape[1]
16 | nb_train_questions = int(nb_questions * 0.9)
17 |
18 | self.train_range = np.array(range(nb_train_questions))
19 | self.val_range = np.array(range(nb_train_questions, nb_questions))
20 | self.enable_time = True # add time embeddings
21 | self.use_bow = False # use Bag-of-Words instead of Position-Encoding
22 | self.linear_start = True
23 | self.share_type = 1 # 1: adjacent, 2: layer-wise weight tying
24 | self.randomize_time = 0.1 # amount of noise injected into time index
25 | self.add_proj = False # add linear layer between internal states
26 | self.add_nonlin = False # add non-linearity to internal states
27 |
28 | if self.linear_start:
29 | self.ls_nepochs = 20
30 | self.ls_lrate_decay_step = 21
31 | self.ls_init_lrate = 0.01 / 2
32 |
33 | # Training configuration
34 | self.train_config = {
35 | "init_lrate" : 0.01,
36 | "max_grad_norm": 40,
37 | "in_dim" : 20,
38 | "out_dim" : 20,
39 | "sz" : min(50, train_story.shape[1]), # number of sentences
40 | "voc_sz" : len(self.dictionary),
41 | "bsz" : self.batch_size,
42 | "max_words" : len(train_story),
43 | "weight" : None
44 | }
45 |
46 | if self.linear_start:
47 | self.train_config["init_lrate"] = 0.01 / 2
48 |
49 | if self.enable_time:
50 | self.train_config.update({
51 | "voc_sz" : self.train_config["voc_sz"] + self.train_config["sz"],
52 | "max_words": self.train_config["max_words"] + 1 # Add 1 for time words
53 | })
54 |
55 |
56 | class BabiConfigJoint(object):
57 | """
58 | Joint configuration for bAbI
59 | """
60 | def __init__(self, train_story, train_questions, dictionary):
61 |
62 | # TODO: Inherit from BabiConfig
63 | self.dictionary = dictionary
64 | self.batch_size = 32
65 | self.nhops = 3
66 | self.nepochs = 60
67 |
68 | self.lrate_decay_step = 15 # reduce learning rate by half every 25 epochs # XXX:
69 |
70 | # Use 10% of training data for validation # XXX
71 | nb_questions = train_questions.shape[1]
72 | nb_train_questions = int(nb_questions * 0.9)
73 |
74 | # Randomly split to training and validation sets
75 | rp = np.random.permutation(nb_questions)
76 | self.train_range = rp[:nb_train_questions]
77 | self.val_range = rp[nb_train_questions:]
78 |
79 | self.enable_time = True # add time embeddings
80 | self.use_bow = False # use Bag-of-Words instead of Position-Encoding
81 | self.linear_start = True
82 | self.share_type = 1 # 1: adjacent, 2: layer-wise weight tying
83 | self.randomize_time = 0.1 # amount of noise injected into time index
84 | self.add_proj = False # add linear layer between internal states
85 | self.add_nonlin = False # add non-linearity to internal states
86 |
87 | if self.linear_start:
88 | self.ls_nepochs = 30 # XXX:
89 | self.ls_lrate_decay_step = 31 # XXX:
90 | self.ls_init_lrate = 0.01 / 2
91 |
92 | # Training configuration
93 | self.train_config = {
94 | "init_lrate" : 0.01,
95 | "max_grad_norm": 40,
96 | "in_dim" : 50, # XXX:
97 | "out_dim" : 50, # XXX:
98 | "sz" : min(50, train_story.shape[1]),
99 | "voc_sz" : len(self.dictionary),
100 | "bsz" : self.batch_size,
101 | "max_words" : len(train_story),
102 | "weight" : None
103 | }
104 |
105 | if self.linear_start:
106 | self.train_config["init_lrate"] = 0.01 / 2
107 |
108 | if self.enable_time:
109 | self.train_config.update({
110 | "voc_sz" : self.train_config["voc_sz"] + self.train_config["sz"],
111 | "max_words": self.train_config["max_words"] + 1 # Add 1 for time words
112 | })
113 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | This folder should contain the bAbI dataset v1.2 from [fb.ai/babi](http://fb.ai/babi).
2 |
--------------------------------------------------------------------------------
/demo/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/demo/qa.py:
--------------------------------------------------------------------------------
1 | """
2 | Demo of using Memory Network for question answering
3 | """
4 | import glob
5 | import os
6 | import gzip
7 | import sys
8 | import pickle
9 |
10 | import argparse
11 | import numpy as np
12 |
13 | from config import BabiConfigJoint
14 | from train_test import train, train_linear_start
15 | from util import parse_babi_task, build_model
16 |
17 |
18 | class MemN2N(object):
19 | """
20 | MemN2N class
21 | """
22 | def __init__(self, data_dir, model_file):
23 | self.data_dir = data_dir
24 | self.model_file = model_file
25 | self.reversed_dict = None
26 | self.memory = None
27 | self.model = None
28 | self.loss = None
29 | self.general_config = None
30 |
31 | def save_model(self):
32 | with gzip.open(self.model_file, "wb") as f:
33 | print("Saving model to file %s ..." % self.model_file)
34 | pickle.dump((self.reversed_dict, self.memory, self.model, self.loss, self.general_config), f)
35 |
36 | def load_model(self):
37 | # Check if model was loaded
38 | if self.reversed_dict is None or self.memory is None or \
39 | self.model is None or self.loss is None or self.general_config is None:
40 | print("Loading model from file %s ..." % self.model_file)
41 | with gzip.open(self.model_file, "rb") as f:
42 | self.reversed_dict, self.memory, self.model, self.loss, self.general_config = pickle.load(f)
43 |
44 | def train(self):
45 | """
46 | Train MemN2N model using training data for tasks.
47 | """
48 | np.random.seed(42) # for reproducing
49 | assert self.data_dir is not None, "data_dir is not specified."
50 | print("Reading data from %s ..." % self.data_dir)
51 |
52 | # Parse training data
53 | train_data_path = glob.glob('%s/qa*_*_train.txt' % self.data_dir)
54 | dictionary = {"nil": 0}
55 | train_story, train_questions, train_qstory = parse_babi_task(train_data_path, dictionary, False)
56 |
57 | # Parse test data just to expand the dictionary so that it covers all words in the test data too
58 | test_data_path = glob.glob('%s/qa*_*_test.txt' % self.data_dir)
59 | parse_babi_task(test_data_path, dictionary, False)
60 |
61 | # Get reversed dictionary mapping index to word
62 | self.reversed_dict = dict((ix, w) for w, ix in dictionary.items())
63 |
64 | # Construct model
65 | self.general_config = BabiConfigJoint(train_story, train_questions, dictionary)
66 | self.memory, self.model, self.loss = build_model(self.general_config)
67 |
68 | # Train model
69 | if self.general_config.linear_start:
70 | train_linear_start(train_story, train_questions, train_qstory,
71 | self.memory, self.model, self.loss, self.general_config)
72 | else:
73 | train(train_story, train_questions, train_qstory,
74 | self.memory, self.model, self.loss, self.general_config)
75 |
76 | # Save model
77 | self.save_model()
78 |
79 | def get_story_texts(self, test_story, test_questions, test_qstory,
80 | question_idx, story_idx, last_sentence_idx):
81 | """
82 | Get text of question, its corresponding fact statements.
83 | """
84 | train_config = self.general_config.train_config
85 | enable_time = self.general_config.enable_time
86 | max_words = train_config["max_words"] \
87 | if not enable_time else train_config["max_words"] - 1
88 |
89 | story = [[self.reversed_dict[test_story[word_pos, sent_idx, story_idx]]
90 | for word_pos in range(max_words)]
91 | for sent_idx in range(last_sentence_idx + 1)]
92 |
93 | question = [self.reversed_dict[test_qstory[word_pos, question_idx]]
94 | for word_pos in range(max_words)]
95 |
96 | story_txt = [" ".join([w for w in sent if w != "nil"]) for sent in story]
97 | question_txt = " ".join([w for w in question if w != "nil"])
98 | correct_answer = self.reversed_dict[test_questions[2, question_idx]]
99 |
100 | return story_txt, question_txt, correct_answer
101 |
102 | def predict_answer(self, test_story, test_questions, test_qstory,
103 | question_idx, story_idx, last_sentence_idx,
104 | user_question=''):
105 | # Get configuration
106 | nhops = self.general_config.nhops
107 | train_config = self.general_config.train_config
108 | batch_size = self.general_config.batch_size
109 | dictionary = self.general_config.dictionary
110 | enable_time = self.general_config.enable_time
111 |
112 | max_words = train_config["max_words"] \
113 | if not enable_time else train_config["max_words"] - 1
114 |
115 | input_data = np.zeros((max_words, batch_size), np.float32)
116 | input_data[:] = dictionary["nil"]
117 | self.memory[0].data[:] = dictionary["nil"]
118 |
119 | # Check if user provides questions and it's different from suggested question
120 | _, suggested_question, _ = self.get_story_texts(test_story, test_questions, test_qstory,
121 | question_idx, story_idx, last_sentence_idx)
122 | user_question_provided = user_question != '' and user_question != suggested_question
123 | encoded_user_question = None
124 | if user_question_provided:
125 | # print("User question = '%s'" % user_question)
126 | user_question = user_question.strip()
127 | if user_question[-1] == '?':
128 | user_question = user_question[:-1]
129 | qwords = user_question.rstrip().lower().split() # skip '?'
130 |
131 | # Encoding
132 | encoded_user_question = np.zeros(max_words)
133 | encoded_user_question[:] = dictionary["nil"]
134 | for ix, w in enumerate(qwords):
135 | if w in dictionary:
136 | encoded_user_question[ix] = dictionary[w]
137 | else:
138 | print("WARNING - The word '%s' is not in dictionary." % w)
139 |
140 | # Input data and data for the 1st memory cell
141 | # Here we duplicate input_data to fill the whole batch
142 | for b in range(batch_size):
143 | d = test_story[:, :(1 + last_sentence_idx), story_idx]
144 |
145 | offset = max(0, d.shape[1] - train_config["sz"])
146 | d = d[:, offset:]
147 |
148 | self.memory[0].data[:d.shape[0], :d.shape[1], b] = d
149 |
150 | if enable_time:
151 | self.memory[0].data[-1, :d.shape[1], b] = \
152 | np.arange(d.shape[1])[::-1] + len(dictionary) # time words
153 |
154 | if user_question_provided:
155 | input_data[:test_qstory.shape[0], b] = encoded_user_question
156 | else:
157 | input_data[:test_qstory.shape[0], b] = test_qstory[:, question_idx]
158 |
159 | # Data for the rest memory cells
160 | for i in range(1, nhops):
161 | self.memory[i].data = self.memory[0].data
162 |
163 | # Run model to predict answer
164 | out = self.model.fprop(input_data)
165 | memory_probs = np.array([self.memory[i].probs[:(last_sentence_idx + 1), 0] for i in range(nhops)])
166 |
167 | # Get answer for the 1st question since all are the same
168 | pred_answer_idx = out[:, 0].argmax()
169 | pred_prob = out[pred_answer_idx, 0]
170 |
171 | return pred_answer_idx, pred_prob, memory_probs
172 |
173 |
174 | def train_model(data_dir, model_file):
175 | memn2n = MemN2N(data_dir, model_file)
176 | memn2n.train()
177 |
178 |
179 | def run_console_demo(data_dir, model_file):
180 | """
181 | Console-based demo
182 | """
183 | memn2n = MemN2N(data_dir, model_file)
184 |
185 | # Try to load model
186 | memn2n.load_model()
187 |
188 | # Read test data
189 | print("Reading test data from %s ..." % memn2n.data_dir)
190 | test_data_path = glob.glob('%s/qa*_*_test.txt' % memn2n.data_dir)
191 | test_story, test_questions, test_qstory = \
192 | parse_babi_task(test_data_path, memn2n.general_config.dictionary, False)
193 |
194 | while True:
195 | # Pick a random question
196 | question_idx = np.random.randint(test_questions.shape[1])
197 | story_idx = test_questions[0, question_idx]
198 | last_sentence_idx = test_questions[1, question_idx]
199 |
200 | # Get story and question
201 | story_txt, question_txt, correct_answer = memn2n.get_story_texts(test_story, test_questions, test_qstory,
202 | question_idx, story_idx, last_sentence_idx)
203 | print("* Story:")
204 | print("\n\t".join(story_txt))
205 | print("\n* Suggested question:\n\t%s?" % question_txt)
206 |
207 | while True:
208 | user_question = raw_input("Your question (press Enter to use the suggested question):\n\t")
209 |
210 | pred_answer_idx, pred_prob, memory_probs = \
211 | memn2n.predict_answer(test_story, test_questions, test_qstory,
212 | question_idx, story_idx, last_sentence_idx,
213 | user_question)
214 |
215 | pred_answer = memn2n.reversed_dict[pred_answer_idx]
216 |
217 | print("* Answer: '%s', confidence score = %.2f%%" % (pred_answer, 100. * pred_prob))
218 | if user_question == '':
219 | if pred_answer == correct_answer:
220 | print(" Correct!")
221 | else:
222 | print(" Wrong. The correct answer is '%s'" % correct_answer)
223 |
224 | print("\n* Explanation:")
225 | print("\t".join(["Memory %d" % (i + 1) for i in range(len(memory_probs))]) + "\tText")
226 | for sent_idx, sent_txt in enumerate(story_txt):
227 | prob_output = "\t".join(["%.3f" % mem_prob for mem_prob in memory_probs[:, sent_idx]])
228 | print("%s\t%s" % (prob_output, sent_txt))
229 |
230 | asking_another_question = raw_input("\nDo you want to ask another question? [y/N] ")
231 | if asking_another_question == '' or asking_another_question.lower() == 'n': break
232 |
233 | will_continue = raw_input("Do you want to continue? [Y/n] ")
234 | if will_continue != '' and will_continue.lower() != 'y': break
235 | print("=" * 70)
236 |
237 |
238 | def run_web_demo(data_dir, model_file):
239 | from demo.web import webapp
240 | webapp.init(data_dir, model_file)
241 | webapp.run()
242 |
243 | if __name__ == "__main__":
244 | parser = argparse.ArgumentParser()
245 | parser.add_argument("-d", "--data-dir", default="data/tasks_1-20_v1-2/en",
246 | help="path to dataset directory (default: %(default)s)")
247 | parser.add_argument("-m", "--model-file", default="trained_model/memn2n_model.pklz",
248 | help="model file (default: %(default)s)")
249 | group = parser.add_mutually_exclusive_group()
250 | group.add_argument("-train", "--train", action="store_true",
251 | help="train model (default: %(default)s)")
252 | group.add_argument("-console", "--console-demo", action="store_true",
253 | help="run console-based demo (default: %(default)s)")
254 | group.add_argument("-web", "--web-demo", action="store_true", default=True,
255 | help="run web-based demo (default: %(default)s)")
256 | args = parser.parse_args()
257 |
258 | if not os.path.exists(args.data_dir):
259 | print("The data directory '%s' does not exist. Please download it first." % args.data_dir)
260 | sys.exit(1)
261 |
262 | if args.train:
263 | train_model(args.data_dir, args.model_file)
264 | elif args.console_demo:
265 | run_console_demo(args.data_dir, args.model_file)
266 | else:
267 | run_web_demo(args.data_dir, args.model_file)
268 |
--------------------------------------------------------------------------------
/demo/web/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/demo/web/static/script.js:
--------------------------------------------------------------------------------
1 | $(function() {
2 | var $story = $('#story'),
3 | $question = $('#question'),
4 | $answer = $('#answer'),
5 | $getAnswer = $('#get_answer'),
6 | $getStory = $('#get_story'),
7 | $explainTable = $('#explanation');
8 |
9 | getStory();
10 |
11 | // Activate tooltip
12 | $('.qa-container').find('.glyphicon-info-sign').tooltip();
13 |
14 | $getAnswer.on('click', function(e) {
15 | e.preventDefault();
16 | getAnswer();
17 | });
18 |
19 | $getStory.on('click', function(e) {
20 | e.preventDefault();
21 | getStory();
22 | });
23 |
24 | function getStory() {
25 | $.get('/get/story', function(json) {
26 | $story.val(json["story"]);
27 | $question.val(json["question"]);
28 | $question.data('question_idx', json["question_idx"]);
29 | $question.data('suggested_question', json["question"]); // Save suggested question
30 | $answer.val('');
31 | $answer.data('correct_answer', json["correct_answer"]);
32 | //$explainTable.find('tbody').empty();
33 | });
34 | }
35 |
36 | function getAnswer() {
37 | var questionIdx = $question.data('question_idx'),
38 | correctAnswer = $answer.data('correct_answer'),
39 | suggestedQuestion = $question.data('suggested_question'),
40 | question = $question.val();
41 |
42 | var userQuestion = suggestedQuestion !== question? question : '';
43 | var url = '/get/answer?question_idx=' + questionIdx +
44 | '&user_question=' + encodeURIComponent(userQuestion);
45 |
46 | $.get(url, function(json) {
47 | var predAnswer = json["pred_answer"],
48 | predProb = json["pred_prob"],
49 | memProbs = json["memory_probs"];
50 |
51 | var outputMessage = "Answer = '" + predAnswer + "'" +
52 | "\nConfidence score = " + (predProb * 100).toFixed(2) + "%";
53 |
54 | // Show answer's feedback only if suggested question was used
55 | if (userQuestion === '') {
56 | if (predAnswer === correctAnswer)
57 | outputMessage += "\nCorrect!";
58 | else
59 | outputMessage += "\nWrong. The correct answer is '" + correctAnswer + "'";
60 | }
61 | $answer.val(outputMessage);
62 |
63 | // Explain answer
64 | var explanationHtml = [];
65 | var sentenceList = $story.val().split('\n');
66 | var maxLatestSents = memProbs.length;
67 | var numSents = sentenceList.length;
68 |
69 | for (var i = Math.max(0, numSents - maxLatestSents); i < numSents; i++) {
70 | var rowHtml = [];
71 | rowHtml.push('
');
72 | rowHtml.push('' + sentenceList[i] + ' | ');
73 | for (var j = 0; j < 3; j++) {
74 | var val = memProbs[i][j].toFixed(2);
75 | if (val > 0) {
76 | rowHtml.push('' + val + ' | ');
78 | } else {
79 | rowHtml.push('' + val + ' | ');
80 | }
81 | }
82 | rowHtml.push('
');
83 | explanationHtml.push(rowHtml.join('\n'));
84 | }
85 | $explainTable.find('tbody').html(explanationHtml);
86 | });
87 | }
88 | });
89 |
--------------------------------------------------------------------------------
/demo/web/static/style.css:
--------------------------------------------------------------------------------
1 | body {
2 | padding-top: 10px;
3 | font-family: 'Roboto', sans-serif;
4 | }
5 |
6 | div.shadow {
7 | border: 1px solid #ccc;
8 | border-radius: 10px;
9 | box-shadow: 0 0 24px #bbb;
10 | background: #fff;
11 | color: #666;
12 | }
13 |
14 | .qa-container, .qa-explain-container {
15 | width: 550px;
16 | /*max-height: 650px;*/
17 | padding: 20px 50px 40px;
18 | margin-top: 10px;
19 | margin-left: 30px;
20 | }
21 |
22 | .qa-explain-container table th {
23 | font-size: 16px;
24 | font-weight: bold;
25 | }
26 |
27 | .container > h3 {
28 | text-align: center;
29 | }
30 |
31 | .qa-form {
32 | max-width: 600px;
33 | margin: 0 auto;
34 | }
35 |
36 | .qa-form .btn {
37 | width: 200px;
38 | }
39 |
40 | .qa-form .form-control[readonly] {
41 | background-color: #fff;
42 | }
43 |
44 | /* Font */
45 | .app-header, .qa-form label, .qa-explain-container label, .qa-explain-container th {
46 | font-family: 'Shadows Into Light Two', sans-serif;
47 | }
48 | /*.qa-form input, .qa-form textarea, .qa-form button {*/
49 | /*font-family: 'Roboto', sans-serif;*/
50 | /*}*/
51 |
--------------------------------------------------------------------------------
/demo/web/templates/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | MemN2N for bAbI Tasks
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/demo/web/webapp.py:
--------------------------------------------------------------------------------
1 | """
2 | Web-based demo
3 | """
4 | import glob
5 | import flask
6 | import numpy as np
7 |
8 | from demo.qa import MemN2N
9 | from util import parse_babi_task
10 |
11 | app = flask.Flask(__name__)
12 | memn2n = None
13 | test_story, test_questions, test_qstory = None, None, None
14 |
15 |
16 | def init(data_dir, model_file):
17 | """ Initialize web app """
18 | global memn2n, test_story, test_questions, test_qstory
19 |
20 | # Try to load model
21 | memn2n = MemN2N(data_dir, model_file)
22 | memn2n.load_model()
23 |
24 | # Read test data
25 | print("Reading test data from %s ..." % memn2n.data_dir)
26 | test_data_path = glob.glob('%s/qa*_*_test.txt' % memn2n.data_dir)
27 | test_story, test_questions, test_qstory = \
28 | parse_babi_task(test_data_path, memn2n.general_config.dictionary, False)
29 |
30 |
31 | def run():
32 | app.run()
33 |
34 |
35 | @app.route('/')
36 | def index():
37 | return flask.render_template("index.html")
38 |
39 |
40 | @app.route('/get/story', methods=['GET'])
41 | def get_story():
42 | question_idx = np.random.randint(test_questions.shape[1])
43 | story_idx = test_questions[0, question_idx]
44 | last_sentence_idx = test_questions[1, question_idx]
45 |
46 | story_txt, question_txt, correct_answer = memn2n.get_story_texts(test_story, test_questions, test_qstory,
47 | question_idx, story_idx, last_sentence_idx)
48 | # Format text
49 | story_txt = "\n".join(story_txt)
50 | question_txt += "?"
51 |
52 | return flask.jsonify({
53 | "question_idx": question_idx,
54 | "story": story_txt,
55 | "question": question_txt,
56 | "correct_answer": correct_answer
57 | })
58 |
59 |
60 | @app.route('/get/answer', methods=['GET'])
61 | def get_answer():
62 | question_idx = int(flask.request.args.get('question_idx'))
63 | user_question = flask.request.args.get('user_question', '')
64 |
65 | story_idx = test_questions[0, question_idx]
66 | last_sentence_idx = test_questions[1, question_idx]
67 |
68 | pred_answer_idx, pred_prob, memory_probs = memn2n.predict_answer(test_story, test_questions, test_qstory,
69 | question_idx, story_idx, last_sentence_idx,
70 | user_question)
71 | pred_answer = memn2n.reversed_dict[pred_answer_idx]
72 |
73 | return flask.jsonify({
74 | "pred_answer" : pred_answer,
75 | "pred_prob" : pred_prob,
76 | "memory_probs": memory_probs.T.tolist()
77 | })
78 |
--------------------------------------------------------------------------------
/memn2n/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/memn2n/memory.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from memn2n.nn import ElemMult, Identity, Sequential, LookupTable, Module
4 | from memn2n.nn import Sum, Parallel, Softmax, MatVecProd
5 |
6 |
7 | class Memory(Module):
8 | """
9 | Memory:
10 | Query module = Parallel(LookupTable + Identity) + MatVecProd with transpose + Softmax
11 | Output module = Parallel(LookupTable + Identity) + MatVecProd
12 | """
13 | def __init__(self, train_config):
14 | super(Memory, self).__init__()
15 |
16 | self.sz = train_config["sz"]
17 | self.voc_sz = train_config["voc_sz"]
18 | self.in_dim = train_config["in_dim"]
19 | self.out_dim = train_config["out_dim"]
20 |
21 | # TODO: Mark self.nil_word and self.data as None since they will be overriden eventually
22 | # In build.model.py, memory[i].nil_word = dictionary['nil']"
23 | self.nil_word = train_config["voc_sz"]
24 | self.config = train_config
25 | self.data = np.zeros((self.sz, train_config["bsz"]), np.float32)
26 |
27 | self.emb_query = None
28 | self.emb_out = None
29 | self.mod_query = None
30 | self.mod_out = None
31 | self.probs = None
32 |
33 | self.init_query_module()
34 | self.init_output_module()
35 |
36 | def init_query_module(self):
37 | self.emb_query = LookupTable(self.voc_sz, self.in_dim)
38 | p = Parallel()
39 | p.add(self.emb_query)
40 | p.add(Identity())
41 |
42 | self.mod_query = Sequential()
43 | self.mod_query.add(p)
44 | self.mod_query.add(MatVecProd(True))
45 | self.mod_query.add(Softmax())
46 |
47 | def init_output_module(self):
48 | self.emb_out = LookupTable(self.voc_sz, self.out_dim)
49 | p = Parallel()
50 | p.add(self.emb_out)
51 | p.add(Identity())
52 |
53 | self.mod_out = Sequential()
54 | self.mod_out.add(p)
55 | self.mod_out.add(MatVecProd(False))
56 |
57 | def reset(self):
58 | self.data[:] = self.nil_word
59 |
60 | def put(self, data_row):
61 | self.data[1:, :] = self.data[:-1, :] # shift rows down
62 | self.data[0, :] = data_row # add the new data row on top
63 |
64 | def fprop(self, input_data):
65 | self.probs = self.mod_query.fprop([self.data, input_data])
66 | self.output = self.mod_out.fprop([self.data, self.probs])
67 | return self.output
68 |
69 | def bprop(self, input_data, grad_output):
70 | g1 = self.mod_out.bprop([self.data, self.probs], grad_output)
71 | g2 = self.mod_query.bprop([self.data, input_data], g1[1])
72 | self.grad_input = g2[1]
73 | return self.grad_input
74 |
75 | def update(self, params):
76 | self.mod_out.update(params)
77 | self.mod_query.update(params)
78 | self.emb_out.weight.D[:, self.nil_word] = 0
79 |
80 | def share(self, m):
81 | pass
82 |
83 |
84 | class MemoryBoW(Memory):
85 | """
86 | MemoryBoW:
87 | Query module = Parallel((LookupTable + Sum(1)) + Identity) + MatVecProd with transpose + Softmax
88 | Output module = Parallel((LookupTable + Sum(1)) + Identity) + MatVecProd
89 | """
90 | def __init__(self, config):
91 | super(MemoryBoW, self).__init__(config)
92 | self.data = np.zeros((config["max_words"], self.sz, config["bsz"]), np.float32)
93 |
94 | def init_query_module(self):
95 | self.emb_query = LookupTable(self.voc_sz, self.in_dim)
96 | s = Sequential()
97 | s.add(self.emb_query)
98 | s.add(Sum(dim=1))
99 |
100 | p = Parallel()
101 | p.add(s)
102 | p.add(Identity())
103 |
104 | self.mod_query = Sequential()
105 | self.mod_query.add(p)
106 | self.mod_query.add(MatVecProd(True))
107 | self.mod_query.add(Softmax())
108 |
109 | def init_output_module(self):
110 | self.emb_out = LookupTable(self.voc_sz, self.out_dim)
111 | s = Sequential()
112 | s.add(self.emb_out)
113 | s.add(Sum(dim=1))
114 |
115 | p = Parallel()
116 | p.add(s)
117 | p.add(Identity())
118 |
119 | self.mod_out = Sequential()
120 | self.mod_out.add(p)
121 | self.mod_out.add(MatVecProd(False))
122 |
123 |
124 | class MemoryL(Memory):
125 | """
126 | MemoryL:
127 | Query module = Parallel((LookupTable + ElemMult + Sum(1)) + Identity) + MatVecProd with transpose + Softmax
128 | Output module = Parallel((LookupTable + ElemMult + Sum(1)) + Identity) + MatVecProd
129 | """
130 | def __init__(self, train_config):
131 | super(MemoryL, self).__init__(train_config)
132 | self.data = np.zeros((train_config["max_words"], self.sz, train_config["bsz"]), np.float32)
133 |
134 | def init_query_module(self):
135 | self.emb_query = LookupTable(self.voc_sz, self.in_dim)
136 | s = Sequential()
137 | s.add(self.emb_query)
138 | s.add(ElemMult(self.config["weight"]))
139 | s.add(Sum(dim=1))
140 |
141 | p = Parallel()
142 | p.add(s)
143 | p.add(Identity())
144 |
145 | self.mod_query = Sequential()
146 | self.mod_query.add(p)
147 | self.mod_query.add(MatVecProd(True))
148 | self.mod_query.add(Softmax())
149 |
150 | def init_output_module(self):
151 | self.emb_out = LookupTable(self.voc_sz, self.out_dim)
152 | s = Sequential()
153 | s.add(self.emb_out)
154 | s.add(ElemMult(self.config["weight"]))
155 | s.add(Sum(dim=1))
156 |
157 | p = Parallel()
158 | p.add(s)
159 | p.add(Identity())
160 |
161 | self.mod_out = Sequential()
162 | self.mod_out.add(p)
163 | self.mod_out.add(MatVecProd(False))
164 |
165 |
--------------------------------------------------------------------------------
/memn2n/nn.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 |
3 | import numpy as np
4 |
5 | # Ignore division by zero which will happen CrossEntropyLoss.fprop()
6 | # when Softmax is not included at the end layer.
7 | np.seterr(divide='ignore')
8 |
9 |
10 | class Module(object):
11 | """
12 | Abstract Module class for neural net
13 | """
14 | __metaclass__ = ABCMeta
15 |
16 | def __init__(self):
17 | self.output = None
18 | self.grad_input = None
19 |
20 | @abstractmethod
21 | def fprop(self, input_data):
22 | self.output = input_data
23 | return self.output
24 |
25 | @abstractmethod
26 | def bprop(self, input_data, grad_output):
27 | self.grad_input = grad_output
28 | return self.grad_input
29 |
30 | @abstractmethod
31 | def update(self, params):
32 | pass
33 |
34 | @abstractmethod
35 | def share(self, m):
36 | pass
37 |
38 |
39 | class Container(Module):
40 | """
41 | Container
42 | """
43 | def __init__(self):
44 | super(Container, self).__init__()
45 | self.modules = []
46 |
47 | def add(self, m):
48 | self.modules.append(m)
49 |
50 | def update(self, params):
51 | for module in self.modules:
52 | module.update(params)
53 |
54 | def share(self, m):
55 | for c_module, m_module in zip(self.modules, m.modules):
56 | c_module.share(m_module)
57 |
58 |
59 | class Sequential(Container):
60 |
61 | def fprop(self, input_data):
62 | temp = input_data
63 | for module in self.modules:
64 | temp = module.fprop(temp)
65 |
66 | self.output = temp
67 | return self.output
68 |
69 | def bprop(self, input_data, grad_output):
70 | for i in range(len(self.modules) - 1, 0, -1):
71 | grad_input = self.modules[i].bprop(self.modules[i - 1].output, grad_output)
72 | grad_output = grad_input
73 | grad_input = self.modules[0].bprop(input_data, grad_output)
74 |
75 | self.grad_input = grad_input
76 | return self.grad_input
77 |
78 |
79 | class Parallel(Container):
80 | """
81 | Computes forward and backward propagations for all modules at once.
82 | """
83 | def fprop(self, input_data):
84 | self.output = [module.fprop(input_elem)
85 | for module, input_elem in zip(self.modules, input_data)]
86 | return self.output
87 |
88 | def bprop(self, input_data, grad_output):
89 | self.grad_input = [module.bprop(input_elem, grad_output_elem)
90 | for module, input_elem, grad_output_elem
91 | in zip(self.modules, input_data, grad_output)]
92 | return self.grad_input
93 |
94 |
95 | class AddTable(Module):
96 | """
97 | Module for sum operator which sums up all elements in input data
98 | """
99 | def __init__(self):
100 | super(AddTable, self).__init__()
101 |
102 | def fprop(self, input_data):
103 | self.output = input_data[0]
104 | for elem in input_data[1:]:
105 | # Expand to the same ndim as self.output
106 | # TODO: Code improvement
107 | if elem.ndim == self.output.ndim - 1:
108 | elem = np.expand_dims(elem, axis=elem.ndim + 1)
109 | self.output += elem
110 | return self.output
111 |
112 | def bprop(self, input_data, grad_output):
113 | self.grad_input = [grad_output for _ in range(len(input_data))]
114 | return self.grad_input
115 |
116 | def share(self, m):
117 | pass
118 |
119 | def update(self, params):
120 | pass
121 |
122 |
123 | class ConstMult(Module):
124 | """
125 | Module for multiplying with a constant
126 | """
127 | def __init__(self, c):
128 | super(ConstMult, self).__init__()
129 | self.c = c
130 |
131 | def fprop(self, input_data):
132 | self.output = self.c * input_data
133 | return self.output
134 |
135 | def bprop(self, input_data, grad_output):
136 | self.grad_input = self.c * grad_output
137 | return self.grad_input
138 |
139 | def share(self, m):
140 | pass
141 |
142 | def update(self, params):
143 | pass
144 |
145 |
146 | class Duplicate(Module):
147 | """
148 | Duplicate module which essentially makes a clone for input data
149 | """
150 | def __init__(self):
151 | super(Duplicate, self).__init__()
152 |
153 | def fprop(self, input_data):
154 | self.output = [input_data, input_data]
155 | return self.output
156 |
157 | def bprop(self, input_data, grad_output):
158 | self.grad_input = grad_output[0] + grad_output[1]
159 | return self.grad_input
160 |
161 | def share(self, m):
162 | pass
163 |
164 | def update(self, params):
165 | pass
166 |
167 |
168 | class ElemMult(Module):
169 | """
170 | Module for element-wise product
171 | """
172 | def __init__(self, weight):
173 | super(ElemMult, self).__init__()
174 | self.weight = weight
175 |
176 | def fprop(self, input_data):
177 | # TODO: Rewrite these checkings!!!
178 | if input_data.ndim == 2:
179 | self.output = input_data * self.weight
180 | elif input_data.ndim == 3:
181 | self.output = input_data * self.weight[:, :, None] # broadcasting
182 | elif input_data.ndim == 4:
183 | self.output = input_data * self.weight[:, :, None, None] # broadcasting
184 | else:
185 | raise Exception("input_data has large dimension = %d" % input_data.ndim)
186 | return self.output
187 |
188 | def bprop(self, input_data, grad_output):
189 | # TODO: Same as above.
190 | if input_data.ndim == 2:
191 | self.grad_input = grad_output * self.weight
192 | elif input_data.ndim == 3:
193 | self.grad_input = grad_output * self.weight[:, :, None] # broadcasting
194 | elif input_data.ndim == 4:
195 | self.grad_input = grad_output * self.weight[:, :, None, None] # broadcasting
196 | else:
197 | raise Exception("input_data has large dimension = %d" % input_data.ndim)
198 | return self.grad_input
199 |
200 | def share(self, m):
201 | pass
202 |
203 | def update(self, params):
204 | pass
205 |
206 |
207 | class Identity(Module):
208 | """
209 | Identical forward and backward propagations
210 | """
211 | def __init__(self):
212 | super(Identity, self).__init__()
213 |
214 | def fprop(self, input_data):
215 | self.output = input_data
216 | return self.output
217 |
218 | def bprop(self, input_data, grad_output):
219 | self.grad_input = grad_output
220 | return self.grad_input
221 |
222 | def share(self, m):
223 | pass
224 |
225 | def update(self, params):
226 | pass
227 |
228 |
229 | class Linear(Module):
230 | """
231 | Linear Layer
232 | """
233 | def __init__(self, in_dim, out_dim):
234 | super(Linear, self).__init__()
235 | self.in_dim = in_dim
236 | self.out_dim = out_dim
237 | self.weight = Weight((out_dim, in_dim))
238 | self.bias = Weight((out_dim, 1))
239 |
240 | def fprop(self, input_data):
241 | high_dimension_input = input_data.ndim > 2
242 |
243 | # Reshape input
244 | if high_dimension_input:
245 | input_data = input_data.reshape(input_data.shape[0], -1)
246 |
247 | self.output = np.dot(self.weight.D, input_data) + self.bias.D
248 |
249 | # Reshape output
250 | if high_dimension_input:
251 | self.output = self.output.reshape(self.output.shape[0], -1)
252 |
253 | return self.out_dim
254 |
255 | def bprop(self, input_data, grad_output):
256 | orig_input_data_shape = input_data.shape
257 | high_dimension_input = input_data.ndim > 2
258 |
259 | # Reshape input and grad_output
260 | if high_dimension_input:
261 | input_data = input_data.reshape(input_data.shape[0], -1)
262 | grad_output = grad_output.reshape(grad_output.shape[0], -1)
263 |
264 | self.weight.grad = self.weight.grad + np.dot(grad_output, input_data.T)
265 | self.bias.grad = self.bias.grad + grad_output.sum(axis=1)
266 | self.grad_input = np.dot(self.weight.D.T, grad_output)
267 |
268 | if high_dimension_input:
269 | self.grad_input = self.grad_input.reshape(orig_input_data_shape)
270 |
271 | return self.grad_input
272 |
273 | def update(self, params):
274 | self.weight.update(params)
275 | self.bias.update(params)
276 |
277 | def share(self, m):
278 | self.weight = m.weight
279 | self.bias = m.bias
280 |
281 |
282 | class LinearNB(Module):
283 | """
284 | Linear layer with no bias
285 | """
286 | def __init__(self, in_dim, out_dim, do_transpose=False):
287 | super(LinearNB, self).__init__()
288 | self.in_dim = in_dim
289 | self.out_dim = out_dim
290 | self.do_transpose = do_transpose
291 |
292 | if do_transpose:
293 | self.weight = Weight((in_dim, out_dim))
294 | else:
295 | self.weight = Weight((out_dim, in_dim))
296 |
297 | def fprop(self, input_data):
298 | high_dimension_input = input_data.ndim > 2
299 |
300 | if high_dimension_input:
301 | input_data = input_data.reshape(input_data.shape[0], -1)
302 |
303 | if self.do_transpose:
304 | self.output = np.dot(self.weight.D.T, input_data)
305 | else:
306 | self.output = np.dot(self.weight.D, input_data)
307 |
308 | if high_dimension_input:
309 | self.output = self.output.reshape(self.output.shape[0], -1)
310 |
311 | return self.output
312 |
313 | def bprop(self, input_data, grad_output):
314 | orig_input_data_shape = input_data.shape
315 | high_dimension_input = input_data.ndim > 2
316 |
317 | # Reshape input and grad_output
318 | if high_dimension_input:
319 | input_data = input_data.reshape(input_data.shape[0], -1)
320 | grad_output = grad_output.reshape(grad_output.shape[0], -1)
321 |
322 | if self.do_transpose:
323 | self.weight.grad = self.weight.grad + np.dot(input_data, grad_output.T)
324 | self.grad_input = np.dot(self.weight.D, grad_output)
325 | else:
326 | self.weight.grad = self.weight.grad + np.dot(grad_output, input_data.T)
327 | self.grad_input = np.dot(self.weight.D.T, grad_output)
328 |
329 | if high_dimension_input:
330 | self.grad_input = self.grad_input.reshape(orig_input_data_shape)
331 |
332 | return self.grad_input
333 |
334 | def update(self, params):
335 | self.weight.update(params)
336 |
337 | def share(self, m):
338 | self.weight = m.weight
339 |
340 |
341 | class LookupTable(Module):
342 | """
343 | Lookup table
344 | """
345 | def __init__(self, voc_sz, out_dim):
346 | """
347 | Constructor
348 |
349 | Args:
350 | voc_sz (int): vocabulary size
351 | out_dim (int): output dimension
352 | """
353 | super(LookupTable, self).__init__()
354 | self.sz = voc_sz
355 | self.out_dim = out_dim
356 | self.weight = Weight((out_dim, voc_sz))
357 |
358 | def fprop(self, input_data):
359 | self.output = self.weight.D[:, input_data.T.astype(np.int).flatten()]
360 | # Matlab's reshape uses Fortran order (i.e. column first)
361 | self.output = np.squeeze(self.output.reshape((self.out_dim,) + input_data.shape, order='F'))
362 | return self.output
363 |
364 | def bprop(self, input_data, grad_output):
365 | # Make sure input_data has one dim lower than grad_output (see the index below)
366 | if input_data.ndim == grad_output.ndim:
367 | input_data = np.squeeze(input_data) # TODO: Seems clumsy!
368 |
369 | input_data = input_data.astype(int)
370 | c = np.unique(input_data.flatten())
371 | for i in np.nditer(c):
372 | self.weight.grad[:, i] += np.sum(grad_output[:, input_data == i], axis=1)
373 |
374 | self.grad_input = []
375 | return self.grad_input
376 |
377 | def update(self, params):
378 | self.weight.update(params)
379 |
380 | def share(self, m):
381 | self.weight = m.weight
382 |
383 |
384 | class MatVecProd(Module):
385 | """
386 | Product of matrix and vector in batch, where
387 | matrix's shape is [:, :, batch] and vectors is [:, batch]
388 | Result is a vector of size [:, batch]
389 | """
390 | def __init__(self, do_transpose):
391 | super(MatVecProd, self).__init__()
392 | self.do_transpose = do_transpose
393 |
394 | def fprop(self, input_data):
395 | M = input_data[0]
396 | V = input_data[1]
397 |
398 | # Expand M to 3-dimension and V to 2-dimension
399 | if M.ndim == 2:
400 | M = np.expand_dims(M, axis=2)
401 | if V.ndim == 1:
402 | V = np.expand_dims(V, axis=1)
403 |
404 | batch_size = M.shape[2]
405 |
406 | if self.do_transpose:
407 | self.output = np.zeros((M.shape[1], batch_size), np.float32)
408 | for i in range(batch_size):
409 | self.output[:, i] = np.dot(M[:, :, i].T, V[:, i])
410 | else:
411 | self.output = np.zeros((M.shape[0], batch_size), np.float32)
412 | for i in range(batch_size):
413 | self.output[:, i] = np.dot(M[:, :, i], V[:, i])
414 |
415 | return self.output
416 |
417 | def bprop(self, input_data, grad_output):
418 | M = input_data[0]
419 | V = input_data[1]
420 |
421 | # Expand M to 3-dimension and V to 2-dimension
422 | if M.ndim == 2:
423 | M = np.expand_dims(M, axis=2)
424 | if V.ndim == 1:
425 | V = np.expand_dims(V, axis=1)
426 |
427 | batch_size = M.shape[2]
428 |
429 | grad_M = np.zeros_like(M, np.float32)
430 | grad_V = np.zeros_like(V, np.float32)
431 |
432 | for i in range(batch_size):
433 | if self.do_transpose:
434 | grad_M[:, :, i] = np.dot(V[:, [i]], grad_output[:, [i]].T)
435 | grad_V[:, i] = np.dot(M[:, :, i], grad_output[:, i])
436 | else:
437 | grad_M[:, :, i] = np.dot(grad_output[:, [i]], V[:, [i]].T)
438 | grad_V[:, i] = np.dot(M[:, :, i].T, grad_output[:, i])
439 |
440 | self.grad_input = (grad_M, grad_V)
441 | return self.grad_input
442 |
443 | def update(self, params):
444 | pass
445 |
446 | def share(self, m):
447 | pass
448 |
449 |
450 | class ReLU(Module):
451 | """ ReLU module """
452 |
453 | def fprop(self, input_data):
454 | self.output = np.multiply(input_data, input_data > 0)
455 | return self.output
456 |
457 | def bprop(self, input_data, grad_output):
458 | self.grad_input = np.multiply(grad_output, input_data > 0)
459 | return self.grad_input
460 |
461 | def update(self, params):
462 | pass
463 |
464 | def share(self, m):
465 | pass
466 |
467 |
468 | class SelectTable(Module):
469 | """ SelectTable which slices input data in a specific dimension """
470 |
471 | def __init__(self, index):
472 | super(SelectTable, self).__init__()
473 | self.index = index
474 |
475 | def fprop(self, input_data):
476 | self.output = input_data[self.index]
477 | return self.output
478 |
479 | def bprop(self, input_data, grad_output):
480 | self.grad_input = [grad_output if i == self.index
481 | else np.zeros_like(input_elem, np.float32)
482 | for i, input_elem in enumerate(input_data)]
483 | return self.grad_input
484 |
485 | def update(self, params):
486 | pass
487 |
488 | def share(self, m):
489 | pass
490 |
491 |
492 | class Sigmoid(Module):
493 |
494 | def fprop(self, input_data):
495 | self.output = 1. / (1 + np.exp(-input_data))
496 | return self.output
497 |
498 | def bprop(self, input_data, grad_output):
499 | return grad_output * self.output * (1. - self.output)
500 |
501 | def update(self, params):
502 | pass
503 |
504 | def share(self, m):
505 | pass
506 |
507 |
508 | class Softmax(Module):
509 |
510 | def __init__(self, skip_bprop=False):
511 | super(Softmax, self).__init__()
512 | self.skip_bprop = skip_bprop # for the output module
513 |
514 | def fprop(self, input_data):
515 | input_data -= np.max(input_data, axis=0)
516 | input_data += 1.0
517 |
518 | a = np.exp(input_data)
519 | sum_a = a.sum(axis=0)
520 |
521 | self.output = a / sum_a[None, :] # divide by row
522 | return self.output
523 |
524 | def bprop(self, input_data, grad_output):
525 | if not self.skip_bprop:
526 | z = grad_output - np.sum(self.output * grad_output, axis=0)
527 | self.grad_input = self.output * z
528 | else:
529 | self.grad_input = grad_output
530 |
531 | return self.grad_input
532 |
533 | def update(self, params):
534 | pass
535 |
536 | def share(self, m):
537 | pass
538 |
539 |
540 | class Sum(Module):
541 | """
542 | Sum module which sums up input data at specified dimension
543 | """
544 | def __init__(self, dim):
545 | super(Sum, self).__init__()
546 | self.dim = dim
547 |
548 | def fprop(self, input_data):
549 | self.output = np.squeeze(np.sum(input_data, axis=self.dim))
550 | return self.output
551 |
552 | def bprop(self, input_data, grad_output):
553 | # TODO: Seems clumsy!
554 | sz = np.array(input_data.shape)
555 | sz[self.dim] = 1
556 | grad_output = grad_output.reshape(sz)
557 | sz[:] = 1
558 | sz[self.dim] = input_data.shape[self.dim]
559 | self.grad_input = np.tile(grad_output, sz)
560 | return self.grad_input
561 |
562 | def update(self, params):
563 | pass
564 |
565 | def share(self, m):
566 | pass
567 |
568 |
569 | class Weight(object):
570 | def __init__(self, sz):
571 | """
572 | Initialize weight
573 | Args:
574 | sz (tuple): shape
575 | """
576 | self.sz = sz
577 | self.D = 0.1 * np.random.standard_normal(sz)
578 | self.grad = np.zeros(sz, np.float32)
579 |
580 | def update(self, params):
581 | """
582 | Update weights
583 | """
584 | max_grad_norm = params.get('max_grad_norm')
585 | if max_grad_norm and max_grad_norm > 0:
586 | grad_norm = np.linalg.norm(self.grad, 2)
587 | if grad_norm > max_grad_norm:
588 | self.grad = self.grad * max_grad_norm / grad_norm
589 |
590 | self.D -= params['lrate'] * self.grad
591 | self.grad[:] = 0
592 |
593 | def clone(self):
594 | m = Weight(self.sz)
595 | m.D = np.copy(self.D)
596 | m.grad = np.copy(self.grad)
597 | return m
598 |
599 |
600 | class Loss(object):
601 | """ Abstract Loss class """
602 | __metaclass__ = ABCMeta
603 |
604 | @abstractmethod
605 | def fprop(self, input_data, target_data):
606 | """ Abstract function for forward propagation """
607 | pass
608 |
609 | @abstractmethod
610 | def bprop(self, input_data, target_data):
611 | """ Abstract function for back-propagation """
612 | pass
613 |
614 |
615 | class CrossEntropyLoss(Loss):
616 |
617 | def __init__(self):
618 | self.do_softmax_bprop = False
619 | self.eps = 1e-7
620 | self.size_average = True
621 |
622 | def fprop(self, input_data, target_data):
623 | tmp = [(t, i) for i, t in enumerate(target_data)]
624 | z = zip(*tmp) # unzipping trick !
625 | cost = np.sum(-np.log(input_data[z]))
626 | if self.size_average:
627 | cost /= input_data.shape[1]
628 |
629 | return cost
630 |
631 | def bprop(self, input_data, target_data):
632 | tmp = [(t, i) for i, t in enumerate(target_data)]
633 | z = zip(*tmp)
634 |
635 | if self.do_softmax_bprop:
636 | grad_input = input_data
637 | grad_input[z] -= 1
638 | else:
639 | grad_input = np.zeros_like(input_data, np.float32)
640 | grad_input[z] = -1. / (input_data[z] + self.eps)
641 |
642 | if self.size_average:
643 | grad_input /= input_data.shape[1]
644 |
645 | return grad_input
646 |
647 | def get_error(self, input_data, target_data):
648 | y = input_data.argmax(axis=0)
649 | return np.sum(y != target_data)
650 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | Flask
3 |
--------------------------------------------------------------------------------
/train_test.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import math
4 | import numpy as np
5 |
6 | from memn2n.nn import Softmax
7 | from util import Progress
8 |
9 |
10 | def train(train_story, train_questions, train_qstory, memory, model, loss, general_config):
11 |
12 | train_config = general_config.train_config
13 | dictionary = general_config.dictionary
14 | nepochs = general_config.nepochs
15 | nhops = general_config.nhops
16 | batch_size = general_config.batch_size
17 | enable_time = general_config.enable_time
18 | randomize_time = general_config.randomize_time
19 | lrate_decay_step = general_config.lrate_decay_step
20 |
21 | train_range = general_config.train_range # indices of training questions
22 | val_range = general_config.val_range # indices of validation questions
23 | train_len = len(train_range)
24 | val_len = len(val_range)
25 |
26 | params = {
27 | "lrate": train_config["init_lrate"],
28 | "max_grad_norm": train_config["max_grad_norm"]
29 | }
30 |
31 | for ep in range(nepochs):
32 | # Decrease learning rate after every decay step
33 | if (ep + 1) % lrate_decay_step == 0:
34 | params["lrate"] *= 0.5
35 |
36 | total_err = 0.
37 | total_cost = 0.
38 | total_num = 0
39 | for _ in Progress(range(int(math.floor(train_len / batch_size)))):
40 | # Question batch
41 | batch = train_range[np.random.randint(train_len, size=batch_size)]
42 |
43 | input_data = np.zeros((train_story.shape[0], batch_size), np.float32) # words of training questions
44 | target_data = train_questions[2, batch] # indices of training answers
45 |
46 | memory[0].data[:] = dictionary["nil"]
47 |
48 | # Compose batch of training data
49 | for b in range(batch_size):
50 | # NOTE: +1 since train_questions[1, :] is the index of the sentence right before the training question.
51 | # d is a batch of [word indices in sentence, sentence indices from batch] for this story
52 | d = train_story[:, :(1 + train_questions[1, batch[b]]), train_questions[0, batch[b]]]
53 |
54 | # Pick a fixed number of latest sentences (before the question) from the story
55 | offset = max(0, d.shape[1] - train_config["sz"])
56 | d = d[:, offset:]
57 |
58 | # Training data for the 1st memory cell
59 | memory[0].data[:d.shape[0], :d.shape[1], b] = d
60 |
61 | if enable_time:
62 | # Inject noise into time index (i.e. word index)
63 | if randomize_time > 0:
64 | # Random number of blank (must be < total sentences until the training question?)
65 | nblank = np.random.randint(int(math.ceil(d.shape[1] * randomize_time)))
66 | rt = np.random.permutation(d.shape[1] + nblank)
67 |
68 | rt[rt >= train_config["sz"]] = train_config["sz"] - 1 # put the cap
69 |
70 | # Add random time (must be > dictionary's length) into the time word (decreasing order)
71 | memory[0].data[-1, :d.shape[1], b] = np.sort(rt[:d.shape[1]])[::-1] + len(dictionary)
72 |
73 | else:
74 | memory[0].data[-1, :d.shape[1], b] = \
75 | np.arange(d.shape[1])[::-1] + len(dictionary)
76 |
77 | input_data[:, b] = train_qstory[:, batch[b]]
78 |
79 | for i in range(1, nhops):
80 | memory[i].data = memory[0].data
81 |
82 | out = model.fprop(input_data)
83 | total_cost += loss.fprop(out, target_data)
84 | total_err += loss.get_error(out, target_data)
85 | total_num += batch_size
86 |
87 | grad = loss.bprop(out, target_data)
88 | model.bprop(input_data, grad)
89 | model.update(params)
90 |
91 | for i in range(nhops):
92 | memory[i].emb_query.weight.D[:, 0] = 0
93 |
94 | # Validation
95 | total_val_err = 0.
96 | total_val_cost = 0.
97 | total_val_num = 0
98 |
99 | for k in range(int(math.floor(val_len / batch_size))):
100 | batch = val_range[np.arange(k * batch_size, (k + 1) * batch_size)]
101 | input_data = np.zeros((train_story.shape[0], batch_size), np.float32)
102 | target_data = train_questions[2, batch]
103 |
104 | memory[0].data[:] = dictionary["nil"]
105 |
106 | for b in range(batch_size):
107 | d = train_story[:, :(1 + train_questions[1, batch[b]]), train_questions[0, batch[b]]]
108 |
109 | offset = max(0, d.shape[1] - train_config["sz"])
110 | d = d[:, offset:]
111 |
112 | # Data for the 1st memory cell
113 | memory[0].data[:d.shape[0], :d.shape[1], b] = d
114 |
115 | if enable_time:
116 | memory[0].data[-1, :d.shape[1], b] = np.arange(d.shape[1])[::-1] + len(dictionary)
117 |
118 | input_data[:, b] = train_qstory[:, batch[b]]
119 |
120 | for i in range(1, nhops):
121 | memory[i].data = memory[0].data
122 |
123 | out = model.fprop(input_data)
124 | total_val_cost += loss.fprop(out, target_data)
125 | total_val_err += loss.get_error(out, target_data)
126 | total_val_num += batch_size
127 |
128 | train_error = total_err / total_num
129 | val_error = total_val_err / total_val_num
130 |
131 | print("%d | train error: %g | val error: %g" % (ep + 1, train_error, val_error))
132 |
133 |
134 | def train_linear_start(train_story, train_questions, train_qstory, memory, model, loss, general_config):
135 |
136 | train_config = general_config.train_config
137 |
138 | # Remove softmax from memory
139 | for i in range(general_config.nhops):
140 | memory[i].mod_query.modules.pop()
141 |
142 | # Save settings
143 | nepochs2 = general_config.nepochs
144 | lrate_decay_step2 = general_config.lrate_decay_step
145 | init_lrate2 = train_config["init_lrate"]
146 |
147 | # Add new settings
148 | general_config.nepochs = general_config.ls_nepochs
149 | general_config.lrate_decay_step = general_config.ls_lrate_decay_step
150 | train_config["init_lrate"] = general_config.ls_init_lrate
151 |
152 | # Train with new settings
153 | train(train_story, train_questions, train_qstory, memory, model, loss, general_config)
154 |
155 | # Add softmax back
156 | for i in range(general_config.nhops):
157 | memory[i].mod_query.add(Softmax())
158 |
159 | # Restore old settings
160 | general_config.nepochs = nepochs2
161 | general_config.lrate_decay_step = lrate_decay_step2
162 | train_config["init_lrate"] = init_lrate2
163 |
164 | # Train with old settings
165 | train(train_story, train_questions, train_qstory, memory, model, loss, general_config)
166 |
167 |
168 | def test(test_story, test_questions, test_qstory, memory, model, loss, general_config):
169 | total_test_err = 0.
170 | total_test_num = 0
171 |
172 | nhops = general_config.nhops
173 | train_config = general_config.train_config
174 | batch_size = general_config.batch_size
175 | dictionary = general_config.dictionary
176 | enable_time = general_config.enable_time
177 |
178 | max_words = train_config["max_words"] \
179 | if not enable_time else train_config["max_words"] - 1
180 |
181 | for k in range(int(math.floor(test_questions.shape[1] / batch_size))):
182 | batch = np.arange(k * batch_size, (k + 1) * batch_size)
183 |
184 | input_data = np.zeros((max_words, batch_size), np.float32)
185 | target_data = test_questions[2, batch]
186 |
187 | input_data[:] = dictionary["nil"]
188 | memory[0].data[:] = dictionary["nil"]
189 |
190 | for b in range(batch_size):
191 | d = test_story[:, :(1 + test_questions[1, batch[b]]), test_questions[0, batch[b]]]
192 |
193 | offset = max(0, d.shape[1] - train_config["sz"])
194 | d = d[:, offset:]
195 |
196 | memory[0].data[:d.shape[0], :d.shape[1], b] = d
197 |
198 | if enable_time:
199 | memory[0].data[-1, :d.shape[1], b] = np.arange(d.shape[1])[::-1] + len(dictionary) # time words
200 |
201 | input_data[:test_qstory.shape[0], b] = test_qstory[:, batch[b]]
202 |
203 | for i in range(1, nhops):
204 | memory[i].data = memory[0].data
205 |
206 | out = model.fprop(input_data)
207 | # cost = loss.fprop(out, target_data)
208 | total_test_err += loss.get_error(out, target_data)
209 | total_test_num += batch_size
210 |
211 | test_error = total_test_err / total_test_num
212 | print("Test error: %f" % test_error)
213 |
--------------------------------------------------------------------------------
/trained_model/README.md:
--------------------------------------------------------------------------------
1 | The model `memn2n_model.pklz` was trained jointly using data from all 20 tasks (1k training questions for
2 | each task).
--------------------------------------------------------------------------------
/trained_model/memn2n_model.pklz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinhkhuc/MemN2N-babi-python/84f635138782b02db0f96013bb9189ccf9ca1e06/trained_model/memn2n_model.pklz
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import sys
4 | import time
5 |
6 | import numpy as np
7 |
8 | from memn2n.memory import MemoryL, MemoryBoW
9 | from memn2n.nn import AddTable, CrossEntropyLoss, Duplicate, ElemMult, LinearNB
10 | from memn2n.nn import Identity, ReLU, Sequential, LookupTable, Sum, Parallel, Softmax
11 |
12 |
13 | def parse_babi_task(data_files, dictionary, include_question):
14 | """ Parse bAbI data.
15 |
16 | Args:
17 | data_files (list): a list of data file's paths.
18 | dictionary (dict): word's dictionary
19 | include_question (bool): whether count question toward input sentence.
20 |
21 | Returns:
22 | A tuple of (story, questions, qstory):
23 | story (3-D array)
24 | [position of word in sentence, sentence index, story index] = index of word in dictionary
25 | questions (2-D array)
26 | [0-9, question index], in which the first component is encoded as follows:
27 | 0 - story index
28 | 1 - index of the last sentence before the question
29 | 2 - index of the answer word in dictionary
30 | 3 to 13 - indices of supporting sentence
31 | 14 - line index
32 | qstory (2-D array) question's indices within a story
33 | [index of word in question, question index] = index of word in dictionary
34 | """
35 | # Try to reserve spaces beforehand (large matrices for both 1k and 10k data sets)
36 | # maximum number of words in sentence = 20
37 | story = np.zeros((20, 500, len(data_files) * 3500), np.int16)
38 | questions = np.zeros((14, len(data_files) * 10000), np.int16)
39 | qstory = np.zeros((20, len(data_files) * 10000), np.int16)
40 |
41 | # NOTE: question's indices are not reset when going through a new story
42 | story_idx, question_idx, sentence_idx, max_words, max_sentences = -1, -1, -1, 0, 0
43 |
44 | # Mapping line number (within a story) to sentence's index (to support the flag include_question)
45 | mapping = None
46 |
47 | for fp in data_files:
48 | with open(fp) as f:
49 | for line_idx, line in enumerate(f):
50 | line = line.rstrip().lower()
51 | words = line.split()
52 |
53 | # Story begins
54 | if words[0] == '1':
55 | story_idx += 1
56 | sentence_idx = -1
57 | mapping = []
58 |
59 | # FIXME: This condition makes the code more fragile!
60 | if '?' not in line:
61 | is_question = False
62 | sentence_idx += 1
63 | else:
64 | is_question = True
65 | question_idx += 1
66 | questions[0, question_idx] = story_idx
67 | questions[1, question_idx] = sentence_idx
68 | if include_question:
69 | sentence_idx += 1
70 |
71 | mapping.append(sentence_idx)
72 |
73 | # Skip substory index
74 | for k in range(1, len(words)):
75 | w = words[k]
76 |
77 | if w.endswith('.') or w.endswith('?'):
78 | w = w[:-1]
79 |
80 | if w not in dictionary:
81 | dictionary[w] = len(dictionary)
82 |
83 | if max_words < k:
84 | max_words = k
85 |
86 | if not is_question:
87 | story[k - 1, sentence_idx, story_idx] = dictionary[w]
88 | else:
89 | qstory[k - 1, question_idx] = dictionary[w]
90 | if include_question:
91 | story[k - 1, sentence_idx, story_idx] = dictionary[w]
92 |
93 | # NOTE: Punctuation is already removed from w
94 | if words[k].endswith('?'):
95 | answer = words[k + 1]
96 | if answer not in dictionary:
97 | dictionary[answer] = len(dictionary)
98 |
99 | questions[2, question_idx] = dictionary[answer]
100 |
101 | # Indices of supporting sentences
102 | for h in range(k + 2, len(words)):
103 | questions[1 + h - k, question_idx] = mapping[int(words[h]) - 1]
104 |
105 | questions[-1, question_idx] = line_idx
106 | break
107 |
108 | if max_sentences < sentence_idx + 1:
109 | max_sentences = sentence_idx + 1
110 |
111 | story = story[:max_words, :max_sentences, :(story_idx + 1)]
112 | questions = questions[:, :(question_idx + 1)]
113 | qstory = qstory[:max_words, :(question_idx + 1)]
114 |
115 | return story, questions, qstory
116 |
117 |
118 | def build_model(general_config):
119 | """
120 | Build model
121 |
122 | NOTE: (for default config)
123 | 1) Model's architecture (embedding B)
124 | LookupTable -> ElemMult -> Sum -> [ Duplicate -> { Parallel -> Memory -> Identity } -> AddTable ] -> LinearNB -> Softmax
125 |
126 | 2) Memory's architecture
127 | a) Query module (embedding A)
128 | Parallel -> { LookupTable + ElemMult + Sum } -> Identity -> MatVecProd -> Softmax
129 |
130 | b) Output module (embedding C)
131 | Parallel -> { LookupTable + ElemMult + Sum } -> Identity -> MatVecProd
132 | """
133 | train_config = general_config.train_config
134 | dictionary = general_config.dictionary
135 | use_bow = general_config.use_bow
136 | nhops = general_config.nhops
137 | add_proj = general_config.add_proj
138 | share_type = general_config.share_type
139 | enable_time = general_config.enable_time
140 | add_nonlin = general_config.add_nonlin
141 |
142 | in_dim = train_config["in_dim"]
143 | out_dim = train_config["out_dim"]
144 | max_words = train_config["max_words"]
145 | voc_sz = train_config["voc_sz"]
146 |
147 | if not use_bow:
148 | train_config["weight"] = np.ones((in_dim, max_words), np.float32)
149 | for i in range(in_dim):
150 | for j in range(max_words):
151 | train_config["weight"][i][j] = (i + 1 - (in_dim + 1) / 2) * \
152 | (j + 1 - (max_words + 1) / 2)
153 | train_config["weight"] = \
154 | 1 + 4 * train_config["weight"] / (in_dim * max_words)
155 |
156 | memory = {}
157 | model = Sequential()
158 | model.add(LookupTable(voc_sz, in_dim))
159 | if not use_bow:
160 | if enable_time:
161 | model.add(ElemMult(train_config["weight"][:, :-1]))
162 | else:
163 | model.add(ElemMult(train_config["weight"]))
164 |
165 | model.add(Sum(dim=1))
166 |
167 | proj = {}
168 | for i in range(nhops):
169 | if use_bow:
170 | memory[i] = MemoryBoW(train_config)
171 | else:
172 | memory[i] = MemoryL(train_config)
173 |
174 | # Override nil_word which is initialized in "self.nil_word = train_config["voc_sz"]"
175 | memory[i].nil_word = dictionary['nil']
176 | model.add(Duplicate())
177 | p = Parallel()
178 | p.add(memory[i])
179 |
180 | if add_proj:
181 | proj[i] = LinearNB(in_dim, in_dim)
182 | p.add(proj[i])
183 | else:
184 | p.add(Identity())
185 |
186 | model.add(p)
187 | model.add(AddTable())
188 | if add_nonlin:
189 | model.add(ReLU())
190 |
191 | model.add(LinearNB(out_dim, voc_sz, True))
192 | model.add(Softmax())
193 |
194 | # Share weights
195 | if share_type == 1:
196 | # Type 1: adjacent weight tying
197 | memory[0].emb_query.share(model.modules[0])
198 | for i in range(1, nhops):
199 | memory[i].emb_query.share(memory[i - 1].emb_out)
200 |
201 | model.modules[-2].share(memory[len(memory) - 1].emb_out)
202 |
203 | elif share_type == 2:
204 | # Type 2: layer-wise weight tying
205 | for i in range(1, nhops):
206 | memory[i].emb_query.share(memory[0].emb_query)
207 | memory[i].emb_out.share(memory[0].emb_out)
208 |
209 | if add_proj:
210 | for i in range(1, nhops):
211 | proj[i].share(proj[0])
212 |
213 | # Cost
214 | loss = CrossEntropyLoss()
215 | loss.size_average = False
216 | loss.do_softmax_bprop = True
217 | model.modules[-1].skip_bprop = True
218 |
219 | return memory, model, loss
220 |
221 |
222 | class Progress(object):
223 | """
224 | Progress bar
225 | """
226 |
227 | def __init__(self, iterable, bar_length=50):
228 | self.iterable = iterable
229 | self.bar_length = bar_length
230 | self.total_length = len(iterable)
231 | self.start_time = time.time()
232 | self.count = 0
233 |
234 | def __iter__(self):
235 | for obj in self.iterable:
236 | yield obj
237 | self.count += 1
238 | percent = self.count / self.total_length
239 | print_length = int(percent * self.bar_length)
240 | progress = "=" * print_length + " " * (self.bar_length - print_length)
241 | elapsed_time = time.time() - self.start_time
242 | print_msg = "\r|%s| %.0f%% %.1fs" % (progress, percent * 100, elapsed_time)
243 | sys.stdout.write(print_msg)
244 | if self.count == self.total_length:
245 | sys.stdout.write("\r" + " " * len(print_msg) + "\r")
246 | sys.stdout.flush()
247 |
--------------------------------------------------------------------------------