├── .gitignore ├── LICENSE ├── README.md ├── data.py ├── data ├── .gitignore ├── dev-v1.1.json └── train-v1.1.json ├── layers ├── Argmax.py ├── PointerGRU.py ├── QuestionAttnGRU.py ├── QuestionPooling.py ├── SelfAttnGRU.py ├── SharedWeight.py ├── Slice.py ├── VariationalDropout.py ├── WrappedGRU.py ├── __init__.py └── helpers.py ├── lib └── .gitignore ├── model.py ├── models └── .gitignore ├── parse_data.py ├── predict.py ├── preprocessing.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pkl 3 | .vscode/ 4 | .* 5 | !.gitignore 6 | 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 YerevaNN Foundation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # R-NET implementation in Keras 2 | 3 | This repository is an attempt to reproduce the results presented in the [technical report by Microsoft Research Asia](https://www.microsoft.com/en-us/research/wp-content/uploads/2017/05/r-net.pdf). The report describes a complex neural network called [R-NET](https://www.microsoft.com/en-us/research/publication/mrc/) designed for question answering. 4 | 5 | **[This blogpost](http://yerevann.github.io/2017/08/25/challenges-of-reproducing-r-net-neural-network-using-keras/) describes the details.** 6 | 7 | R-NET is currently (August 25, 2017) the best single model on the Stanford QA database: [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/). SQuAD dataset uses two performance metrics, exact match (EM) and F1-score (F1). Human performance is estimated to be EM=82.3% and F1=91.2% on the test set. 8 | 9 | The report describes two versions of R-NET: 10 | 1. The first one is called `R-NET (Wang et al., 2017)` (which refers to a paper which not yet available online) and reaches EM=71.3% and F1=79.7% on the test set. It consists of input encoders, a modified version of Match-LSTM, self-matching attention layer (the main contribution of the paper) and a pointer network. 11 | 2. The second version called `R-NET (March 2017)` has one additional BiGRU between the self-matching attention layer and the pointer network and reaches EM=72.3% and F1=80.7%. 12 | 13 | The current best single-model on SQuAD leaderboard has a higher score, which means R-NET development continued after March 2017. Ensemble models reach higher scores. 14 | 15 | This repository contains an implementation of the first version, but we cannot yet reproduce the reported results. The best performance we got so far was EM=57.52% and F1=67.42% on the dev set. We are aware of a few differences between our implementation and the network described in the paper: 16 | 17 | 1. The first formula in (11) of the [report](https://www.microsoft.com/en-us/research/wp-content/uploads/2017/05/r-net.pdf) contains a strange summand W_v^Q V_r^Q. Both tensors are trainable and are not used anywhere else in the network. We have replaced this product with a single trainable vector. 18 | 2. The size of the hidden layer should 75 according to the report, but we get better results with a lower number. Overfitting is huge with 75 neurons. 19 | 3. We are not sure whether we applied dropout correctly. 20 | 4. There is nothing about weight initialization or batch generation in the report. 21 | 5. Question-aware passage representation generation (probably) should be done by a bidirectional GRU. 22 | 23 | On the other hand we can't rule out that we have bugs in our code. 24 | 25 | ## Instructions (make sure you are running Keras version 2.0.6) 26 | 27 | 1. We need to parse and split the data 28 | ```sh 29 | python parse_data.py data/train-v1.1.json --train_ratio 0.9 --outfile data/train_parsed.json --outfile_valid data/valid_parsed.json 30 | python parse_data.py data/dev-v1.1.json --outfile data/dev_parsed.json 31 | ``` 32 | 33 | 2. Preprocess the data 34 | ```sh 35 | python preprocessing.py data/train_parsed.json --outfile data/train_data_str.pkl --include_str 36 | python preprocessing.py data/valid_parsed.json --outfile data/valid_data_str.pkl --include_str 37 | python preprocessing.py data/dev_parsed.json --outfile data/dev_data_str.pkl --include_str 38 | ``` 39 | 40 | 3. Train the model 41 | ```sh 42 | python train.py --hdim 45 --batch_size 50 --nb_epochs 50 --optimizer adadelta --lr 1 --dropout 0.2 --char_level_embeddings --train_data data/train_data_str.pkl --valid_data data/valid_data_str.pkl 43 | ``` 44 | 45 | 4. Predict on dev/test set samples 46 | ```sh 47 | python predict.py --batch_size 100 --dev_data data/dev_data_str.pkl models/31-t3.05458271443-v3.27696280528.model prediction.json 48 | ``` 49 | 50 | Our best model can be downloaded from Release v0.1: https://github.com/YerevaNN/R-NET-in-Keras/releases/download/v0.1/31-t3.05458271443-v3.27696280528.model 51 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import cPickle as pickle 7 | 8 | from keras import backend as K 9 | from keras.utils import np_utils 10 | from keras.preprocessing import sequence 11 | 12 | from random import shuffle 13 | import itertools 14 | 15 | def load_dataset(filename): 16 | with open(filename, 'rb') as f: 17 | return pickle.load(f) 18 | 19 | def padded_batch_input(input, indices=None, dtype=K.floatx(), maxlen=None): 20 | if indices is None: 21 | indices = np.arange(len(input)) 22 | 23 | batch_input = [input[i] for i in indices] 24 | return sequence.pad_sequences(batch_input, maxlen, dtype, padding='post') 25 | 26 | def categorical_batch_target(target, classes, indices=None, dtype=K.floatx()): 27 | if indices is None: 28 | indices = np.arange(len(target)) 29 | 30 | batch_target = [min(target[i], classes-1) for i in indices] 31 | return np_utils.to_categorical(batch_target, classes).astype(dtype) 32 | 33 | def lengthGroup(length): 34 | if length < 150: 35 | return 0 36 | if length < 240: 37 | return 1 38 | if length < 380: 39 | return 2 40 | if length < 520: 41 | return 3 42 | if length < 660: 43 | return 4 44 | return 5 45 | 46 | class BatchGen(object): 47 | def __init__(self, inputs, targets=None, batch_size=None, stop=False, 48 | shuffle=True, balance=False, dtype=K.floatx(), 49 | flatten_targets=False, sort_by_length=False, 50 | group=False, maxlen=None): 51 | assert len(set([len(i) for i in inputs])) == 1 52 | assert(not shuffle or not sort_by_length) 53 | self.inputs = inputs 54 | self.nb_samples = len(inputs[0]) 55 | 56 | self.batch_size = batch_size if batch_size else self.nb_samples 57 | 58 | self.dtype = dtype 59 | 60 | self.stop = stop 61 | self.shuffle = shuffle 62 | self.balance = balance 63 | self.targets = targets 64 | self.flatten_targets = flatten_targets 65 | if isinstance(maxlen, (list, tuple)): 66 | self.maxlen = maxlen 67 | else: 68 | self.maxlen = [maxlen] * len(inputs) 69 | 70 | self.sort_by_length = None 71 | if sort_by_length: 72 | self.sort_by_length = np.argsort([-len(p) for p in inputs[0]]) 73 | 74 | # if self.targets and self.balance: 75 | # self.class_weight = class_weight(self.targets) 76 | 77 | self.generator = self._generator() 78 | self._steps = -(-self.nb_samples // self.batch_size) # round up 79 | 80 | self.groups = None 81 | if group is not False: 82 | indices = np.arange(self.nb_samples) 83 | 84 | ff = lambda i: lengthGroup(len(inputs[0][i])) 85 | 86 | indices = np.argsort([ff(i) for i in indices]) 87 | 88 | self.groups = itertools.groupby(indices, ff) 89 | 90 | self.groups = {k: np.array(list(v)) for k, v in self.groups} 91 | 92 | def _generator(self): 93 | while True: 94 | if self.shuffle: 95 | permutation = np.random.permutation(self.nb_samples) 96 | elif self.sort_by_length is not None: 97 | permutation = self.sort_by_length 98 | elif self.groups is not None: 99 | # permutation = np.arange(self.nb_samples) 100 | # tmp = permutation.copy() 101 | # for id in self.group_ids: 102 | # mask = (self.groups==id) 103 | # tmp[mask] = np.random.permutation(permutation[mask]) 104 | # permutation = tmp 105 | # import ipdb 106 | # ipdb.set_trace() 107 | 108 | for k, v in self.groups.items(): 109 | np.random.shuffle(v) 110 | 111 | tmp = np.concatenate(self.groups.values()) 112 | batches = np.array_split(tmp, self._steps) 113 | 114 | remainder = [] 115 | if len(batches[-1]) < self._steps: 116 | remainder = batches[-1:] 117 | batches = batches[:-1] 118 | 119 | shuffle(batches) 120 | batches += remainder 121 | permutation = np.concatenate(batches) 122 | 123 | else: 124 | permutation = np.arange(self.nb_samples) 125 | 126 | i = 0 127 | longest = 767 128 | 129 | while i < self.nb_samples: 130 | if self.sort_by_length is not None: 131 | bs = self.batch_size * 767 // self.inputs[0][permutation[i]].shape[0] 132 | else: 133 | bs = self.batch_size 134 | 135 | indices = permutation[i : i + bs] 136 | i = i + bs 137 | 138 | # for i in range(0, self.nb_samples, self.batch_size): 139 | # indices = permutation[i : i + self.batch_size] 140 | 141 | batch_X = [padded_batch_input(x, indices, self.dtype, maxlen) 142 | for x, maxlen in zip(self.inputs, self.maxlen)] 143 | 144 | P = batch_X[0].shape[1] 145 | 146 | if not self.targets: 147 | yield batch_X 148 | continue 149 | 150 | batch_Y = [categorical_batch_target(target, P, 151 | indices, self.dtype) 152 | for target in self.targets] 153 | 154 | if self.flatten_targets: 155 | batch_Y = np.concatenate(batch_Y, axis=-1) 156 | 157 | if not self.balance: 158 | yield (batch_X, batch_Y) 159 | continue 160 | 161 | # batch_W = np.array([self.class_weight[y] for y in batch_targets]) 162 | batch_W = np.array([bs / self.batch_size for x in batch_X[0]]).astype(self.dtype) 163 | yield (batch_X, batch_Y, batch_W) 164 | 165 | if self.stop: 166 | raise StopIteration 167 | 168 | def __iter__(self): 169 | return self.generator 170 | 171 | def next(self): 172 | return self.generator.next() 173 | 174 | def __next__(self): 175 | return self.generator.__next__() 176 | 177 | def steps(self): 178 | if self.sort_by_length is None: 179 | return self._steps 180 | 181 | print("Steps was called") 182 | if self.shuffle: 183 | permutation = np.random.permutation(self.nb_samples) 184 | elif self.sort_by_length is not None: 185 | permutation = self.sort_by_length 186 | else: 187 | permutation = np.arange(self.nb_samples) 188 | 189 | i = 0 190 | longest = 767 191 | 192 | self._steps = 0 193 | while i < self.nb_samples: 194 | bs = self.batch_size * 767 // self.inputs[0][permutation[i]].shape[0] 195 | i = i + bs 196 | self._steps += 1 197 | 198 | return self._steps 199 | 200 | batch_gen = BatchGen # for backward compatibility 201 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | 6 | -------------------------------------------------------------------------------- /layers/Argmax.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras import backend as K 3 | from keras.engine import Layer, InputSpec 4 | 5 | class Argmax(Layer): 6 | def __init__(self, axis=-1, **kwargs): 7 | super(Argmax, self).__init__(**kwargs) 8 | self.supports_masking = True 9 | self.axis = axis 10 | 11 | def call(self, inputs, mask=None): 12 | return K.argmax(inputs, axis=self.axis) 13 | 14 | def compute_output_shape(self, input_shape): 15 | input_shape = list(input_shape) 16 | del input_shape[self.axis] 17 | return tuple(input_shape) 18 | 19 | def compute_mask(self, x, mask): 20 | return None 21 | 22 | def get_config(self): 23 | config = {'axis': self.axis} 24 | base_config = super(Argmax, self).get_config() 25 | return dict(list(base_config.items()) + list(config.items())) 26 | -------------------------------------------------------------------------------- /layers/PointerGRU.py: -------------------------------------------------------------------------------- 1 | # from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | from keras import backend as K 6 | from keras.layers import Layer 7 | from keras.layers.wrappers import TimeDistributed 8 | 9 | from WrappedGRU import WrappedGRU 10 | from helpers import compute_mask, softmax 11 | 12 | class PointerGRU(WrappedGRU): 13 | 14 | def build(self, input_shape): 15 | H = self.units // 2 16 | assert(isinstance(input_shape, list)) 17 | 18 | nb_inputs = len(input_shape) 19 | assert(nb_inputs >= 6) 20 | 21 | assert(len(input_shape[0]) >= 2) 22 | B, T = input_shape[0][:2] 23 | 24 | 25 | assert(len(input_shape[1]) == 3) 26 | B, P, H_ = input_shape[1] 27 | assert(H_ == 2 * H) 28 | 29 | self.input_spec = [None] 30 | super(PointerGRU, self).build(input_shape=(B, T, 2 * H)) 31 | self.GRU_input_spec = self.input_spec 32 | self.input_spec = [None] * nb_inputs # TODO TODO TODO 33 | 34 | def step(self, inputs, states): 35 | # input 36 | ha_tm1 = states[0] # (B, 2H) 37 | _ = states[1:3] # ignore internal dropout/masks 38 | hP, WP_h, Wa_h, v = states[3:7] # (B, P, 2H) 39 | hP_mask, = states[7:8] 40 | 41 | WP_h_Dot = K.dot(hP, WP_h) # (B, P, H) 42 | Wa_h_Dot = K.dot(K.expand_dims(ha_tm1, axis=1), Wa_h) # (B, 1, H) 43 | 44 | s_t_hat = K.tanh(WP_h_Dot + Wa_h_Dot) # (B, P, H) 45 | s_t = K.dot(s_t_hat, v) # (B, P, 1) 46 | s_t = K.batch_flatten(s_t) # (B, P) 47 | a_t = softmax(s_t, mask=hP_mask, axis=1) # (B, P) 48 | c_t = K.batch_dot(hP, a_t, axes=[1, 1]) # (B, 2H) 49 | 50 | GRU_inputs = c_t 51 | ha_t, (ha_t_,) = super(PointerGRU, self).step(GRU_inputs, states) 52 | 53 | return a_t, [ha_t] 54 | 55 | def compute_output_shape(self, input_shape): 56 | assert(isinstance(input_shape, list)) 57 | 58 | nb_inputs = len(input_shape) 59 | assert(nb_inputs >= 5) 60 | 61 | assert(len(input_shape[0]) >= 2) 62 | B, T = input_shape[0][:2] 63 | 64 | assert(len(input_shape[1]) == 3) 65 | B, P, H_ = input_shape[1] 66 | 67 | if self.return_sequences: 68 | return (B, T, P) 69 | else: 70 | return (B, P) 71 | 72 | def compute_mask(self, inputs, mask=None): 73 | return None # TODO 74 | -------------------------------------------------------------------------------- /layers/QuestionAttnGRU.py: -------------------------------------------------------------------------------- 1 | # from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | from keras import backend as K 6 | 7 | from WrappedGRU import WrappedGRU 8 | from helpers import compute_mask, softmax 9 | 10 | class QuestionAttnGRU(WrappedGRU): 11 | 12 | def build(self, input_shape): 13 | H = self.units 14 | assert(isinstance(input_shape, list)) 15 | 16 | nb_inputs = len(input_shape) 17 | assert(nb_inputs >= 2) 18 | 19 | assert(len(input_shape[0]) == 3) 20 | B, P, H_ = input_shape[0] 21 | assert(H_ == 2 * H) 22 | 23 | assert(len(input_shape[1]) == 3) 24 | B, Q, H_ = input_shape[1] 25 | assert(H_ == 2 * H) 26 | 27 | self.input_spec = [None] 28 | super(QuestionAttnGRU, self).build(input_shape=(B, P, 4 * H)) 29 | self.GRU_input_spec = self.input_spec 30 | self.input_spec = [None] * nb_inputs 31 | 32 | def step(self, inputs, states): 33 | uP_t = inputs 34 | vP_tm1 = states[0] 35 | _ = states[1:3] # ignore internal dropout/masks 36 | uQ, WQ_u, WP_v, WP_u, v, W_g1 = states[3:9] 37 | uQ_mask, = states[9:10] 38 | 39 | WQ_u_Dot = K.dot(uQ, WQ_u) #WQ_u 40 | WP_v_Dot = K.dot(K.expand_dims(vP_tm1, axis=1), WP_v) #WP_v 41 | WP_u_Dot = K.dot(K.expand_dims(uP_t, axis=1), WP_u) # WP_u 42 | 43 | s_t_hat = K.tanh(WQ_u_Dot + WP_v_Dot + WP_u_Dot) 44 | 45 | s_t = K.dot(s_t_hat, v) # v 46 | s_t = K.batch_flatten(s_t) 47 | a_t = softmax(s_t, mask=uQ_mask, axis=1) 48 | c_t = K.batch_dot(a_t, uQ, axes=[1, 1]) 49 | 50 | GRU_inputs = K.concatenate([uP_t, c_t]) 51 | g = K.sigmoid(K.dot(GRU_inputs, W_g1)) # W_g1 52 | GRU_inputs = g * GRU_inputs 53 | vP_t, s = super(QuestionAttnGRU, self).step(GRU_inputs, states) 54 | 55 | return vP_t, s 56 | -------------------------------------------------------------------------------- /layers/QuestionPooling.py: -------------------------------------------------------------------------------- 1 | # from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | from keras import backend as K 6 | from keras.layers import Layer 7 | from keras.layers.wrappers import TimeDistributed 8 | 9 | from helpers import compute_mask, softmax 10 | 11 | class QuestionPooling(Layer): 12 | 13 | def __init__(self, **kwargs): 14 | super(QuestionPooling, self).__init__(**kwargs) 15 | self.supports_masking = True 16 | 17 | def compute_output_shape(self, input_shape): 18 | assert(isinstance(input_shape, list) and len(input_shape) == 5) 19 | 20 | input_shape = input_shape[0] 21 | B, Q, H = input_shape 22 | 23 | return (B, H) 24 | 25 | def build(self, input_shape): 26 | assert(isinstance(input_shape, list) and len(input_shape) == 5) 27 | input_shape = input_shape[0] 28 | 29 | B, Q, H_ = input_shape 30 | H = H_ // 2 31 | 32 | def call(self, inputs, mask=None): 33 | assert(isinstance(inputs, list) and len(inputs) == 5) 34 | uQ, WQ_u, WQ_v, v, VQ_r = inputs 35 | uQ_mask = mask[0] if mask is not None else None 36 | 37 | ones = K.ones_like(K.sum(uQ, axis=1, keepdims=True)) # (B, 1, 2H) 38 | s_hat = K.dot(uQ, WQ_u) 39 | s_hat += K.dot(ones, K.dot(WQ_v, VQ_r)) 40 | s_hat = K.tanh(s_hat) 41 | s = K.dot(s_hat, v) 42 | s = K.batch_flatten(s) 43 | 44 | a = softmax(s, mask=uQ_mask, axis=1) 45 | 46 | rQ = K.batch_dot(uQ, a, axes=[1, 1]) 47 | 48 | return rQ 49 | 50 | def compute_mask(self, input, mask=None): 51 | return None 52 | -------------------------------------------------------------------------------- /layers/SelfAttnGRU.py: -------------------------------------------------------------------------------- 1 | # from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | from keras import backend as K 6 | 7 | from WrappedGRU import WrappedGRU 8 | from helpers import compute_mask, softmax 9 | 10 | class SelfAttnGRU(WrappedGRU): 11 | 12 | def build(self, input_shape): 13 | H = self.units 14 | assert(isinstance(input_shape, list)) 15 | 16 | nb_inputs = len(input_shape) 17 | assert(nb_inputs >= 2) 18 | 19 | assert(len(input_shape[0]) == 3) 20 | B, P, H_ = input_shape[0] 21 | assert(H_ == H) 22 | 23 | 24 | assert(len(input_shape[1]) == 3) 25 | B, P_, H_ = input_shape[1] 26 | assert(P_ == P) 27 | assert(H_ == H) 28 | 29 | self.input_spec = [None] 30 | super(SelfAttnGRU, self).build(input_shape=(B, P, 2 * H)) 31 | self.GRU_input_spec = self.input_spec 32 | self.input_spec = [None] * nb_inputs 33 | 34 | def step(self, inputs, states): 35 | vP_t = inputs 36 | hP_tm1 = states[0] 37 | _ = states[1:3] # ignore internal dropout/masks 38 | vP, WP_v, WPP_v, v, W_g2 = states[3:8] 39 | vP_mask, = states[8:] 40 | 41 | WP_v_Dot = K.dot(vP, WP_v) 42 | WPP_v_Dot = K.dot(K.expand_dims(vP_t, axis=1), WPP_v) 43 | 44 | s_t_hat = K.tanh(WPP_v_Dot + WP_v_Dot) 45 | s_t = K.dot(s_t_hat, v) 46 | s_t = K.batch_flatten(s_t) 47 | 48 | a_t = softmax(s_t, mask=vP_mask, axis=1) 49 | 50 | c_t = K.batch_dot(a_t, vP, axes=[1, 1]) 51 | 52 | GRU_inputs = K.concatenate([vP_t, c_t]) 53 | g = K.sigmoid(K.dot(GRU_inputs, W_g2)) 54 | GRU_inputs = g * GRU_inputs 55 | 56 | hP_t, s = super(SelfAttnGRU, self).step(GRU_inputs, states) 57 | 58 | return hP_t, s 59 | -------------------------------------------------------------------------------- /layers/SharedWeight.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | 3 | from keras import initializers 4 | from keras import regularizers 5 | 6 | from keras.engine.topology import Node 7 | from keras.layers import Layer, InputLayer 8 | 9 | class SharedWeightLayer(InputLayer): 10 | def __init__(self, 11 | size, 12 | initializer='glorot_uniform', 13 | regularizer=None, 14 | name=None, 15 | **kwargs): 16 | self.size = tuple(size) 17 | self.initializer = initializers.get(initializer) 18 | self.regularizer = regularizers.get(regularizer) 19 | 20 | if not name: 21 | prefix = 'shared_weight' 22 | name = prefix + '_' + str(K.get_uid(prefix)) 23 | 24 | Layer.__init__(self, name=name, **kwargs) 25 | 26 | with K.name_scope(self.name): 27 | self.kernel = self.add_weight(shape=self.size, 28 | initializer=self.initializer, 29 | name='kernel', 30 | regularizer=self.regularizer) 31 | 32 | 33 | self.trainable = True 34 | self.built = True 35 | # self.sparse = sparse 36 | 37 | input_tensor = self.kernel * 1.0 38 | 39 | self.is_placeholder = False 40 | input_tensor._keras_shape = self.size 41 | 42 | input_tensor._uses_learning_phase = False 43 | input_tensor._keras_history = (self, 0, 0) 44 | 45 | Node(self, 46 | inbound_layers=[], 47 | node_indices=[], 48 | tensor_indices=[], 49 | input_tensors=[input_tensor], 50 | output_tensors=[input_tensor], 51 | input_masks=[None], 52 | output_masks=[None], 53 | input_shapes=[self.size], 54 | output_shapes=[self.size]) 55 | 56 | def get_config(self): 57 | config = { 58 | 'size': self.size, 59 | 'initializer': initializers.serialize(self.initializer), 60 | 'regularizer': regularizers.serialize(self.regularizer) 61 | } 62 | base_config = Layer.get_config(self) 63 | return dict(list(base_config.items()) + list(config.items())) 64 | 65 | def SharedWeight(**kwargs): 66 | input_layer = SharedWeightLayer(**kwargs) 67 | 68 | outputs = input_layer.inbound_nodes[0].output_tensors 69 | if len(outputs) == 1: 70 | return outputs[0] 71 | else: 72 | return outputs 73 | -------------------------------------------------------------------------------- /layers/Slice.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras import backend as K 3 | from keras.engine import Layer, InputSpec 4 | 5 | class Slice(Layer): 6 | def __init__(self, indices, axis=1, **kwargs): 7 | self.supports_masking = True 8 | self.axis = axis 9 | 10 | if isinstance(indices, slice): 11 | self.indices = (indices.start, indices.stop, indices.step) 12 | else: 13 | self.indices = indices 14 | 15 | self.slices = [ slice(None) ] * self.axis 16 | 17 | if isinstance(self.indices, int): 18 | self.slices.append(self.indices) 19 | elif isinstance(self.indices, (list, tuple)): 20 | self.slices.append(slice(*self.indices)) 21 | else: 22 | raise TypeError("indices must be int or slice") 23 | 24 | super(Slice, self).__init__(**kwargs) 25 | 26 | def call(self, inputs, mask=None): 27 | return inputs[self.slices] 28 | 29 | def compute_output_shape(self, input_shape): 30 | input_shape = list(input_shape) 31 | for i, slice in enumerate(self.slices): 32 | if i == self.axis: 33 | continue 34 | start = slice.start or 0 35 | stop = slice.stop or input_shape[i] 36 | step = slice.step or 1 37 | input_shape[i] = None if stop is None else (stop - start) // step 38 | del input_shape[self.axis] 39 | 40 | return tuple(input_shape) 41 | 42 | def compute_mask(self, x, mask=None): 43 | if mask is None: 44 | return mask 45 | if self.axis == 1: 46 | return mask[self.slices] 47 | else: 48 | return mask 49 | 50 | def get_config(self): 51 | config = {'axis': self.axis, 52 | 'indices': self.indices} 53 | base_config = super(Slice, self).get_config() 54 | return dict(list(base_config.items()) + list(config.items())) 55 | -------------------------------------------------------------------------------- /layers/VariationalDropout.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | from keras.engine.topology import Layer 3 | 4 | 5 | class VariationalDropout(Layer): 6 | 7 | def __init__(self, rate, noise_shape=None, seed=None, **kwargs): 8 | super(VariationalDropout, self).__init__(**kwargs) 9 | self.rate = min(1., max(0., rate)) 10 | self.noise_shape = noise_shape 11 | self.seed = seed 12 | self.supports_masking = True 13 | 14 | def call(self, inputs, training=None): 15 | if 0. < self.rate < 1.: 16 | symbolic_shape = K.shape(inputs) 17 | noise_shape = [shape if shape > 0 else symbolic_shape[axis] 18 | for axis, shape in enumerate(self.noise_shape)] 19 | noise_shape = tuple(noise_shape) 20 | 21 | def dropped_inputs(): 22 | return K.dropout(inputs, self.rate, noise_shape, seed=self.seed) 23 | 24 | return K.in_train_phase(dropped_inputs, inputs, training=training) 25 | 26 | return inputs 27 | 28 | def get_config(self): 29 | config = {'rate': self.rate, 30 | 'noise_shape': self.noise_shape, 31 | 'seed': self.seed} 32 | base_config = super(VariationalDropout, self).get_config() 33 | return dict(list(base_config.items()) + list(config.items())) 34 | -------------------------------------------------------------------------------- /layers/WrappedGRU.py: -------------------------------------------------------------------------------- 1 | # from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | from keras import backend as K 6 | from keras.layers import Layer 7 | from keras.layers.wrappers import TimeDistributed 8 | from keras.layers.recurrent import GRU 9 | 10 | class WrappedGRU(GRU): 11 | 12 | def __init__(self, initial_state_provided=False, **kwargs): 13 | kwargs['implementation'] = kwargs.get('implementation', 2) 14 | assert(kwargs['implementation'] == 2) 15 | 16 | super(WrappedGRU, self).__init__(**kwargs) 17 | self.input_spec = None 18 | self.initial_state_provided = initial_state_provided 19 | 20 | 21 | def call(self, inputs, mask=None, training=None, initial_state=None): 22 | if self.initial_state_provided: 23 | initial_state = inputs[-1:] 24 | inputs = inputs[:-1] 25 | 26 | initial_state_mask = mask[-1:] 27 | mask = mask[:-1] if mask is not None else None 28 | 29 | self._non_sequences = inputs[1:] 30 | inputs = inputs[:1] 31 | 32 | self._mask_non_sequences = [] 33 | if mask is not None: 34 | self._mask_non_sequences = mask[1:] 35 | mask = mask[:1] 36 | self._mask_non_sequences = [mask for mask in self._mask_non_sequences 37 | if mask is not None] 38 | 39 | if self.initial_state_provided: 40 | assert(len(inputs) == len(initial_state)) 41 | inputs += initial_state 42 | 43 | if len(inputs) == 1: 44 | inputs = inputs[0] 45 | 46 | if isinstance(mask, list) and len(mask) == 1: 47 | mask = mask[0] 48 | 49 | return super(WrappedGRU, self).call(inputs, mask, training) 50 | 51 | def get_constants(self, inputs, training=None): 52 | constants = super(WrappedGRU, self).get_constants(inputs, training=training) 53 | constants += self._non_sequences 54 | constants += self._mask_non_sequences 55 | return constants 56 | 57 | def get_config(self): 58 | config = {'initial_state_provided': self.initial_state_provided} 59 | base_config = super(WrappedGRU, self).get_config() 60 | return dict(list(base_config.items()) + list(config.items())) 61 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from QuestionAttnGRU import QuestionAttnGRU 2 | from SelfAttnGRU import SelfAttnGRU 3 | from PointerGRU import PointerGRU 4 | from QuestionPooling import QuestionPooling 5 | from Argmax import Argmax 6 | from Slice import Slice 7 | from SharedWeight import SharedWeightLayer, SharedWeight 8 | from VariationalDropout import VariationalDropout 9 | -------------------------------------------------------------------------------- /layers/helpers.py: -------------------------------------------------------------------------------- 1 | # from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | from keras import backend as K 6 | 7 | def softmax(x, axis, mask=None): 8 | if mask is None: 9 | mask = K.constant(True) 10 | mask = K.cast(mask, K.floatx()) 11 | if K.ndim(x) is K.ndim(mask) + 1: 12 | mask = K.expand_dims(mask) 13 | 14 | m = K.max(x, axis=axis, keepdims=True) 15 | e = K.exp(x - m) * mask 16 | s = K.sum(e, axis=axis, keepdims=True) 17 | s += K.cast(K.cast(s < K.epsilon(), K.floatx()) * K.epsilon(), K.floatx()) 18 | return e / s 19 | 20 | def compute_mask(x, mask_value=0): 21 | boolean_mask = K.any(K.not_equal(x, mask_value), axis=-1, keepdims=False) 22 | return K.cast(boolean_mask, K.floatx()) -------------------------------------------------------------------------------- /lib/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | 6 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | from keras import backend as K 6 | from keras.models import Model, Sequential 7 | from keras.layers import Input, InputLayer 8 | from keras.layers.core import Dense, RepeatVector, Masking, Dropout 9 | from keras.layers.merge import Concatenate 10 | from keras.layers.wrappers import Bidirectional, TimeDistributed 11 | from keras.layers.recurrent import GRU 12 | from keras.layers.embeddings import Embedding 13 | from keras.layers.pooling import GlobalMaxPooling1D 14 | 15 | from layers import QuestionAttnGRU 16 | from layers import SelfAttnGRU 17 | from layers import PointerGRU 18 | from layers import QuestionPooling 19 | from layers import VariationalDropout 20 | from layers import Slice 21 | from layers import SharedWeight 22 | 23 | class RNet(Model): 24 | def __init__(self, inputs=None, outputs=None, 25 | N=None, M=None, C=25, unroll=False, 26 | hdim=75, word2vec_dim=300, 27 | dropout_rate=0, 28 | char_level_embeddings=False, 29 | **kwargs): 30 | # Load model from config 31 | if inputs is not None and outputs is not None: 32 | super(RNet, self).__init__(inputs=inputs, 33 | outputs=outputs, 34 | **kwargs) 35 | return 36 | 37 | '''Dimensions''' 38 | B = None 39 | H = hdim 40 | W = word2vec_dim 41 | 42 | v = SharedWeight(size=(H, 1), name='v') 43 | WQ_u = SharedWeight(size=(2 * H, H), name='WQ_u') 44 | WP_u = SharedWeight(size=(2 * H, H), name='WP_u') 45 | WP_v = SharedWeight(size=(H, H), name='WP_v') 46 | W_g1 = SharedWeight(size=(4 * H, 4 * H), name='W_g1') 47 | W_g2 = SharedWeight(size=(2 * H, 2 * H), name='W_g2') 48 | WP_h = SharedWeight(size=(2 * H, H), name='WP_h') 49 | Wa_h = SharedWeight(size=(2 * H, H), name='Wa_h') 50 | WQ_v = SharedWeight(size=(2 * H, H), name='WQ_v') 51 | WPP_v = SharedWeight(size=(H, H), name='WPP_v') 52 | VQ_r = SharedWeight(size=(H, H), name='VQ_r') 53 | 54 | shared_weights = [v, WQ_u, WP_u, WP_v, W_g1, W_g2, WP_h, Wa_h, WQ_v, WPP_v, VQ_r] 55 | 56 | P_vecs = Input(shape=(N, W), name='P_vecs') 57 | Q_vecs = Input(shape=(M, W), name='Q_vecs') 58 | 59 | if char_level_embeddings: 60 | P_str = Input(shape=(N, C), dtype='int32', name='P_str') 61 | Q_str = Input(shape=(M, C), dtype='int32', name='Q_str') 62 | input_placeholders = [P_vecs, P_str, Q_vecs, Q_str] 63 | 64 | char_embedding_layer = TimeDistributed(Sequential([ 65 | InputLayer(input_shape=(C,), dtype='int32'), 66 | Embedding(input_dim=127, output_dim=H, mask_zero=True), 67 | Bidirectional(GRU(units=H)) 68 | ])) 69 | 70 | # char_embedding_layer.build(input_shape=(None, None, C)) 71 | 72 | P_char_embeddings = char_embedding_layer(P_str) 73 | Q_char_embeddings = char_embedding_layer(Q_str) 74 | 75 | P = Concatenate() ([P_vecs, P_char_embeddings]) 76 | Q = Concatenate() ([Q_vecs, Q_char_embeddings]) 77 | 78 | else: 79 | P = P_vecs 80 | Q = Q_vecs 81 | input_placeholders = [P_vecs, Q_vecs] 82 | 83 | uP = Masking() (P) 84 | for i in range(3): 85 | uP = Bidirectional(GRU(units=H, 86 | return_sequences=True, 87 | dropout=dropout_rate, 88 | unroll=unroll)) (uP) 89 | uP = VariationalDropout(rate=dropout_rate, noise_shape=(None, 1, 2 * H), name='uP') (uP) 90 | 91 | uQ = Masking() (Q) 92 | for i in range(3): 93 | uQ = Bidirectional(GRU(units=H, 94 | return_sequences=True, 95 | dropout=dropout_rate, 96 | unroll=unroll)) (uQ) 97 | uQ = VariationalDropout(rate=dropout_rate, noise_shape=(None, 1, 2 * H), name='uQ') (uQ) 98 | 99 | vP = QuestionAttnGRU(units=H, 100 | return_sequences=True, 101 | unroll=unroll) ([ 102 | uP, uQ, 103 | WQ_u, WP_v, WP_u, v, W_g1 104 | ]) 105 | vP = VariationalDropout(rate=dropout_rate, noise_shape=(None, 1, H), name='vP') (vP) 106 | 107 | hP = Bidirectional(SelfAttnGRU(units=H, 108 | return_sequences=True, 109 | unroll=unroll)) ([ 110 | vP, vP, 111 | WP_v, WPP_v, v, W_g2 112 | ]) 113 | 114 | hP = VariationalDropout(rate=dropout_rate, noise_shape=(None, 1, 2 * H), name='hP') (hP) 115 | 116 | gP = Bidirectional(GRU(units=H, 117 | return_sequences=True, 118 | unroll=unroll)) (hP) 119 | 120 | rQ = QuestionPooling() ([uQ, WQ_u, WQ_v, v, VQ_r]) 121 | rQ = Dropout(rate=dropout_rate, name='rQ') (rQ) 122 | 123 | fake_input = GlobalMaxPooling1D() (P) 124 | fake_input = RepeatVector(n=2, name='fake_input') (fake_input) 125 | 126 | ps = PointerGRU(units=2 * H, 127 | return_sequences=True, 128 | initial_state_provided=True, 129 | name='ps', 130 | unroll=unroll) ([ 131 | fake_input, gP, 132 | WP_h, Wa_h, v, 133 | rQ 134 | ]) 135 | 136 | answer_start = Slice(0, name='answer_start') (ps) 137 | answer_end = Slice(1, name='answer_end') (ps) 138 | 139 | inputs = input_placeholders + shared_weights 140 | outputs = [answer_start, answer_end] 141 | 142 | super(RNet, self).__init__(inputs=inputs, 143 | outputs=outputs, 144 | **kwargs) 145 | -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | 6 | -------------------------------------------------------------------------------- /parse_data.py: -------------------------------------------------------------------------------- 1 | # from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import json 6 | import argparse 7 | import random 8 | 9 | if __name__ == '__main__': 10 | random.seed(42) 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('data', type=str, help='Path to the dataset file') 14 | parser.add_argument('--outfile', default='data/train_parsed.json', 15 | type=str, help='Desired path to output train json') 16 | parser.add_argument('--outfile_valid', default='data/valid_parsed.json', 17 | type=str, help='Desired path to output valid json') 18 | parser.add_argument('--train_ratio', default=1., type=float, 19 | help='ratio for train/val split') 20 | args = parser.parse_args() 21 | 22 | with open(args.data, 'r') as f: 23 | data = json.load(f) 24 | 25 | data = data['data'] 26 | 27 | # Lists containing ContextQuestionAnswerS 28 | train_cqas = [] 29 | valid_cqas = [] 30 | 31 | for topic in data: 32 | cqas = [{'context': paragraph['context'], 33 | 'id': qa['id'], 34 | 'question': qa['question'], 35 | 'answer': qa['answers'][0]['text'], 36 | 'answer_start': qa['answers'][0]['answer_start'], 37 | 'answer_end': qa['answers'][0]['answer_start'] + \ 38 | len(qa['answers'][0]['text']) - 1, 39 | 'topic': topic['title'] } 40 | for paragraph in topic['paragraphs'] 41 | for qa in paragraph['qas']] 42 | 43 | if random.random() < args.train_ratio: 44 | train_cqas += cqas 45 | else: 46 | valid_cqas += cqas 47 | 48 | if args.train_ratio == 1.: 49 | print('Writing to file {}...'.format(args.outfile), end='') 50 | with open(args.outfile, 'w') as fd: 51 | json.dump(train_cqas, fd) 52 | print('Done!') 53 | else: 54 | print('Train/Val ratio is {}'.format(len(train_cqas) / len(valid_cqas))) 55 | print('Writing to files {}, {}...'.format(args.outfile, 56 | args.outfile_valid), end='') 57 | with open(args.outfile, 'w') as fd: 58 | json.dump(train_cqas, fd) 59 | with open(args.outfile_valid, 'w') as fd: 60 | json.dump(valid_cqas, fd) 61 | print('Done!') 62 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import argparse 7 | import json 8 | import os 9 | 10 | from tqdm import tqdm 11 | 12 | from keras import backend as K 13 | from keras.models import Model, load_model 14 | 15 | from layers import Argmax 16 | from data import BatchGen, load_dataset 17 | from utils import custom_objects 18 | 19 | from preprocessing import CoreNLP_tokenizer 20 | 21 | np.random.seed(10) 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--batch_size', default=70, type=int, help='Batch size') 25 | parser.add_argument('--dev_data', default='data/dev_data.pkl', type=str, 26 | help='Validation Set') 27 | parser.add_argument('model', type=str, help='Model to run') 28 | parser.add_argument('prediction', type=str, default='pred.json', 29 | help='Outfile to save predictions') 30 | args = parser.parse_args() 31 | 32 | print('Preparing model...', end='') 33 | model = load_model(args.model, custom_objects()) 34 | 35 | inputs = model.inputs 36 | outputs = [ Argmax() (output) for output in model.outputs ] 37 | 38 | predicting_model = Model(inputs, outputs) 39 | print('Done!') 40 | 41 | print('Loading data...', end='') 42 | dev_data = load_dataset(args.dev_data) 43 | char_level_embeddings = len(dev_data[0]) is 4 44 | maxlen = [300, 300, 30, 30] if char_level_embeddings else [300, 30] 45 | dev_data_gen = BatchGen(*dev_data, batch_size=args.batch_size, shuffle=False, group=False, maxlen=maxlen) 46 | 47 | with open('data/dev_parsed.json') as f: 48 | samples = json.load(f) 49 | print('Done!') 50 | 51 | print('Running predicting model...', end='') 52 | predictions = predicting_model.predict_generator(generator=dev_data_gen, 53 | steps=dev_data_gen.steps(), 54 | verbose=1) 55 | print('Done!') 56 | 57 | print('Initiating CoreNLP service connection... ', end='') 58 | tokenize = CoreNLP_tokenizer() 59 | print('Done!') 60 | 61 | print('Preparing prediction file...', end='') 62 | contexts = [sample['context'] for sample in samples] 63 | 64 | answers = {} 65 | for sample, context, start, end in tqdm(zip(samples, contexts, *predictions)): 66 | id = sample['id'] 67 | context_tokens, _ = tokenize(context) 68 | answer = ' '.join(context_tokens[start : end+1]) 69 | answers[id] = answer 70 | print('Done!') 71 | 72 | print('Writing predictions to file {}...'.format(args.prediction), end='') 73 | with open(args.prediction, 'w') as f: 74 | json.dump(answers, f) 75 | print('Done!') 76 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | # from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import json 7 | import os 8 | import argparse 9 | import cPickle as pickle 10 | 11 | from os import path 12 | from gensim.scripts.glove2word2vec import glove2word2vec 13 | from tqdm import tqdm 14 | from unidecode import unidecode 15 | 16 | from utils import CoreNLP_path, get_glove_file_path 17 | from stanford_corenlp_pywrapper import CoreNLP 18 | from gensim.models import KeyedVectors 19 | from keras.preprocessing.sequence import pad_sequences 20 | 21 | 22 | def CoreNLP_tokenizer(): 23 | proc = CoreNLP(configdict={'annotators': 'tokenize,ssplit'}, 24 | corenlp_jars=[path.join(CoreNLP_path(), '*')]) 25 | 26 | def tokenize_context(context): 27 | parsed = proc.parse_doc(context) 28 | tokens = [] 29 | char_offsets = [] 30 | for sentence in parsed['sentences']: 31 | tokens += sentence['tokens'] 32 | char_offsets += sentence['char_offsets'] 33 | 34 | return tokens, char_offsets 35 | 36 | return tokenize_context 37 | 38 | 39 | def word2vec(word2vec_path): 40 | # Download word2vec data if it's not present yet 41 | if not path.exists(word2vec_path): 42 | glove_file_path = get_glove_file_path() 43 | print('Converting Glove to word2vec...', end='') 44 | glove2word2vec(glove_file_path, word2vec_path) # Convert glove to word2vec 45 | os.remove(glove_file_path) # Remove glove file and keep only word2vec 46 | print('Done') 47 | 48 | print('Reading word2vec data... ', end='') 49 | model = KeyedVectors.load_word2vec_format(word2vec_path) 50 | print('Done') 51 | 52 | def get_word_vector(word): 53 | try: 54 | return model[word] 55 | except KeyError: 56 | return np.zeros(model.vector_size) 57 | 58 | return get_word_vector 59 | 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--word2vec_path', type=str, 64 | default='data/word2vec_from_glove_300.vec', 65 | help='Word2Vec vectors file path') 66 | parser.add_argument('--outfile', type=str, default='data/tmp.pkl', 67 | help='Desired path to output pickle') 68 | parser.add_argument('--include_str', action='store_true', 69 | help='Include strings') 70 | parser.add_argument('data', type=str, help='Data json') 71 | args = parser.parse_args() 72 | 73 | if not args.outfile.endswith('.pkl'): 74 | args.outfile += '.pkl' 75 | 76 | print('Reading SQuAD data... ', end='') 77 | with open(args.data) as fd: 78 | samples = json.load(fd) 79 | print('Done!') 80 | 81 | print('Initiating CoreNLP service connection... ', end='') 82 | tokenize = CoreNLP_tokenizer() 83 | print('Done!') 84 | 85 | word_vector = word2vec(args.word2vec_path) 86 | 87 | def parse_sample(context, question, answer_start, answer_end, **kwargs): 88 | inputs = [] 89 | targets = [] 90 | 91 | tokens, char_offsets = tokenize(context) 92 | try: 93 | answer_start = [s <= answer_start < e 94 | for s, e in char_offsets].index(True) 95 | targets.append(answer_start) 96 | answer_end = [s <= answer_end < e 97 | for s, e in char_offsets].index(True) 98 | targets.append(answer_end) 99 | except ValueError: 100 | return None 101 | 102 | tokens = [unidecode(token) for token in tokens] 103 | 104 | context_vecs = [word_vector(token) for token in tokens] 105 | context_vecs = np.vstack(context_vecs).astype(np.float32) 106 | inputs.append(context_vecs) 107 | 108 | if args.include_str: 109 | context_str = [np.fromstring(token, dtype=np.uint8).astype(np.int32) 110 | for token in tokens] 111 | context_str = pad_sequences(context_str, maxlen=25) 112 | inputs.append(context_str) 113 | 114 | tokens, char_offsets = tokenize(question) 115 | tokens = [unidecode(token) for token in tokens] 116 | 117 | question_vecs = [word_vector(token) for token in tokens] 118 | question_vecs = np.vstack(question_vecs).astype(np.float32) 119 | inputs.append(question_vecs) 120 | 121 | if args.include_str: 122 | question_str = [np.fromstring(token, dtype=np.uint8).astype(np.int32) 123 | for token in tokens] 124 | question_str = pad_sequences(question_str, maxlen=25) 125 | inputs.append(question_str) 126 | 127 | return [inputs, targets] 128 | 129 | print('Parsing samples... ', end='') 130 | samples = [parse_sample(**sample) for sample in tqdm(samples)] 131 | samples = [sample for sample in samples if sample is not None] 132 | print('Done!') 133 | 134 | # Transpose 135 | def transpose(x): 136 | return map(list, zip(*x)) 137 | 138 | data = [transpose(input) for input in transpose(samples)] 139 | 140 | 141 | print('Writing to file {}... '.format(args.outfile), end='') 142 | with open(args.outfile, 'wb') as fd: 143 | pickle.dump(data, fd, protocol=pickle.HIGHEST_PROTOCOL) 144 | print('Done!') 145 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import argparse 7 | 8 | import keras 9 | from keras.callbacks import ModelCheckpoint 10 | 11 | from model import RNet 12 | from data import BatchGen, load_dataset 13 | 14 | import sys 15 | sys.setrecursionlimit(100000) 16 | 17 | np.random.seed(10) 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--hdim', default=75, help='Model to evaluate', type=int) 21 | parser.add_argument('--batch_size', default=70, help='Batch size', type=int) 22 | parser.add_argument('--nb_epochs', default=50, help='Number of Epochs', type=int) 23 | parser.add_argument('--optimizer', default='Adadelta', help='Optimizer', type=str) 24 | parser.add_argument('--lr', default=None, help='Learning rate', type=float) 25 | parser.add_argument('--name', default='', help='Model dump name prefix', type=str) 26 | parser.add_argument('--loss', default='categorical_crossentropy', help='Loss', type=str) 27 | 28 | parser.add_argument('--dropout', default=0, type=float) 29 | parser.add_argument('--char_level_embeddings', action='store_true') 30 | 31 | parser.add_argument('--train_data', default='data/train_data.pkl', help='Train Set', type=str) 32 | parser.add_argument('--valid_data', default='data/valid_data.pkl', help='Validation Set', type=str) 33 | 34 | # parser.add_argument('model', help='Model to evaluate', type=str) 35 | args = parser.parse_args() 36 | 37 | print('Creating the model...', end='') 38 | model = RNet(hdim=args.hdim, dropout_rate=args.dropout, N=None, M=None, 39 | char_level_embeddings=args.char_level_embeddings) 40 | print('Done!') 41 | 42 | print('Compiling Keras model...', end='') 43 | optimizer_config = {'class_name': args.optimizer, 44 | 'config': {'lr': args.lr} if args.lr else {}} 45 | model.compile(optimizer=optimizer_config, 46 | loss=args.loss, 47 | metrics=['accuracy']) 48 | print('Done!') 49 | 50 | print('Loading datasets...', end='') 51 | train_data = load_dataset(args.train_data) 52 | valid_data = load_dataset(args.valid_data) 53 | print('Done!') 54 | 55 | print('Preparing generators...', end='') 56 | maxlen = [300, 300, 30, 30] if args.char_level_embeddings else [300, 30] 57 | 58 | train_data_gen = BatchGen(*train_data, batch_size=args.batch_size, shuffle=False, group=True, maxlen=maxlen) 59 | valid_data_gen = BatchGen(*valid_data, batch_size=args.batch_size, shuffle=False, group=True, maxlen=maxlen) 60 | print('Done!') 61 | 62 | print('Training...', end='') 63 | 64 | path = 'models/' + args.name + '{epoch}-t{loss}-v{val_loss}.model' 65 | 66 | model.fit_generator(generator=train_data_gen, 67 | steps_per_epoch=train_data_gen.steps(), 68 | validation_data=valid_data_gen, 69 | validation_steps=valid_data_gen.steps(), 70 | epochs=args.nb_epochs, 71 | callbacks=[ 72 | ModelCheckpoint(path, verbose=1, save_best_only=True) 73 | ]) 74 | print('Done!') 75 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import os 6 | from os import path 7 | from keras.utils.data_utils import get_file 8 | 9 | def custom_objects(): 10 | from layers import * 11 | from model import * 12 | return locals() 13 | 14 | def CoreNLP_path(): 15 | SERVER = 'http://nlp.stanford.edu/software/' 16 | VERSION = 'stanford-corenlp-full-2017-06-09' 17 | 18 | origin = '{server}{version}.zip'.format(server=SERVER, version=VERSION) 19 | lib_dir = path.join(path.abspath(path.dirname(__file__)), 'lib') 20 | 21 | get_file('/tmp/stanford-corenlp.zip', 22 | origin=origin, 23 | cache_dir=lib_dir, 24 | cache_subdir='', 25 | extract=True) 26 | 27 | return path.join(lib_dir, VERSION) 28 | 29 | 30 | def get_glove_file_path(): 31 | SERVER = 'http://nlp.stanford.edu/data/' 32 | VERSION = 'glove.840B.300d' 33 | 34 | origin = '{server}{version}.zip'.format(server=SERVER, version=VERSION) 35 | cache_dir = path.join(path.abspath(path.dirname(__file__)), 'data') 36 | 37 | fname = '/tmp/glove.zip' 38 | get_file(fname, 39 | origin=origin, 40 | cache_dir=cache_dir, 41 | cache_subdir='', 42 | extract=True) 43 | 44 | # Remove unnecessary .zip file and keep only extracted .txt version 45 | os.remove(fname) 46 | return path.join(cache_dir, VERSION) + '.txt' 47 | --------------------------------------------------------------------------------