├── LICENSE
├── README.md
├── config.py
├── download.sh
├── evaluate-v1.1.py
├── func.py
├── img
├── em.jpg
└── f1.jpg
├── inference.py
├── main.py
├── model.py
├── prepro.py
└── util.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 HKUST-KnowComp
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # R-Net
2 | * A Tensorflow implementation of [R-NET: MACHINE READING COMPREHENSION WITH SELF-MATCHING NETWORKS](https://www.microsoft.com/en-us/research/wp-content/uploads/2017/05/r-net.pdf). This project is specially designed for the [SQuAD](https://arxiv.org/pdf/1606.05250.pdf) dataset.
3 | * Should you have any question, please contact Wenxuan Zhou (wzhouad@connect.ust.hk).
4 |
5 | ## Requirements
6 |
7 | There have been a lot of known problems caused by using different software versions. Please check your versions before opening issues or emailing me.
8 |
9 | #### General
10 | * Python >= 3.4
11 | * unzip, wget
12 | #### Python Packages
13 | * tensorflow-gpu >= 1.5.0
14 | * spaCy >= 2.0.0
15 | * tqdm
16 | * ujson
17 |
18 | ## Usage
19 |
20 | To download and preprocess the data, run
21 |
22 | ```bash
23 | # download SQuAD and Glove
24 | sh download.sh
25 | # preprocess the data
26 | python config.py --mode prepro
27 | ```
28 |
29 | Hyper parameters are stored in config.py. To debug/train/test the model, run
30 |
31 | ```bash
32 | python config.py --mode debug/train/test
33 | ```
34 |
35 | To get the official score, run
36 |
37 | ```bash
38 | python evaluate-v1.1.py ~/data/squad/dev-v1.1.json log/answer/answer.json
39 | ```
40 |
41 | The default directory for tensorboard log file is `log/event`
42 |
43 | See release for trained model.
44 |
45 | ## Detailed Implementaion
46 |
47 | * The original paper uses additive attention, which consumes lots of memory. This project adopts scaled multiplicative attention presented in [Attention Is All You Need](https://arxiv.org/abs/1706.03762).
48 | * This project adopts variational dropout presented in [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](https://arxiv.org/abs/1512.05287).
49 | * To solve the degradation problem in stacked RNN, outputs of each layer are concatenated to produce the final output.
50 | * When the loss on dev set increases in a certain period, the learning rate is halved.
51 | * During prediction, the project adopts search method presented in [Machine Comprehension Using Match-LSTM and Answer Pointer](https://arxiv.org/abs/1608.07905).
52 | * To address efficiency issue, this implementation uses bucketing method (contributed by xiongyifan) and CudnnGRU. The bucketing method can speedup training, but will lower the F1 score by 0.3%.
53 |
54 | ## Performance
55 |
56 | #### Score
57 |
58 | ||EM|F1|
59 | |---|---|---|
60 | |original paper|71.1|79.5|
61 | |this project|71.07|79.51|
62 |
63 |
64 |
65 |
66 |
67 | #### Training Time (s/it)
68 |
69 | ||Native|Native + Bucket|Cudnn|Cudnn + Bucket|
70 | |---|---|---|---|---|
71 | |E5-2640|6.21|3.56|-|-|
72 | |TITAN X|2.56|1.31|0.41|0.28|
73 |
74 | ## Extensions
75 |
76 | These settings may increase the score but not used in the model by default. You can turn these settings on in `config.py`.
77 |
78 | * [Pretrained GloVe character embedding](https://github.com/minimaxir/char-embeddings). Contributed by yanghanxy.
79 | * [Fasttext Embedding](https://fasttext.cc/docs/en/english-vectors.html). Contributed by xiongyifan. May increase the F1 by 1% (reported by xiongyifan).
80 |
81 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 |
4 | from prepro import prepro
5 | from main import train, test
6 |
7 | flags = tf.flags
8 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
9 |
10 | home = os.path.expanduser("~")
11 | train_file = os.path.join(home, "data", "squad", "train-v1.1.json")
12 | dev_file = os.path.join(home, "data", "squad", "dev-v1.1.json")
13 | test_file = os.path.join(home, "data", "squad", "dev-v1.1.json")
14 | glove_word_file = os.path.join(home, "data", "glove", "glove.840B.300d.txt")
15 |
16 | target_dir = "data"
17 | log_dir = "log/event"
18 | save_dir = "log/model"
19 | answer_dir = "log/answer"
20 | train_record_file = os.path.join(target_dir, "train.tfrecords")
21 | dev_record_file = os.path.join(target_dir, "dev.tfrecords")
22 | test_record_file = os.path.join(target_dir, "test.tfrecords")
23 | word_emb_file = os.path.join(target_dir, "word_emb.json")
24 | char_emb_file = os.path.join(target_dir, "char_emb.json")
25 | train_eval = os.path.join(target_dir, "train_eval.json")
26 | dev_eval = os.path.join(target_dir, "dev_eval.json")
27 | test_eval = os.path.join(target_dir, "test_eval.json")
28 | dev_meta = os.path.join(target_dir, "dev_meta.json")
29 | test_meta = os.path.join(target_dir, "test_meta.json")
30 | word2idx_file = os.path.join(target_dir, "word2idx.json")
31 | char2idx_file = os.path.join(target_dir, "char2idx.json")
32 | answer_file = os.path.join(answer_dir, "answer.json")
33 |
34 | if not os.path.exists(target_dir):
35 | os.makedirs(target_dir)
36 | if not os.path.exists(log_dir):
37 | os.makedirs(log_dir)
38 | if not os.path.exists(save_dir):
39 | os.makedirs(save_dir)
40 | if not os.path.exists(answer_dir):
41 | os.makedirs(answer_dir)
42 |
43 | flags.DEFINE_string("mode", "train", "train/debug/test")
44 |
45 | flags.DEFINE_string("target_dir", target_dir, "")
46 | flags.DEFINE_string("log_dir", log_dir, "")
47 | flags.DEFINE_string("save_dir", save_dir, "")
48 | flags.DEFINE_string("train_file", train_file, "")
49 | flags.DEFINE_string("dev_file", dev_file, "")
50 | flags.DEFINE_string("test_file", test_file, "")
51 | flags.DEFINE_string("glove_word_file", glove_word_file, "")
52 |
53 | flags.DEFINE_string("train_record_file", train_record_file, "")
54 | flags.DEFINE_string("dev_record_file", dev_record_file, "")
55 | flags.DEFINE_string("test_record_file", test_record_file, "")
56 | flags.DEFINE_string("word_emb_file", word_emb_file, "")
57 | flags.DEFINE_string("char_emb_file", char_emb_file, "")
58 | flags.DEFINE_string("train_eval_file", train_eval, "")
59 | flags.DEFINE_string("dev_eval_file", dev_eval, "")
60 | flags.DEFINE_string("test_eval_file", test_eval, "")
61 | flags.DEFINE_string("dev_meta", dev_meta, "")
62 | flags.DEFINE_string("test_meta", test_meta, "")
63 | flags.DEFINE_string("word2idx_file", word2idx_file, "")
64 | flags.DEFINE_string("char2idx_file", char2idx_file, "")
65 | flags.DEFINE_string("answer_file", answer_file, "")
66 |
67 |
68 | flags.DEFINE_integer("glove_char_size", 94, "Corpus size for Glove")
69 | flags.DEFINE_integer("glove_word_size", int(2.2e6), "Corpus size for Glove")
70 | flags.DEFINE_integer("glove_dim", 300, "Embedding dimension for Glove")
71 | flags.DEFINE_integer("char_dim", 8, "Embedding dimension for char")
72 |
73 | flags.DEFINE_integer("para_limit", 400, "Limit length for paragraph")
74 | flags.DEFINE_integer("ques_limit", 50, "Limit length for question")
75 | flags.DEFINE_integer("test_para_limit", 1000,
76 | "Max length for paragraph in test")
77 | flags.DEFINE_integer("test_ques_limit", 100, "Max length of questions in test")
78 | flags.DEFINE_integer("char_limit", 16, "Limit length for character")
79 | flags.DEFINE_integer("word_count_limit", -1, "Min count for word")
80 | flags.DEFINE_integer("char_count_limit", -1, "Min count for char")
81 |
82 | flags.DEFINE_integer("capacity", 15000, "Batch size of dataset shuffle")
83 | flags.DEFINE_integer("num_threads", 4, "Number of threads in input pipeline")
84 | flags.DEFINE_boolean("use_cudnn", True, "Whether to use cudnn (only for GPU)")
85 | flags.DEFINE_boolean("is_bucket", False, "Whether to use bucketing")
86 | flags.DEFINE_list("bucket_range", [40, 361, 40], "range of bucket")
87 |
88 | flags.DEFINE_integer("batch_size", 64, "Batch size")
89 | flags.DEFINE_integer("num_steps", 60000, "Number of steps")
90 | flags.DEFINE_integer("checkpoint", 1000, "checkpoint for evaluation")
91 | flags.DEFINE_integer("period", 100, "period to save batch loss")
92 | flags.DEFINE_integer("val_num_batches", 150, "Num of batches for evaluation")
93 | flags.DEFINE_float("init_lr", 0.5, "Initial lr for Adadelta")
94 | flags.DEFINE_float("keep_prob", 0.7, "Keep prob in rnn")
95 | flags.DEFINE_float("ptr_keep_prob", 0.7, "Keep prob for pointer network")
96 | flags.DEFINE_float("grad_clip", 5.0, "Global Norm gradient clipping rate")
97 | flags.DEFINE_integer("hidden", 75, "Hidden size")
98 | flags.DEFINE_integer("char_hidden", 100, "GRU dim for char")
99 | flags.DEFINE_integer("patience", 3, "Patience for lr decay")
100 |
101 | # Extensions (Uncomment corresponding line in download.sh to download the required data)
102 | glove_char_file = os.path.join(
103 | home, "data", "glove", "glove.840B.300d-char.txt")
104 | flags.DEFINE_string("glove_char_file", glove_char_file,
105 | "Glove character embedding")
106 | flags.DEFINE_boolean("pretrained_char", False,
107 | "Whether to use pretrained char embedding")
108 |
109 | fasttext_file = os.path.join(home, "data", "fasttext", "wiki-news-300d-1M.vec")
110 | flags.DEFINE_string("fasttext_file", fasttext_file, "Fasttext word embedding")
111 | flags.DEFINE_boolean("fasttext", False, "Whether to use fasttext")
112 |
113 |
114 | def main(_):
115 | config = flags.FLAGS
116 | if config.mode == "train":
117 | train(config)
118 | elif config.mode == "prepro":
119 | prepro(config)
120 | elif config.mode == "debug":
121 | config.num_steps = 2
122 | config.val_num_batches = 1
123 | config.checkpoint = 1
124 | config.period = 1
125 | train(config)
126 | elif config.mode == "test":
127 | test(config)
128 | else:
129 | print("Unknown mode")
130 | exit(0)
131 |
132 |
133 | if __name__ == "__main__":
134 | tf.app.run()
135 |
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Download SQuAD
4 | SQUAD_DIR=~/data/squad
5 | mkdir -p $SQUAD_DIR
6 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json -O $SQUAD_DIR/train-v1.1.json
7 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json -O $SQUAD_DIR/dev-v1.1.json
8 |
9 | # Download GloVe
10 | GLOVE_DIR=~/data/glove
11 | mkdir -p $GLOVE_DIR
12 | wget http://nlp.stanford.edu/data/glove.840B.300d.zip -O $GLOVE_DIR/glove.840B.300d.zip
13 | unzip $GLOVE_DIR/glove.840B.300d.zip -d $GLOVE_DIR
14 |
15 | # Download Glove Character Embedding
16 | # wget https://raw.githubusercontent.com/minimaxir/char-embeddings/master/glove.840B.300d-char.txt -O $GLOVE_DIR/glove.840B.300d-char.txt
17 |
18 | # Download fasttext
19 | # FASTTEXT_DIR=~/data/fasttext
20 | # mkdir -p $FASTTEXT_DIR
21 | # wget https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki-news-300d-1M.vec.zip -O $FASTTEXT_DIR/wiki-news-300d-1M.vec.zip
22 | # unzip $FASTTEXT_DIR/wiki-news-300d-1M.vec.zip -d $FASTTEXT_DIR
23 |
24 | # Download Spacy language models
25 | python3 -m spacy download en
26 |
--------------------------------------------------------------------------------
/evaluate-v1.1.py:
--------------------------------------------------------------------------------
1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """
2 | from __future__ import print_function
3 | from collections import Counter
4 | import string
5 | import re
6 | import argparse
7 | import json
8 | import sys
9 |
10 |
11 | def normalize_answer(s):
12 | """Lower text and remove punctuation, articles and extra whitespace."""
13 | def remove_articles(text):
14 | return re.sub(r'\b(a|an|the)\b', ' ', text)
15 |
16 | def white_space_fix(text):
17 | return ' '.join(text.split())
18 |
19 | def remove_punc(text):
20 | exclude = set(string.punctuation)
21 | return ''.join(ch for ch in text if ch not in exclude)
22 |
23 | def lower(text):
24 | return text.lower()
25 |
26 | return white_space_fix(remove_articles(remove_punc(lower(s))))
27 |
28 |
29 | def f1_score(prediction, ground_truth):
30 | prediction_tokens = normalize_answer(prediction).split()
31 | ground_truth_tokens = normalize_answer(ground_truth).split()
32 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
33 | num_same = sum(common.values())
34 | if num_same == 0:
35 | return 0
36 | precision = 1.0 * num_same / len(prediction_tokens)
37 | recall = 1.0 * num_same / len(ground_truth_tokens)
38 | f1 = (2 * precision * recall) / (precision + recall)
39 | return f1
40 |
41 |
42 | def exact_match_score(prediction, ground_truth):
43 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
44 |
45 |
46 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
47 | scores_for_ground_truths = []
48 | for ground_truth in ground_truths:
49 | score = metric_fn(prediction, ground_truth)
50 | scores_for_ground_truths.append(score)
51 | return max(scores_for_ground_truths)
52 |
53 |
54 | def evaluate(dataset, predictions):
55 | f1 = exact_match = total = 0
56 | for article in dataset:
57 | for paragraph in article['paragraphs']:
58 | for qa in paragraph['qas']:
59 | total += 1
60 | if qa['id'] not in predictions:
61 | message = 'Unanswered question ' + qa['id'] + \
62 | ' will receive score 0.'
63 | print(message, file=sys.stderr)
64 | continue
65 | ground_truths = list(map(lambda x: x['text'], qa['answers']))
66 | prediction = predictions[qa['id']]
67 | exact_match += metric_max_over_ground_truths(
68 | exact_match_score, prediction, ground_truths)
69 | f1 += metric_max_over_ground_truths(
70 | f1_score, prediction, ground_truths)
71 |
72 | exact_match = 100.0 * exact_match / total
73 | f1 = 100.0 * f1 / total
74 |
75 | return {'exact_match': exact_match, 'f1': f1}
76 |
77 |
78 | if __name__ == '__main__':
79 | expected_version = '1.1'
80 | parser = argparse.ArgumentParser(
81 | description='Evaluation for SQuAD ' + expected_version)
82 | parser.add_argument('dataset_file', help='Dataset file')
83 | parser.add_argument('prediction_file', help='Prediction File')
84 | args = parser.parse_args()
85 | with open(args.dataset_file) as dataset_file:
86 | dataset_json = json.load(dataset_file)
87 | if (dataset_json['version'] != expected_version):
88 | print('Evaluation expects v-' + expected_version +
89 | ', but got dataset with v-' + dataset_json['version'],
90 | file=sys.stderr)
91 | dataset = dataset_json['data']
92 | with open(args.prediction_file) as prediction_file:
93 | predictions = json.load(prediction_file)
94 | print(json.dumps(evaluate(dataset, predictions)))
95 |
--------------------------------------------------------------------------------
/func.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | INF = 1e30
4 |
5 |
6 | class cudnn_gru:
7 |
8 | def __init__(self, num_layers, num_units, batch_size, input_size, keep_prob=1.0, is_train=None, scope=None):
9 | self.num_layers = num_layers
10 | self.grus = []
11 | self.inits = []
12 | self.dropout_mask = []
13 | for layer in range(num_layers):
14 | input_size_ = input_size if layer == 0 else 2 * num_units
15 | gru_fw = tf.contrib.cudnn_rnn.CudnnGRU(1, num_units)
16 | gru_bw = tf.contrib.cudnn_rnn.CudnnGRU(1, num_units)
17 | init_fw = tf.tile(tf.Variable(
18 | tf.zeros([1, 1, num_units])), [1, batch_size, 1])
19 | init_bw = tf.tile(tf.Variable(
20 | tf.zeros([1, 1, num_units])), [1, batch_size, 1])
21 | mask_fw = dropout(tf.ones([1, batch_size, input_size_], dtype=tf.float32),
22 | keep_prob=keep_prob, is_train=is_train, mode=None)
23 | mask_bw = dropout(tf.ones([1, batch_size, input_size_], dtype=tf.float32),
24 | keep_prob=keep_prob, is_train=is_train, mode=None)
25 | self.grus.append((gru_fw, gru_bw, ))
26 | self.inits.append((init_fw, init_bw, ))
27 | self.dropout_mask.append((mask_fw, mask_bw, ))
28 |
29 | def __call__(self, inputs, seq_len, keep_prob=1.0, is_train=None, concat_layers=True):
30 | outputs = [tf.transpose(inputs, [1, 0, 2])]
31 | for layer in range(self.num_layers):
32 | gru_fw, gru_bw = self.grus[layer]
33 | init_fw, init_bw = self.inits[layer]
34 | mask_fw, mask_bw = self.dropout_mask[layer]
35 | with tf.variable_scope("fw_{}".format(layer)):
36 | out_fw, _ = gru_fw(
37 | outputs[-1] * mask_fw, initial_state=(init_fw, ))
38 | with tf.variable_scope("bw_{}".format(layer)):
39 | inputs_bw = tf.reverse_sequence(
40 | outputs[-1] * mask_bw, seq_lengths=seq_len, seq_dim=0, batch_dim=1)
41 | out_bw, _ = gru_bw(inputs_bw, initial_state=(init_bw, ))
42 | out_bw = tf.reverse_sequence(
43 | out_bw, seq_lengths=seq_len, seq_dim=0, batch_dim=1)
44 | outputs.append(tf.concat([out_fw, out_bw], axis=2))
45 | if concat_layers:
46 | res = tf.concat(outputs[1:], axis=2)
47 | else:
48 | res = outputs[-1]
49 | res = tf.transpose(res, [1, 0, 2])
50 | return res
51 |
52 |
53 | class native_gru:
54 |
55 | def __init__(self, num_layers, num_units, batch_size, input_size, keep_prob=1.0, is_train=None, scope="native_gru"):
56 | self.num_layers = num_layers
57 | self.grus = []
58 | self.inits = []
59 | self.dropout_mask = []
60 | self.scope = scope
61 | for layer in range(num_layers):
62 | input_size_ = input_size if layer == 0 else 2 * num_units
63 | gru_fw = tf.contrib.rnn.GRUCell(num_units)
64 | gru_bw = tf.contrib.rnn.GRUCell(num_units)
65 | init_fw = tf.tile(tf.Variable(
66 | tf.zeros([1, num_units])), [batch_size, 1])
67 | init_bw = tf.tile(tf.Variable(
68 | tf.zeros([1, num_units])), [batch_size, 1])
69 | mask_fw = dropout(tf.ones([batch_size, 1, input_size_], dtype=tf.float32),
70 | keep_prob=keep_prob, is_train=is_train, mode=None)
71 | mask_bw = dropout(tf.ones([batch_size, 1, input_size_], dtype=tf.float32),
72 | keep_prob=keep_prob, is_train=is_train, mode=None)
73 | self.grus.append((gru_fw, gru_bw, ))
74 | self.inits.append((init_fw, init_bw, ))
75 | self.dropout_mask.append((mask_fw, mask_bw, ))
76 |
77 | def __call__(self, inputs, seq_len, keep_prob=1.0, is_train=None, concat_layers=True):
78 | outputs = [inputs]
79 | with tf.variable_scope(self.scope):
80 | for layer in range(self.num_layers):
81 | gru_fw, gru_bw = self.grus[layer]
82 | init_fw, init_bw = self.inits[layer]
83 | mask_fw, mask_bw = self.dropout_mask[layer]
84 | with tf.variable_scope("fw_{}".format(layer)):
85 | out_fw, _ = tf.nn.dynamic_rnn(
86 | gru_fw, outputs[-1] * mask_fw, seq_len, initial_state=init_fw, dtype=tf.float32)
87 | with tf.variable_scope("bw_{}".format(layer)):
88 | inputs_bw = tf.reverse_sequence(
89 | outputs[-1] * mask_bw, seq_lengths=seq_len, seq_dim=1, batch_dim=0)
90 | out_bw, _ = tf.nn.dynamic_rnn(
91 | gru_bw, inputs_bw, seq_len, initial_state=init_bw, dtype=tf.float32)
92 | out_bw = tf.reverse_sequence(
93 | out_bw, seq_lengths=seq_len, seq_dim=1, batch_dim=0)
94 | outputs.append(tf.concat([out_fw, out_bw], axis=2))
95 | if concat_layers:
96 | res = tf.concat(outputs[1:], axis=2)
97 | else:
98 | res = outputs[-1]
99 | return res
100 |
101 |
102 | class ptr_net:
103 | def __init__(self, batch, hidden, keep_prob=1.0, is_train=None, scope="ptr_net"):
104 | self.gru = tf.contrib.rnn.GRUCell(hidden)
105 | self.batch = batch
106 | self.scope = scope
107 | self.keep_prob = keep_prob
108 | self.is_train = is_train
109 | self.dropout_mask = dropout(tf.ones(
110 | [batch, hidden], dtype=tf.float32), keep_prob=keep_prob, is_train=is_train)
111 |
112 | def __call__(self, init, match, d, mask):
113 | with tf.variable_scope(self.scope):
114 | d_match = dropout(match, keep_prob=self.keep_prob,
115 | is_train=self.is_train)
116 | inp, logits1 = pointer(d_match, init * self.dropout_mask, d, mask)
117 | d_inp = dropout(inp, keep_prob=self.keep_prob,
118 | is_train=self.is_train)
119 | _, state = self.gru(d_inp, init)
120 | tf.get_variable_scope().reuse_variables()
121 | _, logits2 = pointer(d_match, state * self.dropout_mask, d, mask)
122 | return logits1, logits2
123 |
124 |
125 | def dropout(args, keep_prob, is_train, mode="recurrent"):
126 | if keep_prob < 1.0:
127 | noise_shape = None
128 | scale = 1.0
129 | shape = tf.shape(args)
130 | if mode == "embedding":
131 | noise_shape = [shape[0], 1]
132 | scale = keep_prob
133 | if mode == "recurrent" and len(args.get_shape().as_list()) == 3:
134 | noise_shape = [shape[0], 1, shape[-1]]
135 | args = tf.cond(is_train, lambda: tf.nn.dropout(
136 | args, keep_prob, noise_shape=noise_shape) * scale, lambda: args)
137 | return args
138 |
139 |
140 | def softmax_mask(val, mask):
141 | return -INF * (1 - tf.cast(mask, tf.float32)) + val
142 |
143 |
144 | def pointer(inputs, state, hidden, mask, scope="pointer"):
145 | with tf.variable_scope(scope):
146 | u = tf.concat([tf.tile(tf.expand_dims(state, axis=1), [
147 | 1, tf.shape(inputs)[1], 1]), inputs], axis=2)
148 | s0 = tf.nn.tanh(dense(u, hidden, use_bias=False, scope="s0"))
149 | s = dense(s0, 1, use_bias=False, scope="s")
150 | s1 = softmax_mask(tf.squeeze(s, [2]), mask)
151 | a = tf.expand_dims(tf.nn.softmax(s1), axis=2)
152 | res = tf.reduce_sum(a * inputs, axis=1)
153 | return res, s1
154 |
155 |
156 | def summ(memory, hidden, mask, keep_prob=1.0, is_train=None, scope="summ"):
157 | with tf.variable_scope(scope):
158 | d_memory = dropout(memory, keep_prob=keep_prob, is_train=is_train)
159 | s0 = tf.nn.tanh(dense(d_memory, hidden, scope="s0"))
160 | s = dense(s0, 1, use_bias=False, scope="s")
161 | s1 = softmax_mask(tf.squeeze(s, [2]), mask)
162 | a = tf.expand_dims(tf.nn.softmax(s1), axis=2)
163 | res = tf.reduce_sum(a * memory, axis=1)
164 | return res
165 |
166 |
167 | def dot_attention(inputs, memory, mask, hidden, keep_prob=1.0, is_train=None, scope="dot_attention"):
168 | with tf.variable_scope(scope):
169 |
170 | d_inputs = dropout(inputs, keep_prob=keep_prob, is_train=is_train)
171 | d_memory = dropout(memory, keep_prob=keep_prob, is_train=is_train)
172 | JX = tf.shape(inputs)[1]
173 |
174 | with tf.variable_scope("attention"):
175 | inputs_ = tf.nn.relu(
176 | dense(d_inputs, hidden, use_bias=False, scope="inputs"))
177 | memory_ = tf.nn.relu(
178 | dense(d_memory, hidden, use_bias=False, scope="memory"))
179 | outputs = tf.matmul(inputs_, tf.transpose(
180 | memory_, [0, 2, 1])) / (hidden ** 0.5)
181 | mask = tf.tile(tf.expand_dims(mask, axis=1), [1, JX, 1])
182 | logits = tf.nn.softmax(softmax_mask(outputs, mask))
183 | outputs = tf.matmul(logits, memory)
184 | res = tf.concat([inputs, outputs], axis=2)
185 |
186 | with tf.variable_scope("gate"):
187 | dim = res.get_shape().as_list()[-1]
188 | d_res = dropout(res, keep_prob=keep_prob, is_train=is_train)
189 | gate = tf.nn.sigmoid(dense(d_res, dim, use_bias=False))
190 | return res * gate
191 |
192 |
193 | def dense(inputs, hidden, use_bias=True, scope="dense"):
194 | with tf.variable_scope(scope):
195 | shape = tf.shape(inputs)
196 | dim = inputs.get_shape().as_list()[-1]
197 | out_shape = [shape[idx] for idx in range(
198 | len(inputs.get_shape().as_list()) - 1)] + [hidden]
199 | flat_inputs = tf.reshape(inputs, [-1, dim])
200 | W = tf.get_variable("W", [dim, hidden])
201 | res = tf.matmul(flat_inputs, W)
202 | if use_bias:
203 | b = tf.get_variable(
204 | "b", [hidden], initializer=tf.constant_initializer(0.))
205 | res = tf.nn.bias_add(res, b)
206 | res = tf.reshape(res, out_shape)
207 | return res
208 |
--------------------------------------------------------------------------------
/img/em.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUST-KnowComp/R-Net/1efc5ef18067a84d44fa8d31c309a19c52958e92/img/em.jpg
--------------------------------------------------------------------------------
/img/f1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HKUST-KnowComp/R-Net/1efc5ef18067a84d44fa8d31c309a19c52958e92/img/f1.jpg
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import spacy
3 | import os
4 | import numpy as np
5 | import ujson as json
6 |
7 |
8 | from func import cudnn_gru, native_gru, dot_attention, summ, ptr_net
9 | from prepro import word_tokenize, convert_idx
10 |
11 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
12 |
13 | # Must be consistant with training
14 | char_limit = 16
15 | hidden = 75
16 | char_dim = 8
17 | char_hidden = 100
18 | use_cudnn = True
19 |
20 | # File path
21 | target_dir = "data"
22 | save_dir = "log/model"
23 | word_emb_file = os.path.join(target_dir, "word_emb.json")
24 | char_emb_file = os.path.join(target_dir, "char_emb.json")
25 | word2idx_file = os.path.join(target_dir, "word2idx.json")
26 | char2idx_file = os.path.join(target_dir, "char2idx.json")
27 |
28 |
29 | class InfModel(object):
30 | # Used to zero elements in the probability matrix that correspond to answer
31 | # spans that are longer than the number of tokens specified here.
32 | max_answer_tokens = 15
33 |
34 | def __init__(self, word_mat, char_mat):
35 | self.c = tf.placeholder(tf.int32, [1, None])
36 | self.q = tf.placeholder(tf.int32, [1, None])
37 | self.ch = tf.placeholder(tf.int32, [1, None, char_limit])
38 | self.qh = tf.placeholder(tf.int32, [1, None, char_limit])
39 | self.tokens_in_context = tf.placeholder(tf.int64)
40 |
41 | self.word_mat = tf.get_variable("word_mat", initializer=tf.constant(
42 | word_mat, dtype=tf.float32), trainable=False)
43 | self.char_mat = tf.get_variable(
44 | "char_mat", initializer=tf.constant(char_mat, dtype=tf.float32))
45 |
46 | self.c_mask = tf.cast(self.c, tf.bool)
47 | self.q_mask = tf.cast(self.q, tf.bool)
48 | self.c_len = tf.reduce_sum(tf.cast(self.c_mask, tf.int32), axis=1)
49 | self.q_len = tf.reduce_sum(tf.cast(self.q_mask, tf.int32), axis=1)
50 |
51 | self.c_maxlen = tf.reduce_max(self.c_len)
52 | self.q_maxlen = tf.reduce_max(self.q_len)
53 |
54 | self.ch_len = tf.reshape(tf.reduce_sum(
55 | tf.cast(tf.cast(self.ch, tf.bool), tf.int32), axis=2), [-1])
56 | self.qh_len = tf.reshape(tf.reduce_sum(
57 | tf.cast(tf.cast(self.qh, tf.bool), tf.int32), axis=2), [-1])
58 |
59 | self.ready()
60 |
61 | def ready(self):
62 | N, PL, QL, CL, d, dc, dg = \
63 | 1, self.c_maxlen, self.q_maxlen, char_limit, hidden, char_dim, \
64 | char_hidden
65 | gru = cudnn_gru if use_cudnn else native_gru
66 |
67 | with tf.variable_scope("emb"):
68 | with tf.variable_scope("char"):
69 | ch_emb = tf.reshape(tf.nn.embedding_lookup(
70 | self.char_mat, self.ch), [N * PL, CL, dc])
71 | qh_emb = tf.reshape(tf.nn.embedding_lookup(
72 | self.char_mat, self.qh), [N * QL, CL, dc])
73 | cell_fw = tf.contrib.rnn.GRUCell(dg)
74 | cell_bw = tf.contrib.rnn.GRUCell(dg)
75 | _, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(
76 | cell_fw, cell_bw, ch_emb, self.ch_len, dtype=tf.float32)
77 | ch_emb = tf.concat([state_fw, state_bw], axis=1)
78 | _, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(
79 | cell_fw, cell_bw, qh_emb, self.qh_len, dtype=tf.float32)
80 | qh_emb = tf.concat([state_fw, state_bw], axis=1)
81 | qh_emb = tf.reshape(qh_emb, [N, QL, 2 * dg])
82 | ch_emb = tf.reshape(ch_emb, [N, PL, 2 * dg])
83 |
84 | with tf.name_scope("word"):
85 | c_emb = tf.nn.embedding_lookup(self.word_mat, self.c)
86 | q_emb = tf.nn.embedding_lookup(self.word_mat, self.q)
87 |
88 | c_emb = tf.concat([c_emb, ch_emb], axis=2)
89 | q_emb = tf.concat([q_emb, qh_emb], axis=2)
90 |
91 | with tf.variable_scope("encoding"):
92 | rnn = gru(num_layers=3, num_units=d, batch_size=N,
93 | input_size=c_emb.get_shape().as_list()[-1])
94 | c = rnn(c_emb, seq_len=self.c_len)
95 | q = rnn(q_emb, seq_len=self.q_len)
96 |
97 | with tf.variable_scope("attention"):
98 | qc_att = dot_attention(c, q, mask=self.q_mask, hidden=d)
99 | rnn = gru(num_layers=1, num_units=d, batch_size=N,
100 | input_size=qc_att.get_shape().as_list()[-1])
101 | att = rnn(qc_att, seq_len=self.c_len)
102 |
103 | with tf.variable_scope("match"):
104 | self_att = dot_attention(att, att, mask=self.c_mask, hidden=d)
105 | rnn = gru(num_layers=1, num_units=d, batch_size=N,
106 | input_size=self_att.get_shape().as_list()[-1])
107 | match = rnn(self_att, seq_len=self.c_len)
108 |
109 | with tf.variable_scope("pointer"):
110 | init = summ(q[:, :, -2 * d:], d, mask=self.q_mask)
111 | pointer = ptr_net(batch=N, hidden=init.get_shape().as_list()[-1])
112 | logits1, logits2 = pointer(init, match, d, self.c_mask)
113 |
114 | with tf.variable_scope("predict"):
115 | outer = tf.matmul(tf.expand_dims(tf.nn.softmax(logits1), axis=2),
116 | tf.expand_dims(tf.nn.softmax(logits2), axis=1))
117 | outer = tf.cond(
118 | self.tokens_in_context < self.max_answer_tokens,
119 | lambda: tf.matrix_band_part(outer, 0, -1),
120 | lambda: tf.matrix_band_part(outer, 0, self.max_answer_tokens))
121 | self.yp1 = tf.argmax(tf.reduce_max(outer, axis=2), axis=1)
122 | self.yp2 = tf.argmax(tf.reduce_max(outer, axis=1), axis=1)
123 |
124 |
125 | class Inference(object):
126 |
127 | def __init__(self):
128 | with open(word_emb_file, "r") as fh:
129 | self.word_mat = np.array(json.load(fh), dtype=np.float32)
130 | with open(char_emb_file, "r") as fh:
131 | self.char_mat = np.array(json.load(fh), dtype=np.float32)
132 | with open(word2idx_file, "r") as fh:
133 | self.word2idx_dict = json.load(fh)
134 | with open(char2idx_file, "r") as fh:
135 | self.char2idx_dict = json.load(fh)
136 | self.model = InfModel(self.word_mat, self.char_mat)
137 | sess_config = tf.ConfigProto(allow_soft_placement=True)
138 | sess_config.gpu_options.allow_growth = True
139 | self.sess = tf.Session(config=sess_config)
140 | saver = tf.train.Saver()
141 | saver.restore(self.sess, tf.train.latest_checkpoint(save_dir))
142 |
143 | def response(self, context, question):
144 | sess = self.sess
145 | model = self.model
146 | span, context_idxs, ques_idxs, context_char_idxs, ques_char_idxs = \
147 | self.prepro(context, question)
148 | yp1, yp2 = \
149 | sess.run(
150 | [model.yp1, model.yp2],
151 | feed_dict={
152 | model.c: context_idxs, model.q: ques_idxs,
153 | model.ch: context_char_idxs, model.qh: ques_char_idxs,
154 | model.tokens_in_context: len(span)})
155 | start_idx = span[yp1[0]][0]
156 | end_idx = span[yp2[0]][1]
157 | return context[start_idx: end_idx]
158 |
159 | def prepro(self, context, question):
160 | context = context.replace("''", '" ').replace("``", '" ')
161 | context_tokens = word_tokenize(context)
162 | context_chars = [list(token) for token in context_tokens]
163 | spans = convert_idx(context, context_tokens)
164 | ques = question.replace("''", '" ').replace("``", '" ')
165 | ques_tokens = word_tokenize(ques)
166 | ques_chars = [list(token) for token in ques_tokens]
167 |
168 | context_idxs = np.zeros([1, len(context_tokens)], dtype=np.int32)
169 | context_char_idxs = np.zeros(
170 | [1, len(context_tokens), char_limit], dtype=np.int32)
171 | ques_idxs = np.zeros([1, len(ques_tokens)], dtype=np.int32)
172 | ques_char_idxs = np.zeros(
173 | [1, len(ques_tokens), char_limit], dtype=np.int32)
174 |
175 | def _get_word(word):
176 | for each in (word, word.lower(), word.capitalize(), word.upper()):
177 | if each in self.word2idx_dict:
178 | return self.word2idx_dict[each]
179 | return 1
180 |
181 | def _get_char(char):
182 | if char in self.char2idx_dict:
183 | return self.char2idx_dict[char]
184 | return 1
185 |
186 | for i, token in enumerate(context_tokens):
187 | context_idxs[0, i] = _get_word(token)
188 |
189 | for i, token in enumerate(ques_tokens):
190 | ques_idxs[0, i] = _get_word(token)
191 |
192 | for i, token in enumerate(context_chars):
193 | for j, char in enumerate(token):
194 | if j == char_limit:
195 | break
196 | context_char_idxs[0, i, j] = _get_char(char)
197 |
198 | for i, token in enumerate(ques_chars):
199 | for j, char in enumerate(token):
200 | if j == char_limit:
201 | break
202 | ques_char_idxs[0, i, j] = _get_char(char)
203 | return spans, context_idxs, ques_idxs, context_char_idxs, ques_char_idxs
204 |
205 |
206 | # Demo, example from paper "SQuAD: 100,000+ Questions for Machine Comprehension of Text"
207 | if __name__ == "__main__":
208 | infer = Inference()
209 | context = "In meteorology, precipitation is any product of the condensation " \
210 | "of atmospheric water vapor that falls under gravity. The main forms " \
211 | "of precipitation include drizzle, rain, sleet, snow, graupel and hail." \
212 | "Precipitation forms as smaller droplets coalesce via collision with other " \
213 | "rain drops or ice crystals within a cloud. Short, intense periods of rain " \
214 | "in scattered locations are called “showers”."
215 | ques1 = "What causes precipitation to fall?"
216 | ques2 = "What is another main form of precipitation besides drizzle, rain, snow, sleet and hail?"
217 | ques3 = "Where do water droplets collide with ice crystals to form precipitation?"
218 |
219 | # Correct: gravity, Output: drizzle, rain, sleet, snow, graupel and hail
220 | ans1 = infer.response(context, ques1)
221 | print("Answer 1: {}".format(ans1))
222 |
223 | # Correct: graupel, Output: graupel
224 | ans2 = infer.response(context, ques2)
225 | print("Answer 2: {}".format(ans2))
226 |
227 | # Correct: within a cloud, Output: within a cloud
228 | ans3 = infer.response(context, ques3)
229 | print("Answer 3: {}".format(ans3))
230 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import ujson as json
3 | import numpy as np
4 | from tqdm import tqdm
5 | import os
6 |
7 | from model import Model
8 | from util import get_record_parser, convert_tokens, evaluate, get_batch_dataset, get_dataset
9 |
10 |
11 | def train(config):
12 | with open(config.word_emb_file, "r") as fh:
13 | word_mat = np.array(json.load(fh), dtype=np.float32)
14 | with open(config.char_emb_file, "r") as fh:
15 | char_mat = np.array(json.load(fh), dtype=np.float32)
16 | with open(config.train_eval_file, "r") as fh:
17 | train_eval_file = json.load(fh)
18 | with open(config.dev_eval_file, "r") as fh:
19 | dev_eval_file = json.load(fh)
20 | with open(config.dev_meta, "r") as fh:
21 | meta = json.load(fh)
22 |
23 | dev_total = meta["total"]
24 |
25 | print("Building model...")
26 | parser = get_record_parser(config)
27 | train_dataset = get_batch_dataset(config.train_record_file, parser, config)
28 | dev_dataset = get_dataset(config.dev_record_file, parser, config)
29 | handle = tf.placeholder(tf.string, shape=[])
30 | iterator = tf.data.Iterator.from_string_handle(
31 | handle, train_dataset.output_types, train_dataset.output_shapes)
32 | train_iterator = train_dataset.make_one_shot_iterator()
33 | dev_iterator = dev_dataset.make_one_shot_iterator()
34 |
35 | model = Model(config, iterator, word_mat, char_mat)
36 |
37 | sess_config = tf.ConfigProto(allow_soft_placement=True)
38 | sess_config.gpu_options.allow_growth = True
39 |
40 | loss_save = 100.0
41 | patience = 0
42 | lr = config.init_lr
43 |
44 | with tf.Session(config=sess_config) as sess:
45 | writer = tf.summary.FileWriter(config.log_dir)
46 | sess.run(tf.global_variables_initializer())
47 | saver = tf.train.Saver()
48 | train_handle = sess.run(train_iterator.string_handle())
49 | dev_handle = sess.run(dev_iterator.string_handle())
50 | sess.run(tf.assign(model.is_train, tf.constant(True, dtype=tf.bool)))
51 | sess.run(tf.assign(model.lr, tf.constant(lr, dtype=tf.float32)))
52 |
53 | for _ in tqdm(range(1, config.num_steps + 1)):
54 | global_step = sess.run(model.global_step) + 1
55 | loss, train_op = sess.run([model.loss, model.train_op], feed_dict={
56 | handle: train_handle})
57 | if global_step % config.period == 0:
58 | loss_sum = tf.Summary(value=[tf.Summary.Value(
59 | tag="model/loss", simple_value=loss), ])
60 | writer.add_summary(loss_sum, global_step)
61 | if global_step % config.checkpoint == 0:
62 | sess.run(tf.assign(model.is_train,
63 | tf.constant(False, dtype=tf.bool)))
64 | _, summ = evaluate_batch(
65 | model, config.val_num_batches, train_eval_file, sess, "train", handle, train_handle)
66 | for s in summ:
67 | writer.add_summary(s, global_step)
68 |
69 | metrics, summ = evaluate_batch(
70 | model, dev_total // config.batch_size + 1, dev_eval_file, sess, "dev", handle, dev_handle)
71 | sess.run(tf.assign(model.is_train,
72 | tf.constant(True, dtype=tf.bool)))
73 |
74 | dev_loss = metrics["loss"]
75 | if dev_loss < loss_save:
76 | loss_save = dev_loss
77 | patience = 0
78 | else:
79 | patience += 1
80 | if patience >= config.patience:
81 | lr /= 2.0
82 | loss_save = dev_loss
83 | patience = 0
84 | sess.run(tf.assign(model.lr, tf.constant(lr, dtype=tf.float32)))
85 | for s in summ:
86 | writer.add_summary(s, global_step)
87 | writer.flush()
88 | filename = os.path.join(
89 | config.save_dir, "model_{}.ckpt".format(global_step))
90 | saver.save(sess, filename)
91 |
92 |
93 | def evaluate_batch(model, num_batches, eval_file, sess, data_type, handle, str_handle):
94 | answer_dict = {}
95 | losses = []
96 | for _ in tqdm(range(1, num_batches + 1)):
97 | qa_id, loss, yp1, yp2, = sess.run(
98 | [model.qa_id, model.loss, model.yp1, model.yp2], feed_dict={handle: str_handle})
99 | answer_dict_, _ = convert_tokens(
100 | eval_file, qa_id.tolist(), yp1.tolist(), yp2.tolist())
101 | answer_dict.update(answer_dict_)
102 | losses.append(loss)
103 | loss = np.mean(losses)
104 | metrics = evaluate(eval_file, answer_dict)
105 | metrics["loss"] = loss
106 | loss_sum = tf.Summary(value=[tf.Summary.Value(
107 | tag="{}/loss".format(data_type), simple_value=metrics["loss"]), ])
108 | f1_sum = tf.Summary(value=[tf.Summary.Value(
109 | tag="{}/f1".format(data_type), simple_value=metrics["f1"]), ])
110 | em_sum = tf.Summary(value=[tf.Summary.Value(
111 | tag="{}/em".format(data_type), simple_value=metrics["exact_match"]), ])
112 | return metrics, [loss_sum, f1_sum, em_sum]
113 |
114 |
115 | def test(config):
116 | with open(config.word_emb_file, "r") as fh:
117 | word_mat = np.array(json.load(fh), dtype=np.float32)
118 | with open(config.char_emb_file, "r") as fh:
119 | char_mat = np.array(json.load(fh), dtype=np.float32)
120 | with open(config.test_eval_file, "r") as fh:
121 | eval_file = json.load(fh)
122 | with open(config.test_meta, "r") as fh:
123 | meta = json.load(fh)
124 |
125 | total = meta["total"]
126 |
127 | print("Loading model...")
128 | test_batch = get_dataset(config.test_record_file, get_record_parser(
129 | config, is_test=True), config).make_one_shot_iterator()
130 |
131 | model = Model(config, test_batch, word_mat, char_mat, trainable=False)
132 |
133 | sess_config = tf.ConfigProto(allow_soft_placement=True)
134 | sess_config.gpu_options.allow_growth = True
135 |
136 | with tf.Session(config=sess_config) as sess:
137 | sess.run(tf.global_variables_initializer())
138 | saver = tf.train.Saver()
139 | saver.restore(sess, tf.train.latest_checkpoint(config.save_dir))
140 | sess.run(tf.assign(model.is_train, tf.constant(False, dtype=tf.bool)))
141 | losses = []
142 | answer_dict = {}
143 | remapped_dict = {}
144 | for step in tqdm(range(total // config.batch_size + 1)):
145 | qa_id, loss, yp1, yp2 = sess.run(
146 | [model.qa_id, model.loss, model.yp1, model.yp2])
147 | answer_dict_, remapped_dict_ = convert_tokens(
148 | eval_file, qa_id.tolist(), yp1.tolist(), yp2.tolist())
149 | answer_dict.update(answer_dict_)
150 | remapped_dict.update(remapped_dict_)
151 | losses.append(loss)
152 | loss = np.mean(losses)
153 | metrics = evaluate(eval_file, answer_dict)
154 | with open(config.answer_file, "w") as fh:
155 | json.dump(remapped_dict, fh)
156 | print("Exact Match: {}, F1: {}".format(
157 | metrics['exact_match'], metrics['f1']))
158 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from func import cudnn_gru, native_gru, dot_attention, summ, dropout, ptr_net
3 |
4 |
5 | class Model(object):
6 | def __init__(self, config, batch, word_mat=None, char_mat=None, trainable=True, opt=True):
7 | self.config = config
8 | self.global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32,
9 | initializer=tf.constant_initializer(0), trainable=False)
10 | self.c, self.q, self.ch, self.qh, self.y1, self.y2, self.qa_id = batch.get_next()
11 | self.is_train = tf.get_variable(
12 | "is_train", shape=[], dtype=tf.bool, trainable=False)
13 | self.word_mat = tf.get_variable("word_mat", initializer=tf.constant(
14 | word_mat, dtype=tf.float32), trainable=False)
15 | self.char_mat = tf.get_variable(
16 | "char_mat", initializer=tf.constant(char_mat, dtype=tf.float32))
17 |
18 | self.c_mask = tf.cast(self.c, tf.bool)
19 | self.q_mask = tf.cast(self.q, tf.bool)
20 | self.c_len = tf.reduce_sum(tf.cast(self.c_mask, tf.int32), axis=1)
21 | self.q_len = tf.reduce_sum(tf.cast(self.q_mask, tf.int32), axis=1)
22 |
23 | if opt:
24 | N, CL = config.batch_size, config.char_limit
25 | self.c_maxlen = tf.reduce_max(self.c_len)
26 | self.q_maxlen = tf.reduce_max(self.q_len)
27 | self.c = tf.slice(self.c, [0, 0], [N, self.c_maxlen])
28 | self.q = tf.slice(self.q, [0, 0], [N, self.q_maxlen])
29 | self.c_mask = tf.slice(self.c_mask, [0, 0], [N, self.c_maxlen])
30 | self.q_mask = tf.slice(self.q_mask, [0, 0], [N, self.q_maxlen])
31 | self.ch = tf.slice(self.ch, [0, 0, 0], [N, self.c_maxlen, CL])
32 | self.qh = tf.slice(self.qh, [0, 0, 0], [N, self.q_maxlen, CL])
33 | self.y1 = tf.slice(self.y1, [0, 0], [N, self.c_maxlen])
34 | self.y2 = tf.slice(self.y2, [0, 0], [N, self.c_maxlen])
35 | else:
36 | self.c_maxlen, self.q_maxlen = config.para_limit, config.ques_limit
37 |
38 | self.ch_len = tf.reshape(tf.reduce_sum(
39 | tf.cast(tf.cast(self.ch, tf.bool), tf.int32), axis=2), [-1])
40 | self.qh_len = tf.reshape(tf.reduce_sum(
41 | tf.cast(tf.cast(self.qh, tf.bool), tf.int32), axis=2), [-1])
42 |
43 | self.ready()
44 |
45 | if trainable:
46 | self.lr = tf.get_variable(
47 | "lr", shape=[], dtype=tf.float32, trainable=False)
48 | self.opt = tf.train.AdadeltaOptimizer(
49 | learning_rate=self.lr, epsilon=1e-6)
50 | grads = self.opt.compute_gradients(self.loss)
51 | gradients, variables = zip(*grads)
52 | capped_grads, _ = tf.clip_by_global_norm(
53 | gradients, config.grad_clip)
54 | self.train_op = self.opt.apply_gradients(
55 | zip(capped_grads, variables), global_step=self.global_step)
56 |
57 | def ready(self):
58 | config = self.config
59 | N, PL, QL, CL, d, dc, dg = config.batch_size, self.c_maxlen, self.q_maxlen, config.char_limit, config.hidden, config.char_dim, config.char_hidden
60 | gru = cudnn_gru if config.use_cudnn else native_gru
61 |
62 | with tf.variable_scope("emb"):
63 | with tf.variable_scope("char"):
64 | ch_emb = tf.reshape(tf.nn.embedding_lookup(
65 | self.char_mat, self.ch), [N * PL, CL, dc])
66 | qh_emb = tf.reshape(tf.nn.embedding_lookup(
67 | self.char_mat, self.qh), [N * QL, CL, dc])
68 | ch_emb = dropout(
69 | ch_emb, keep_prob=config.keep_prob, is_train=self.is_train)
70 | qh_emb = dropout(
71 | qh_emb, keep_prob=config.keep_prob, is_train=self.is_train)
72 | cell_fw = tf.contrib.rnn.GRUCell(dg)
73 | cell_bw = tf.contrib.rnn.GRUCell(dg)
74 | _, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(
75 | cell_fw, cell_bw, ch_emb, self.ch_len, dtype=tf.float32)
76 | ch_emb = tf.concat([state_fw, state_bw], axis=1)
77 | _, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(
78 | cell_fw, cell_bw, qh_emb, self.qh_len, dtype=tf.float32)
79 | qh_emb = tf.concat([state_fw, state_bw], axis=1)
80 | qh_emb = tf.reshape(qh_emb, [N, QL, 2 * dg])
81 | ch_emb = tf.reshape(ch_emb, [N, PL, 2 * dg])
82 |
83 | with tf.name_scope("word"):
84 | c_emb = tf.nn.embedding_lookup(self.word_mat, self.c)
85 | q_emb = tf.nn.embedding_lookup(self.word_mat, self.q)
86 |
87 | c_emb = tf.concat([c_emb, ch_emb], axis=2)
88 | q_emb = tf.concat([q_emb, qh_emb], axis=2)
89 |
90 | with tf.variable_scope("encoding"):
91 | rnn = gru(num_layers=3, num_units=d, batch_size=N, input_size=c_emb.get_shape(
92 | ).as_list()[-1], keep_prob=config.keep_prob, is_train=self.is_train)
93 | c = rnn(c_emb, seq_len=self.c_len)
94 | q = rnn(q_emb, seq_len=self.q_len)
95 |
96 | with tf.variable_scope("attention"):
97 | qc_att = dot_attention(c, q, mask=self.q_mask, hidden=d,
98 | keep_prob=config.keep_prob, is_train=self.is_train)
99 | rnn = gru(num_layers=1, num_units=d, batch_size=N, input_size=qc_att.get_shape(
100 | ).as_list()[-1], keep_prob=config.keep_prob, is_train=self.is_train)
101 | att = rnn(qc_att, seq_len=self.c_len)
102 |
103 | with tf.variable_scope("match"):
104 | self_att = dot_attention(
105 | att, att, mask=self.c_mask, hidden=d, keep_prob=config.keep_prob, is_train=self.is_train)
106 | rnn = gru(num_layers=1, num_units=d, batch_size=N, input_size=self_att.get_shape(
107 | ).as_list()[-1], keep_prob=config.keep_prob, is_train=self.is_train)
108 | match = rnn(self_att, seq_len=self.c_len)
109 |
110 | with tf.variable_scope("pointer"):
111 | init = summ(q[:, :, -2 * d:], d, mask=self.q_mask,
112 | keep_prob=config.ptr_keep_prob, is_train=self.is_train)
113 | pointer = ptr_net(batch=N, hidden=init.get_shape().as_list(
114 | )[-1], keep_prob=config.ptr_keep_prob, is_train=self.is_train)
115 | logits1, logits2 = pointer(init, match, d, self.c_mask)
116 |
117 | with tf.variable_scope("predict"):
118 | outer = tf.matmul(tf.expand_dims(tf.nn.softmax(logits1), axis=2),
119 | tf.expand_dims(tf.nn.softmax(logits2), axis=1))
120 | outer = tf.matrix_band_part(outer, 0, 15)
121 | self.yp1 = tf.argmax(tf.reduce_max(outer, axis=2), axis=1)
122 | self.yp2 = tf.argmax(tf.reduce_max(outer, axis=1), axis=1)
123 | losses = tf.nn.softmax_cross_entropy_with_logits_v2(
124 | logits=logits1, labels=tf.stop_gradient(self.y1))
125 | losses2 = tf.nn.softmax_cross_entropy_with_logits_v2(
126 | logits=logits2, labels=tf.stop_gradient(self.y2))
127 | self.loss = tf.reduce_mean(losses + losses2)
128 |
129 | def get_loss(self):
130 | return self.loss
131 |
132 | def get_global_step(self):
133 | return self.global_step
134 |
--------------------------------------------------------------------------------
/prepro.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import random
3 | from tqdm import tqdm
4 | import spacy
5 | import ujson as json
6 | from collections import Counter
7 | import numpy as np
8 | import os.path
9 |
10 | nlp = spacy.blank("en")
11 |
12 |
13 | def word_tokenize(sent):
14 | doc = nlp(sent)
15 | return [token.text for token in doc]
16 |
17 |
18 | def convert_idx(text, tokens):
19 | current = 0
20 | spans = []
21 | for token in tokens:
22 | current = text.find(token, current)
23 | if current < 0:
24 | print("Token {} cannot be found".format(token))
25 | raise Exception()
26 | spans.append((current, current + len(token)))
27 | current += len(token)
28 | return spans
29 |
30 |
31 | def process_file(filename, data_type, word_counter, char_counter):
32 | print("Generating {} examples...".format(data_type))
33 | examples = []
34 | eval_examples = {}
35 | total = 0
36 | with open(filename, "r") as fh:
37 | source = json.load(fh)
38 | for article in tqdm(source["data"]):
39 | for para in article["paragraphs"]:
40 | context = para["context"].replace(
41 | "''", '" ').replace("``", '" ')
42 | context_tokens = word_tokenize(context)
43 | context_chars = [list(token) for token in context_tokens]
44 | spans = convert_idx(context, context_tokens)
45 | for token in context_tokens:
46 | word_counter[token] += len(para["qas"])
47 | for char in token:
48 | char_counter[char] += len(para["qas"])
49 | for qa in para["qas"]:
50 | total += 1
51 | ques = qa["question"].replace(
52 | "''", '" ').replace("``", '" ')
53 | ques_tokens = word_tokenize(ques)
54 | ques_chars = [list(token) for token in ques_tokens]
55 | for token in ques_tokens:
56 | word_counter[token] += 1
57 | for char in token:
58 | char_counter[char] += 1
59 | y1s, y2s = [], []
60 | answer_texts = []
61 | for answer in qa["answers"]:
62 | answer_text = answer["text"]
63 | answer_start = answer['answer_start']
64 | answer_end = answer_start + len(answer_text)
65 | answer_texts.append(answer_text)
66 | answer_span = []
67 | for idx, span in enumerate(spans):
68 | if not (answer_end <= span[0] or answer_start >= span[1]):
69 | answer_span.append(idx)
70 | y1, y2 = answer_span[0], answer_span[-1]
71 | y1s.append(y1)
72 | y2s.append(y2)
73 | example = {"context_tokens": context_tokens, "context_chars": context_chars, "ques_tokens": ques_tokens,
74 | "ques_chars": ques_chars, "y1s": y1s, "y2s": y2s, "id": total}
75 | examples.append(example)
76 | eval_examples[str(total)] = {
77 | "context": context, "spans": spans, "answers": answer_texts, "uuid": qa["id"]}
78 | random.shuffle(examples)
79 | print("{} questions in total".format(len(examples)))
80 | return examples, eval_examples
81 |
82 |
83 | def get_embedding(counter, data_type, limit=-1, emb_file=None, size=None, vec_size=None, token2idx_dict=None):
84 | print("Generating {} embedding...".format(data_type))
85 | embedding_dict = {}
86 | filtered_elements = [k for k, v in counter.items() if v > limit]
87 | if emb_file is not None:
88 | assert size is not None
89 | assert vec_size is not None
90 | with open(emb_file, "r", encoding="utf-8") as fh:
91 | for line in tqdm(fh, total=size):
92 | array = line.split()
93 | word = "".join(array[0:-vec_size])
94 | vector = list(map(float, array[-vec_size:]))
95 | if word in counter and counter[word] > limit:
96 | embedding_dict[word] = vector
97 | print("{} / {} tokens have corresponding {} embedding vector".format(
98 | len(embedding_dict), len(filtered_elements), data_type))
99 | else:
100 | assert vec_size is not None
101 | for token in filtered_elements:
102 | embedding_dict[token] = [np.random.normal(
103 | scale=0.01) for _ in range(vec_size)]
104 | print("{} tokens have corresponding embedding vector".format(
105 | len(filtered_elements)))
106 |
107 | NULL = "--NULL--"
108 | OOV = "--OOV--"
109 | token2idx_dict = {token: idx for idx, token in enumerate(
110 | embedding_dict.keys(), 2)} if token2idx_dict is None else token2idx_dict
111 | token2idx_dict[NULL] = 0
112 | token2idx_dict[OOV] = 1
113 | embedding_dict[NULL] = [0. for _ in range(vec_size)]
114 | embedding_dict[OOV] = [0. for _ in range(vec_size)]
115 | idx2emb_dict = {idx: embedding_dict[token]
116 | for token, idx in token2idx_dict.items()}
117 | emb_mat = [idx2emb_dict[idx] for idx in range(len(idx2emb_dict))]
118 | return emb_mat, token2idx_dict
119 |
120 |
121 | def build_features(config, examples, data_type, out_file, word2idx_dict, char2idx_dict, is_test=False):
122 |
123 | para_limit = config.test_para_limit if is_test else config.para_limit
124 | ques_limit = config.test_ques_limit if is_test else config.ques_limit
125 | char_limit = config.char_limit
126 |
127 | def filter_func(example, is_test=False):
128 | return len(example["context_tokens"]) > para_limit or len(example["ques_tokens"]) > ques_limit
129 |
130 | print("Processing {} examples...".format(data_type))
131 | writer = tf.python_io.TFRecordWriter(out_file)
132 | total = 0
133 | total_ = 0
134 | meta = {}
135 | for example in tqdm(examples):
136 | total_ += 1
137 |
138 | if filter_func(example, is_test):
139 | continue
140 |
141 | total += 1
142 | context_idxs = np.zeros([para_limit], dtype=np.int32)
143 | context_char_idxs = np.zeros([para_limit, char_limit], dtype=np.int32)
144 | ques_idxs = np.zeros([ques_limit], dtype=np.int32)
145 | ques_char_idxs = np.zeros([ques_limit, char_limit], dtype=np.int32)
146 | y1 = np.zeros([para_limit], dtype=np.float32)
147 | y2 = np.zeros([para_limit], dtype=np.float32)
148 |
149 | def _get_word(word):
150 | for each in (word, word.lower(), word.capitalize(), word.upper()):
151 | if each in word2idx_dict:
152 | return word2idx_dict[each]
153 | return 1
154 |
155 | def _get_char(char):
156 | if char in char2idx_dict:
157 | return char2idx_dict[char]
158 | return 1
159 |
160 | for i, token in enumerate(example["context_tokens"]):
161 | context_idxs[i] = _get_word(token)
162 |
163 | for i, token in enumerate(example["ques_tokens"]):
164 | ques_idxs[i] = _get_word(token)
165 |
166 | for i, token in enumerate(example["context_chars"]):
167 | for j, char in enumerate(token):
168 | if j == char_limit:
169 | break
170 | context_char_idxs[i, j] = _get_char(char)
171 |
172 | for i, token in enumerate(example["ques_chars"]):
173 | for j, char in enumerate(token):
174 | if j == char_limit:
175 | break
176 | ques_char_idxs[i, j] = _get_char(char)
177 |
178 | start, end = example["y1s"][-1], example["y2s"][-1]
179 | y1[start], y2[end] = 1.0, 1.0
180 |
181 | record = tf.train.Example(features=tf.train.Features(feature={
182 | "context_idxs": tf.train.Feature(bytes_list=tf.train.BytesList(value=[context_idxs.tostring()])),
183 | "ques_idxs": tf.train.Feature(bytes_list=tf.train.BytesList(value=[ques_idxs.tostring()])),
184 | "context_char_idxs": tf.train.Feature(bytes_list=tf.train.BytesList(value=[context_char_idxs.tostring()])),
185 | "ques_char_idxs": tf.train.Feature(bytes_list=tf.train.BytesList(value=[ques_char_idxs.tostring()])),
186 | "y1": tf.train.Feature(bytes_list=tf.train.BytesList(value=[y1.tostring()])),
187 | "y2": tf.train.Feature(bytes_list=tf.train.BytesList(value=[y2.tostring()])),
188 | "id": tf.train.Feature(int64_list=tf.train.Int64List(value=[example["id"]]))
189 | }))
190 | writer.write(record.SerializeToString())
191 | print("Build {} / {} instances of features in total".format(total, total_))
192 | meta["total"] = total
193 | writer.close()
194 | return meta
195 |
196 |
197 | def save(filename, obj, message=None):
198 | if message is not None:
199 | print("Saving {}...".format(message))
200 | with open(filename, "w") as fh:
201 | json.dump(obj, fh)
202 |
203 |
204 | def prepro(config):
205 | word_counter, char_counter = Counter(), Counter()
206 | train_examples, train_eval = process_file(
207 | config.train_file, "train", word_counter, char_counter)
208 | dev_examples, dev_eval = process_file(
209 | config.dev_file, "dev", word_counter, char_counter)
210 | test_examples, test_eval = process_file(
211 | config.test_file, "test", word_counter, char_counter)
212 |
213 | word_emb_file = config.fasttext_file if config.fasttext else config.glove_word_file
214 | char_emb_file = config.glove_char_file if config.pretrained_char else None
215 | char_emb_size = config.glove_char_size if config.pretrained_char else None
216 | char_emb_dim = config.glove_dim if config.pretrained_char else config.char_dim
217 |
218 | word2idx_dict = None
219 | if os.path.isfile(config.word2idx_file):
220 | with open(config.word2idx_file, "r") as fh:
221 | word2idx_dict = json.load(fh)
222 | word_emb_mat, word2idx_dict = get_embedding(word_counter, "word", emb_file=word_emb_file,
223 | size=config.glove_word_size, vec_size=config.glove_dim, token2idx_dict=word2idx_dict)
224 |
225 | char2idx_dict = None
226 | if os.path.isfile(config.char2idx_file):
227 | with open(config.char2idx_file, "r") as fh:
228 | char2idx_dict = json.load(fh)
229 | char_emb_mat, char2idx_dict = get_embedding(
230 | char_counter, "char", emb_file=char_emb_file, size=char_emb_size, vec_size=char_emb_dim, token2idx_dict=char2idx_dict)
231 |
232 | build_features(config, train_examples, "train",
233 | config.train_record_file, word2idx_dict, char2idx_dict)
234 | dev_meta = build_features(config, dev_examples, "dev",
235 | config.dev_record_file, word2idx_dict, char2idx_dict)
236 | test_meta = build_features(config, test_examples, "test",
237 | config.test_record_file, word2idx_dict, char2idx_dict, is_test=True)
238 |
239 | save(config.word_emb_file, word_emb_mat, message="word embedding")
240 | save(config.char_emb_file, char_emb_mat, message="char embedding")
241 | save(config.train_eval_file, train_eval, message="train eval")
242 | save(config.dev_eval_file, dev_eval, message="dev eval")
243 | save(config.test_eval_file, test_eval, message="test eval")
244 | save(config.dev_meta, dev_meta, message="dev meta")
245 | save(config.word2idx_file, word2idx_dict, message="word2idx")
246 | save(config.char2idx_file, char2idx_dict, message="char2idx")
247 | save(config.test_meta, test_meta, message="test meta")
248 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import re
4 | from collections import Counter
5 | import string
6 |
7 |
8 | def get_record_parser(config, is_test=False):
9 | def parse(example):
10 | para_limit = config.test_para_limit if is_test else config.para_limit
11 | ques_limit = config.test_ques_limit if is_test else config.ques_limit
12 | char_limit = config.char_limit
13 | features = tf.parse_single_example(example,
14 | features={
15 | "context_idxs": tf.FixedLenFeature([], tf.string),
16 | "ques_idxs": tf.FixedLenFeature([], tf.string),
17 | "context_char_idxs": tf.FixedLenFeature([], tf.string),
18 | "ques_char_idxs": tf.FixedLenFeature([], tf.string),
19 | "y1": tf.FixedLenFeature([], tf.string),
20 | "y2": tf.FixedLenFeature([], tf.string),
21 | "id": tf.FixedLenFeature([], tf.int64)
22 | })
23 | context_idxs = tf.reshape(tf.decode_raw(
24 | features["context_idxs"], tf.int32), [para_limit])
25 | ques_idxs = tf.reshape(tf.decode_raw(
26 | features["ques_idxs"], tf.int32), [ques_limit])
27 | context_char_idxs = tf.reshape(tf.decode_raw(
28 | features["context_char_idxs"], tf.int32), [para_limit, char_limit])
29 | ques_char_idxs = tf.reshape(tf.decode_raw(
30 | features["ques_char_idxs"], tf.int32), [ques_limit, char_limit])
31 | y1 = tf.reshape(tf.decode_raw(
32 | features["y1"], tf.float32), [para_limit])
33 | y2 = tf.reshape(tf.decode_raw(
34 | features["y2"], tf.float32), [para_limit])
35 | qa_id = features["id"]
36 | return context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, y1, y2, qa_id
37 | return parse
38 |
39 |
40 | def get_batch_dataset(record_file, parser, config):
41 | num_threads = tf.constant(config.num_threads, dtype=tf.int32)
42 | dataset = tf.data.TFRecordDataset(record_file).map(
43 | parser, num_parallel_calls=num_threads).shuffle(config.capacity).repeat()
44 | if config.is_bucket:
45 | buckets = [tf.constant(num) for num in range(*config.bucket_range)]
46 |
47 | def key_func(context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, y1, y2, qa_id):
48 | c_len = tf.reduce_sum(
49 | tf.cast(tf.cast(context_idxs, tf.bool), tf.int32))
50 | buckets_min = [np.iinfo(np.int32).min] + buckets
51 | buckets_max = buckets + [np.iinfo(np.int32).max]
52 | conditions_c = tf.logical_and(
53 | tf.less(buckets_min, c_len), tf.less_equal(c_len, buckets_max))
54 | bucket_id = tf.reduce_min(tf.where(conditions_c))
55 | return bucket_id
56 |
57 | def reduce_func(key, elements):
58 | return elements.batch(config.batch_size)
59 |
60 | dataset = dataset.apply(tf.contrib.data.group_by_window(
61 | key_func, reduce_func, window_size=5 * config.batch_size)).shuffle(len(buckets) * 25)
62 | else:
63 | dataset = dataset.batch(config.batch_size)
64 | return dataset
65 |
66 |
67 | def get_dataset(record_file, parser, config):
68 | num_threads = tf.constant(config.num_threads, dtype=tf.int32)
69 | dataset = tf.data.TFRecordDataset(record_file).map(
70 | parser, num_parallel_calls=num_threads).repeat().batch(config.batch_size)
71 | return dataset
72 |
73 |
74 | def convert_tokens(eval_file, qa_id, pp1, pp2):
75 | answer_dict = {}
76 | remapped_dict = {}
77 | for qid, p1, p2 in zip(qa_id, pp1, pp2):
78 | context = eval_file[str(qid)]["context"]
79 | spans = eval_file[str(qid)]["spans"]
80 | uuid = eval_file[str(qid)]["uuid"]
81 | start_idx = spans[p1][0]
82 | end_idx = spans[p2][1]
83 | answer_dict[str(qid)] = context[start_idx: end_idx]
84 | remapped_dict[uuid] = context[start_idx: end_idx]
85 | return answer_dict, remapped_dict
86 |
87 |
88 | def evaluate(eval_file, answer_dict):
89 | f1 = exact_match = total = 0
90 | for key, value in answer_dict.items():
91 | total += 1
92 | ground_truths = eval_file[key]["answers"]
93 | prediction = value
94 | exact_match += metric_max_over_ground_truths(
95 | exact_match_score, prediction, ground_truths)
96 | f1 += metric_max_over_ground_truths(f1_score,
97 | prediction, ground_truths)
98 | exact_match = 100.0 * exact_match / total
99 | f1 = 100.0 * f1 / total
100 | return {'exact_match': exact_match, 'f1': f1}
101 |
102 |
103 | def normalize_answer(s):
104 |
105 | def remove_articles(text):
106 | return re.sub(r'\b(a|an|the)\b', ' ', text)
107 |
108 | def white_space_fix(text):
109 | return ' '.join(text.split())
110 |
111 | def remove_punc(text):
112 | exclude = set(string.punctuation)
113 | return ''.join(ch for ch in text if ch not in exclude)
114 |
115 | def lower(text):
116 | return text.lower()
117 |
118 | return white_space_fix(remove_articles(remove_punc(lower(s))))
119 |
120 |
121 | def f1_score(prediction, ground_truth):
122 | prediction_tokens = normalize_answer(prediction).split()
123 | ground_truth_tokens = normalize_answer(ground_truth).split()
124 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
125 | num_same = sum(common.values())
126 | if num_same == 0:
127 | return 0
128 | precision = 1.0 * num_same / len(prediction_tokens)
129 | recall = 1.0 * num_same / len(ground_truth_tokens)
130 | f1 = (2 * precision * recall) / (precision + recall)
131 | return f1
132 |
133 |
134 | def exact_match_score(prediction, ground_truth):
135 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
136 |
137 |
138 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
139 | scores_for_ground_truths = []
140 | for ground_truth in ground_truths:
141 | score = metric_fn(prediction, ground_truth)
142 | scores_for_ground_truths.append(score)
143 | return max(scores_for_ground_truths)
144 |
--------------------------------------------------------------------------------