├── LICENSE
├── README.md
├── basic
├── __init__.py
├── cli.py
├── ensemble.py
├── ensemble_fast.py
├── evaluator.py
├── graph_handler.py
├── main.py
├── model.py
├── read_data.py
├── run_ensemble.sh
├── run_single.sh
├── templates
│ └── visualizer.html
├── trainer.py
└── visualizer.py
├── basic_cnn
├── __init__.py
├── cli.py
├── evaluator.py
├── graph_handler.py
├── main.py
├── model.py
├── read_data.py
├── superhighway.py
├── templates
│ └── visualizer.html
├── trainer.py
└── visualizer.py
├── cnn_dm
├── __init__.py
├── eda.ipynb
├── evaluate.py
└── prepro.py
├── download.sh
├── my
├── __init__.py
├── corenlp_interface.py
├── nltk_utils.py
├── tensorflow
│ ├── __init__.py
│ ├── general.py
│ ├── nn.py
│ ├── rnn.py
│ └── rnn_cell.py
├── utils.py
└── zip_save.py
├── requirements.txt
├── squad
├── __init__.py
├── aug_squad.py
├── eda_aug_dev.ipynb
├── eda_aug_train.ipynb
├── evaluate-v1.1.py
├── evaluate.py
├── prepro.py
├── prepro_aug.py
└── utils.py
├── tree
├── __init__.py
├── cli.py
├── evaluator.py
├── graph_handler.py
├── main.py
├── model.py
├── read_data.py
├── templates
│ └── visualizer.html
├── test.ipynb
├── trainer.py
└── visualizer.py
└── visualization
└── compare_models.py
/README.md:
--------------------------------------------------------------------------------
1 | # Bi-directional Attention Flow for Machine Comprehension
2 |
3 | - This the original implementation of [Bi-directional Attention Flow for Machine Comprehension][paper].
4 | - The CodaLab worksheet for the [SQuAD Leaderboard][squad] submission is available [here][worksheet].
5 | - For TensorFlow v1.2 compatible version, see the [dev][dev] branch.
6 | - Please contact [Minjoon Seo][minjoon] ([@seominjoon][minjoon-github]) for questions and suggestions.
7 |
8 | ## 0. Requirements
9 | #### General
10 | - Python (verified on 3.5.2. Issues have been reported with Python 2!)
11 | - unzip, wget (for running `download.sh` only)
12 |
13 | #### Python Packages
14 | - tensorflow (deep learning library, only works on r0.11)
15 | - nltk (NLP tools, verified on 3.2.1)
16 | - tqdm (progress bar, verified on 4.7.4)
17 | - jinja2 (for visaulization; if you only train and test, not needed)
18 |
19 | ## 1. Pre-processing
20 | First, prepare data. Donwload SQuAD data and GloVe and nltk corpus
21 | (~850 MB, this will download files to `$HOME/data`):
22 | ```
23 | chmod +x download.sh; ./download.sh
24 | ```
25 |
26 | Second, Preprocess Stanford QA dataset (along with GloVe vectors) and save them in `$PWD/data/squad` (~5 minutes):
27 | ```
28 | python -m squad.prepro
29 | ```
30 |
31 | ## 2. Training
32 | The model has ~2.5M parameters.
33 | The model was trained with NVidia Titan X (Pascal Architecture, 2016).
34 | The model requires at least 12GB of GPU RAM.
35 | If your GPU RAM is smaller than 12GB, you can either decrease batch size (performance might degrade),
36 | or you can use multi GPU (see below).
37 | The training converges at ~18k steps, and it took ~4s per step (i.e. ~20 hours).
38 |
39 | Before training, it is recommended to first try the following code to verify everything is okay and memory is sufficient:
40 | ```
41 | python -m basic.cli --mode train --noload --debug
42 | ```
43 |
44 | Then to fully train, run:
45 | ```
46 | python -m basic.cli --mode train --noload
47 | ```
48 |
49 | You can speed up the training process with optimization flags:
50 | ```
51 | python -m basic.cli --mode train --noload --len_opt --cluster
52 | ```
53 | You can still omit them, but training will be much slower.
54 |
55 | Note that during the training, the EM and F1 scores from the occasional evaluation are not the same with the score from official squad evaluation script.
56 | The printed scores are not official (our scoring scheme is a bit harsher).
57 | To obtain the official number, use the official evaluator (copied in `squad` folder, `squad/evaluate-v1.1.py`). For more information See 3.Test.
58 |
59 |
60 | ## 3. Test
61 | To test, run:
62 | ```
63 | python -m basic.cli
64 | ```
65 |
66 | Similarly to training, you can give the optimization flags to speed up test (5 minutes on dev data):
67 | ```
68 | python -m basic.cli --len_opt --cluster
69 | ```
70 |
71 | This command loads the most recently saved model during training and begins testing on the test data.
72 | After the process ends, it prints F1 and EM scores, and also outputs a json file (`$PWD/out/basic/00/answer/test-####.json`,
73 | where `####` is the step # that the model was saved).
74 | Note that the printed scores are not official (our scoring scheme is a bit harsher).
75 | To obtain the official number, use the official evaluator (copied in `squad` folder) and the output json file:
76 |
77 | ```
78 | python squad/evaluate-v1.1.py $HOME/data/squad/dev-v1.1.json out/basic/00/answer/test-####.json
79 | ```
80 |
81 | ### 3.1 Loading from pre-trained weights
82 | Instead of training the model yourself, you can choose to use pre-trained weights that were used for [SQuAD Leaderboard][squad] submission.
83 | Refer to [this worksheet][worksheet] in CodaLab to reproduce the results.
84 | If you are unfamiliar with CodaLab, follow these simple steps (given that you met all prereqs above):
85 |
86 | 1. Download `save.zip` from the [worksheet][worksheet] and unzip it in the current directory.
87 | 2. Copy `glove.6B.100d.txt` from your glove data folder (`$HOME/data/glove/`) to the current directory.
88 | 3. To reproduce single model:
89 |
90 | ```
91 | basic/run_single.sh $HOME/data/squad/dev-v1.1.json single.json
92 | ```
93 |
94 | This writes the answers to `single.json` in the current directory. You can then use the official evaluator to obtain EM and F1 scores. If you want to run on GPU (~5 mins), change the value of batch_size flag in the shell file to a higher number (60 for 12GB GPU RAM).
95 | 4. Similarly, to reproduce ensemble method:
96 |
97 | ```
98 | basic/run_ensemble.sh $HOME/data/squad/dev-v1.1.json ensemble.json
99 | ```
100 | If you want to run on GPU, you should run the script sequentially by removing '&' in the forloop, or you will need to specify different GPUs for each run of the for loop.
101 |
102 | ## Results
103 |
104 | ### Dev Data
105 |
106 | Note these scores are from the official evaluator (copied in `squad` folder, `squad/evaluate-v1.1.py`). For more information See 3.Test.
107 | The scores appeared during the training could be lower than the scores from the official evaluator.
108 |
109 | | | EM (%) | F1 (%) |
110 | | -------- |:------:|:------:|
111 | | single | 67.7 | 77.3 |
112 | | ensemble | 72.6 | 80.7 |
113 |
114 | ### Test Data
115 |
116 | | | EM (%) | F1 (%) |
117 | | -------- |:------:|:------:|
118 | | single | 68.0 | 77.3 |
119 | | ensemble | 73.3 | 81.1 |
120 |
121 | Refer to [our paper][paper] for more details.
122 | See [SQuAD Leaderboard][squad] to compare with other models.
123 |
124 |
125 |
136 |
137 |
138 | ## Multi-GPU Training & Testing
139 | Our model supports multi-GPU training.
140 | We follow the parallelization paradigm described in [TensorFlow Tutorial][multi-gpu].
141 | In short, if you want to use batch size of 60 (default) but if you have 3 GPUs with 4GB of RAM,
142 | then you initialize each GPU with batch size of 20, and combine the gradients on CPU.
143 | This can be easily done by running:
144 | ```
145 | python -m basic.cli --mode train --noload --num_gpus 3 --batch_size 20
146 | ```
147 |
148 | Similarly, you can speed up your testing by:
149 | ```
150 | python -m basic.cli --num_gpus 3 --batch_size 20
151 | ```
152 |
153 | ## Demo
154 | For now, please refer to the `demo` branch of this repository.
155 |
156 |
157 | [multi-gpu]: https://www.tensorflow.org/versions/r0.11/tutorials/deep_cnn/index.html#training-a-model-using-multiple-gpu-cards
158 | [squad]: http://stanford-qa.com
159 | [paper]: https://arxiv.org/abs/1611.01603
160 | [worksheet]: https://worksheets.codalab.org/worksheets/0x37a9b8c44f6845c28866267ef941c89d/
161 | [minjoon]: https://seominjoon.github.io
162 | [minjoon-github]: https://github.com/seominjoon
163 | [dev]: https://github.com/allenai/bi-att-flow/tree/dev
164 |
--------------------------------------------------------------------------------
/basic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/bi-att-flow/49004549e9a88b78c359b31481afa7792dbb3f4a/basic/__init__.py
--------------------------------------------------------------------------------
/basic/cli.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import tensorflow as tf
4 |
5 | from basic.main import main as m
6 |
7 | flags = tf.app.flags
8 |
9 | # Names and directories
10 | flags.DEFINE_string("model_name", "basic", "Model name [basic]")
11 | flags.DEFINE_string("data_dir", "data/squad", "Data dir [data/squad]")
12 | flags.DEFINE_string("run_id", "0", "Run ID [0]")
13 | flags.DEFINE_string("out_base_dir", "out", "out base dir [out]")
14 | flags.DEFINE_string("forward_name", "single", "Forward name [single]")
15 | flags.DEFINE_string("answer_path", "", "Answer path []")
16 | flags.DEFINE_string("eval_path", "", "Eval path []")
17 | flags.DEFINE_string("load_path", "", "Load path []")
18 | flags.DEFINE_string("shared_path", "", "Shared path []")
19 |
20 | # Device placement
21 | flags.DEFINE_string("device", "/cpu:0", "default device for summing gradients. [/cpu:0]")
22 | flags.DEFINE_string("device_type", "gpu", "device for computing gradients (parallelization). cpu | gpu [gpu]")
23 | flags.DEFINE_integer("num_gpus", 1, "num of gpus or cpus for computing gradients [1]")
24 |
25 | # Essential training and test options
26 | flags.DEFINE_string("mode", "test", "trains | test | forward [test]")
27 | flags.DEFINE_boolean("load", True, "load saved data? [True]")
28 | flags.DEFINE_bool("single", False, "supervise only the answer sentence? [False]")
29 | flags.DEFINE_boolean("debug", False, "Debugging mode? [False]")
30 | flags.DEFINE_bool('load_ema', True, "load exponential average of variables when testing? [True]")
31 | flags.DEFINE_bool("eval", True, "eval? [True]")
32 |
33 | # Training / test parameters
34 | flags.DEFINE_integer("batch_size", 60, "Batch size [60]")
35 | flags.DEFINE_integer("val_num_batches", 100, "validation num batches [100]")
36 | flags.DEFINE_integer("test_num_batches", 0, "test num batches [0]")
37 | flags.DEFINE_integer("num_epochs", 12, "Total number of epochs for training [12]")
38 | flags.DEFINE_integer("num_steps", 20000, "Number of steps [20000]")
39 | flags.DEFINE_integer("load_step", 0, "load step [0]")
40 | flags.DEFINE_float("init_lr", 0.5, "Initial learning rate [0.5]")
41 | flags.DEFINE_float("input_keep_prob", 0.8, "Input keep prob for the dropout of LSTM weights [0.8]")
42 | flags.DEFINE_float("keep_prob", 0.8, "Keep prob for the dropout of Char-CNN weights [0.8]")
43 | flags.DEFINE_float("wd", 0.0, "L2 weight decay for regularization [0.0]")
44 | flags.DEFINE_integer("hidden_size", 100, "Hidden size [100]")
45 | flags.DEFINE_integer("char_out_size", 100, "char-level word embedding size [100]")
46 | flags.DEFINE_integer("char_emb_size", 8, "Char emb size [8]")
47 | flags.DEFINE_string("out_channel_dims", "100", "Out channel dims of Char-CNN, separated by commas [100]")
48 | flags.DEFINE_string("filter_heights", "5", "Filter heights of Char-CNN, separated by commas [5]")
49 | flags.DEFINE_bool("finetune", False, "Finetune word embeddings? [False]")
50 | flags.DEFINE_bool("highway", True, "Use highway? [True]")
51 | flags.DEFINE_integer("highway_num_layers", 2, "highway num layers [2]")
52 | flags.DEFINE_bool("share_cnn_weights", True, "Share Char-CNN weights [True]")
53 | flags.DEFINE_bool("share_lstm_weights", True, "Share pre-processing (phrase-level) LSTM weights [True]")
54 | flags.DEFINE_float("var_decay", 0.999, "Exponential moving average decay for variables [0.999]")
55 |
56 | # Optimizations
57 | flags.DEFINE_bool("cluster", False, "Cluster data for faster training [False]")
58 | flags.DEFINE_bool("len_opt", False, "Length optimization? [False]")
59 | flags.DEFINE_bool("cpu_opt", False, "CPU optimization? GPU computation can be slower [False]")
60 |
61 | # Logging and saving options
62 | flags.DEFINE_boolean("progress", True, "Show progress? [True]")
63 | flags.DEFINE_integer("log_period", 100, "Log period [100]")
64 | flags.DEFINE_integer("eval_period", 1000, "Eval period [1000]")
65 | flags.DEFINE_integer("save_period", 1000, "Save Period [1000]")
66 | flags.DEFINE_integer("max_to_keep", 20, "Max recent saves to keep [20]")
67 | flags.DEFINE_bool("dump_eval", True, "dump eval? [True]")
68 | flags.DEFINE_bool("dump_answer", True, "dump answer? [True]")
69 | flags.DEFINE_bool("vis", False, "output visualization numbers? [False]")
70 | flags.DEFINE_bool("dump_pickle", True, "Dump pickle instead of json? [True]")
71 | flags.DEFINE_float("decay", 0.9, "Exponential moving average decay for logging values [0.9]")
72 |
73 | # Thresholds for speed and less memory usage
74 | flags.DEFINE_integer("word_count_th", 10, "word count th [100]")
75 | flags.DEFINE_integer("char_count_th", 50, "char count th [500]")
76 | flags.DEFINE_integer("sent_size_th", 400, "sent size th [64]")
77 | flags.DEFINE_integer("num_sents_th", 8, "num sents th [8]")
78 | flags.DEFINE_integer("ques_size_th", 30, "ques size th [32]")
79 | flags.DEFINE_integer("word_size_th", 16, "word size th [16]")
80 | flags.DEFINE_integer("para_size_th", 256, "para size th [256]")
81 |
82 | # Advanced training options
83 | flags.DEFINE_bool("lower_word", True, "lower word [True]")
84 | flags.DEFINE_bool("squash", False, "squash the sentences into one? [False]")
85 | flags.DEFINE_bool("swap_memory", True, "swap memory? [True]")
86 | flags.DEFINE_string("data_filter", "max", "max | valid | semi [max]")
87 | flags.DEFINE_bool("use_glove_for_unk", True, "use glove for unk [False]")
88 | flags.DEFINE_bool("known_if_glove", True, "consider as known if present in glove [False]")
89 | flags.DEFINE_string("logit_func", "tri_linear", "logit func [tri_linear]")
90 | flags.DEFINE_string("answer_func", "linear", "answer logit func [linear]")
91 | flags.DEFINE_string("sh_logit_func", "tri_linear", "sh logit func [tri_linear]")
92 |
93 | # Ablation options
94 | flags.DEFINE_bool("use_char_emb", True, "use char emb? [True]")
95 | flags.DEFINE_bool("use_word_emb", True, "use word embedding? [True]")
96 | flags.DEFINE_bool("q2c_att", True, "question-to-context attention? [True]")
97 | flags.DEFINE_bool("c2q_att", True, "context-to-question attention? [True]")
98 | flags.DEFINE_bool("dynamic_att", False, "Dynamic attention [False]")
99 |
100 |
101 | def main(_):
102 | config = flags.FLAGS
103 |
104 | config.out_dir = os.path.join(config.out_base_dir, config.model_name, str(config.run_id).zfill(2))
105 |
106 | m(config)
107 |
108 | if __name__ == "__main__":
109 | tf.app.run()
110 |
--------------------------------------------------------------------------------
/basic/ensemble.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import functools
3 | import gzip
4 | import json
5 | import pickle
6 | from collections import defaultdict
7 | from operator import mul
8 |
9 | from tqdm import tqdm
10 | from squad.utils import get_phrase, get_best_span
11 |
12 |
13 | def get_args():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('paths', nargs='+')
16 | parser.add_argument('-o', '--out', default='ensemble.json')
17 | parser.add_argument("--data_path", default="data/squad/data_test.json")
18 | parser.add_argument("--shared_path", default="data/squad/shared_test.json")
19 | args = parser.parse_args()
20 | return args
21 |
22 |
23 | def ensemble(args):
24 | e_list = []
25 | for path in tqdm(args.paths):
26 | with gzip.open(path, 'r') as fh:
27 | e = pickle.load(fh)
28 | e_list.append(e)
29 |
30 | with open(args.data_path, 'r') as fh:
31 | data = json.load(fh)
32 |
33 | with open(args.shared_path, 'r') as fh:
34 | shared = json.load(fh)
35 |
36 | out = {}
37 | for idx, (id_, rx) in tqdm(enumerate(zip(data['ids'], data['*x'])), total=len(e['yp'])):
38 | if idx >= len(e['yp']):
39 | # for debugging purpose
40 | break
41 | context = shared['p'][rx[0]][rx[1]]
42 | wordss = shared['x'][rx[0]][rx[1]]
43 | yp_list = [e['yp'][idx] for e in e_list]
44 | yp2_list = [e['yp2'][idx] for e in e_list]
45 | answer = ensemble3(context, wordss, yp_list, yp2_list)
46 | out[id_] = answer
47 |
48 | with open(args.out, 'w') as fh:
49 | json.dump(out, fh)
50 |
51 |
52 | def ensemble1(context, wordss, y1_list, y2_list):
53 | """
54 |
55 | :param context: Original context
56 | :param wordss: tokenized words (nested 2D list)
57 | :param y1_list: list of start index probs (each element corresponds to probs form single model)
58 | :param y2_list: list of stop index probs
59 | :return:
60 | """
61 | sum_y1 = combine_y_list(y1_list)
62 | sum_y2 = combine_y_list(y2_list)
63 | span, score = get_best_span(sum_y1, sum_y2)
64 | return get_phrase(context, wordss, span)
65 |
66 |
67 | def ensemble2(context, wordss, y1_list, y2_list):
68 | start_dict = defaultdict(float)
69 | stop_dict = defaultdict(float)
70 | for y1, y2 in zip(y1_list, y2_list):
71 | span, score = get_best_span(y1, y2)
72 | start_dict[span[0]] += y1[span[0][0]][span[0][1]]
73 | stop_dict[span[1]] += y2[span[1][0]][span[1][1]]
74 | start = max(start_dict.items(), key=lambda pair: pair[1])[0]
75 | stop = max(stop_dict.items(), key=lambda pair: pair[1])[0]
76 | best_span = (start, stop)
77 | return get_phrase(context, wordss, best_span)
78 |
79 |
80 | def ensemble3(context, wordss, y1_list, y2_list):
81 | d = defaultdict(float)
82 | for y1, y2 in zip(y1_list, y2_list):
83 | span, score = get_best_span(y1, y2)
84 | phrase = get_phrase(context, wordss, span)
85 | d[phrase] += score
86 | return max(d.items(), key=lambda pair: pair[1])[0]
87 |
88 |
89 | def combine_y_list(y_list, op='*'):
90 | if op == '+':
91 | func = sum
92 | elif op == '*':
93 | def func(l): return functools.reduce(mul, l)
94 | else:
95 | func = op
96 | return [[func(yij_list) for yij_list in zip(*yi_list)] for yi_list in zip(*y_list)]
97 |
98 |
99 | def main():
100 | args = get_args()
101 | ensemble(args)
102 |
103 | if __name__ == "__main__":
104 | main()
105 |
106 |
107 |
--------------------------------------------------------------------------------
/basic/ensemble_fast.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | from collections import Counter, defaultdict
4 | import re
5 |
6 | def key_func(pair):
7 | return pair[1]
8 |
9 |
10 | def get_func(vals, probs):
11 | counter = Counter(vals)
12 | # return max(zip(vals, probs), key=lambda pair: pair[1])[0]
13 | # return max(zip(vals, probs), key=lambda pair: pair[1] * counter[pair[0]] / len(counter) - 999 * (len(pair[0]) == 0) )[0]
14 | # return max(zip(vals, probs), key=lambda pair: pair[1] + 0.7 * counter[pair[0]] / len(counter) - 999 * (len(pair[0]) == 0) )[0]
15 | d = defaultdict(float)
16 | for val, prob in zip(vals, probs):
17 | d[val] += prob
18 | d[''] = 0
19 | return max(d.items(), key=lambda pair: pair[1])[0]
20 |
21 | third_path = sys.argv[1]
22 | other_paths = sys.argv[2:]
23 |
24 | others = [json.load(open(path, 'r')) for path in other_paths]
25 |
26 |
27 | c = {}
28 |
29 | assert min(map(len, others)) == max(map(len, others)), list(map(len, others))
30 |
31 | for key in others[0].keys():
32 | if key == 'scores':
33 | continue
34 | probs = [other['scores'][key] for other in others]
35 | vals = [other[key] for other in others]
36 | largest_val = get_func(vals, probs)
37 | c[key] = largest_val
38 |
39 | json.dump(c, open(third_path, 'w'))
--------------------------------------------------------------------------------
/basic/graph_handler.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import json
3 | from json import encoder
4 | import os
5 |
6 | import tensorflow as tf
7 |
8 | from basic.evaluator import Evaluation, F1Evaluation
9 | from my.utils import short_floats
10 |
11 | import pickle
12 |
13 |
14 | class GraphHandler(object):
15 | def __init__(self, config, model):
16 | self.config = config
17 | self.model = model
18 | self.saver = tf.train.Saver(max_to_keep=config.max_to_keep)
19 | self.writer = None
20 | self.save_path = os.path.join(config.save_dir, config.model_name)
21 |
22 | def initialize(self, sess):
23 | sess.run(tf.initialize_all_variables())
24 | if self.config.load:
25 | self._load(sess)
26 |
27 | if self.config.mode == 'train':
28 | self.writer = tf.train.SummaryWriter(self.config.log_dir, graph=tf.get_default_graph())
29 |
30 | def save(self, sess, global_step=None):
31 | saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
32 | saver.save(sess, self.save_path, global_step=global_step)
33 |
34 | def _load(self, sess):
35 | config = self.config
36 | vars_ = {var.name.split(":")[0]: var for var in tf.all_variables()}
37 | if config.load_ema:
38 | ema = self.model.var_ema
39 | for var in tf.trainable_variables():
40 | del vars_[var.name.split(":")[0]]
41 | vars_[ema.average_name(var)] = var
42 | saver = tf.train.Saver(vars_, max_to_keep=config.max_to_keep)
43 |
44 | if config.load_path:
45 | save_path = config.load_path
46 | elif config.load_step > 0:
47 | save_path = os.path.join(config.save_dir, "{}-{}".format(config.model_name, config.load_step))
48 | else:
49 | save_dir = config.save_dir
50 | checkpoint = tf.train.get_checkpoint_state(save_dir)
51 | assert checkpoint is not None, "cannot load checkpoint at {}".format(save_dir)
52 | save_path = checkpoint.model_checkpoint_path
53 | print("Loading saved model from {}".format(save_path))
54 | saver.restore(sess, save_path)
55 |
56 | def add_summary(self, summary, global_step):
57 | self.writer.add_summary(summary, global_step)
58 |
59 | def add_summaries(self, summaries, global_step):
60 | for summary in summaries:
61 | self.add_summary(summary, global_step)
62 |
63 | def dump_eval(self, e, precision=2, path=None):
64 | assert isinstance(e, Evaluation)
65 | if self.config.dump_pickle:
66 | path = path or os.path.join(self.config.eval_dir, "{}-{}.pklz".format(e.data_type, str(e.global_step).zfill(6)))
67 | with gzip.open(path, 'wb', compresslevel=3) as fh:
68 | pickle.dump(e.dict, fh)
69 | else:
70 | path = path or os.path.join(self.config.eval_dir, "{}-{}.json".format(e.data_type, str(e.global_step).zfill(6)))
71 | with open(path, 'w') as fh:
72 | json.dump(short_floats(e.dict, precision), fh)
73 |
74 | def dump_answer(self, e, path=None):
75 | assert isinstance(e, Evaluation)
76 | path = path or os.path.join(self.config.answer_dir, "{}-{}.json".format(e.data_type, str(e.global_step).zfill(6)))
77 | with open(path, 'w') as fh:
78 | json.dump(e.id2answer_dict, fh)
79 |
80 |
--------------------------------------------------------------------------------
/basic/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import math
4 | import os
5 | import shutil
6 | from pprint import pprint
7 |
8 | import tensorflow as tf
9 | from tqdm import tqdm
10 | import numpy as np
11 |
12 | from basic.evaluator import ForwardEvaluator, MultiGPUF1Evaluator
13 | from basic.graph_handler import GraphHandler
14 | from basic.model import get_multi_gpu_models
15 | from basic.trainer import MultiGPUTrainer
16 | from basic.read_data import read_data, get_squad_data_filter, update_config
17 |
18 |
19 | def main(config):
20 | set_dirs(config)
21 | with tf.device(config.device):
22 | if config.mode == 'train':
23 | _train(config)
24 | elif config.mode == 'test':
25 | _test(config)
26 | elif config.mode == 'forward':
27 | _forward(config)
28 | else:
29 | raise ValueError("invalid value for 'mode': {}".format(config.mode))
30 |
31 |
32 | def set_dirs(config):
33 | # create directories
34 | assert config.load or config.mode == 'train', "config.load must be True if not training"
35 | if not config.load and os.path.exists(config.out_dir):
36 | shutil.rmtree(config.out_dir)
37 |
38 | config.save_dir = os.path.join(config.out_dir, "save")
39 | config.log_dir = os.path.join(config.out_dir, "log")
40 | config.eval_dir = os.path.join(config.out_dir, "eval")
41 | config.answer_dir = os.path.join(config.out_dir, "answer")
42 | if not os.path.exists(config.out_dir):
43 | os.makedirs(config.out_dir)
44 | if not os.path.exists(config.save_dir):
45 | os.mkdir(config.save_dir)
46 | if not os.path.exists(config.log_dir):
47 | os.mkdir(config.log_dir)
48 | if not os.path.exists(config.answer_dir):
49 | os.mkdir(config.answer_dir)
50 | if not os.path.exists(config.eval_dir):
51 | os.mkdir(config.eval_dir)
52 |
53 |
54 | def _config_debug(config):
55 | if config.debug:
56 | config.num_steps = 2
57 | config.eval_period = 1
58 | config.log_period = 1
59 | config.save_period = 1
60 | config.val_num_batches = 2
61 | config.test_num_batches = 2
62 |
63 |
64 | def _train(config):
65 | data_filter = get_squad_data_filter(config)
66 | train_data = read_data(config, 'train', config.load, data_filter=data_filter)
67 | dev_data = read_data(config, 'dev', True, data_filter=data_filter)
68 | update_config(config, [train_data, dev_data])
69 |
70 | _config_debug(config)
71 |
72 | word2vec_dict = train_data.shared['lower_word2vec'] if config.lower_word else train_data.shared['word2vec']
73 | word2idx_dict = train_data.shared['word2idx']
74 | idx2vec_dict = {word2idx_dict[word]: vec for word, vec in word2vec_dict.items() if word in word2idx_dict}
75 | emb_mat = np.array([idx2vec_dict[idx] if idx in idx2vec_dict
76 | else np.random.multivariate_normal(np.zeros(config.word_emb_size), np.eye(config.word_emb_size))
77 | for idx in range(config.word_vocab_size)])
78 | config.emb_mat = emb_mat
79 |
80 | # construct model graph and variables (using default graph)
81 | pprint(config.__flags, indent=2)
82 | models = get_multi_gpu_models(config)
83 | model = models[0]
84 | trainer = MultiGPUTrainer(config, models)
85 | evaluator = MultiGPUF1Evaluator(config, models, tensor_dict=model.tensor_dict if config.vis else None)
86 | graph_handler = GraphHandler(config, model) # controls all tensors and variables in the graph, including loading /saving
87 |
88 | # Variables
89 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
90 | graph_handler.initialize(sess)
91 |
92 | # Begin training
93 | num_steps = config.num_steps or int(math.ceil(train_data.num_examples / (config.batch_size * config.num_gpus))) * config.num_epochs
94 | global_step = 0
95 | for batches in tqdm(train_data.get_multi_batches(config.batch_size, config.num_gpus,
96 | num_steps=num_steps, shuffle=True, cluster=config.cluster), total=num_steps):
97 | global_step = sess.run(model.global_step) + 1 # +1 because all calculations are done after step
98 | get_summary = global_step % config.log_period == 0
99 | loss, summary, train_op = trainer.step(sess, batches, get_summary=get_summary)
100 | if get_summary:
101 | graph_handler.add_summary(summary, global_step)
102 |
103 | # occasional saving
104 | if global_step % config.save_period == 0:
105 | graph_handler.save(sess, global_step=global_step)
106 |
107 | if not config.eval:
108 | continue
109 | # Occasional evaluation
110 | if global_step % config.eval_period == 0:
111 | num_steps = math.ceil(dev_data.num_examples / (config.batch_size * config.num_gpus))
112 | if 0 < config.val_num_batches < num_steps:
113 | num_steps = config.val_num_batches
114 | e_train = evaluator.get_evaluation_from_batches(
115 | sess, tqdm(train_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps), total=num_steps)
116 | )
117 | graph_handler.add_summaries(e_train.summaries, global_step)
118 | e_dev = evaluator.get_evaluation_from_batches(
119 | sess, tqdm(dev_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps), total=num_steps))
120 | graph_handler.add_summaries(e_dev.summaries, global_step)
121 |
122 | if config.dump_eval:
123 | graph_handler.dump_eval(e_dev)
124 | if config.dump_answer:
125 | graph_handler.dump_answer(e_dev)
126 | if global_step % config.save_period != 0:
127 | graph_handler.save(sess, global_step=global_step)
128 |
129 |
130 | def _test(config):
131 | test_data = read_data(config, 'test', True)
132 | update_config(config, [test_data])
133 |
134 | _config_debug(config)
135 |
136 | if config.use_glove_for_unk:
137 | word2vec_dict = test_data.shared['lower_word2vec'] if config.lower_word else test_data.shared['word2vec']
138 | new_word2idx_dict = test_data.shared['new_word2idx']
139 | idx2vec_dict = {idx: word2vec_dict[word] for word, idx in new_word2idx_dict.items()}
140 | new_emb_mat = np.array([idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32')
141 | config.new_emb_mat = new_emb_mat
142 |
143 | pprint(config.__flags, indent=2)
144 | models = get_multi_gpu_models(config)
145 | model = models[0]
146 | evaluator = MultiGPUF1Evaluator(config, models, tensor_dict=models[0].tensor_dict if config.vis else None)
147 | graph_handler = GraphHandler(config, model)
148 |
149 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
150 | graph_handler.initialize(sess)
151 | num_steps = math.ceil(test_data.num_examples / (config.batch_size * config.num_gpus))
152 | if 0 < config.test_num_batches < num_steps:
153 | num_steps = config.test_num_batches
154 |
155 | e = None
156 | for multi_batch in tqdm(test_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps, cluster=config.cluster), total=num_steps):
157 | ei = evaluator.get_evaluation(sess, multi_batch)
158 | e = ei if e is None else e + ei
159 | if config.vis:
160 | eval_subdir = os.path.join(config.eval_dir, "{}-{}".format(ei.data_type, str(ei.global_step).zfill(6)))
161 | if not os.path.exists(eval_subdir):
162 | os.mkdir(eval_subdir)
163 | path = os.path.join(eval_subdir, str(ei.idxs[0]).zfill(8))
164 | graph_handler.dump_eval(ei, path=path)
165 |
166 | print(e)
167 | if config.dump_answer:
168 | print("dumping answer ...")
169 | graph_handler.dump_answer(e)
170 | if config.dump_eval:
171 | print("dumping eval ...")
172 | graph_handler.dump_eval(e)
173 |
174 |
175 | def _forward(config):
176 | assert config.load
177 | test_data = read_data(config, config.forward_name, True)
178 | update_config(config, [test_data])
179 |
180 | _config_debug(config)
181 |
182 | if config.use_glove_for_unk:
183 | word2vec_dict = test_data.shared['lower_word2vec'] if config.lower_word else test_data.shared['word2vec']
184 | new_word2idx_dict = test_data.shared['new_word2idx']
185 | idx2vec_dict = {idx: word2vec_dict[word] for word, idx in new_word2idx_dict.items()}
186 | new_emb_mat = np.array([idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32')
187 | config.new_emb_mat = new_emb_mat
188 |
189 | pprint(config.__flags, indent=2)
190 | models = get_multi_gpu_models(config)
191 | model = models[0]
192 | evaluator = ForwardEvaluator(config, model)
193 | graph_handler = GraphHandler(config, model) # controls all tensors and variables in the graph, including loading /saving
194 |
195 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
196 | graph_handler.initialize(sess)
197 |
198 | num_batches = math.ceil(test_data.num_examples / config.batch_size)
199 | if 0 < config.test_num_batches < num_batches:
200 | num_batches = config.test_num_batches
201 | e = evaluator.get_evaluation_from_batches(sess, tqdm(test_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches))
202 | print(e)
203 | if config.dump_answer:
204 | print("dumping answer ...")
205 | graph_handler.dump_answer(e, path=config.answer_path)
206 | if config.dump_eval:
207 | print("dumping eval ...")
208 | graph_handler.dump_eval(e, path=config.eval_path)
209 |
210 |
211 | def _get_args():
212 | parser = argparse.ArgumentParser()
213 | parser.add_argument("config_path")
214 | return parser.parse_args()
215 |
216 |
217 | class Config(object):
218 | def __init__(self, **entries):
219 | self.__dict__.update(entries)
220 |
221 |
222 | def _run():
223 | args = _get_args()
224 | with open(args.config_path, 'r') as fh:
225 | config = Config(**json.load(fh))
226 | main(config)
227 |
228 |
229 | if __name__ == "__main__":
230 | _run()
231 |
--------------------------------------------------------------------------------
/basic/run_ensemble.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | source_path=$1
3 | target_path=$2
4 | inter_dir="inter_ensemble"
5 | root_dir="save"
6 |
7 | parg=""
8 | marg=""
9 | if [ "$3" = "debug" ]
10 | then
11 | parg="-d"
12 | marg="--debug"
13 | fi
14 |
15 | # Preprocess data
16 | python3 -m squad.prepro --mode single --single_path $source_path $parg --target_dir $inter_dir --glove_dir .
17 |
18 | eargs=""
19 | for num in 31 33 34 35 36 37 40 41 43 44 45 46; do
20 | load_path="$root_dir/$num/save"
21 | shared_path="$root_dir/$num/shared.json"
22 | eval_path="$inter_dir/eval-$num.json"
23 | eargs="$eargs $eval_path"
24 | python3 -m basic.cli --data_dir $inter_dir --eval_path $eval_path --nodump_answer --load_path $load_path --shared_path $shared_path $marg --eval_num_batches 0 --mode forward --batch_size 1 --len_opt --cluster --cpu_opt --load_ema &
25 | done
26 | wait
27 |
28 | # Ensemble
29 | python3 -m basic.ensemble --data_path $inter_dir/data_single.json --shared_path $inter_dir/shared_single.json -o $target_path $eargs
30 |
--------------------------------------------------------------------------------
/basic/run_single.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | source_path=$1
3 | target_path=$2
4 | inter_dir="inter_single"
5 | root_dir="save"
6 |
7 | parg=""
8 | marg=""
9 | if [ "$3" = "debug" ]
10 | then
11 | parg="-d"
12 | marg="--debug"
13 | fi
14 |
15 | # Preprocess data
16 | python3 -m squad.prepro --mode single --single_path $source_path $parg --target_dir $inter_dir --glove_dir .
17 |
18 | num=37
19 | load_path="$root_dir/$num/save"
20 | shared_path="$root_dir/$num/shared.json"
21 | eval_path="$inter_dir/eval.json"
22 | python3 -m basic.cli --data_dir $inter_dir --eval_path $eval_path --nodump_answer --load_path $load_path --shared_path $shared_path $marg --eval_num_batches 0 --mode forward --batch_size 1 --len_opt --cluster --cpu_opt --load_ema
23 |
24 | # Ensemble (for single run, just one input)
25 | python3 -m basic.ensemble --data_path $inter_dir/data_single.json --shared_path $inter_dir/shared_single.json -o $target_path $eval_path
26 |
27 |
28 |
--------------------------------------------------------------------------------
/basic/templates/visualizer.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | {{ title }}
6 |
7 |
8 |
19 |
20 |
23 |
24 | {{ title }}
25 |
26 |
27 | ID |
28 | Question |
29 | Answers |
30 | Predicted |
31 | Score |
32 | Paragraph |
33 |
34 | {% for row in rows %}
35 |
36 | {{ row.id }} |
37 |
38 | {% for qj in row.ques %}
39 | {{ qj }}
40 | {% endfor %}
41 | |
42 |
43 | {% for aa in row.a %}
44 | {{ aa }}
45 | {% endfor %}
46 | |
47 | {{ row.ap }} |
48 | {{ row.score }} |
49 |
50 |
51 | {% for xj, ypj, yp2j in zip(row.para, row.yp, row.yp2) %}
52 |
53 | {% set rowloop = loop %}
54 | {% for xjk, ypjk in zip(xj, ypj) %}
55 |
56 | {% if row.y[0][0] == rowloop.index0 and row.y[0][1] <= loop.index0 <= row.y[1][1] %}
57 | {{ xjk }}
58 | {% else %}
59 | {{ xjk }}
60 | {% endif %}
61 | |
62 | {% endfor %}
63 |
64 |
65 | {% for xjk, yp2jk in zip(xj, yp2j) %}
66 | - |
67 | {% endfor %}
68 |
69 | {% endfor %}
70 |
71 | |
72 |
73 | {% endfor %}
74 |
75 |
76 |
--------------------------------------------------------------------------------
/basic/trainer.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from basic.model import Model
4 | from my.tensorflow import average_gradients
5 |
6 |
7 | class Trainer(object):
8 | def __init__(self, config, model):
9 | assert isinstance(model, Model)
10 | self.config = config
11 | self.model = model
12 | self.opt = tf.train.AdadeltaOptimizer(config.init_lr)
13 | self.loss = model.get_loss()
14 | self.var_list = model.get_var_list()
15 | self.global_step = model.get_global_step()
16 | self.summary = model.summary
17 | self.grads = self.opt.compute_gradients(self.loss, var_list=self.var_list)
18 | self.train_op = self.opt.apply_gradients(self.grads, global_step=self.global_step)
19 |
20 | def get_train_op(self):
21 | return self.train_op
22 |
23 | def step(self, sess, batch, get_summary=False):
24 | assert isinstance(sess, tf.Session)
25 | _, ds = batch
26 | feed_dict = self.model.get_feed_dict(ds, True)
27 | if get_summary:
28 | loss, summary, train_op = \
29 | sess.run([self.loss, self.summary, self.train_op], feed_dict=feed_dict)
30 | else:
31 | loss, train_op = sess.run([self.loss, self.train_op], feed_dict=feed_dict)
32 | summary = None
33 | return loss, summary, train_op
34 |
35 |
36 | class MultiGPUTrainer(object):
37 | def __init__(self, config, models):
38 | model = models[0]
39 | assert isinstance(model, Model)
40 | self.config = config
41 | self.model = model
42 | self.opt = tf.train.AdadeltaOptimizer(config.init_lr)
43 | self.var_list = model.get_var_list()
44 | self.global_step = model.get_global_step()
45 | self.summary = model.summary
46 | self.models = models
47 | losses = []
48 | grads_list = []
49 | for gpu_idx, model in enumerate(models):
50 | with tf.name_scope("grads_{}".format(gpu_idx)), tf.device("/{}:{}".format(config.device_type, gpu_idx)):
51 | loss = model.get_loss()
52 | grads = self.opt.compute_gradients(loss, var_list=self.var_list)
53 | losses.append(loss)
54 | grads_list.append(grads)
55 |
56 | self.loss = tf.add_n(losses)/len(losses)
57 | self.grads = average_gradients(grads_list)
58 | self.train_op = self.opt.apply_gradients(self.grads, global_step=self.global_step)
59 |
60 | def step(self, sess, batches, get_summary=False):
61 | assert isinstance(sess, tf.Session)
62 | feed_dict = {}
63 | for batch, model in zip(batches, self.models):
64 | _, ds = batch
65 | feed_dict.update(model.get_feed_dict(ds, True))
66 |
67 | if get_summary:
68 | loss, summary, train_op = \
69 | sess.run([self.loss, self.summary, self.train_op], feed_dict=feed_dict)
70 | else:
71 | loss, train_op = sess.run([self.loss, self.train_op], feed_dict=feed_dict)
72 | summary = None
73 | return loss, summary, train_op
74 |
--------------------------------------------------------------------------------
/basic/visualizer.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | from collections import OrderedDict
3 | import http.server
4 | import socketserver
5 | import argparse
6 | import json
7 | import os
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | from jinja2 import Environment, FileSystemLoader
12 |
13 | from basic.evaluator import get_span_score_pairs
14 | from squad.utils import get_best_span, get_span_score_pairs
15 |
16 |
17 | def bool_(string):
18 | if string == 'True':
19 | return True
20 | elif string == 'False':
21 | return False
22 | else:
23 | raise Exception()
24 |
25 | def get_args():
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument("--model_name", type=str, default='basic')
28 | parser.add_argument("--data_type", type=str, default='dev')
29 | parser.add_argument("--step", type=int, default=5000)
30 | parser.add_argument("--template_name", type=str, default="visualizer.html")
31 | parser.add_argument("--num_per_page", type=int, default=100)
32 | parser.add_argument("--data_dir", type=str, default="data/squad")
33 | parser.add_argument("--port", type=int, default=8000)
34 | parser.add_argument("--host", type=str, default="0.0.0.0")
35 | parser.add_argument("--open", type=str, default='False')
36 | parser.add_argument("--run_id", type=str, default="0")
37 |
38 | args = parser.parse_args()
39 | return args
40 |
41 |
42 | def _decode(decoder, sent):
43 | return " ".join(decoder[idx] for idx in sent)
44 |
45 |
46 | def accuracy2_visualizer(args):
47 | model_name = args.model_name
48 | data_type = args.data_type
49 | num_per_page = args.num_per_page
50 | data_dir = args.data_dir
51 | run_id = args.run_id.zfill(2)
52 | step = args.step
53 |
54 | eval_path =os.path.join("out", model_name, run_id, "eval", "{}-{}.json".format(data_type, str(step).zfill(6)))
55 | print("loading {}".format(eval_path))
56 | eval_ = json.load(open(eval_path, 'r'))
57 |
58 | _id = 0
59 | html_dir = "/tmp/list_results%d" % _id
60 | while os.path.exists(html_dir):
61 | _id += 1
62 | html_dir = "/tmp/list_results%d" % _id
63 |
64 | if os.path.exists(html_dir):
65 | shutil.rmtree(html_dir)
66 | os.mkdir(html_dir)
67 |
68 | cur_dir = os.path.dirname(os.path.realpath(__file__))
69 | templates_dir = os.path.join(cur_dir, 'templates')
70 | env = Environment(loader=FileSystemLoader(templates_dir))
71 | env.globals.update(zip=zip, reversed=reversed)
72 | template = env.get_template(args.template_name)
73 |
74 | data_path = os.path.join(data_dir, "data_{}.json".format(data_type))
75 | shared_path = os.path.join(data_dir, "shared_{}.json".format(data_type))
76 | print("loading {}".format(data_path))
77 | data = json.load(open(data_path, 'r'))
78 | print("loading {}".format(shared_path))
79 | shared = json.load(open(shared_path, 'r'))
80 |
81 | rows = []
82 | for i, (idx, yi, ypi, yp2i) in tqdm(enumerate(zip(*[eval_[key] for key in ('idxs', 'y', 'yp', 'yp2')])), total=len(eval_['idxs'])):
83 | id_, q, rx, answers = (data[key][idx] for key in ('ids', 'q', '*x', 'answerss'))
84 | x = shared['x'][rx[0]][rx[1]]
85 | ques = [" ".join(q)]
86 | para = [[word for word in sent] for sent in x]
87 | span = get_best_span(ypi, yp2i)
88 | ap = get_segment(para, span)
89 | score = "{:.3f}".format(ypi[span[0][0]][span[0][1]] * yp2i[span[1][0]][span[1][1]-1])
90 |
91 | row = {
92 | 'id': id_,
93 | 'title': "Hello world!",
94 | 'ques': ques,
95 | 'para': para,
96 | 'y': yi[0][0],
97 | 'y2': yi[0][1],
98 | 'yp': ypi,
99 | 'yp2': yp2i,
100 | 'a': answers,
101 | 'ap': ap,
102 | 'score': score
103 | }
104 | rows.append(row)
105 |
106 | if i % num_per_page == 0:
107 | html_path = os.path.join(html_dir, "%s.html" % str(i).zfill(8))
108 |
109 | if (i + 1) % num_per_page == 0 or (i + 1) == len(eval_['y']):
110 | var_dict = {'title': "Accuracy Visualization",
111 | 'rows': rows
112 | }
113 | with open(html_path, "wb") as f:
114 | f.write(template.render(**var_dict).encode('UTF-8'))
115 | rows = []
116 |
117 | os.chdir(html_dir)
118 | port = args.port
119 | host = args.host
120 | # Overriding to suppress log message
121 | class MyHandler(http.server.SimpleHTTPRequestHandler):
122 | def log_message(self, format, *args):
123 | pass
124 | handler = MyHandler
125 | httpd = socketserver.TCPServer((host, port), handler)
126 | if args.open == 'True':
127 | os.system("open http://%s:%d" % (args.host, args.port))
128 | print("serving at %s:%d" % (host, port))
129 | httpd.serve_forever()
130 |
131 |
132 | def get_segment(para, span):
133 | return " ".join(para[span[0][0]][span[0][1]:span[1][1]])
134 |
135 |
136 | if __name__ == "__main__":
137 | ARGS = get_args()
138 | accuracy2_visualizer(ARGS)
--------------------------------------------------------------------------------
/basic_cnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/bi-att-flow/49004549e9a88b78c359b31481afa7792dbb3f4a/basic_cnn/__init__.py
--------------------------------------------------------------------------------
/basic_cnn/cli.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import tensorflow as tf
4 |
5 | from basic_cnn.main import main as m
6 |
7 | flags = tf.app.flags
8 |
9 | flags.DEFINE_string("model_name", "basic_cnn", "Model name [basic]")
10 | flags.DEFINE_string("data_dir", "data/cnn", "Data dir [data/cnn]")
11 | flags.DEFINE_string("root_dir", "/Users/minjoons/data/cnn/questions", "root dir [~/data/cnn/questions]")
12 | flags.DEFINE_string("run_id", "0", "Run ID [0]")
13 | flags.DEFINE_string("out_base_dir", "out", "out base dir [out]")
14 |
15 | flags.DEFINE_integer("batch_size", 60, "Batch size [60]")
16 | flags.DEFINE_float("init_lr", 0.5, "Initial learning rate [0.5]")
17 | flags.DEFINE_integer("num_epochs", 50, "Total number of epochs for training [50]")
18 | flags.DEFINE_integer("num_steps", 20000, "Number of steps [20000]")
19 | flags.DEFINE_integer("eval_num_batches", 100, "eval num batches [100]")
20 | flags.DEFINE_integer("load_step", 0, "load step [0]")
21 | flags.DEFINE_integer("early_stop", 4, "early stop [4]")
22 |
23 | flags.DEFINE_string("mode", "test", "train | dev | test | forward [test]")
24 | flags.DEFINE_boolean("load", True, "load saved data? [True]")
25 | flags.DEFINE_boolean("progress", True, "Show progress? [True]")
26 | flags.DEFINE_integer("log_period", 100, "Log period [100]")
27 | flags.DEFINE_integer("eval_period", 1000, "Eval period [1000]")
28 | flags.DEFINE_integer("save_period", 1000, "Save Period [1000]")
29 | flags.DEFINE_float("decay", 0.9, "Exponential moving average decay [0.9]")
30 |
31 | flags.DEFINE_boolean("draft", False, "Draft for quick testing? [False]")
32 |
33 | flags.DEFINE_integer("hidden_size", 100, "Hidden size [100]")
34 | flags.DEFINE_integer("char_out_size", 100, "Char out size [100]")
35 | flags.DEFINE_float("input_keep_prob", 0.8, "Input keep prob [0.8]")
36 | flags.DEFINE_integer("char_emb_size", 8, "Char emb size [8]")
37 | flags.DEFINE_integer("char_filter_height", 5, "Char filter height [5]")
38 | flags.DEFINE_float("wd", 0.0, "Weight decay [0.0]")
39 | flags.DEFINE_bool("lower_word", True, "lower word [True]")
40 | flags.DEFINE_bool("dump_eval", False, "dump eval? [True]")
41 | flags.DEFINE_bool("dump_answer", True, "dump answer? [True]")
42 | flags.DEFINE_string("model", "2", "config 1 |2 [2]")
43 | flags.DEFINE_bool("squash", False, "squash the sentences into one? [False]")
44 | flags.DEFINE_bool("single", False, "supervise only the answer sentence? [False]")
45 |
46 | flags.DEFINE_integer("word_count_th", 10, "word count th [100]")
47 | flags.DEFINE_integer("char_count_th", 50, "char count th [500]")
48 | flags.DEFINE_integer("sent_size_th", 60, "sent size th [64]")
49 | flags.DEFINE_integer("num_sents_th", 200, "num sents th [8]")
50 | flags.DEFINE_integer("ques_size_th", 30, "ques size th [32]")
51 | flags.DEFINE_integer("word_size_th", 16, "word size th [16]")
52 | flags.DEFINE_integer("para_size_th", 256, "para size th [256]")
53 |
54 | flags.DEFINE_bool("swap_memory", True, "swap memory? [True]")
55 | flags.DEFINE_string("data_filter", "max", "max | valid | semi [max]")
56 | flags.DEFINE_bool("finetune", False, "finetune? [False]")
57 | flags.DEFINE_bool("feed_gt", False, "feed gt prev token during training [False]")
58 | flags.DEFINE_bool("feed_hard", False, "feed hard argmax prev token during testing [False]")
59 | flags.DEFINE_bool("use_glove_for_unk", True, "use glove for unk [False]")
60 | flags.DEFINE_bool("known_if_glove", True, "consider as known if present in glove [False]")
61 | flags.DEFINE_bool("eval", True, "eval? [True]")
62 | flags.DEFINE_integer("highway_num_layers", 2, "highway num layers [2]")
63 | flags.DEFINE_bool("use_word_emb", True, "use word embedding? [True]")
64 |
65 | flags.DEFINE_string("forward_name", "single", "Forward name [single]")
66 | flags.DEFINE_string("answer_path", "", "Answer path []")
67 | flags.DEFINE_string("load_path", "", "Load path []")
68 | flags.DEFINE_string("shared_path", "", "Shared path []")
69 | flags.DEFINE_string("device", "/cpu:0", "default device [/cpu:0]")
70 | flags.DEFINE_integer("num_gpus", 1, "num of gpus [1]")
71 |
72 | flags.DEFINE_string("out_channel_dims", "100", "Out channel dims, separated by commas [100]")
73 | flags.DEFINE_string("filter_heights", "5", "Filter heights, separated by commas [5]")
74 |
75 | flags.DEFINE_bool("share_cnn_weights", True, "Share CNN weights [False]")
76 | flags.DEFINE_bool("share_lstm_weights", True, "Share LSTM weights [True]")
77 | flags.DEFINE_bool("two_prepro_layers", False, "Use two layers for preprocessing? [False]")
78 | flags.DEFINE_bool("aug_att", False, "Augment attention layers with more features? [False]")
79 | flags.DEFINE_integer("max_to_keep", 20, "Max recent saves to keep [20]")
80 | flags.DEFINE_bool("vis", False, "output visualization numbers? [False]")
81 | flags.DEFINE_bool("dump_pickle", True, "Dump pickle instead of json? [True]")
82 | flags.DEFINE_float("keep_prob", 1.0, "keep prob [1.0]")
83 | flags.DEFINE_string("prev_mode", "a", "prev mode gy | y | a [a]")
84 | flags.DEFINE_string("logit_func", "tri_linear", "logit func [tri_linear]")
85 | flags.DEFINE_bool("sh", False, "use superhighway [False]")
86 | flags.DEFINE_string("answer_func", "linear", "answer logit func [linear]")
87 | flags.DEFINE_bool("cluster", False, "Cluster data for faster training [False]")
88 | flags.DEFINE_bool("len_opt", False, "Length optimization? [False]")
89 | flags.DEFINE_string("sh_logit_func", "tri_linear", "sh logit func [tri_linear]")
90 | flags.DEFINE_float("filter_ratio", 1.0, "filter ratio [1.0]")
91 | flags.DEFINE_bool("bi", False, "bi-directional attention? [False]")
92 | flags.DEFINE_integer("width", 5, "width around entity [5]")
93 |
94 |
95 | def main(_):
96 | config = flags.FLAGS
97 |
98 | config.out_dir = os.path.join(config.out_base_dir, config.model_name, str(config.run_id).zfill(2))
99 |
100 | m(config)
101 |
102 | if __name__ == "__main__":
103 | tf.app.run()
104 |
--------------------------------------------------------------------------------
/basic_cnn/graph_handler.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import json
3 | from json import encoder
4 | import os
5 |
6 | import tensorflow as tf
7 |
8 | from basic_cnn.evaluator import Evaluation, F1Evaluation
9 | from my.utils import short_floats
10 |
11 | import pickle
12 |
13 |
14 | class GraphHandler(object):
15 | def __init__(self, config):
16 | self.config = config
17 | self.saver = tf.train.Saver(max_to_keep=config.max_to_keep)
18 | self.writer = None
19 | self.save_path = os.path.join(config.save_dir, config.model_name)
20 |
21 | def initialize(self, sess):
22 | if self.config.load:
23 | self._load(sess)
24 | else:
25 | sess.run(tf.initialize_all_variables())
26 |
27 | if self.config.mode == 'train':
28 | self.writer = tf.train.SummaryWriter(self.config.log_dir, graph=tf.get_default_graph())
29 |
30 | def save(self, sess, global_step=None):
31 | self.saver.save(sess, self.save_path, global_step=global_step)
32 |
33 | def _load(self, sess):
34 | config = self.config
35 | if config.load_path:
36 | save_path = config.load_path
37 | elif config.load_step > 0:
38 | save_path = os.path.join(config.save_dir, "{}-{}".format(config.model_name, config.load_step))
39 | else:
40 | save_dir = config.save_dir
41 | checkpoint = tf.train.get_checkpoint_state(save_dir)
42 | assert checkpoint is not None, "cannot load checkpoint at {}".format(save_dir)
43 | save_path = checkpoint.model_checkpoint_path
44 | print("Loading saved model from {}".format(save_path))
45 | self.saver.restore(sess, save_path)
46 |
47 | def add_summary(self, summary, global_step):
48 | self.writer.add_summary(summary, global_step)
49 |
50 | def add_summaries(self, summaries, global_step):
51 | for summary in summaries:
52 | self.add_summary(summary, global_step)
53 |
54 | def dump_eval(self, e, precision=2, path=None):
55 | assert isinstance(e, Evaluation)
56 | if self.config.dump_pickle:
57 | path = path or os.path.join(self.config.eval_dir, "{}-{}.pklz".format(e.data_type, str(e.global_step).zfill(6)))
58 | with gzip.open(path, 'wb', compresslevel=3) as fh:
59 | pickle.dump(e.dict, fh)
60 | else:
61 | path = path or os.path.join(self.config.eval_dir, "{}-{}.json".format(e.data_type, str(e.global_step).zfill(6)))
62 | with open(path, 'w') as fh:
63 | json.dump(short_floats(e.dict, precision), fh)
64 |
65 | def dump_answer(self, e, path=None):
66 | assert isinstance(e, Evaluation)
67 | path = path or os.path.join(self.config.answer_dir, "{}-{}.json".format(e.data_type, str(e.global_step).zfill(6)))
68 | with open(path, 'w') as fh:
69 | json.dump(e.id2answer_dict, fh)
70 |
71 |
--------------------------------------------------------------------------------
/basic_cnn/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import math
4 | import os
5 | import shutil
6 | from pprint import pprint
7 |
8 | import tensorflow as tf
9 | from tqdm import tqdm
10 | import numpy as np
11 |
12 | from basic_cnn.evaluator import F1Evaluator, Evaluator, ForwardEvaluator, MultiGPUF1Evaluator, CNNAccuracyEvaluator, \
13 | MultiGPUCNNAccuracyEvaluator
14 | from basic_cnn.graph_handler import GraphHandler
15 | from basic_cnn.model import Model, get_multi_gpu_models
16 | from basic_cnn.trainer import Trainer, MultiGPUTrainer
17 |
18 | from basic_cnn.read_data import read_data, get_cnn_data_filter, update_config
19 |
20 |
21 | def main(config):
22 | set_dirs(config)
23 | with tf.device(config.device):
24 | if config.mode == 'train':
25 | _train(config)
26 | elif config.mode == 'test' or config.mode == 'dev':
27 | _test(config)
28 | elif config.mode == 'forward':
29 | _forward(config)
30 | else:
31 | raise ValueError("invalid value for 'mode': {}".format(config.mode))
32 |
33 |
34 | def _config_draft(config):
35 | if config.draft:
36 | config.num_steps = 2
37 | config.eval_period = 1
38 | config.log_period = 1
39 | config.save_period = 1
40 | config.eval_num_batches = 1
41 |
42 |
43 | def _train(config):
44 | # load_metadata(config, 'train') # this updates the config file according to metadata file
45 |
46 | data_filter = get_cnn_data_filter(config)
47 | train_data = read_data(config, 'train', config.load, data_filter=data_filter)
48 | dev_data = read_data(config, 'dev', True, data_filter=data_filter)
49 | # test_data = read_data(config, 'test', True, data_filter=data_filter)
50 | update_config(config, [train_data, dev_data])
51 |
52 | _config_draft(config)
53 |
54 | word2vec_dict = train_data.shared['lower_word2vec'] if config.lower_word else train_data.shared['word2vec']
55 | word2idx_dict = train_data.shared['word2idx']
56 | idx2vec_dict = {word2idx_dict[word]: vec for word, vec in word2vec_dict.items() if word in word2idx_dict}
57 | print("{}/{} unique words have corresponding glove vectors.".format(len(idx2vec_dict), len(word2idx_dict)))
58 | emb_mat = np.array([idx2vec_dict[idx] if idx in idx2vec_dict
59 | else np.random.multivariate_normal(np.zeros(config.word_emb_size), np.eye(config.word_emb_size))
60 | for idx in range(config.word_vocab_size)])
61 | config.emb_mat = emb_mat
62 |
63 | # construct model graph and variables (using default graph)
64 | pprint(config.__flags, indent=2)
65 | # model = Model(config)
66 | models = get_multi_gpu_models(config)
67 | model = models[0]
68 | trainer = MultiGPUTrainer(config, models)
69 | evaluator = MultiGPUCNNAccuracyEvaluator(config, models, tensor_dict=model.tensor_dict if config.vis else None)
70 | graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving
71 |
72 | # Variables
73 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
74 | graph_handler.initialize(sess)
75 |
76 | # begin training
77 | print(train_data.num_examples)
78 | num_steps = config.num_steps or int(math.ceil(train_data.num_examples / (config.batch_size * config.num_gpus))) * config.num_epochs
79 | global_step = 0
80 | for batches in tqdm(train_data.get_multi_batches(config.batch_size, config.num_gpus,
81 | num_steps=num_steps, shuffle=True, cluster=config.cluster), total=num_steps):
82 | global_step = sess.run(model.global_step) + 1 # +1 because all calculations are done after step
83 | get_summary = global_step % config.log_period == 0
84 | loss, summary, train_op = trainer.step(sess, batches, get_summary=get_summary)
85 | if get_summary:
86 | graph_handler.add_summary(summary, global_step)
87 |
88 | # occasional saving
89 | if global_step % config.save_period == 0:
90 | graph_handler.save(sess, global_step=global_step)
91 |
92 | if not config.eval:
93 | continue
94 | # Occasional evaluation
95 | if global_step % config.eval_period == 0:
96 | num_steps = math.ceil(dev_data.num_examples / (config.batch_size * config.num_gpus))
97 | if 0 < config.eval_num_batches < num_steps:
98 | num_steps = config.eval_num_batches
99 | e_train = evaluator.get_evaluation_from_batches(
100 | sess, tqdm(train_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps), total=num_steps)
101 | )
102 | graph_handler.add_summaries(e_train.summaries, global_step)
103 | e_dev = evaluator.get_evaluation_from_batches(
104 | sess, tqdm(dev_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps), total=num_steps))
105 | graph_handler.add_summaries(e_dev.summaries, global_step)
106 |
107 | if config.dump_eval:
108 | graph_handler.dump_eval(e_dev)
109 | if config.dump_answer:
110 | graph_handler.dump_answer(e_dev)
111 | if global_step % config.save_period != 0:
112 | graph_handler.save(sess, global_step=global_step)
113 |
114 |
115 | def _test(config):
116 | assert config.load
117 | test_data = read_data(config, config.mode, True)
118 | update_config(config, [test_data])
119 |
120 | _config_draft(config)
121 |
122 | if config.use_glove_for_unk:
123 | word2vec_dict = test_data.shared['lower_word2vec'] if config.lower_word else test_data.shared['word2vec']
124 | new_word2idx_dict = test_data.shared['new_word2idx']
125 | idx2vec_dict = {idx: word2vec_dict[word] for word, idx in new_word2idx_dict.items()}
126 | # print("{}/{} unique words have corresponding glove vectors.".format(len(idx2vec_dict), len(word2idx_dict)))
127 | new_emb_mat = np.array([idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32')
128 | config.new_emb_mat = new_emb_mat
129 |
130 | pprint(config.__flags, indent=2)
131 | models = get_multi_gpu_models(config)
132 | evaluator = MultiGPUCNNAccuracyEvaluator(config, models, tensor_dict=models[0].tensor_dict if config.vis else None)
133 | graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving
134 |
135 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
136 | graph_handler.initialize(sess)
137 | num_steps = math.ceil(test_data.num_examples / (config.batch_size * config.num_gpus))
138 | if 0 < config.eval_num_batches < num_steps:
139 | num_steps = config.eval_num_batches
140 |
141 | e = None
142 | for multi_batch in tqdm(test_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps, cluster=config.cluster), total=num_steps):
143 | ei = evaluator.get_evaluation(sess, multi_batch)
144 | e = ei if e is None else e + ei
145 | if config.vis:
146 | eval_subdir = os.path.join(config.eval_dir, "{}-{}".format(ei.data_type, str(ei.global_step).zfill(6)))
147 | if not os.path.exists(eval_subdir):
148 | os.mkdir(eval_subdir)
149 | path = os.path.join(eval_subdir, str(ei.idxs[0]).zfill(8))
150 | graph_handler.dump_eval(ei, path=path)
151 |
152 | print(e)
153 | if config.dump_answer:
154 | print("dumping answer ...")
155 | graph_handler.dump_answer(e)
156 | if config.dump_eval:
157 | print("dumping eval ...")
158 | graph_handler.dump_eval(e)
159 |
160 |
161 | def _forward(config):
162 | assert config.load
163 | test_data = read_data(config, config.forward_name, True)
164 | update_config(config, [test_data])
165 |
166 | _config_draft(config)
167 |
168 | if config.use_glove_for_unk:
169 | word2vec_dict = test_data.shared['lower_word2vec'] if config.lower_word else test_data.shared['word2vec']
170 | new_word2idx_dict = test_data.shared['new_word2idx']
171 | idx2vec_dict = {idx: word2vec_dict[word] for word, idx in new_word2idx_dict.items()}
172 | # print("{}/{} unique words have corresponding glove vectors.".format(len(idx2vec_dict), len(word2idx_dict)))
173 | new_emb_mat = np.array([idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32')
174 | config.new_emb_mat = new_emb_mat
175 |
176 | pprint(config.__flags, indent=2)
177 | models = get_multi_gpu_models(config)
178 | model = models[0]
179 | evaluator = ForwardEvaluator(config, model)
180 | graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving
181 |
182 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
183 | graph_handler.initialize(sess)
184 |
185 | num_batches = math.ceil(test_data.num_examples / config.batch_size)
186 | if 0 < config.eval_num_batches < num_batches:
187 | num_batches = config.eval_num_batches
188 | e = evaluator.get_evaluation_from_batches(sess, tqdm(test_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches))
189 | print(e)
190 | if config.dump_answer:
191 | print("dumping answer ...")
192 | graph_handler.dump_answer(e, path=config.answer_path)
193 | if config.dump_eval:
194 | print("dumping eval ...")
195 | graph_handler.dump_eval(e)
196 |
197 |
198 | def set_dirs(config):
199 | # create directories
200 | if not config.load and os.path.exists(config.out_dir):
201 | shutil.rmtree(config.out_dir)
202 |
203 | config.save_dir = os.path.join(config.out_dir, "save")
204 | config.log_dir = os.path.join(config.out_dir, "log")
205 | config.eval_dir = os.path.join(config.out_dir, "eval")
206 | config.answer_dir = os.path.join(config.out_dir, "answer")
207 | if not os.path.exists(config.out_dir):
208 | os.makedirs(config.out_dir)
209 | if not os.path.exists(config.save_dir):
210 | os.mkdir(config.save_dir)
211 | if not os.path.exists(config.log_dir):
212 | os.mkdir(config.log_dir)
213 | if not os.path.exists(config.answer_dir):
214 | os.mkdir(config.answer_dir)
215 | if not os.path.exists(config.eval_dir):
216 | os.mkdir(config.eval_dir)
217 |
218 |
219 | def _get_args():
220 | parser = argparse.ArgumentParser()
221 | parser.add_argument("config_path")
222 | return parser.parse_args()
223 |
224 |
225 | class Config(object):
226 | def __init__(self, **entries):
227 | self.__dict__.update(entries)
228 |
229 |
230 | def _run():
231 | args = _get_args()
232 | with open(args.config_path, 'r') as fh:
233 | config = Config(**json.load(fh))
234 | main(config)
235 |
236 |
237 | if __name__ == "__main__":
238 | _run()
239 |
--------------------------------------------------------------------------------
/basic_cnn/superhighway.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.python.ops.rnn_cell import RNNCell
3 |
4 | from my.tensorflow.nn import linear
5 |
6 |
7 | class SHCell(RNNCell):
8 | """
9 | Super-Highway Cell
10 | """
11 | def __init__(self, input_size, logit_func='tri_linear', scalar=False):
12 | self._state_size = input_size
13 | self._output_size = input_size
14 | self._logit_func = logit_func
15 | self._scalar = scalar
16 |
17 | @property
18 | def state_size(self):
19 | return self._state_size
20 |
21 | @property
22 | def output_size(self):
23 | return self._output_size
24 |
25 | def __call__(self, inputs, state, scope=None):
26 | with tf.variable_scope(scope or "SHCell"):
27 | a_size = 1 if self._scalar else self._state_size
28 | h, u = tf.split(1, 2, inputs)
29 | if self._logit_func == 'mul_linear':
30 | args = [h * u, state * u]
31 | a = tf.nn.sigmoid(linear(args, a_size, True))
32 | elif self._logit_func == 'linear':
33 | args = [h, u, state]
34 | a = tf.nn.sigmoid(linear(args, a_size, True))
35 | elif self._logit_func == 'tri_linear':
36 | args = [h, u, state, h * u, state * u]
37 | a = tf.nn.sigmoid(linear(args, a_size, True))
38 | elif self._logit_func == 'double':
39 | args = [h, u, state]
40 | a = tf.nn.sigmoid(linear(tf.tanh(linear(args, a_size, True)), self._state_size, True))
41 |
42 | else:
43 | raise Exception()
44 | new_state = a * state + (1 - a) * h
45 | outputs = state
46 | return outputs, new_state
47 |
48 |
--------------------------------------------------------------------------------
/basic_cnn/templates/visualizer.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | {{ title }}
6 |
7 |
8 |
19 |
20 |
23 |
24 | {{ title }}
25 |
26 |
27 | ID |
28 | Question |
29 | Answers |
30 | Predicted |
31 | Score |
32 | Paragraph |
33 |
34 | {% for row in rows %}
35 |
36 | {{ row.id }} |
37 |
38 | {% for qj in row.ques %}
39 | {{ qj }}
40 | {% endfor %}
41 | |
42 |
43 | {% for aa in row.a %}
44 | {{ aa }}
45 | {% endfor %}
46 | |
47 | {{ row.ap }} |
48 | {{ row.score }} |
49 |
50 |
51 | {% for xj, ypj, yp2j in zip(row.para, row.yp, row.yp2) %}
52 |
53 | {% set rowloop = loop %}
54 | {% for xjk, ypjk in zip(xj, ypj) %}
55 |
56 | {% if row.y[0][0] == rowloop.index0 and row.y[0][1] <= loop.index0 <= row.y[1][1] %}
57 | {{ xjk }}
58 | {% else %}
59 | {{ xjk }}
60 | {% endif %}
61 | |
62 | {% endfor %}
63 |
64 |
65 | {% for xjk, yp2jk in zip(xj, yp2j) %}
66 | - |
67 | {% endfor %}
68 |
69 | {% endfor %}
70 |
71 | |
72 |
73 | {% endfor %}
74 |
75 |
76 |
--------------------------------------------------------------------------------
/basic_cnn/trainer.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from basic_cnn.model import Model
4 | from my.tensorflow import average_gradients
5 |
6 |
7 | class Trainer(object):
8 | def __init__(self, config, model):
9 | assert isinstance(model, Model)
10 | self.config = config
11 | self.model = model
12 | self.opt = tf.train.AdadeltaOptimizer(config.init_lr)
13 | self.loss = model.get_loss()
14 | self.var_list = model.get_var_list()
15 | self.global_step = model.get_global_step()
16 | self.summary = model.summary
17 | self.grads = self.opt.compute_gradients(self.loss, var_list=self.var_list)
18 | self.train_op = self.opt.apply_gradients(self.grads, global_step=self.global_step)
19 |
20 | def get_train_op(self):
21 | return self.train_op
22 |
23 | def step(self, sess, batch, get_summary=False):
24 | assert isinstance(sess, tf.Session)
25 | _, ds = batch
26 | feed_dict = self.model.get_feed_dict(ds, True)
27 | if get_summary:
28 | loss, summary, train_op = \
29 | sess.run([self.loss, self.summary, self.train_op], feed_dict=feed_dict)
30 | else:
31 | loss, train_op = sess.run([self.loss, self.train_op], feed_dict=feed_dict)
32 | summary = None
33 | return loss, summary, train_op
34 |
35 |
36 | class MultiGPUTrainer(object):
37 | def __init__(self, config, models):
38 | model = models[0]
39 | assert isinstance(model, Model)
40 | self.config = config
41 | self.model = model
42 | self.opt = tf.train.AdadeltaOptimizer(config.init_lr)
43 | self.var_list = model.get_var_list()
44 | self.global_step = model.get_global_step()
45 | self.summary = model.summary
46 | self.models = models
47 | losses = []
48 | grads_list = []
49 | for gpu_idx, model in enumerate(models):
50 | with tf.name_scope("grads_{}".format(gpu_idx)), tf.device("/gpu:{}".format(gpu_idx)):
51 | loss = model.get_loss()
52 | grads = self.opt.compute_gradients(loss, var_list=self.var_list)
53 | losses.append(loss)
54 | grads_list.append(grads)
55 |
56 | self.loss = tf.add_n(losses)/len(losses)
57 | self.grads = average_gradients(grads_list)
58 | self.train_op = self.opt.apply_gradients(self.grads, global_step=self.global_step)
59 |
60 | def step(self, sess, batches, get_summary=False):
61 | assert isinstance(sess, tf.Session)
62 | feed_dict = {}
63 | for batch, model in zip(batches, self.models):
64 | _, ds = batch
65 | feed_dict.update(model.get_feed_dict(ds, True))
66 |
67 | if get_summary:
68 | loss, summary, train_op = \
69 | sess.run([self.loss, self.summary, self.train_op], feed_dict=feed_dict)
70 | else:
71 | loss, train_op = sess.run([self.loss, self.train_op], feed_dict=feed_dict)
72 | summary = None
73 | return loss, summary, train_op
74 |
--------------------------------------------------------------------------------
/basic_cnn/visualizer.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | from collections import OrderedDict
3 | import http.server
4 | import socketserver
5 | import argparse
6 | import json
7 | import os
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | from jinja2 import Environment, FileSystemLoader
12 |
13 | from basic_cnn.evaluator import get_span_score_pairs, get_best_span
14 |
15 |
16 | def bool_(string):
17 | if string == 'True':
18 | return True
19 | elif string == 'False':
20 | return False
21 | else:
22 | raise Exception()
23 |
24 | def get_args():
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--model_name", type=str, default='basic')
27 | parser.add_argument("--data_type", type=str, default='dev')
28 | parser.add_argument("--step", type=int, default=5000)
29 | parser.add_argument("--template_name", type=str, default="visualizer.html")
30 | parser.add_argument("--num_per_page", type=int, default=100)
31 | parser.add_argument("--data_dir", type=str, default="data/squad")
32 | parser.add_argument("--port", type=int, default=8000)
33 | parser.add_argument("--host", type=str, default="0.0.0.0")
34 | parser.add_argument("--open", type=str, default='False')
35 | parser.add_argument("--run_id", type=str, default="0")
36 |
37 | args = parser.parse_args()
38 | return args
39 |
40 |
41 | def _decode(decoder, sent):
42 | return " ".join(decoder[idx] for idx in sent)
43 |
44 |
45 | def accuracy2_visualizer(args):
46 | model_name = args.model_name
47 | data_type = args.data_type
48 | num_per_page = args.num_per_page
49 | data_dir = args.data_dir
50 | run_id = args.run_id.zfill(2)
51 | step = args.step
52 |
53 | eval_path =os.path.join("out", model_name, run_id, "eval", "{}-{}.json".format(data_type, str(step).zfill(6)))
54 | print("loading {}".format(eval_path))
55 | eval_ = json.load(open(eval_path, 'r'))
56 |
57 | _id = 0
58 | html_dir = "/tmp/list_results%d" % _id
59 | while os.path.exists(html_dir):
60 | _id += 1
61 | html_dir = "/tmp/list_results%d" % _id
62 |
63 | if os.path.exists(html_dir):
64 | shutil.rmtree(html_dir)
65 | os.mkdir(html_dir)
66 |
67 | cur_dir = os.path.dirname(os.path.realpath(__file__))
68 | templates_dir = os.path.join(cur_dir, 'templates')
69 | env = Environment(loader=FileSystemLoader(templates_dir))
70 | env.globals.update(zip=zip, reversed=reversed)
71 | template = env.get_template(args.template_name)
72 |
73 | data_path = os.path.join(data_dir, "data_{}.json".format(data_type))
74 | shared_path = os.path.join(data_dir, "shared_{}.json".format(data_type))
75 | print("loading {}".format(data_path))
76 | data = json.load(open(data_path, 'r'))
77 | print("loading {}".format(shared_path))
78 | shared = json.load(open(shared_path, 'r'))
79 |
80 | rows = []
81 | for i, (idx, yi, ypi, yp2i) in tqdm(enumerate(zip(*[eval_[key] for key in ('idxs', 'y', 'yp', 'yp2')])), total=len(eval_['idxs'])):
82 | id_, q, rx, answers = (data[key][idx] for key in ('ids', 'q', '*x', 'answerss'))
83 | x = shared['x'][rx[0]][rx[1]]
84 | ques = [" ".join(q)]
85 | para = [[word for word in sent] for sent in x]
86 | span = get_best_span(ypi, yp2i)
87 | ap = get_segment(para, span)
88 | score = "{:.3f}".format(ypi[span[0][0]][span[0][1]] * yp2i[span[1][0]][span[1][1]-1])
89 |
90 | row = {
91 | 'id': id_,
92 | 'title': "Hello world!",
93 | 'ques': ques,
94 | 'para': para,
95 | 'y': yi[0][0],
96 | 'y2': yi[0][1],
97 | 'yp': ypi,
98 | 'yp2': yp2i,
99 | 'a': answers,
100 | 'ap': ap,
101 | 'score': score
102 | }
103 | rows.append(row)
104 |
105 | if i % num_per_page == 0:
106 | html_path = os.path.join(html_dir, "%s.html" % str(i).zfill(8))
107 |
108 | if (i + 1) % num_per_page == 0 or (i + 1) == len(eval_['y']):
109 | var_dict = {'title': "Accuracy Visualization",
110 | 'rows': rows
111 | }
112 | with open(html_path, "wb") as f:
113 | f.write(template.render(**var_dict).encode('UTF-8'))
114 | rows = []
115 |
116 | os.chdir(html_dir)
117 | port = args.port
118 | host = args.host
119 | # Overriding to suppress log message
120 | class MyHandler(http.server.SimpleHTTPRequestHandler):
121 | def log_message(self, format, *args):
122 | pass
123 | handler = MyHandler
124 | httpd = socketserver.TCPServer((host, port), handler)
125 | if args.open == 'True':
126 | os.system("open http://%s:%d" % (args.host, args.port))
127 | print("serving at %s:%d" % (host, port))
128 | httpd.serve_forever()
129 |
130 |
131 | def get_segment(para, span):
132 | return " ".join(para[span[0][0]][span[0][1]:span[1][1]])
133 |
134 |
135 | if __name__ == "__main__":
136 | ARGS = get_args()
137 | accuracy2_visualizer(ARGS)
--------------------------------------------------------------------------------
/cnn_dm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/bi-att-flow/49004549e9a88b78c359b31481afa7792dbb3f4a/cnn_dm/__init__.py
--------------------------------------------------------------------------------
/cnn_dm/evaluate.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 |
5 | root_dir = sys.argv[1]
6 | answer_path = sys.argv[2]
7 | file_names = os.listdir(root_dir)
8 |
9 | num_correct = 0
10 | num_wrong = 0
11 |
12 | with open(answer_path, 'r') as fh:
13 | id2answer_dict = json.load(fh)
14 |
15 | for file_name in file_names:
16 | if not file_name.endswith(".question"):
17 | continue
18 | with open(os.path.join(root_dir, file_name), 'r') as fh:
19 | url = fh.readline().strip()
20 | _ = fh.readline()
21 | para = fh.readline().strip()
22 | _ = fh.readline()
23 | ques = fh.readline().strip()
24 | _ = fh.readline()
25 | answer = fh.readline().strip()
26 | _ = fh.readline()
27 | if file_name in id2answer_dict:
28 | pred = id2answer_dict[file_name]
29 | if pred == answer:
30 | num_correct += 1
31 | else:
32 | num_wrong += 1
33 | else:
34 | num_wrong += 1
35 |
36 | total = num_correct + num_wrong
37 | acc = float(num_correct) / total
38 | print("{} = {} / {}".format(acc, num_correct, total))
--------------------------------------------------------------------------------
/cnn_dm/prepro.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | # data: q, cq, (dq), (pq), y, *x, *cx
5 | # shared: x, cx, (dx), (px), word_counter, char_counter, word2vec
6 | # no metadata
7 | from collections import Counter
8 |
9 | from tqdm import tqdm
10 |
11 | from my.utils import process_tokens
12 | from squad.utils import get_word_span, process_tokens
13 |
14 |
15 | def bool_(arg):
16 | if arg == 'True':
17 | return True
18 | elif arg == 'False':
19 | return False
20 | raise Exception(arg)
21 |
22 |
23 | def main():
24 | args = get_args()
25 | prepro(args)
26 |
27 |
28 | def get_args():
29 | parser = argparse.ArgumentParser()
30 | home = os.path.expanduser("~")
31 | source_dir = os.path.join(home, "data", "cnn", 'questions')
32 | target_dir = "data/cnn"
33 | glove_dir = os.path.join(home, "data", "glove")
34 | parser.add_argument("--source_dir", default=source_dir)
35 | parser.add_argument("--target_dir", default=target_dir)
36 | parser.add_argument("--glove_dir", default=glove_dir)
37 | parser.add_argument("--glove_corpus", default='6B')
38 | parser.add_argument("--glove_vec_size", default=100, type=int)
39 | parser.add_argument("--debug", default=False, type=bool_)
40 | parser.add_argument("--num_sents_th", default=200, type=int)
41 | parser.add_argument("--ques_size_th", default=30, type=int)
42 | parser.add_argument("--width", default=5, type=int)
43 | # TODO : put more args here
44 | return parser.parse_args()
45 |
46 |
47 | def prepro(args):
48 | prepro_each(args, 'train')
49 | prepro_each(args, 'dev')
50 | prepro_each(args, 'test')
51 |
52 |
53 | def para2sents(para, width):
54 | """
55 | Turn para into double array of words (wordss)
56 | Where each sentence is up to 5 word neighbors of each entity
57 | :param para:
58 | :return:
59 | """
60 | words = para.split(" ")
61 | sents = []
62 | for i, word in enumerate(words):
63 | if word.startswith("@"):
64 | start = max(i - width, 0)
65 | stop = min(i + width + 1, len(words))
66 | sent = words[start:stop]
67 | sents.append(sent)
68 | return sents
69 |
70 |
71 | def get_word2vec(args, word_counter):
72 | glove_path = os.path.join(args.glove_dir, "glove.{}.{}d.txt".format(args.glove_corpus, args.glove_vec_size))
73 | sizes = {'6B': int(4e5), '42B': int(1.9e6), '840B': int(2.2e6), '2B': int(1.2e6)}
74 | total = sizes[args.glove_corpus]
75 | word2vec_dict = {}
76 | with open(glove_path, 'r', encoding='utf-8') as fh:
77 | for line in tqdm(fh, total=total):
78 | array = line.lstrip().rstrip().split(" ")
79 | word = array[0]
80 | vector = list(map(float, array[1:]))
81 | if word in word_counter:
82 | word2vec_dict[word] = vector
83 | elif word.capitalize() in word_counter:
84 | word2vec_dict[word.capitalize()] = vector
85 | elif word.lower() in word_counter:
86 | word2vec_dict[word.lower()] = vector
87 | elif word.upper() in word_counter:
88 | word2vec_dict[word.upper()] = vector
89 |
90 | print("{}/{} of word vocab have corresponding vectors in {}".format(len(word2vec_dict), len(word_counter), glove_path))
91 | return word2vec_dict
92 |
93 |
94 | def prepro_each(args, mode):
95 | source_dir = os.path.join(args.source_dir, mode)
96 | word_counter = Counter()
97 | lower_word_counter = Counter()
98 | ent_counter = Counter()
99 | char_counter = Counter()
100 | max_sent_size = 0
101 | max_word_size = 0
102 | max_ques_size = 0
103 | max_num_sents = 0
104 |
105 | file_names = list(os.listdir(source_dir))
106 | if args.debug:
107 | file_names = file_names[:1000]
108 | lens = []
109 |
110 | out_file_names = []
111 | for file_name in tqdm(file_names, total=len(file_names)):
112 | if file_name.endswith(".question"):
113 | with open(os.path.join(source_dir, file_name), 'r') as fh:
114 | url = fh.readline().strip()
115 | _ = fh.readline()
116 | para = fh.readline().strip()
117 | _ = fh.readline()
118 | ques = fh.readline().strip()
119 | _ = fh.readline()
120 | answer = fh.readline().strip()
121 | _ = fh.readline()
122 | cands = list(line.strip() for line in fh)
123 | cand_ents = list(cand.split(":")[0] for cand in cands)
124 | sents = para2sents(para, args.width)
125 | ques_words = ques.split(" ")
126 |
127 | # Filtering
128 | if len(sents) > args.num_sents_th or len(ques_words) > args.ques_size_th:
129 | continue
130 |
131 | max_sent_size = max(max(map(len, sents)), max_sent_size)
132 | max_ques_size = max(len(ques_words), max_ques_size)
133 | max_word_size = max(max(len(word) for sent in sents for word in sent), max_word_size)
134 | max_num_sents = max(len(sents), max_num_sents)
135 |
136 | for word in ques_words:
137 | if word.startswith("@"):
138 | ent_counter[word] += 1
139 | word_counter[word] += 1
140 | else:
141 | word_counter[word] += 1
142 | lower_word_counter[word.lower()] += 1
143 | for c in word:
144 | char_counter[c] += 1
145 | for sent in sents:
146 | for word in sent:
147 | if word.startswith("@"):
148 | ent_counter[word] += 1
149 | word_counter[word] += 1
150 | else:
151 | word_counter[word] += 1
152 | lower_word_counter[word.lower()] += 1
153 | for c in word:
154 | char_counter[c] += 1
155 |
156 | out_file_names.append(file_name)
157 | lens.append(len(sents))
158 | num_examples = len(out_file_names)
159 |
160 | assert len(out_file_names) == len(lens)
161 | sorted_file_names, lens = zip(*sorted(zip(out_file_names, lens), key=lambda each: each[1]))
162 | assert lens[-1] == max_num_sents
163 |
164 | word2vec_dict = get_word2vec(args, word_counter)
165 | lower_word2vec_dit = get_word2vec(args, lower_word_counter)
166 |
167 | shared = {'word_counter': word_counter, 'ent_counter': ent_counter, 'char_counter': char_counter,
168 | 'lower_word_counter': lower_word_counter,
169 | 'max_num_sents': max_num_sents, 'max_sent_size': max_sent_size, 'max_word_size': max_word_size,
170 | 'max_ques_size': max_ques_size,
171 | 'word2vec': word2vec_dict, 'lower_word2vec': lower_word2vec_dit, 'sorted': sorted_file_names,
172 | 'num_examples': num_examples}
173 |
174 | print("max num sents: {}".format(max_num_sents))
175 | print("max ques size: {}".format(max_ques_size))
176 |
177 | if not os.path.exists(args.target_dir):
178 | os.makedirs(args.target_dir)
179 | shared_path = os.path.join(args.target_dir, "shared_{}.json".format(mode))
180 | with open(shared_path, 'w') as fh:
181 | json.dump(shared, fh)
182 |
183 |
184 | if __name__ == "__main__":
185 | main()
186 |
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | DATA_DIR=$HOME/data
4 | mkdir $DATA_DIR
5 |
6 | # Download SQuAD
7 | SQUAD_DIR=$DATA_DIR/squad
8 | mkdir $SQUAD_DIR
9 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json -O $SQUAD_DIR/train-v1.1.json
10 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json -O $SQUAD_DIR/dev-v1.1.json
11 |
12 |
13 | # Download CNN and DailyMail
14 | # Download at: http://cs.nyu.edu/~kcho/DMQA/
15 |
16 |
17 | # Download GloVe
18 | GLOVE_DIR=$DATA_DIR/glove
19 | mkdir $GLOVE_DIR
20 | wget http://nlp.stanford.edu/data/glove.6B.zip -O $GLOVE_DIR/glove.6B.zip
21 | unzip $GLOVE_DIR/glove.6B.zip -d $GLOVE_DIR
22 |
23 | # Download NLTK (for tokenizer)
24 | # Make sure that nltk is installed!
25 | python3 -m nltk.downloader -d $HOME/nltk_data punkt
26 |
--------------------------------------------------------------------------------
/my/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/bi-att-flow/49004549e9a88b78c359b31481afa7792dbb3f4a/my/__init__.py
--------------------------------------------------------------------------------
/my/corenlp_interface.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import requests
4 | import nltk
5 | import json
6 | import networkx as nx
7 | import time
8 |
9 |
10 | class CoreNLPInterface(object):
11 | def __init__(self, url, port):
12 | self._url = url
13 | self._port = port
14 |
15 | def get(self, type_, in_, num_max_requests=100):
16 | in_ = in_.encode("utf-8")
17 | url = "http://{}:{}/{}".format(self._url, self._port, type_)
18 | out = None
19 | for _ in range(num_max_requests):
20 | try:
21 | r = requests.post(url, data=in_)
22 | out = r.content.decode('utf-8')
23 | if out == 'error':
24 | out = None
25 | break
26 | except:
27 | time.sleep(1)
28 | return out
29 |
30 | def split_doc(self, doc):
31 | out = self.get("doc", doc)
32 | return out if out is None else json.loads(out)
33 |
34 | def split_sent(self, sent):
35 | out = self.get("sent", sent)
36 | return out if out is None else json.loads(out)
37 |
38 | def get_dep(self, sent):
39 | out = self.get("dep", sent)
40 | return out if out is None else json.loads(out)
41 |
42 | def get_const(self, sent):
43 | out = self.get("const", sent)
44 | return out
45 |
46 | def get_const_tree(self, sent):
47 | out = self.get_const(sent)
48 | return out if out is None else nltk.tree.Tree.fromstring(out)
49 |
50 | @staticmethod
51 | def dep2tree(dep):
52 | tree = nx.DiGraph()
53 | for dep, i, gov, j, label in dep:
54 | tree.add_edge(gov, dep, label=label)
55 | return tree
56 |
--------------------------------------------------------------------------------
/my/nltk_utils.py:
--------------------------------------------------------------------------------
1 | import nltk
2 | import numpy as np
3 |
4 |
5 | def _set_span(t, i):
6 | if isinstance(t[0], str):
7 | t.span = (i, i+len(t))
8 | else:
9 | first = True
10 | for c in t:
11 | cur_span = _set_span(c, i)
12 | i = cur_span[1]
13 | if first:
14 | min_ = cur_span[0]
15 | first = False
16 | max_ = cur_span[1]
17 | t.span = (min_, max_)
18 | return t.span
19 |
20 |
21 | def set_span(t):
22 | assert isinstance(t, nltk.tree.Tree)
23 | try:
24 | return _set_span(t, 0)
25 | except:
26 | print(t)
27 | exit()
28 |
29 |
30 | def tree_contains_span(tree, span):
31 | """
32 | Assumes that tree span has been set with set_span
33 | Returns true if any subtree of t has exact span as the given span
34 | :param t:
35 | :param span:
36 | :return bool:
37 | """
38 | return span in set(t.span for t in tree.subtrees())
39 |
40 |
41 | def span_len(span):
42 | return span[1] - span[0]
43 |
44 |
45 | def span_overlap(s1, s2):
46 | start = max(s1[0], s2[0])
47 | stop = min(s1[1], s2[1])
48 | if stop > start:
49 | return start, stop
50 | return None
51 |
52 |
53 | def span_prec(true_span, pred_span):
54 | overlap = span_overlap(true_span, pred_span)
55 | if overlap is None:
56 | return 0
57 | return span_len(overlap) / span_len(pred_span)
58 |
59 |
60 | def span_recall(true_span, pred_span):
61 | overlap = span_overlap(true_span, pred_span)
62 | if overlap is None:
63 | return 0
64 | return span_len(overlap) / span_len(true_span)
65 |
66 |
67 | def span_f1(true_span, pred_span):
68 | p = span_prec(true_span, pred_span)
69 | r = span_recall(true_span, pred_span)
70 | if p == 0 or r == 0:
71 | return 0.0
72 | return 2 * p * r / (p + r)
73 |
74 |
75 | def find_max_f1_span(tree, span):
76 | return find_max_f1_subtree(tree, span).span
77 |
78 |
79 | def find_max_f1_subtree(tree, span):
80 | return max(((t, span_f1(span, t.span)) for t in tree.subtrees()), key=lambda p: p[1])[0]
81 |
82 |
83 | def tree2matrix(tree, node2num, row_size=None, col_size=None, dtype='int32'):
84 | set_span(tree)
85 | D = tree.height() - 1
86 | B = len(tree.leaves())
87 | row_size = row_size or D
88 | col_size = col_size or B
89 | matrix = np.zeros([row_size, col_size], dtype=dtype)
90 | mask = np.zeros([row_size, col_size, col_size], dtype='bool')
91 |
92 | for subtree in tree.subtrees():
93 | row = subtree.height() - 2
94 | col = subtree.span[0]
95 | matrix[row, col] = node2num(subtree)
96 | for subsub in subtree.subtrees():
97 | if isinstance(subsub, nltk.tree.Tree):
98 | mask[row, col, subsub.span[0]] = True
99 | if not isinstance(subsub[0], nltk.tree.Tree):
100 | c = subsub.span[0]
101 | for r in range(row):
102 | mask[r, c, c] = True
103 | else:
104 | mask[row, col, col] = True
105 |
106 | return matrix, mask
107 |
108 |
109 | def load_compressed_tree(s):
110 |
111 | def compress_tree(tree):
112 | assert not isinstance(tree, str)
113 | if len(tree) == 1:
114 | if isinstance(tree[0], nltk.tree.Tree):
115 | return compress_tree(tree[0])
116 | else:
117 | return tree
118 | else:
119 | for i, t in enumerate(tree):
120 | if isinstance(t, nltk.tree.Tree):
121 | tree[i] = compress_tree(t)
122 | else:
123 | tree[i] = t
124 | return tree
125 |
126 | return compress_tree(nltk.tree.Tree.fromstring(s))
127 |
128 |
129 |
130 |
--------------------------------------------------------------------------------
/my/tensorflow/__init__.py:
--------------------------------------------------------------------------------
1 | from my.tensorflow.general import *
--------------------------------------------------------------------------------
/my/tensorflow/general.py:
--------------------------------------------------------------------------------
1 | from itertools import zip_longest
2 |
3 | import tensorflow as tf
4 | from functools import reduce
5 | from operator import mul
6 | import numpy as np
7 |
8 | VERY_BIG_NUMBER = 1e30
9 | VERY_SMALL_NUMBER = 1e-30
10 | VERY_POSITIVE_NUMBER = VERY_BIG_NUMBER
11 | VERY_NEGATIVE_NUMBER = -VERY_BIG_NUMBER
12 |
13 |
14 | def get_initializer(matrix):
15 | def _initializer(shape, dtype=None, partition_info=None, **kwargs): return matrix
16 | return _initializer
17 |
18 |
19 | def variable_on_cpu(name, shape, initializer):
20 | """Helper to create a Variable stored on CPU memory.
21 |
22 | Args:
23 | name: name of the variable
24 | shape: list of ints
25 | initializer: initializer for Variable
26 |
27 | Returns:
28 | Variable Tensor
29 | """
30 | with tf.device('/cpu:0'):
31 | var = tf.get_variable(name, shape, initializer=initializer)
32 | return var
33 |
34 |
35 | def variable_with_weight_decay(name, shape, stddev, wd):
36 | """Helper to create an initialized Variable with weight decay.
37 |
38 | Note that the Variable is initialized with a truncated normal distribution.
39 | A weight decay is added only if one is specified.
40 |
41 | Args:
42 | name: name of the variable
43 | shape: list of ints
44 | stddev: standard deviation of a truncated Gaussian
45 | wd: add L2Loss weight decay multiplied by this float. If None, weight
46 | decay is not added for this Variable.
47 |
48 | Returns:
49 | Variable Tensor
50 | """
51 | var = variable_on_cpu(name, shape,
52 | tf.truncated_normal_initializer(stddev=stddev))
53 | if wd:
54 | weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
55 | tf.add_to_collection('losses', weight_decay)
56 | return var
57 |
58 |
59 | def average_gradients(tower_grads):
60 | """Calculate the average gradient for each shared variable across all towers.
61 |
62 | Note that this function provides a synchronization point across all towers.
63 |
64 | Args:
65 | tower_grads: List of lists of (gradient, variable) tuples. The outer list
66 | is over individual gradients. The inner list is over the gradient
67 | calculation for each tower.
68 | Returns:
69 | List of pairs of (gradient, variable) where the gradient has been averaged
70 | across all towers.
71 | """
72 | average_grads = []
73 | for grad_and_vars in zip(*tower_grads):
74 | # Note that each grad_and_vars looks like the following:
75 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
76 | grads = []
77 | for g, var in grad_and_vars:
78 | # Add 0 dimension to the gradients to represent the tower.
79 | assert g is not None, var.name
80 | expanded_g = tf.expand_dims(g, 0)
81 |
82 | # Append on a 'tower' dimension which we will average over below.
83 | grads.append(expanded_g)
84 |
85 | # Average over the 'tower' dimension.
86 | grad = tf.concat(0, grads)
87 | grad = tf.reduce_mean(grad, 0)
88 |
89 | # Keep in mind that the Variables are redundant because they are shared
90 | # across towers. So .. we will just return the first tower's pointer to
91 | # the Variable.
92 | v = grad_and_vars[0][1]
93 | grad_and_var = (grad, v)
94 | average_grads.append(grad_and_var)
95 | return average_grads
96 |
97 |
98 | def mask(val, mask, name=None):
99 | if name is None:
100 | name = 'mask'
101 | return tf.mul(val, tf.cast(mask, 'float'), name=name)
102 |
103 |
104 | def exp_mask(val, mask, name=None):
105 | """Give very negative number to unmasked elements in val.
106 | For example, [-3, -2, 10], [True, True, False] -> [-3, -2, -1e9].
107 | Typically, this effectively masks in exponential space (e.g. softmax)
108 | Args:
109 | val: values to be masked
110 | mask: masking boolean tensor, same shape as tensor
111 | name: name for output tensor
112 |
113 | Returns:
114 | Same shape as val, where some elements are very small (exponentially zero)
115 | """
116 | if name is None:
117 | name = "exp_mask"
118 | return tf.add(val, (1 - tf.cast(mask, 'float')) * VERY_NEGATIVE_NUMBER, name=name)
119 |
120 |
121 | def flatten(tensor, keep):
122 | fixed_shape = tensor.get_shape().as_list()
123 | start = len(fixed_shape) - keep
124 | left = reduce(mul, [fixed_shape[i] or tf.shape(tensor)[i] for i in range(start)])
125 | out_shape = [left] + [fixed_shape[i] or tf.shape(tensor)[i] for i in range(start, len(fixed_shape))]
126 | flat = tf.reshape(tensor, out_shape)
127 | return flat
128 |
129 |
130 | def reconstruct(tensor, ref, keep):
131 | ref_shape = ref.get_shape().as_list()
132 | tensor_shape = tensor.get_shape().as_list()
133 | ref_stop = len(ref_shape) - keep
134 | tensor_start = len(tensor_shape) - keep
135 | pre_shape = [ref_shape[i] or tf.shape(ref)[i] for i in range(ref_stop)]
136 | keep_shape = [tensor_shape[i] or tf.shape(tensor)[i] for i in range(tensor_start, len(tensor_shape))]
137 | # pre_shape = [tf.shape(ref)[i] for i in range(len(ref.get_shape().as_list()[:-keep]))]
138 | # keep_shape = tensor.get_shape().as_list()[-keep:]
139 | target_shape = pre_shape + keep_shape
140 | out = tf.reshape(tensor, target_shape)
141 | return out
142 |
143 |
144 | def add_wd(wd, scope=None):
145 | scope = scope or tf.get_variable_scope().name
146 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
147 | with tf.name_scope("weight_decay"):
148 | for var in variables:
149 | weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name="{}/wd".format(var.op.name))
150 | tf.add_to_collection('losses', weight_decay)
151 |
152 |
153 | def grouper(iterable, n, fillvalue=None, shorten=False, num_groups=None):
154 | args = [iter(iterable)] * n
155 | out = zip_longest(*args, fillvalue=fillvalue)
156 | out = list(out)
157 | if num_groups is not None:
158 | default = (fillvalue, ) * n
159 | assert isinstance(num_groups, int)
160 | out = list(each for each, _ in zip_longest(out, range(num_groups), fillvalue=default))
161 | if shorten:
162 | assert fillvalue is None
163 | out = (tuple(e for e in each if e is not None) for each in out)
164 | return out
165 |
166 | def padded_reshape(tensor, shape, mode='CONSTANT', name=None):
167 | paddings = [[0, shape[i] - tf.shape(tensor)[i]] for i in range(len(shape))]
168 | return tf.pad(tensor, paddings, mode=mode, name=name)
--------------------------------------------------------------------------------
/my/tensorflow/nn.py:
--------------------------------------------------------------------------------
1 | from tensorflow.python.ops.rnn_cell import _linear
2 | from tensorflow.python.util import nest
3 | import tensorflow as tf
4 |
5 | from my.tensorflow import flatten, reconstruct, add_wd, exp_mask
6 |
7 |
8 | def linear(args, output_size, bias, bias_start=0.0, scope=None, squeeze=False, wd=0.0, input_keep_prob=1.0,
9 | is_train=None):
10 | if args is None or (nest.is_sequence(args) and not args):
11 | raise ValueError("`args` must be specified")
12 | if not nest.is_sequence(args):
13 | args = [args]
14 |
15 | flat_args = [flatten(arg, 1) for arg in args]
16 | if input_keep_prob < 1.0:
17 | assert is_train is not None
18 | flat_args = [tf.cond(is_train, lambda: tf.nn.dropout(arg, input_keep_prob), lambda: arg)
19 | for arg in flat_args]
20 | flat_out = _linear(flat_args, output_size, bias, bias_start=bias_start, scope=scope)
21 | out = reconstruct(flat_out, args[0], 1)
22 | if squeeze:
23 | out = tf.squeeze(out, [len(args[0].get_shape().as_list())-1])
24 | if wd:
25 | add_wd(wd)
26 |
27 | return out
28 |
29 |
30 | def dropout(x, keep_prob, is_train, noise_shape=None, seed=None, name=None):
31 | with tf.name_scope(name or "dropout"):
32 | if keep_prob < 1.0:
33 | d = tf.nn.dropout(x, keep_prob, noise_shape=noise_shape, seed=seed)
34 | out = tf.cond(is_train, lambda: d, lambda: x)
35 | return out
36 | return x
37 |
38 |
39 | def softmax(logits, mask=None, scope=None):
40 | with tf.name_scope(scope or "Softmax"):
41 | if mask is not None:
42 | logits = exp_mask(logits, mask)
43 | flat_logits = flatten(logits, 1)
44 | flat_out = tf.nn.softmax(flat_logits)
45 | out = reconstruct(flat_out, logits, 1)
46 |
47 | return out
48 |
49 |
50 | def softsel(target, logits, mask=None, scope=None):
51 | """
52 |
53 | :param target: [ ..., J, d] dtype=float
54 | :param logits: [ ..., J], dtype=float
55 | :param mask: [ ..., J], dtype=bool
56 | :param scope:
57 | :return: [..., d], dtype=float
58 | """
59 | with tf.name_scope(scope or "Softsel"):
60 | a = softmax(logits, mask=mask)
61 | target_rank = len(target.get_shape().as_list())
62 | out = tf.reduce_sum(tf.expand_dims(a, -1) * target, target_rank - 2)
63 | return out
64 |
65 |
66 | def double_linear_logits(args, size, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None):
67 | with tf.variable_scope(scope or "Double_Linear_Logits"):
68 | first = tf.tanh(linear(args, size, bias, bias_start=bias_start, scope='first',
69 | wd=wd, input_keep_prob=input_keep_prob, is_train=is_train))
70 | second = linear(first, 1, bias, bias_start=bias_start, squeeze=True, scope='second',
71 | wd=wd, input_keep_prob=input_keep_prob, is_train=is_train)
72 | if mask is not None:
73 | second = exp_mask(second, mask)
74 | return second
75 |
76 |
77 | def linear_logits(args, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None):
78 | with tf.variable_scope(scope or "Linear_Logits"):
79 | logits = linear(args, 1, bias, bias_start=bias_start, squeeze=True, scope='first',
80 | wd=wd, input_keep_prob=input_keep_prob, is_train=is_train)
81 | if mask is not None:
82 | logits = exp_mask(logits, mask)
83 | return logits
84 |
85 |
86 | def sum_logits(args, mask=None, name=None):
87 | with tf.name_scope(name or "sum_logits"):
88 | if args is None or (nest.is_sequence(args) and not args):
89 | raise ValueError("`args` must be specified")
90 | if not nest.is_sequence(args):
91 | args = [args]
92 | rank = len(args[0].get_shape())
93 | logits = sum(tf.reduce_sum(arg, rank-1) for arg in args)
94 | if mask is not None:
95 | logits = exp_mask(logits, mask)
96 | return logits
97 |
98 |
99 | def get_logits(args, size, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None, func=None):
100 | if func is None:
101 | func = "sum"
102 | if func == 'sum':
103 | return sum_logits(args, mask=mask, name=scope)
104 | elif func == 'linear':
105 | return linear_logits(args, bias, bias_start=bias_start, scope=scope, mask=mask, wd=wd, input_keep_prob=input_keep_prob,
106 | is_train=is_train)
107 | elif func == 'double':
108 | return double_linear_logits(args, size, bias, bias_start=bias_start, scope=scope, mask=mask, wd=wd, input_keep_prob=input_keep_prob,
109 | is_train=is_train)
110 | elif func == 'dot':
111 | assert len(args) == 2
112 | arg = args[0] * args[1]
113 | return sum_logits([arg], mask=mask, name=scope)
114 | elif func == 'mul_linear':
115 | assert len(args) == 2
116 | arg = args[0] * args[1]
117 | return linear_logits([arg], bias, bias_start=bias_start, scope=scope, mask=mask, wd=wd, input_keep_prob=input_keep_prob,
118 | is_train=is_train)
119 | elif func == 'proj':
120 | assert len(args) == 2
121 | d = args[1].get_shape()[-1]
122 | proj = linear([args[0]], d, False, bias_start=bias_start, scope=scope, wd=wd, input_keep_prob=input_keep_prob,
123 | is_train=is_train)
124 | return sum_logits([proj * args[1]], mask=mask)
125 | elif func == 'tri_linear':
126 | assert len(args) == 2
127 | new_arg = args[0] * args[1]
128 | return linear_logits([args[0], args[1], new_arg], bias, bias_start=bias_start, scope=scope, mask=mask, wd=wd, input_keep_prob=input_keep_prob,
129 | is_train=is_train)
130 | else:
131 | raise Exception()
132 |
133 |
134 | def highway_layer(arg, bias, bias_start=0.0, scope=None, wd=0.0, input_keep_prob=1.0, is_train=None):
135 | with tf.variable_scope(scope or "highway_layer"):
136 | d = arg.get_shape()[-1]
137 | trans = linear([arg], d, bias, bias_start=bias_start, scope='trans', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train)
138 | trans = tf.nn.relu(trans)
139 | gate = linear([arg], d, bias, bias_start=bias_start, scope='gate', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train)
140 | gate = tf.nn.sigmoid(gate)
141 | out = gate * trans + (1 - gate) * arg
142 | return out
143 |
144 |
145 | def highway_network(arg, num_layers, bias, bias_start=0.0, scope=None, wd=0.0, input_keep_prob=1.0, is_train=None):
146 | with tf.variable_scope(scope or "highway_network"):
147 | prev = arg
148 | cur = None
149 | for layer_idx in range(num_layers):
150 | cur = highway_layer(prev, bias, bias_start=bias_start, scope="layer_{}".format(layer_idx), wd=wd,
151 | input_keep_prob=input_keep_prob, is_train=is_train)
152 | prev = cur
153 | return cur
154 |
155 |
156 | def conv1d(in_, filter_size, height, padding, is_train=None, keep_prob=1.0, scope=None):
157 | with tf.variable_scope(scope or "conv1d"):
158 | num_channels = in_.get_shape()[-1]
159 | filter_ = tf.get_variable("filter", shape=[1, height, num_channels, filter_size], dtype='float')
160 | bias = tf.get_variable("bias", shape=[filter_size], dtype='float')
161 | strides = [1, 1, 1, 1]
162 | if is_train is not None and keep_prob < 1.0:
163 | in_ = dropout(in_, keep_prob, is_train)
164 | xxc = tf.nn.conv2d(in_, filter_, strides, padding) + bias # [N*M, JX, W/filter_stride, d]
165 | out = tf.reduce_max(tf.nn.relu(xxc), 2) # [-1, JX, d]
166 | return out
167 |
168 |
169 | def multi_conv1d(in_, filter_sizes, heights, padding, is_train=None, keep_prob=1.0, scope=None):
170 | with tf.variable_scope(scope or "multi_conv1d"):
171 | assert len(filter_sizes) == len(heights)
172 | outs = []
173 | for filter_size, height in zip(filter_sizes, heights):
174 | if filter_size == 0:
175 | continue
176 | out = conv1d(in_, filter_size, height, padding, is_train=is_train, keep_prob=keep_prob, scope="conv1d_{}".format(height))
177 | outs.append(out)
178 | concat_out = tf.concat(2, outs)
179 | return concat_out
180 |
--------------------------------------------------------------------------------
/my/tensorflow/rnn.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.python.ops.rnn import dynamic_rnn as _dynamic_rnn, \
3 | bidirectional_dynamic_rnn as _bidirectional_dynamic_rnn
4 | from tensorflow.python.ops.rnn import bidirectional_rnn as _bidirectional_rnn
5 |
6 | from my.tensorflow import flatten, reconstruct
7 |
8 |
9 | def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
10 | dtype=None, parallel_iterations=None, swap_memory=False,
11 | time_major=False, scope=None):
12 | assert not time_major # TODO : to be implemented later!
13 | flat_inputs = flatten(inputs, 2) # [-1, J, d]
14 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64')
15 |
16 | flat_outputs, final_state = _dynamic_rnn(cell, flat_inputs, sequence_length=flat_len,
17 | initial_state=initial_state, dtype=dtype,
18 | parallel_iterations=parallel_iterations, swap_memory=swap_memory,
19 | time_major=time_major, scope=scope)
20 |
21 | outputs = reconstruct(flat_outputs, inputs, 2)
22 | return outputs, final_state
23 |
24 |
25 | def bw_dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
26 | dtype=None, parallel_iterations=None, swap_memory=False,
27 | time_major=False, scope=None):
28 | assert not time_major # TODO : to be implemented later!
29 |
30 | flat_inputs = flatten(inputs, 2) # [-1, J, d]
31 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64')
32 |
33 | flat_inputs = tf.reverse(flat_inputs, 1) if sequence_length is None \
34 | else tf.reverse_sequence(flat_inputs, sequence_length, 1)
35 | flat_outputs, final_state = _dynamic_rnn(cell, flat_inputs, sequence_length=flat_len,
36 | initial_state=initial_state, dtype=dtype,
37 | parallel_iterations=parallel_iterations, swap_memory=swap_memory,
38 | time_major=time_major, scope=scope)
39 | flat_outputs = tf.reverse(flat_outputs, 1) if sequence_length is None \
40 | else tf.reverse_sequence(flat_outputs, sequence_length, 1)
41 |
42 | outputs = reconstruct(flat_outputs, inputs, 2)
43 | return outputs, final_state
44 |
45 |
46 | def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
47 | initial_state_fw=None, initial_state_bw=None,
48 | dtype=None, parallel_iterations=None,
49 | swap_memory=False, time_major=False, scope=None):
50 | assert not time_major
51 |
52 | flat_inputs = flatten(inputs, 2) # [-1, J, d]
53 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64')
54 |
55 | (flat_fw_outputs, flat_bw_outputs), final_state = \
56 | _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len,
57 | initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw,
58 | dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory,
59 | time_major=time_major, scope=scope)
60 |
61 | fw_outputs = reconstruct(flat_fw_outputs, inputs, 2)
62 | bw_outputs = reconstruct(flat_bw_outputs, inputs, 2)
63 | # FIXME : final state is not reshaped!
64 | return (fw_outputs, bw_outputs), final_state
65 |
66 |
67 | def bidirectional_rnn(cell_fw, cell_bw, inputs,
68 | initial_state_fw=None, initial_state_bw=None,
69 | dtype=None, sequence_length=None, scope=None):
70 |
71 | flat_inputs = flatten(inputs, 2) # [-1, J, d]
72 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64')
73 |
74 | (flat_fw_outputs, flat_bw_outputs), final_state = \
75 | _bidirectional_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len,
76 | initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw,
77 | dtype=dtype, scope=scope)
78 |
79 | fw_outputs = reconstruct(flat_fw_outputs, inputs, 2)
80 | bw_outputs = reconstruct(flat_bw_outputs, inputs, 2)
81 | # FIXME : final state is not reshaped!
82 | return (fw_outputs, bw_outputs), final_state
83 |
--------------------------------------------------------------------------------
/my/tensorflow/rnn_cell.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.python.ops.rnn_cell import DropoutWrapper, RNNCell, LSTMStateTuple
3 |
4 | from my.tensorflow import exp_mask, flatten
5 | from my.tensorflow.nn import linear, softsel, double_linear_logits
6 |
7 |
8 | class SwitchableDropoutWrapper(DropoutWrapper):
9 | def __init__(self, cell, is_train, input_keep_prob=1.0, output_keep_prob=1.0,
10 | seed=None):
11 | super(SwitchableDropoutWrapper, self).__init__(cell, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob,
12 | seed=seed)
13 | self.is_train = is_train
14 |
15 | def __call__(self, inputs, state, scope=None):
16 | outputs_do, new_state_do = super(SwitchableDropoutWrapper, self).__call__(inputs, state, scope=scope)
17 | tf.get_variable_scope().reuse_variables()
18 | outputs, new_state = self._cell(inputs, state, scope)
19 | outputs = tf.cond(self.is_train, lambda: outputs_do, lambda: outputs)
20 | if isinstance(state, tuple):
21 | new_state = state.__class__(*[tf.cond(self.is_train, lambda: new_state_do_i, lambda: new_state_i)
22 | for new_state_do_i, new_state_i in zip(new_state_do, new_state)])
23 | else:
24 | new_state = tf.cond(self.is_train, lambda: new_state_do, lambda: new_state)
25 | return outputs, new_state
26 |
27 |
28 | class TreeRNNCell(RNNCell):
29 | def __init__(self, cell, input_size, reduce_func):
30 | self._cell = cell
31 | self._input_size = input_size
32 | self._reduce_func = reduce_func
33 |
34 | def __call__(self, inputs, state, scope=None):
35 | """
36 | :param inputs: [N*B, I + B]
37 | :param state: [N*B, d]
38 | :param scope:
39 | :return: [N*B, d]
40 | """
41 | with tf.variable_scope(scope or self.__class__.__name__):
42 | d = self.state_size
43 | x = tf.slice(inputs, [0, 0], [-1, self._input_size]) # [N*B, I]
44 | mask = tf.slice(inputs, [0, self._input_size], [-1, -1]) # [N*B, B]
45 | B = tf.shape(mask)[1]
46 | prev_state = tf.expand_dims(tf.reshape(state, [-1, B, d]), 1) # [N, B, d] -> [N, 1, B, d]
47 | mask = tf.tile(tf.expand_dims(tf.reshape(mask, [-1, B, B]), -1), [1, 1, 1, d]) # [N, B, B, d]
48 | # prev_state = self._reduce_func(tf.tile(prev_state, [1, B, 1, 1]), 2)
49 | prev_state = self._reduce_func(exp_mask(prev_state, mask), 2) # [N, B, d]
50 | prev_state = tf.reshape(prev_state, [-1, d]) # [N*B, d]
51 | return self._cell(x, prev_state)
52 |
53 | @property
54 | def state_size(self):
55 | return self._cell.state_size
56 |
57 | @property
58 | def output_size(self):
59 | return self._cell.output_size
60 |
61 |
62 | class NoOpCell(RNNCell):
63 | def __init__(self, num_units):
64 | self._num_units = num_units
65 |
66 | def __call__(self, inputs, state, scope=None):
67 | return state, state
68 |
69 | @property
70 | def state_size(self):
71 | return self._num_units
72 |
73 | @property
74 | def output_size(self):
75 | return self._num_units
76 |
77 |
78 | class MatchCell(RNNCell):
79 | def __init__(self, cell, input_size, q_len):
80 | self._cell = cell
81 | self._input_size = input_size
82 | # FIXME : This won't be needed with good shape guessing
83 | self._q_len = q_len
84 |
85 | @property
86 | def state_size(self):
87 | return self._cell.state_size
88 |
89 | @property
90 | def output_size(self):
91 | return self._cell.output_size
92 |
93 | def __call__(self, inputs, state, scope=None):
94 | """
95 |
96 | :param inputs: [N, d + JQ + JQ * d]
97 | :param state: [N, d]
98 | :param scope:
99 | :return:
100 | """
101 | with tf.variable_scope(scope or self.__class__.__name__):
102 | c_prev, h_prev = state
103 | x = tf.slice(inputs, [0, 0], [-1, self._input_size])
104 | q_mask = tf.slice(inputs, [0, self._input_size], [-1, self._q_len]) # [N, JQ]
105 | qs = tf.slice(inputs, [0, self._input_size + self._q_len], [-1, -1])
106 | qs = tf.reshape(qs, [-1, self._q_len, self._input_size]) # [N, JQ, d]
107 | x_tiled = tf.tile(tf.expand_dims(x, 1), [1, self._q_len, 1]) # [N, JQ, d]
108 | h_prev_tiled = tf.tile(tf.expand_dims(h_prev, 1), [1, self._q_len, 1]) # [N, JQ, d]
109 | f = tf.tanh(linear([qs, x_tiled, h_prev_tiled], self._input_size, True, scope='f')) # [N, JQ, d]
110 | a = tf.nn.softmax(exp_mask(linear(f, 1, True, squeeze=True, scope='a'), q_mask)) # [N, JQ]
111 | q = tf.reduce_sum(qs * tf.expand_dims(a, -1), 1)
112 | z = tf.concat(1, [x, q]) # [N, 2d]
113 | return self._cell(z, state)
114 |
115 |
116 | class AttentionCell(RNNCell):
117 | def __init__(self, cell, memory, mask=None, controller=None, mapper=None, input_keep_prob=1.0, is_train=None):
118 | """
119 | Early fusion attention cell: uses the (inputs, state) to control the current attention.
120 |
121 | :param cell:
122 | :param memory: [N, M, m]
123 | :param mask:
124 | :param controller: (inputs, prev_state, memory) -> memory_logits
125 | """
126 | self._cell = cell
127 | self._memory = memory
128 | self._mask = mask
129 | self._flat_memory = flatten(memory, 2)
130 | self._flat_mask = flatten(mask, 1)
131 | if controller is None:
132 | controller = AttentionCell.get_linear_controller(True, is_train=is_train)
133 | self._controller = controller
134 | if mapper is None:
135 | mapper = AttentionCell.get_concat_mapper()
136 | elif mapper == 'sim':
137 | mapper = AttentionCell.get_sim_mapper()
138 | self._mapper = mapper
139 |
140 | @property
141 | def state_size(self):
142 | return self._cell.state_size
143 |
144 | @property
145 | def output_size(self):
146 | return self._cell.output_size
147 |
148 | def __call__(self, inputs, state, scope=None):
149 | with tf.variable_scope(scope or "AttentionCell"):
150 | memory_logits = self._controller(inputs, state, self._flat_memory)
151 | sel_mem = softsel(self._flat_memory, memory_logits, mask=self._flat_mask) # [N, m]
152 | new_inputs, new_state = self._mapper(inputs, state, sel_mem)
153 | return self._cell(new_inputs, state)
154 |
155 | @staticmethod
156 | def get_double_linear_controller(size, bias, input_keep_prob=1.0, is_train=None):
157 | def double_linear_controller(inputs, state, memory):
158 | """
159 |
160 | :param inputs: [N, i]
161 | :param state: [N, d]
162 | :param memory: [N, M, m]
163 | :return: [N, M]
164 | """
165 | rank = len(memory.get_shape())
166 | _memory_size = tf.shape(memory)[rank-2]
167 | tiled_inputs = tf.tile(tf.expand_dims(inputs, 1), [1, _memory_size, 1])
168 | if isinstance(state, tuple):
169 | tiled_states = [tf.tile(tf.expand_dims(each, 1), [1, _memory_size, 1])
170 | for each in state]
171 | else:
172 | tiled_states = [tf.tile(tf.expand_dims(state, 1), [1, _memory_size, 1])]
173 |
174 | # [N, M, d]
175 | in_ = tf.concat(2, [tiled_inputs] + tiled_states + [memory])
176 | out = double_linear_logits(in_, size, bias, input_keep_prob=input_keep_prob,
177 | is_train=is_train)
178 | return out
179 | return double_linear_controller
180 |
181 | @staticmethod
182 | def get_linear_controller(bias, input_keep_prob=1.0, is_train=None):
183 | def linear_controller(inputs, state, memory):
184 | rank = len(memory.get_shape())
185 | _memory_size = tf.shape(memory)[rank-2]
186 | tiled_inputs = tf.tile(tf.expand_dims(inputs, 1), [1, _memory_size, 1])
187 | if isinstance(state, tuple):
188 | tiled_states = [tf.tile(tf.expand_dims(each, 1), [1, _memory_size, 1])
189 | for each in state]
190 | else:
191 | tiled_states = [tf.tile(tf.expand_dims(state, 1), [1, _memory_size, 1])]
192 |
193 | # [N, M, d]
194 | in_ = tf.concat(2, [tiled_inputs] + tiled_states + [memory])
195 | out = linear(in_, 1, bias, squeeze=True, input_keep_prob=input_keep_prob, is_train=is_train)
196 | return out
197 | return linear_controller
198 |
199 | @staticmethod
200 | def get_concat_mapper():
201 | def concat_mapper(inputs, state, sel_mem):
202 | """
203 |
204 | :param inputs: [N, i]
205 | :param state: [N, d]
206 | :param sel_mem: [N, m]
207 | :return: (new_inputs, new_state) tuple
208 | """
209 | return tf.concat(1, [inputs, sel_mem]), state
210 | return concat_mapper
211 |
212 | @staticmethod
213 | def get_sim_mapper():
214 | def sim_mapper(inputs, state, sel_mem):
215 | """
216 | Assume that inputs and sel_mem are the same size
217 | :param inputs: [N, i]
218 | :param state: [N, d]
219 | :param sel_mem: [N, i]
220 | :return: (new_inputs, new_state) tuple
221 | """
222 | return tf.concat(1, [inputs, sel_mem, inputs * sel_mem, tf.abs(inputs - sel_mem)]), state
223 | return sim_mapper
224 |
--------------------------------------------------------------------------------
/my/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import deque
3 |
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 |
8 | def mytqdm(list_, desc="", show=True):
9 | if show:
10 | pbar = tqdm(list_)
11 | pbar.set_description(desc)
12 | return pbar
13 | return list_
14 |
15 |
16 | def json_pretty_dump(obj, fh):
17 | return json.dump(obj, fh, sort_keys=True, indent=2, separators=(',', ': '))
18 |
19 |
20 | def index(l, i):
21 | return index(l[i[0]], i[1:]) if len(i) > 1 else l[i[0]]
22 |
23 |
24 | def fill(l, shape, dtype=None):
25 | out = np.zeros(shape, dtype=dtype)
26 | stack = deque()
27 | stack.appendleft(((), l))
28 | while len(stack) > 0:
29 | indices, cur = stack.pop()
30 | if len(indices) < shape:
31 | for i, sub in enumerate(cur):
32 | stack.appendleft([indices + (i,), sub])
33 | else:
34 | out[indices] = cur
35 | return out
36 |
37 |
38 | def short_floats(o, precision):
39 | class ShortFloat(float):
40 | def __repr__(self):
41 | return '%.{}g'.format(precision) % self
42 |
43 | def _short_floats(obj):
44 | if isinstance(obj, float):
45 | return ShortFloat(obj)
46 | elif isinstance(obj, dict):
47 | return dict((k, _short_floats(v)) for k, v in obj.items())
48 | elif isinstance(obj, (list, tuple)):
49 | return tuple(map(_short_floats, obj))
50 | return obj
51 |
52 | return _short_floats(o)
53 |
54 |
55 | def argmax(x):
56 | return np.unravel_index(x.argmax(), x.shape)
57 |
58 |
59 |
--------------------------------------------------------------------------------
/my/zip_save.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import shutil
5 | from zipfile import ZipFile
6 |
7 | from tqdm import tqdm
8 |
9 |
10 | def get_args():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('paths', nargs='+')
13 | parser.add_argument('-o', '--out', default='save.zip')
14 | args = parser.parse_args()
15 | return args
16 |
17 |
18 | def zip_save(args):
19 | temp_dir = "."
20 | save_dir = os.path.join(temp_dir, "save")
21 | if not os.path.exists(save_dir):
22 | os.makedirs(save_dir)
23 | for save_source_path in tqdm(args.paths):
24 | # path = "out/basic/30/save/basic-18000"
25 | # target_path = "save_dir/30/save"
26 | # also output full path name to "save_dir/30/readme.txt
27 | # need to also extract "out/basic/30/shared.json"
28 | temp, _ = os.path.split(save_source_path) # "out/basic/30/save", _
29 | model_dir, _ = os.path.split(temp) # "out/basic/30, _
30 | _, model_name = os.path.split(model_dir)
31 | cur_dir = os.path.join(save_dir, model_name)
32 | if not os.path.exists(cur_dir):
33 | os.makedirs(cur_dir)
34 | save_target_path = os.path.join(cur_dir, "save")
35 | shared_target_path = os.path.join(cur_dir, "shared.json")
36 | readme_path = os.path.join(cur_dir, "readme.txt")
37 | shared_source_path = os.path.join(model_dir, "shared.json")
38 | shutil.copy(save_source_path, save_target_path)
39 | shutil.copy(shared_source_path, shared_target_path)
40 | with open(readme_path, 'w') as fh:
41 | fh.write(save_source_path)
42 |
43 | os.system("zip {} -r {}".format(args.out, save_dir))
44 |
45 | def main():
46 | args = get_args()
47 | zip_save(args)
48 |
49 | if __name__ == "__main__":
50 | main()
51 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow>=0.11
2 | nltk
3 | tqdm
4 | jinja2
--------------------------------------------------------------------------------
/squad/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/bi-att-flow/49004549e9a88b78c359b31481afa7792dbb3f4a/squad/__init__.py
--------------------------------------------------------------------------------
/squad/aug_squad.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 |
4 | from tqdm import tqdm
5 |
6 | from my.corenlp_interface import CoreNLPInterface
7 |
8 | in_path = sys.argv[1]
9 | out_path = sys.argv[2]
10 | url = sys.argv[3]
11 | port = int(sys.argv[4])
12 | data = json.load(open(in_path, 'r'))
13 |
14 | h = CoreNLPInterface(url, port)
15 |
16 |
17 | def find_all(a_str, sub):
18 | start = 0
19 | while True:
20 | start = a_str.find(sub, start)
21 | if start == -1: return
22 | yield start
23 | start += len(sub) # use start += 1 to find overlapping matches
24 |
25 |
26 | def to_hex(s):
27 | return " ".join(map(hex, map(ord, s)))
28 |
29 |
30 | def handle_nobreak(cand, text):
31 | if cand == text:
32 | return cand
33 | if cand.replace(u'\u00A0', ' ') == text:
34 | return cand
35 | elif cand == text.replace(u'\u00A0', ' '):
36 | return text
37 | raise Exception("{} '{}' {} '{}'".format(cand, to_hex(cand), text, to_hex(text)))
38 |
39 |
40 | # resolving unicode complication
41 |
42 | wrong_loc_count = 0
43 | loc_diffs = []
44 |
45 | for article in data['data']:
46 | for para in article['paragraphs']:
47 | para['context'] = para['context'].replace(u'\u000A', '')
48 | para['context'] = para['context'].replace(u'\u00A0', ' ')
49 | context = para['context']
50 | for qa in para['qas']:
51 | for answer in qa['answers']:
52 | answer['text'] = answer['text'].replace(u'\u00A0', ' ')
53 | text = answer['text']
54 | answer_start = answer['answer_start']
55 | if context[answer_start:answer_start + len(text)] == text:
56 | if text.lstrip() == text:
57 | pass
58 | else:
59 | answer_start += len(text) - len(text.lstrip())
60 | answer['answer_start'] = answer_start
61 | text = text.lstrip()
62 | answer['text'] = text
63 | else:
64 | wrong_loc_count += 1
65 | text = text.lstrip()
66 | answer['text'] = text
67 | starts = list(find_all(context, text))
68 | if len(starts) == 1:
69 | answer_start = starts[0]
70 | elif len(starts) > 1:
71 | new_answer_start = min(starts, key=lambda s: abs(s - answer_start))
72 | loc_diffs.append(abs(new_answer_start - answer_start))
73 | answer_start = new_answer_start
74 | else:
75 | raise Exception()
76 | answer['answer_start'] = answer_start
77 |
78 | answer_stop = answer_start + len(text)
79 | answer['answer_stop'] = answer_stop
80 | assert para['context'][answer_start:answer_stop] == answer['text'], "{} {}".format(
81 | para['context'][answer_start:answer_stop], answer['text'])
82 |
83 | print(wrong_loc_count, loc_diffs)
84 |
85 | mismatch_count = 0
86 | dep_fail_count = 0
87 | no_answer_count = 0
88 |
89 | size = sum(len(article['paragraphs']) for article in data['data'])
90 | pbar = tqdm(range(size))
91 |
92 | for ai, article in enumerate(data['data']):
93 | for pi, para in enumerate(article['paragraphs']):
94 | context = para['context']
95 | sents = h.split_doc(context)
96 | words = h.split_sent(context)
97 | sent_starts = []
98 | ref_idx = 0
99 | for sent in sents:
100 | new_idx = context.find(sent, ref_idx)
101 | sent_starts.append(new_idx)
102 | ref_idx = new_idx + len(sent)
103 | para['sents'] = sents
104 | para['words'] = words
105 | para['sent_starts'] = sent_starts
106 |
107 | consts = list(map(h.get_const, sents))
108 | para['consts'] = consts
109 | deps = list(map(h.get_dep, sents))
110 | para['deps'] = deps
111 |
112 | for qa in para['qas']:
113 | question = qa['question']
114 | question_const = h.get_const(question)
115 | qa['const'] = question_const
116 | question_dep = h.get_dep(question)
117 | qa['dep'] = question_dep
118 | qa['words'] = h.split_sent(question)
119 |
120 | for answer in qa['answers']:
121 | answer_start = answer['answer_start']
122 | text = answer['text']
123 | answer_stop = answer_start + len(text)
124 | # answer_words = h.split_sent(text)
125 | word_idxs = []
126 | answer_words = []
127 | for sent_idx, (sent, sent_start, dep) in enumerate(zip(sents, sent_starts, deps)):
128 | if dep is None:
129 | print("dep parse failed at {} {} {}".format(ai, pi, sent_idx))
130 | dep_fail_count += 1
131 | continue
132 | nodes, edges = dep
133 | words = [node[0] for node in nodes]
134 |
135 | for word_idx, (word, _, _, start, _) in enumerate(nodes):
136 | global_start = sent_start + start
137 | global_stop = global_start + len(word)
138 | if answer_start <= global_start < answer_stop or answer_start < global_stop <= answer_stop:
139 | word_idxs.append((sent_idx, word_idx))
140 | answer_words.append(word)
141 | if len(word_idxs) > 0:
142 | answer['answer_word_start'] = word_idxs[0]
143 | answer['answer_word_stop'] = word_idxs[-1][0], word_idxs[-1][1] + 1
144 | if not text.startswith(answer_words[0]):
145 | print("'{}' '{}'".format(text, ' '.join(answer_words)))
146 | mismatch_count += 1
147 | else:
148 | answer['answer_word_start'] = None
149 | answer['answer_word_stop'] = None
150 | no_answer_count += 1
151 | pbar.update(1)
152 | pbar.close()
153 |
154 | print(mismatch_count, dep_fail_count, no_answer_count)
155 |
156 | print("saving...")
157 | json.dump(data, open(out_path, 'w'))
--------------------------------------------------------------------------------
/squad/eda_aug_dev.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import json\n",
12 | "\n",
13 | "aug_data_path = \"/Users/minjoons/data/squad/dev-v1.0-aug.json\"\n",
14 | "aug_data = json.load(open(aug_data_path, 'r'))"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 17,
20 | "metadata": {
21 | "collapsed": false
22 | },
23 | "outputs": [
24 | {
25 | "name": "stdout",
26 | "output_type": "stream",
27 | "text": [
28 | "(['Denver', 'Broncos'], 'Denver Broncos')\n",
29 | "(['Denver', 'Broncos'], 'Denver Broncos')\n",
30 | "(['Denver', 'Broncos'], 'Denver Broncos ')\n",
31 | "(['Carolina', 'Panthers'], 'Carolina Panthers')\n"
32 | ]
33 | }
34 | ],
35 | "source": [
36 | "def compare_answers():\n",
37 | " for article in aug_data['data']:\n",
38 | " for para in article['paragraphs']:\n",
39 | " deps = para['deps']\n",
40 | " nodess = []\n",
41 | " for dep in deps:\n",
42 | " nodes, edges = dep\n",
43 | " if dep is not None:\n",
44 | " nodess.append(nodes)\n",
45 | " else:\n",
46 | " nodess.append([])\n",
47 | " wordss = [[node[0] for node in nodes] for nodes in nodess]\n",
48 | " for qa in para['qas']:\n",
49 | " for answer in qa['answers']:\n",
50 | " text = answer['text']\n",
51 | " word_start = answer['answer_word_start']\n",
52 | " word_stop = answer['answer_word_stop']\n",
53 | " answer_words = wordss[word_start[0]][word_start[1]:word_stop[1]]\n",
54 | " yield answer_words, text\n",
55 | "\n",
56 | "ca = compare_answers()\n",
57 | "print(next(ca))\n",
58 | "print(next(ca))\n",
59 | "print(next(ca))\n",
60 | "print(next(ca))"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 18,
66 | "metadata": {
67 | "collapsed": false
68 | },
69 | "outputs": [
70 | {
71 | "name": "stdout",
72 | "output_type": "stream",
73 | "text": [
74 | "8\n"
75 | ]
76 | }
77 | ],
78 | "source": [
79 | "def counter():\n",
80 | " count = 0\n",
81 | " for article in aug_data['data']:\n",
82 | " for para in article['paragraphs']:\n",
83 | " deps = para['deps']\n",
84 | " nodess = []\n",
85 | " for dep in deps:\n",
86 | " if dep is None:\n",
87 | " count += 1\n",
88 | " print(count)\n",
89 | "counter()\n"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 19,
95 | "metadata": {
96 | "collapsed": false
97 | },
98 | "outputs": [
99 | {
100 | "name": "stdout",
101 | "output_type": "stream",
102 | "text": [
103 | "0\n"
104 | ]
105 | }
106 | ],
107 | "source": [
108 | "def bad_node_counter():\n",
109 | " count = 0\n",
110 | " for article in aug_data['data']:\n",
111 | " for para in article['paragraphs']:\n",
112 | " sents = para['sents']\n",
113 | " deps = para['deps']\n",
114 | " nodess = []\n",
115 | " for dep in deps:\n",
116 | " if dep is not None:\n",
117 | " nodes, edges = dep\n",
118 | " for node in nodes:\n",
119 | " if len(node) != 5:\n",
120 | " count += 1\n",
121 | " print(count)\n",
122 | "bad_node_counter() "
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": 20,
128 | "metadata": {
129 | "collapsed": false
130 | },
131 | "outputs": [
132 | {
133 | "name": "stdout",
134 | "output_type": "stream",
135 | "text": [
136 | "7\n"
137 | ]
138 | }
139 | ],
140 | "source": [
141 | "def noanswer_counter():\n",
142 | " count = 0\n",
143 | " for article in aug_data['data']:\n",
144 | " for para in article['paragraphs']:\n",
145 | " deps = para['deps']\n",
146 | " nodess = []\n",
147 | " for dep in deps:\n",
148 | " if dep is not None:\n",
149 | " nodes, edges = dep\n",
150 | " nodess.append(nodes)\n",
151 | " else:\n",
152 | " nodess.append([])\n",
153 | " wordss = [[node[0] for node in nodes] for nodes in nodess]\n",
154 | " for qa in para['qas']:\n",
155 | " for answer in qa['answers']:\n",
156 | " text = answer['text']\n",
157 | " word_start = answer['answer_word_start']\n",
158 | " word_stop = answer['answer_word_stop']\n",
159 | " if word_start is None:\n",
160 | " count += 1\n",
161 | " print(count)\n",
162 | "noanswer_counter()"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": 22,
168 | "metadata": {
169 | "collapsed": false
170 | },
171 | "outputs": [
172 | {
173 | "name": "stdout",
174 | "output_type": "stream",
175 | "text": [
176 | "10600\n"
177 | ]
178 | }
179 | ],
180 | "source": [
181 | "print(sum(len(para['qas']) for a in aug_data['data'] for para in a['paragraphs']))"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": 5,
187 | "metadata": {
188 | "collapsed": false
189 | },
190 | "outputs": [
191 | {
192 | "name": "stdout",
193 | "output_type": "stream",
194 | "text": [
195 | "10348\n"
196 | ]
197 | }
198 | ],
199 | "source": [
200 | "import nltk\n",
201 | "\n",
202 | "def _set_span(t, i):\n",
203 | " if isinstance(t[0], str):\n",
204 | " t.span = (i, i+len(t))\n",
205 | " else:\n",
206 | " first = True\n",
207 | " for c in t:\n",
208 | " cur_span = _set_span(c, i)\n",
209 | " i = cur_span[1]\n",
210 | " if first:\n",
211 | " min_ = cur_span[0]\n",
212 | " first = False\n",
213 | " max_ = cur_span[1]\n",
214 | " t.span = (min_, max_)\n",
215 | " return t.span\n",
216 | "\n",
217 | "\n",
218 | "def set_span(t):\n",
219 | " assert isinstance(t, nltk.tree.Tree)\n",
220 | " try:\n",
221 | " return _set_span(t, 0)\n",
222 | " except:\n",
223 | " print(t)\n",
224 | " exit()\n",
225 | "\n",
226 | "def same_span_counter():\n",
227 | " count = 0\n",
228 | " for article in aug_data['data']:\n",
229 | " for para in article['paragraphs']:\n",
230 | " consts = para['consts']\n",
231 | " for const in consts:\n",
232 | " tree = nltk.tree.Tree.fromstring(const)\n",
233 | " set_span(tree)\n",
234 | " if len(list(tree.subtrees())) > len(set(t.span for t in tree.subtrees())):\n",
235 | " count += 1\n",
236 | " print(count)\n",
237 | "same_span_counter()"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": null,
243 | "metadata": {
244 | "collapsed": true
245 | },
246 | "outputs": [],
247 | "source": []
248 | }
249 | ],
250 | "metadata": {
251 | "kernelspec": {
252 | "display_name": "Python 3",
253 | "language": "python",
254 | "name": "python3"
255 | },
256 | "language_info": {
257 | "codemirror_mode": {
258 | "name": "ipython",
259 | "version": 3
260 | },
261 | "file_extension": ".py",
262 | "mimetype": "text/x-python",
263 | "name": "python",
264 | "nbconvert_exporter": "python",
265 | "pygments_lexer": "ipython3",
266 | "version": "3.5.1"
267 | }
268 | },
269 | "nbformat": 4,
270 | "nbformat_minor": 0
271 | }
272 |
--------------------------------------------------------------------------------
/squad/eda_aug_train.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import json\n",
12 | "\n",
13 | "aug_data_path = \"/Users/minjoons/data/squad/train-v1.0-aug.json\"\n",
14 | "aug_data = json.load(open(aug_data_path, 'r'))"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "metadata": {
21 | "collapsed": false
22 | },
23 | "outputs": [
24 | {
25 | "name": "stdout",
26 | "output_type": "stream",
27 | "text": [
28 | "(['Saint', 'Bernadette', 'Soubirous'], 'Saint Bernadette Soubirous')\n",
29 | "(['a', 'copper', 'statue', 'of', 'Christ'], 'a copper statue of Christ')\n",
30 | "(['the', 'Main', 'Building'], 'the Main Building')\n",
31 | "(['a', 'Marian', 'place', 'of', 'prayer', 'and', 'reflection'], 'a Marian place of prayer and reflection')\n"
32 | ]
33 | }
34 | ],
35 | "source": [
36 | "def compare_answers():\n",
37 | " for article in aug_data['data']:\n",
38 | " for para in article['paragraphs']:\n",
39 | " deps = para['deps']\n",
40 | " nodess = []\n",
41 | " for dep in deps:\n",
42 | " nodes, edges = dep\n",
43 | " if dep is not None:\n",
44 | " nodess.append(nodes)\n",
45 | " else:\n",
46 | " nodess.append([])\n",
47 | " wordss = [[node[0] for node in nodes] for nodes in nodess]\n",
48 | " for qa in para['qas']:\n",
49 | " for answer in qa['answers']:\n",
50 | " text = answer['text']\n",
51 | " word_start = answer['answer_word_start']\n",
52 | " word_stop = answer['answer_word_stop']\n",
53 | " answer_words = wordss[word_start[0]][word_start[1]:word_stop[1]]\n",
54 | " yield answer_words, text\n",
55 | "\n",
56 | "ca = compare_answers()\n",
57 | "print(next(ca))\n",
58 | "print(next(ca))\n",
59 | "print(next(ca))\n",
60 | "print(next(ca))"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 11,
66 | "metadata": {
67 | "collapsed": false
68 | },
69 | "outputs": [
70 | {
71 | "name": "stdout",
72 | "output_type": "stream",
73 | "text": [
74 | "x: .\n",
75 | "x: .\n",
76 | "x: .\n",
77 | "x: .\n",
78 | "x: .\n",
79 | "x: .\n",
80 | "x: .\n",
81 | "x: .\n",
82 | "q: k\n",
83 | "q: j\n",
84 | "q: n\n",
85 | "q: b\n",
86 | "q: v\n",
87 | "x: .\n",
88 | "x: :208\n",
89 | "x: .\n",
90 | "x: .\n",
91 | "x: .\n",
92 | "x: .\n",
93 | "x: .\n",
94 | "x: .\n",
95 | "x: .\n",
96 | "x: .\n",
97 | "x: .\n",
98 | "x: .\n",
99 | "x: .\n",
100 | "q: dd\n",
101 | "q: dd\n",
102 | "q: dd\n",
103 | "q: dd\n",
104 | "q: d\n",
105 | "x: .\n",
106 | "x: .\n",
107 | "x: .\n",
108 | "x: .\n",
109 | "x: .\n",
110 | "x: .\n",
111 | "x: .\n",
112 | "x: .\n",
113 | "x: :411\n",
114 | "x: .\n",
115 | "x: .\n",
116 | "x: .\n",
117 | "x: .\n",
118 | "x: .\n",
119 | "x: .\n",
120 | "x: :40\n",
121 | "x: .\n",
122 | "x: *\n",
123 | "x: :14\n",
124 | "x: .\n",
125 | "x: .\n",
126 | "x: .\n",
127 | "x: :131\n",
128 | "x: .\n",
129 | "x: .\n",
130 | "x: .\n",
131 | "x: .\n",
132 | "x: .\n",
133 | "x: .\n",
134 | "x: .\n",
135 | "x: .\n",
136 | "x: .\n",
137 | "53 10\n"
138 | ]
139 | }
140 | ],
141 | "source": [
142 | "def nodep_counter():\n",
143 | " x_count = 0\n",
144 | " q_count = 0\n",
145 | " for article in aug_data['data']:\n",
146 | " for para in article['paragraphs']:\n",
147 | " deps = para['deps']\n",
148 | " nodess = []\n",
149 | " for sent, dep in zip(para['sents'], deps):\n",
150 | " if dep is None:\n",
151 | " print(\"x:\", sent)\n",
152 | " x_count += 1\n",
153 | " for qa in para['qas']:\n",
154 | " if qa['dep'] is None:\n",
155 | " print(\"q:\", qa['question'])\n",
156 | " q_count += 1\n",
157 | " print(x_count, q_count)\n",
158 | "nodep_counter()\n"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 4,
164 | "metadata": {
165 | "collapsed": false
166 | },
167 | "outputs": [
168 | {
169 | "name": "stdout",
170 | "output_type": "stream",
171 | "text": [
172 | "0\n"
173 | ]
174 | }
175 | ],
176 | "source": [
177 | "def bad_node_counter():\n",
178 | " count = 0\n",
179 | " for article in aug_data['data']:\n",
180 | " for para in article['paragraphs']:\n",
181 | " sents = para['sents']\n",
182 | " deps = para['deps']\n",
183 | " nodess = []\n",
184 | " for dep in deps:\n",
185 | " if dep is not None:\n",
186 | " nodes, edges = dep\n",
187 | " for node in nodes:\n",
188 | " if len(node) != 5:\n",
189 | " count += 1\n",
190 | " print(count)\n",
191 | "bad_node_counter() "
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": 5,
197 | "metadata": {
198 | "collapsed": false
199 | },
200 | "outputs": [
201 | {
202 | "name": "stdout",
203 | "output_type": "stream",
204 | "text": [
205 | "36\n"
206 | ]
207 | }
208 | ],
209 | "source": [
210 | "def noanswer_counter():\n",
211 | " count = 0\n",
212 | " for article in aug_data['data']:\n",
213 | " for para in article['paragraphs']:\n",
214 | " deps = para['deps']\n",
215 | " nodess = []\n",
216 | " for dep in deps:\n",
217 | " if dep is not None:\n",
218 | " nodes, edges = dep\n",
219 | " nodess.append(nodes)\n",
220 | " else:\n",
221 | " nodess.append([])\n",
222 | " wordss = [[node[0] for node in nodes] for nodes in nodess]\n",
223 | " for qa in para['qas']:\n",
224 | " for answer in qa['answers']:\n",
225 | " text = answer['text']\n",
226 | " word_start = answer['answer_word_start']\n",
227 | " word_stop = answer['answer_word_stop']\n",
228 | " if word_start is None:\n",
229 | " count += 1\n",
230 | " print(count)\n",
231 | "noanswer_counter()"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": 14,
237 | "metadata": {
238 | "collapsed": false
239 | },
240 | "outputs": [
241 | {
242 | "name": "stdout",
243 | "output_type": "stream",
244 | "text": [
245 | "106\n"
246 | ]
247 | }
248 | ],
249 | "source": [
250 | "def mult_sent_answer_counter():\n",
251 | " count = 0\n",
252 | " for article in aug_data['data']:\n",
253 | " for para in article['paragraphs']:\n",
254 | " for qa in para['qas']:\n",
255 | " for answer in qa['answers']:\n",
256 | " text = answer['text']\n",
257 | " word_start = answer['answer_word_start']\n",
258 | " word_stop = answer['answer_word_stop']\n",
259 | " if word_start is not None and word_start[0] != word_stop[0]:\n",
260 | " count += 1\n",
261 | " print(count)\n",
262 | "mult_sent_answer_counter()"
263 | ]
264 | },
265 | {
266 | "cell_type": "code",
267 | "execution_count": null,
268 | "metadata": {
269 | "collapsed": true
270 | },
271 | "outputs": [],
272 | "source": []
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": null,
277 | "metadata": {
278 | "collapsed": true
279 | },
280 | "outputs": [],
281 | "source": []
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": null,
286 | "metadata": {
287 | "collapsed": true
288 | },
289 | "outputs": [],
290 | "source": []
291 | }
292 | ],
293 | "metadata": {
294 | "kernelspec": {
295 | "display_name": "Python 3",
296 | "language": "python",
297 | "name": "python3"
298 | },
299 | "language_info": {
300 | "codemirror_mode": {
301 | "name": "ipython",
302 | "version": 3
303 | },
304 | "file_extension": ".py",
305 | "mimetype": "text/x-python",
306 | "name": "python",
307 | "nbconvert_exporter": "python",
308 | "pygments_lexer": "ipython3",
309 | "version": "3.5.1"
310 | }
311 | },
312 | "nbformat": 4,
313 | "nbformat_minor": 0
314 | }
315 |
--------------------------------------------------------------------------------
/squad/evaluate-v1.1.py:
--------------------------------------------------------------------------------
1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """
2 | from __future__ import print_function
3 | from collections import Counter
4 | import string
5 | import re
6 | import argparse
7 | import json
8 | import sys
9 |
10 |
11 | def normalize_answer(s):
12 | """Lower text and remove punctuation, articles and extra whitespace."""
13 | def remove_articles(text):
14 | return re.sub(r'\b(a|an|the)\b', ' ', text)
15 |
16 | def white_space_fix(text):
17 | return ' '.join(text.split())
18 |
19 | def remove_punc(text):
20 | exclude = set(string.punctuation)
21 | return ''.join(ch for ch in text if ch not in exclude)
22 |
23 | def lower(text):
24 | return text.lower()
25 |
26 | return white_space_fix(remove_articles(remove_punc(lower(s))))
27 |
28 |
29 | def f1_score(prediction, ground_truth):
30 | prediction_tokens = normalize_answer(prediction).split()
31 | ground_truth_tokens = normalize_answer(ground_truth).split()
32 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
33 | num_same = sum(common.values())
34 | if num_same == 0:
35 | return 0
36 | precision = 1.0 * num_same / len(prediction_tokens)
37 | recall = 1.0 * num_same / len(ground_truth_tokens)
38 | f1 = (2 * precision * recall) / (precision + recall)
39 | return f1
40 |
41 |
42 | def exact_match_score(prediction, ground_truth):
43 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
44 |
45 |
46 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
47 | scores_for_ground_truths = []
48 | for ground_truth in ground_truths:
49 | score = metric_fn(prediction, ground_truth)
50 | scores_for_ground_truths.append(score)
51 | return max(scores_for_ground_truths)
52 |
53 |
54 | def evaluate(dataset, predictions):
55 | f1 = exact_match = total = 0
56 | for article in dataset:
57 | for paragraph in article['paragraphs']:
58 | for qa in paragraph['qas']:
59 | total += 1
60 | if qa['id'] not in predictions:
61 | message = 'Unanswered question ' + qa['id'] + \
62 | ' will receive score 0.'
63 | print(message, file=sys.stderr)
64 | continue
65 | ground_truths = list(map(lambda x: x['text'], qa['answers']))
66 | prediction = predictions[qa['id']]
67 | exact_match += metric_max_over_ground_truths(
68 | exact_match_score, prediction, ground_truths)
69 | f1 += metric_max_over_ground_truths(
70 | f1_score, prediction, ground_truths)
71 |
72 | exact_match = 100.0 * exact_match / total
73 | f1 = 100.0 * f1 / total
74 |
75 | return {'exact_match': exact_match, 'f1': f1}
76 |
77 |
78 | if __name__ == '__main__':
79 | expected_version = '1.1'
80 | parser = argparse.ArgumentParser(
81 | description='Evaluation for SQuAD ' + expected_version)
82 | parser.add_argument('dataset_file', help='Dataset file')
83 | parser.add_argument('prediction_file', help='Prediction File')
84 | args = parser.parse_args()
85 | with open(args.dataset_file) as dataset_file:
86 | dataset_json = json.load(dataset_file)
87 | if (dataset_json['version'] != expected_version):
88 | print('Evaluation expects v-' + expected_version +
89 | ', but got dataset with v-' + dataset_json['version'],
90 | file=sys.stderr)
91 | dataset = dataset_json['data']
92 | with open(args.prediction_file) as prediction_file:
93 | predictions = json.load(prediction_file)
94 | print(json.dumps(evaluate(dataset, predictions)))
95 |
--------------------------------------------------------------------------------
/squad/evaluate.py:
--------------------------------------------------------------------------------
1 | """ Official evaluation script for v1.1 of the SQuAD dataset. [Changed name for external importing]"""
2 | from __future__ import print_function
3 | from collections import Counter
4 | import string
5 | import re
6 | import argparse
7 | import json
8 | import sys
9 |
10 |
11 | def normalize_answer(s):
12 | """Lower text and remove punctuation, articles and extra whitespace."""
13 | def remove_articles(text):
14 | return re.sub(r'\b(a|an|the)\b', ' ', text)
15 |
16 | def white_space_fix(text):
17 | return ' '.join(text.split())
18 |
19 | def remove_punc(text):
20 | exclude = set(string.punctuation)
21 | return ''.join(ch for ch in text if ch not in exclude)
22 |
23 | def lower(text):
24 | return text.lower()
25 |
26 | return white_space_fix(remove_articles(remove_punc(lower(s))))
27 |
28 |
29 | def f1_score(prediction, ground_truth):
30 | prediction_tokens = normalize_answer(prediction).split()
31 | ground_truth_tokens = normalize_answer(ground_truth).split()
32 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
33 | num_same = sum(common.values())
34 | if num_same == 0:
35 | return 0
36 | precision = 1.0 * num_same / len(prediction_tokens)
37 | recall = 1.0 * num_same / len(ground_truth_tokens)
38 | f1 = (2 * precision * recall) / (precision + recall)
39 | return f1
40 |
41 |
42 | def exact_match_score(prediction, ground_truth):
43 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
44 |
45 |
46 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
47 | scores_for_ground_truths = []
48 | for ground_truth in ground_truths:
49 | score = metric_fn(prediction, ground_truth)
50 | scores_for_ground_truths.append(score)
51 | return max(scores_for_ground_truths)
52 |
53 |
54 | def evaluate(dataset, predictions):
55 | f1 = exact_match = total = 0
56 | for article in dataset:
57 | for paragraph in article['paragraphs']:
58 | for qa in paragraph['qas']:
59 | total += 1
60 | if qa['id'] not in predictions:
61 | message = 'Unanswered question ' + qa['id'] + \
62 | ' will receive score 0.'
63 | print(message, file=sys.stderr)
64 | continue
65 | ground_truths = list(map(lambda x: x['text'], qa['answers']))
66 | prediction = predictions[qa['id']]
67 | exact_match += metric_max_over_ground_truths(
68 | exact_match_score, prediction, ground_truths)
69 | f1 += metric_max_over_ground_truths(
70 | f1_score, prediction, ground_truths)
71 |
72 | exact_match = 100.0 * exact_match / total
73 | f1 = 100.0 * f1 / total
74 |
75 | return {'exact_match': exact_match, 'f1': f1}
76 |
77 |
78 | if __name__ == '__main__':
79 | expected_version = '1.1'
80 | parser = argparse.ArgumentParser(
81 | description='Evaluation for SQuAD ' + expected_version)
82 | parser.add_argument('dataset_file', help='Dataset file')
83 | parser.add_argument('prediction_file', help='Prediction File')
84 | args = parser.parse_args()
85 | with open(args.dataset_file) as dataset_file:
86 | dataset_json = json.load(dataset_file)
87 | if (dataset_json['version'] != expected_version):
88 | print('Evaluation expects v-' + expected_version +
89 | ', but got dataset with v-' + dataset_json['version'],
90 | file=sys.stderr)
91 | dataset = dataset_json['data']
92 | with open(args.prediction_file) as prediction_file:
93 | predictions = json.load(prediction_file)
94 | print(json.dumps(evaluate(dataset, predictions)))
95 |
--------------------------------------------------------------------------------
/squad/prepro.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | # data: q, cq, (dq), (pq), y, *x, *cx
5 | # shared: x, cx, (dx), (px), word_counter, char_counter, word2vec
6 | # no metadata
7 | from collections import Counter
8 |
9 | from tqdm import tqdm
10 |
11 | from squad.utils import get_word_span, get_word_idx, process_tokens
12 |
13 |
14 | def main():
15 | args = get_args()
16 | prepro(args)
17 |
18 |
19 | def get_args():
20 | parser = argparse.ArgumentParser()
21 | home = os.path.expanduser("~")
22 | source_dir = os.path.join(home, "data", "squad")
23 | target_dir = "data/squad"
24 | glove_dir = os.path.join(home, "data", "glove")
25 | parser.add_argument('-s', "--source_dir", default=source_dir)
26 | parser.add_argument('-t', "--target_dir", default=target_dir)
27 | parser.add_argument('-d', "--debug", action='store_true')
28 | parser.add_argument("--train_ratio", default=0.9, type=int)
29 | parser.add_argument("--glove_corpus", default="6B")
30 | parser.add_argument("--glove_dir", default=glove_dir)
31 | parser.add_argument("--glove_vec_size", default=100, type=int)
32 | parser.add_argument("--mode", default="full", type=str)
33 | parser.add_argument("--single_path", default="", type=str)
34 | parser.add_argument("--tokenizer", default="PTB", type=str)
35 | parser.add_argument("--url", default="vision-server2.corp.ai2", type=str)
36 | parser.add_argument("--port", default=8000, type=int)
37 | parser.add_argument("--split", action='store_true')
38 | # TODO : put more args here
39 | return parser.parse_args()
40 |
41 |
42 | def create_all(args):
43 | out_path = os.path.join(args.source_dir, "all-v1.1.json")
44 | if os.path.exists(out_path):
45 | return
46 | train_path = os.path.join(args.source_dir, "train-v1.1.json")
47 | train_data = json.load(open(train_path, 'r'))
48 | dev_path = os.path.join(args.source_dir, "dev-v1.1.json")
49 | dev_data = json.load(open(dev_path, 'r'))
50 | train_data['data'].extend(dev_data['data'])
51 | print("dumping all data ...")
52 | json.dump(train_data, open(out_path, 'w'))
53 |
54 |
55 | def prepro(args):
56 | if not os.path.exists(args.target_dir):
57 | os.makedirs(args.target_dir)
58 |
59 | if args.mode == 'full':
60 | prepro_each(args, 'train', out_name='train')
61 | prepro_each(args, 'dev', out_name='dev')
62 | prepro_each(args, 'dev', out_name='test')
63 | elif args.mode == 'all':
64 | create_all(args)
65 | prepro_each(args, 'dev', 0.0, 0.0, out_name='dev')
66 | prepro_each(args, 'dev', 0.0, 0.0, out_name='test')
67 | prepro_each(args, 'all', out_name='train')
68 | elif args.mode == 'single':
69 | assert len(args.single_path) > 0
70 | prepro_each(args, "NULL", out_name="single", in_path=args.single_path)
71 | else:
72 | prepro_each(args, 'train', 0.0, args.train_ratio, out_name='train')
73 | prepro_each(args, 'train', args.train_ratio, 1.0, out_name='dev')
74 | prepro_each(args, 'dev', out_name='test')
75 |
76 |
77 | def save(args, data, shared, data_type):
78 | data_path = os.path.join(args.target_dir, "data_{}.json".format(data_type))
79 | shared_path = os.path.join(args.target_dir, "shared_{}.json".format(data_type))
80 | json.dump(data, open(data_path, 'w'))
81 | json.dump(shared, open(shared_path, 'w'))
82 |
83 |
84 | def get_word2vec(args, word_counter):
85 | glove_path = os.path.join(args.glove_dir, "glove.{}.{}d.txt".format(args.glove_corpus, args.glove_vec_size))
86 | sizes = {'6B': int(4e5), '42B': int(1.9e6), '840B': int(2.2e6), '2B': int(1.2e6)}
87 | total = sizes[args.glove_corpus]
88 | word2vec_dict = {}
89 | with open(glove_path, 'r', encoding='utf-8') as fh:
90 | for line in tqdm(fh, total=total):
91 | array = line.lstrip().rstrip().split(" ")
92 | word = array[0]
93 | vector = list(map(float, array[1:]))
94 | if word in word_counter:
95 | word2vec_dict[word] = vector
96 | elif word.capitalize() in word_counter:
97 | word2vec_dict[word.capitalize()] = vector
98 | elif word.lower() in word_counter:
99 | word2vec_dict[word.lower()] = vector
100 | elif word.upper() in word_counter:
101 | word2vec_dict[word.upper()] = vector
102 |
103 | print("{}/{} of word vocab have corresponding vectors in {}".format(len(word2vec_dict), len(word_counter), glove_path))
104 | return word2vec_dict
105 |
106 |
107 | def prepro_each(args, data_type, start_ratio=0.0, stop_ratio=1.0, out_name="default", in_path=None):
108 | if args.tokenizer == "PTB":
109 | import nltk
110 | sent_tokenize = nltk.sent_tokenize
111 | def word_tokenize(tokens):
112 | return [token.replace("''", '"').replace("``", '"') for token in nltk.word_tokenize(tokens)]
113 | elif args.tokenizer == 'Stanford':
114 | from my.corenlp_interface import CoreNLPInterface
115 | interface = CoreNLPInterface(args.url, args.port)
116 | sent_tokenize = interface.split_doc
117 | word_tokenize = interface.split_sent
118 | else:
119 | raise Exception()
120 |
121 | if not args.split:
122 | sent_tokenize = lambda para: [para]
123 |
124 | source_path = in_path or os.path.join(args.source_dir, "{}-v1.1.json".format(data_type))
125 | source_data = json.load(open(source_path, 'r'))
126 |
127 | q, cq, y, rx, rcx, ids, idxs = [], [], [], [], [], [], []
128 | cy = []
129 | x, cx = [], []
130 | answerss = []
131 | p = []
132 | word_counter, char_counter, lower_word_counter = Counter(), Counter(), Counter()
133 | start_ai = int(round(len(source_data['data']) * start_ratio))
134 | stop_ai = int(round(len(source_data['data']) * stop_ratio))
135 | for ai, article in enumerate(tqdm(source_data['data'][start_ai:stop_ai])):
136 | xp, cxp = [], []
137 | pp = []
138 | x.append(xp)
139 | cx.append(cxp)
140 | p.append(pp)
141 | for pi, para in enumerate(article['paragraphs']):
142 | # wordss
143 | context = para['context']
144 | context = context.replace("''", '" ')
145 | context = context.replace("``", '" ')
146 | xi = list(map(word_tokenize, sent_tokenize(context)))
147 | xi = [process_tokens(tokens) for tokens in xi] # process tokens
148 | # given xi, add chars
149 | cxi = [[list(xijk) for xijk in xij] for xij in xi]
150 | xp.append(xi)
151 | cxp.append(cxi)
152 | pp.append(context)
153 |
154 | for xij in xi:
155 | for xijk in xij:
156 | word_counter[xijk] += len(para['qas'])
157 | lower_word_counter[xijk.lower()] += len(para['qas'])
158 | for xijkl in xijk:
159 | char_counter[xijkl] += len(para['qas'])
160 |
161 | rxi = [ai, pi]
162 | assert len(x) - 1 == ai
163 | assert len(x[ai]) - 1 == pi
164 | for qa in para['qas']:
165 | # get words
166 | qi = word_tokenize(qa['question'])
167 | cqi = [list(qij) for qij in qi]
168 | yi = []
169 | cyi = []
170 | answers = []
171 | for answer in qa['answers']:
172 | answer_text = answer['text']
173 | answers.append(answer_text)
174 | answer_start = answer['answer_start']
175 | answer_stop = answer_start + len(answer_text)
176 | # TODO : put some function that gives word_start, word_stop here
177 | yi0, yi1 = get_word_span(context, xi, answer_start, answer_stop)
178 | # yi0 = answer['answer_word_start'] or [0, 0]
179 | # yi1 = answer['answer_word_stop'] or [0, 1]
180 | assert len(xi[yi0[0]]) > yi0[1]
181 | assert len(xi[yi1[0]]) >= yi1[1]
182 | w0 = xi[yi0[0]][yi0[1]]
183 | w1 = xi[yi1[0]][yi1[1]-1]
184 | i0 = get_word_idx(context, xi, yi0)
185 | i1 = get_word_idx(context, xi, (yi1[0], yi1[1]-1))
186 | cyi0 = answer_start - i0
187 | cyi1 = answer_stop - i1 - 1
188 | # print(answer_text, w0[cyi0:], w1[:cyi1+1])
189 | assert answer_text[0] == w0[cyi0], (answer_text, w0, cyi0)
190 | assert answer_text[-1] == w1[cyi1]
191 | assert cyi0 < 32, (answer_text, w0)
192 | assert cyi1 < 32, (answer_text, w1)
193 |
194 | yi.append([yi0, yi1])
195 | cyi.append([cyi0, cyi1])
196 |
197 | for qij in qi:
198 | word_counter[qij] += 1
199 | lower_word_counter[qij.lower()] += 1
200 | for qijk in qij:
201 | char_counter[qijk] += 1
202 |
203 | q.append(qi)
204 | cq.append(cqi)
205 | y.append(yi)
206 | cy.append(cyi)
207 | rx.append(rxi)
208 | rcx.append(rxi)
209 | ids.append(qa['id'])
210 | idxs.append(len(idxs))
211 | answerss.append(answers)
212 |
213 | if args.debug:
214 | break
215 |
216 | word2vec_dict = get_word2vec(args, word_counter)
217 | lower_word2vec_dict = get_word2vec(args, lower_word_counter)
218 |
219 | # add context here
220 | data = {'q': q, 'cq': cq, 'y': y, '*x': rx, '*cx': rcx, 'cy': cy,
221 | 'idxs': idxs, 'ids': ids, 'answerss': answerss, '*p': rx}
222 | shared = {'x': x, 'cx': cx, 'p': p,
223 | 'word_counter': word_counter, 'char_counter': char_counter, 'lower_word_counter': lower_word_counter,
224 | 'word2vec': word2vec_dict, 'lower_word2vec': lower_word2vec_dict}
225 |
226 | print("saving ...")
227 | save(args, data, shared, out_name)
228 |
229 |
230 |
231 | if __name__ == "__main__":
232 | main()
--------------------------------------------------------------------------------
/squad/prepro_aug.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | # data: q, cq, (dq), (pq), y, *x, *cx
5 | # shared: x, cx, (dx), (px), word_counter, char_counter, word2vec
6 | # no metadata
7 | from collections import Counter
8 |
9 | import nltk
10 | from tqdm import tqdm
11 |
12 | from my.nltk_utils import load_compressed_tree
13 |
14 |
15 | def bool_(arg):
16 | if arg == 'True':
17 | return True
18 | elif arg == 'False':
19 | return False
20 | raise Exception()
21 |
22 |
23 | def main():
24 | args = get_args()
25 | prepro(args)
26 |
27 |
28 | def get_args():
29 | parser = argparse.ArgumentParser()
30 | home = os.path.expanduser("~")
31 | source_dir = os.path.join(home, "data", "squad")
32 | target_dir = "data/squad"
33 | glove_dir = os.path.join(home, "data", "glove")
34 | parser.add_argument("--source_dir", default=source_dir)
35 | parser.add_argument("--target_dir", default=target_dir)
36 | parser.add_argument("--debug", default=False, type=bool_)
37 | parser.add_argument("--train_ratio", default=0.9, type=int)
38 | parser.add_argument("--glove_corpus", default="6B")
39 | parser.add_argument("--glove_dir", default=glove_dir)
40 | parser.add_argument("--glove_vec_size", default=100, type=int)
41 | parser.add_argument("--full_train", default=False, type=bool_)
42 | # TODO : put more args here
43 | return parser.parse_args()
44 |
45 |
46 | def prepro(args):
47 | if not os.path.exists(args.target_dir):
48 | os.makedirs(args.target_dir)
49 |
50 | if args.full_train:
51 | data_train, shared_train = prepro_each(args, 'train')
52 | data_dev, shared_dev = prepro_each(args, 'dev')
53 | else:
54 | data_train, shared_train = prepro_each(args, 'train', 0.0, args.train_ratio)
55 | data_dev, shared_dev = prepro_each(args, 'train', args.train_ratio, 1.0)
56 | data_test, shared_test = prepro_each(args, 'dev')
57 |
58 | print("saving ...")
59 | save(args, data_train, shared_train, 'train')
60 | save(args, data_dev, shared_dev, 'dev')
61 | save(args, data_test, shared_test, 'test')
62 |
63 |
64 | def save(args, data, shared, data_type):
65 | data_path = os.path.join(args.target_dir, "data_{}.json".format(data_type))
66 | shared_path = os.path.join(args.target_dir, "shared_{}.json".format(data_type))
67 | json.dump(data, open(data_path, 'w'))
68 | json.dump(shared, open(shared_path, 'w'))
69 |
70 |
71 | def get_word2vec(args, word_counter):
72 | glove_path = os.path.join(args.glove_dir, "glove.{}.{}d.txt".format(args.glove_corpus, args.glove_vec_size))
73 | sizes = {'6B': int(4e5), '42B': int(1.9e6), '840B': int(2.2e6), '2B': int(1.2e6)}
74 | total = sizes[args.glove_corpus]
75 | word2vec_dict = {}
76 | with open(glove_path, 'r') as fh:
77 | for line in tqdm(fh, total=total):
78 | array = line.lstrip().rstrip().split(" ")
79 | word = array[0]
80 | vector = list(map(float, array[1:]))
81 | if word in word_counter:
82 | word2vec_dict[word] = vector
83 | elif word.capitalize() in word_counter:
84 | word2vec_dict[word.capitalize()] = vector
85 | elif word.lower() in word_counter:
86 | word2vec_dict[word.lower()] = vector
87 | elif word.upper() in word_counter:
88 | word2vec_dict[word.upper()] = vector
89 |
90 | print("{}/{} of word vocab have corresponding vectors in {}".format(len(word2vec_dict), len(word_counter), glove_path))
91 | return word2vec_dict
92 |
93 |
94 | def prepro_each(args, data_type, start_ratio=0.0, stop_ratio=1.0):
95 | source_path = os.path.join(args.source_dir, "{}-v1.0-aug.json".format(data_type))
96 | source_data = json.load(open(source_path, 'r'))
97 |
98 | q, cq, y, rx, rcx, ids, idxs = [], [], [], [], [], [], []
99 | x, cx, tx, stx = [], [], [], []
100 | answerss = []
101 | word_counter, char_counter, lower_word_counter = Counter(), Counter(), Counter()
102 | pos_counter = Counter()
103 | start_ai = int(round(len(source_data['data']) * start_ratio))
104 | stop_ai = int(round(len(source_data['data']) * stop_ratio))
105 | for ai, article in enumerate(tqdm(source_data['data'][start_ai:stop_ai])):
106 | xp, cxp, txp, stxp = [], [], [], []
107 | x.append(xp)
108 | cx.append(cxp)
109 | tx.append(txp)
110 | stx.append(stxp)
111 | for pi, para in enumerate(article['paragraphs']):
112 | xi = []
113 | for dep in para['deps']:
114 | if dep is None:
115 | xi.append([])
116 | else:
117 | xi.append([node[0] for node in dep[0]])
118 | cxi = [[list(xijk) for xijk in xij] for xij in xi]
119 | xp.append(xi)
120 | cxp.append(cxi)
121 | txp.append(para['consts'])
122 | stxp.append([str(load_compressed_tree(s)) for s in para['consts']])
123 | trees = map(nltk.tree.Tree.fromstring, para['consts'])
124 | for tree in trees:
125 | for subtree in tree.subtrees():
126 | pos_counter[subtree.label()] += 1
127 |
128 | for xij in xi:
129 | for xijk in xij:
130 | word_counter[xijk] += len(para['qas'])
131 | lower_word_counter[xijk.lower()] += len(para['qas'])
132 | for xijkl in xijk:
133 | char_counter[xijkl] += len(para['qas'])
134 |
135 | rxi = [ai, pi]
136 | assert len(x) - 1 == ai
137 | assert len(x[ai]) - 1 == pi
138 | for qa in para['qas']:
139 | dep = qa['dep']
140 | qi = [] if dep is None else [node[0] for node in dep[0]]
141 | cqi = [list(qij) for qij in qi]
142 | yi = []
143 | answers = []
144 | for answer in qa['answers']:
145 | answers.append(answer['text'])
146 | yi0 = answer['answer_word_start'] or [0, 0]
147 | yi1 = answer['answer_word_stop'] or [0, 1]
148 | assert len(xi[yi0[0]]) > yi0[1]
149 | assert len(xi[yi1[0]]) >= yi1[1]
150 | yi.append([yi0, yi1])
151 |
152 | for qij in qi:
153 | word_counter[qij] += 1
154 | lower_word_counter[qij.lower()] += 1
155 | for qijk in qij:
156 | char_counter[qijk] += 1
157 |
158 | q.append(qi)
159 | cq.append(cqi)
160 | y.append(yi)
161 | rx.append(rxi)
162 | rcx.append(rxi)
163 | ids.append(qa['id'])
164 | idxs.append(len(idxs))
165 | answerss.append(answers)
166 |
167 | if args.debug:
168 | break
169 |
170 | word2vec_dict = get_word2vec(args, word_counter)
171 | lower_word2vec_dict = get_word2vec(args, lower_word_counter)
172 |
173 | data = {'q': q, 'cq': cq, 'y': y, '*x': rx, '*cx': rcx, '*tx': rx, '*stx': rx,
174 | 'idxs': idxs, 'ids': ids, 'answerss': answerss}
175 | shared = {'x': x, 'cx': cx, 'tx': tx, 'stx': stx,
176 | 'word_counter': word_counter, 'char_counter': char_counter, 'lower_word_counter': lower_word_counter,
177 | 'word2vec': word2vec_dict, 'lower_word2vec': lower_word2vec_dict, 'pos_counter': pos_counter}
178 |
179 | return data, shared
180 |
181 |
182 | if __name__ == "__main__":
183 | main()
--------------------------------------------------------------------------------
/squad/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | def get_2d_spans(text, tokenss):
5 | spanss = []
6 | cur_idx = 0
7 | for tokens in tokenss:
8 | spans = []
9 | for token in tokens:
10 | if text.find(token, cur_idx) < 0:
11 | print(tokens)
12 | print("{} {} {}".format(token, cur_idx, text))
13 | raise Exception()
14 | cur_idx = text.find(token, cur_idx)
15 | spans.append((cur_idx, cur_idx + len(token)))
16 | cur_idx += len(token)
17 | spanss.append(spans)
18 | return spanss
19 |
20 |
21 | def get_word_span(context, wordss, start, stop):
22 | spanss = get_2d_spans(context, wordss)
23 | idxs = []
24 | for sent_idx, spans in enumerate(spanss):
25 | for word_idx, span in enumerate(spans):
26 | if not (stop <= span[0] or start >= span[1]):
27 | idxs.append((sent_idx, word_idx))
28 |
29 | assert len(idxs) > 0, "{} {} {} {}".format(context, spanss, start, stop)
30 | return idxs[0], (idxs[-1][0], idxs[-1][1] + 1)
31 |
32 |
33 | def get_phrase(context, wordss, span):
34 | """
35 | Obtain phrase as substring of context given start and stop indices in word level
36 | :param context:
37 | :param wordss:
38 | :param start: [sent_idx, word_idx]
39 | :param stop: [sent_idx, word_idx]
40 | :return:
41 | """
42 | start, stop = span
43 | flat_start = get_flat_idx(wordss, start)
44 | flat_stop = get_flat_idx(wordss, stop)
45 | words = sum(wordss, [])
46 | char_idx = 0
47 | char_start, char_stop = None, None
48 | for word_idx, word in enumerate(words):
49 | char_idx = context.find(word, char_idx)
50 | assert char_idx >= 0
51 | if word_idx == flat_start:
52 | char_start = char_idx
53 | char_idx += len(word)
54 | if word_idx == flat_stop - 1:
55 | char_stop = char_idx
56 | assert char_start is not None
57 | assert char_stop is not None
58 | return context[char_start:char_stop]
59 |
60 |
61 | def get_flat_idx(wordss, idx):
62 | return sum(len(words) for words in wordss[:idx[0]]) + idx[1]
63 |
64 |
65 | def get_word_idx(context, wordss, idx):
66 | spanss = get_2d_spans(context, wordss)
67 | return spanss[idx[0]][idx[1]][0]
68 |
69 |
70 | def process_tokens(temp_tokens):
71 | tokens = []
72 | for token in temp_tokens:
73 | flag = False
74 | l = ("-", "\u2212", "\u2014", "\u2013", "/", "~", '"', "'", "\u201C", "\u2019", "\u201D", "\u2018", "\u00B0")
75 | # \u2013 is en-dash. Used for number to nubmer
76 | # l = ("-", "\u2212", "\u2014", "\u2013")
77 | # l = ("\u2013",)
78 | tokens.extend(re.split("([{}])".format("".join(l)), token))
79 | return tokens
80 |
81 |
82 | def get_best_span(ypi, yp2i):
83 | max_val = 0
84 | best_word_span = (0, 1)
85 | best_sent_idx = 0
86 | for f, (ypif, yp2if) in enumerate(zip(ypi, yp2i)):
87 | argmax_j1 = 0
88 | for j in range(len(ypif)):
89 | val1 = ypif[argmax_j1]
90 | if val1 < ypif[j]:
91 | val1 = ypif[j]
92 | argmax_j1 = j
93 |
94 | val2 = yp2if[j]
95 | if val1 * val2 > max_val:
96 | best_word_span = (argmax_j1, j)
97 | best_sent_idx = f
98 | max_val = val1 * val2
99 | return ((best_sent_idx, best_word_span[0]), (best_sent_idx, best_word_span[1] + 1)), float(max_val)
100 |
101 |
102 | def get_span_score_pairs(ypi, yp2i):
103 | span_score_pairs = []
104 | for f, (ypif, yp2if) in enumerate(zip(ypi, yp2i)):
105 | for j in range(len(ypif)):
106 | for k in range(j, len(yp2if)):
107 | span = ((f, j), (f, k+1))
108 | score = ypif[j] * yp2if[k]
109 | span_score_pairs.append((span, score))
110 | return span_score_pairs
111 |
112 |
113 |
--------------------------------------------------------------------------------
/tree/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allenai/bi-att-flow/49004549e9a88b78c359b31481afa7792dbb3f4a/tree/__init__.py
--------------------------------------------------------------------------------
/tree/cli.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pprint import pprint
3 |
4 | import tensorflow as tf
5 |
6 | from tree.main import main as m
7 |
8 | flags = tf.app.flags
9 |
10 | flags.DEFINE_string("model_name", "tree", "Model name [tree]")
11 | flags.DEFINE_string("data_dir", "data/squad", "Data dir [data/squad]")
12 | flags.DEFINE_integer("run_id", 0, "Run ID [0]")
13 |
14 | flags.DEFINE_integer("batch_size", 128, "Batch size [128]")
15 | flags.DEFINE_float("init_lr", 0.5, "Initial learning rate [0.5]")
16 | flags.DEFINE_integer("num_epochs", 50, "Total number of epochs for training [50]")
17 | flags.DEFINE_integer("num_steps", 0, "Number of steps [0]")
18 | flags.DEFINE_integer("eval_num_batches", 100, "eval num batches [100]")
19 | flags.DEFINE_integer("load_step", 0, "load step [0]")
20 | flags.DEFINE_integer("early_stop", 4, "early stop [4]")
21 |
22 | flags.DEFINE_string("mode", "test", "train | test | forward [test]")
23 | flags.DEFINE_boolean("load", True, "load saved data? [True]")
24 | flags.DEFINE_boolean("progress", True, "Show progress? [True]")
25 | flags.DEFINE_integer("log_period", 100, "Log period [100]")
26 | flags.DEFINE_integer("eval_period", 1000, "Eval period [1000]")
27 | flags.DEFINE_integer("save_period", 1000, "Save Period [1000]")
28 | flags.DEFINE_float("decay", 0.9, "Exponential moving average decay [0.9]")
29 |
30 | flags.DEFINE_boolean("draft", False, "Draft for quick testing? [False]")
31 |
32 | flags.DEFINE_integer("hidden_size", 32, "Hidden size [32]")
33 | flags.DEFINE_float("input_keep_prob", 0.5, "Input keep prob [0.5]")
34 | flags.DEFINE_integer("char_emb_size", 8, "Char emb size [8]")
35 | flags.DEFINE_integer("char_filter_height", 5, "Char filter height [5]")
36 | flags.DEFINE_float("wd", 0.0001, "Weight decay [0.001]")
37 | flags.DEFINE_bool("lower_word", True, "lower word [True]")
38 | flags.DEFINE_bool("dump_eval", True, "dump eval? [True]")
39 |
40 | flags.DEFINE_integer("word_count_th", 100, "word count th [100]")
41 | flags.DEFINE_integer("char_count_th", 500, "char count th [500]")
42 | flags.DEFINE_integer("sent_size_th", 64, "sent size th [64]")
43 | flags.DEFINE_integer("num_sents_th", 8, "num sents th [8]")
44 | flags.DEFINE_integer("ques_size_th", 64, "ques size th [64]")
45 | flags.DEFINE_integer("word_size_th", 16, "word size th [16]")
46 | flags.DEFINE_integer("tree_height_th", 16, "tree height th [16]")
47 |
48 |
49 | def main(_):
50 | config = flags.FLAGS
51 |
52 | config.out_dir = os.path.join("out", config.model_name, str(config.run_id).zfill(2))
53 |
54 | m(config)
55 |
56 | if __name__ == "__main__":
57 | tf.app.run()
58 |
--------------------------------------------------------------------------------
/tree/evaluator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from tree.read_data import DataSet
5 | from my.nltk_utils import span_f1
6 |
7 |
8 | class Evaluation(object):
9 | def __init__(self, data_type, global_step, idxs, yp):
10 | self.data_type = data_type
11 | self.global_step = global_step
12 | self.idxs = idxs
13 | self.yp = yp
14 | self.num_examples = len(yp)
15 | self.dict = {'data_type': data_type,
16 | 'global_step': global_step,
17 | 'yp': yp,
18 | 'idxs': idxs,
19 | 'num_examples': self.num_examples}
20 | self.summaries = None
21 |
22 | def __repr__(self):
23 | return "{} step {}".format(self.data_type, self.global_step)
24 |
25 | def __add__(self, other):
26 | if other == 0:
27 | return self
28 | assert self.data_type == other.data_type
29 | assert self.global_step == other.global_step
30 | new_yp = self.yp + other.yp
31 | new_idxs = self.idxs + other.idxs
32 | return Evaluation(self.data_type, self.global_step, new_idxs, new_yp)
33 |
34 | def __radd__(self, other):
35 | return self.__add__(other)
36 |
37 |
38 | class LabeledEvaluation(Evaluation):
39 | def __init__(self, data_type, global_step, idxs, yp, y):
40 | super(LabeledEvaluation, self).__init__(data_type, global_step, idxs, yp)
41 | self.y = y
42 | self.dict['y'] = y
43 |
44 | def __add__(self, other):
45 | if other == 0:
46 | return self
47 | assert self.data_type == other.data_type
48 | assert self.global_step == other.global_step
49 | new_yp = self.yp + other.yp
50 | new_y = self.y + other.y
51 | new_idxs = self.idxs + other.idxs
52 | return LabeledEvaluation(self.data_type, self.global_step, new_idxs, new_yp, new_y)
53 |
54 |
55 | class AccuracyEvaluation(LabeledEvaluation):
56 | def __init__(self, data_type, global_step, idxs, yp, y, correct, loss):
57 | super(AccuracyEvaluation, self).__init__(data_type, global_step, idxs, yp, y)
58 | self.loss = loss
59 | self.correct = correct
60 | self.acc = sum(correct) / len(correct)
61 | self.dict['loss'] = loss
62 | self.dict['correct'] = correct
63 | self.dict['acc'] = self.acc
64 | loss_summary = tf.Summary(value=[tf.Summary.Value(tag='dev/loss', simple_value=self.loss)])
65 | acc_summary = tf.Summary(value=[tf.Summary.Value(tag='dev/acc', simple_value=self.acc)])
66 | self.summaries = [loss_summary, acc_summary]
67 |
68 | def __repr__(self):
69 | return "{} step {}: accuracy={}, loss={}".format(self.data_type, self.global_step, self.acc, self.loss)
70 |
71 | def __add__(self, other):
72 | if other == 0:
73 | return self
74 | assert self.data_type == other.data_type
75 | assert self.global_step == other.global_step
76 | new_idxs = self.idxs + other.idxs
77 | new_yp = self.yp + other.yp
78 | new_y = self.y + other.y
79 | new_correct = self.correct + other.correct
80 | new_loss = (self.loss * self.num_examples + other.loss * other.num_examples) / len(new_correct)
81 | return AccuracyEvaluation(self.data_type, self.global_step, new_idxs, new_yp, new_y, new_correct, new_loss)
82 |
83 |
84 | class Evaluator(object):
85 | def __init__(self, config, model):
86 | self.config = config
87 | self.model = model
88 |
89 | def get_evaluation(self, sess, batch):
90 | idxs, data_set = batch
91 | feed_dict = self.model.get_feed_dict(data_set, False, supervised=False)
92 | global_step, yp = sess.run([self.model.global_step, self.model.yp], feed_dict=feed_dict)
93 | yp = yp[:data_set.num_examples]
94 | e = Evaluation(data_set.data_type, int(global_step), idxs, yp.tolist())
95 | return e
96 |
97 | def get_evaluation_from_batches(self, sess, batches):
98 | e = sum(self.get_evaluation(sess, batch) for batch in batches)
99 | return e
100 |
101 |
102 | class LabeledEvaluator(Evaluator):
103 | def get_evaluation(self, sess, batch):
104 | idxs, data_set = batch
105 | feed_dict = self.model.get_feed_dict(data_set, False, supervised=False)
106 | global_step, yp = sess.run([self.model.global_step, self.model.yp], feed_dict=feed_dict)
107 | yp = yp[:data_set.num_examples]
108 | y = feed_dict[self.model.y]
109 | e = LabeledEvaluation(data_set.data_type, int(global_step), idxs, yp.tolist(), y.tolist())
110 | return e
111 |
112 |
113 | class AccuracyEvaluator(LabeledEvaluator):
114 | def get_evaluation(self, sess, batch):
115 | idxs, data_set = batch
116 | assert isinstance(data_set, DataSet)
117 | feed_dict = self.model.get_feed_dict(data_set, False)
118 | global_step, yp, loss = sess.run([self.model.global_step, self.model.yp, self.model.loss], feed_dict=feed_dict)
119 | y = feed_dict[self.model.y]
120 | yp = yp[:data_set.num_examples]
121 | correct = [self.__class__.compare(yi, ypi) for yi, ypi in zip(y, yp)]
122 | e = AccuracyEvaluation(data_set.data_type, int(global_step), idxs, yp.tolist(), y.tolist(), correct, float(loss))
123 | return e
124 |
125 | @staticmethod
126 | def compare(yi, ypi):
127 | return int(np.argmax(yi)) == int(np.argmax(ypi))
128 |
129 |
130 | class AccuracyEvaluator2(AccuracyEvaluator):
131 | @staticmethod
132 | def compare(yi, ypi):
133 | i = int(np.argmax(yi.flatten()))
134 | j = int(np.argmax(ypi.flatten()))
135 | # print(i, j, i == j)
136 | return i == j
137 |
138 |
139 | class TempEvaluation(AccuracyEvaluation):
140 | def __init__(self, data_type, global_step, idxs, yp, yp2, y, y2, correct, loss, f1s):
141 | super(TempEvaluation, self).__init__(data_type, global_step, idxs, yp, y, correct, loss)
142 | self.y2 = y2
143 | self.yp2 = yp2
144 | self.f1s = f1s
145 | self.f1 = float(np.mean(f1s))
146 | self.dict['y2'] = y2
147 | self.dict['yp2'] = yp2
148 | self.dict['f1s'] = f1s
149 | self.dict['f1'] = self.f1
150 | f1_summary = tf.Summary(value=[tf.Summary.Value(tag='dev/f1', simple_value=self.f1)])
151 | self.summaries.append(f1_summary)
152 |
153 | def __add__(self, other):
154 | if other == 0:
155 | return self
156 | assert self.data_type == other.data_type
157 | assert self.global_step == other.global_step
158 | new_idxs = self.idxs + other.idxs
159 | new_yp = self.yp + other.yp
160 | new_yp2 = self.yp2 + other.yp2
161 | new_y = self.y + other.y
162 | new_y2 = self.y2 + other.y2
163 | new_correct = self.correct + other.correct
164 | new_f1s = self.f1s + other.f1s
165 | new_loss = (self.loss * self.num_examples + other.loss * other.num_examples) / len(new_correct)
166 | return TempEvaluation(self.data_type, self.global_step, new_idxs, new_yp, new_yp2, new_y, new_y2, new_correct, new_loss, new_f1s)
167 |
168 |
169 | class TempEvaluator(LabeledEvaluator):
170 | def get_evaluation(self, sess, batch):
171 | idxs, data_set = batch
172 | assert isinstance(data_set, DataSet)
173 | feed_dict = self.model.get_feed_dict(data_set, False)
174 | global_step, yp, yp2, loss = sess.run([self.model.global_step, self.model.yp, self.model.yp2, self.model.loss], feed_dict=feed_dict)
175 | y, y2 = feed_dict[self.model.y], feed_dict[self.model.y2]
176 | yp, yp2 = yp[:data_set.num_examples], yp2[:data_set.num_examples]
177 | correct = [self.__class__.compare(yi, y2i, ypi, yp2i) for yi, y2i, ypi, yp2i in zip(y, y2, yp, yp2)]
178 | f1s = [self.__class__.span_f1(yi, y2i, ypi, yp2i) for yi, y2i, ypi, yp2i in zip(y, y2, yp, yp2)]
179 | e = TempEvaluation(data_set.data_type, int(global_step), idxs, yp.tolist(), yp2.tolist(), y.tolist(), y2.tolist(), correct, float(loss), f1s)
180 | return e
181 |
182 | @staticmethod
183 | def compare(yi, y2i, ypi, yp2i):
184 | i = int(np.argmax(yi.flatten()))
185 | j = int(np.argmax(ypi.flatten()))
186 | k = int(np.argmax(y2i.flatten()))
187 | l = int(np.argmax(yp2i.flatten()))
188 | # print(i, j, i == j)
189 | return i == j and k == l
190 |
191 | @staticmethod
192 | def span_f1(yi, y2i, ypi, yp2i):
193 | true_span = (np.argmax(yi.flatten()), np.argmax(y2i.flatten())+1)
194 | pred_span = (np.argmax(ypi.flatten()), np.argmax(yp2i.flatten())+1)
195 | f1 = span_f1(true_span, pred_span)
196 | return f1
197 |
198 |
--------------------------------------------------------------------------------
/tree/graph_handler.py:
--------------------------------------------------------------------------------
1 | import json
2 | from json import encoder
3 | import os
4 |
5 | import tensorflow as tf
6 |
7 | from tree.evaluator import Evaluation
8 | from my.utils import short_floats
9 |
10 |
11 | class GraphHandler(object):
12 | def __init__(self, config):
13 | self.config = config
14 | self.saver = tf.train.Saver()
15 | self.writer = None
16 | self.save_path = os.path.join(config.save_dir, config.model_name)
17 |
18 | def initialize(self, sess):
19 | if self.config.load:
20 | self._load(sess)
21 | else:
22 | sess.run(tf.initialize_all_variables())
23 |
24 | if self.config.mode == 'train':
25 | self.writer = tf.train.SummaryWriter(self.config.log_dir, graph=tf.get_default_graph())
26 |
27 | def save(self, sess, global_step=None):
28 | self.saver.save(sess, self.save_path, global_step=global_step)
29 |
30 | def _load(self, sess):
31 | config = self.config
32 | if config.load_step > 0:
33 | save_path = os.path.join(config.save_dir, "{}-{}".format(config.model_name, config.load_step))
34 | else:
35 | save_dir = config.save_dir
36 | checkpoint = tf.train.get_checkpoint_state(save_dir)
37 | assert checkpoint is not None, "cannot load checkpoint at {}".format(save_dir)
38 | save_path = checkpoint.model_checkpoint_path
39 | print("Loading saved model from {}".format(save_path))
40 | self.saver.restore(sess, save_path)
41 |
42 | def add_summary(self, summary, global_step):
43 | self.writer.add_summary(summary, global_step)
44 |
45 | def add_summaries(self, summaries, global_step):
46 | for summary in summaries:
47 | self.add_summary(summary, global_step)
48 |
49 | def dump_eval(self, e, precision=2):
50 | assert isinstance(e, Evaluation)
51 | path = os.path.join(self.config.eval_dir, "{}-{}.json".format(e.data_type, str(e.global_step).zfill(6)))
52 | with open(path, 'w') as fh:
53 | json.dump(short_floats(e.dict, precision), fh)
54 |
55 |
--------------------------------------------------------------------------------
/tree/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import math
4 | import os
5 | import shutil
6 | from pprint import pprint
7 |
8 | import tensorflow as tf
9 | from tqdm import tqdm
10 | import numpy as np
11 |
12 | from tree.evaluator import AccuracyEvaluator2, Evaluator
13 | from tree.graph_handler import GraphHandler
14 | from tree.model import Model
15 | from tree.trainer import Trainer
16 |
17 | from tree.read_data import load_metadata, read_data, get_squad_data_filter, update_config
18 |
19 |
20 | def main(config):
21 | set_dirs(config)
22 | if config.mode == 'train':
23 | _train(config)
24 | elif config.mode == 'test':
25 | _test(config)
26 | elif config.mode == 'forward':
27 | _forward(config)
28 | else:
29 | raise ValueError("invalid value for 'mode': {}".format(config.mode))
30 |
31 |
32 | def _config_draft(config):
33 | if config.draft:
34 | config.num_steps = 10
35 | config.eval_period = 10
36 | config.log_period = 1
37 | config.save_period = 10
38 | config.eval_num_batches = 1
39 |
40 |
41 | def _train(config):
42 | # load_metadata(config, 'train') # this updates the config file according to metadata file
43 |
44 | data_filter = get_squad_data_filter(config)
45 | train_data = read_data(config, 'train', config.load, data_filter=data_filter)
46 | dev_data = read_data(config, 'dev', True, data_filter=data_filter)
47 | update_config(config, [train_data, dev_data])
48 |
49 | _config_draft(config)
50 |
51 | word2vec_dict = train_data.shared['lower_word2vec'] if config.lower_word else train_data.shared['word2vec']
52 | word2idx_dict = train_data.shared['word2idx']
53 | idx2vec_dict = {word2idx_dict[word]: vec for word, vec in word2vec_dict.items() if word in word2idx_dict}
54 | print("{}/{} unique words have corresponding glove vectors.".format(len(idx2vec_dict), len(word2idx_dict)))
55 | emb_mat = np.array([idx2vec_dict[idx] if idx in idx2vec_dict
56 | else np.random.multivariate_normal(np.zeros(config.word_emb_size), np.eye(config.word_emb_size))
57 | for idx in range(config.word_vocab_size)])
58 | config.emb_mat = emb_mat
59 |
60 | # construct model graph and variables (using default graph)
61 | pprint(config.__flags, indent=2)
62 | model = Model(config)
63 | trainer = Trainer(config, model)
64 | evaluator = AccuracyEvaluator2(config, model)
65 | graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving
66 |
67 | # Variables
68 | sess = tf.Session()
69 | graph_handler.initialize(sess)
70 |
71 | # begin training
72 | num_steps = config.num_steps or int(config.num_epochs * train_data.num_examples / config.batch_size)
73 | max_acc = 0
74 | noupdate_count = 0
75 | global_step = 0
76 | for _, batch in tqdm(train_data.get_batches(config.batch_size, num_batches=num_steps, shuffle=True), total=num_steps):
77 | global_step = sess.run(model.global_step) + 1 # +1 because all calculations are done after step
78 | get_summary = global_step % config.log_period == 0
79 | loss, summary, train_op = trainer.step(sess, batch, get_summary=get_summary)
80 | if get_summary:
81 | graph_handler.add_summary(summary, global_step)
82 |
83 | # Occasional evaluation and saving
84 | if global_step % config.save_period == 0:
85 | graph_handler.save(sess, global_step=global_step)
86 | if global_step % config.eval_period == 0:
87 | num_batches = math.ceil(dev_data.num_examples / config.batch_size)
88 | if 0 < config.eval_num_batches < num_batches:
89 | num_batches = config.eval_num_batches
90 | e = evaluator.get_evaluation_from_batches(
91 | sess, tqdm(dev_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches))
92 | graph_handler.add_summaries(e.summaries, global_step)
93 | if e.acc > max_acc:
94 | max_acc = e.acc
95 | noupdate_count = 0
96 | else:
97 | noupdate_count += 1
98 | if noupdate_count == config.early_stop:
99 | break
100 | if config.dump_eval:
101 | graph_handler.dump_eval(e)
102 | if global_step % config.save_period != 0:
103 | graph_handler.save(sess, global_step=global_step)
104 |
105 |
106 | def _test(config):
107 | test_data = read_data(config, 'test', True)
108 | update_config(config, [test_data])
109 |
110 | _config_draft(config)
111 |
112 | pprint(config.__flags, indent=2)
113 | model = Model(config)
114 | evaluator = AccuracyEvaluator2(config, model)
115 | graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving
116 |
117 | sess = tf.Session()
118 | graph_handler.initialize(sess)
119 |
120 | num_batches = math.ceil(test_data.num_examples / config.batch_size)
121 | if 0 < config.eval_num_batches < num_batches:
122 | num_batches = config.eval_num_batches
123 | e = evaluator.get_evaluation_from_batches(sess, tqdm(test_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches))
124 | print(e)
125 | if config.dump_eval:
126 | graph_handler.dump_eval(e)
127 |
128 |
129 | def _forward(config):
130 |
131 | forward_data = read_data(config, 'forward', True)
132 |
133 | _config_draft(config)
134 |
135 | pprint(config.__flag, indent=2)
136 | model = Model(config)
137 | evaluator = Evaluator(config, model)
138 | graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving
139 |
140 | sess = tf.Session()
141 | graph_handler.initialize(sess)
142 |
143 | num_batches = math.ceil(forward_data.num_examples / config.batch_size)
144 | if 0 < config.eval_num_batches < num_batches:
145 | num_batches = config.eval_num_batches
146 | e = evaluator.get_evaluation_from_batches(sess, tqdm(forward_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches))
147 | print(e)
148 | if config.dump_eval:
149 | graph_handler.dump_eval(e)
150 |
151 |
152 | def set_dirs(config):
153 | # create directories
154 | if not config.load and os.path.exists(config.out_dir):
155 | shutil.rmtree(config.out_dir)
156 |
157 | config.save_dir = os.path.join(config.out_dir, "save")
158 | config.log_dir = os.path.join(config.out_dir, "log")
159 | config.eval_dir = os.path.join(config.out_dir, "eval")
160 | if not os.path.exists(config.out_dir):
161 | os.makedirs(config.out_dir)
162 | if not os.path.exists(config.save_dir):
163 | os.mkdir(config.save_dir)
164 | if not os.path.exists(config.log_dir):
165 | os.mkdir(config.eval_dir)
166 |
167 |
168 | def _get_args():
169 | parser = argparse.ArgumentParser()
170 | parser.add_argument("config_path")
171 | return parser.parse_args()
172 |
173 |
174 | class Config(object):
175 | def __init__(self, **entries):
176 | self.__dict__.update(entries)
177 |
178 |
179 | def _run():
180 | args = _get_args()
181 | with open(args.config_path, 'r') as fh:
182 | config = Config(**json.load(fh))
183 | main(config)
184 |
185 |
186 | if __name__ == "__main__":
187 | _run()
188 |
--------------------------------------------------------------------------------
/tree/read_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | import itertools
5 | import math
6 |
7 | import nltk
8 |
9 | from my.nltk_utils import load_compressed_tree
10 | from my.utils import index
11 |
12 |
13 | class DataSet(object):
14 | def __init__(self, data, data_type, shared=None, valid_idxs=None):
15 | total_num_examples = len(next(iter(data.values())))
16 | self.data = data # e.g. {'X': [0, 1, 2], 'Y': [2, 3, 4]}
17 | self.data_type = data_type
18 | self.shared = shared
19 | self.valid_idxs = range(total_num_examples) if valid_idxs is None else valid_idxs
20 | self.num_examples = len(self.valid_idxs)
21 |
22 | def get_batches(self, batch_size, num_batches=None, shuffle=False):
23 | num_batches_per_epoch = int(math.ceil(self.num_examples / batch_size))
24 | if num_batches is None:
25 | num_batches = num_batches_per_epoch
26 | num_epochs = int(math.ceil(num_batches / num_batches_per_epoch))
27 |
28 | idxs = itertools.chain.from_iterable(random.sample(self.valid_idxs, len(self.valid_idxs))
29 | if shuffle else self.valid_idxs
30 | for _ in range(num_epochs))
31 | for _ in range(num_batches):
32 | batch_idxs = tuple(itertools.islice(idxs, batch_size))
33 | batch_data = {}
34 | for key, val in self.data.items():
35 | if key.startswith('*'):
36 | assert self.shared is not None
37 | shared_key = key[1:]
38 | batch_data[shared_key] = [index(self.shared[shared_key], val[idx]) for idx in batch_idxs]
39 | else:
40 | batch_data[key] = list(map(val.__getitem__, batch_idxs))
41 |
42 | batch_ds = DataSet(batch_data, self.data_type, shared=self.shared)
43 | yield batch_idxs, batch_ds
44 |
45 |
46 | class SquadDataSet(DataSet):
47 | def __init__(self, data, data_type, shared=None, valid_idxs=None):
48 | super(SquadDataSet, self).__init__(data, data_type, shared=shared, valid_idxs=valid_idxs)
49 |
50 |
51 | def load_metadata(config, data_type):
52 | metadata_path = os.path.join(config.data_dir, "metadata_{}.json".format(data_type))
53 | with open(metadata_path, 'r') as fh:
54 | metadata = json.load(fh)
55 | for key, val in metadata.items():
56 | config.__setattr__(key, val)
57 | return metadata
58 |
59 |
60 | def read_data(config, data_type, ref, data_filter=None):
61 | data_path = os.path.join(config.data_dir, "data_{}.json".format(data_type))
62 | shared_path = os.path.join(config.data_dir, "shared_{}.json".format(data_type))
63 | with open(data_path, 'r') as fh:
64 | data = json.load(fh)
65 | with open(shared_path, 'r') as fh:
66 | shared = json.load(fh)
67 |
68 | num_examples = len(next(iter(data.values())))
69 | if data_filter is None:
70 | valid_idxs = range(num_examples)
71 | else:
72 | mask = []
73 | keys = data.keys()
74 | values = data.values()
75 | for vals in zip(*values):
76 | each = {key: val for key, val in zip(keys, vals)}
77 | mask.append(data_filter(each, shared))
78 | valid_idxs = [idx for idx in range(len(mask)) if mask[idx]]
79 |
80 | print("Loaded {}/{} examples from {}".format(len(valid_idxs), num_examples, data_type))
81 |
82 | shared_path = os.path.join(config.out_dir, "shared.json")
83 | if not ref:
84 | word_counter = shared['lower_word_counter'] if config.lower_word else shared['word_counter']
85 | char_counter = shared['char_counter']
86 | pos_counter = shared['pos_counter']
87 | shared['word2idx'] = {word: idx + 2 for idx, word in
88 | enumerate(word for word, count in word_counter.items()
89 | if count > config.word_count_th)}
90 | shared['char2idx'] = {char: idx + 2 for idx, char in
91 | enumerate(char for char, count in char_counter.items()
92 | if count > config.char_count_th)}
93 | shared['pos2idx'] = {pos: idx + 2 for idx, pos in enumerate(pos_counter.keys())}
94 | NULL = "-NULL-"
95 | UNK = "-UNK-"
96 | shared['word2idx'][NULL] = 0
97 | shared['word2idx'][UNK] = 1
98 | shared['char2idx'][NULL] = 0
99 | shared['char2idx'][UNK] = 1
100 | shared['pos2idx'][NULL] = 0
101 | shared['pos2idx'][UNK] = 1
102 | json.dump({'word2idx': shared['word2idx'], 'char2idx': shared['char2idx'],
103 | 'pos2idx': shared['pos2idx']}, open(shared_path, 'w'))
104 | else:
105 | new_shared = json.load(open(shared_path, 'r'))
106 | for key, val in new_shared.items():
107 | shared[key] = val
108 |
109 | data_set = DataSet(data, data_type, shared=shared, valid_idxs=valid_idxs)
110 | return data_set
111 |
112 |
113 | def get_squad_data_filter(config):
114 | def data_filter(data_point, shared):
115 | assert shared is not None
116 | rx, rcx, q, cq, y = (data_point[key] for key in ('*x', '*cx', 'q', 'cq', 'y'))
117 | x, cx, stx = shared['x'], shared['cx'], shared['stx']
118 | if len(q) > config.ques_size_th:
119 | return False
120 | xi = x[rx[0]][rx[1]]
121 | if len(xi) > config.num_sents_th:
122 | return False
123 | if any(len(xij) > config.sent_size_th for xij in xi):
124 | return False
125 | stxi = stx[rx[0]][rx[1]]
126 | if any(nltk.tree.Tree.fromstring(s).height() > config.tree_height_th for s in stxi):
127 | return False
128 | return True
129 | return data_filter
130 |
131 |
132 | def update_config(config, data_sets):
133 | config.max_num_sents = 0
134 | config.max_sent_size = 0
135 | config.max_ques_size = 0
136 | config.max_word_size = 0
137 | config.max_tree_height = 0
138 | for data_set in data_sets:
139 | data = data_set.data
140 | shared = data_set.shared
141 | for idx in data_set.valid_idxs:
142 | rx = data['*x'][idx]
143 | q = data['q'][idx]
144 | sents = shared['x'][rx[0]][rx[1]]
145 | trees = map(nltk.tree.Tree.fromstring, shared['stx'][rx[0]][rx[1]])
146 | config.max_tree_height = max(config.max_tree_height, max(tree.height() for tree in trees))
147 | config.max_num_sents = max(config.max_num_sents, len(sents))
148 | config.max_sent_size = max(config.max_sent_size, max(map(len, sents)))
149 | config.max_word_size = max(config.max_word_size, max(len(word) for sent in sents for word in sent))
150 | if len(q) > 0:
151 | config.max_ques_size = max(config.max_ques_size, len(q))
152 | config.max_word_size = max(config.max_word_size, max(len(word) for word in q))
153 |
154 | config.max_word_size = min(config.max_word_size, config.word_size_th)
155 |
156 | config.char_vocab_size = len(data_sets[0].shared['char2idx'])
157 | config.word_emb_size = len(next(iter(data_sets[0].shared['word2vec'].values())))
158 | config.word_vocab_size = len(data_sets[0].shared['word2idx'])
159 | config.pos_vocab_size = len(data_sets[0].shared['pos2idx'])
160 |
--------------------------------------------------------------------------------
/tree/templates/visualizer.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | {{ title }}
6 |
7 |
8 |
19 |
20 |
23 |
24 | {{ title }}
25 |
26 |
27 | ID |
28 | Question |
29 | Answer |
30 | Paragraph |
31 |
32 | {% for row in rows %}
33 |
34 | {{ row.id }} |
35 |
36 | {% for qj in row.ques %}
37 | {{ qj }}
38 | {% endfor %}
39 | |
40 | {{ row.a }} |
41 |
42 |
43 | {% for xj, yj, y2j, ypj, yp2j in zip(row.para, row.y, row.y2, row.yp, row.yp2) %}
44 |
45 | {% for xjk, yjk, y2jk, ypjk in zip(xj, yj, y2j, ypj) %}
46 |
47 | {% if yjk or y2jk %}
48 | {{ xjk }}
49 | {% else %}
50 | {{ xjk }}
51 | {% endif %}
52 | |
53 | {% endfor %}
54 |
55 |
56 | {% for xjk, yp2jk in zip(xj, yp2j) %}
57 | - |
58 | {% endfor %}
59 |
60 | {% endfor %}
61 |
62 | |
63 |
64 | {% endfor %}
65 |
66 |
67 |
--------------------------------------------------------------------------------
/tree/test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import nltk\n",
12 | "import matplotlib.pyplot as plt\n",
13 | "%matplotlib inline"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 10,
19 | "metadata": {
20 | "collapsed": false
21 | },
22 | "outputs": [
23 | {
24 | "name": "stdout",
25 | "output_type": "stream",
26 | "text": [
27 | "(S (PRP I) (VP (VBP am) (NNP Sam)) (. .))\n",
28 | "(PRP I)\n",
29 | "(VP (VBP am) (NNP Sam))\n",
30 | "(VBP am)\n",
31 | "(NNP Sam)\n",
32 | "(. .)\n",
33 | "(S (PRP I) (VP (VBP am) (NNP Sam)) (. .))\n"
34 | ]
35 | }
36 | ],
37 | "source": [
38 | "string = \"(ROOT(S(NP (PRP I))(VP (VBP am)(NP (NNP Sam)))(. .)))\"\n",
39 | "tree = nltk.tree.Tree.fromstring(string)\n",
40 | "\n",
41 | "def load_compressed_tree(s):\n",
42 | "\n",
43 | " def compress_tree(tree):\n",
44 | " if len(tree) == 1:\n",
45 | " if isinstance(tree[0], nltk.tree.Tree):\n",
46 | " return compress_tree(tree[0])\n",
47 | " else:\n",
48 | " return tree\n",
49 | " else:\n",
50 | " for i, t in enumerate(tree):\n",
51 | " tree[i] = compress_tree(t)\n",
52 | " return tree\n",
53 | "\n",
54 | " return compress_tree(nltk.tree.Tree.fromstring(s))\n",
55 | "tree = load_compressed_tree(string)\n",
56 | "for t in tree.subtrees():\n",
57 | " print(t)\n",
58 | " \n",
59 | "print(str(tree))"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 3,
65 | "metadata": {
66 | "collapsed": false
67 | },
68 | "outputs": [
69 | {
70 | "name": "stdout",
71 | "output_type": "stream",
72 | "text": [
73 | "(ROOT I am Sam .)\n"
74 | ]
75 | }
76 | ],
77 | "source": [
78 | "print(tree.flatten())"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 10,
84 | "metadata": {
85 | "collapsed": false
86 | },
87 | "outputs": [
88 | {
89 | "name": "stdout",
90 | "output_type": "stream",
91 | "text": [
92 | "['ROOT', 'S', 'NP', 'PRP', 'VP', 'VBP', 'NP', 'NNP', '.']\n"
93 | ]
94 | }
95 | ],
96 | "source": [
97 | "print(list(t.label() for t in tree.subtrees()))"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 11,
103 | "metadata": {
104 | "collapsed": true
105 | },
106 | "outputs": [],
107 | "source": [
108 | "import json\n",
109 | "d = json.load(open(\"data/squad/shared_dev.json\", 'r'))"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 12,
115 | "metadata": {
116 | "collapsed": false
117 | },
118 | "outputs": [
119 | {
120 | "data": {
121 | "text/plain": [
122 | "73"
123 | ]
124 | },
125 | "execution_count": 12,
126 | "metadata": {},
127 | "output_type": "execute_result"
128 | }
129 | ],
130 | "source": [
131 | "len(d['pos_counter'])"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": 13,
137 | "metadata": {
138 | "collapsed": false
139 | },
140 | "outputs": [
141 | {
142 | "data": {
143 | "text/plain": [
144 | "{'#': 6,\n",
145 | " '$': 80,\n",
146 | " \"''\": 1291,\n",
147 | " ',': 14136,\n",
148 | " '-LRB-': 1926,\n",
149 | " '-RRB-': 1925,\n",
150 | " '.': 9505,\n",
151 | " ':': 1455,\n",
152 | " 'ADJP': 3426,\n",
153 | " 'ADVP': 4936,\n",
154 | " 'CC': 9300,\n",
155 | " 'CD': 6216,\n",
156 | " 'CONJP': 191,\n",
157 | " 'DT': 26286,\n",
158 | " 'EX': 288,\n",
159 | " 'FRAG': 107,\n",
160 | " 'FW': 96,\n",
161 | " 'IN': 32564,\n",
162 | " 'INTJ': 12,\n",
163 | " 'JJ': 21452,\n",
164 | " 'JJR': 563,\n",
165 | " 'JJS': 569,\n",
166 | " 'LS': 7,\n",
167 | " 'LST': 1,\n",
168 | " 'MD': 1051,\n",
169 | " 'NAC': 19,\n",
170 | " 'NN': 34750,\n",
171 | " 'NNP': 28392,\n",
172 | " 'NNPS': 1400,\n",
173 | " 'NNS': 16716,\n",
174 | " 'NP': 91636,\n",
175 | " 'NP-TMP': 236,\n",
176 | " 'NX': 108,\n",
177 | " 'PDT': 89,\n",
178 | " 'POS': 1451,\n",
179 | " 'PP': 33278,\n",
180 | " 'PRN': 2085,\n",
181 | " 'PRP': 2320,\n",
182 | " 'PRP$': 1959,\n",
183 | " 'PRT': 450,\n",
184 | " 'QP': 838,\n",
185 | " 'RB': 7611,\n",
186 | " 'RBR': 301,\n",
187 | " 'RBS': 252,\n",
188 | " 'ROOT': 9587,\n",
189 | " 'RP': 454,\n",
190 | " 'RRC': 19,\n",
191 | " 'S': 21557,\n",
192 | " 'SBAR': 5009,\n",
193 | " 'SBARQ': 6,\n",
194 | " 'SINV': 135,\n",
195 | " 'SQ': 5,\n",
196 | " 'SYM': 17,\n",
197 | " 'TO': 5167,\n",
198 | " 'UCP': 143,\n",
199 | " 'UH': 15,\n",
200 | " 'VB': 4197,\n",
201 | " 'VBD': 8377,\n",
202 | " 'VBG': 3570,\n",
203 | " 'VBN': 7218,\n",
204 | " 'VBP': 2897,\n",
205 | " 'VBZ': 4146,\n",
206 | " 'VP': 33696,\n",
207 | " 'WDT': 1368,\n",
208 | " 'WHADJP': 5,\n",
209 | " 'WHADVP': 439,\n",
210 | " 'WHNP': 1927,\n",
211 | " 'WHPP': 153,\n",
212 | " 'WP': 482,\n",
213 | " 'WP$': 50,\n",
214 | " 'WRB': 442,\n",
215 | " 'X': 23,\n",
216 | " '``': 1269}"
217 | ]
218 | },
219 | "execution_count": 13,
220 | "metadata": {},
221 | "output_type": "execute_result"
222 | }
223 | ],
224 | "source": [
225 | "d['pos_counter']"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": 3,
231 | "metadata": {
232 | "collapsed": false
233 | },
234 | "outputs": [
235 | {
236 | "name": "stdout",
237 | "output_type": "stream",
238 | "text": [
239 | "[[False False False False]\n",
240 | " [False True False False]\n",
241 | " [False False False False]]\n",
242 | "[[0 2 2 0]\n",
243 | " [2 2 0 2]\n",
244 | " [2 0 0 0]]\n"
245 | ]
246 | }
247 | ],
248 | "source": [
249 | "from my.nltk_utils import tree2matrix, load_compressed_tree, find_max_f1_subtree, set_span\n",
250 | "string = \"(ROOT(S(NP (PRP I))(VP (VBP am)(NP (NNP Sam)))(. .)))\"\n",
251 | "tree = load_compressed_tree(string)\n",
252 | "span = (1, 3)\n",
253 | "set_span(tree)\n",
254 | "subtree = find_max_f1_subtree(tree, span)\n",
255 | "f = lambda t: t == subtree\n",
256 | "g = lambda t: 1 if isinstance(t, str) else 2\n",
257 | "a, b = tree2matrix(tree, f, dtype='bool')\n",
258 | "c, d = tree2matrix(tree, g, dtype='int32')\n",
259 | "print(a)\n",
260 | "print(c)"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": null,
266 | "metadata": {
267 | "collapsed": true
268 | },
269 | "outputs": [],
270 | "source": []
271 | }
272 | ],
273 | "metadata": {
274 | "kernelspec": {
275 | "display_name": "Python 3",
276 | "language": "python",
277 | "name": "python3"
278 | },
279 | "language_info": {
280 | "codemirror_mode": {
281 | "name": "ipython",
282 | "version": 3
283 | },
284 | "file_extension": ".py",
285 | "mimetype": "text/x-python",
286 | "name": "python",
287 | "nbconvert_exporter": "python",
288 | "pygments_lexer": "ipython3",
289 | "version": "3.5.1"
290 | }
291 | },
292 | "nbformat": 4,
293 | "nbformat_minor": 0
294 | }
295 |
--------------------------------------------------------------------------------
/tree/trainer.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tree.model import Model
4 |
5 |
6 | class Trainer(object):
7 | def __init__(self, config, model):
8 | assert isinstance(model, Model)
9 | self.config = config
10 | self.model = model
11 | self.opt = tf.train.AdagradOptimizer(config.init_lr)
12 | self.loss = model.get_loss()
13 | self.var_list = model.get_var_list()
14 | self.global_step = model.get_global_step()
15 | self.ema_op = model.ema_op
16 | self.summary = model.summary
17 | self.grads = self.opt.compute_gradients(self.loss, var_list=self.var_list)
18 | opt_op = self.opt.apply_gradients(self.grads, global_step=self.global_step)
19 |
20 | # Define train op
21 | with tf.control_dependencies([opt_op]):
22 | self.train_op = tf.group(self.ema_op)
23 |
24 | def get_train_op(self):
25 | return self.train_op
26 |
27 | def step(self, sess, batch, get_summary=False):
28 | assert isinstance(sess, tf.Session)
29 | feed_dict = self.model.get_feed_dict(batch, True)
30 | if get_summary:
31 | loss, summary, train_op = \
32 | sess.run([self.loss, self.summary, self.train_op], feed_dict=feed_dict)
33 | else:
34 | loss, train_op = sess.run([self.loss, self.train_op], feed_dict=feed_dict)
35 | summary = None
36 | return loss, summary, train_op
37 |
--------------------------------------------------------------------------------
/tree/visualizer.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | from collections import OrderedDict
3 | import http.server
4 | import socketserver
5 | import argparse
6 | import json
7 | import os
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | from jinja2 import Environment, FileSystemLoader
12 |
13 |
14 | def bool_(string):
15 | if string == 'True':
16 | return True
17 | elif string == 'False':
18 | return False
19 | else:
20 | raise Exception()
21 |
22 | def get_args():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--model_name", type=str, default='basic')
25 | parser.add_argument("--data_type", type=str, default='dev')
26 | parser.add_argument("--step", type=int, default=5000)
27 | parser.add_argument("--template_name", type=str, default="visualizer.html")
28 | parser.add_argument("--num_per_page", type=int, default=100)
29 | parser.add_argument("--data_dir", type=str, default="data/squad")
30 | parser.add_argument("--port", type=int, default=8000)
31 | parser.add_argument("--host", type=str, default="0.0.0.0")
32 | parser.add_argument("--open", type=str, default='False')
33 | parser.add_argument("--run_id", type=str, default="0")
34 |
35 | args = parser.parse_args()
36 | return args
37 |
38 |
39 | def _decode(decoder, sent):
40 | return " ".join(decoder[idx] for idx in sent)
41 |
42 |
43 | def accuracy2_visualizer(args):
44 | model_name = args.model_name
45 | data_type = args.data_type
46 | num_per_page = args.num_per_page
47 | data_dir = args.data_dir
48 | run_id = args.run_id.zfill(2)
49 | step = args.step
50 |
51 | eval_path =os.path.join("out", model_name, run_id, "eval", "{}-{}.json".format(data_type, str(step).zfill(6)))
52 | eval_ = json.load(open(eval_path, 'r'))
53 |
54 | _id = 0
55 | html_dir = "/tmp/list_results%d" % _id
56 | while os.path.exists(html_dir):
57 | _id += 1
58 | html_dir = "/tmp/list_results%d" % _id
59 |
60 | if os.path.exists(html_dir):
61 | shutil.rmtree(html_dir)
62 | os.mkdir(html_dir)
63 |
64 | cur_dir = os.path.dirname(os.path.realpath(__file__))
65 | templates_dir = os.path.join(cur_dir, 'templates')
66 | env = Environment(loader=FileSystemLoader(templates_dir))
67 | env.globals.update(zip=zip, reversed=reversed)
68 | template = env.get_template(args.template_name)
69 |
70 | data_path = os.path.join(data_dir, "data_{}.json".format(data_type))
71 | shared_path = os.path.join(data_dir, "shared_{}.json".format(data_type))
72 | data = json.load(open(data_path, 'r'))
73 | shared = json.load(open(shared_path, 'r'))
74 |
75 | rows = []
76 | for i, (idx, yi, ypi) in enumerate(zip(*[eval_[key] for key in ('idxs', 'y', 'yp')])):
77 | id_, q, rx = (data[key][idx] for key in ('ids', 'q', '*x'))
78 | x = shared['x'][rx[0]][rx[1]]
79 | ques = [" ".join(q)]
80 | para = [[word for word in sent] for sent in x]
81 | row = {
82 | 'id': id_,
83 | 'title': "Hello world!",
84 | 'ques': ques,
85 | 'para': para,
86 | 'y': yi,
87 | 'y2': yi,
88 | 'yp': ypi,
89 | 'yp2': ypi,
90 | 'a': ""
91 | }
92 | rows.append(row)
93 |
94 | if i % num_per_page == 0:
95 | html_path = os.path.join(html_dir, "%s.html" % str(i).zfill(8))
96 |
97 | if (i + 1) % num_per_page == 0 or (i + 1) == len(eval_['y']):
98 | var_dict = {'title': "Accuracy Visualization",
99 | 'rows': rows
100 | }
101 | with open(html_path, "wb") as f:
102 | f.write(template.render(**var_dict).encode('UTF-8'))
103 | rows = []
104 |
105 | os.chdir(html_dir)
106 | port = args.port
107 | host = args.host
108 | # Overriding to suppress log message
109 | class MyHandler(http.server.SimpleHTTPRequestHandler):
110 | def log_message(self, format, *args):
111 | pass
112 | handler = MyHandler
113 | httpd = socketserver.TCPServer((host, port), handler)
114 | if args.open == 'True':
115 | os.system("open http://%s:%d" % (args.host, args.port))
116 | print("serving at %s:%d" % (host, port))
117 | httpd.serve_forever()
118 |
119 |
120 | if __name__ == "__main__":
121 | ARGS = get_args()
122 | accuracy2_visualizer(ARGS)
--------------------------------------------------------------------------------