├── LICENSE ├── README.md ├── config.py ├── download.sh ├── evaluate-v1.1.py ├── func.py ├── img ├── em.jpg └── f1.jpg ├── inference.py ├── main.py ├── model.py ├── prepro.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 HKUST-KnowComp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # R-Net 2 | * A Tensorflow implementation of [R-NET: MACHINE READING COMPREHENSION WITH SELF-MATCHING NETWORKS](https://www.microsoft.com/en-us/research/wp-content/uploads/2017/05/r-net.pdf). This project is specially designed for the [SQuAD](https://arxiv.org/pdf/1606.05250.pdf) dataset. 3 | * Should you have any question, please contact Wenxuan Zhou (wzhouad@connect.ust.hk). 4 | 5 | ## Requirements 6 | 7 | There have been a lot of known problems caused by using different software versions. Please check your versions before opening issues or emailing me. 8 | 9 | #### General 10 | * Python >= 3.4 11 | * unzip, wget 12 | #### Python Packages 13 | * tensorflow-gpu >= 1.5.0 14 | * spaCy >= 2.0.0 15 | * tqdm 16 | * ujson 17 | 18 | ## Usage 19 | 20 | To download and preprocess the data, run 21 | 22 | ```bash 23 | # download SQuAD and Glove 24 | sh download.sh 25 | # preprocess the data 26 | python config.py --mode prepro 27 | ``` 28 | 29 | Hyper parameters are stored in config.py. To debug/train/test the model, run 30 | 31 | ```bash 32 | python config.py --mode debug/train/test 33 | ``` 34 | 35 | To get the official score, run 36 | 37 | ```bash 38 | python evaluate-v1.1.py ~/data/squad/dev-v1.1.json log/answer/answer.json 39 | ``` 40 | 41 | The default directory for tensorboard log file is `log/event` 42 | 43 | See release for trained model. 44 | 45 | ## Detailed Implementaion 46 | 47 | * The original paper uses additive attention, which consumes lots of memory. This project adopts scaled multiplicative attention presented in [Attention Is All You Need](https://arxiv.org/abs/1706.03762). 48 | * This project adopts variational dropout presented in [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](https://arxiv.org/abs/1512.05287). 49 | * To solve the degradation problem in stacked RNN, outputs of each layer are concatenated to produce the final output. 50 | * When the loss on dev set increases in a certain period, the learning rate is halved. 51 | * During prediction, the project adopts search method presented in [Machine Comprehension Using Match-LSTM and Answer Pointer](https://arxiv.org/abs/1608.07905). 52 | * To address efficiency issue, this implementation uses bucketing method (contributed by xiongyifan) and CudnnGRU. The bucketing method can speedup training, but will lower the F1 score by 0.3%. 53 | 54 | ## Performance 55 | 56 | #### Score 57 | 58 | ||EM|F1| 59 | |---|---|---| 60 | |original paper|71.1|79.5| 61 | |this project|71.07|79.51| 62 | 63 | 64 | 65 | 66 | 67 | #### Training Time (s/it) 68 | 69 | ||Native|Native + Bucket|Cudnn|Cudnn + Bucket| 70 | |---|---|---|---|---| 71 | |E5-2640|6.21|3.56|-|-| 72 | |TITAN X|2.56|1.31|0.41|0.28| 73 | 74 | ## Extensions 75 | 76 | These settings may increase the score but not used in the model by default. You can turn these settings on in `config.py`. 77 | 78 | * [Pretrained GloVe character embedding](https://github.com/minimaxir/char-embeddings). Contributed by yanghanxy. 79 | * [Fasttext Embedding](https://fasttext.cc/docs/en/english-vectors.html). Contributed by xiongyifan. May increase the F1 by 1% (reported by xiongyifan). 80 | 81 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | from prepro import prepro 5 | from main import train, test 6 | 7 | flags = tf.flags 8 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 9 | 10 | home = os.path.expanduser("~") 11 | train_file = os.path.join(home, "data", "squad", "train-v1.1.json") 12 | dev_file = os.path.join(home, "data", "squad", "dev-v1.1.json") 13 | test_file = os.path.join(home, "data", "squad", "dev-v1.1.json") 14 | glove_word_file = os.path.join(home, "data", "glove", "glove.840B.300d.txt") 15 | 16 | target_dir = "data" 17 | log_dir = "log/event" 18 | save_dir = "log/model" 19 | answer_dir = "log/answer" 20 | train_record_file = os.path.join(target_dir, "train.tfrecords") 21 | dev_record_file = os.path.join(target_dir, "dev.tfrecords") 22 | test_record_file = os.path.join(target_dir, "test.tfrecords") 23 | word_emb_file = os.path.join(target_dir, "word_emb.json") 24 | char_emb_file = os.path.join(target_dir, "char_emb.json") 25 | train_eval = os.path.join(target_dir, "train_eval.json") 26 | dev_eval = os.path.join(target_dir, "dev_eval.json") 27 | test_eval = os.path.join(target_dir, "test_eval.json") 28 | dev_meta = os.path.join(target_dir, "dev_meta.json") 29 | test_meta = os.path.join(target_dir, "test_meta.json") 30 | word2idx_file = os.path.join(target_dir, "word2idx.json") 31 | char2idx_file = os.path.join(target_dir, "char2idx.json") 32 | answer_file = os.path.join(answer_dir, "answer.json") 33 | 34 | if not os.path.exists(target_dir): 35 | os.makedirs(target_dir) 36 | if not os.path.exists(log_dir): 37 | os.makedirs(log_dir) 38 | if not os.path.exists(save_dir): 39 | os.makedirs(save_dir) 40 | if not os.path.exists(answer_dir): 41 | os.makedirs(answer_dir) 42 | 43 | flags.DEFINE_string("mode", "train", "train/debug/test") 44 | 45 | flags.DEFINE_string("target_dir", target_dir, "") 46 | flags.DEFINE_string("log_dir", log_dir, "") 47 | flags.DEFINE_string("save_dir", save_dir, "") 48 | flags.DEFINE_string("train_file", train_file, "") 49 | flags.DEFINE_string("dev_file", dev_file, "") 50 | flags.DEFINE_string("test_file", test_file, "") 51 | flags.DEFINE_string("glove_word_file", glove_word_file, "") 52 | 53 | flags.DEFINE_string("train_record_file", train_record_file, "") 54 | flags.DEFINE_string("dev_record_file", dev_record_file, "") 55 | flags.DEFINE_string("test_record_file", test_record_file, "") 56 | flags.DEFINE_string("word_emb_file", word_emb_file, "") 57 | flags.DEFINE_string("char_emb_file", char_emb_file, "") 58 | flags.DEFINE_string("train_eval_file", train_eval, "") 59 | flags.DEFINE_string("dev_eval_file", dev_eval, "") 60 | flags.DEFINE_string("test_eval_file", test_eval, "") 61 | flags.DEFINE_string("dev_meta", dev_meta, "") 62 | flags.DEFINE_string("test_meta", test_meta, "") 63 | flags.DEFINE_string("word2idx_file", word2idx_file, "") 64 | flags.DEFINE_string("char2idx_file", char2idx_file, "") 65 | flags.DEFINE_string("answer_file", answer_file, "") 66 | 67 | 68 | flags.DEFINE_integer("glove_char_size", 94, "Corpus size for Glove") 69 | flags.DEFINE_integer("glove_word_size", int(2.2e6), "Corpus size for Glove") 70 | flags.DEFINE_integer("glove_dim", 300, "Embedding dimension for Glove") 71 | flags.DEFINE_integer("char_dim", 8, "Embedding dimension for char") 72 | 73 | flags.DEFINE_integer("para_limit", 400, "Limit length for paragraph") 74 | flags.DEFINE_integer("ques_limit", 50, "Limit length for question") 75 | flags.DEFINE_integer("test_para_limit", 1000, 76 | "Max length for paragraph in test") 77 | flags.DEFINE_integer("test_ques_limit", 100, "Max length of questions in test") 78 | flags.DEFINE_integer("char_limit", 16, "Limit length for character") 79 | flags.DEFINE_integer("word_count_limit", -1, "Min count for word") 80 | flags.DEFINE_integer("char_count_limit", -1, "Min count for char") 81 | 82 | flags.DEFINE_integer("capacity", 15000, "Batch size of dataset shuffle") 83 | flags.DEFINE_integer("num_threads", 4, "Number of threads in input pipeline") 84 | flags.DEFINE_boolean("use_cudnn", True, "Whether to use cudnn (only for GPU)") 85 | flags.DEFINE_boolean("is_bucket", False, "Whether to use bucketing") 86 | flags.DEFINE_list("bucket_range", [40, 361, 40], "range of bucket") 87 | 88 | flags.DEFINE_integer("batch_size", 64, "Batch size") 89 | flags.DEFINE_integer("num_steps", 60000, "Number of steps") 90 | flags.DEFINE_integer("checkpoint", 1000, "checkpoint for evaluation") 91 | flags.DEFINE_integer("period", 100, "period to save batch loss") 92 | flags.DEFINE_integer("val_num_batches", 150, "Num of batches for evaluation") 93 | flags.DEFINE_float("init_lr", 0.5, "Initial lr for Adadelta") 94 | flags.DEFINE_float("keep_prob", 0.7, "Keep prob in rnn") 95 | flags.DEFINE_float("ptr_keep_prob", 0.7, "Keep prob for pointer network") 96 | flags.DEFINE_float("grad_clip", 5.0, "Global Norm gradient clipping rate") 97 | flags.DEFINE_integer("hidden", 75, "Hidden size") 98 | flags.DEFINE_integer("char_hidden", 100, "GRU dim for char") 99 | flags.DEFINE_integer("patience", 3, "Patience for lr decay") 100 | 101 | # Extensions (Uncomment corresponding line in download.sh to download the required data) 102 | glove_char_file = os.path.join( 103 | home, "data", "glove", "glove.840B.300d-char.txt") 104 | flags.DEFINE_string("glove_char_file", glove_char_file, 105 | "Glove character embedding") 106 | flags.DEFINE_boolean("pretrained_char", False, 107 | "Whether to use pretrained char embedding") 108 | 109 | fasttext_file = os.path.join(home, "data", "fasttext", "wiki-news-300d-1M.vec") 110 | flags.DEFINE_string("fasttext_file", fasttext_file, "Fasttext word embedding") 111 | flags.DEFINE_boolean("fasttext", False, "Whether to use fasttext") 112 | 113 | 114 | def main(_): 115 | config = flags.FLAGS 116 | if config.mode == "train": 117 | train(config) 118 | elif config.mode == "prepro": 119 | prepro(config) 120 | elif config.mode == "debug": 121 | config.num_steps = 2 122 | config.val_num_batches = 1 123 | config.checkpoint = 1 124 | config.period = 1 125 | train(config) 126 | elif config.mode == "test": 127 | test(config) 128 | else: 129 | print("Unknown mode") 130 | exit(0) 131 | 132 | 133 | if __name__ == "__main__": 134 | tf.app.run() 135 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Download SQuAD 4 | SQUAD_DIR=~/data/squad 5 | mkdir -p $SQUAD_DIR 6 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json -O $SQUAD_DIR/train-v1.1.json 7 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json -O $SQUAD_DIR/dev-v1.1.json 8 | 9 | # Download GloVe 10 | GLOVE_DIR=~/data/glove 11 | mkdir -p $GLOVE_DIR 12 | wget http://nlp.stanford.edu/data/glove.840B.300d.zip -O $GLOVE_DIR/glove.840B.300d.zip 13 | unzip $GLOVE_DIR/glove.840B.300d.zip -d $GLOVE_DIR 14 | 15 | # Download Glove Character Embedding 16 | # wget https://raw.githubusercontent.com/minimaxir/char-embeddings/master/glove.840B.300d-char.txt -O $GLOVE_DIR/glove.840B.300d-char.txt 17 | 18 | # Download fasttext 19 | # FASTTEXT_DIR=~/data/fasttext 20 | # mkdir -p $FASTTEXT_DIR 21 | # wget https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki-news-300d-1M.vec.zip -O $FASTTEXT_DIR/wiki-news-300d-1M.vec.zip 22 | # unzip $FASTTEXT_DIR/wiki-news-300d-1M.vec.zip -d $FASTTEXT_DIR 23 | 24 | # Download Spacy language models 25 | python3 -m spacy download en 26 | -------------------------------------------------------------------------------- /evaluate-v1.1.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """ 2 | from __future__ import print_function 3 | from collections import Counter 4 | import string 5 | import re 6 | import argparse 7 | import json 8 | import sys 9 | 10 | 11 | def normalize_answer(s): 12 | """Lower text and remove punctuation, articles and extra whitespace.""" 13 | def remove_articles(text): 14 | return re.sub(r'\b(a|an|the)\b', ' ', text) 15 | 16 | def white_space_fix(text): 17 | return ' '.join(text.split()) 18 | 19 | def remove_punc(text): 20 | exclude = set(string.punctuation) 21 | return ''.join(ch for ch in text if ch not in exclude) 22 | 23 | def lower(text): 24 | return text.lower() 25 | 26 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 27 | 28 | 29 | def f1_score(prediction, ground_truth): 30 | prediction_tokens = normalize_answer(prediction).split() 31 | ground_truth_tokens = normalize_answer(ground_truth).split() 32 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 33 | num_same = sum(common.values()) 34 | if num_same == 0: 35 | return 0 36 | precision = 1.0 * num_same / len(prediction_tokens) 37 | recall = 1.0 * num_same / len(ground_truth_tokens) 38 | f1 = (2 * precision * recall) / (precision + recall) 39 | return f1 40 | 41 | 42 | def exact_match_score(prediction, ground_truth): 43 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 44 | 45 | 46 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 47 | scores_for_ground_truths = [] 48 | for ground_truth in ground_truths: 49 | score = metric_fn(prediction, ground_truth) 50 | scores_for_ground_truths.append(score) 51 | return max(scores_for_ground_truths) 52 | 53 | 54 | def evaluate(dataset, predictions): 55 | f1 = exact_match = total = 0 56 | for article in dataset: 57 | for paragraph in article['paragraphs']: 58 | for qa in paragraph['qas']: 59 | total += 1 60 | if qa['id'] not in predictions: 61 | message = 'Unanswered question ' + qa['id'] + \ 62 | ' will receive score 0.' 63 | print(message, file=sys.stderr) 64 | continue 65 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 66 | prediction = predictions[qa['id']] 67 | exact_match += metric_max_over_ground_truths( 68 | exact_match_score, prediction, ground_truths) 69 | f1 += metric_max_over_ground_truths( 70 | f1_score, prediction, ground_truths) 71 | 72 | exact_match = 100.0 * exact_match / total 73 | f1 = 100.0 * f1 / total 74 | 75 | return {'exact_match': exact_match, 'f1': f1} 76 | 77 | 78 | if __name__ == '__main__': 79 | expected_version = '1.1' 80 | parser = argparse.ArgumentParser( 81 | description='Evaluation for SQuAD ' + expected_version) 82 | parser.add_argument('dataset_file', help='Dataset file') 83 | parser.add_argument('prediction_file', help='Prediction File') 84 | args = parser.parse_args() 85 | with open(args.dataset_file) as dataset_file: 86 | dataset_json = json.load(dataset_file) 87 | if (dataset_json['version'] != expected_version): 88 | print('Evaluation expects v-' + expected_version + 89 | ', but got dataset with v-' + dataset_json['version'], 90 | file=sys.stderr) 91 | dataset = dataset_json['data'] 92 | with open(args.prediction_file) as prediction_file: 93 | predictions = json.load(prediction_file) 94 | print(json.dumps(evaluate(dataset, predictions))) 95 | -------------------------------------------------------------------------------- /func.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | INF = 1e30 4 | 5 | 6 | class cudnn_gru: 7 | 8 | def __init__(self, num_layers, num_units, batch_size, input_size, keep_prob=1.0, is_train=None, scope=None): 9 | self.num_layers = num_layers 10 | self.grus = [] 11 | self.inits = [] 12 | self.dropout_mask = [] 13 | for layer in range(num_layers): 14 | input_size_ = input_size if layer == 0 else 2 * num_units 15 | gru_fw = tf.contrib.cudnn_rnn.CudnnGRU(1, num_units) 16 | gru_bw = tf.contrib.cudnn_rnn.CudnnGRU(1, num_units) 17 | init_fw = tf.tile(tf.Variable( 18 | tf.zeros([1, 1, num_units])), [1, batch_size, 1]) 19 | init_bw = tf.tile(tf.Variable( 20 | tf.zeros([1, 1, num_units])), [1, batch_size, 1]) 21 | mask_fw = dropout(tf.ones([1, batch_size, input_size_], dtype=tf.float32), 22 | keep_prob=keep_prob, is_train=is_train, mode=None) 23 | mask_bw = dropout(tf.ones([1, batch_size, input_size_], dtype=tf.float32), 24 | keep_prob=keep_prob, is_train=is_train, mode=None) 25 | self.grus.append((gru_fw, gru_bw, )) 26 | self.inits.append((init_fw, init_bw, )) 27 | self.dropout_mask.append((mask_fw, mask_bw, )) 28 | 29 | def __call__(self, inputs, seq_len, keep_prob=1.0, is_train=None, concat_layers=True): 30 | outputs = [tf.transpose(inputs, [1, 0, 2])] 31 | for layer in range(self.num_layers): 32 | gru_fw, gru_bw = self.grus[layer] 33 | init_fw, init_bw = self.inits[layer] 34 | mask_fw, mask_bw = self.dropout_mask[layer] 35 | with tf.variable_scope("fw_{}".format(layer)): 36 | out_fw, _ = gru_fw( 37 | outputs[-1] * mask_fw, initial_state=(init_fw, )) 38 | with tf.variable_scope("bw_{}".format(layer)): 39 | inputs_bw = tf.reverse_sequence( 40 | outputs[-1] * mask_bw, seq_lengths=seq_len, seq_dim=0, batch_dim=1) 41 | out_bw, _ = gru_bw(inputs_bw, initial_state=(init_bw, )) 42 | out_bw = tf.reverse_sequence( 43 | out_bw, seq_lengths=seq_len, seq_dim=0, batch_dim=1) 44 | outputs.append(tf.concat([out_fw, out_bw], axis=2)) 45 | if concat_layers: 46 | res = tf.concat(outputs[1:], axis=2) 47 | else: 48 | res = outputs[-1] 49 | res = tf.transpose(res, [1, 0, 2]) 50 | return res 51 | 52 | 53 | class native_gru: 54 | 55 | def __init__(self, num_layers, num_units, batch_size, input_size, keep_prob=1.0, is_train=None, scope="native_gru"): 56 | self.num_layers = num_layers 57 | self.grus = [] 58 | self.inits = [] 59 | self.dropout_mask = [] 60 | self.scope = scope 61 | for layer in range(num_layers): 62 | input_size_ = input_size if layer == 0 else 2 * num_units 63 | gru_fw = tf.contrib.rnn.GRUCell(num_units) 64 | gru_bw = tf.contrib.rnn.GRUCell(num_units) 65 | init_fw = tf.tile(tf.Variable( 66 | tf.zeros([1, num_units])), [batch_size, 1]) 67 | init_bw = tf.tile(tf.Variable( 68 | tf.zeros([1, num_units])), [batch_size, 1]) 69 | mask_fw = dropout(tf.ones([batch_size, 1, input_size_], dtype=tf.float32), 70 | keep_prob=keep_prob, is_train=is_train, mode=None) 71 | mask_bw = dropout(tf.ones([batch_size, 1, input_size_], dtype=tf.float32), 72 | keep_prob=keep_prob, is_train=is_train, mode=None) 73 | self.grus.append((gru_fw, gru_bw, )) 74 | self.inits.append((init_fw, init_bw, )) 75 | self.dropout_mask.append((mask_fw, mask_bw, )) 76 | 77 | def __call__(self, inputs, seq_len, keep_prob=1.0, is_train=None, concat_layers=True): 78 | outputs = [inputs] 79 | with tf.variable_scope(self.scope): 80 | for layer in range(self.num_layers): 81 | gru_fw, gru_bw = self.grus[layer] 82 | init_fw, init_bw = self.inits[layer] 83 | mask_fw, mask_bw = self.dropout_mask[layer] 84 | with tf.variable_scope("fw_{}".format(layer)): 85 | out_fw, _ = tf.nn.dynamic_rnn( 86 | gru_fw, outputs[-1] * mask_fw, seq_len, initial_state=init_fw, dtype=tf.float32) 87 | with tf.variable_scope("bw_{}".format(layer)): 88 | inputs_bw = tf.reverse_sequence( 89 | outputs[-1] * mask_bw, seq_lengths=seq_len, seq_dim=1, batch_dim=0) 90 | out_bw, _ = tf.nn.dynamic_rnn( 91 | gru_bw, inputs_bw, seq_len, initial_state=init_bw, dtype=tf.float32) 92 | out_bw = tf.reverse_sequence( 93 | out_bw, seq_lengths=seq_len, seq_dim=1, batch_dim=0) 94 | outputs.append(tf.concat([out_fw, out_bw], axis=2)) 95 | if concat_layers: 96 | res = tf.concat(outputs[1:], axis=2) 97 | else: 98 | res = outputs[-1] 99 | return res 100 | 101 | 102 | class ptr_net: 103 | def __init__(self, batch, hidden, keep_prob=1.0, is_train=None, scope="ptr_net"): 104 | self.gru = tf.contrib.rnn.GRUCell(hidden) 105 | self.batch = batch 106 | self.scope = scope 107 | self.keep_prob = keep_prob 108 | self.is_train = is_train 109 | self.dropout_mask = dropout(tf.ones( 110 | [batch, hidden], dtype=tf.float32), keep_prob=keep_prob, is_train=is_train) 111 | 112 | def __call__(self, init, match, d, mask): 113 | with tf.variable_scope(self.scope): 114 | d_match = dropout(match, keep_prob=self.keep_prob, 115 | is_train=self.is_train) 116 | inp, logits1 = pointer(d_match, init * self.dropout_mask, d, mask) 117 | d_inp = dropout(inp, keep_prob=self.keep_prob, 118 | is_train=self.is_train) 119 | _, state = self.gru(d_inp, init) 120 | tf.get_variable_scope().reuse_variables() 121 | _, logits2 = pointer(d_match, state * self.dropout_mask, d, mask) 122 | return logits1, logits2 123 | 124 | 125 | def dropout(args, keep_prob, is_train, mode="recurrent"): 126 | if keep_prob < 1.0: 127 | noise_shape = None 128 | scale = 1.0 129 | shape = tf.shape(args) 130 | if mode == "embedding": 131 | noise_shape = [shape[0], 1] 132 | scale = keep_prob 133 | if mode == "recurrent" and len(args.get_shape().as_list()) == 3: 134 | noise_shape = [shape[0], 1, shape[-1]] 135 | args = tf.cond(is_train, lambda: tf.nn.dropout( 136 | args, keep_prob, noise_shape=noise_shape) * scale, lambda: args) 137 | return args 138 | 139 | 140 | def softmax_mask(val, mask): 141 | return -INF * (1 - tf.cast(mask, tf.float32)) + val 142 | 143 | 144 | def pointer(inputs, state, hidden, mask, scope="pointer"): 145 | with tf.variable_scope(scope): 146 | u = tf.concat([tf.tile(tf.expand_dims(state, axis=1), [ 147 | 1, tf.shape(inputs)[1], 1]), inputs], axis=2) 148 | s0 = tf.nn.tanh(dense(u, hidden, use_bias=False, scope="s0")) 149 | s = dense(s0, 1, use_bias=False, scope="s") 150 | s1 = softmax_mask(tf.squeeze(s, [2]), mask) 151 | a = tf.expand_dims(tf.nn.softmax(s1), axis=2) 152 | res = tf.reduce_sum(a * inputs, axis=1) 153 | return res, s1 154 | 155 | 156 | def summ(memory, hidden, mask, keep_prob=1.0, is_train=None, scope="summ"): 157 | with tf.variable_scope(scope): 158 | d_memory = dropout(memory, keep_prob=keep_prob, is_train=is_train) 159 | s0 = tf.nn.tanh(dense(d_memory, hidden, scope="s0")) 160 | s = dense(s0, 1, use_bias=False, scope="s") 161 | s1 = softmax_mask(tf.squeeze(s, [2]), mask) 162 | a = tf.expand_dims(tf.nn.softmax(s1), axis=2) 163 | res = tf.reduce_sum(a * memory, axis=1) 164 | return res 165 | 166 | 167 | def dot_attention(inputs, memory, mask, hidden, keep_prob=1.0, is_train=None, scope="dot_attention"): 168 | with tf.variable_scope(scope): 169 | 170 | d_inputs = dropout(inputs, keep_prob=keep_prob, is_train=is_train) 171 | d_memory = dropout(memory, keep_prob=keep_prob, is_train=is_train) 172 | JX = tf.shape(inputs)[1] 173 | 174 | with tf.variable_scope("attention"): 175 | inputs_ = tf.nn.relu( 176 | dense(d_inputs, hidden, use_bias=False, scope="inputs")) 177 | memory_ = tf.nn.relu( 178 | dense(d_memory, hidden, use_bias=False, scope="memory")) 179 | outputs = tf.matmul(inputs_, tf.transpose( 180 | memory_, [0, 2, 1])) / (hidden ** 0.5) 181 | mask = tf.tile(tf.expand_dims(mask, axis=1), [1, JX, 1]) 182 | logits = tf.nn.softmax(softmax_mask(outputs, mask)) 183 | outputs = tf.matmul(logits, memory) 184 | res = tf.concat([inputs, outputs], axis=2) 185 | 186 | with tf.variable_scope("gate"): 187 | dim = res.get_shape().as_list()[-1] 188 | d_res = dropout(res, keep_prob=keep_prob, is_train=is_train) 189 | gate = tf.nn.sigmoid(dense(d_res, dim, use_bias=False)) 190 | return res * gate 191 | 192 | 193 | def dense(inputs, hidden, use_bias=True, scope="dense"): 194 | with tf.variable_scope(scope): 195 | shape = tf.shape(inputs) 196 | dim = inputs.get_shape().as_list()[-1] 197 | out_shape = [shape[idx] for idx in range( 198 | len(inputs.get_shape().as_list()) - 1)] + [hidden] 199 | flat_inputs = tf.reshape(inputs, [-1, dim]) 200 | W = tf.get_variable("W", [dim, hidden]) 201 | res = tf.matmul(flat_inputs, W) 202 | if use_bias: 203 | b = tf.get_variable( 204 | "b", [hidden], initializer=tf.constant_initializer(0.)) 205 | res = tf.nn.bias_add(res, b) 206 | res = tf.reshape(res, out_shape) 207 | return res 208 | -------------------------------------------------------------------------------- /img/em.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-KnowComp/R-Net/1efc5ef18067a84d44fa8d31c309a19c52958e92/img/em.jpg -------------------------------------------------------------------------------- /img/f1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HKUST-KnowComp/R-Net/1efc5ef18067a84d44fa8d31c309a19c52958e92/img/f1.jpg -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import spacy 3 | import os 4 | import numpy as np 5 | import ujson as json 6 | 7 | 8 | from func import cudnn_gru, native_gru, dot_attention, summ, ptr_net 9 | from prepro import word_tokenize, convert_idx 10 | 11 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 12 | 13 | # Must be consistant with training 14 | char_limit = 16 15 | hidden = 75 16 | char_dim = 8 17 | char_hidden = 100 18 | use_cudnn = True 19 | 20 | # File path 21 | target_dir = "data" 22 | save_dir = "log/model" 23 | word_emb_file = os.path.join(target_dir, "word_emb.json") 24 | char_emb_file = os.path.join(target_dir, "char_emb.json") 25 | word2idx_file = os.path.join(target_dir, "word2idx.json") 26 | char2idx_file = os.path.join(target_dir, "char2idx.json") 27 | 28 | 29 | class InfModel(object): 30 | # Used to zero elements in the probability matrix that correspond to answer 31 | # spans that are longer than the number of tokens specified here. 32 | max_answer_tokens = 15 33 | 34 | def __init__(self, word_mat, char_mat): 35 | self.c = tf.placeholder(tf.int32, [1, None]) 36 | self.q = tf.placeholder(tf.int32, [1, None]) 37 | self.ch = tf.placeholder(tf.int32, [1, None, char_limit]) 38 | self.qh = tf.placeholder(tf.int32, [1, None, char_limit]) 39 | self.tokens_in_context = tf.placeholder(tf.int64) 40 | 41 | self.word_mat = tf.get_variable("word_mat", initializer=tf.constant( 42 | word_mat, dtype=tf.float32), trainable=False) 43 | self.char_mat = tf.get_variable( 44 | "char_mat", initializer=tf.constant(char_mat, dtype=tf.float32)) 45 | 46 | self.c_mask = tf.cast(self.c, tf.bool) 47 | self.q_mask = tf.cast(self.q, tf.bool) 48 | self.c_len = tf.reduce_sum(tf.cast(self.c_mask, tf.int32), axis=1) 49 | self.q_len = tf.reduce_sum(tf.cast(self.q_mask, tf.int32), axis=1) 50 | 51 | self.c_maxlen = tf.reduce_max(self.c_len) 52 | self.q_maxlen = tf.reduce_max(self.q_len) 53 | 54 | self.ch_len = tf.reshape(tf.reduce_sum( 55 | tf.cast(tf.cast(self.ch, tf.bool), tf.int32), axis=2), [-1]) 56 | self.qh_len = tf.reshape(tf.reduce_sum( 57 | tf.cast(tf.cast(self.qh, tf.bool), tf.int32), axis=2), [-1]) 58 | 59 | self.ready() 60 | 61 | def ready(self): 62 | N, PL, QL, CL, d, dc, dg = \ 63 | 1, self.c_maxlen, self.q_maxlen, char_limit, hidden, char_dim, \ 64 | char_hidden 65 | gru = cudnn_gru if use_cudnn else native_gru 66 | 67 | with tf.variable_scope("emb"): 68 | with tf.variable_scope("char"): 69 | ch_emb = tf.reshape(tf.nn.embedding_lookup( 70 | self.char_mat, self.ch), [N * PL, CL, dc]) 71 | qh_emb = tf.reshape(tf.nn.embedding_lookup( 72 | self.char_mat, self.qh), [N * QL, CL, dc]) 73 | cell_fw = tf.contrib.rnn.GRUCell(dg) 74 | cell_bw = tf.contrib.rnn.GRUCell(dg) 75 | _, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn( 76 | cell_fw, cell_bw, ch_emb, self.ch_len, dtype=tf.float32) 77 | ch_emb = tf.concat([state_fw, state_bw], axis=1) 78 | _, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn( 79 | cell_fw, cell_bw, qh_emb, self.qh_len, dtype=tf.float32) 80 | qh_emb = tf.concat([state_fw, state_bw], axis=1) 81 | qh_emb = tf.reshape(qh_emb, [N, QL, 2 * dg]) 82 | ch_emb = tf.reshape(ch_emb, [N, PL, 2 * dg]) 83 | 84 | with tf.name_scope("word"): 85 | c_emb = tf.nn.embedding_lookup(self.word_mat, self.c) 86 | q_emb = tf.nn.embedding_lookup(self.word_mat, self.q) 87 | 88 | c_emb = tf.concat([c_emb, ch_emb], axis=2) 89 | q_emb = tf.concat([q_emb, qh_emb], axis=2) 90 | 91 | with tf.variable_scope("encoding"): 92 | rnn = gru(num_layers=3, num_units=d, batch_size=N, 93 | input_size=c_emb.get_shape().as_list()[-1]) 94 | c = rnn(c_emb, seq_len=self.c_len) 95 | q = rnn(q_emb, seq_len=self.q_len) 96 | 97 | with tf.variable_scope("attention"): 98 | qc_att = dot_attention(c, q, mask=self.q_mask, hidden=d) 99 | rnn = gru(num_layers=1, num_units=d, batch_size=N, 100 | input_size=qc_att.get_shape().as_list()[-1]) 101 | att = rnn(qc_att, seq_len=self.c_len) 102 | 103 | with tf.variable_scope("match"): 104 | self_att = dot_attention(att, att, mask=self.c_mask, hidden=d) 105 | rnn = gru(num_layers=1, num_units=d, batch_size=N, 106 | input_size=self_att.get_shape().as_list()[-1]) 107 | match = rnn(self_att, seq_len=self.c_len) 108 | 109 | with tf.variable_scope("pointer"): 110 | init = summ(q[:, :, -2 * d:], d, mask=self.q_mask) 111 | pointer = ptr_net(batch=N, hidden=init.get_shape().as_list()[-1]) 112 | logits1, logits2 = pointer(init, match, d, self.c_mask) 113 | 114 | with tf.variable_scope("predict"): 115 | outer = tf.matmul(tf.expand_dims(tf.nn.softmax(logits1), axis=2), 116 | tf.expand_dims(tf.nn.softmax(logits2), axis=1)) 117 | outer = tf.cond( 118 | self.tokens_in_context < self.max_answer_tokens, 119 | lambda: tf.matrix_band_part(outer, 0, -1), 120 | lambda: tf.matrix_band_part(outer, 0, self.max_answer_tokens)) 121 | self.yp1 = tf.argmax(tf.reduce_max(outer, axis=2), axis=1) 122 | self.yp2 = tf.argmax(tf.reduce_max(outer, axis=1), axis=1) 123 | 124 | 125 | class Inference(object): 126 | 127 | def __init__(self): 128 | with open(word_emb_file, "r") as fh: 129 | self.word_mat = np.array(json.load(fh), dtype=np.float32) 130 | with open(char_emb_file, "r") as fh: 131 | self.char_mat = np.array(json.load(fh), dtype=np.float32) 132 | with open(word2idx_file, "r") as fh: 133 | self.word2idx_dict = json.load(fh) 134 | with open(char2idx_file, "r") as fh: 135 | self.char2idx_dict = json.load(fh) 136 | self.model = InfModel(self.word_mat, self.char_mat) 137 | sess_config = tf.ConfigProto(allow_soft_placement=True) 138 | sess_config.gpu_options.allow_growth = True 139 | self.sess = tf.Session(config=sess_config) 140 | saver = tf.train.Saver() 141 | saver.restore(self.sess, tf.train.latest_checkpoint(save_dir)) 142 | 143 | def response(self, context, question): 144 | sess = self.sess 145 | model = self.model 146 | span, context_idxs, ques_idxs, context_char_idxs, ques_char_idxs = \ 147 | self.prepro(context, question) 148 | yp1, yp2 = \ 149 | sess.run( 150 | [model.yp1, model.yp2], 151 | feed_dict={ 152 | model.c: context_idxs, model.q: ques_idxs, 153 | model.ch: context_char_idxs, model.qh: ques_char_idxs, 154 | model.tokens_in_context: len(span)}) 155 | start_idx = span[yp1[0]][0] 156 | end_idx = span[yp2[0]][1] 157 | return context[start_idx: end_idx] 158 | 159 | def prepro(self, context, question): 160 | context = context.replace("''", '" ').replace("``", '" ') 161 | context_tokens = word_tokenize(context) 162 | context_chars = [list(token) for token in context_tokens] 163 | spans = convert_idx(context, context_tokens) 164 | ques = question.replace("''", '" ').replace("``", '" ') 165 | ques_tokens = word_tokenize(ques) 166 | ques_chars = [list(token) for token in ques_tokens] 167 | 168 | context_idxs = np.zeros([1, len(context_tokens)], dtype=np.int32) 169 | context_char_idxs = np.zeros( 170 | [1, len(context_tokens), char_limit], dtype=np.int32) 171 | ques_idxs = np.zeros([1, len(ques_tokens)], dtype=np.int32) 172 | ques_char_idxs = np.zeros( 173 | [1, len(ques_tokens), char_limit], dtype=np.int32) 174 | 175 | def _get_word(word): 176 | for each in (word, word.lower(), word.capitalize(), word.upper()): 177 | if each in self.word2idx_dict: 178 | return self.word2idx_dict[each] 179 | return 1 180 | 181 | def _get_char(char): 182 | if char in self.char2idx_dict: 183 | return self.char2idx_dict[char] 184 | return 1 185 | 186 | for i, token in enumerate(context_tokens): 187 | context_idxs[0, i] = _get_word(token) 188 | 189 | for i, token in enumerate(ques_tokens): 190 | ques_idxs[0, i] = _get_word(token) 191 | 192 | for i, token in enumerate(context_chars): 193 | for j, char in enumerate(token): 194 | if j == char_limit: 195 | break 196 | context_char_idxs[0, i, j] = _get_char(char) 197 | 198 | for i, token in enumerate(ques_chars): 199 | for j, char in enumerate(token): 200 | if j == char_limit: 201 | break 202 | ques_char_idxs[0, i, j] = _get_char(char) 203 | return spans, context_idxs, ques_idxs, context_char_idxs, ques_char_idxs 204 | 205 | 206 | # Demo, example from paper "SQuAD: 100,000+ Questions for Machine Comprehension of Text" 207 | if __name__ == "__main__": 208 | infer = Inference() 209 | context = "In meteorology, precipitation is any product of the condensation " \ 210 | "of atmospheric water vapor that falls under gravity. The main forms " \ 211 | "of precipitation include drizzle, rain, sleet, snow, graupel and hail." \ 212 | "Precipitation forms as smaller droplets coalesce via collision with other " \ 213 | "rain drops or ice crystals within a cloud. Short, intense periods of rain " \ 214 | "in scattered locations are called “showers”." 215 | ques1 = "What causes precipitation to fall?" 216 | ques2 = "What is another main form of precipitation besides drizzle, rain, snow, sleet and hail?" 217 | ques3 = "Where do water droplets collide with ice crystals to form precipitation?" 218 | 219 | # Correct: gravity, Output: drizzle, rain, sleet, snow, graupel and hail 220 | ans1 = infer.response(context, ques1) 221 | print("Answer 1: {}".format(ans1)) 222 | 223 | # Correct: graupel, Output: graupel 224 | ans2 = infer.response(context, ques2) 225 | print("Answer 2: {}".format(ans2)) 226 | 227 | # Correct: within a cloud, Output: within a cloud 228 | ans3 = infer.response(context, ques3) 229 | print("Answer 3: {}".format(ans3)) 230 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import ujson as json 3 | import numpy as np 4 | from tqdm import tqdm 5 | import os 6 | 7 | from model import Model 8 | from util import get_record_parser, convert_tokens, evaluate, get_batch_dataset, get_dataset 9 | 10 | 11 | def train(config): 12 | with open(config.word_emb_file, "r") as fh: 13 | word_mat = np.array(json.load(fh), dtype=np.float32) 14 | with open(config.char_emb_file, "r") as fh: 15 | char_mat = np.array(json.load(fh), dtype=np.float32) 16 | with open(config.train_eval_file, "r") as fh: 17 | train_eval_file = json.load(fh) 18 | with open(config.dev_eval_file, "r") as fh: 19 | dev_eval_file = json.load(fh) 20 | with open(config.dev_meta, "r") as fh: 21 | meta = json.load(fh) 22 | 23 | dev_total = meta["total"] 24 | 25 | print("Building model...") 26 | parser = get_record_parser(config) 27 | train_dataset = get_batch_dataset(config.train_record_file, parser, config) 28 | dev_dataset = get_dataset(config.dev_record_file, parser, config) 29 | handle = tf.placeholder(tf.string, shape=[]) 30 | iterator = tf.data.Iterator.from_string_handle( 31 | handle, train_dataset.output_types, train_dataset.output_shapes) 32 | train_iterator = train_dataset.make_one_shot_iterator() 33 | dev_iterator = dev_dataset.make_one_shot_iterator() 34 | 35 | model = Model(config, iterator, word_mat, char_mat) 36 | 37 | sess_config = tf.ConfigProto(allow_soft_placement=True) 38 | sess_config.gpu_options.allow_growth = True 39 | 40 | loss_save = 100.0 41 | patience = 0 42 | lr = config.init_lr 43 | 44 | with tf.Session(config=sess_config) as sess: 45 | writer = tf.summary.FileWriter(config.log_dir) 46 | sess.run(tf.global_variables_initializer()) 47 | saver = tf.train.Saver() 48 | train_handle = sess.run(train_iterator.string_handle()) 49 | dev_handle = sess.run(dev_iterator.string_handle()) 50 | sess.run(tf.assign(model.is_train, tf.constant(True, dtype=tf.bool))) 51 | sess.run(tf.assign(model.lr, tf.constant(lr, dtype=tf.float32))) 52 | 53 | for _ in tqdm(range(1, config.num_steps + 1)): 54 | global_step = sess.run(model.global_step) + 1 55 | loss, train_op = sess.run([model.loss, model.train_op], feed_dict={ 56 | handle: train_handle}) 57 | if global_step % config.period == 0: 58 | loss_sum = tf.Summary(value=[tf.Summary.Value( 59 | tag="model/loss", simple_value=loss), ]) 60 | writer.add_summary(loss_sum, global_step) 61 | if global_step % config.checkpoint == 0: 62 | sess.run(tf.assign(model.is_train, 63 | tf.constant(False, dtype=tf.bool))) 64 | _, summ = evaluate_batch( 65 | model, config.val_num_batches, train_eval_file, sess, "train", handle, train_handle) 66 | for s in summ: 67 | writer.add_summary(s, global_step) 68 | 69 | metrics, summ = evaluate_batch( 70 | model, dev_total // config.batch_size + 1, dev_eval_file, sess, "dev", handle, dev_handle) 71 | sess.run(tf.assign(model.is_train, 72 | tf.constant(True, dtype=tf.bool))) 73 | 74 | dev_loss = metrics["loss"] 75 | if dev_loss < loss_save: 76 | loss_save = dev_loss 77 | patience = 0 78 | else: 79 | patience += 1 80 | if patience >= config.patience: 81 | lr /= 2.0 82 | loss_save = dev_loss 83 | patience = 0 84 | sess.run(tf.assign(model.lr, tf.constant(lr, dtype=tf.float32))) 85 | for s in summ: 86 | writer.add_summary(s, global_step) 87 | writer.flush() 88 | filename = os.path.join( 89 | config.save_dir, "model_{}.ckpt".format(global_step)) 90 | saver.save(sess, filename) 91 | 92 | 93 | def evaluate_batch(model, num_batches, eval_file, sess, data_type, handle, str_handle): 94 | answer_dict = {} 95 | losses = [] 96 | for _ in tqdm(range(1, num_batches + 1)): 97 | qa_id, loss, yp1, yp2, = sess.run( 98 | [model.qa_id, model.loss, model.yp1, model.yp2], feed_dict={handle: str_handle}) 99 | answer_dict_, _ = convert_tokens( 100 | eval_file, qa_id.tolist(), yp1.tolist(), yp2.tolist()) 101 | answer_dict.update(answer_dict_) 102 | losses.append(loss) 103 | loss = np.mean(losses) 104 | metrics = evaluate(eval_file, answer_dict) 105 | metrics["loss"] = loss 106 | loss_sum = tf.Summary(value=[tf.Summary.Value( 107 | tag="{}/loss".format(data_type), simple_value=metrics["loss"]), ]) 108 | f1_sum = tf.Summary(value=[tf.Summary.Value( 109 | tag="{}/f1".format(data_type), simple_value=metrics["f1"]), ]) 110 | em_sum = tf.Summary(value=[tf.Summary.Value( 111 | tag="{}/em".format(data_type), simple_value=metrics["exact_match"]), ]) 112 | return metrics, [loss_sum, f1_sum, em_sum] 113 | 114 | 115 | def test(config): 116 | with open(config.word_emb_file, "r") as fh: 117 | word_mat = np.array(json.load(fh), dtype=np.float32) 118 | with open(config.char_emb_file, "r") as fh: 119 | char_mat = np.array(json.load(fh), dtype=np.float32) 120 | with open(config.test_eval_file, "r") as fh: 121 | eval_file = json.load(fh) 122 | with open(config.test_meta, "r") as fh: 123 | meta = json.load(fh) 124 | 125 | total = meta["total"] 126 | 127 | print("Loading model...") 128 | test_batch = get_dataset(config.test_record_file, get_record_parser( 129 | config, is_test=True), config).make_one_shot_iterator() 130 | 131 | model = Model(config, test_batch, word_mat, char_mat, trainable=False) 132 | 133 | sess_config = tf.ConfigProto(allow_soft_placement=True) 134 | sess_config.gpu_options.allow_growth = True 135 | 136 | with tf.Session(config=sess_config) as sess: 137 | sess.run(tf.global_variables_initializer()) 138 | saver = tf.train.Saver() 139 | saver.restore(sess, tf.train.latest_checkpoint(config.save_dir)) 140 | sess.run(tf.assign(model.is_train, tf.constant(False, dtype=tf.bool))) 141 | losses = [] 142 | answer_dict = {} 143 | remapped_dict = {} 144 | for step in tqdm(range(total // config.batch_size + 1)): 145 | qa_id, loss, yp1, yp2 = sess.run( 146 | [model.qa_id, model.loss, model.yp1, model.yp2]) 147 | answer_dict_, remapped_dict_ = convert_tokens( 148 | eval_file, qa_id.tolist(), yp1.tolist(), yp2.tolist()) 149 | answer_dict.update(answer_dict_) 150 | remapped_dict.update(remapped_dict_) 151 | losses.append(loss) 152 | loss = np.mean(losses) 153 | metrics = evaluate(eval_file, answer_dict) 154 | with open(config.answer_file, "w") as fh: 155 | json.dump(remapped_dict, fh) 156 | print("Exact Match: {}, F1: {}".format( 157 | metrics['exact_match'], metrics['f1'])) 158 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from func import cudnn_gru, native_gru, dot_attention, summ, dropout, ptr_net 3 | 4 | 5 | class Model(object): 6 | def __init__(self, config, batch, word_mat=None, char_mat=None, trainable=True, opt=True): 7 | self.config = config 8 | self.global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, 9 | initializer=tf.constant_initializer(0), trainable=False) 10 | self.c, self.q, self.ch, self.qh, self.y1, self.y2, self.qa_id = batch.get_next() 11 | self.is_train = tf.get_variable( 12 | "is_train", shape=[], dtype=tf.bool, trainable=False) 13 | self.word_mat = tf.get_variable("word_mat", initializer=tf.constant( 14 | word_mat, dtype=tf.float32), trainable=False) 15 | self.char_mat = tf.get_variable( 16 | "char_mat", initializer=tf.constant(char_mat, dtype=tf.float32)) 17 | 18 | self.c_mask = tf.cast(self.c, tf.bool) 19 | self.q_mask = tf.cast(self.q, tf.bool) 20 | self.c_len = tf.reduce_sum(tf.cast(self.c_mask, tf.int32), axis=1) 21 | self.q_len = tf.reduce_sum(tf.cast(self.q_mask, tf.int32), axis=1) 22 | 23 | if opt: 24 | N, CL = config.batch_size, config.char_limit 25 | self.c_maxlen = tf.reduce_max(self.c_len) 26 | self.q_maxlen = tf.reduce_max(self.q_len) 27 | self.c = tf.slice(self.c, [0, 0], [N, self.c_maxlen]) 28 | self.q = tf.slice(self.q, [0, 0], [N, self.q_maxlen]) 29 | self.c_mask = tf.slice(self.c_mask, [0, 0], [N, self.c_maxlen]) 30 | self.q_mask = tf.slice(self.q_mask, [0, 0], [N, self.q_maxlen]) 31 | self.ch = tf.slice(self.ch, [0, 0, 0], [N, self.c_maxlen, CL]) 32 | self.qh = tf.slice(self.qh, [0, 0, 0], [N, self.q_maxlen, CL]) 33 | self.y1 = tf.slice(self.y1, [0, 0], [N, self.c_maxlen]) 34 | self.y2 = tf.slice(self.y2, [0, 0], [N, self.c_maxlen]) 35 | else: 36 | self.c_maxlen, self.q_maxlen = config.para_limit, config.ques_limit 37 | 38 | self.ch_len = tf.reshape(tf.reduce_sum( 39 | tf.cast(tf.cast(self.ch, tf.bool), tf.int32), axis=2), [-1]) 40 | self.qh_len = tf.reshape(tf.reduce_sum( 41 | tf.cast(tf.cast(self.qh, tf.bool), tf.int32), axis=2), [-1]) 42 | 43 | self.ready() 44 | 45 | if trainable: 46 | self.lr = tf.get_variable( 47 | "lr", shape=[], dtype=tf.float32, trainable=False) 48 | self.opt = tf.train.AdadeltaOptimizer( 49 | learning_rate=self.lr, epsilon=1e-6) 50 | grads = self.opt.compute_gradients(self.loss) 51 | gradients, variables = zip(*grads) 52 | capped_grads, _ = tf.clip_by_global_norm( 53 | gradients, config.grad_clip) 54 | self.train_op = self.opt.apply_gradients( 55 | zip(capped_grads, variables), global_step=self.global_step) 56 | 57 | def ready(self): 58 | config = self.config 59 | N, PL, QL, CL, d, dc, dg = config.batch_size, self.c_maxlen, self.q_maxlen, config.char_limit, config.hidden, config.char_dim, config.char_hidden 60 | gru = cudnn_gru if config.use_cudnn else native_gru 61 | 62 | with tf.variable_scope("emb"): 63 | with tf.variable_scope("char"): 64 | ch_emb = tf.reshape(tf.nn.embedding_lookup( 65 | self.char_mat, self.ch), [N * PL, CL, dc]) 66 | qh_emb = tf.reshape(tf.nn.embedding_lookup( 67 | self.char_mat, self.qh), [N * QL, CL, dc]) 68 | ch_emb = dropout( 69 | ch_emb, keep_prob=config.keep_prob, is_train=self.is_train) 70 | qh_emb = dropout( 71 | qh_emb, keep_prob=config.keep_prob, is_train=self.is_train) 72 | cell_fw = tf.contrib.rnn.GRUCell(dg) 73 | cell_bw = tf.contrib.rnn.GRUCell(dg) 74 | _, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn( 75 | cell_fw, cell_bw, ch_emb, self.ch_len, dtype=tf.float32) 76 | ch_emb = tf.concat([state_fw, state_bw], axis=1) 77 | _, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn( 78 | cell_fw, cell_bw, qh_emb, self.qh_len, dtype=tf.float32) 79 | qh_emb = tf.concat([state_fw, state_bw], axis=1) 80 | qh_emb = tf.reshape(qh_emb, [N, QL, 2 * dg]) 81 | ch_emb = tf.reshape(ch_emb, [N, PL, 2 * dg]) 82 | 83 | with tf.name_scope("word"): 84 | c_emb = tf.nn.embedding_lookup(self.word_mat, self.c) 85 | q_emb = tf.nn.embedding_lookup(self.word_mat, self.q) 86 | 87 | c_emb = tf.concat([c_emb, ch_emb], axis=2) 88 | q_emb = tf.concat([q_emb, qh_emb], axis=2) 89 | 90 | with tf.variable_scope("encoding"): 91 | rnn = gru(num_layers=3, num_units=d, batch_size=N, input_size=c_emb.get_shape( 92 | ).as_list()[-1], keep_prob=config.keep_prob, is_train=self.is_train) 93 | c = rnn(c_emb, seq_len=self.c_len) 94 | q = rnn(q_emb, seq_len=self.q_len) 95 | 96 | with tf.variable_scope("attention"): 97 | qc_att = dot_attention(c, q, mask=self.q_mask, hidden=d, 98 | keep_prob=config.keep_prob, is_train=self.is_train) 99 | rnn = gru(num_layers=1, num_units=d, batch_size=N, input_size=qc_att.get_shape( 100 | ).as_list()[-1], keep_prob=config.keep_prob, is_train=self.is_train) 101 | att = rnn(qc_att, seq_len=self.c_len) 102 | 103 | with tf.variable_scope("match"): 104 | self_att = dot_attention( 105 | att, att, mask=self.c_mask, hidden=d, keep_prob=config.keep_prob, is_train=self.is_train) 106 | rnn = gru(num_layers=1, num_units=d, batch_size=N, input_size=self_att.get_shape( 107 | ).as_list()[-1], keep_prob=config.keep_prob, is_train=self.is_train) 108 | match = rnn(self_att, seq_len=self.c_len) 109 | 110 | with tf.variable_scope("pointer"): 111 | init = summ(q[:, :, -2 * d:], d, mask=self.q_mask, 112 | keep_prob=config.ptr_keep_prob, is_train=self.is_train) 113 | pointer = ptr_net(batch=N, hidden=init.get_shape().as_list( 114 | )[-1], keep_prob=config.ptr_keep_prob, is_train=self.is_train) 115 | logits1, logits2 = pointer(init, match, d, self.c_mask) 116 | 117 | with tf.variable_scope("predict"): 118 | outer = tf.matmul(tf.expand_dims(tf.nn.softmax(logits1), axis=2), 119 | tf.expand_dims(tf.nn.softmax(logits2), axis=1)) 120 | outer = tf.matrix_band_part(outer, 0, 15) 121 | self.yp1 = tf.argmax(tf.reduce_max(outer, axis=2), axis=1) 122 | self.yp2 = tf.argmax(tf.reduce_max(outer, axis=1), axis=1) 123 | losses = tf.nn.softmax_cross_entropy_with_logits_v2( 124 | logits=logits1, labels=tf.stop_gradient(self.y1)) 125 | losses2 = tf.nn.softmax_cross_entropy_with_logits_v2( 126 | logits=logits2, labels=tf.stop_gradient(self.y2)) 127 | self.loss = tf.reduce_mean(losses + losses2) 128 | 129 | def get_loss(self): 130 | return self.loss 131 | 132 | def get_global_step(self): 133 | return self.global_step 134 | -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import random 3 | from tqdm import tqdm 4 | import spacy 5 | import ujson as json 6 | from collections import Counter 7 | import numpy as np 8 | import os.path 9 | 10 | nlp = spacy.blank("en") 11 | 12 | 13 | def word_tokenize(sent): 14 | doc = nlp(sent) 15 | return [token.text for token in doc] 16 | 17 | 18 | def convert_idx(text, tokens): 19 | current = 0 20 | spans = [] 21 | for token in tokens: 22 | current = text.find(token, current) 23 | if current < 0: 24 | print("Token {} cannot be found".format(token)) 25 | raise Exception() 26 | spans.append((current, current + len(token))) 27 | current += len(token) 28 | return spans 29 | 30 | 31 | def process_file(filename, data_type, word_counter, char_counter): 32 | print("Generating {} examples...".format(data_type)) 33 | examples = [] 34 | eval_examples = {} 35 | total = 0 36 | with open(filename, "r") as fh: 37 | source = json.load(fh) 38 | for article in tqdm(source["data"]): 39 | for para in article["paragraphs"]: 40 | context = para["context"].replace( 41 | "''", '" ').replace("``", '" ') 42 | context_tokens = word_tokenize(context) 43 | context_chars = [list(token) for token in context_tokens] 44 | spans = convert_idx(context, context_tokens) 45 | for token in context_tokens: 46 | word_counter[token] += len(para["qas"]) 47 | for char in token: 48 | char_counter[char] += len(para["qas"]) 49 | for qa in para["qas"]: 50 | total += 1 51 | ques = qa["question"].replace( 52 | "''", '" ').replace("``", '" ') 53 | ques_tokens = word_tokenize(ques) 54 | ques_chars = [list(token) for token in ques_tokens] 55 | for token in ques_tokens: 56 | word_counter[token] += 1 57 | for char in token: 58 | char_counter[char] += 1 59 | y1s, y2s = [], [] 60 | answer_texts = [] 61 | for answer in qa["answers"]: 62 | answer_text = answer["text"] 63 | answer_start = answer['answer_start'] 64 | answer_end = answer_start + len(answer_text) 65 | answer_texts.append(answer_text) 66 | answer_span = [] 67 | for idx, span in enumerate(spans): 68 | if not (answer_end <= span[0] or answer_start >= span[1]): 69 | answer_span.append(idx) 70 | y1, y2 = answer_span[0], answer_span[-1] 71 | y1s.append(y1) 72 | y2s.append(y2) 73 | example = {"context_tokens": context_tokens, "context_chars": context_chars, "ques_tokens": ques_tokens, 74 | "ques_chars": ques_chars, "y1s": y1s, "y2s": y2s, "id": total} 75 | examples.append(example) 76 | eval_examples[str(total)] = { 77 | "context": context, "spans": spans, "answers": answer_texts, "uuid": qa["id"]} 78 | random.shuffle(examples) 79 | print("{} questions in total".format(len(examples))) 80 | return examples, eval_examples 81 | 82 | 83 | def get_embedding(counter, data_type, limit=-1, emb_file=None, size=None, vec_size=None, token2idx_dict=None): 84 | print("Generating {} embedding...".format(data_type)) 85 | embedding_dict = {} 86 | filtered_elements = [k for k, v in counter.items() if v > limit] 87 | if emb_file is not None: 88 | assert size is not None 89 | assert vec_size is not None 90 | with open(emb_file, "r", encoding="utf-8") as fh: 91 | for line in tqdm(fh, total=size): 92 | array = line.split() 93 | word = "".join(array[0:-vec_size]) 94 | vector = list(map(float, array[-vec_size:])) 95 | if word in counter and counter[word] > limit: 96 | embedding_dict[word] = vector 97 | print("{} / {} tokens have corresponding {} embedding vector".format( 98 | len(embedding_dict), len(filtered_elements), data_type)) 99 | else: 100 | assert vec_size is not None 101 | for token in filtered_elements: 102 | embedding_dict[token] = [np.random.normal( 103 | scale=0.01) for _ in range(vec_size)] 104 | print("{} tokens have corresponding embedding vector".format( 105 | len(filtered_elements))) 106 | 107 | NULL = "--NULL--" 108 | OOV = "--OOV--" 109 | token2idx_dict = {token: idx for idx, token in enumerate( 110 | embedding_dict.keys(), 2)} if token2idx_dict is None else token2idx_dict 111 | token2idx_dict[NULL] = 0 112 | token2idx_dict[OOV] = 1 113 | embedding_dict[NULL] = [0. for _ in range(vec_size)] 114 | embedding_dict[OOV] = [0. for _ in range(vec_size)] 115 | idx2emb_dict = {idx: embedding_dict[token] 116 | for token, idx in token2idx_dict.items()} 117 | emb_mat = [idx2emb_dict[idx] for idx in range(len(idx2emb_dict))] 118 | return emb_mat, token2idx_dict 119 | 120 | 121 | def build_features(config, examples, data_type, out_file, word2idx_dict, char2idx_dict, is_test=False): 122 | 123 | para_limit = config.test_para_limit if is_test else config.para_limit 124 | ques_limit = config.test_ques_limit if is_test else config.ques_limit 125 | char_limit = config.char_limit 126 | 127 | def filter_func(example, is_test=False): 128 | return len(example["context_tokens"]) > para_limit or len(example["ques_tokens"]) > ques_limit 129 | 130 | print("Processing {} examples...".format(data_type)) 131 | writer = tf.python_io.TFRecordWriter(out_file) 132 | total = 0 133 | total_ = 0 134 | meta = {} 135 | for example in tqdm(examples): 136 | total_ += 1 137 | 138 | if filter_func(example, is_test): 139 | continue 140 | 141 | total += 1 142 | context_idxs = np.zeros([para_limit], dtype=np.int32) 143 | context_char_idxs = np.zeros([para_limit, char_limit], dtype=np.int32) 144 | ques_idxs = np.zeros([ques_limit], dtype=np.int32) 145 | ques_char_idxs = np.zeros([ques_limit, char_limit], dtype=np.int32) 146 | y1 = np.zeros([para_limit], dtype=np.float32) 147 | y2 = np.zeros([para_limit], dtype=np.float32) 148 | 149 | def _get_word(word): 150 | for each in (word, word.lower(), word.capitalize(), word.upper()): 151 | if each in word2idx_dict: 152 | return word2idx_dict[each] 153 | return 1 154 | 155 | def _get_char(char): 156 | if char in char2idx_dict: 157 | return char2idx_dict[char] 158 | return 1 159 | 160 | for i, token in enumerate(example["context_tokens"]): 161 | context_idxs[i] = _get_word(token) 162 | 163 | for i, token in enumerate(example["ques_tokens"]): 164 | ques_idxs[i] = _get_word(token) 165 | 166 | for i, token in enumerate(example["context_chars"]): 167 | for j, char in enumerate(token): 168 | if j == char_limit: 169 | break 170 | context_char_idxs[i, j] = _get_char(char) 171 | 172 | for i, token in enumerate(example["ques_chars"]): 173 | for j, char in enumerate(token): 174 | if j == char_limit: 175 | break 176 | ques_char_idxs[i, j] = _get_char(char) 177 | 178 | start, end = example["y1s"][-1], example["y2s"][-1] 179 | y1[start], y2[end] = 1.0, 1.0 180 | 181 | record = tf.train.Example(features=tf.train.Features(feature={ 182 | "context_idxs": tf.train.Feature(bytes_list=tf.train.BytesList(value=[context_idxs.tostring()])), 183 | "ques_idxs": tf.train.Feature(bytes_list=tf.train.BytesList(value=[ques_idxs.tostring()])), 184 | "context_char_idxs": tf.train.Feature(bytes_list=tf.train.BytesList(value=[context_char_idxs.tostring()])), 185 | "ques_char_idxs": tf.train.Feature(bytes_list=tf.train.BytesList(value=[ques_char_idxs.tostring()])), 186 | "y1": tf.train.Feature(bytes_list=tf.train.BytesList(value=[y1.tostring()])), 187 | "y2": tf.train.Feature(bytes_list=tf.train.BytesList(value=[y2.tostring()])), 188 | "id": tf.train.Feature(int64_list=tf.train.Int64List(value=[example["id"]])) 189 | })) 190 | writer.write(record.SerializeToString()) 191 | print("Build {} / {} instances of features in total".format(total, total_)) 192 | meta["total"] = total 193 | writer.close() 194 | return meta 195 | 196 | 197 | def save(filename, obj, message=None): 198 | if message is not None: 199 | print("Saving {}...".format(message)) 200 | with open(filename, "w") as fh: 201 | json.dump(obj, fh) 202 | 203 | 204 | def prepro(config): 205 | word_counter, char_counter = Counter(), Counter() 206 | train_examples, train_eval = process_file( 207 | config.train_file, "train", word_counter, char_counter) 208 | dev_examples, dev_eval = process_file( 209 | config.dev_file, "dev", word_counter, char_counter) 210 | test_examples, test_eval = process_file( 211 | config.test_file, "test", word_counter, char_counter) 212 | 213 | word_emb_file = config.fasttext_file if config.fasttext else config.glove_word_file 214 | char_emb_file = config.glove_char_file if config.pretrained_char else None 215 | char_emb_size = config.glove_char_size if config.pretrained_char else None 216 | char_emb_dim = config.glove_dim if config.pretrained_char else config.char_dim 217 | 218 | word2idx_dict = None 219 | if os.path.isfile(config.word2idx_file): 220 | with open(config.word2idx_file, "r") as fh: 221 | word2idx_dict = json.load(fh) 222 | word_emb_mat, word2idx_dict = get_embedding(word_counter, "word", emb_file=word_emb_file, 223 | size=config.glove_word_size, vec_size=config.glove_dim, token2idx_dict=word2idx_dict) 224 | 225 | char2idx_dict = None 226 | if os.path.isfile(config.char2idx_file): 227 | with open(config.char2idx_file, "r") as fh: 228 | char2idx_dict = json.load(fh) 229 | char_emb_mat, char2idx_dict = get_embedding( 230 | char_counter, "char", emb_file=char_emb_file, size=char_emb_size, vec_size=char_emb_dim, token2idx_dict=char2idx_dict) 231 | 232 | build_features(config, train_examples, "train", 233 | config.train_record_file, word2idx_dict, char2idx_dict) 234 | dev_meta = build_features(config, dev_examples, "dev", 235 | config.dev_record_file, word2idx_dict, char2idx_dict) 236 | test_meta = build_features(config, test_examples, "test", 237 | config.test_record_file, word2idx_dict, char2idx_dict, is_test=True) 238 | 239 | save(config.word_emb_file, word_emb_mat, message="word embedding") 240 | save(config.char_emb_file, char_emb_mat, message="char embedding") 241 | save(config.train_eval_file, train_eval, message="train eval") 242 | save(config.dev_eval_file, dev_eval, message="dev eval") 243 | save(config.test_eval_file, test_eval, message="test eval") 244 | save(config.dev_meta, dev_meta, message="dev meta") 245 | save(config.word2idx_file, word2idx_dict, message="word2idx") 246 | save(config.char2idx_file, char2idx_dict, message="char2idx") 247 | save(config.test_meta, test_meta, message="test meta") 248 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import re 4 | from collections import Counter 5 | import string 6 | 7 | 8 | def get_record_parser(config, is_test=False): 9 | def parse(example): 10 | para_limit = config.test_para_limit if is_test else config.para_limit 11 | ques_limit = config.test_ques_limit if is_test else config.ques_limit 12 | char_limit = config.char_limit 13 | features = tf.parse_single_example(example, 14 | features={ 15 | "context_idxs": tf.FixedLenFeature([], tf.string), 16 | "ques_idxs": tf.FixedLenFeature([], tf.string), 17 | "context_char_idxs": tf.FixedLenFeature([], tf.string), 18 | "ques_char_idxs": tf.FixedLenFeature([], tf.string), 19 | "y1": tf.FixedLenFeature([], tf.string), 20 | "y2": tf.FixedLenFeature([], tf.string), 21 | "id": tf.FixedLenFeature([], tf.int64) 22 | }) 23 | context_idxs = tf.reshape(tf.decode_raw( 24 | features["context_idxs"], tf.int32), [para_limit]) 25 | ques_idxs = tf.reshape(tf.decode_raw( 26 | features["ques_idxs"], tf.int32), [ques_limit]) 27 | context_char_idxs = tf.reshape(tf.decode_raw( 28 | features["context_char_idxs"], tf.int32), [para_limit, char_limit]) 29 | ques_char_idxs = tf.reshape(tf.decode_raw( 30 | features["ques_char_idxs"], tf.int32), [ques_limit, char_limit]) 31 | y1 = tf.reshape(tf.decode_raw( 32 | features["y1"], tf.float32), [para_limit]) 33 | y2 = tf.reshape(tf.decode_raw( 34 | features["y2"], tf.float32), [para_limit]) 35 | qa_id = features["id"] 36 | return context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, y1, y2, qa_id 37 | return parse 38 | 39 | 40 | def get_batch_dataset(record_file, parser, config): 41 | num_threads = tf.constant(config.num_threads, dtype=tf.int32) 42 | dataset = tf.data.TFRecordDataset(record_file).map( 43 | parser, num_parallel_calls=num_threads).shuffle(config.capacity).repeat() 44 | if config.is_bucket: 45 | buckets = [tf.constant(num) for num in range(*config.bucket_range)] 46 | 47 | def key_func(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, y1, y2, qa_id): 48 | c_len = tf.reduce_sum( 49 | tf.cast(tf.cast(context_idxs, tf.bool), tf.int32)) 50 | buckets_min = [np.iinfo(np.int32).min] + buckets 51 | buckets_max = buckets + [np.iinfo(np.int32).max] 52 | conditions_c = tf.logical_and( 53 | tf.less(buckets_min, c_len), tf.less_equal(c_len, buckets_max)) 54 | bucket_id = tf.reduce_min(tf.where(conditions_c)) 55 | return bucket_id 56 | 57 | def reduce_func(key, elements): 58 | return elements.batch(config.batch_size) 59 | 60 | dataset = dataset.apply(tf.contrib.data.group_by_window( 61 | key_func, reduce_func, window_size=5 * config.batch_size)).shuffle(len(buckets) * 25) 62 | else: 63 | dataset = dataset.batch(config.batch_size) 64 | return dataset 65 | 66 | 67 | def get_dataset(record_file, parser, config): 68 | num_threads = tf.constant(config.num_threads, dtype=tf.int32) 69 | dataset = tf.data.TFRecordDataset(record_file).map( 70 | parser, num_parallel_calls=num_threads).repeat().batch(config.batch_size) 71 | return dataset 72 | 73 | 74 | def convert_tokens(eval_file, qa_id, pp1, pp2): 75 | answer_dict = {} 76 | remapped_dict = {} 77 | for qid, p1, p2 in zip(qa_id, pp1, pp2): 78 | context = eval_file[str(qid)]["context"] 79 | spans = eval_file[str(qid)]["spans"] 80 | uuid = eval_file[str(qid)]["uuid"] 81 | start_idx = spans[p1][0] 82 | end_idx = spans[p2][1] 83 | answer_dict[str(qid)] = context[start_idx: end_idx] 84 | remapped_dict[uuid] = context[start_idx: end_idx] 85 | return answer_dict, remapped_dict 86 | 87 | 88 | def evaluate(eval_file, answer_dict): 89 | f1 = exact_match = total = 0 90 | for key, value in answer_dict.items(): 91 | total += 1 92 | ground_truths = eval_file[key]["answers"] 93 | prediction = value 94 | exact_match += metric_max_over_ground_truths( 95 | exact_match_score, prediction, ground_truths) 96 | f1 += metric_max_over_ground_truths(f1_score, 97 | prediction, ground_truths) 98 | exact_match = 100.0 * exact_match / total 99 | f1 = 100.0 * f1 / total 100 | return {'exact_match': exact_match, 'f1': f1} 101 | 102 | 103 | def normalize_answer(s): 104 | 105 | def remove_articles(text): 106 | return re.sub(r'\b(a|an|the)\b', ' ', text) 107 | 108 | def white_space_fix(text): 109 | return ' '.join(text.split()) 110 | 111 | def remove_punc(text): 112 | exclude = set(string.punctuation) 113 | return ''.join(ch for ch in text if ch not in exclude) 114 | 115 | def lower(text): 116 | return text.lower() 117 | 118 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 119 | 120 | 121 | def f1_score(prediction, ground_truth): 122 | prediction_tokens = normalize_answer(prediction).split() 123 | ground_truth_tokens = normalize_answer(ground_truth).split() 124 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 125 | num_same = sum(common.values()) 126 | if num_same == 0: 127 | return 0 128 | precision = 1.0 * num_same / len(prediction_tokens) 129 | recall = 1.0 * num_same / len(ground_truth_tokens) 130 | f1 = (2 * precision * recall) / (precision + recall) 131 | return f1 132 | 133 | 134 | def exact_match_score(prediction, ground_truth): 135 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 136 | 137 | 138 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 139 | scores_for_ground_truths = [] 140 | for ground_truth in ground_truths: 141 | score = metric_fn(prediction, ground_truth) 142 | scores_for_ground_truths.append(score) 143 | return max(scores_for_ground_truths) 144 | --------------------------------------------------------------------------------