├── LICENSE ├── README.md ├── basic ├── __init__.py ├── cli.py ├── ensemble.py ├── ensemble_fast.py ├── evaluator.py ├── graph_handler.py ├── main.py ├── model.py ├── read_data.py ├── run_ensemble.sh ├── run_single.sh ├── templates │ └── visualizer.html ├── trainer.py └── visualizer.py ├── basic_cnn ├── __init__.py ├── cli.py ├── evaluator.py ├── graph_handler.py ├── main.py ├── model.py ├── read_data.py ├── superhighway.py ├── templates │ └── visualizer.html ├── trainer.py └── visualizer.py ├── cnn_dm ├── __init__.py ├── eda.ipynb ├── evaluate.py └── prepro.py ├── download.sh ├── my ├── __init__.py ├── corenlp_interface.py ├── nltk_utils.py ├── tensorflow │ ├── __init__.py │ ├── general.py │ ├── nn.py │ ├── rnn.py │ └── rnn_cell.py ├── utils.py └── zip_save.py ├── requirements.txt ├── squad ├── __init__.py ├── aug_squad.py ├── eda_aug_dev.ipynb ├── eda_aug_train.ipynb ├── evaluate-v1.1.py ├── evaluate.py ├── prepro.py ├── prepro_aug.py └── utils.py ├── tree ├── __init__.py ├── cli.py ├── evaluator.py ├── graph_handler.py ├── main.py ├── model.py ├── read_data.py ├── templates │ └── visualizer.html ├── test.ipynb ├── trainer.py └── visualizer.py └── visualization └── compare_models.py /README.md: -------------------------------------------------------------------------------- 1 | # Bi-directional Attention Flow for Machine Comprehension 2 | 3 | - This the original implementation of [Bi-directional Attention Flow for Machine Comprehension][paper]. 4 | - The CodaLab worksheet for the [SQuAD Leaderboard][squad] submission is available [here][worksheet]. 5 | - For TensorFlow v1.2 compatible version, see the [dev][dev] branch. 6 | - Please contact [Minjoon Seo][minjoon] ([@seominjoon][minjoon-github]) for questions and suggestions. 7 | 8 | ## 0. Requirements 9 | #### General 10 | - Python (verified on 3.5.2. Issues have been reported with Python 2!) 11 | - unzip, wget (for running `download.sh` only) 12 | 13 | #### Python Packages 14 | - tensorflow (deep learning library, only works on r0.11) 15 | - nltk (NLP tools, verified on 3.2.1) 16 | - tqdm (progress bar, verified on 4.7.4) 17 | - jinja2 (for visaulization; if you only train and test, not needed) 18 | 19 | ## 1. Pre-processing 20 | First, prepare data. Donwload SQuAD data and GloVe and nltk corpus 21 | (~850 MB, this will download files to `$HOME/data`): 22 | ``` 23 | chmod +x download.sh; ./download.sh 24 | ``` 25 | 26 | Second, Preprocess Stanford QA dataset (along with GloVe vectors) and save them in `$PWD/data/squad` (~5 minutes): 27 | ``` 28 | python -m squad.prepro 29 | ``` 30 | 31 | ## 2. Training 32 | The model has ~2.5M parameters. 33 | The model was trained with NVidia Titan X (Pascal Architecture, 2016). 34 | The model requires at least 12GB of GPU RAM. 35 | If your GPU RAM is smaller than 12GB, you can either decrease batch size (performance might degrade), 36 | or you can use multi GPU (see below). 37 | The training converges at ~18k steps, and it took ~4s per step (i.e. ~20 hours). 38 | 39 | Before training, it is recommended to first try the following code to verify everything is okay and memory is sufficient: 40 | ``` 41 | python -m basic.cli --mode train --noload --debug 42 | ``` 43 | 44 | Then to fully train, run: 45 | ``` 46 | python -m basic.cli --mode train --noload 47 | ``` 48 | 49 | You can speed up the training process with optimization flags: 50 | ``` 51 | python -m basic.cli --mode train --noload --len_opt --cluster 52 | ``` 53 | You can still omit them, but training will be much slower. 54 | 55 | Note that during the training, the EM and F1 scores from the occasional evaluation are not the same with the score from official squad evaluation script. 56 | The printed scores are not official (our scoring scheme is a bit harsher). 57 | To obtain the official number, use the official evaluator (copied in `squad` folder, `squad/evaluate-v1.1.py`). For more information See 3.Test. 58 | 59 | 60 | ## 3. Test 61 | To test, run: 62 | ``` 63 | python -m basic.cli 64 | ``` 65 | 66 | Similarly to training, you can give the optimization flags to speed up test (5 minutes on dev data): 67 | ``` 68 | python -m basic.cli --len_opt --cluster 69 | ``` 70 | 71 | This command loads the most recently saved model during training and begins testing on the test data. 72 | After the process ends, it prints F1 and EM scores, and also outputs a json file (`$PWD/out/basic/00/answer/test-####.json`, 73 | where `####` is the step # that the model was saved). 74 | Note that the printed scores are not official (our scoring scheme is a bit harsher). 75 | To obtain the official number, use the official evaluator (copied in `squad` folder) and the output json file: 76 | 77 | ``` 78 | python squad/evaluate-v1.1.py $HOME/data/squad/dev-v1.1.json out/basic/00/answer/test-####.json 79 | ``` 80 | 81 | ### 3.1 Loading from pre-trained weights 82 | Instead of training the model yourself, you can choose to use pre-trained weights that were used for [SQuAD Leaderboard][squad] submission. 83 | Refer to [this worksheet][worksheet] in CodaLab to reproduce the results. 84 | If you are unfamiliar with CodaLab, follow these simple steps (given that you met all prereqs above): 85 | 86 | 1. Download `save.zip` from the [worksheet][worksheet] and unzip it in the current directory. 87 | 2. Copy `glove.6B.100d.txt` from your glove data folder (`$HOME/data/glove/`) to the current directory. 88 | 3. To reproduce single model: 89 | 90 | ``` 91 | basic/run_single.sh $HOME/data/squad/dev-v1.1.json single.json 92 | ``` 93 | 94 | This writes the answers to `single.json` in the current directory. You can then use the official evaluator to obtain EM and F1 scores. If you want to run on GPU (~5 mins), change the value of batch_size flag in the shell file to a higher number (60 for 12GB GPU RAM). 95 | 4. Similarly, to reproduce ensemble method: 96 | 97 | ``` 98 | basic/run_ensemble.sh $HOME/data/squad/dev-v1.1.json ensemble.json 99 | ``` 100 | If you want to run on GPU, you should run the script sequentially by removing '&' in the forloop, or you will need to specify different GPUs for each run of the for loop. 101 | 102 | ## Results 103 | 104 | ### Dev Data 105 | 106 | Note these scores are from the official evaluator (copied in `squad` folder, `squad/evaluate-v1.1.py`). For more information See 3.Test. 107 | The scores appeared during the training could be lower than the scores from the official evaluator. 108 | 109 | | | EM (%) | F1 (%) | 110 | | -------- |:------:|:------:| 111 | | single | 67.7 | 77.3 | 112 | | ensemble | 72.6 | 80.7 | 113 | 114 | ### Test Data 115 | 116 | | | EM (%) | F1 (%) | 117 | | -------- |:------:|:------:| 118 | | single | 68.0 | 77.3 | 119 | | ensemble | 73.3 | 81.1 | 120 | 121 | Refer to [our paper][paper] for more details. 122 | See [SQuAD Leaderboard][squad] to compare with other models. 123 | 124 | 125 | 136 | 137 | 138 | ## Multi-GPU Training & Testing 139 | Our model supports multi-GPU training. 140 | We follow the parallelization paradigm described in [TensorFlow Tutorial][multi-gpu]. 141 | In short, if you want to use batch size of 60 (default) but if you have 3 GPUs with 4GB of RAM, 142 | then you initialize each GPU with batch size of 20, and combine the gradients on CPU. 143 | This can be easily done by running: 144 | ``` 145 | python -m basic.cli --mode train --noload --num_gpus 3 --batch_size 20 146 | ``` 147 | 148 | Similarly, you can speed up your testing by: 149 | ``` 150 | python -m basic.cli --num_gpus 3 --batch_size 20 151 | ``` 152 | 153 | ## Demo 154 | For now, please refer to the `demo` branch of this repository. 155 | 156 | 157 | [multi-gpu]: https://www.tensorflow.org/versions/r0.11/tutorials/deep_cnn/index.html#training-a-model-using-multiple-gpu-cards 158 | [squad]: http://stanford-qa.com 159 | [paper]: https://arxiv.org/abs/1611.01603 160 | [worksheet]: https://worksheets.codalab.org/worksheets/0x37a9b8c44f6845c28866267ef941c89d/ 161 | [minjoon]: https://seominjoon.github.io 162 | [minjoon-github]: https://github.com/seominjoon 163 | [dev]: https://github.com/allenai/bi-att-flow/tree/dev 164 | -------------------------------------------------------------------------------- /basic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/bi-att-flow/49004549e9a88b78c359b31481afa7792dbb3f4a/basic/__init__.py -------------------------------------------------------------------------------- /basic/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | 5 | from basic.main import main as m 6 | 7 | flags = tf.app.flags 8 | 9 | # Names and directories 10 | flags.DEFINE_string("model_name", "basic", "Model name [basic]") 11 | flags.DEFINE_string("data_dir", "data/squad", "Data dir [data/squad]") 12 | flags.DEFINE_string("run_id", "0", "Run ID [0]") 13 | flags.DEFINE_string("out_base_dir", "out", "out base dir [out]") 14 | flags.DEFINE_string("forward_name", "single", "Forward name [single]") 15 | flags.DEFINE_string("answer_path", "", "Answer path []") 16 | flags.DEFINE_string("eval_path", "", "Eval path []") 17 | flags.DEFINE_string("load_path", "", "Load path []") 18 | flags.DEFINE_string("shared_path", "", "Shared path []") 19 | 20 | # Device placement 21 | flags.DEFINE_string("device", "/cpu:0", "default device for summing gradients. [/cpu:0]") 22 | flags.DEFINE_string("device_type", "gpu", "device for computing gradients (parallelization). cpu | gpu [gpu]") 23 | flags.DEFINE_integer("num_gpus", 1, "num of gpus or cpus for computing gradients [1]") 24 | 25 | # Essential training and test options 26 | flags.DEFINE_string("mode", "test", "trains | test | forward [test]") 27 | flags.DEFINE_boolean("load", True, "load saved data? [True]") 28 | flags.DEFINE_bool("single", False, "supervise only the answer sentence? [False]") 29 | flags.DEFINE_boolean("debug", False, "Debugging mode? [False]") 30 | flags.DEFINE_bool('load_ema', True, "load exponential average of variables when testing? [True]") 31 | flags.DEFINE_bool("eval", True, "eval? [True]") 32 | 33 | # Training / test parameters 34 | flags.DEFINE_integer("batch_size", 60, "Batch size [60]") 35 | flags.DEFINE_integer("val_num_batches", 100, "validation num batches [100]") 36 | flags.DEFINE_integer("test_num_batches", 0, "test num batches [0]") 37 | flags.DEFINE_integer("num_epochs", 12, "Total number of epochs for training [12]") 38 | flags.DEFINE_integer("num_steps", 20000, "Number of steps [20000]") 39 | flags.DEFINE_integer("load_step", 0, "load step [0]") 40 | flags.DEFINE_float("init_lr", 0.5, "Initial learning rate [0.5]") 41 | flags.DEFINE_float("input_keep_prob", 0.8, "Input keep prob for the dropout of LSTM weights [0.8]") 42 | flags.DEFINE_float("keep_prob", 0.8, "Keep prob for the dropout of Char-CNN weights [0.8]") 43 | flags.DEFINE_float("wd", 0.0, "L2 weight decay for regularization [0.0]") 44 | flags.DEFINE_integer("hidden_size", 100, "Hidden size [100]") 45 | flags.DEFINE_integer("char_out_size", 100, "char-level word embedding size [100]") 46 | flags.DEFINE_integer("char_emb_size", 8, "Char emb size [8]") 47 | flags.DEFINE_string("out_channel_dims", "100", "Out channel dims of Char-CNN, separated by commas [100]") 48 | flags.DEFINE_string("filter_heights", "5", "Filter heights of Char-CNN, separated by commas [5]") 49 | flags.DEFINE_bool("finetune", False, "Finetune word embeddings? [False]") 50 | flags.DEFINE_bool("highway", True, "Use highway? [True]") 51 | flags.DEFINE_integer("highway_num_layers", 2, "highway num layers [2]") 52 | flags.DEFINE_bool("share_cnn_weights", True, "Share Char-CNN weights [True]") 53 | flags.DEFINE_bool("share_lstm_weights", True, "Share pre-processing (phrase-level) LSTM weights [True]") 54 | flags.DEFINE_float("var_decay", 0.999, "Exponential moving average decay for variables [0.999]") 55 | 56 | # Optimizations 57 | flags.DEFINE_bool("cluster", False, "Cluster data for faster training [False]") 58 | flags.DEFINE_bool("len_opt", False, "Length optimization? [False]") 59 | flags.DEFINE_bool("cpu_opt", False, "CPU optimization? GPU computation can be slower [False]") 60 | 61 | # Logging and saving options 62 | flags.DEFINE_boolean("progress", True, "Show progress? [True]") 63 | flags.DEFINE_integer("log_period", 100, "Log period [100]") 64 | flags.DEFINE_integer("eval_period", 1000, "Eval period [1000]") 65 | flags.DEFINE_integer("save_period", 1000, "Save Period [1000]") 66 | flags.DEFINE_integer("max_to_keep", 20, "Max recent saves to keep [20]") 67 | flags.DEFINE_bool("dump_eval", True, "dump eval? [True]") 68 | flags.DEFINE_bool("dump_answer", True, "dump answer? [True]") 69 | flags.DEFINE_bool("vis", False, "output visualization numbers? [False]") 70 | flags.DEFINE_bool("dump_pickle", True, "Dump pickle instead of json? [True]") 71 | flags.DEFINE_float("decay", 0.9, "Exponential moving average decay for logging values [0.9]") 72 | 73 | # Thresholds for speed and less memory usage 74 | flags.DEFINE_integer("word_count_th", 10, "word count th [100]") 75 | flags.DEFINE_integer("char_count_th", 50, "char count th [500]") 76 | flags.DEFINE_integer("sent_size_th", 400, "sent size th [64]") 77 | flags.DEFINE_integer("num_sents_th", 8, "num sents th [8]") 78 | flags.DEFINE_integer("ques_size_th", 30, "ques size th [32]") 79 | flags.DEFINE_integer("word_size_th", 16, "word size th [16]") 80 | flags.DEFINE_integer("para_size_th", 256, "para size th [256]") 81 | 82 | # Advanced training options 83 | flags.DEFINE_bool("lower_word", True, "lower word [True]") 84 | flags.DEFINE_bool("squash", False, "squash the sentences into one? [False]") 85 | flags.DEFINE_bool("swap_memory", True, "swap memory? [True]") 86 | flags.DEFINE_string("data_filter", "max", "max | valid | semi [max]") 87 | flags.DEFINE_bool("use_glove_for_unk", True, "use glove for unk [False]") 88 | flags.DEFINE_bool("known_if_glove", True, "consider as known if present in glove [False]") 89 | flags.DEFINE_string("logit_func", "tri_linear", "logit func [tri_linear]") 90 | flags.DEFINE_string("answer_func", "linear", "answer logit func [linear]") 91 | flags.DEFINE_string("sh_logit_func", "tri_linear", "sh logit func [tri_linear]") 92 | 93 | # Ablation options 94 | flags.DEFINE_bool("use_char_emb", True, "use char emb? [True]") 95 | flags.DEFINE_bool("use_word_emb", True, "use word embedding? [True]") 96 | flags.DEFINE_bool("q2c_att", True, "question-to-context attention? [True]") 97 | flags.DEFINE_bool("c2q_att", True, "context-to-question attention? [True]") 98 | flags.DEFINE_bool("dynamic_att", False, "Dynamic attention [False]") 99 | 100 | 101 | def main(_): 102 | config = flags.FLAGS 103 | 104 | config.out_dir = os.path.join(config.out_base_dir, config.model_name, str(config.run_id).zfill(2)) 105 | 106 | m(config) 107 | 108 | if __name__ == "__main__": 109 | tf.app.run() 110 | -------------------------------------------------------------------------------- /basic/ensemble.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import gzip 4 | import json 5 | import pickle 6 | from collections import defaultdict 7 | from operator import mul 8 | 9 | from tqdm import tqdm 10 | from squad.utils import get_phrase, get_best_span 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('paths', nargs='+') 16 | parser.add_argument('-o', '--out', default='ensemble.json') 17 | parser.add_argument("--data_path", default="data/squad/data_test.json") 18 | parser.add_argument("--shared_path", default="data/squad/shared_test.json") 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def ensemble(args): 24 | e_list = [] 25 | for path in tqdm(args.paths): 26 | with gzip.open(path, 'r') as fh: 27 | e = pickle.load(fh) 28 | e_list.append(e) 29 | 30 | with open(args.data_path, 'r') as fh: 31 | data = json.load(fh) 32 | 33 | with open(args.shared_path, 'r') as fh: 34 | shared = json.load(fh) 35 | 36 | out = {} 37 | for idx, (id_, rx) in tqdm(enumerate(zip(data['ids'], data['*x'])), total=len(e['yp'])): 38 | if idx >= len(e['yp']): 39 | # for debugging purpose 40 | break 41 | context = shared['p'][rx[0]][rx[1]] 42 | wordss = shared['x'][rx[0]][rx[1]] 43 | yp_list = [e['yp'][idx] for e in e_list] 44 | yp2_list = [e['yp2'][idx] for e in e_list] 45 | answer = ensemble3(context, wordss, yp_list, yp2_list) 46 | out[id_] = answer 47 | 48 | with open(args.out, 'w') as fh: 49 | json.dump(out, fh) 50 | 51 | 52 | def ensemble1(context, wordss, y1_list, y2_list): 53 | """ 54 | 55 | :param context: Original context 56 | :param wordss: tokenized words (nested 2D list) 57 | :param y1_list: list of start index probs (each element corresponds to probs form single model) 58 | :param y2_list: list of stop index probs 59 | :return: 60 | """ 61 | sum_y1 = combine_y_list(y1_list) 62 | sum_y2 = combine_y_list(y2_list) 63 | span, score = get_best_span(sum_y1, sum_y2) 64 | return get_phrase(context, wordss, span) 65 | 66 | 67 | def ensemble2(context, wordss, y1_list, y2_list): 68 | start_dict = defaultdict(float) 69 | stop_dict = defaultdict(float) 70 | for y1, y2 in zip(y1_list, y2_list): 71 | span, score = get_best_span(y1, y2) 72 | start_dict[span[0]] += y1[span[0][0]][span[0][1]] 73 | stop_dict[span[1]] += y2[span[1][0]][span[1][1]] 74 | start = max(start_dict.items(), key=lambda pair: pair[1])[0] 75 | stop = max(stop_dict.items(), key=lambda pair: pair[1])[0] 76 | best_span = (start, stop) 77 | return get_phrase(context, wordss, best_span) 78 | 79 | 80 | def ensemble3(context, wordss, y1_list, y2_list): 81 | d = defaultdict(float) 82 | for y1, y2 in zip(y1_list, y2_list): 83 | span, score = get_best_span(y1, y2) 84 | phrase = get_phrase(context, wordss, span) 85 | d[phrase] += score 86 | return max(d.items(), key=lambda pair: pair[1])[0] 87 | 88 | 89 | def combine_y_list(y_list, op='*'): 90 | if op == '+': 91 | func = sum 92 | elif op == '*': 93 | def func(l): return functools.reduce(mul, l) 94 | else: 95 | func = op 96 | return [[func(yij_list) for yij_list in zip(*yi_list)] for yi_list in zip(*y_list)] 97 | 98 | 99 | def main(): 100 | args = get_args() 101 | ensemble(args) 102 | 103 | if __name__ == "__main__": 104 | main() 105 | 106 | 107 | -------------------------------------------------------------------------------- /basic/ensemble_fast.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | from collections import Counter, defaultdict 4 | import re 5 | 6 | def key_func(pair): 7 | return pair[1] 8 | 9 | 10 | def get_func(vals, probs): 11 | counter = Counter(vals) 12 | # return max(zip(vals, probs), key=lambda pair: pair[1])[0] 13 | # return max(zip(vals, probs), key=lambda pair: pair[1] * counter[pair[0]] / len(counter) - 999 * (len(pair[0]) == 0) )[0] 14 | # return max(zip(vals, probs), key=lambda pair: pair[1] + 0.7 * counter[pair[0]] / len(counter) - 999 * (len(pair[0]) == 0) )[0] 15 | d = defaultdict(float) 16 | for val, prob in zip(vals, probs): 17 | d[val] += prob 18 | d[''] = 0 19 | return max(d.items(), key=lambda pair: pair[1])[0] 20 | 21 | third_path = sys.argv[1] 22 | other_paths = sys.argv[2:] 23 | 24 | others = [json.load(open(path, 'r')) for path in other_paths] 25 | 26 | 27 | c = {} 28 | 29 | assert min(map(len, others)) == max(map(len, others)), list(map(len, others)) 30 | 31 | for key in others[0].keys(): 32 | if key == 'scores': 33 | continue 34 | probs = [other['scores'][key] for other in others] 35 | vals = [other[key] for other in others] 36 | largest_val = get_func(vals, probs) 37 | c[key] = largest_val 38 | 39 | json.dump(c, open(third_path, 'w')) -------------------------------------------------------------------------------- /basic/graph_handler.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | from json import encoder 4 | import os 5 | 6 | import tensorflow as tf 7 | 8 | from basic.evaluator import Evaluation, F1Evaluation 9 | from my.utils import short_floats 10 | 11 | import pickle 12 | 13 | 14 | class GraphHandler(object): 15 | def __init__(self, config, model): 16 | self.config = config 17 | self.model = model 18 | self.saver = tf.train.Saver(max_to_keep=config.max_to_keep) 19 | self.writer = None 20 | self.save_path = os.path.join(config.save_dir, config.model_name) 21 | 22 | def initialize(self, sess): 23 | sess.run(tf.initialize_all_variables()) 24 | if self.config.load: 25 | self._load(sess) 26 | 27 | if self.config.mode == 'train': 28 | self.writer = tf.train.SummaryWriter(self.config.log_dir, graph=tf.get_default_graph()) 29 | 30 | def save(self, sess, global_step=None): 31 | saver = tf.train.Saver(max_to_keep=self.config.max_to_keep) 32 | saver.save(sess, self.save_path, global_step=global_step) 33 | 34 | def _load(self, sess): 35 | config = self.config 36 | vars_ = {var.name.split(":")[0]: var for var in tf.all_variables()} 37 | if config.load_ema: 38 | ema = self.model.var_ema 39 | for var in tf.trainable_variables(): 40 | del vars_[var.name.split(":")[0]] 41 | vars_[ema.average_name(var)] = var 42 | saver = tf.train.Saver(vars_, max_to_keep=config.max_to_keep) 43 | 44 | if config.load_path: 45 | save_path = config.load_path 46 | elif config.load_step > 0: 47 | save_path = os.path.join(config.save_dir, "{}-{}".format(config.model_name, config.load_step)) 48 | else: 49 | save_dir = config.save_dir 50 | checkpoint = tf.train.get_checkpoint_state(save_dir) 51 | assert checkpoint is not None, "cannot load checkpoint at {}".format(save_dir) 52 | save_path = checkpoint.model_checkpoint_path 53 | print("Loading saved model from {}".format(save_path)) 54 | saver.restore(sess, save_path) 55 | 56 | def add_summary(self, summary, global_step): 57 | self.writer.add_summary(summary, global_step) 58 | 59 | def add_summaries(self, summaries, global_step): 60 | for summary in summaries: 61 | self.add_summary(summary, global_step) 62 | 63 | def dump_eval(self, e, precision=2, path=None): 64 | assert isinstance(e, Evaluation) 65 | if self.config.dump_pickle: 66 | path = path or os.path.join(self.config.eval_dir, "{}-{}.pklz".format(e.data_type, str(e.global_step).zfill(6))) 67 | with gzip.open(path, 'wb', compresslevel=3) as fh: 68 | pickle.dump(e.dict, fh) 69 | else: 70 | path = path or os.path.join(self.config.eval_dir, "{}-{}.json".format(e.data_type, str(e.global_step).zfill(6))) 71 | with open(path, 'w') as fh: 72 | json.dump(short_floats(e.dict, precision), fh) 73 | 74 | def dump_answer(self, e, path=None): 75 | assert isinstance(e, Evaluation) 76 | path = path or os.path.join(self.config.answer_dir, "{}-{}.json".format(e.data_type, str(e.global_step).zfill(6))) 77 | with open(path, 'w') as fh: 78 | json.dump(e.id2answer_dict, fh) 79 | 80 | -------------------------------------------------------------------------------- /basic/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import os 5 | import shutil 6 | from pprint import pprint 7 | 8 | import tensorflow as tf 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | from basic.evaluator import ForwardEvaluator, MultiGPUF1Evaluator 13 | from basic.graph_handler import GraphHandler 14 | from basic.model import get_multi_gpu_models 15 | from basic.trainer import MultiGPUTrainer 16 | from basic.read_data import read_data, get_squad_data_filter, update_config 17 | 18 | 19 | def main(config): 20 | set_dirs(config) 21 | with tf.device(config.device): 22 | if config.mode == 'train': 23 | _train(config) 24 | elif config.mode == 'test': 25 | _test(config) 26 | elif config.mode == 'forward': 27 | _forward(config) 28 | else: 29 | raise ValueError("invalid value for 'mode': {}".format(config.mode)) 30 | 31 | 32 | def set_dirs(config): 33 | # create directories 34 | assert config.load or config.mode == 'train', "config.load must be True if not training" 35 | if not config.load and os.path.exists(config.out_dir): 36 | shutil.rmtree(config.out_dir) 37 | 38 | config.save_dir = os.path.join(config.out_dir, "save") 39 | config.log_dir = os.path.join(config.out_dir, "log") 40 | config.eval_dir = os.path.join(config.out_dir, "eval") 41 | config.answer_dir = os.path.join(config.out_dir, "answer") 42 | if not os.path.exists(config.out_dir): 43 | os.makedirs(config.out_dir) 44 | if not os.path.exists(config.save_dir): 45 | os.mkdir(config.save_dir) 46 | if not os.path.exists(config.log_dir): 47 | os.mkdir(config.log_dir) 48 | if not os.path.exists(config.answer_dir): 49 | os.mkdir(config.answer_dir) 50 | if not os.path.exists(config.eval_dir): 51 | os.mkdir(config.eval_dir) 52 | 53 | 54 | def _config_debug(config): 55 | if config.debug: 56 | config.num_steps = 2 57 | config.eval_period = 1 58 | config.log_period = 1 59 | config.save_period = 1 60 | config.val_num_batches = 2 61 | config.test_num_batches = 2 62 | 63 | 64 | def _train(config): 65 | data_filter = get_squad_data_filter(config) 66 | train_data = read_data(config, 'train', config.load, data_filter=data_filter) 67 | dev_data = read_data(config, 'dev', True, data_filter=data_filter) 68 | update_config(config, [train_data, dev_data]) 69 | 70 | _config_debug(config) 71 | 72 | word2vec_dict = train_data.shared['lower_word2vec'] if config.lower_word else train_data.shared['word2vec'] 73 | word2idx_dict = train_data.shared['word2idx'] 74 | idx2vec_dict = {word2idx_dict[word]: vec for word, vec in word2vec_dict.items() if word in word2idx_dict} 75 | emb_mat = np.array([idx2vec_dict[idx] if idx in idx2vec_dict 76 | else np.random.multivariate_normal(np.zeros(config.word_emb_size), np.eye(config.word_emb_size)) 77 | for idx in range(config.word_vocab_size)]) 78 | config.emb_mat = emb_mat 79 | 80 | # construct model graph and variables (using default graph) 81 | pprint(config.__flags, indent=2) 82 | models = get_multi_gpu_models(config) 83 | model = models[0] 84 | trainer = MultiGPUTrainer(config, models) 85 | evaluator = MultiGPUF1Evaluator(config, models, tensor_dict=model.tensor_dict if config.vis else None) 86 | graph_handler = GraphHandler(config, model) # controls all tensors and variables in the graph, including loading /saving 87 | 88 | # Variables 89 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 90 | graph_handler.initialize(sess) 91 | 92 | # Begin training 93 | num_steps = config.num_steps or int(math.ceil(train_data.num_examples / (config.batch_size * config.num_gpus))) * config.num_epochs 94 | global_step = 0 95 | for batches in tqdm(train_data.get_multi_batches(config.batch_size, config.num_gpus, 96 | num_steps=num_steps, shuffle=True, cluster=config.cluster), total=num_steps): 97 | global_step = sess.run(model.global_step) + 1 # +1 because all calculations are done after step 98 | get_summary = global_step % config.log_period == 0 99 | loss, summary, train_op = trainer.step(sess, batches, get_summary=get_summary) 100 | if get_summary: 101 | graph_handler.add_summary(summary, global_step) 102 | 103 | # occasional saving 104 | if global_step % config.save_period == 0: 105 | graph_handler.save(sess, global_step=global_step) 106 | 107 | if not config.eval: 108 | continue 109 | # Occasional evaluation 110 | if global_step % config.eval_period == 0: 111 | num_steps = math.ceil(dev_data.num_examples / (config.batch_size * config.num_gpus)) 112 | if 0 < config.val_num_batches < num_steps: 113 | num_steps = config.val_num_batches 114 | e_train = evaluator.get_evaluation_from_batches( 115 | sess, tqdm(train_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps), total=num_steps) 116 | ) 117 | graph_handler.add_summaries(e_train.summaries, global_step) 118 | e_dev = evaluator.get_evaluation_from_batches( 119 | sess, tqdm(dev_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps), total=num_steps)) 120 | graph_handler.add_summaries(e_dev.summaries, global_step) 121 | 122 | if config.dump_eval: 123 | graph_handler.dump_eval(e_dev) 124 | if config.dump_answer: 125 | graph_handler.dump_answer(e_dev) 126 | if global_step % config.save_period != 0: 127 | graph_handler.save(sess, global_step=global_step) 128 | 129 | 130 | def _test(config): 131 | test_data = read_data(config, 'test', True) 132 | update_config(config, [test_data]) 133 | 134 | _config_debug(config) 135 | 136 | if config.use_glove_for_unk: 137 | word2vec_dict = test_data.shared['lower_word2vec'] if config.lower_word else test_data.shared['word2vec'] 138 | new_word2idx_dict = test_data.shared['new_word2idx'] 139 | idx2vec_dict = {idx: word2vec_dict[word] for word, idx in new_word2idx_dict.items()} 140 | new_emb_mat = np.array([idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32') 141 | config.new_emb_mat = new_emb_mat 142 | 143 | pprint(config.__flags, indent=2) 144 | models = get_multi_gpu_models(config) 145 | model = models[0] 146 | evaluator = MultiGPUF1Evaluator(config, models, tensor_dict=models[0].tensor_dict if config.vis else None) 147 | graph_handler = GraphHandler(config, model) 148 | 149 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 150 | graph_handler.initialize(sess) 151 | num_steps = math.ceil(test_data.num_examples / (config.batch_size * config.num_gpus)) 152 | if 0 < config.test_num_batches < num_steps: 153 | num_steps = config.test_num_batches 154 | 155 | e = None 156 | for multi_batch in tqdm(test_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps, cluster=config.cluster), total=num_steps): 157 | ei = evaluator.get_evaluation(sess, multi_batch) 158 | e = ei if e is None else e + ei 159 | if config.vis: 160 | eval_subdir = os.path.join(config.eval_dir, "{}-{}".format(ei.data_type, str(ei.global_step).zfill(6))) 161 | if not os.path.exists(eval_subdir): 162 | os.mkdir(eval_subdir) 163 | path = os.path.join(eval_subdir, str(ei.idxs[0]).zfill(8)) 164 | graph_handler.dump_eval(ei, path=path) 165 | 166 | print(e) 167 | if config.dump_answer: 168 | print("dumping answer ...") 169 | graph_handler.dump_answer(e) 170 | if config.dump_eval: 171 | print("dumping eval ...") 172 | graph_handler.dump_eval(e) 173 | 174 | 175 | def _forward(config): 176 | assert config.load 177 | test_data = read_data(config, config.forward_name, True) 178 | update_config(config, [test_data]) 179 | 180 | _config_debug(config) 181 | 182 | if config.use_glove_for_unk: 183 | word2vec_dict = test_data.shared['lower_word2vec'] if config.lower_word else test_data.shared['word2vec'] 184 | new_word2idx_dict = test_data.shared['new_word2idx'] 185 | idx2vec_dict = {idx: word2vec_dict[word] for word, idx in new_word2idx_dict.items()} 186 | new_emb_mat = np.array([idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32') 187 | config.new_emb_mat = new_emb_mat 188 | 189 | pprint(config.__flags, indent=2) 190 | models = get_multi_gpu_models(config) 191 | model = models[0] 192 | evaluator = ForwardEvaluator(config, model) 193 | graph_handler = GraphHandler(config, model) # controls all tensors and variables in the graph, including loading /saving 194 | 195 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 196 | graph_handler.initialize(sess) 197 | 198 | num_batches = math.ceil(test_data.num_examples / config.batch_size) 199 | if 0 < config.test_num_batches < num_batches: 200 | num_batches = config.test_num_batches 201 | e = evaluator.get_evaluation_from_batches(sess, tqdm(test_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches)) 202 | print(e) 203 | if config.dump_answer: 204 | print("dumping answer ...") 205 | graph_handler.dump_answer(e, path=config.answer_path) 206 | if config.dump_eval: 207 | print("dumping eval ...") 208 | graph_handler.dump_eval(e, path=config.eval_path) 209 | 210 | 211 | def _get_args(): 212 | parser = argparse.ArgumentParser() 213 | parser.add_argument("config_path") 214 | return parser.parse_args() 215 | 216 | 217 | class Config(object): 218 | def __init__(self, **entries): 219 | self.__dict__.update(entries) 220 | 221 | 222 | def _run(): 223 | args = _get_args() 224 | with open(args.config_path, 'r') as fh: 225 | config = Config(**json.load(fh)) 226 | main(config) 227 | 228 | 229 | if __name__ == "__main__": 230 | _run() 231 | -------------------------------------------------------------------------------- /basic/run_ensemble.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source_path=$1 3 | target_path=$2 4 | inter_dir="inter_ensemble" 5 | root_dir="save" 6 | 7 | parg="" 8 | marg="" 9 | if [ "$3" = "debug" ] 10 | then 11 | parg="-d" 12 | marg="--debug" 13 | fi 14 | 15 | # Preprocess data 16 | python3 -m squad.prepro --mode single --single_path $source_path $parg --target_dir $inter_dir --glove_dir . 17 | 18 | eargs="" 19 | for num in 31 33 34 35 36 37 40 41 43 44 45 46; do 20 | load_path="$root_dir/$num/save" 21 | shared_path="$root_dir/$num/shared.json" 22 | eval_path="$inter_dir/eval-$num.json" 23 | eargs="$eargs $eval_path" 24 | python3 -m basic.cli --data_dir $inter_dir --eval_path $eval_path --nodump_answer --load_path $load_path --shared_path $shared_path $marg --eval_num_batches 0 --mode forward --batch_size 1 --len_opt --cluster --cpu_opt --load_ema & 25 | done 26 | wait 27 | 28 | # Ensemble 29 | python3 -m basic.ensemble --data_path $inter_dir/data_single.json --shared_path $inter_dir/shared_single.json -o $target_path $eargs 30 | -------------------------------------------------------------------------------- /basic/run_single.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source_path=$1 3 | target_path=$2 4 | inter_dir="inter_single" 5 | root_dir="save" 6 | 7 | parg="" 8 | marg="" 9 | if [ "$3" = "debug" ] 10 | then 11 | parg="-d" 12 | marg="--debug" 13 | fi 14 | 15 | # Preprocess data 16 | python3 -m squad.prepro --mode single --single_path $source_path $parg --target_dir $inter_dir --glove_dir . 17 | 18 | num=37 19 | load_path="$root_dir/$num/save" 20 | shared_path="$root_dir/$num/shared.json" 21 | eval_path="$inter_dir/eval.json" 22 | python3 -m basic.cli --data_dir $inter_dir --eval_path $eval_path --nodump_answer --load_path $load_path --shared_path $shared_path $marg --eval_num_batches 0 --mode forward --batch_size 1 --len_opt --cluster --cpu_opt --load_ema 23 | 24 | # Ensemble (for single run, just one input) 25 | python3 -m basic.ensemble --data_path $inter_dir/data_single.json --shared_path $inter_dir/shared_single.json -o $target_path $eval_path 26 | 27 | 28 | -------------------------------------------------------------------------------- /basic/templates/visualizer.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | {{ title }} 6 | 7 | 8 | 19 | 20 | 23 | 24 |

{{ title }}

25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | {% for row in rows %} 35 | 36 | 37 | 42 | 47 | 48 | 49 | 72 | 73 | {% endfor %} 74 |
IDQuestionAnswersPredictedScoreParagraph
{{ row.id }} 38 | {% for qj in row.ques %} 39 | {{ qj }} 40 | {% endfor %} 41 | 43 | {% for aa in row.a %} 44 |
  • {{ aa }}
  • 45 | {% endfor %} 46 |
    {{ row.ap }}{{ row.score }} 50 | 51 | {% for xj, ypj, yp2j in zip(row.para, row.yp, row.yp2) %} 52 | 53 | {% set rowloop = loop %} 54 | {% for xjk, ypjk in zip(xj, ypj) %} 55 | 62 | {% endfor %} 63 | 64 | 65 | {% for xjk, yp2jk in zip(xj, yp2j) %} 66 | 67 | {% endfor %} 68 | 69 | {% endfor %} 70 |
    56 | {% if row.y[0][0] == rowloop.index0 and row.y[0][1] <= loop.index0 <= row.y[1][1] %} 57 | {{ xjk }} 58 | {% else %} 59 | {{ xjk }} 60 | {% endif %} 61 |
    -
    71 |
    75 | 76 | -------------------------------------------------------------------------------- /basic/trainer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from basic.model import Model 4 | from my.tensorflow import average_gradients 5 | 6 | 7 | class Trainer(object): 8 | def __init__(self, config, model): 9 | assert isinstance(model, Model) 10 | self.config = config 11 | self.model = model 12 | self.opt = tf.train.AdadeltaOptimizer(config.init_lr) 13 | self.loss = model.get_loss() 14 | self.var_list = model.get_var_list() 15 | self.global_step = model.get_global_step() 16 | self.summary = model.summary 17 | self.grads = self.opt.compute_gradients(self.loss, var_list=self.var_list) 18 | self.train_op = self.opt.apply_gradients(self.grads, global_step=self.global_step) 19 | 20 | def get_train_op(self): 21 | return self.train_op 22 | 23 | def step(self, sess, batch, get_summary=False): 24 | assert isinstance(sess, tf.Session) 25 | _, ds = batch 26 | feed_dict = self.model.get_feed_dict(ds, True) 27 | if get_summary: 28 | loss, summary, train_op = \ 29 | sess.run([self.loss, self.summary, self.train_op], feed_dict=feed_dict) 30 | else: 31 | loss, train_op = sess.run([self.loss, self.train_op], feed_dict=feed_dict) 32 | summary = None 33 | return loss, summary, train_op 34 | 35 | 36 | class MultiGPUTrainer(object): 37 | def __init__(self, config, models): 38 | model = models[0] 39 | assert isinstance(model, Model) 40 | self.config = config 41 | self.model = model 42 | self.opt = tf.train.AdadeltaOptimizer(config.init_lr) 43 | self.var_list = model.get_var_list() 44 | self.global_step = model.get_global_step() 45 | self.summary = model.summary 46 | self.models = models 47 | losses = [] 48 | grads_list = [] 49 | for gpu_idx, model in enumerate(models): 50 | with tf.name_scope("grads_{}".format(gpu_idx)), tf.device("/{}:{}".format(config.device_type, gpu_idx)): 51 | loss = model.get_loss() 52 | grads = self.opt.compute_gradients(loss, var_list=self.var_list) 53 | losses.append(loss) 54 | grads_list.append(grads) 55 | 56 | self.loss = tf.add_n(losses)/len(losses) 57 | self.grads = average_gradients(grads_list) 58 | self.train_op = self.opt.apply_gradients(self.grads, global_step=self.global_step) 59 | 60 | def step(self, sess, batches, get_summary=False): 61 | assert isinstance(sess, tf.Session) 62 | feed_dict = {} 63 | for batch, model in zip(batches, self.models): 64 | _, ds = batch 65 | feed_dict.update(model.get_feed_dict(ds, True)) 66 | 67 | if get_summary: 68 | loss, summary, train_op = \ 69 | sess.run([self.loss, self.summary, self.train_op], feed_dict=feed_dict) 70 | else: 71 | loss, train_op = sess.run([self.loss, self.train_op], feed_dict=feed_dict) 72 | summary = None 73 | return loss, summary, train_op 74 | -------------------------------------------------------------------------------- /basic/visualizer.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from collections import OrderedDict 3 | import http.server 4 | import socketserver 5 | import argparse 6 | import json 7 | import os 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from jinja2 import Environment, FileSystemLoader 12 | 13 | from basic.evaluator import get_span_score_pairs 14 | from squad.utils import get_best_span, get_span_score_pairs 15 | 16 | 17 | def bool_(string): 18 | if string == 'True': 19 | return True 20 | elif string == 'False': 21 | return False 22 | else: 23 | raise Exception() 24 | 25 | def get_args(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--model_name", type=str, default='basic') 28 | parser.add_argument("--data_type", type=str, default='dev') 29 | parser.add_argument("--step", type=int, default=5000) 30 | parser.add_argument("--template_name", type=str, default="visualizer.html") 31 | parser.add_argument("--num_per_page", type=int, default=100) 32 | parser.add_argument("--data_dir", type=str, default="data/squad") 33 | parser.add_argument("--port", type=int, default=8000) 34 | parser.add_argument("--host", type=str, default="0.0.0.0") 35 | parser.add_argument("--open", type=str, default='False') 36 | parser.add_argument("--run_id", type=str, default="0") 37 | 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | def _decode(decoder, sent): 43 | return " ".join(decoder[idx] for idx in sent) 44 | 45 | 46 | def accuracy2_visualizer(args): 47 | model_name = args.model_name 48 | data_type = args.data_type 49 | num_per_page = args.num_per_page 50 | data_dir = args.data_dir 51 | run_id = args.run_id.zfill(2) 52 | step = args.step 53 | 54 | eval_path =os.path.join("out", model_name, run_id, "eval", "{}-{}.json".format(data_type, str(step).zfill(6))) 55 | print("loading {}".format(eval_path)) 56 | eval_ = json.load(open(eval_path, 'r')) 57 | 58 | _id = 0 59 | html_dir = "/tmp/list_results%d" % _id 60 | while os.path.exists(html_dir): 61 | _id += 1 62 | html_dir = "/tmp/list_results%d" % _id 63 | 64 | if os.path.exists(html_dir): 65 | shutil.rmtree(html_dir) 66 | os.mkdir(html_dir) 67 | 68 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 69 | templates_dir = os.path.join(cur_dir, 'templates') 70 | env = Environment(loader=FileSystemLoader(templates_dir)) 71 | env.globals.update(zip=zip, reversed=reversed) 72 | template = env.get_template(args.template_name) 73 | 74 | data_path = os.path.join(data_dir, "data_{}.json".format(data_type)) 75 | shared_path = os.path.join(data_dir, "shared_{}.json".format(data_type)) 76 | print("loading {}".format(data_path)) 77 | data = json.load(open(data_path, 'r')) 78 | print("loading {}".format(shared_path)) 79 | shared = json.load(open(shared_path, 'r')) 80 | 81 | rows = [] 82 | for i, (idx, yi, ypi, yp2i) in tqdm(enumerate(zip(*[eval_[key] for key in ('idxs', 'y', 'yp', 'yp2')])), total=len(eval_['idxs'])): 83 | id_, q, rx, answers = (data[key][idx] for key in ('ids', 'q', '*x', 'answerss')) 84 | x = shared['x'][rx[0]][rx[1]] 85 | ques = [" ".join(q)] 86 | para = [[word for word in sent] for sent in x] 87 | span = get_best_span(ypi, yp2i) 88 | ap = get_segment(para, span) 89 | score = "{:.3f}".format(ypi[span[0][0]][span[0][1]] * yp2i[span[1][0]][span[1][1]-1]) 90 | 91 | row = { 92 | 'id': id_, 93 | 'title': "Hello world!", 94 | 'ques': ques, 95 | 'para': para, 96 | 'y': yi[0][0], 97 | 'y2': yi[0][1], 98 | 'yp': ypi, 99 | 'yp2': yp2i, 100 | 'a': answers, 101 | 'ap': ap, 102 | 'score': score 103 | } 104 | rows.append(row) 105 | 106 | if i % num_per_page == 0: 107 | html_path = os.path.join(html_dir, "%s.html" % str(i).zfill(8)) 108 | 109 | if (i + 1) % num_per_page == 0 or (i + 1) == len(eval_['y']): 110 | var_dict = {'title': "Accuracy Visualization", 111 | 'rows': rows 112 | } 113 | with open(html_path, "wb") as f: 114 | f.write(template.render(**var_dict).encode('UTF-8')) 115 | rows = [] 116 | 117 | os.chdir(html_dir) 118 | port = args.port 119 | host = args.host 120 | # Overriding to suppress log message 121 | class MyHandler(http.server.SimpleHTTPRequestHandler): 122 | def log_message(self, format, *args): 123 | pass 124 | handler = MyHandler 125 | httpd = socketserver.TCPServer((host, port), handler) 126 | if args.open == 'True': 127 | os.system("open http://%s:%d" % (args.host, args.port)) 128 | print("serving at %s:%d" % (host, port)) 129 | httpd.serve_forever() 130 | 131 | 132 | def get_segment(para, span): 133 | return " ".join(para[span[0][0]][span[0][1]:span[1][1]]) 134 | 135 | 136 | if __name__ == "__main__": 137 | ARGS = get_args() 138 | accuracy2_visualizer(ARGS) -------------------------------------------------------------------------------- /basic_cnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/bi-att-flow/49004549e9a88b78c359b31481afa7792dbb3f4a/basic_cnn/__init__.py -------------------------------------------------------------------------------- /basic_cnn/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | 5 | from basic_cnn.main import main as m 6 | 7 | flags = tf.app.flags 8 | 9 | flags.DEFINE_string("model_name", "basic_cnn", "Model name [basic]") 10 | flags.DEFINE_string("data_dir", "data/cnn", "Data dir [data/cnn]") 11 | flags.DEFINE_string("root_dir", "/Users/minjoons/data/cnn/questions", "root dir [~/data/cnn/questions]") 12 | flags.DEFINE_string("run_id", "0", "Run ID [0]") 13 | flags.DEFINE_string("out_base_dir", "out", "out base dir [out]") 14 | 15 | flags.DEFINE_integer("batch_size", 60, "Batch size [60]") 16 | flags.DEFINE_float("init_lr", 0.5, "Initial learning rate [0.5]") 17 | flags.DEFINE_integer("num_epochs", 50, "Total number of epochs for training [50]") 18 | flags.DEFINE_integer("num_steps", 20000, "Number of steps [20000]") 19 | flags.DEFINE_integer("eval_num_batches", 100, "eval num batches [100]") 20 | flags.DEFINE_integer("load_step", 0, "load step [0]") 21 | flags.DEFINE_integer("early_stop", 4, "early stop [4]") 22 | 23 | flags.DEFINE_string("mode", "test", "train | dev | test | forward [test]") 24 | flags.DEFINE_boolean("load", True, "load saved data? [True]") 25 | flags.DEFINE_boolean("progress", True, "Show progress? [True]") 26 | flags.DEFINE_integer("log_period", 100, "Log period [100]") 27 | flags.DEFINE_integer("eval_period", 1000, "Eval period [1000]") 28 | flags.DEFINE_integer("save_period", 1000, "Save Period [1000]") 29 | flags.DEFINE_float("decay", 0.9, "Exponential moving average decay [0.9]") 30 | 31 | flags.DEFINE_boolean("draft", False, "Draft for quick testing? [False]") 32 | 33 | flags.DEFINE_integer("hidden_size", 100, "Hidden size [100]") 34 | flags.DEFINE_integer("char_out_size", 100, "Char out size [100]") 35 | flags.DEFINE_float("input_keep_prob", 0.8, "Input keep prob [0.8]") 36 | flags.DEFINE_integer("char_emb_size", 8, "Char emb size [8]") 37 | flags.DEFINE_integer("char_filter_height", 5, "Char filter height [5]") 38 | flags.DEFINE_float("wd", 0.0, "Weight decay [0.0]") 39 | flags.DEFINE_bool("lower_word", True, "lower word [True]") 40 | flags.DEFINE_bool("dump_eval", False, "dump eval? [True]") 41 | flags.DEFINE_bool("dump_answer", True, "dump answer? [True]") 42 | flags.DEFINE_string("model", "2", "config 1 |2 [2]") 43 | flags.DEFINE_bool("squash", False, "squash the sentences into one? [False]") 44 | flags.DEFINE_bool("single", False, "supervise only the answer sentence? [False]") 45 | 46 | flags.DEFINE_integer("word_count_th", 10, "word count th [100]") 47 | flags.DEFINE_integer("char_count_th", 50, "char count th [500]") 48 | flags.DEFINE_integer("sent_size_th", 60, "sent size th [64]") 49 | flags.DEFINE_integer("num_sents_th", 200, "num sents th [8]") 50 | flags.DEFINE_integer("ques_size_th", 30, "ques size th [32]") 51 | flags.DEFINE_integer("word_size_th", 16, "word size th [16]") 52 | flags.DEFINE_integer("para_size_th", 256, "para size th [256]") 53 | 54 | flags.DEFINE_bool("swap_memory", True, "swap memory? [True]") 55 | flags.DEFINE_string("data_filter", "max", "max | valid | semi [max]") 56 | flags.DEFINE_bool("finetune", False, "finetune? [False]") 57 | flags.DEFINE_bool("feed_gt", False, "feed gt prev token during training [False]") 58 | flags.DEFINE_bool("feed_hard", False, "feed hard argmax prev token during testing [False]") 59 | flags.DEFINE_bool("use_glove_for_unk", True, "use glove for unk [False]") 60 | flags.DEFINE_bool("known_if_glove", True, "consider as known if present in glove [False]") 61 | flags.DEFINE_bool("eval", True, "eval? [True]") 62 | flags.DEFINE_integer("highway_num_layers", 2, "highway num layers [2]") 63 | flags.DEFINE_bool("use_word_emb", True, "use word embedding? [True]") 64 | 65 | flags.DEFINE_string("forward_name", "single", "Forward name [single]") 66 | flags.DEFINE_string("answer_path", "", "Answer path []") 67 | flags.DEFINE_string("load_path", "", "Load path []") 68 | flags.DEFINE_string("shared_path", "", "Shared path []") 69 | flags.DEFINE_string("device", "/cpu:0", "default device [/cpu:0]") 70 | flags.DEFINE_integer("num_gpus", 1, "num of gpus [1]") 71 | 72 | flags.DEFINE_string("out_channel_dims", "100", "Out channel dims, separated by commas [100]") 73 | flags.DEFINE_string("filter_heights", "5", "Filter heights, separated by commas [5]") 74 | 75 | flags.DEFINE_bool("share_cnn_weights", True, "Share CNN weights [False]") 76 | flags.DEFINE_bool("share_lstm_weights", True, "Share LSTM weights [True]") 77 | flags.DEFINE_bool("two_prepro_layers", False, "Use two layers for preprocessing? [False]") 78 | flags.DEFINE_bool("aug_att", False, "Augment attention layers with more features? [False]") 79 | flags.DEFINE_integer("max_to_keep", 20, "Max recent saves to keep [20]") 80 | flags.DEFINE_bool("vis", False, "output visualization numbers? [False]") 81 | flags.DEFINE_bool("dump_pickle", True, "Dump pickle instead of json? [True]") 82 | flags.DEFINE_float("keep_prob", 1.0, "keep prob [1.0]") 83 | flags.DEFINE_string("prev_mode", "a", "prev mode gy | y | a [a]") 84 | flags.DEFINE_string("logit_func", "tri_linear", "logit func [tri_linear]") 85 | flags.DEFINE_bool("sh", False, "use superhighway [False]") 86 | flags.DEFINE_string("answer_func", "linear", "answer logit func [linear]") 87 | flags.DEFINE_bool("cluster", False, "Cluster data for faster training [False]") 88 | flags.DEFINE_bool("len_opt", False, "Length optimization? [False]") 89 | flags.DEFINE_string("sh_logit_func", "tri_linear", "sh logit func [tri_linear]") 90 | flags.DEFINE_float("filter_ratio", 1.0, "filter ratio [1.0]") 91 | flags.DEFINE_bool("bi", False, "bi-directional attention? [False]") 92 | flags.DEFINE_integer("width", 5, "width around entity [5]") 93 | 94 | 95 | def main(_): 96 | config = flags.FLAGS 97 | 98 | config.out_dir = os.path.join(config.out_base_dir, config.model_name, str(config.run_id).zfill(2)) 99 | 100 | m(config) 101 | 102 | if __name__ == "__main__": 103 | tf.app.run() 104 | -------------------------------------------------------------------------------- /basic_cnn/graph_handler.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | from json import encoder 4 | import os 5 | 6 | import tensorflow as tf 7 | 8 | from basic_cnn.evaluator import Evaluation, F1Evaluation 9 | from my.utils import short_floats 10 | 11 | import pickle 12 | 13 | 14 | class GraphHandler(object): 15 | def __init__(self, config): 16 | self.config = config 17 | self.saver = tf.train.Saver(max_to_keep=config.max_to_keep) 18 | self.writer = None 19 | self.save_path = os.path.join(config.save_dir, config.model_name) 20 | 21 | def initialize(self, sess): 22 | if self.config.load: 23 | self._load(sess) 24 | else: 25 | sess.run(tf.initialize_all_variables()) 26 | 27 | if self.config.mode == 'train': 28 | self.writer = tf.train.SummaryWriter(self.config.log_dir, graph=tf.get_default_graph()) 29 | 30 | def save(self, sess, global_step=None): 31 | self.saver.save(sess, self.save_path, global_step=global_step) 32 | 33 | def _load(self, sess): 34 | config = self.config 35 | if config.load_path: 36 | save_path = config.load_path 37 | elif config.load_step > 0: 38 | save_path = os.path.join(config.save_dir, "{}-{}".format(config.model_name, config.load_step)) 39 | else: 40 | save_dir = config.save_dir 41 | checkpoint = tf.train.get_checkpoint_state(save_dir) 42 | assert checkpoint is not None, "cannot load checkpoint at {}".format(save_dir) 43 | save_path = checkpoint.model_checkpoint_path 44 | print("Loading saved model from {}".format(save_path)) 45 | self.saver.restore(sess, save_path) 46 | 47 | def add_summary(self, summary, global_step): 48 | self.writer.add_summary(summary, global_step) 49 | 50 | def add_summaries(self, summaries, global_step): 51 | for summary in summaries: 52 | self.add_summary(summary, global_step) 53 | 54 | def dump_eval(self, e, precision=2, path=None): 55 | assert isinstance(e, Evaluation) 56 | if self.config.dump_pickle: 57 | path = path or os.path.join(self.config.eval_dir, "{}-{}.pklz".format(e.data_type, str(e.global_step).zfill(6))) 58 | with gzip.open(path, 'wb', compresslevel=3) as fh: 59 | pickle.dump(e.dict, fh) 60 | else: 61 | path = path or os.path.join(self.config.eval_dir, "{}-{}.json".format(e.data_type, str(e.global_step).zfill(6))) 62 | with open(path, 'w') as fh: 63 | json.dump(short_floats(e.dict, precision), fh) 64 | 65 | def dump_answer(self, e, path=None): 66 | assert isinstance(e, Evaluation) 67 | path = path or os.path.join(self.config.answer_dir, "{}-{}.json".format(e.data_type, str(e.global_step).zfill(6))) 68 | with open(path, 'w') as fh: 69 | json.dump(e.id2answer_dict, fh) 70 | 71 | -------------------------------------------------------------------------------- /basic_cnn/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import os 5 | import shutil 6 | from pprint import pprint 7 | 8 | import tensorflow as tf 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | from basic_cnn.evaluator import F1Evaluator, Evaluator, ForwardEvaluator, MultiGPUF1Evaluator, CNNAccuracyEvaluator, \ 13 | MultiGPUCNNAccuracyEvaluator 14 | from basic_cnn.graph_handler import GraphHandler 15 | from basic_cnn.model import Model, get_multi_gpu_models 16 | from basic_cnn.trainer import Trainer, MultiGPUTrainer 17 | 18 | from basic_cnn.read_data import read_data, get_cnn_data_filter, update_config 19 | 20 | 21 | def main(config): 22 | set_dirs(config) 23 | with tf.device(config.device): 24 | if config.mode == 'train': 25 | _train(config) 26 | elif config.mode == 'test' or config.mode == 'dev': 27 | _test(config) 28 | elif config.mode == 'forward': 29 | _forward(config) 30 | else: 31 | raise ValueError("invalid value for 'mode': {}".format(config.mode)) 32 | 33 | 34 | def _config_draft(config): 35 | if config.draft: 36 | config.num_steps = 2 37 | config.eval_period = 1 38 | config.log_period = 1 39 | config.save_period = 1 40 | config.eval_num_batches = 1 41 | 42 | 43 | def _train(config): 44 | # load_metadata(config, 'train') # this updates the config file according to metadata file 45 | 46 | data_filter = get_cnn_data_filter(config) 47 | train_data = read_data(config, 'train', config.load, data_filter=data_filter) 48 | dev_data = read_data(config, 'dev', True, data_filter=data_filter) 49 | # test_data = read_data(config, 'test', True, data_filter=data_filter) 50 | update_config(config, [train_data, dev_data]) 51 | 52 | _config_draft(config) 53 | 54 | word2vec_dict = train_data.shared['lower_word2vec'] if config.lower_word else train_data.shared['word2vec'] 55 | word2idx_dict = train_data.shared['word2idx'] 56 | idx2vec_dict = {word2idx_dict[word]: vec for word, vec in word2vec_dict.items() if word in word2idx_dict} 57 | print("{}/{} unique words have corresponding glove vectors.".format(len(idx2vec_dict), len(word2idx_dict))) 58 | emb_mat = np.array([idx2vec_dict[idx] if idx in idx2vec_dict 59 | else np.random.multivariate_normal(np.zeros(config.word_emb_size), np.eye(config.word_emb_size)) 60 | for idx in range(config.word_vocab_size)]) 61 | config.emb_mat = emb_mat 62 | 63 | # construct model graph and variables (using default graph) 64 | pprint(config.__flags, indent=2) 65 | # model = Model(config) 66 | models = get_multi_gpu_models(config) 67 | model = models[0] 68 | trainer = MultiGPUTrainer(config, models) 69 | evaluator = MultiGPUCNNAccuracyEvaluator(config, models, tensor_dict=model.tensor_dict if config.vis else None) 70 | graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving 71 | 72 | # Variables 73 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 74 | graph_handler.initialize(sess) 75 | 76 | # begin training 77 | print(train_data.num_examples) 78 | num_steps = config.num_steps or int(math.ceil(train_data.num_examples / (config.batch_size * config.num_gpus))) * config.num_epochs 79 | global_step = 0 80 | for batches in tqdm(train_data.get_multi_batches(config.batch_size, config.num_gpus, 81 | num_steps=num_steps, shuffle=True, cluster=config.cluster), total=num_steps): 82 | global_step = sess.run(model.global_step) + 1 # +1 because all calculations are done after step 83 | get_summary = global_step % config.log_period == 0 84 | loss, summary, train_op = trainer.step(sess, batches, get_summary=get_summary) 85 | if get_summary: 86 | graph_handler.add_summary(summary, global_step) 87 | 88 | # occasional saving 89 | if global_step % config.save_period == 0: 90 | graph_handler.save(sess, global_step=global_step) 91 | 92 | if not config.eval: 93 | continue 94 | # Occasional evaluation 95 | if global_step % config.eval_period == 0: 96 | num_steps = math.ceil(dev_data.num_examples / (config.batch_size * config.num_gpus)) 97 | if 0 < config.eval_num_batches < num_steps: 98 | num_steps = config.eval_num_batches 99 | e_train = evaluator.get_evaluation_from_batches( 100 | sess, tqdm(train_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps), total=num_steps) 101 | ) 102 | graph_handler.add_summaries(e_train.summaries, global_step) 103 | e_dev = evaluator.get_evaluation_from_batches( 104 | sess, tqdm(dev_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps), total=num_steps)) 105 | graph_handler.add_summaries(e_dev.summaries, global_step) 106 | 107 | if config.dump_eval: 108 | graph_handler.dump_eval(e_dev) 109 | if config.dump_answer: 110 | graph_handler.dump_answer(e_dev) 111 | if global_step % config.save_period != 0: 112 | graph_handler.save(sess, global_step=global_step) 113 | 114 | 115 | def _test(config): 116 | assert config.load 117 | test_data = read_data(config, config.mode, True) 118 | update_config(config, [test_data]) 119 | 120 | _config_draft(config) 121 | 122 | if config.use_glove_for_unk: 123 | word2vec_dict = test_data.shared['lower_word2vec'] if config.lower_word else test_data.shared['word2vec'] 124 | new_word2idx_dict = test_data.shared['new_word2idx'] 125 | idx2vec_dict = {idx: word2vec_dict[word] for word, idx in new_word2idx_dict.items()} 126 | # print("{}/{} unique words have corresponding glove vectors.".format(len(idx2vec_dict), len(word2idx_dict))) 127 | new_emb_mat = np.array([idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32') 128 | config.new_emb_mat = new_emb_mat 129 | 130 | pprint(config.__flags, indent=2) 131 | models = get_multi_gpu_models(config) 132 | evaluator = MultiGPUCNNAccuracyEvaluator(config, models, tensor_dict=models[0].tensor_dict if config.vis else None) 133 | graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving 134 | 135 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 136 | graph_handler.initialize(sess) 137 | num_steps = math.ceil(test_data.num_examples / (config.batch_size * config.num_gpus)) 138 | if 0 < config.eval_num_batches < num_steps: 139 | num_steps = config.eval_num_batches 140 | 141 | e = None 142 | for multi_batch in tqdm(test_data.get_multi_batches(config.batch_size, config.num_gpus, num_steps=num_steps, cluster=config.cluster), total=num_steps): 143 | ei = evaluator.get_evaluation(sess, multi_batch) 144 | e = ei if e is None else e + ei 145 | if config.vis: 146 | eval_subdir = os.path.join(config.eval_dir, "{}-{}".format(ei.data_type, str(ei.global_step).zfill(6))) 147 | if not os.path.exists(eval_subdir): 148 | os.mkdir(eval_subdir) 149 | path = os.path.join(eval_subdir, str(ei.idxs[0]).zfill(8)) 150 | graph_handler.dump_eval(ei, path=path) 151 | 152 | print(e) 153 | if config.dump_answer: 154 | print("dumping answer ...") 155 | graph_handler.dump_answer(e) 156 | if config.dump_eval: 157 | print("dumping eval ...") 158 | graph_handler.dump_eval(e) 159 | 160 | 161 | def _forward(config): 162 | assert config.load 163 | test_data = read_data(config, config.forward_name, True) 164 | update_config(config, [test_data]) 165 | 166 | _config_draft(config) 167 | 168 | if config.use_glove_for_unk: 169 | word2vec_dict = test_data.shared['lower_word2vec'] if config.lower_word else test_data.shared['word2vec'] 170 | new_word2idx_dict = test_data.shared['new_word2idx'] 171 | idx2vec_dict = {idx: word2vec_dict[word] for word, idx in new_word2idx_dict.items()} 172 | # print("{}/{} unique words have corresponding glove vectors.".format(len(idx2vec_dict), len(word2idx_dict))) 173 | new_emb_mat = np.array([idx2vec_dict[idx] for idx in range(len(idx2vec_dict))], dtype='float32') 174 | config.new_emb_mat = new_emb_mat 175 | 176 | pprint(config.__flags, indent=2) 177 | models = get_multi_gpu_models(config) 178 | model = models[0] 179 | evaluator = ForwardEvaluator(config, model) 180 | graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving 181 | 182 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 183 | graph_handler.initialize(sess) 184 | 185 | num_batches = math.ceil(test_data.num_examples / config.batch_size) 186 | if 0 < config.eval_num_batches < num_batches: 187 | num_batches = config.eval_num_batches 188 | e = evaluator.get_evaluation_from_batches(sess, tqdm(test_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches)) 189 | print(e) 190 | if config.dump_answer: 191 | print("dumping answer ...") 192 | graph_handler.dump_answer(e, path=config.answer_path) 193 | if config.dump_eval: 194 | print("dumping eval ...") 195 | graph_handler.dump_eval(e) 196 | 197 | 198 | def set_dirs(config): 199 | # create directories 200 | if not config.load and os.path.exists(config.out_dir): 201 | shutil.rmtree(config.out_dir) 202 | 203 | config.save_dir = os.path.join(config.out_dir, "save") 204 | config.log_dir = os.path.join(config.out_dir, "log") 205 | config.eval_dir = os.path.join(config.out_dir, "eval") 206 | config.answer_dir = os.path.join(config.out_dir, "answer") 207 | if not os.path.exists(config.out_dir): 208 | os.makedirs(config.out_dir) 209 | if not os.path.exists(config.save_dir): 210 | os.mkdir(config.save_dir) 211 | if not os.path.exists(config.log_dir): 212 | os.mkdir(config.log_dir) 213 | if not os.path.exists(config.answer_dir): 214 | os.mkdir(config.answer_dir) 215 | if not os.path.exists(config.eval_dir): 216 | os.mkdir(config.eval_dir) 217 | 218 | 219 | def _get_args(): 220 | parser = argparse.ArgumentParser() 221 | parser.add_argument("config_path") 222 | return parser.parse_args() 223 | 224 | 225 | class Config(object): 226 | def __init__(self, **entries): 227 | self.__dict__.update(entries) 228 | 229 | 230 | def _run(): 231 | args = _get_args() 232 | with open(args.config_path, 'r') as fh: 233 | config = Config(**json.load(fh)) 234 | main(config) 235 | 236 | 237 | if __name__ == "__main__": 238 | _run() 239 | -------------------------------------------------------------------------------- /basic_cnn/superhighway.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops.rnn_cell import RNNCell 3 | 4 | from my.tensorflow.nn import linear 5 | 6 | 7 | class SHCell(RNNCell): 8 | """ 9 | Super-Highway Cell 10 | """ 11 | def __init__(self, input_size, logit_func='tri_linear', scalar=False): 12 | self._state_size = input_size 13 | self._output_size = input_size 14 | self._logit_func = logit_func 15 | self._scalar = scalar 16 | 17 | @property 18 | def state_size(self): 19 | return self._state_size 20 | 21 | @property 22 | def output_size(self): 23 | return self._output_size 24 | 25 | def __call__(self, inputs, state, scope=None): 26 | with tf.variable_scope(scope or "SHCell"): 27 | a_size = 1 if self._scalar else self._state_size 28 | h, u = tf.split(1, 2, inputs) 29 | if self._logit_func == 'mul_linear': 30 | args = [h * u, state * u] 31 | a = tf.nn.sigmoid(linear(args, a_size, True)) 32 | elif self._logit_func == 'linear': 33 | args = [h, u, state] 34 | a = tf.nn.sigmoid(linear(args, a_size, True)) 35 | elif self._logit_func == 'tri_linear': 36 | args = [h, u, state, h * u, state * u] 37 | a = tf.nn.sigmoid(linear(args, a_size, True)) 38 | elif self._logit_func == 'double': 39 | args = [h, u, state] 40 | a = tf.nn.sigmoid(linear(tf.tanh(linear(args, a_size, True)), self._state_size, True)) 41 | 42 | else: 43 | raise Exception() 44 | new_state = a * state + (1 - a) * h 45 | outputs = state 46 | return outputs, new_state 47 | 48 | -------------------------------------------------------------------------------- /basic_cnn/templates/visualizer.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | {{ title }} 6 | 7 | 8 | 19 | 20 | 23 | 24 |

    {{ title }}

    25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | {% for row in rows %} 35 | 36 | 37 | 42 | 47 | 48 | 49 | 72 | 73 | {% endfor %} 74 |
    IDQuestionAnswersPredictedScoreParagraph
    {{ row.id }} 38 | {% for qj in row.ques %} 39 | {{ qj }} 40 | {% endfor %} 41 | 43 | {% for aa in row.a %} 44 |
  • {{ aa }}
  • 45 | {% endfor %} 46 |
    {{ row.ap }}{{ row.score }} 50 | 51 | {% for xj, ypj, yp2j in zip(row.para, row.yp, row.yp2) %} 52 | 53 | {% set rowloop = loop %} 54 | {% for xjk, ypjk in zip(xj, ypj) %} 55 | 62 | {% endfor %} 63 | 64 | 65 | {% for xjk, yp2jk in zip(xj, yp2j) %} 66 | 67 | {% endfor %} 68 | 69 | {% endfor %} 70 |
    56 | {% if row.y[0][0] == rowloop.index0 and row.y[0][1] <= loop.index0 <= row.y[1][1] %} 57 | {{ xjk }} 58 | {% else %} 59 | {{ xjk }} 60 | {% endif %} 61 |
    -
    71 |
    75 | 76 | -------------------------------------------------------------------------------- /basic_cnn/trainer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from basic_cnn.model import Model 4 | from my.tensorflow import average_gradients 5 | 6 | 7 | class Trainer(object): 8 | def __init__(self, config, model): 9 | assert isinstance(model, Model) 10 | self.config = config 11 | self.model = model 12 | self.opt = tf.train.AdadeltaOptimizer(config.init_lr) 13 | self.loss = model.get_loss() 14 | self.var_list = model.get_var_list() 15 | self.global_step = model.get_global_step() 16 | self.summary = model.summary 17 | self.grads = self.opt.compute_gradients(self.loss, var_list=self.var_list) 18 | self.train_op = self.opt.apply_gradients(self.grads, global_step=self.global_step) 19 | 20 | def get_train_op(self): 21 | return self.train_op 22 | 23 | def step(self, sess, batch, get_summary=False): 24 | assert isinstance(sess, tf.Session) 25 | _, ds = batch 26 | feed_dict = self.model.get_feed_dict(ds, True) 27 | if get_summary: 28 | loss, summary, train_op = \ 29 | sess.run([self.loss, self.summary, self.train_op], feed_dict=feed_dict) 30 | else: 31 | loss, train_op = sess.run([self.loss, self.train_op], feed_dict=feed_dict) 32 | summary = None 33 | return loss, summary, train_op 34 | 35 | 36 | class MultiGPUTrainer(object): 37 | def __init__(self, config, models): 38 | model = models[0] 39 | assert isinstance(model, Model) 40 | self.config = config 41 | self.model = model 42 | self.opt = tf.train.AdadeltaOptimizer(config.init_lr) 43 | self.var_list = model.get_var_list() 44 | self.global_step = model.get_global_step() 45 | self.summary = model.summary 46 | self.models = models 47 | losses = [] 48 | grads_list = [] 49 | for gpu_idx, model in enumerate(models): 50 | with tf.name_scope("grads_{}".format(gpu_idx)), tf.device("/gpu:{}".format(gpu_idx)): 51 | loss = model.get_loss() 52 | grads = self.opt.compute_gradients(loss, var_list=self.var_list) 53 | losses.append(loss) 54 | grads_list.append(grads) 55 | 56 | self.loss = tf.add_n(losses)/len(losses) 57 | self.grads = average_gradients(grads_list) 58 | self.train_op = self.opt.apply_gradients(self.grads, global_step=self.global_step) 59 | 60 | def step(self, sess, batches, get_summary=False): 61 | assert isinstance(sess, tf.Session) 62 | feed_dict = {} 63 | for batch, model in zip(batches, self.models): 64 | _, ds = batch 65 | feed_dict.update(model.get_feed_dict(ds, True)) 66 | 67 | if get_summary: 68 | loss, summary, train_op = \ 69 | sess.run([self.loss, self.summary, self.train_op], feed_dict=feed_dict) 70 | else: 71 | loss, train_op = sess.run([self.loss, self.train_op], feed_dict=feed_dict) 72 | summary = None 73 | return loss, summary, train_op 74 | -------------------------------------------------------------------------------- /basic_cnn/visualizer.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from collections import OrderedDict 3 | import http.server 4 | import socketserver 5 | import argparse 6 | import json 7 | import os 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from jinja2 import Environment, FileSystemLoader 12 | 13 | from basic_cnn.evaluator import get_span_score_pairs, get_best_span 14 | 15 | 16 | def bool_(string): 17 | if string == 'True': 18 | return True 19 | elif string == 'False': 20 | return False 21 | else: 22 | raise Exception() 23 | 24 | def get_args(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--model_name", type=str, default='basic') 27 | parser.add_argument("--data_type", type=str, default='dev') 28 | parser.add_argument("--step", type=int, default=5000) 29 | parser.add_argument("--template_name", type=str, default="visualizer.html") 30 | parser.add_argument("--num_per_page", type=int, default=100) 31 | parser.add_argument("--data_dir", type=str, default="data/squad") 32 | parser.add_argument("--port", type=int, default=8000) 33 | parser.add_argument("--host", type=str, default="0.0.0.0") 34 | parser.add_argument("--open", type=str, default='False') 35 | parser.add_argument("--run_id", type=str, default="0") 36 | 37 | args = parser.parse_args() 38 | return args 39 | 40 | 41 | def _decode(decoder, sent): 42 | return " ".join(decoder[idx] for idx in sent) 43 | 44 | 45 | def accuracy2_visualizer(args): 46 | model_name = args.model_name 47 | data_type = args.data_type 48 | num_per_page = args.num_per_page 49 | data_dir = args.data_dir 50 | run_id = args.run_id.zfill(2) 51 | step = args.step 52 | 53 | eval_path =os.path.join("out", model_name, run_id, "eval", "{}-{}.json".format(data_type, str(step).zfill(6))) 54 | print("loading {}".format(eval_path)) 55 | eval_ = json.load(open(eval_path, 'r')) 56 | 57 | _id = 0 58 | html_dir = "/tmp/list_results%d" % _id 59 | while os.path.exists(html_dir): 60 | _id += 1 61 | html_dir = "/tmp/list_results%d" % _id 62 | 63 | if os.path.exists(html_dir): 64 | shutil.rmtree(html_dir) 65 | os.mkdir(html_dir) 66 | 67 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 68 | templates_dir = os.path.join(cur_dir, 'templates') 69 | env = Environment(loader=FileSystemLoader(templates_dir)) 70 | env.globals.update(zip=zip, reversed=reversed) 71 | template = env.get_template(args.template_name) 72 | 73 | data_path = os.path.join(data_dir, "data_{}.json".format(data_type)) 74 | shared_path = os.path.join(data_dir, "shared_{}.json".format(data_type)) 75 | print("loading {}".format(data_path)) 76 | data = json.load(open(data_path, 'r')) 77 | print("loading {}".format(shared_path)) 78 | shared = json.load(open(shared_path, 'r')) 79 | 80 | rows = [] 81 | for i, (idx, yi, ypi, yp2i) in tqdm(enumerate(zip(*[eval_[key] for key in ('idxs', 'y', 'yp', 'yp2')])), total=len(eval_['idxs'])): 82 | id_, q, rx, answers = (data[key][idx] for key in ('ids', 'q', '*x', 'answerss')) 83 | x = shared['x'][rx[0]][rx[1]] 84 | ques = [" ".join(q)] 85 | para = [[word for word in sent] for sent in x] 86 | span = get_best_span(ypi, yp2i) 87 | ap = get_segment(para, span) 88 | score = "{:.3f}".format(ypi[span[0][0]][span[0][1]] * yp2i[span[1][0]][span[1][1]-1]) 89 | 90 | row = { 91 | 'id': id_, 92 | 'title': "Hello world!", 93 | 'ques': ques, 94 | 'para': para, 95 | 'y': yi[0][0], 96 | 'y2': yi[0][1], 97 | 'yp': ypi, 98 | 'yp2': yp2i, 99 | 'a': answers, 100 | 'ap': ap, 101 | 'score': score 102 | } 103 | rows.append(row) 104 | 105 | if i % num_per_page == 0: 106 | html_path = os.path.join(html_dir, "%s.html" % str(i).zfill(8)) 107 | 108 | if (i + 1) % num_per_page == 0 or (i + 1) == len(eval_['y']): 109 | var_dict = {'title': "Accuracy Visualization", 110 | 'rows': rows 111 | } 112 | with open(html_path, "wb") as f: 113 | f.write(template.render(**var_dict).encode('UTF-8')) 114 | rows = [] 115 | 116 | os.chdir(html_dir) 117 | port = args.port 118 | host = args.host 119 | # Overriding to suppress log message 120 | class MyHandler(http.server.SimpleHTTPRequestHandler): 121 | def log_message(self, format, *args): 122 | pass 123 | handler = MyHandler 124 | httpd = socketserver.TCPServer((host, port), handler) 125 | if args.open == 'True': 126 | os.system("open http://%s:%d" % (args.host, args.port)) 127 | print("serving at %s:%d" % (host, port)) 128 | httpd.serve_forever() 129 | 130 | 131 | def get_segment(para, span): 132 | return " ".join(para[span[0][0]][span[0][1]:span[1][1]]) 133 | 134 | 135 | if __name__ == "__main__": 136 | ARGS = get_args() 137 | accuracy2_visualizer(ARGS) -------------------------------------------------------------------------------- /cnn_dm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/bi-att-flow/49004549e9a88b78c359b31481afa7792dbb3f4a/cnn_dm/__init__.py -------------------------------------------------------------------------------- /cnn_dm/evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | 5 | root_dir = sys.argv[1] 6 | answer_path = sys.argv[2] 7 | file_names = os.listdir(root_dir) 8 | 9 | num_correct = 0 10 | num_wrong = 0 11 | 12 | with open(answer_path, 'r') as fh: 13 | id2answer_dict = json.load(fh) 14 | 15 | for file_name in file_names: 16 | if not file_name.endswith(".question"): 17 | continue 18 | with open(os.path.join(root_dir, file_name), 'r') as fh: 19 | url = fh.readline().strip() 20 | _ = fh.readline() 21 | para = fh.readline().strip() 22 | _ = fh.readline() 23 | ques = fh.readline().strip() 24 | _ = fh.readline() 25 | answer = fh.readline().strip() 26 | _ = fh.readline() 27 | if file_name in id2answer_dict: 28 | pred = id2answer_dict[file_name] 29 | if pred == answer: 30 | num_correct += 1 31 | else: 32 | num_wrong += 1 33 | else: 34 | num_wrong += 1 35 | 36 | total = num_correct + num_wrong 37 | acc = float(num_correct) / total 38 | print("{} = {} / {}".format(acc, num_correct, total)) -------------------------------------------------------------------------------- /cnn_dm/prepro.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | # data: q, cq, (dq), (pq), y, *x, *cx 5 | # shared: x, cx, (dx), (px), word_counter, char_counter, word2vec 6 | # no metadata 7 | from collections import Counter 8 | 9 | from tqdm import tqdm 10 | 11 | from my.utils import process_tokens 12 | from squad.utils import get_word_span, process_tokens 13 | 14 | 15 | def bool_(arg): 16 | if arg == 'True': 17 | return True 18 | elif arg == 'False': 19 | return False 20 | raise Exception(arg) 21 | 22 | 23 | def main(): 24 | args = get_args() 25 | prepro(args) 26 | 27 | 28 | def get_args(): 29 | parser = argparse.ArgumentParser() 30 | home = os.path.expanduser("~") 31 | source_dir = os.path.join(home, "data", "cnn", 'questions') 32 | target_dir = "data/cnn" 33 | glove_dir = os.path.join(home, "data", "glove") 34 | parser.add_argument("--source_dir", default=source_dir) 35 | parser.add_argument("--target_dir", default=target_dir) 36 | parser.add_argument("--glove_dir", default=glove_dir) 37 | parser.add_argument("--glove_corpus", default='6B') 38 | parser.add_argument("--glove_vec_size", default=100, type=int) 39 | parser.add_argument("--debug", default=False, type=bool_) 40 | parser.add_argument("--num_sents_th", default=200, type=int) 41 | parser.add_argument("--ques_size_th", default=30, type=int) 42 | parser.add_argument("--width", default=5, type=int) 43 | # TODO : put more args here 44 | return parser.parse_args() 45 | 46 | 47 | def prepro(args): 48 | prepro_each(args, 'train') 49 | prepro_each(args, 'dev') 50 | prepro_each(args, 'test') 51 | 52 | 53 | def para2sents(para, width): 54 | """ 55 | Turn para into double array of words (wordss) 56 | Where each sentence is up to 5 word neighbors of each entity 57 | :param para: 58 | :return: 59 | """ 60 | words = para.split(" ") 61 | sents = [] 62 | for i, word in enumerate(words): 63 | if word.startswith("@"): 64 | start = max(i - width, 0) 65 | stop = min(i + width + 1, len(words)) 66 | sent = words[start:stop] 67 | sents.append(sent) 68 | return sents 69 | 70 | 71 | def get_word2vec(args, word_counter): 72 | glove_path = os.path.join(args.glove_dir, "glove.{}.{}d.txt".format(args.glove_corpus, args.glove_vec_size)) 73 | sizes = {'6B': int(4e5), '42B': int(1.9e6), '840B': int(2.2e6), '2B': int(1.2e6)} 74 | total = sizes[args.glove_corpus] 75 | word2vec_dict = {} 76 | with open(glove_path, 'r', encoding='utf-8') as fh: 77 | for line in tqdm(fh, total=total): 78 | array = line.lstrip().rstrip().split(" ") 79 | word = array[0] 80 | vector = list(map(float, array[1:])) 81 | if word in word_counter: 82 | word2vec_dict[word] = vector 83 | elif word.capitalize() in word_counter: 84 | word2vec_dict[word.capitalize()] = vector 85 | elif word.lower() in word_counter: 86 | word2vec_dict[word.lower()] = vector 87 | elif word.upper() in word_counter: 88 | word2vec_dict[word.upper()] = vector 89 | 90 | print("{}/{} of word vocab have corresponding vectors in {}".format(len(word2vec_dict), len(word_counter), glove_path)) 91 | return word2vec_dict 92 | 93 | 94 | def prepro_each(args, mode): 95 | source_dir = os.path.join(args.source_dir, mode) 96 | word_counter = Counter() 97 | lower_word_counter = Counter() 98 | ent_counter = Counter() 99 | char_counter = Counter() 100 | max_sent_size = 0 101 | max_word_size = 0 102 | max_ques_size = 0 103 | max_num_sents = 0 104 | 105 | file_names = list(os.listdir(source_dir)) 106 | if args.debug: 107 | file_names = file_names[:1000] 108 | lens = [] 109 | 110 | out_file_names = [] 111 | for file_name in tqdm(file_names, total=len(file_names)): 112 | if file_name.endswith(".question"): 113 | with open(os.path.join(source_dir, file_name), 'r') as fh: 114 | url = fh.readline().strip() 115 | _ = fh.readline() 116 | para = fh.readline().strip() 117 | _ = fh.readline() 118 | ques = fh.readline().strip() 119 | _ = fh.readline() 120 | answer = fh.readline().strip() 121 | _ = fh.readline() 122 | cands = list(line.strip() for line in fh) 123 | cand_ents = list(cand.split(":")[0] for cand in cands) 124 | sents = para2sents(para, args.width) 125 | ques_words = ques.split(" ") 126 | 127 | # Filtering 128 | if len(sents) > args.num_sents_th or len(ques_words) > args.ques_size_th: 129 | continue 130 | 131 | max_sent_size = max(max(map(len, sents)), max_sent_size) 132 | max_ques_size = max(len(ques_words), max_ques_size) 133 | max_word_size = max(max(len(word) for sent in sents for word in sent), max_word_size) 134 | max_num_sents = max(len(sents), max_num_sents) 135 | 136 | for word in ques_words: 137 | if word.startswith("@"): 138 | ent_counter[word] += 1 139 | word_counter[word] += 1 140 | else: 141 | word_counter[word] += 1 142 | lower_word_counter[word.lower()] += 1 143 | for c in word: 144 | char_counter[c] += 1 145 | for sent in sents: 146 | for word in sent: 147 | if word.startswith("@"): 148 | ent_counter[word] += 1 149 | word_counter[word] += 1 150 | else: 151 | word_counter[word] += 1 152 | lower_word_counter[word.lower()] += 1 153 | for c in word: 154 | char_counter[c] += 1 155 | 156 | out_file_names.append(file_name) 157 | lens.append(len(sents)) 158 | num_examples = len(out_file_names) 159 | 160 | assert len(out_file_names) == len(lens) 161 | sorted_file_names, lens = zip(*sorted(zip(out_file_names, lens), key=lambda each: each[1])) 162 | assert lens[-1] == max_num_sents 163 | 164 | word2vec_dict = get_word2vec(args, word_counter) 165 | lower_word2vec_dit = get_word2vec(args, lower_word_counter) 166 | 167 | shared = {'word_counter': word_counter, 'ent_counter': ent_counter, 'char_counter': char_counter, 168 | 'lower_word_counter': lower_word_counter, 169 | 'max_num_sents': max_num_sents, 'max_sent_size': max_sent_size, 'max_word_size': max_word_size, 170 | 'max_ques_size': max_ques_size, 171 | 'word2vec': word2vec_dict, 'lower_word2vec': lower_word2vec_dit, 'sorted': sorted_file_names, 172 | 'num_examples': num_examples} 173 | 174 | print("max num sents: {}".format(max_num_sents)) 175 | print("max ques size: {}".format(max_ques_size)) 176 | 177 | if not os.path.exists(args.target_dir): 178 | os.makedirs(args.target_dir) 179 | shared_path = os.path.join(args.target_dir, "shared_{}.json".format(mode)) 180 | with open(shared_path, 'w') as fh: 181 | json.dump(shared, fh) 182 | 183 | 184 | if __name__ == "__main__": 185 | main() 186 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATA_DIR=$HOME/data 4 | mkdir $DATA_DIR 5 | 6 | # Download SQuAD 7 | SQUAD_DIR=$DATA_DIR/squad 8 | mkdir $SQUAD_DIR 9 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json -O $SQUAD_DIR/train-v1.1.json 10 | wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json -O $SQUAD_DIR/dev-v1.1.json 11 | 12 | 13 | # Download CNN and DailyMail 14 | # Download at: http://cs.nyu.edu/~kcho/DMQA/ 15 | 16 | 17 | # Download GloVe 18 | GLOVE_DIR=$DATA_DIR/glove 19 | mkdir $GLOVE_DIR 20 | wget http://nlp.stanford.edu/data/glove.6B.zip -O $GLOVE_DIR/glove.6B.zip 21 | unzip $GLOVE_DIR/glove.6B.zip -d $GLOVE_DIR 22 | 23 | # Download NLTK (for tokenizer) 24 | # Make sure that nltk is installed! 25 | python3 -m nltk.downloader -d $HOME/nltk_data punkt 26 | -------------------------------------------------------------------------------- /my/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/bi-att-flow/49004549e9a88b78c359b31481afa7792dbb3f4a/my/__init__.py -------------------------------------------------------------------------------- /my/corenlp_interface.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import requests 4 | import nltk 5 | import json 6 | import networkx as nx 7 | import time 8 | 9 | 10 | class CoreNLPInterface(object): 11 | def __init__(self, url, port): 12 | self._url = url 13 | self._port = port 14 | 15 | def get(self, type_, in_, num_max_requests=100): 16 | in_ = in_.encode("utf-8") 17 | url = "http://{}:{}/{}".format(self._url, self._port, type_) 18 | out = None 19 | for _ in range(num_max_requests): 20 | try: 21 | r = requests.post(url, data=in_) 22 | out = r.content.decode('utf-8') 23 | if out == 'error': 24 | out = None 25 | break 26 | except: 27 | time.sleep(1) 28 | return out 29 | 30 | def split_doc(self, doc): 31 | out = self.get("doc", doc) 32 | return out if out is None else json.loads(out) 33 | 34 | def split_sent(self, sent): 35 | out = self.get("sent", sent) 36 | return out if out is None else json.loads(out) 37 | 38 | def get_dep(self, sent): 39 | out = self.get("dep", sent) 40 | return out if out is None else json.loads(out) 41 | 42 | def get_const(self, sent): 43 | out = self.get("const", sent) 44 | return out 45 | 46 | def get_const_tree(self, sent): 47 | out = self.get_const(sent) 48 | return out if out is None else nltk.tree.Tree.fromstring(out) 49 | 50 | @staticmethod 51 | def dep2tree(dep): 52 | tree = nx.DiGraph() 53 | for dep, i, gov, j, label in dep: 54 | tree.add_edge(gov, dep, label=label) 55 | return tree 56 | -------------------------------------------------------------------------------- /my/nltk_utils.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import numpy as np 3 | 4 | 5 | def _set_span(t, i): 6 | if isinstance(t[0], str): 7 | t.span = (i, i+len(t)) 8 | else: 9 | first = True 10 | for c in t: 11 | cur_span = _set_span(c, i) 12 | i = cur_span[1] 13 | if first: 14 | min_ = cur_span[0] 15 | first = False 16 | max_ = cur_span[1] 17 | t.span = (min_, max_) 18 | return t.span 19 | 20 | 21 | def set_span(t): 22 | assert isinstance(t, nltk.tree.Tree) 23 | try: 24 | return _set_span(t, 0) 25 | except: 26 | print(t) 27 | exit() 28 | 29 | 30 | def tree_contains_span(tree, span): 31 | """ 32 | Assumes that tree span has been set with set_span 33 | Returns true if any subtree of t has exact span as the given span 34 | :param t: 35 | :param span: 36 | :return bool: 37 | """ 38 | return span in set(t.span for t in tree.subtrees()) 39 | 40 | 41 | def span_len(span): 42 | return span[1] - span[0] 43 | 44 | 45 | def span_overlap(s1, s2): 46 | start = max(s1[0], s2[0]) 47 | stop = min(s1[1], s2[1]) 48 | if stop > start: 49 | return start, stop 50 | return None 51 | 52 | 53 | def span_prec(true_span, pred_span): 54 | overlap = span_overlap(true_span, pred_span) 55 | if overlap is None: 56 | return 0 57 | return span_len(overlap) / span_len(pred_span) 58 | 59 | 60 | def span_recall(true_span, pred_span): 61 | overlap = span_overlap(true_span, pred_span) 62 | if overlap is None: 63 | return 0 64 | return span_len(overlap) / span_len(true_span) 65 | 66 | 67 | def span_f1(true_span, pred_span): 68 | p = span_prec(true_span, pred_span) 69 | r = span_recall(true_span, pred_span) 70 | if p == 0 or r == 0: 71 | return 0.0 72 | return 2 * p * r / (p + r) 73 | 74 | 75 | def find_max_f1_span(tree, span): 76 | return find_max_f1_subtree(tree, span).span 77 | 78 | 79 | def find_max_f1_subtree(tree, span): 80 | return max(((t, span_f1(span, t.span)) for t in tree.subtrees()), key=lambda p: p[1])[0] 81 | 82 | 83 | def tree2matrix(tree, node2num, row_size=None, col_size=None, dtype='int32'): 84 | set_span(tree) 85 | D = tree.height() - 1 86 | B = len(tree.leaves()) 87 | row_size = row_size or D 88 | col_size = col_size or B 89 | matrix = np.zeros([row_size, col_size], dtype=dtype) 90 | mask = np.zeros([row_size, col_size, col_size], dtype='bool') 91 | 92 | for subtree in tree.subtrees(): 93 | row = subtree.height() - 2 94 | col = subtree.span[0] 95 | matrix[row, col] = node2num(subtree) 96 | for subsub in subtree.subtrees(): 97 | if isinstance(subsub, nltk.tree.Tree): 98 | mask[row, col, subsub.span[0]] = True 99 | if not isinstance(subsub[0], nltk.tree.Tree): 100 | c = subsub.span[0] 101 | for r in range(row): 102 | mask[r, c, c] = True 103 | else: 104 | mask[row, col, col] = True 105 | 106 | return matrix, mask 107 | 108 | 109 | def load_compressed_tree(s): 110 | 111 | def compress_tree(tree): 112 | assert not isinstance(tree, str) 113 | if len(tree) == 1: 114 | if isinstance(tree[0], nltk.tree.Tree): 115 | return compress_tree(tree[0]) 116 | else: 117 | return tree 118 | else: 119 | for i, t in enumerate(tree): 120 | if isinstance(t, nltk.tree.Tree): 121 | tree[i] = compress_tree(t) 122 | else: 123 | tree[i] = t 124 | return tree 125 | 126 | return compress_tree(nltk.tree.Tree.fromstring(s)) 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /my/tensorflow/__init__.py: -------------------------------------------------------------------------------- 1 | from my.tensorflow.general import * -------------------------------------------------------------------------------- /my/tensorflow/general.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | 3 | import tensorflow as tf 4 | from functools import reduce 5 | from operator import mul 6 | import numpy as np 7 | 8 | VERY_BIG_NUMBER = 1e30 9 | VERY_SMALL_NUMBER = 1e-30 10 | VERY_POSITIVE_NUMBER = VERY_BIG_NUMBER 11 | VERY_NEGATIVE_NUMBER = -VERY_BIG_NUMBER 12 | 13 | 14 | def get_initializer(matrix): 15 | def _initializer(shape, dtype=None, partition_info=None, **kwargs): return matrix 16 | return _initializer 17 | 18 | 19 | def variable_on_cpu(name, shape, initializer): 20 | """Helper to create a Variable stored on CPU memory. 21 | 22 | Args: 23 | name: name of the variable 24 | shape: list of ints 25 | initializer: initializer for Variable 26 | 27 | Returns: 28 | Variable Tensor 29 | """ 30 | with tf.device('/cpu:0'): 31 | var = tf.get_variable(name, shape, initializer=initializer) 32 | return var 33 | 34 | 35 | def variable_with_weight_decay(name, shape, stddev, wd): 36 | """Helper to create an initialized Variable with weight decay. 37 | 38 | Note that the Variable is initialized with a truncated normal distribution. 39 | A weight decay is added only if one is specified. 40 | 41 | Args: 42 | name: name of the variable 43 | shape: list of ints 44 | stddev: standard deviation of a truncated Gaussian 45 | wd: add L2Loss weight decay multiplied by this float. If None, weight 46 | decay is not added for this Variable. 47 | 48 | Returns: 49 | Variable Tensor 50 | """ 51 | var = variable_on_cpu(name, shape, 52 | tf.truncated_normal_initializer(stddev=stddev)) 53 | if wd: 54 | weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss') 55 | tf.add_to_collection('losses', weight_decay) 56 | return var 57 | 58 | 59 | def average_gradients(tower_grads): 60 | """Calculate the average gradient for each shared variable across all towers. 61 | 62 | Note that this function provides a synchronization point across all towers. 63 | 64 | Args: 65 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 66 | is over individual gradients. The inner list is over the gradient 67 | calculation for each tower. 68 | Returns: 69 | List of pairs of (gradient, variable) where the gradient has been averaged 70 | across all towers. 71 | """ 72 | average_grads = [] 73 | for grad_and_vars in zip(*tower_grads): 74 | # Note that each grad_and_vars looks like the following: 75 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 76 | grads = [] 77 | for g, var in grad_and_vars: 78 | # Add 0 dimension to the gradients to represent the tower. 79 | assert g is not None, var.name 80 | expanded_g = tf.expand_dims(g, 0) 81 | 82 | # Append on a 'tower' dimension which we will average over below. 83 | grads.append(expanded_g) 84 | 85 | # Average over the 'tower' dimension. 86 | grad = tf.concat(0, grads) 87 | grad = tf.reduce_mean(grad, 0) 88 | 89 | # Keep in mind that the Variables are redundant because they are shared 90 | # across towers. So .. we will just return the first tower's pointer to 91 | # the Variable. 92 | v = grad_and_vars[0][1] 93 | grad_and_var = (grad, v) 94 | average_grads.append(grad_and_var) 95 | return average_grads 96 | 97 | 98 | def mask(val, mask, name=None): 99 | if name is None: 100 | name = 'mask' 101 | return tf.mul(val, tf.cast(mask, 'float'), name=name) 102 | 103 | 104 | def exp_mask(val, mask, name=None): 105 | """Give very negative number to unmasked elements in val. 106 | For example, [-3, -2, 10], [True, True, False] -> [-3, -2, -1e9]. 107 | Typically, this effectively masks in exponential space (e.g. softmax) 108 | Args: 109 | val: values to be masked 110 | mask: masking boolean tensor, same shape as tensor 111 | name: name for output tensor 112 | 113 | Returns: 114 | Same shape as val, where some elements are very small (exponentially zero) 115 | """ 116 | if name is None: 117 | name = "exp_mask" 118 | return tf.add(val, (1 - tf.cast(mask, 'float')) * VERY_NEGATIVE_NUMBER, name=name) 119 | 120 | 121 | def flatten(tensor, keep): 122 | fixed_shape = tensor.get_shape().as_list() 123 | start = len(fixed_shape) - keep 124 | left = reduce(mul, [fixed_shape[i] or tf.shape(tensor)[i] for i in range(start)]) 125 | out_shape = [left] + [fixed_shape[i] or tf.shape(tensor)[i] for i in range(start, len(fixed_shape))] 126 | flat = tf.reshape(tensor, out_shape) 127 | return flat 128 | 129 | 130 | def reconstruct(tensor, ref, keep): 131 | ref_shape = ref.get_shape().as_list() 132 | tensor_shape = tensor.get_shape().as_list() 133 | ref_stop = len(ref_shape) - keep 134 | tensor_start = len(tensor_shape) - keep 135 | pre_shape = [ref_shape[i] or tf.shape(ref)[i] for i in range(ref_stop)] 136 | keep_shape = [tensor_shape[i] or tf.shape(tensor)[i] for i in range(tensor_start, len(tensor_shape))] 137 | # pre_shape = [tf.shape(ref)[i] for i in range(len(ref.get_shape().as_list()[:-keep]))] 138 | # keep_shape = tensor.get_shape().as_list()[-keep:] 139 | target_shape = pre_shape + keep_shape 140 | out = tf.reshape(tensor, target_shape) 141 | return out 142 | 143 | 144 | def add_wd(wd, scope=None): 145 | scope = scope or tf.get_variable_scope().name 146 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) 147 | with tf.name_scope("weight_decay"): 148 | for var in variables: 149 | weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name="{}/wd".format(var.op.name)) 150 | tf.add_to_collection('losses', weight_decay) 151 | 152 | 153 | def grouper(iterable, n, fillvalue=None, shorten=False, num_groups=None): 154 | args = [iter(iterable)] * n 155 | out = zip_longest(*args, fillvalue=fillvalue) 156 | out = list(out) 157 | if num_groups is not None: 158 | default = (fillvalue, ) * n 159 | assert isinstance(num_groups, int) 160 | out = list(each for each, _ in zip_longest(out, range(num_groups), fillvalue=default)) 161 | if shorten: 162 | assert fillvalue is None 163 | out = (tuple(e for e in each if e is not None) for each in out) 164 | return out 165 | 166 | def padded_reshape(tensor, shape, mode='CONSTANT', name=None): 167 | paddings = [[0, shape[i] - tf.shape(tensor)[i]] for i in range(len(shape))] 168 | return tf.pad(tensor, paddings, mode=mode, name=name) -------------------------------------------------------------------------------- /my/tensorflow/nn.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.ops.rnn_cell import _linear 2 | from tensorflow.python.util import nest 3 | import tensorflow as tf 4 | 5 | from my.tensorflow import flatten, reconstruct, add_wd, exp_mask 6 | 7 | 8 | def linear(args, output_size, bias, bias_start=0.0, scope=None, squeeze=False, wd=0.0, input_keep_prob=1.0, 9 | is_train=None): 10 | if args is None or (nest.is_sequence(args) and not args): 11 | raise ValueError("`args` must be specified") 12 | if not nest.is_sequence(args): 13 | args = [args] 14 | 15 | flat_args = [flatten(arg, 1) for arg in args] 16 | if input_keep_prob < 1.0: 17 | assert is_train is not None 18 | flat_args = [tf.cond(is_train, lambda: tf.nn.dropout(arg, input_keep_prob), lambda: arg) 19 | for arg in flat_args] 20 | flat_out = _linear(flat_args, output_size, bias, bias_start=bias_start, scope=scope) 21 | out = reconstruct(flat_out, args[0], 1) 22 | if squeeze: 23 | out = tf.squeeze(out, [len(args[0].get_shape().as_list())-1]) 24 | if wd: 25 | add_wd(wd) 26 | 27 | return out 28 | 29 | 30 | def dropout(x, keep_prob, is_train, noise_shape=None, seed=None, name=None): 31 | with tf.name_scope(name or "dropout"): 32 | if keep_prob < 1.0: 33 | d = tf.nn.dropout(x, keep_prob, noise_shape=noise_shape, seed=seed) 34 | out = tf.cond(is_train, lambda: d, lambda: x) 35 | return out 36 | return x 37 | 38 | 39 | def softmax(logits, mask=None, scope=None): 40 | with tf.name_scope(scope or "Softmax"): 41 | if mask is not None: 42 | logits = exp_mask(logits, mask) 43 | flat_logits = flatten(logits, 1) 44 | flat_out = tf.nn.softmax(flat_logits) 45 | out = reconstruct(flat_out, logits, 1) 46 | 47 | return out 48 | 49 | 50 | def softsel(target, logits, mask=None, scope=None): 51 | """ 52 | 53 | :param target: [ ..., J, d] dtype=float 54 | :param logits: [ ..., J], dtype=float 55 | :param mask: [ ..., J], dtype=bool 56 | :param scope: 57 | :return: [..., d], dtype=float 58 | """ 59 | with tf.name_scope(scope or "Softsel"): 60 | a = softmax(logits, mask=mask) 61 | target_rank = len(target.get_shape().as_list()) 62 | out = tf.reduce_sum(tf.expand_dims(a, -1) * target, target_rank - 2) 63 | return out 64 | 65 | 66 | def double_linear_logits(args, size, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None): 67 | with tf.variable_scope(scope or "Double_Linear_Logits"): 68 | first = tf.tanh(linear(args, size, bias, bias_start=bias_start, scope='first', 69 | wd=wd, input_keep_prob=input_keep_prob, is_train=is_train)) 70 | second = linear(first, 1, bias, bias_start=bias_start, squeeze=True, scope='second', 71 | wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) 72 | if mask is not None: 73 | second = exp_mask(second, mask) 74 | return second 75 | 76 | 77 | def linear_logits(args, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None): 78 | with tf.variable_scope(scope or "Linear_Logits"): 79 | logits = linear(args, 1, bias, bias_start=bias_start, squeeze=True, scope='first', 80 | wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) 81 | if mask is not None: 82 | logits = exp_mask(logits, mask) 83 | return logits 84 | 85 | 86 | def sum_logits(args, mask=None, name=None): 87 | with tf.name_scope(name or "sum_logits"): 88 | if args is None or (nest.is_sequence(args) and not args): 89 | raise ValueError("`args` must be specified") 90 | if not nest.is_sequence(args): 91 | args = [args] 92 | rank = len(args[0].get_shape()) 93 | logits = sum(tf.reduce_sum(arg, rank-1) for arg in args) 94 | if mask is not None: 95 | logits = exp_mask(logits, mask) 96 | return logits 97 | 98 | 99 | def get_logits(args, size, bias, bias_start=0.0, scope=None, mask=None, wd=0.0, input_keep_prob=1.0, is_train=None, func=None): 100 | if func is None: 101 | func = "sum" 102 | if func == 'sum': 103 | return sum_logits(args, mask=mask, name=scope) 104 | elif func == 'linear': 105 | return linear_logits(args, bias, bias_start=bias_start, scope=scope, mask=mask, wd=wd, input_keep_prob=input_keep_prob, 106 | is_train=is_train) 107 | elif func == 'double': 108 | return double_linear_logits(args, size, bias, bias_start=bias_start, scope=scope, mask=mask, wd=wd, input_keep_prob=input_keep_prob, 109 | is_train=is_train) 110 | elif func == 'dot': 111 | assert len(args) == 2 112 | arg = args[0] * args[1] 113 | return sum_logits([arg], mask=mask, name=scope) 114 | elif func == 'mul_linear': 115 | assert len(args) == 2 116 | arg = args[0] * args[1] 117 | return linear_logits([arg], bias, bias_start=bias_start, scope=scope, mask=mask, wd=wd, input_keep_prob=input_keep_prob, 118 | is_train=is_train) 119 | elif func == 'proj': 120 | assert len(args) == 2 121 | d = args[1].get_shape()[-1] 122 | proj = linear([args[0]], d, False, bias_start=bias_start, scope=scope, wd=wd, input_keep_prob=input_keep_prob, 123 | is_train=is_train) 124 | return sum_logits([proj * args[1]], mask=mask) 125 | elif func == 'tri_linear': 126 | assert len(args) == 2 127 | new_arg = args[0] * args[1] 128 | return linear_logits([args[0], args[1], new_arg], bias, bias_start=bias_start, scope=scope, mask=mask, wd=wd, input_keep_prob=input_keep_prob, 129 | is_train=is_train) 130 | else: 131 | raise Exception() 132 | 133 | 134 | def highway_layer(arg, bias, bias_start=0.0, scope=None, wd=0.0, input_keep_prob=1.0, is_train=None): 135 | with tf.variable_scope(scope or "highway_layer"): 136 | d = arg.get_shape()[-1] 137 | trans = linear([arg], d, bias, bias_start=bias_start, scope='trans', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) 138 | trans = tf.nn.relu(trans) 139 | gate = linear([arg], d, bias, bias_start=bias_start, scope='gate', wd=wd, input_keep_prob=input_keep_prob, is_train=is_train) 140 | gate = tf.nn.sigmoid(gate) 141 | out = gate * trans + (1 - gate) * arg 142 | return out 143 | 144 | 145 | def highway_network(arg, num_layers, bias, bias_start=0.0, scope=None, wd=0.0, input_keep_prob=1.0, is_train=None): 146 | with tf.variable_scope(scope or "highway_network"): 147 | prev = arg 148 | cur = None 149 | for layer_idx in range(num_layers): 150 | cur = highway_layer(prev, bias, bias_start=bias_start, scope="layer_{}".format(layer_idx), wd=wd, 151 | input_keep_prob=input_keep_prob, is_train=is_train) 152 | prev = cur 153 | return cur 154 | 155 | 156 | def conv1d(in_, filter_size, height, padding, is_train=None, keep_prob=1.0, scope=None): 157 | with tf.variable_scope(scope or "conv1d"): 158 | num_channels = in_.get_shape()[-1] 159 | filter_ = tf.get_variable("filter", shape=[1, height, num_channels, filter_size], dtype='float') 160 | bias = tf.get_variable("bias", shape=[filter_size], dtype='float') 161 | strides = [1, 1, 1, 1] 162 | if is_train is not None and keep_prob < 1.0: 163 | in_ = dropout(in_, keep_prob, is_train) 164 | xxc = tf.nn.conv2d(in_, filter_, strides, padding) + bias # [N*M, JX, W/filter_stride, d] 165 | out = tf.reduce_max(tf.nn.relu(xxc), 2) # [-1, JX, d] 166 | return out 167 | 168 | 169 | def multi_conv1d(in_, filter_sizes, heights, padding, is_train=None, keep_prob=1.0, scope=None): 170 | with tf.variable_scope(scope or "multi_conv1d"): 171 | assert len(filter_sizes) == len(heights) 172 | outs = [] 173 | for filter_size, height in zip(filter_sizes, heights): 174 | if filter_size == 0: 175 | continue 176 | out = conv1d(in_, filter_size, height, padding, is_train=is_train, keep_prob=keep_prob, scope="conv1d_{}".format(height)) 177 | outs.append(out) 178 | concat_out = tf.concat(2, outs) 179 | return concat_out 180 | -------------------------------------------------------------------------------- /my/tensorflow/rnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops.rnn import dynamic_rnn as _dynamic_rnn, \ 3 | bidirectional_dynamic_rnn as _bidirectional_dynamic_rnn 4 | from tensorflow.python.ops.rnn import bidirectional_rnn as _bidirectional_rnn 5 | 6 | from my.tensorflow import flatten, reconstruct 7 | 8 | 9 | def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, 10 | dtype=None, parallel_iterations=None, swap_memory=False, 11 | time_major=False, scope=None): 12 | assert not time_major # TODO : to be implemented later! 13 | flat_inputs = flatten(inputs, 2) # [-1, J, d] 14 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') 15 | 16 | flat_outputs, final_state = _dynamic_rnn(cell, flat_inputs, sequence_length=flat_len, 17 | initial_state=initial_state, dtype=dtype, 18 | parallel_iterations=parallel_iterations, swap_memory=swap_memory, 19 | time_major=time_major, scope=scope) 20 | 21 | outputs = reconstruct(flat_outputs, inputs, 2) 22 | return outputs, final_state 23 | 24 | 25 | def bw_dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, 26 | dtype=None, parallel_iterations=None, swap_memory=False, 27 | time_major=False, scope=None): 28 | assert not time_major # TODO : to be implemented later! 29 | 30 | flat_inputs = flatten(inputs, 2) # [-1, J, d] 31 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') 32 | 33 | flat_inputs = tf.reverse(flat_inputs, 1) if sequence_length is None \ 34 | else tf.reverse_sequence(flat_inputs, sequence_length, 1) 35 | flat_outputs, final_state = _dynamic_rnn(cell, flat_inputs, sequence_length=flat_len, 36 | initial_state=initial_state, dtype=dtype, 37 | parallel_iterations=parallel_iterations, swap_memory=swap_memory, 38 | time_major=time_major, scope=scope) 39 | flat_outputs = tf.reverse(flat_outputs, 1) if sequence_length is None \ 40 | else tf.reverse_sequence(flat_outputs, sequence_length, 1) 41 | 42 | outputs = reconstruct(flat_outputs, inputs, 2) 43 | return outputs, final_state 44 | 45 | 46 | def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, 47 | initial_state_fw=None, initial_state_bw=None, 48 | dtype=None, parallel_iterations=None, 49 | swap_memory=False, time_major=False, scope=None): 50 | assert not time_major 51 | 52 | flat_inputs = flatten(inputs, 2) # [-1, J, d] 53 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') 54 | 55 | (flat_fw_outputs, flat_bw_outputs), final_state = \ 56 | _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, 57 | initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, 58 | dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, 59 | time_major=time_major, scope=scope) 60 | 61 | fw_outputs = reconstruct(flat_fw_outputs, inputs, 2) 62 | bw_outputs = reconstruct(flat_bw_outputs, inputs, 2) 63 | # FIXME : final state is not reshaped! 64 | return (fw_outputs, bw_outputs), final_state 65 | 66 | 67 | def bidirectional_rnn(cell_fw, cell_bw, inputs, 68 | initial_state_fw=None, initial_state_bw=None, 69 | dtype=None, sequence_length=None, scope=None): 70 | 71 | flat_inputs = flatten(inputs, 2) # [-1, J, d] 72 | flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64') 73 | 74 | (flat_fw_outputs, flat_bw_outputs), final_state = \ 75 | _bidirectional_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len, 76 | initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, 77 | dtype=dtype, scope=scope) 78 | 79 | fw_outputs = reconstruct(flat_fw_outputs, inputs, 2) 80 | bw_outputs = reconstruct(flat_bw_outputs, inputs, 2) 81 | # FIXME : final state is not reshaped! 82 | return (fw_outputs, bw_outputs), final_state 83 | -------------------------------------------------------------------------------- /my/tensorflow/rnn_cell.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops.rnn_cell import DropoutWrapper, RNNCell, LSTMStateTuple 3 | 4 | from my.tensorflow import exp_mask, flatten 5 | from my.tensorflow.nn import linear, softsel, double_linear_logits 6 | 7 | 8 | class SwitchableDropoutWrapper(DropoutWrapper): 9 | def __init__(self, cell, is_train, input_keep_prob=1.0, output_keep_prob=1.0, 10 | seed=None): 11 | super(SwitchableDropoutWrapper, self).__init__(cell, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob, 12 | seed=seed) 13 | self.is_train = is_train 14 | 15 | def __call__(self, inputs, state, scope=None): 16 | outputs_do, new_state_do = super(SwitchableDropoutWrapper, self).__call__(inputs, state, scope=scope) 17 | tf.get_variable_scope().reuse_variables() 18 | outputs, new_state = self._cell(inputs, state, scope) 19 | outputs = tf.cond(self.is_train, lambda: outputs_do, lambda: outputs) 20 | if isinstance(state, tuple): 21 | new_state = state.__class__(*[tf.cond(self.is_train, lambda: new_state_do_i, lambda: new_state_i) 22 | for new_state_do_i, new_state_i in zip(new_state_do, new_state)]) 23 | else: 24 | new_state = tf.cond(self.is_train, lambda: new_state_do, lambda: new_state) 25 | return outputs, new_state 26 | 27 | 28 | class TreeRNNCell(RNNCell): 29 | def __init__(self, cell, input_size, reduce_func): 30 | self._cell = cell 31 | self._input_size = input_size 32 | self._reduce_func = reduce_func 33 | 34 | def __call__(self, inputs, state, scope=None): 35 | """ 36 | :param inputs: [N*B, I + B] 37 | :param state: [N*B, d] 38 | :param scope: 39 | :return: [N*B, d] 40 | """ 41 | with tf.variable_scope(scope or self.__class__.__name__): 42 | d = self.state_size 43 | x = tf.slice(inputs, [0, 0], [-1, self._input_size]) # [N*B, I] 44 | mask = tf.slice(inputs, [0, self._input_size], [-1, -1]) # [N*B, B] 45 | B = tf.shape(mask)[1] 46 | prev_state = tf.expand_dims(tf.reshape(state, [-1, B, d]), 1) # [N, B, d] -> [N, 1, B, d] 47 | mask = tf.tile(tf.expand_dims(tf.reshape(mask, [-1, B, B]), -1), [1, 1, 1, d]) # [N, B, B, d] 48 | # prev_state = self._reduce_func(tf.tile(prev_state, [1, B, 1, 1]), 2) 49 | prev_state = self._reduce_func(exp_mask(prev_state, mask), 2) # [N, B, d] 50 | prev_state = tf.reshape(prev_state, [-1, d]) # [N*B, d] 51 | return self._cell(x, prev_state) 52 | 53 | @property 54 | def state_size(self): 55 | return self._cell.state_size 56 | 57 | @property 58 | def output_size(self): 59 | return self._cell.output_size 60 | 61 | 62 | class NoOpCell(RNNCell): 63 | def __init__(self, num_units): 64 | self._num_units = num_units 65 | 66 | def __call__(self, inputs, state, scope=None): 67 | return state, state 68 | 69 | @property 70 | def state_size(self): 71 | return self._num_units 72 | 73 | @property 74 | def output_size(self): 75 | return self._num_units 76 | 77 | 78 | class MatchCell(RNNCell): 79 | def __init__(self, cell, input_size, q_len): 80 | self._cell = cell 81 | self._input_size = input_size 82 | # FIXME : This won't be needed with good shape guessing 83 | self._q_len = q_len 84 | 85 | @property 86 | def state_size(self): 87 | return self._cell.state_size 88 | 89 | @property 90 | def output_size(self): 91 | return self._cell.output_size 92 | 93 | def __call__(self, inputs, state, scope=None): 94 | """ 95 | 96 | :param inputs: [N, d + JQ + JQ * d] 97 | :param state: [N, d] 98 | :param scope: 99 | :return: 100 | """ 101 | with tf.variable_scope(scope or self.__class__.__name__): 102 | c_prev, h_prev = state 103 | x = tf.slice(inputs, [0, 0], [-1, self._input_size]) 104 | q_mask = tf.slice(inputs, [0, self._input_size], [-1, self._q_len]) # [N, JQ] 105 | qs = tf.slice(inputs, [0, self._input_size + self._q_len], [-1, -1]) 106 | qs = tf.reshape(qs, [-1, self._q_len, self._input_size]) # [N, JQ, d] 107 | x_tiled = tf.tile(tf.expand_dims(x, 1), [1, self._q_len, 1]) # [N, JQ, d] 108 | h_prev_tiled = tf.tile(tf.expand_dims(h_prev, 1), [1, self._q_len, 1]) # [N, JQ, d] 109 | f = tf.tanh(linear([qs, x_tiled, h_prev_tiled], self._input_size, True, scope='f')) # [N, JQ, d] 110 | a = tf.nn.softmax(exp_mask(linear(f, 1, True, squeeze=True, scope='a'), q_mask)) # [N, JQ] 111 | q = tf.reduce_sum(qs * tf.expand_dims(a, -1), 1) 112 | z = tf.concat(1, [x, q]) # [N, 2d] 113 | return self._cell(z, state) 114 | 115 | 116 | class AttentionCell(RNNCell): 117 | def __init__(self, cell, memory, mask=None, controller=None, mapper=None, input_keep_prob=1.0, is_train=None): 118 | """ 119 | Early fusion attention cell: uses the (inputs, state) to control the current attention. 120 | 121 | :param cell: 122 | :param memory: [N, M, m] 123 | :param mask: 124 | :param controller: (inputs, prev_state, memory) -> memory_logits 125 | """ 126 | self._cell = cell 127 | self._memory = memory 128 | self._mask = mask 129 | self._flat_memory = flatten(memory, 2) 130 | self._flat_mask = flatten(mask, 1) 131 | if controller is None: 132 | controller = AttentionCell.get_linear_controller(True, is_train=is_train) 133 | self._controller = controller 134 | if mapper is None: 135 | mapper = AttentionCell.get_concat_mapper() 136 | elif mapper == 'sim': 137 | mapper = AttentionCell.get_sim_mapper() 138 | self._mapper = mapper 139 | 140 | @property 141 | def state_size(self): 142 | return self._cell.state_size 143 | 144 | @property 145 | def output_size(self): 146 | return self._cell.output_size 147 | 148 | def __call__(self, inputs, state, scope=None): 149 | with tf.variable_scope(scope or "AttentionCell"): 150 | memory_logits = self._controller(inputs, state, self._flat_memory) 151 | sel_mem = softsel(self._flat_memory, memory_logits, mask=self._flat_mask) # [N, m] 152 | new_inputs, new_state = self._mapper(inputs, state, sel_mem) 153 | return self._cell(new_inputs, state) 154 | 155 | @staticmethod 156 | def get_double_linear_controller(size, bias, input_keep_prob=1.0, is_train=None): 157 | def double_linear_controller(inputs, state, memory): 158 | """ 159 | 160 | :param inputs: [N, i] 161 | :param state: [N, d] 162 | :param memory: [N, M, m] 163 | :return: [N, M] 164 | """ 165 | rank = len(memory.get_shape()) 166 | _memory_size = tf.shape(memory)[rank-2] 167 | tiled_inputs = tf.tile(tf.expand_dims(inputs, 1), [1, _memory_size, 1]) 168 | if isinstance(state, tuple): 169 | tiled_states = [tf.tile(tf.expand_dims(each, 1), [1, _memory_size, 1]) 170 | for each in state] 171 | else: 172 | tiled_states = [tf.tile(tf.expand_dims(state, 1), [1, _memory_size, 1])] 173 | 174 | # [N, M, d] 175 | in_ = tf.concat(2, [tiled_inputs] + tiled_states + [memory]) 176 | out = double_linear_logits(in_, size, bias, input_keep_prob=input_keep_prob, 177 | is_train=is_train) 178 | return out 179 | return double_linear_controller 180 | 181 | @staticmethod 182 | def get_linear_controller(bias, input_keep_prob=1.0, is_train=None): 183 | def linear_controller(inputs, state, memory): 184 | rank = len(memory.get_shape()) 185 | _memory_size = tf.shape(memory)[rank-2] 186 | tiled_inputs = tf.tile(tf.expand_dims(inputs, 1), [1, _memory_size, 1]) 187 | if isinstance(state, tuple): 188 | tiled_states = [tf.tile(tf.expand_dims(each, 1), [1, _memory_size, 1]) 189 | for each in state] 190 | else: 191 | tiled_states = [tf.tile(tf.expand_dims(state, 1), [1, _memory_size, 1])] 192 | 193 | # [N, M, d] 194 | in_ = tf.concat(2, [tiled_inputs] + tiled_states + [memory]) 195 | out = linear(in_, 1, bias, squeeze=True, input_keep_prob=input_keep_prob, is_train=is_train) 196 | return out 197 | return linear_controller 198 | 199 | @staticmethod 200 | def get_concat_mapper(): 201 | def concat_mapper(inputs, state, sel_mem): 202 | """ 203 | 204 | :param inputs: [N, i] 205 | :param state: [N, d] 206 | :param sel_mem: [N, m] 207 | :return: (new_inputs, new_state) tuple 208 | """ 209 | return tf.concat(1, [inputs, sel_mem]), state 210 | return concat_mapper 211 | 212 | @staticmethod 213 | def get_sim_mapper(): 214 | def sim_mapper(inputs, state, sel_mem): 215 | """ 216 | Assume that inputs and sel_mem are the same size 217 | :param inputs: [N, i] 218 | :param state: [N, d] 219 | :param sel_mem: [N, i] 220 | :return: (new_inputs, new_state) tuple 221 | """ 222 | return tf.concat(1, [inputs, sel_mem, inputs * sel_mem, tf.abs(inputs - sel_mem)]), state 223 | return sim_mapper 224 | -------------------------------------------------------------------------------- /my/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import deque 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | 8 | def mytqdm(list_, desc="", show=True): 9 | if show: 10 | pbar = tqdm(list_) 11 | pbar.set_description(desc) 12 | return pbar 13 | return list_ 14 | 15 | 16 | def json_pretty_dump(obj, fh): 17 | return json.dump(obj, fh, sort_keys=True, indent=2, separators=(',', ': ')) 18 | 19 | 20 | def index(l, i): 21 | return index(l[i[0]], i[1:]) if len(i) > 1 else l[i[0]] 22 | 23 | 24 | def fill(l, shape, dtype=None): 25 | out = np.zeros(shape, dtype=dtype) 26 | stack = deque() 27 | stack.appendleft(((), l)) 28 | while len(stack) > 0: 29 | indices, cur = stack.pop() 30 | if len(indices) < shape: 31 | for i, sub in enumerate(cur): 32 | stack.appendleft([indices + (i,), sub]) 33 | else: 34 | out[indices] = cur 35 | return out 36 | 37 | 38 | def short_floats(o, precision): 39 | class ShortFloat(float): 40 | def __repr__(self): 41 | return '%.{}g'.format(precision) % self 42 | 43 | def _short_floats(obj): 44 | if isinstance(obj, float): 45 | return ShortFloat(obj) 46 | elif isinstance(obj, dict): 47 | return dict((k, _short_floats(v)) for k, v in obj.items()) 48 | elif isinstance(obj, (list, tuple)): 49 | return tuple(map(_short_floats, obj)) 50 | return obj 51 | 52 | return _short_floats(o) 53 | 54 | 55 | def argmax(x): 56 | return np.unravel_index(x.argmax(), x.shape) 57 | 58 | 59 | -------------------------------------------------------------------------------- /my/zip_save.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import shutil 5 | from zipfile import ZipFile 6 | 7 | from tqdm import tqdm 8 | 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('paths', nargs='+') 13 | parser.add_argument('-o', '--out', default='save.zip') 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | def zip_save(args): 19 | temp_dir = "." 20 | save_dir = os.path.join(temp_dir, "save") 21 | if not os.path.exists(save_dir): 22 | os.makedirs(save_dir) 23 | for save_source_path in tqdm(args.paths): 24 | # path = "out/basic/30/save/basic-18000" 25 | # target_path = "save_dir/30/save" 26 | # also output full path name to "save_dir/30/readme.txt 27 | # need to also extract "out/basic/30/shared.json" 28 | temp, _ = os.path.split(save_source_path) # "out/basic/30/save", _ 29 | model_dir, _ = os.path.split(temp) # "out/basic/30, _ 30 | _, model_name = os.path.split(model_dir) 31 | cur_dir = os.path.join(save_dir, model_name) 32 | if not os.path.exists(cur_dir): 33 | os.makedirs(cur_dir) 34 | save_target_path = os.path.join(cur_dir, "save") 35 | shared_target_path = os.path.join(cur_dir, "shared.json") 36 | readme_path = os.path.join(cur_dir, "readme.txt") 37 | shared_source_path = os.path.join(model_dir, "shared.json") 38 | shutil.copy(save_source_path, save_target_path) 39 | shutil.copy(shared_source_path, shared_target_path) 40 | with open(readme_path, 'w') as fh: 41 | fh.write(save_source_path) 42 | 43 | os.system("zip {} -r {}".format(args.out, save_dir)) 44 | 45 | def main(): 46 | args = get_args() 47 | zip_save(args) 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=0.11 2 | nltk 3 | tqdm 4 | jinja2 -------------------------------------------------------------------------------- /squad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/bi-att-flow/49004549e9a88b78c359b31481afa7792dbb3f4a/squad/__init__.py -------------------------------------------------------------------------------- /squad/aug_squad.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | from tqdm import tqdm 5 | 6 | from my.corenlp_interface import CoreNLPInterface 7 | 8 | in_path = sys.argv[1] 9 | out_path = sys.argv[2] 10 | url = sys.argv[3] 11 | port = int(sys.argv[4]) 12 | data = json.load(open(in_path, 'r')) 13 | 14 | h = CoreNLPInterface(url, port) 15 | 16 | 17 | def find_all(a_str, sub): 18 | start = 0 19 | while True: 20 | start = a_str.find(sub, start) 21 | if start == -1: return 22 | yield start 23 | start += len(sub) # use start += 1 to find overlapping matches 24 | 25 | 26 | def to_hex(s): 27 | return " ".join(map(hex, map(ord, s))) 28 | 29 | 30 | def handle_nobreak(cand, text): 31 | if cand == text: 32 | return cand 33 | if cand.replace(u'\u00A0', ' ') == text: 34 | return cand 35 | elif cand == text.replace(u'\u00A0', ' '): 36 | return text 37 | raise Exception("{} '{}' {} '{}'".format(cand, to_hex(cand), text, to_hex(text))) 38 | 39 | 40 | # resolving unicode complication 41 | 42 | wrong_loc_count = 0 43 | loc_diffs = [] 44 | 45 | for article in data['data']: 46 | for para in article['paragraphs']: 47 | para['context'] = para['context'].replace(u'\u000A', '') 48 | para['context'] = para['context'].replace(u'\u00A0', ' ') 49 | context = para['context'] 50 | for qa in para['qas']: 51 | for answer in qa['answers']: 52 | answer['text'] = answer['text'].replace(u'\u00A0', ' ') 53 | text = answer['text'] 54 | answer_start = answer['answer_start'] 55 | if context[answer_start:answer_start + len(text)] == text: 56 | if text.lstrip() == text: 57 | pass 58 | else: 59 | answer_start += len(text) - len(text.lstrip()) 60 | answer['answer_start'] = answer_start 61 | text = text.lstrip() 62 | answer['text'] = text 63 | else: 64 | wrong_loc_count += 1 65 | text = text.lstrip() 66 | answer['text'] = text 67 | starts = list(find_all(context, text)) 68 | if len(starts) == 1: 69 | answer_start = starts[0] 70 | elif len(starts) > 1: 71 | new_answer_start = min(starts, key=lambda s: abs(s - answer_start)) 72 | loc_diffs.append(abs(new_answer_start - answer_start)) 73 | answer_start = new_answer_start 74 | else: 75 | raise Exception() 76 | answer['answer_start'] = answer_start 77 | 78 | answer_stop = answer_start + len(text) 79 | answer['answer_stop'] = answer_stop 80 | assert para['context'][answer_start:answer_stop] == answer['text'], "{} {}".format( 81 | para['context'][answer_start:answer_stop], answer['text']) 82 | 83 | print(wrong_loc_count, loc_diffs) 84 | 85 | mismatch_count = 0 86 | dep_fail_count = 0 87 | no_answer_count = 0 88 | 89 | size = sum(len(article['paragraphs']) for article in data['data']) 90 | pbar = tqdm(range(size)) 91 | 92 | for ai, article in enumerate(data['data']): 93 | for pi, para in enumerate(article['paragraphs']): 94 | context = para['context'] 95 | sents = h.split_doc(context) 96 | words = h.split_sent(context) 97 | sent_starts = [] 98 | ref_idx = 0 99 | for sent in sents: 100 | new_idx = context.find(sent, ref_idx) 101 | sent_starts.append(new_idx) 102 | ref_idx = new_idx + len(sent) 103 | para['sents'] = sents 104 | para['words'] = words 105 | para['sent_starts'] = sent_starts 106 | 107 | consts = list(map(h.get_const, sents)) 108 | para['consts'] = consts 109 | deps = list(map(h.get_dep, sents)) 110 | para['deps'] = deps 111 | 112 | for qa in para['qas']: 113 | question = qa['question'] 114 | question_const = h.get_const(question) 115 | qa['const'] = question_const 116 | question_dep = h.get_dep(question) 117 | qa['dep'] = question_dep 118 | qa['words'] = h.split_sent(question) 119 | 120 | for answer in qa['answers']: 121 | answer_start = answer['answer_start'] 122 | text = answer['text'] 123 | answer_stop = answer_start + len(text) 124 | # answer_words = h.split_sent(text) 125 | word_idxs = [] 126 | answer_words = [] 127 | for sent_idx, (sent, sent_start, dep) in enumerate(zip(sents, sent_starts, deps)): 128 | if dep is None: 129 | print("dep parse failed at {} {} {}".format(ai, pi, sent_idx)) 130 | dep_fail_count += 1 131 | continue 132 | nodes, edges = dep 133 | words = [node[0] for node in nodes] 134 | 135 | for word_idx, (word, _, _, start, _) in enumerate(nodes): 136 | global_start = sent_start + start 137 | global_stop = global_start + len(word) 138 | if answer_start <= global_start < answer_stop or answer_start < global_stop <= answer_stop: 139 | word_idxs.append((sent_idx, word_idx)) 140 | answer_words.append(word) 141 | if len(word_idxs) > 0: 142 | answer['answer_word_start'] = word_idxs[0] 143 | answer['answer_word_stop'] = word_idxs[-1][0], word_idxs[-1][1] + 1 144 | if not text.startswith(answer_words[0]): 145 | print("'{}' '{}'".format(text, ' '.join(answer_words))) 146 | mismatch_count += 1 147 | else: 148 | answer['answer_word_start'] = None 149 | answer['answer_word_stop'] = None 150 | no_answer_count += 1 151 | pbar.update(1) 152 | pbar.close() 153 | 154 | print(mismatch_count, dep_fail_count, no_answer_count) 155 | 156 | print("saving...") 157 | json.dump(data, open(out_path, 'w')) -------------------------------------------------------------------------------- /squad/eda_aug_dev.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import json\n", 12 | "\n", 13 | "aug_data_path = \"/Users/minjoons/data/squad/dev-v1.0-aug.json\"\n", 14 | "aug_data = json.load(open(aug_data_path, 'r'))" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 17, 20 | "metadata": { 21 | "collapsed": false 22 | }, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "(['Denver', 'Broncos'], 'Denver Broncos')\n", 29 | "(['Denver', 'Broncos'], 'Denver Broncos')\n", 30 | "(['Denver', 'Broncos'], 'Denver Broncos ')\n", 31 | "(['Carolina', 'Panthers'], 'Carolina Panthers')\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "def compare_answers():\n", 37 | " for article in aug_data['data']:\n", 38 | " for para in article['paragraphs']:\n", 39 | " deps = para['deps']\n", 40 | " nodess = []\n", 41 | " for dep in deps:\n", 42 | " nodes, edges = dep\n", 43 | " if dep is not None:\n", 44 | " nodess.append(nodes)\n", 45 | " else:\n", 46 | " nodess.append([])\n", 47 | " wordss = [[node[0] for node in nodes] for nodes in nodess]\n", 48 | " for qa in para['qas']:\n", 49 | " for answer in qa['answers']:\n", 50 | " text = answer['text']\n", 51 | " word_start = answer['answer_word_start']\n", 52 | " word_stop = answer['answer_word_stop']\n", 53 | " answer_words = wordss[word_start[0]][word_start[1]:word_stop[1]]\n", 54 | " yield answer_words, text\n", 55 | "\n", 56 | "ca = compare_answers()\n", 57 | "print(next(ca))\n", 58 | "print(next(ca))\n", 59 | "print(next(ca))\n", 60 | "print(next(ca))" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 18, 66 | "metadata": { 67 | "collapsed": false 68 | }, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "8\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "def counter():\n", 80 | " count = 0\n", 81 | " for article in aug_data['data']:\n", 82 | " for para in article['paragraphs']:\n", 83 | " deps = para['deps']\n", 84 | " nodess = []\n", 85 | " for dep in deps:\n", 86 | " if dep is None:\n", 87 | " count += 1\n", 88 | " print(count)\n", 89 | "counter()\n" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 19, 95 | "metadata": { 96 | "collapsed": false 97 | }, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | "0\n" 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "def bad_node_counter():\n", 109 | " count = 0\n", 110 | " for article in aug_data['data']:\n", 111 | " for para in article['paragraphs']:\n", 112 | " sents = para['sents']\n", 113 | " deps = para['deps']\n", 114 | " nodess = []\n", 115 | " for dep in deps:\n", 116 | " if dep is not None:\n", 117 | " nodes, edges = dep\n", 118 | " for node in nodes:\n", 119 | " if len(node) != 5:\n", 120 | " count += 1\n", 121 | " print(count)\n", 122 | "bad_node_counter() " 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 20, 128 | "metadata": { 129 | "collapsed": false 130 | }, 131 | "outputs": [ 132 | { 133 | "name": "stdout", 134 | "output_type": "stream", 135 | "text": [ 136 | "7\n" 137 | ] 138 | } 139 | ], 140 | "source": [ 141 | "def noanswer_counter():\n", 142 | " count = 0\n", 143 | " for article in aug_data['data']:\n", 144 | " for para in article['paragraphs']:\n", 145 | " deps = para['deps']\n", 146 | " nodess = []\n", 147 | " for dep in deps:\n", 148 | " if dep is not None:\n", 149 | " nodes, edges = dep\n", 150 | " nodess.append(nodes)\n", 151 | " else:\n", 152 | " nodess.append([])\n", 153 | " wordss = [[node[0] for node in nodes] for nodes in nodess]\n", 154 | " for qa in para['qas']:\n", 155 | " for answer in qa['answers']:\n", 156 | " text = answer['text']\n", 157 | " word_start = answer['answer_word_start']\n", 158 | " word_stop = answer['answer_word_stop']\n", 159 | " if word_start is None:\n", 160 | " count += 1\n", 161 | " print(count)\n", 162 | "noanswer_counter()" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 22, 168 | "metadata": { 169 | "collapsed": false 170 | }, 171 | "outputs": [ 172 | { 173 | "name": "stdout", 174 | "output_type": "stream", 175 | "text": [ 176 | "10600\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | "print(sum(len(para['qas']) for a in aug_data['data'] for para in a['paragraphs']))" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 5, 187 | "metadata": { 188 | "collapsed": false 189 | }, 190 | "outputs": [ 191 | { 192 | "name": "stdout", 193 | "output_type": "stream", 194 | "text": [ 195 | "10348\n" 196 | ] 197 | } 198 | ], 199 | "source": [ 200 | "import nltk\n", 201 | "\n", 202 | "def _set_span(t, i):\n", 203 | " if isinstance(t[0], str):\n", 204 | " t.span = (i, i+len(t))\n", 205 | " else:\n", 206 | " first = True\n", 207 | " for c in t:\n", 208 | " cur_span = _set_span(c, i)\n", 209 | " i = cur_span[1]\n", 210 | " if first:\n", 211 | " min_ = cur_span[0]\n", 212 | " first = False\n", 213 | " max_ = cur_span[1]\n", 214 | " t.span = (min_, max_)\n", 215 | " return t.span\n", 216 | "\n", 217 | "\n", 218 | "def set_span(t):\n", 219 | " assert isinstance(t, nltk.tree.Tree)\n", 220 | " try:\n", 221 | " return _set_span(t, 0)\n", 222 | " except:\n", 223 | " print(t)\n", 224 | " exit()\n", 225 | "\n", 226 | "def same_span_counter():\n", 227 | " count = 0\n", 228 | " for article in aug_data['data']:\n", 229 | " for para in article['paragraphs']:\n", 230 | " consts = para['consts']\n", 231 | " for const in consts:\n", 232 | " tree = nltk.tree.Tree.fromstring(const)\n", 233 | " set_span(tree)\n", 234 | " if len(list(tree.subtrees())) > len(set(t.span for t in tree.subtrees())):\n", 235 | " count += 1\n", 236 | " print(count)\n", 237 | "same_span_counter()" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": { 244 | "collapsed": true 245 | }, 246 | "outputs": [], 247 | "source": [] 248 | } 249 | ], 250 | "metadata": { 251 | "kernelspec": { 252 | "display_name": "Python 3", 253 | "language": "python", 254 | "name": "python3" 255 | }, 256 | "language_info": { 257 | "codemirror_mode": { 258 | "name": "ipython", 259 | "version": 3 260 | }, 261 | "file_extension": ".py", 262 | "mimetype": "text/x-python", 263 | "name": "python", 264 | "nbconvert_exporter": "python", 265 | "pygments_lexer": "ipython3", 266 | "version": "3.5.1" 267 | } 268 | }, 269 | "nbformat": 4, 270 | "nbformat_minor": 0 271 | } 272 | -------------------------------------------------------------------------------- /squad/eda_aug_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import json\n", 12 | "\n", 13 | "aug_data_path = \"/Users/minjoons/data/squad/train-v1.0-aug.json\"\n", 14 | "aug_data = json.load(open(aug_data_path, 'r'))" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": { 21 | "collapsed": false 22 | }, 23 | "outputs": [ 24 | { 25 | "name": "stdout", 26 | "output_type": "stream", 27 | "text": [ 28 | "(['Saint', 'Bernadette', 'Soubirous'], 'Saint Bernadette Soubirous')\n", 29 | "(['a', 'copper', 'statue', 'of', 'Christ'], 'a copper statue of Christ')\n", 30 | "(['the', 'Main', 'Building'], 'the Main Building')\n", 31 | "(['a', 'Marian', 'place', 'of', 'prayer', 'and', 'reflection'], 'a Marian place of prayer and reflection')\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "def compare_answers():\n", 37 | " for article in aug_data['data']:\n", 38 | " for para in article['paragraphs']:\n", 39 | " deps = para['deps']\n", 40 | " nodess = []\n", 41 | " for dep in deps:\n", 42 | " nodes, edges = dep\n", 43 | " if dep is not None:\n", 44 | " nodess.append(nodes)\n", 45 | " else:\n", 46 | " nodess.append([])\n", 47 | " wordss = [[node[0] for node in nodes] for nodes in nodess]\n", 48 | " for qa in para['qas']:\n", 49 | " for answer in qa['answers']:\n", 50 | " text = answer['text']\n", 51 | " word_start = answer['answer_word_start']\n", 52 | " word_stop = answer['answer_word_stop']\n", 53 | " answer_words = wordss[word_start[0]][word_start[1]:word_stop[1]]\n", 54 | " yield answer_words, text\n", 55 | "\n", 56 | "ca = compare_answers()\n", 57 | "print(next(ca))\n", 58 | "print(next(ca))\n", 59 | "print(next(ca))\n", 60 | "print(next(ca))" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 11, 66 | "metadata": { 67 | "collapsed": false 68 | }, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "x: .\n", 75 | "x: .\n", 76 | "x: .\n", 77 | "x: .\n", 78 | "x: .\n", 79 | "x: .\n", 80 | "x: .\n", 81 | "x: .\n", 82 | "q: k\n", 83 | "q: j\n", 84 | "q: n\n", 85 | "q: b\n", 86 | "q: v\n", 87 | "x: .\n", 88 | "x: :208\n", 89 | "x: .\n", 90 | "x: .\n", 91 | "x: .\n", 92 | "x: .\n", 93 | "x: .\n", 94 | "x: .\n", 95 | "x: .\n", 96 | "x: .\n", 97 | "x: .\n", 98 | "x: .\n", 99 | "x: .\n", 100 | "q: dd\n", 101 | "q: dd\n", 102 | "q: dd\n", 103 | "q: dd\n", 104 | "q: d\n", 105 | "x: .\n", 106 | "x: .\n", 107 | "x: .\n", 108 | "x: .\n", 109 | "x: .\n", 110 | "x: .\n", 111 | "x: .\n", 112 | "x: .\n", 113 | "x: :411\n", 114 | "x: .\n", 115 | "x: .\n", 116 | "x: .\n", 117 | "x: .\n", 118 | "x: .\n", 119 | "x: .\n", 120 | "x: :40\n", 121 | "x: .\n", 122 | "x: *\n", 123 | "x: :14\n", 124 | "x: .\n", 125 | "x: .\n", 126 | "x: .\n", 127 | "x: :131\n", 128 | "x: .\n", 129 | "x: .\n", 130 | "x: .\n", 131 | "x: .\n", 132 | "x: .\n", 133 | "x: .\n", 134 | "x: .\n", 135 | "x: .\n", 136 | "x: .\n", 137 | "53 10\n" 138 | ] 139 | } 140 | ], 141 | "source": [ 142 | "def nodep_counter():\n", 143 | " x_count = 0\n", 144 | " q_count = 0\n", 145 | " for article in aug_data['data']:\n", 146 | " for para in article['paragraphs']:\n", 147 | " deps = para['deps']\n", 148 | " nodess = []\n", 149 | " for sent, dep in zip(para['sents'], deps):\n", 150 | " if dep is None:\n", 151 | " print(\"x:\", sent)\n", 152 | " x_count += 1\n", 153 | " for qa in para['qas']:\n", 154 | " if qa['dep'] is None:\n", 155 | " print(\"q:\", qa['question'])\n", 156 | " q_count += 1\n", 157 | " print(x_count, q_count)\n", 158 | "nodep_counter()\n" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 4, 164 | "metadata": { 165 | "collapsed": false 166 | }, 167 | "outputs": [ 168 | { 169 | "name": "stdout", 170 | "output_type": "stream", 171 | "text": [ 172 | "0\n" 173 | ] 174 | } 175 | ], 176 | "source": [ 177 | "def bad_node_counter():\n", 178 | " count = 0\n", 179 | " for article in aug_data['data']:\n", 180 | " for para in article['paragraphs']:\n", 181 | " sents = para['sents']\n", 182 | " deps = para['deps']\n", 183 | " nodess = []\n", 184 | " for dep in deps:\n", 185 | " if dep is not None:\n", 186 | " nodes, edges = dep\n", 187 | " for node in nodes:\n", 188 | " if len(node) != 5:\n", 189 | " count += 1\n", 190 | " print(count)\n", 191 | "bad_node_counter() " 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 5, 197 | "metadata": { 198 | "collapsed": false 199 | }, 200 | "outputs": [ 201 | { 202 | "name": "stdout", 203 | "output_type": "stream", 204 | "text": [ 205 | "36\n" 206 | ] 207 | } 208 | ], 209 | "source": [ 210 | "def noanswer_counter():\n", 211 | " count = 0\n", 212 | " for article in aug_data['data']:\n", 213 | " for para in article['paragraphs']:\n", 214 | " deps = para['deps']\n", 215 | " nodess = []\n", 216 | " for dep in deps:\n", 217 | " if dep is not None:\n", 218 | " nodes, edges = dep\n", 219 | " nodess.append(nodes)\n", 220 | " else:\n", 221 | " nodess.append([])\n", 222 | " wordss = [[node[0] for node in nodes] for nodes in nodess]\n", 223 | " for qa in para['qas']:\n", 224 | " for answer in qa['answers']:\n", 225 | " text = answer['text']\n", 226 | " word_start = answer['answer_word_start']\n", 227 | " word_stop = answer['answer_word_stop']\n", 228 | " if word_start is None:\n", 229 | " count += 1\n", 230 | " print(count)\n", 231 | "noanswer_counter()" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 14, 237 | "metadata": { 238 | "collapsed": false 239 | }, 240 | "outputs": [ 241 | { 242 | "name": "stdout", 243 | "output_type": "stream", 244 | "text": [ 245 | "106\n" 246 | ] 247 | } 248 | ], 249 | "source": [ 250 | "def mult_sent_answer_counter():\n", 251 | " count = 0\n", 252 | " for article in aug_data['data']:\n", 253 | " for para in article['paragraphs']:\n", 254 | " for qa in para['qas']:\n", 255 | " for answer in qa['answers']:\n", 256 | " text = answer['text']\n", 257 | " word_start = answer['answer_word_start']\n", 258 | " word_stop = answer['answer_word_stop']\n", 259 | " if word_start is not None and word_start[0] != word_stop[0]:\n", 260 | " count += 1\n", 261 | " print(count)\n", 262 | "mult_sent_answer_counter()" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "metadata": { 269 | "collapsed": true 270 | }, 271 | "outputs": [], 272 | "source": [] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "metadata": { 278 | "collapsed": true 279 | }, 280 | "outputs": [], 281 | "source": [] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": { 287 | "collapsed": true 288 | }, 289 | "outputs": [], 290 | "source": [] 291 | } 292 | ], 293 | "metadata": { 294 | "kernelspec": { 295 | "display_name": "Python 3", 296 | "language": "python", 297 | "name": "python3" 298 | }, 299 | "language_info": { 300 | "codemirror_mode": { 301 | "name": "ipython", 302 | "version": 3 303 | }, 304 | "file_extension": ".py", 305 | "mimetype": "text/x-python", 306 | "name": "python", 307 | "nbconvert_exporter": "python", 308 | "pygments_lexer": "ipython3", 309 | "version": "3.5.1" 310 | } 311 | }, 312 | "nbformat": 4, 313 | "nbformat_minor": 0 314 | } 315 | -------------------------------------------------------------------------------- /squad/evaluate-v1.1.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.1 of the SQuAD dataset. """ 2 | from __future__ import print_function 3 | from collections import Counter 4 | import string 5 | import re 6 | import argparse 7 | import json 8 | import sys 9 | 10 | 11 | def normalize_answer(s): 12 | """Lower text and remove punctuation, articles and extra whitespace.""" 13 | def remove_articles(text): 14 | return re.sub(r'\b(a|an|the)\b', ' ', text) 15 | 16 | def white_space_fix(text): 17 | return ' '.join(text.split()) 18 | 19 | def remove_punc(text): 20 | exclude = set(string.punctuation) 21 | return ''.join(ch for ch in text if ch not in exclude) 22 | 23 | def lower(text): 24 | return text.lower() 25 | 26 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 27 | 28 | 29 | def f1_score(prediction, ground_truth): 30 | prediction_tokens = normalize_answer(prediction).split() 31 | ground_truth_tokens = normalize_answer(ground_truth).split() 32 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 33 | num_same = sum(common.values()) 34 | if num_same == 0: 35 | return 0 36 | precision = 1.0 * num_same / len(prediction_tokens) 37 | recall = 1.0 * num_same / len(ground_truth_tokens) 38 | f1 = (2 * precision * recall) / (precision + recall) 39 | return f1 40 | 41 | 42 | def exact_match_score(prediction, ground_truth): 43 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 44 | 45 | 46 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 47 | scores_for_ground_truths = [] 48 | for ground_truth in ground_truths: 49 | score = metric_fn(prediction, ground_truth) 50 | scores_for_ground_truths.append(score) 51 | return max(scores_for_ground_truths) 52 | 53 | 54 | def evaluate(dataset, predictions): 55 | f1 = exact_match = total = 0 56 | for article in dataset: 57 | for paragraph in article['paragraphs']: 58 | for qa in paragraph['qas']: 59 | total += 1 60 | if qa['id'] not in predictions: 61 | message = 'Unanswered question ' + qa['id'] + \ 62 | ' will receive score 0.' 63 | print(message, file=sys.stderr) 64 | continue 65 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 66 | prediction = predictions[qa['id']] 67 | exact_match += metric_max_over_ground_truths( 68 | exact_match_score, prediction, ground_truths) 69 | f1 += metric_max_over_ground_truths( 70 | f1_score, prediction, ground_truths) 71 | 72 | exact_match = 100.0 * exact_match / total 73 | f1 = 100.0 * f1 / total 74 | 75 | return {'exact_match': exact_match, 'f1': f1} 76 | 77 | 78 | if __name__ == '__main__': 79 | expected_version = '1.1' 80 | parser = argparse.ArgumentParser( 81 | description='Evaluation for SQuAD ' + expected_version) 82 | parser.add_argument('dataset_file', help='Dataset file') 83 | parser.add_argument('prediction_file', help='Prediction File') 84 | args = parser.parse_args() 85 | with open(args.dataset_file) as dataset_file: 86 | dataset_json = json.load(dataset_file) 87 | if (dataset_json['version'] != expected_version): 88 | print('Evaluation expects v-' + expected_version + 89 | ', but got dataset with v-' + dataset_json['version'], 90 | file=sys.stderr) 91 | dataset = dataset_json['data'] 92 | with open(args.prediction_file) as prediction_file: 93 | predictions = json.load(prediction_file) 94 | print(json.dumps(evaluate(dataset, predictions))) 95 | -------------------------------------------------------------------------------- /squad/evaluate.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for v1.1 of the SQuAD dataset. [Changed name for external importing]""" 2 | from __future__ import print_function 3 | from collections import Counter 4 | import string 5 | import re 6 | import argparse 7 | import json 8 | import sys 9 | 10 | 11 | def normalize_answer(s): 12 | """Lower text and remove punctuation, articles and extra whitespace.""" 13 | def remove_articles(text): 14 | return re.sub(r'\b(a|an|the)\b', ' ', text) 15 | 16 | def white_space_fix(text): 17 | return ' '.join(text.split()) 18 | 19 | def remove_punc(text): 20 | exclude = set(string.punctuation) 21 | return ''.join(ch for ch in text if ch not in exclude) 22 | 23 | def lower(text): 24 | return text.lower() 25 | 26 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 27 | 28 | 29 | def f1_score(prediction, ground_truth): 30 | prediction_tokens = normalize_answer(prediction).split() 31 | ground_truth_tokens = normalize_answer(ground_truth).split() 32 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 33 | num_same = sum(common.values()) 34 | if num_same == 0: 35 | return 0 36 | precision = 1.0 * num_same / len(prediction_tokens) 37 | recall = 1.0 * num_same / len(ground_truth_tokens) 38 | f1 = (2 * precision * recall) / (precision + recall) 39 | return f1 40 | 41 | 42 | def exact_match_score(prediction, ground_truth): 43 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 44 | 45 | 46 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 47 | scores_for_ground_truths = [] 48 | for ground_truth in ground_truths: 49 | score = metric_fn(prediction, ground_truth) 50 | scores_for_ground_truths.append(score) 51 | return max(scores_for_ground_truths) 52 | 53 | 54 | def evaluate(dataset, predictions): 55 | f1 = exact_match = total = 0 56 | for article in dataset: 57 | for paragraph in article['paragraphs']: 58 | for qa in paragraph['qas']: 59 | total += 1 60 | if qa['id'] not in predictions: 61 | message = 'Unanswered question ' + qa['id'] + \ 62 | ' will receive score 0.' 63 | print(message, file=sys.stderr) 64 | continue 65 | ground_truths = list(map(lambda x: x['text'], qa['answers'])) 66 | prediction = predictions[qa['id']] 67 | exact_match += metric_max_over_ground_truths( 68 | exact_match_score, prediction, ground_truths) 69 | f1 += metric_max_over_ground_truths( 70 | f1_score, prediction, ground_truths) 71 | 72 | exact_match = 100.0 * exact_match / total 73 | f1 = 100.0 * f1 / total 74 | 75 | return {'exact_match': exact_match, 'f1': f1} 76 | 77 | 78 | if __name__ == '__main__': 79 | expected_version = '1.1' 80 | parser = argparse.ArgumentParser( 81 | description='Evaluation for SQuAD ' + expected_version) 82 | parser.add_argument('dataset_file', help='Dataset file') 83 | parser.add_argument('prediction_file', help='Prediction File') 84 | args = parser.parse_args() 85 | with open(args.dataset_file) as dataset_file: 86 | dataset_json = json.load(dataset_file) 87 | if (dataset_json['version'] != expected_version): 88 | print('Evaluation expects v-' + expected_version + 89 | ', but got dataset with v-' + dataset_json['version'], 90 | file=sys.stderr) 91 | dataset = dataset_json['data'] 92 | with open(args.prediction_file) as prediction_file: 93 | predictions = json.load(prediction_file) 94 | print(json.dumps(evaluate(dataset, predictions))) 95 | -------------------------------------------------------------------------------- /squad/prepro.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | # data: q, cq, (dq), (pq), y, *x, *cx 5 | # shared: x, cx, (dx), (px), word_counter, char_counter, word2vec 6 | # no metadata 7 | from collections import Counter 8 | 9 | from tqdm import tqdm 10 | 11 | from squad.utils import get_word_span, get_word_idx, process_tokens 12 | 13 | 14 | def main(): 15 | args = get_args() 16 | prepro(args) 17 | 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser() 21 | home = os.path.expanduser("~") 22 | source_dir = os.path.join(home, "data", "squad") 23 | target_dir = "data/squad" 24 | glove_dir = os.path.join(home, "data", "glove") 25 | parser.add_argument('-s', "--source_dir", default=source_dir) 26 | parser.add_argument('-t', "--target_dir", default=target_dir) 27 | parser.add_argument('-d', "--debug", action='store_true') 28 | parser.add_argument("--train_ratio", default=0.9, type=int) 29 | parser.add_argument("--glove_corpus", default="6B") 30 | parser.add_argument("--glove_dir", default=glove_dir) 31 | parser.add_argument("--glove_vec_size", default=100, type=int) 32 | parser.add_argument("--mode", default="full", type=str) 33 | parser.add_argument("--single_path", default="", type=str) 34 | parser.add_argument("--tokenizer", default="PTB", type=str) 35 | parser.add_argument("--url", default="vision-server2.corp.ai2", type=str) 36 | parser.add_argument("--port", default=8000, type=int) 37 | parser.add_argument("--split", action='store_true') 38 | # TODO : put more args here 39 | return parser.parse_args() 40 | 41 | 42 | def create_all(args): 43 | out_path = os.path.join(args.source_dir, "all-v1.1.json") 44 | if os.path.exists(out_path): 45 | return 46 | train_path = os.path.join(args.source_dir, "train-v1.1.json") 47 | train_data = json.load(open(train_path, 'r')) 48 | dev_path = os.path.join(args.source_dir, "dev-v1.1.json") 49 | dev_data = json.load(open(dev_path, 'r')) 50 | train_data['data'].extend(dev_data['data']) 51 | print("dumping all data ...") 52 | json.dump(train_data, open(out_path, 'w')) 53 | 54 | 55 | def prepro(args): 56 | if not os.path.exists(args.target_dir): 57 | os.makedirs(args.target_dir) 58 | 59 | if args.mode == 'full': 60 | prepro_each(args, 'train', out_name='train') 61 | prepro_each(args, 'dev', out_name='dev') 62 | prepro_each(args, 'dev', out_name='test') 63 | elif args.mode == 'all': 64 | create_all(args) 65 | prepro_each(args, 'dev', 0.0, 0.0, out_name='dev') 66 | prepro_each(args, 'dev', 0.0, 0.0, out_name='test') 67 | prepro_each(args, 'all', out_name='train') 68 | elif args.mode == 'single': 69 | assert len(args.single_path) > 0 70 | prepro_each(args, "NULL", out_name="single", in_path=args.single_path) 71 | else: 72 | prepro_each(args, 'train', 0.0, args.train_ratio, out_name='train') 73 | prepro_each(args, 'train', args.train_ratio, 1.0, out_name='dev') 74 | prepro_each(args, 'dev', out_name='test') 75 | 76 | 77 | def save(args, data, shared, data_type): 78 | data_path = os.path.join(args.target_dir, "data_{}.json".format(data_type)) 79 | shared_path = os.path.join(args.target_dir, "shared_{}.json".format(data_type)) 80 | json.dump(data, open(data_path, 'w')) 81 | json.dump(shared, open(shared_path, 'w')) 82 | 83 | 84 | def get_word2vec(args, word_counter): 85 | glove_path = os.path.join(args.glove_dir, "glove.{}.{}d.txt".format(args.glove_corpus, args.glove_vec_size)) 86 | sizes = {'6B': int(4e5), '42B': int(1.9e6), '840B': int(2.2e6), '2B': int(1.2e6)} 87 | total = sizes[args.glove_corpus] 88 | word2vec_dict = {} 89 | with open(glove_path, 'r', encoding='utf-8') as fh: 90 | for line in tqdm(fh, total=total): 91 | array = line.lstrip().rstrip().split(" ") 92 | word = array[0] 93 | vector = list(map(float, array[1:])) 94 | if word in word_counter: 95 | word2vec_dict[word] = vector 96 | elif word.capitalize() in word_counter: 97 | word2vec_dict[word.capitalize()] = vector 98 | elif word.lower() in word_counter: 99 | word2vec_dict[word.lower()] = vector 100 | elif word.upper() in word_counter: 101 | word2vec_dict[word.upper()] = vector 102 | 103 | print("{}/{} of word vocab have corresponding vectors in {}".format(len(word2vec_dict), len(word_counter), glove_path)) 104 | return word2vec_dict 105 | 106 | 107 | def prepro_each(args, data_type, start_ratio=0.0, stop_ratio=1.0, out_name="default", in_path=None): 108 | if args.tokenizer == "PTB": 109 | import nltk 110 | sent_tokenize = nltk.sent_tokenize 111 | def word_tokenize(tokens): 112 | return [token.replace("''", '"').replace("``", '"') for token in nltk.word_tokenize(tokens)] 113 | elif args.tokenizer == 'Stanford': 114 | from my.corenlp_interface import CoreNLPInterface 115 | interface = CoreNLPInterface(args.url, args.port) 116 | sent_tokenize = interface.split_doc 117 | word_tokenize = interface.split_sent 118 | else: 119 | raise Exception() 120 | 121 | if not args.split: 122 | sent_tokenize = lambda para: [para] 123 | 124 | source_path = in_path or os.path.join(args.source_dir, "{}-v1.1.json".format(data_type)) 125 | source_data = json.load(open(source_path, 'r')) 126 | 127 | q, cq, y, rx, rcx, ids, idxs = [], [], [], [], [], [], [] 128 | cy = [] 129 | x, cx = [], [] 130 | answerss = [] 131 | p = [] 132 | word_counter, char_counter, lower_word_counter = Counter(), Counter(), Counter() 133 | start_ai = int(round(len(source_data['data']) * start_ratio)) 134 | stop_ai = int(round(len(source_data['data']) * stop_ratio)) 135 | for ai, article in enumerate(tqdm(source_data['data'][start_ai:stop_ai])): 136 | xp, cxp = [], [] 137 | pp = [] 138 | x.append(xp) 139 | cx.append(cxp) 140 | p.append(pp) 141 | for pi, para in enumerate(article['paragraphs']): 142 | # wordss 143 | context = para['context'] 144 | context = context.replace("''", '" ') 145 | context = context.replace("``", '" ') 146 | xi = list(map(word_tokenize, sent_tokenize(context))) 147 | xi = [process_tokens(tokens) for tokens in xi] # process tokens 148 | # given xi, add chars 149 | cxi = [[list(xijk) for xijk in xij] for xij in xi] 150 | xp.append(xi) 151 | cxp.append(cxi) 152 | pp.append(context) 153 | 154 | for xij in xi: 155 | for xijk in xij: 156 | word_counter[xijk] += len(para['qas']) 157 | lower_word_counter[xijk.lower()] += len(para['qas']) 158 | for xijkl in xijk: 159 | char_counter[xijkl] += len(para['qas']) 160 | 161 | rxi = [ai, pi] 162 | assert len(x) - 1 == ai 163 | assert len(x[ai]) - 1 == pi 164 | for qa in para['qas']: 165 | # get words 166 | qi = word_tokenize(qa['question']) 167 | cqi = [list(qij) for qij in qi] 168 | yi = [] 169 | cyi = [] 170 | answers = [] 171 | for answer in qa['answers']: 172 | answer_text = answer['text'] 173 | answers.append(answer_text) 174 | answer_start = answer['answer_start'] 175 | answer_stop = answer_start + len(answer_text) 176 | # TODO : put some function that gives word_start, word_stop here 177 | yi0, yi1 = get_word_span(context, xi, answer_start, answer_stop) 178 | # yi0 = answer['answer_word_start'] or [0, 0] 179 | # yi1 = answer['answer_word_stop'] or [0, 1] 180 | assert len(xi[yi0[0]]) > yi0[1] 181 | assert len(xi[yi1[0]]) >= yi1[1] 182 | w0 = xi[yi0[0]][yi0[1]] 183 | w1 = xi[yi1[0]][yi1[1]-1] 184 | i0 = get_word_idx(context, xi, yi0) 185 | i1 = get_word_idx(context, xi, (yi1[0], yi1[1]-1)) 186 | cyi0 = answer_start - i0 187 | cyi1 = answer_stop - i1 - 1 188 | # print(answer_text, w0[cyi0:], w1[:cyi1+1]) 189 | assert answer_text[0] == w0[cyi0], (answer_text, w0, cyi0) 190 | assert answer_text[-1] == w1[cyi1] 191 | assert cyi0 < 32, (answer_text, w0) 192 | assert cyi1 < 32, (answer_text, w1) 193 | 194 | yi.append([yi0, yi1]) 195 | cyi.append([cyi0, cyi1]) 196 | 197 | for qij in qi: 198 | word_counter[qij] += 1 199 | lower_word_counter[qij.lower()] += 1 200 | for qijk in qij: 201 | char_counter[qijk] += 1 202 | 203 | q.append(qi) 204 | cq.append(cqi) 205 | y.append(yi) 206 | cy.append(cyi) 207 | rx.append(rxi) 208 | rcx.append(rxi) 209 | ids.append(qa['id']) 210 | idxs.append(len(idxs)) 211 | answerss.append(answers) 212 | 213 | if args.debug: 214 | break 215 | 216 | word2vec_dict = get_word2vec(args, word_counter) 217 | lower_word2vec_dict = get_word2vec(args, lower_word_counter) 218 | 219 | # add context here 220 | data = {'q': q, 'cq': cq, 'y': y, '*x': rx, '*cx': rcx, 'cy': cy, 221 | 'idxs': idxs, 'ids': ids, 'answerss': answerss, '*p': rx} 222 | shared = {'x': x, 'cx': cx, 'p': p, 223 | 'word_counter': word_counter, 'char_counter': char_counter, 'lower_word_counter': lower_word_counter, 224 | 'word2vec': word2vec_dict, 'lower_word2vec': lower_word2vec_dict} 225 | 226 | print("saving ...") 227 | save(args, data, shared, out_name) 228 | 229 | 230 | 231 | if __name__ == "__main__": 232 | main() -------------------------------------------------------------------------------- /squad/prepro_aug.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | # data: q, cq, (dq), (pq), y, *x, *cx 5 | # shared: x, cx, (dx), (px), word_counter, char_counter, word2vec 6 | # no metadata 7 | from collections import Counter 8 | 9 | import nltk 10 | from tqdm import tqdm 11 | 12 | from my.nltk_utils import load_compressed_tree 13 | 14 | 15 | def bool_(arg): 16 | if arg == 'True': 17 | return True 18 | elif arg == 'False': 19 | return False 20 | raise Exception() 21 | 22 | 23 | def main(): 24 | args = get_args() 25 | prepro(args) 26 | 27 | 28 | def get_args(): 29 | parser = argparse.ArgumentParser() 30 | home = os.path.expanduser("~") 31 | source_dir = os.path.join(home, "data", "squad") 32 | target_dir = "data/squad" 33 | glove_dir = os.path.join(home, "data", "glove") 34 | parser.add_argument("--source_dir", default=source_dir) 35 | parser.add_argument("--target_dir", default=target_dir) 36 | parser.add_argument("--debug", default=False, type=bool_) 37 | parser.add_argument("--train_ratio", default=0.9, type=int) 38 | parser.add_argument("--glove_corpus", default="6B") 39 | parser.add_argument("--glove_dir", default=glove_dir) 40 | parser.add_argument("--glove_vec_size", default=100, type=int) 41 | parser.add_argument("--full_train", default=False, type=bool_) 42 | # TODO : put more args here 43 | return parser.parse_args() 44 | 45 | 46 | def prepro(args): 47 | if not os.path.exists(args.target_dir): 48 | os.makedirs(args.target_dir) 49 | 50 | if args.full_train: 51 | data_train, shared_train = prepro_each(args, 'train') 52 | data_dev, shared_dev = prepro_each(args, 'dev') 53 | else: 54 | data_train, shared_train = prepro_each(args, 'train', 0.0, args.train_ratio) 55 | data_dev, shared_dev = prepro_each(args, 'train', args.train_ratio, 1.0) 56 | data_test, shared_test = prepro_each(args, 'dev') 57 | 58 | print("saving ...") 59 | save(args, data_train, shared_train, 'train') 60 | save(args, data_dev, shared_dev, 'dev') 61 | save(args, data_test, shared_test, 'test') 62 | 63 | 64 | def save(args, data, shared, data_type): 65 | data_path = os.path.join(args.target_dir, "data_{}.json".format(data_type)) 66 | shared_path = os.path.join(args.target_dir, "shared_{}.json".format(data_type)) 67 | json.dump(data, open(data_path, 'w')) 68 | json.dump(shared, open(shared_path, 'w')) 69 | 70 | 71 | def get_word2vec(args, word_counter): 72 | glove_path = os.path.join(args.glove_dir, "glove.{}.{}d.txt".format(args.glove_corpus, args.glove_vec_size)) 73 | sizes = {'6B': int(4e5), '42B': int(1.9e6), '840B': int(2.2e6), '2B': int(1.2e6)} 74 | total = sizes[args.glove_corpus] 75 | word2vec_dict = {} 76 | with open(glove_path, 'r') as fh: 77 | for line in tqdm(fh, total=total): 78 | array = line.lstrip().rstrip().split(" ") 79 | word = array[0] 80 | vector = list(map(float, array[1:])) 81 | if word in word_counter: 82 | word2vec_dict[word] = vector 83 | elif word.capitalize() in word_counter: 84 | word2vec_dict[word.capitalize()] = vector 85 | elif word.lower() in word_counter: 86 | word2vec_dict[word.lower()] = vector 87 | elif word.upper() in word_counter: 88 | word2vec_dict[word.upper()] = vector 89 | 90 | print("{}/{} of word vocab have corresponding vectors in {}".format(len(word2vec_dict), len(word_counter), glove_path)) 91 | return word2vec_dict 92 | 93 | 94 | def prepro_each(args, data_type, start_ratio=0.0, stop_ratio=1.0): 95 | source_path = os.path.join(args.source_dir, "{}-v1.0-aug.json".format(data_type)) 96 | source_data = json.load(open(source_path, 'r')) 97 | 98 | q, cq, y, rx, rcx, ids, idxs = [], [], [], [], [], [], [] 99 | x, cx, tx, stx = [], [], [], [] 100 | answerss = [] 101 | word_counter, char_counter, lower_word_counter = Counter(), Counter(), Counter() 102 | pos_counter = Counter() 103 | start_ai = int(round(len(source_data['data']) * start_ratio)) 104 | stop_ai = int(round(len(source_data['data']) * stop_ratio)) 105 | for ai, article in enumerate(tqdm(source_data['data'][start_ai:stop_ai])): 106 | xp, cxp, txp, stxp = [], [], [], [] 107 | x.append(xp) 108 | cx.append(cxp) 109 | tx.append(txp) 110 | stx.append(stxp) 111 | for pi, para in enumerate(article['paragraphs']): 112 | xi = [] 113 | for dep in para['deps']: 114 | if dep is None: 115 | xi.append([]) 116 | else: 117 | xi.append([node[0] for node in dep[0]]) 118 | cxi = [[list(xijk) for xijk in xij] for xij in xi] 119 | xp.append(xi) 120 | cxp.append(cxi) 121 | txp.append(para['consts']) 122 | stxp.append([str(load_compressed_tree(s)) for s in para['consts']]) 123 | trees = map(nltk.tree.Tree.fromstring, para['consts']) 124 | for tree in trees: 125 | for subtree in tree.subtrees(): 126 | pos_counter[subtree.label()] += 1 127 | 128 | for xij in xi: 129 | for xijk in xij: 130 | word_counter[xijk] += len(para['qas']) 131 | lower_word_counter[xijk.lower()] += len(para['qas']) 132 | for xijkl in xijk: 133 | char_counter[xijkl] += len(para['qas']) 134 | 135 | rxi = [ai, pi] 136 | assert len(x) - 1 == ai 137 | assert len(x[ai]) - 1 == pi 138 | for qa in para['qas']: 139 | dep = qa['dep'] 140 | qi = [] if dep is None else [node[0] for node in dep[0]] 141 | cqi = [list(qij) for qij in qi] 142 | yi = [] 143 | answers = [] 144 | for answer in qa['answers']: 145 | answers.append(answer['text']) 146 | yi0 = answer['answer_word_start'] or [0, 0] 147 | yi1 = answer['answer_word_stop'] or [0, 1] 148 | assert len(xi[yi0[0]]) > yi0[1] 149 | assert len(xi[yi1[0]]) >= yi1[1] 150 | yi.append([yi0, yi1]) 151 | 152 | for qij in qi: 153 | word_counter[qij] += 1 154 | lower_word_counter[qij.lower()] += 1 155 | for qijk in qij: 156 | char_counter[qijk] += 1 157 | 158 | q.append(qi) 159 | cq.append(cqi) 160 | y.append(yi) 161 | rx.append(rxi) 162 | rcx.append(rxi) 163 | ids.append(qa['id']) 164 | idxs.append(len(idxs)) 165 | answerss.append(answers) 166 | 167 | if args.debug: 168 | break 169 | 170 | word2vec_dict = get_word2vec(args, word_counter) 171 | lower_word2vec_dict = get_word2vec(args, lower_word_counter) 172 | 173 | data = {'q': q, 'cq': cq, 'y': y, '*x': rx, '*cx': rcx, '*tx': rx, '*stx': rx, 174 | 'idxs': idxs, 'ids': ids, 'answerss': answerss} 175 | shared = {'x': x, 'cx': cx, 'tx': tx, 'stx': stx, 176 | 'word_counter': word_counter, 'char_counter': char_counter, 'lower_word_counter': lower_word_counter, 177 | 'word2vec': word2vec_dict, 'lower_word2vec': lower_word2vec_dict, 'pos_counter': pos_counter} 178 | 179 | return data, shared 180 | 181 | 182 | if __name__ == "__main__": 183 | main() -------------------------------------------------------------------------------- /squad/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def get_2d_spans(text, tokenss): 5 | spanss = [] 6 | cur_idx = 0 7 | for tokens in tokenss: 8 | spans = [] 9 | for token in tokens: 10 | if text.find(token, cur_idx) < 0: 11 | print(tokens) 12 | print("{} {} {}".format(token, cur_idx, text)) 13 | raise Exception() 14 | cur_idx = text.find(token, cur_idx) 15 | spans.append((cur_idx, cur_idx + len(token))) 16 | cur_idx += len(token) 17 | spanss.append(spans) 18 | return spanss 19 | 20 | 21 | def get_word_span(context, wordss, start, stop): 22 | spanss = get_2d_spans(context, wordss) 23 | idxs = [] 24 | for sent_idx, spans in enumerate(spanss): 25 | for word_idx, span in enumerate(spans): 26 | if not (stop <= span[0] or start >= span[1]): 27 | idxs.append((sent_idx, word_idx)) 28 | 29 | assert len(idxs) > 0, "{} {} {} {}".format(context, spanss, start, stop) 30 | return idxs[0], (idxs[-1][0], idxs[-1][1] + 1) 31 | 32 | 33 | def get_phrase(context, wordss, span): 34 | """ 35 | Obtain phrase as substring of context given start and stop indices in word level 36 | :param context: 37 | :param wordss: 38 | :param start: [sent_idx, word_idx] 39 | :param stop: [sent_idx, word_idx] 40 | :return: 41 | """ 42 | start, stop = span 43 | flat_start = get_flat_idx(wordss, start) 44 | flat_stop = get_flat_idx(wordss, stop) 45 | words = sum(wordss, []) 46 | char_idx = 0 47 | char_start, char_stop = None, None 48 | for word_idx, word in enumerate(words): 49 | char_idx = context.find(word, char_idx) 50 | assert char_idx >= 0 51 | if word_idx == flat_start: 52 | char_start = char_idx 53 | char_idx += len(word) 54 | if word_idx == flat_stop - 1: 55 | char_stop = char_idx 56 | assert char_start is not None 57 | assert char_stop is not None 58 | return context[char_start:char_stop] 59 | 60 | 61 | def get_flat_idx(wordss, idx): 62 | return sum(len(words) for words in wordss[:idx[0]]) + idx[1] 63 | 64 | 65 | def get_word_idx(context, wordss, idx): 66 | spanss = get_2d_spans(context, wordss) 67 | return spanss[idx[0]][idx[1]][0] 68 | 69 | 70 | def process_tokens(temp_tokens): 71 | tokens = [] 72 | for token in temp_tokens: 73 | flag = False 74 | l = ("-", "\u2212", "\u2014", "\u2013", "/", "~", '"', "'", "\u201C", "\u2019", "\u201D", "\u2018", "\u00B0") 75 | # \u2013 is en-dash. Used for number to nubmer 76 | # l = ("-", "\u2212", "\u2014", "\u2013") 77 | # l = ("\u2013",) 78 | tokens.extend(re.split("([{}])".format("".join(l)), token)) 79 | return tokens 80 | 81 | 82 | def get_best_span(ypi, yp2i): 83 | max_val = 0 84 | best_word_span = (0, 1) 85 | best_sent_idx = 0 86 | for f, (ypif, yp2if) in enumerate(zip(ypi, yp2i)): 87 | argmax_j1 = 0 88 | for j in range(len(ypif)): 89 | val1 = ypif[argmax_j1] 90 | if val1 < ypif[j]: 91 | val1 = ypif[j] 92 | argmax_j1 = j 93 | 94 | val2 = yp2if[j] 95 | if val1 * val2 > max_val: 96 | best_word_span = (argmax_j1, j) 97 | best_sent_idx = f 98 | max_val = val1 * val2 99 | return ((best_sent_idx, best_word_span[0]), (best_sent_idx, best_word_span[1] + 1)), float(max_val) 100 | 101 | 102 | def get_span_score_pairs(ypi, yp2i): 103 | span_score_pairs = [] 104 | for f, (ypif, yp2if) in enumerate(zip(ypi, yp2i)): 105 | for j in range(len(ypif)): 106 | for k in range(j, len(yp2if)): 107 | span = ((f, j), (f, k+1)) 108 | score = ypif[j] * yp2if[k] 109 | span_score_pairs.append((span, score)) 110 | return span_score_pairs 111 | 112 | 113 | -------------------------------------------------------------------------------- /tree/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/bi-att-flow/49004549e9a88b78c359b31481afa7792dbb3f4a/tree/__init__.py -------------------------------------------------------------------------------- /tree/cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pprint import pprint 3 | 4 | import tensorflow as tf 5 | 6 | from tree.main import main as m 7 | 8 | flags = tf.app.flags 9 | 10 | flags.DEFINE_string("model_name", "tree", "Model name [tree]") 11 | flags.DEFINE_string("data_dir", "data/squad", "Data dir [data/squad]") 12 | flags.DEFINE_integer("run_id", 0, "Run ID [0]") 13 | 14 | flags.DEFINE_integer("batch_size", 128, "Batch size [128]") 15 | flags.DEFINE_float("init_lr", 0.5, "Initial learning rate [0.5]") 16 | flags.DEFINE_integer("num_epochs", 50, "Total number of epochs for training [50]") 17 | flags.DEFINE_integer("num_steps", 0, "Number of steps [0]") 18 | flags.DEFINE_integer("eval_num_batches", 100, "eval num batches [100]") 19 | flags.DEFINE_integer("load_step", 0, "load step [0]") 20 | flags.DEFINE_integer("early_stop", 4, "early stop [4]") 21 | 22 | flags.DEFINE_string("mode", "test", "train | test | forward [test]") 23 | flags.DEFINE_boolean("load", True, "load saved data? [True]") 24 | flags.DEFINE_boolean("progress", True, "Show progress? [True]") 25 | flags.DEFINE_integer("log_period", 100, "Log period [100]") 26 | flags.DEFINE_integer("eval_period", 1000, "Eval period [1000]") 27 | flags.DEFINE_integer("save_period", 1000, "Save Period [1000]") 28 | flags.DEFINE_float("decay", 0.9, "Exponential moving average decay [0.9]") 29 | 30 | flags.DEFINE_boolean("draft", False, "Draft for quick testing? [False]") 31 | 32 | flags.DEFINE_integer("hidden_size", 32, "Hidden size [32]") 33 | flags.DEFINE_float("input_keep_prob", 0.5, "Input keep prob [0.5]") 34 | flags.DEFINE_integer("char_emb_size", 8, "Char emb size [8]") 35 | flags.DEFINE_integer("char_filter_height", 5, "Char filter height [5]") 36 | flags.DEFINE_float("wd", 0.0001, "Weight decay [0.001]") 37 | flags.DEFINE_bool("lower_word", True, "lower word [True]") 38 | flags.DEFINE_bool("dump_eval", True, "dump eval? [True]") 39 | 40 | flags.DEFINE_integer("word_count_th", 100, "word count th [100]") 41 | flags.DEFINE_integer("char_count_th", 500, "char count th [500]") 42 | flags.DEFINE_integer("sent_size_th", 64, "sent size th [64]") 43 | flags.DEFINE_integer("num_sents_th", 8, "num sents th [8]") 44 | flags.DEFINE_integer("ques_size_th", 64, "ques size th [64]") 45 | flags.DEFINE_integer("word_size_th", 16, "word size th [16]") 46 | flags.DEFINE_integer("tree_height_th", 16, "tree height th [16]") 47 | 48 | 49 | def main(_): 50 | config = flags.FLAGS 51 | 52 | config.out_dir = os.path.join("out", config.model_name, str(config.run_id).zfill(2)) 53 | 54 | m(config) 55 | 56 | if __name__ == "__main__": 57 | tf.app.run() 58 | -------------------------------------------------------------------------------- /tree/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from tree.read_data import DataSet 5 | from my.nltk_utils import span_f1 6 | 7 | 8 | class Evaluation(object): 9 | def __init__(self, data_type, global_step, idxs, yp): 10 | self.data_type = data_type 11 | self.global_step = global_step 12 | self.idxs = idxs 13 | self.yp = yp 14 | self.num_examples = len(yp) 15 | self.dict = {'data_type': data_type, 16 | 'global_step': global_step, 17 | 'yp': yp, 18 | 'idxs': idxs, 19 | 'num_examples': self.num_examples} 20 | self.summaries = None 21 | 22 | def __repr__(self): 23 | return "{} step {}".format(self.data_type, self.global_step) 24 | 25 | def __add__(self, other): 26 | if other == 0: 27 | return self 28 | assert self.data_type == other.data_type 29 | assert self.global_step == other.global_step 30 | new_yp = self.yp + other.yp 31 | new_idxs = self.idxs + other.idxs 32 | return Evaluation(self.data_type, self.global_step, new_idxs, new_yp) 33 | 34 | def __radd__(self, other): 35 | return self.__add__(other) 36 | 37 | 38 | class LabeledEvaluation(Evaluation): 39 | def __init__(self, data_type, global_step, idxs, yp, y): 40 | super(LabeledEvaluation, self).__init__(data_type, global_step, idxs, yp) 41 | self.y = y 42 | self.dict['y'] = y 43 | 44 | def __add__(self, other): 45 | if other == 0: 46 | return self 47 | assert self.data_type == other.data_type 48 | assert self.global_step == other.global_step 49 | new_yp = self.yp + other.yp 50 | new_y = self.y + other.y 51 | new_idxs = self.idxs + other.idxs 52 | return LabeledEvaluation(self.data_type, self.global_step, new_idxs, new_yp, new_y) 53 | 54 | 55 | class AccuracyEvaluation(LabeledEvaluation): 56 | def __init__(self, data_type, global_step, idxs, yp, y, correct, loss): 57 | super(AccuracyEvaluation, self).__init__(data_type, global_step, idxs, yp, y) 58 | self.loss = loss 59 | self.correct = correct 60 | self.acc = sum(correct) / len(correct) 61 | self.dict['loss'] = loss 62 | self.dict['correct'] = correct 63 | self.dict['acc'] = self.acc 64 | loss_summary = tf.Summary(value=[tf.Summary.Value(tag='dev/loss', simple_value=self.loss)]) 65 | acc_summary = tf.Summary(value=[tf.Summary.Value(tag='dev/acc', simple_value=self.acc)]) 66 | self.summaries = [loss_summary, acc_summary] 67 | 68 | def __repr__(self): 69 | return "{} step {}: accuracy={}, loss={}".format(self.data_type, self.global_step, self.acc, self.loss) 70 | 71 | def __add__(self, other): 72 | if other == 0: 73 | return self 74 | assert self.data_type == other.data_type 75 | assert self.global_step == other.global_step 76 | new_idxs = self.idxs + other.idxs 77 | new_yp = self.yp + other.yp 78 | new_y = self.y + other.y 79 | new_correct = self.correct + other.correct 80 | new_loss = (self.loss * self.num_examples + other.loss * other.num_examples) / len(new_correct) 81 | return AccuracyEvaluation(self.data_type, self.global_step, new_idxs, new_yp, new_y, new_correct, new_loss) 82 | 83 | 84 | class Evaluator(object): 85 | def __init__(self, config, model): 86 | self.config = config 87 | self.model = model 88 | 89 | def get_evaluation(self, sess, batch): 90 | idxs, data_set = batch 91 | feed_dict = self.model.get_feed_dict(data_set, False, supervised=False) 92 | global_step, yp = sess.run([self.model.global_step, self.model.yp], feed_dict=feed_dict) 93 | yp = yp[:data_set.num_examples] 94 | e = Evaluation(data_set.data_type, int(global_step), idxs, yp.tolist()) 95 | return e 96 | 97 | def get_evaluation_from_batches(self, sess, batches): 98 | e = sum(self.get_evaluation(sess, batch) for batch in batches) 99 | return e 100 | 101 | 102 | class LabeledEvaluator(Evaluator): 103 | def get_evaluation(self, sess, batch): 104 | idxs, data_set = batch 105 | feed_dict = self.model.get_feed_dict(data_set, False, supervised=False) 106 | global_step, yp = sess.run([self.model.global_step, self.model.yp], feed_dict=feed_dict) 107 | yp = yp[:data_set.num_examples] 108 | y = feed_dict[self.model.y] 109 | e = LabeledEvaluation(data_set.data_type, int(global_step), idxs, yp.tolist(), y.tolist()) 110 | return e 111 | 112 | 113 | class AccuracyEvaluator(LabeledEvaluator): 114 | def get_evaluation(self, sess, batch): 115 | idxs, data_set = batch 116 | assert isinstance(data_set, DataSet) 117 | feed_dict = self.model.get_feed_dict(data_set, False) 118 | global_step, yp, loss = sess.run([self.model.global_step, self.model.yp, self.model.loss], feed_dict=feed_dict) 119 | y = feed_dict[self.model.y] 120 | yp = yp[:data_set.num_examples] 121 | correct = [self.__class__.compare(yi, ypi) for yi, ypi in zip(y, yp)] 122 | e = AccuracyEvaluation(data_set.data_type, int(global_step), idxs, yp.tolist(), y.tolist(), correct, float(loss)) 123 | return e 124 | 125 | @staticmethod 126 | def compare(yi, ypi): 127 | return int(np.argmax(yi)) == int(np.argmax(ypi)) 128 | 129 | 130 | class AccuracyEvaluator2(AccuracyEvaluator): 131 | @staticmethod 132 | def compare(yi, ypi): 133 | i = int(np.argmax(yi.flatten())) 134 | j = int(np.argmax(ypi.flatten())) 135 | # print(i, j, i == j) 136 | return i == j 137 | 138 | 139 | class TempEvaluation(AccuracyEvaluation): 140 | def __init__(self, data_type, global_step, idxs, yp, yp2, y, y2, correct, loss, f1s): 141 | super(TempEvaluation, self).__init__(data_type, global_step, idxs, yp, y, correct, loss) 142 | self.y2 = y2 143 | self.yp2 = yp2 144 | self.f1s = f1s 145 | self.f1 = float(np.mean(f1s)) 146 | self.dict['y2'] = y2 147 | self.dict['yp2'] = yp2 148 | self.dict['f1s'] = f1s 149 | self.dict['f1'] = self.f1 150 | f1_summary = tf.Summary(value=[tf.Summary.Value(tag='dev/f1', simple_value=self.f1)]) 151 | self.summaries.append(f1_summary) 152 | 153 | def __add__(self, other): 154 | if other == 0: 155 | return self 156 | assert self.data_type == other.data_type 157 | assert self.global_step == other.global_step 158 | new_idxs = self.idxs + other.idxs 159 | new_yp = self.yp + other.yp 160 | new_yp2 = self.yp2 + other.yp2 161 | new_y = self.y + other.y 162 | new_y2 = self.y2 + other.y2 163 | new_correct = self.correct + other.correct 164 | new_f1s = self.f1s + other.f1s 165 | new_loss = (self.loss * self.num_examples + other.loss * other.num_examples) / len(new_correct) 166 | return TempEvaluation(self.data_type, self.global_step, new_idxs, new_yp, new_yp2, new_y, new_y2, new_correct, new_loss, new_f1s) 167 | 168 | 169 | class TempEvaluator(LabeledEvaluator): 170 | def get_evaluation(self, sess, batch): 171 | idxs, data_set = batch 172 | assert isinstance(data_set, DataSet) 173 | feed_dict = self.model.get_feed_dict(data_set, False) 174 | global_step, yp, yp2, loss = sess.run([self.model.global_step, self.model.yp, self.model.yp2, self.model.loss], feed_dict=feed_dict) 175 | y, y2 = feed_dict[self.model.y], feed_dict[self.model.y2] 176 | yp, yp2 = yp[:data_set.num_examples], yp2[:data_set.num_examples] 177 | correct = [self.__class__.compare(yi, y2i, ypi, yp2i) for yi, y2i, ypi, yp2i in zip(y, y2, yp, yp2)] 178 | f1s = [self.__class__.span_f1(yi, y2i, ypi, yp2i) for yi, y2i, ypi, yp2i in zip(y, y2, yp, yp2)] 179 | e = TempEvaluation(data_set.data_type, int(global_step), idxs, yp.tolist(), yp2.tolist(), y.tolist(), y2.tolist(), correct, float(loss), f1s) 180 | return e 181 | 182 | @staticmethod 183 | def compare(yi, y2i, ypi, yp2i): 184 | i = int(np.argmax(yi.flatten())) 185 | j = int(np.argmax(ypi.flatten())) 186 | k = int(np.argmax(y2i.flatten())) 187 | l = int(np.argmax(yp2i.flatten())) 188 | # print(i, j, i == j) 189 | return i == j and k == l 190 | 191 | @staticmethod 192 | def span_f1(yi, y2i, ypi, yp2i): 193 | true_span = (np.argmax(yi.flatten()), np.argmax(y2i.flatten())+1) 194 | pred_span = (np.argmax(ypi.flatten()), np.argmax(yp2i.flatten())+1) 195 | f1 = span_f1(true_span, pred_span) 196 | return f1 197 | 198 | -------------------------------------------------------------------------------- /tree/graph_handler.py: -------------------------------------------------------------------------------- 1 | import json 2 | from json import encoder 3 | import os 4 | 5 | import tensorflow as tf 6 | 7 | from tree.evaluator import Evaluation 8 | from my.utils import short_floats 9 | 10 | 11 | class GraphHandler(object): 12 | def __init__(self, config): 13 | self.config = config 14 | self.saver = tf.train.Saver() 15 | self.writer = None 16 | self.save_path = os.path.join(config.save_dir, config.model_name) 17 | 18 | def initialize(self, sess): 19 | if self.config.load: 20 | self._load(sess) 21 | else: 22 | sess.run(tf.initialize_all_variables()) 23 | 24 | if self.config.mode == 'train': 25 | self.writer = tf.train.SummaryWriter(self.config.log_dir, graph=tf.get_default_graph()) 26 | 27 | def save(self, sess, global_step=None): 28 | self.saver.save(sess, self.save_path, global_step=global_step) 29 | 30 | def _load(self, sess): 31 | config = self.config 32 | if config.load_step > 0: 33 | save_path = os.path.join(config.save_dir, "{}-{}".format(config.model_name, config.load_step)) 34 | else: 35 | save_dir = config.save_dir 36 | checkpoint = tf.train.get_checkpoint_state(save_dir) 37 | assert checkpoint is not None, "cannot load checkpoint at {}".format(save_dir) 38 | save_path = checkpoint.model_checkpoint_path 39 | print("Loading saved model from {}".format(save_path)) 40 | self.saver.restore(sess, save_path) 41 | 42 | def add_summary(self, summary, global_step): 43 | self.writer.add_summary(summary, global_step) 44 | 45 | def add_summaries(self, summaries, global_step): 46 | for summary in summaries: 47 | self.add_summary(summary, global_step) 48 | 49 | def dump_eval(self, e, precision=2): 50 | assert isinstance(e, Evaluation) 51 | path = os.path.join(self.config.eval_dir, "{}-{}.json".format(e.data_type, str(e.global_step).zfill(6))) 52 | with open(path, 'w') as fh: 53 | json.dump(short_floats(e.dict, precision), fh) 54 | 55 | -------------------------------------------------------------------------------- /tree/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import os 5 | import shutil 6 | from pprint import pprint 7 | 8 | import tensorflow as tf 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | from tree.evaluator import AccuracyEvaluator2, Evaluator 13 | from tree.graph_handler import GraphHandler 14 | from tree.model import Model 15 | from tree.trainer import Trainer 16 | 17 | from tree.read_data import load_metadata, read_data, get_squad_data_filter, update_config 18 | 19 | 20 | def main(config): 21 | set_dirs(config) 22 | if config.mode == 'train': 23 | _train(config) 24 | elif config.mode == 'test': 25 | _test(config) 26 | elif config.mode == 'forward': 27 | _forward(config) 28 | else: 29 | raise ValueError("invalid value for 'mode': {}".format(config.mode)) 30 | 31 | 32 | def _config_draft(config): 33 | if config.draft: 34 | config.num_steps = 10 35 | config.eval_period = 10 36 | config.log_period = 1 37 | config.save_period = 10 38 | config.eval_num_batches = 1 39 | 40 | 41 | def _train(config): 42 | # load_metadata(config, 'train') # this updates the config file according to metadata file 43 | 44 | data_filter = get_squad_data_filter(config) 45 | train_data = read_data(config, 'train', config.load, data_filter=data_filter) 46 | dev_data = read_data(config, 'dev', True, data_filter=data_filter) 47 | update_config(config, [train_data, dev_data]) 48 | 49 | _config_draft(config) 50 | 51 | word2vec_dict = train_data.shared['lower_word2vec'] if config.lower_word else train_data.shared['word2vec'] 52 | word2idx_dict = train_data.shared['word2idx'] 53 | idx2vec_dict = {word2idx_dict[word]: vec for word, vec in word2vec_dict.items() if word in word2idx_dict} 54 | print("{}/{} unique words have corresponding glove vectors.".format(len(idx2vec_dict), len(word2idx_dict))) 55 | emb_mat = np.array([idx2vec_dict[idx] if idx in idx2vec_dict 56 | else np.random.multivariate_normal(np.zeros(config.word_emb_size), np.eye(config.word_emb_size)) 57 | for idx in range(config.word_vocab_size)]) 58 | config.emb_mat = emb_mat 59 | 60 | # construct model graph and variables (using default graph) 61 | pprint(config.__flags, indent=2) 62 | model = Model(config) 63 | trainer = Trainer(config, model) 64 | evaluator = AccuracyEvaluator2(config, model) 65 | graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving 66 | 67 | # Variables 68 | sess = tf.Session() 69 | graph_handler.initialize(sess) 70 | 71 | # begin training 72 | num_steps = config.num_steps or int(config.num_epochs * train_data.num_examples / config.batch_size) 73 | max_acc = 0 74 | noupdate_count = 0 75 | global_step = 0 76 | for _, batch in tqdm(train_data.get_batches(config.batch_size, num_batches=num_steps, shuffle=True), total=num_steps): 77 | global_step = sess.run(model.global_step) + 1 # +1 because all calculations are done after step 78 | get_summary = global_step % config.log_period == 0 79 | loss, summary, train_op = trainer.step(sess, batch, get_summary=get_summary) 80 | if get_summary: 81 | graph_handler.add_summary(summary, global_step) 82 | 83 | # Occasional evaluation and saving 84 | if global_step % config.save_period == 0: 85 | graph_handler.save(sess, global_step=global_step) 86 | if global_step % config.eval_period == 0: 87 | num_batches = math.ceil(dev_data.num_examples / config.batch_size) 88 | if 0 < config.eval_num_batches < num_batches: 89 | num_batches = config.eval_num_batches 90 | e = evaluator.get_evaluation_from_batches( 91 | sess, tqdm(dev_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches)) 92 | graph_handler.add_summaries(e.summaries, global_step) 93 | if e.acc > max_acc: 94 | max_acc = e.acc 95 | noupdate_count = 0 96 | else: 97 | noupdate_count += 1 98 | if noupdate_count == config.early_stop: 99 | break 100 | if config.dump_eval: 101 | graph_handler.dump_eval(e) 102 | if global_step % config.save_period != 0: 103 | graph_handler.save(sess, global_step=global_step) 104 | 105 | 106 | def _test(config): 107 | test_data = read_data(config, 'test', True) 108 | update_config(config, [test_data]) 109 | 110 | _config_draft(config) 111 | 112 | pprint(config.__flags, indent=2) 113 | model = Model(config) 114 | evaluator = AccuracyEvaluator2(config, model) 115 | graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving 116 | 117 | sess = tf.Session() 118 | graph_handler.initialize(sess) 119 | 120 | num_batches = math.ceil(test_data.num_examples / config.batch_size) 121 | if 0 < config.eval_num_batches < num_batches: 122 | num_batches = config.eval_num_batches 123 | e = evaluator.get_evaluation_from_batches(sess, tqdm(test_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches)) 124 | print(e) 125 | if config.dump_eval: 126 | graph_handler.dump_eval(e) 127 | 128 | 129 | def _forward(config): 130 | 131 | forward_data = read_data(config, 'forward', True) 132 | 133 | _config_draft(config) 134 | 135 | pprint(config.__flag, indent=2) 136 | model = Model(config) 137 | evaluator = Evaluator(config, model) 138 | graph_handler = GraphHandler(config) # controls all tensors and variables in the graph, including loading /saving 139 | 140 | sess = tf.Session() 141 | graph_handler.initialize(sess) 142 | 143 | num_batches = math.ceil(forward_data.num_examples / config.batch_size) 144 | if 0 < config.eval_num_batches < num_batches: 145 | num_batches = config.eval_num_batches 146 | e = evaluator.get_evaluation_from_batches(sess, tqdm(forward_data.get_batches(config.batch_size, num_batches=num_batches), total=num_batches)) 147 | print(e) 148 | if config.dump_eval: 149 | graph_handler.dump_eval(e) 150 | 151 | 152 | def set_dirs(config): 153 | # create directories 154 | if not config.load and os.path.exists(config.out_dir): 155 | shutil.rmtree(config.out_dir) 156 | 157 | config.save_dir = os.path.join(config.out_dir, "save") 158 | config.log_dir = os.path.join(config.out_dir, "log") 159 | config.eval_dir = os.path.join(config.out_dir, "eval") 160 | if not os.path.exists(config.out_dir): 161 | os.makedirs(config.out_dir) 162 | if not os.path.exists(config.save_dir): 163 | os.mkdir(config.save_dir) 164 | if not os.path.exists(config.log_dir): 165 | os.mkdir(config.eval_dir) 166 | 167 | 168 | def _get_args(): 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument("config_path") 171 | return parser.parse_args() 172 | 173 | 174 | class Config(object): 175 | def __init__(self, **entries): 176 | self.__dict__.update(entries) 177 | 178 | 179 | def _run(): 180 | args = _get_args() 181 | with open(args.config_path, 'r') as fh: 182 | config = Config(**json.load(fh)) 183 | main(config) 184 | 185 | 186 | if __name__ == "__main__": 187 | _run() 188 | -------------------------------------------------------------------------------- /tree/read_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import itertools 5 | import math 6 | 7 | import nltk 8 | 9 | from my.nltk_utils import load_compressed_tree 10 | from my.utils import index 11 | 12 | 13 | class DataSet(object): 14 | def __init__(self, data, data_type, shared=None, valid_idxs=None): 15 | total_num_examples = len(next(iter(data.values()))) 16 | self.data = data # e.g. {'X': [0, 1, 2], 'Y': [2, 3, 4]} 17 | self.data_type = data_type 18 | self.shared = shared 19 | self.valid_idxs = range(total_num_examples) if valid_idxs is None else valid_idxs 20 | self.num_examples = len(self.valid_idxs) 21 | 22 | def get_batches(self, batch_size, num_batches=None, shuffle=False): 23 | num_batches_per_epoch = int(math.ceil(self.num_examples / batch_size)) 24 | if num_batches is None: 25 | num_batches = num_batches_per_epoch 26 | num_epochs = int(math.ceil(num_batches / num_batches_per_epoch)) 27 | 28 | idxs = itertools.chain.from_iterable(random.sample(self.valid_idxs, len(self.valid_idxs)) 29 | if shuffle else self.valid_idxs 30 | for _ in range(num_epochs)) 31 | for _ in range(num_batches): 32 | batch_idxs = tuple(itertools.islice(idxs, batch_size)) 33 | batch_data = {} 34 | for key, val in self.data.items(): 35 | if key.startswith('*'): 36 | assert self.shared is not None 37 | shared_key = key[1:] 38 | batch_data[shared_key] = [index(self.shared[shared_key], val[idx]) for idx in batch_idxs] 39 | else: 40 | batch_data[key] = list(map(val.__getitem__, batch_idxs)) 41 | 42 | batch_ds = DataSet(batch_data, self.data_type, shared=self.shared) 43 | yield batch_idxs, batch_ds 44 | 45 | 46 | class SquadDataSet(DataSet): 47 | def __init__(self, data, data_type, shared=None, valid_idxs=None): 48 | super(SquadDataSet, self).__init__(data, data_type, shared=shared, valid_idxs=valid_idxs) 49 | 50 | 51 | def load_metadata(config, data_type): 52 | metadata_path = os.path.join(config.data_dir, "metadata_{}.json".format(data_type)) 53 | with open(metadata_path, 'r') as fh: 54 | metadata = json.load(fh) 55 | for key, val in metadata.items(): 56 | config.__setattr__(key, val) 57 | return metadata 58 | 59 | 60 | def read_data(config, data_type, ref, data_filter=None): 61 | data_path = os.path.join(config.data_dir, "data_{}.json".format(data_type)) 62 | shared_path = os.path.join(config.data_dir, "shared_{}.json".format(data_type)) 63 | with open(data_path, 'r') as fh: 64 | data = json.load(fh) 65 | with open(shared_path, 'r') as fh: 66 | shared = json.load(fh) 67 | 68 | num_examples = len(next(iter(data.values()))) 69 | if data_filter is None: 70 | valid_idxs = range(num_examples) 71 | else: 72 | mask = [] 73 | keys = data.keys() 74 | values = data.values() 75 | for vals in zip(*values): 76 | each = {key: val for key, val in zip(keys, vals)} 77 | mask.append(data_filter(each, shared)) 78 | valid_idxs = [idx for idx in range(len(mask)) if mask[idx]] 79 | 80 | print("Loaded {}/{} examples from {}".format(len(valid_idxs), num_examples, data_type)) 81 | 82 | shared_path = os.path.join(config.out_dir, "shared.json") 83 | if not ref: 84 | word_counter = shared['lower_word_counter'] if config.lower_word else shared['word_counter'] 85 | char_counter = shared['char_counter'] 86 | pos_counter = shared['pos_counter'] 87 | shared['word2idx'] = {word: idx + 2 for idx, word in 88 | enumerate(word for word, count in word_counter.items() 89 | if count > config.word_count_th)} 90 | shared['char2idx'] = {char: idx + 2 for idx, char in 91 | enumerate(char for char, count in char_counter.items() 92 | if count > config.char_count_th)} 93 | shared['pos2idx'] = {pos: idx + 2 for idx, pos in enumerate(pos_counter.keys())} 94 | NULL = "-NULL-" 95 | UNK = "-UNK-" 96 | shared['word2idx'][NULL] = 0 97 | shared['word2idx'][UNK] = 1 98 | shared['char2idx'][NULL] = 0 99 | shared['char2idx'][UNK] = 1 100 | shared['pos2idx'][NULL] = 0 101 | shared['pos2idx'][UNK] = 1 102 | json.dump({'word2idx': shared['word2idx'], 'char2idx': shared['char2idx'], 103 | 'pos2idx': shared['pos2idx']}, open(shared_path, 'w')) 104 | else: 105 | new_shared = json.load(open(shared_path, 'r')) 106 | for key, val in new_shared.items(): 107 | shared[key] = val 108 | 109 | data_set = DataSet(data, data_type, shared=shared, valid_idxs=valid_idxs) 110 | return data_set 111 | 112 | 113 | def get_squad_data_filter(config): 114 | def data_filter(data_point, shared): 115 | assert shared is not None 116 | rx, rcx, q, cq, y = (data_point[key] for key in ('*x', '*cx', 'q', 'cq', 'y')) 117 | x, cx, stx = shared['x'], shared['cx'], shared['stx'] 118 | if len(q) > config.ques_size_th: 119 | return False 120 | xi = x[rx[0]][rx[1]] 121 | if len(xi) > config.num_sents_th: 122 | return False 123 | if any(len(xij) > config.sent_size_th for xij in xi): 124 | return False 125 | stxi = stx[rx[0]][rx[1]] 126 | if any(nltk.tree.Tree.fromstring(s).height() > config.tree_height_th for s in stxi): 127 | return False 128 | return True 129 | return data_filter 130 | 131 | 132 | def update_config(config, data_sets): 133 | config.max_num_sents = 0 134 | config.max_sent_size = 0 135 | config.max_ques_size = 0 136 | config.max_word_size = 0 137 | config.max_tree_height = 0 138 | for data_set in data_sets: 139 | data = data_set.data 140 | shared = data_set.shared 141 | for idx in data_set.valid_idxs: 142 | rx = data['*x'][idx] 143 | q = data['q'][idx] 144 | sents = shared['x'][rx[0]][rx[1]] 145 | trees = map(nltk.tree.Tree.fromstring, shared['stx'][rx[0]][rx[1]]) 146 | config.max_tree_height = max(config.max_tree_height, max(tree.height() for tree in trees)) 147 | config.max_num_sents = max(config.max_num_sents, len(sents)) 148 | config.max_sent_size = max(config.max_sent_size, max(map(len, sents))) 149 | config.max_word_size = max(config.max_word_size, max(len(word) for sent in sents for word in sent)) 150 | if len(q) > 0: 151 | config.max_ques_size = max(config.max_ques_size, len(q)) 152 | config.max_word_size = max(config.max_word_size, max(len(word) for word in q)) 153 | 154 | config.max_word_size = min(config.max_word_size, config.word_size_th) 155 | 156 | config.char_vocab_size = len(data_sets[0].shared['char2idx']) 157 | config.word_emb_size = len(next(iter(data_sets[0].shared['word2vec'].values()))) 158 | config.word_vocab_size = len(data_sets[0].shared['word2idx']) 159 | config.pos_vocab_size = len(data_sets[0].shared['pos2idx']) 160 | -------------------------------------------------------------------------------- /tree/templates/visualizer.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | {{ title }} 6 | 7 | 8 | 19 | 20 | 23 | 24 |

    {{ title }}

    25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | {% for row in rows %} 33 | 34 | 35 | 40 | 41 | 63 | 64 | {% endfor %} 65 |
    IDQuestionAnswerParagraph
    {{ row.id }} 36 | {% for qj in row.ques %} 37 | {{ qj }} 38 | {% endfor %} 39 | {{ row.a }} 42 | 43 | {% for xj, yj, y2j, ypj, yp2j in zip(row.para, row.y, row.y2, row.yp, row.yp2) %} 44 | 45 | {% for xjk, yjk, y2jk, ypjk in zip(xj, yj, y2j, ypj) %} 46 | 53 | {% endfor %} 54 | 55 | 56 | {% for xjk, yp2jk in zip(xj, yp2j) %} 57 | 58 | {% endfor %} 59 | 60 | {% endfor %} 61 |
    47 | {% if yjk or y2jk %} 48 | {{ xjk }} 49 | {% else %} 50 | {{ xjk }} 51 | {% endif %} 52 |
    -
    62 |
    66 | 67 | -------------------------------------------------------------------------------- /tree/test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import nltk\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "%matplotlib inline" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 10, 19 | "metadata": { 20 | "collapsed": false 21 | }, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "(S (PRP I) (VP (VBP am) (NNP Sam)) (. .))\n", 28 | "(PRP I)\n", 29 | "(VP (VBP am) (NNP Sam))\n", 30 | "(VBP am)\n", 31 | "(NNP Sam)\n", 32 | "(. .)\n", 33 | "(S (PRP I) (VP (VBP am) (NNP Sam)) (. .))\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "string = \"(ROOT(S(NP (PRP I))(VP (VBP am)(NP (NNP Sam)))(. .)))\"\n", 39 | "tree = nltk.tree.Tree.fromstring(string)\n", 40 | "\n", 41 | "def load_compressed_tree(s):\n", 42 | "\n", 43 | " def compress_tree(tree):\n", 44 | " if len(tree) == 1:\n", 45 | " if isinstance(tree[0], nltk.tree.Tree):\n", 46 | " return compress_tree(tree[0])\n", 47 | " else:\n", 48 | " return tree\n", 49 | " else:\n", 50 | " for i, t in enumerate(tree):\n", 51 | " tree[i] = compress_tree(t)\n", 52 | " return tree\n", 53 | "\n", 54 | " return compress_tree(nltk.tree.Tree.fromstring(s))\n", 55 | "tree = load_compressed_tree(string)\n", 56 | "for t in tree.subtrees():\n", 57 | " print(t)\n", 58 | " \n", 59 | "print(str(tree))" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "metadata": { 66 | "collapsed": false 67 | }, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "(ROOT I am Sam .)\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "print(tree.flatten())" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 10, 84 | "metadata": { 85 | "collapsed": false 86 | }, 87 | "outputs": [ 88 | { 89 | "name": "stdout", 90 | "output_type": "stream", 91 | "text": [ 92 | "['ROOT', 'S', 'NP', 'PRP', 'VP', 'VBP', 'NP', 'NNP', '.']\n" 93 | ] 94 | } 95 | ], 96 | "source": [ 97 | "print(list(t.label() for t in tree.subtrees()))" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 11, 103 | "metadata": { 104 | "collapsed": true 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "import json\n", 109 | "d = json.load(open(\"data/squad/shared_dev.json\", 'r'))" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 12, 115 | "metadata": { 116 | "collapsed": false 117 | }, 118 | "outputs": [ 119 | { 120 | "data": { 121 | "text/plain": [ 122 | "73" 123 | ] 124 | }, 125 | "execution_count": 12, 126 | "metadata": {}, 127 | "output_type": "execute_result" 128 | } 129 | ], 130 | "source": [ 131 | "len(d['pos_counter'])" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 13, 137 | "metadata": { 138 | "collapsed": false 139 | }, 140 | "outputs": [ 141 | { 142 | "data": { 143 | "text/plain": [ 144 | "{'#': 6,\n", 145 | " '$': 80,\n", 146 | " \"''\": 1291,\n", 147 | " ',': 14136,\n", 148 | " '-LRB-': 1926,\n", 149 | " '-RRB-': 1925,\n", 150 | " '.': 9505,\n", 151 | " ':': 1455,\n", 152 | " 'ADJP': 3426,\n", 153 | " 'ADVP': 4936,\n", 154 | " 'CC': 9300,\n", 155 | " 'CD': 6216,\n", 156 | " 'CONJP': 191,\n", 157 | " 'DT': 26286,\n", 158 | " 'EX': 288,\n", 159 | " 'FRAG': 107,\n", 160 | " 'FW': 96,\n", 161 | " 'IN': 32564,\n", 162 | " 'INTJ': 12,\n", 163 | " 'JJ': 21452,\n", 164 | " 'JJR': 563,\n", 165 | " 'JJS': 569,\n", 166 | " 'LS': 7,\n", 167 | " 'LST': 1,\n", 168 | " 'MD': 1051,\n", 169 | " 'NAC': 19,\n", 170 | " 'NN': 34750,\n", 171 | " 'NNP': 28392,\n", 172 | " 'NNPS': 1400,\n", 173 | " 'NNS': 16716,\n", 174 | " 'NP': 91636,\n", 175 | " 'NP-TMP': 236,\n", 176 | " 'NX': 108,\n", 177 | " 'PDT': 89,\n", 178 | " 'POS': 1451,\n", 179 | " 'PP': 33278,\n", 180 | " 'PRN': 2085,\n", 181 | " 'PRP': 2320,\n", 182 | " 'PRP$': 1959,\n", 183 | " 'PRT': 450,\n", 184 | " 'QP': 838,\n", 185 | " 'RB': 7611,\n", 186 | " 'RBR': 301,\n", 187 | " 'RBS': 252,\n", 188 | " 'ROOT': 9587,\n", 189 | " 'RP': 454,\n", 190 | " 'RRC': 19,\n", 191 | " 'S': 21557,\n", 192 | " 'SBAR': 5009,\n", 193 | " 'SBARQ': 6,\n", 194 | " 'SINV': 135,\n", 195 | " 'SQ': 5,\n", 196 | " 'SYM': 17,\n", 197 | " 'TO': 5167,\n", 198 | " 'UCP': 143,\n", 199 | " 'UH': 15,\n", 200 | " 'VB': 4197,\n", 201 | " 'VBD': 8377,\n", 202 | " 'VBG': 3570,\n", 203 | " 'VBN': 7218,\n", 204 | " 'VBP': 2897,\n", 205 | " 'VBZ': 4146,\n", 206 | " 'VP': 33696,\n", 207 | " 'WDT': 1368,\n", 208 | " 'WHADJP': 5,\n", 209 | " 'WHADVP': 439,\n", 210 | " 'WHNP': 1927,\n", 211 | " 'WHPP': 153,\n", 212 | " 'WP': 482,\n", 213 | " 'WP$': 50,\n", 214 | " 'WRB': 442,\n", 215 | " 'X': 23,\n", 216 | " '``': 1269}" 217 | ] 218 | }, 219 | "execution_count": 13, 220 | "metadata": {}, 221 | "output_type": "execute_result" 222 | } 223 | ], 224 | "source": [ 225 | "d['pos_counter']" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 3, 231 | "metadata": { 232 | "collapsed": false 233 | }, 234 | "outputs": [ 235 | { 236 | "name": "stdout", 237 | "output_type": "stream", 238 | "text": [ 239 | "[[False False False False]\n", 240 | " [False True False False]\n", 241 | " [False False False False]]\n", 242 | "[[0 2 2 0]\n", 243 | " [2 2 0 2]\n", 244 | " [2 0 0 0]]\n" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "from my.nltk_utils import tree2matrix, load_compressed_tree, find_max_f1_subtree, set_span\n", 250 | "string = \"(ROOT(S(NP (PRP I))(VP (VBP am)(NP (NNP Sam)))(. .)))\"\n", 251 | "tree = load_compressed_tree(string)\n", 252 | "span = (1, 3)\n", 253 | "set_span(tree)\n", 254 | "subtree = find_max_f1_subtree(tree, span)\n", 255 | "f = lambda t: t == subtree\n", 256 | "g = lambda t: 1 if isinstance(t, str) else 2\n", 257 | "a, b = tree2matrix(tree, f, dtype='bool')\n", 258 | "c, d = tree2matrix(tree, g, dtype='int32')\n", 259 | "print(a)\n", 260 | "print(c)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": { 267 | "collapsed": true 268 | }, 269 | "outputs": [], 270 | "source": [] 271 | } 272 | ], 273 | "metadata": { 274 | "kernelspec": { 275 | "display_name": "Python 3", 276 | "language": "python", 277 | "name": "python3" 278 | }, 279 | "language_info": { 280 | "codemirror_mode": { 281 | "name": "ipython", 282 | "version": 3 283 | }, 284 | "file_extension": ".py", 285 | "mimetype": "text/x-python", 286 | "name": "python", 287 | "nbconvert_exporter": "python", 288 | "pygments_lexer": "ipython3", 289 | "version": "3.5.1" 290 | } 291 | }, 292 | "nbformat": 4, 293 | "nbformat_minor": 0 294 | } 295 | -------------------------------------------------------------------------------- /tree/trainer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tree.model import Model 4 | 5 | 6 | class Trainer(object): 7 | def __init__(self, config, model): 8 | assert isinstance(model, Model) 9 | self.config = config 10 | self.model = model 11 | self.opt = tf.train.AdagradOptimizer(config.init_lr) 12 | self.loss = model.get_loss() 13 | self.var_list = model.get_var_list() 14 | self.global_step = model.get_global_step() 15 | self.ema_op = model.ema_op 16 | self.summary = model.summary 17 | self.grads = self.opt.compute_gradients(self.loss, var_list=self.var_list) 18 | opt_op = self.opt.apply_gradients(self.grads, global_step=self.global_step) 19 | 20 | # Define train op 21 | with tf.control_dependencies([opt_op]): 22 | self.train_op = tf.group(self.ema_op) 23 | 24 | def get_train_op(self): 25 | return self.train_op 26 | 27 | def step(self, sess, batch, get_summary=False): 28 | assert isinstance(sess, tf.Session) 29 | feed_dict = self.model.get_feed_dict(batch, True) 30 | if get_summary: 31 | loss, summary, train_op = \ 32 | sess.run([self.loss, self.summary, self.train_op], feed_dict=feed_dict) 33 | else: 34 | loss, train_op = sess.run([self.loss, self.train_op], feed_dict=feed_dict) 35 | summary = None 36 | return loss, summary, train_op 37 | -------------------------------------------------------------------------------- /tree/visualizer.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from collections import OrderedDict 3 | import http.server 4 | import socketserver 5 | import argparse 6 | import json 7 | import os 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from jinja2 import Environment, FileSystemLoader 12 | 13 | 14 | def bool_(string): 15 | if string == 'True': 16 | return True 17 | elif string == 'False': 18 | return False 19 | else: 20 | raise Exception() 21 | 22 | def get_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--model_name", type=str, default='basic') 25 | parser.add_argument("--data_type", type=str, default='dev') 26 | parser.add_argument("--step", type=int, default=5000) 27 | parser.add_argument("--template_name", type=str, default="visualizer.html") 28 | parser.add_argument("--num_per_page", type=int, default=100) 29 | parser.add_argument("--data_dir", type=str, default="data/squad") 30 | parser.add_argument("--port", type=int, default=8000) 31 | parser.add_argument("--host", type=str, default="0.0.0.0") 32 | parser.add_argument("--open", type=str, default='False') 33 | parser.add_argument("--run_id", type=str, default="0") 34 | 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | def _decode(decoder, sent): 40 | return " ".join(decoder[idx] for idx in sent) 41 | 42 | 43 | def accuracy2_visualizer(args): 44 | model_name = args.model_name 45 | data_type = args.data_type 46 | num_per_page = args.num_per_page 47 | data_dir = args.data_dir 48 | run_id = args.run_id.zfill(2) 49 | step = args.step 50 | 51 | eval_path =os.path.join("out", model_name, run_id, "eval", "{}-{}.json".format(data_type, str(step).zfill(6))) 52 | eval_ = json.load(open(eval_path, 'r')) 53 | 54 | _id = 0 55 | html_dir = "/tmp/list_results%d" % _id 56 | while os.path.exists(html_dir): 57 | _id += 1 58 | html_dir = "/tmp/list_results%d" % _id 59 | 60 | if os.path.exists(html_dir): 61 | shutil.rmtree(html_dir) 62 | os.mkdir(html_dir) 63 | 64 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 65 | templates_dir = os.path.join(cur_dir, 'templates') 66 | env = Environment(loader=FileSystemLoader(templates_dir)) 67 | env.globals.update(zip=zip, reversed=reversed) 68 | template = env.get_template(args.template_name) 69 | 70 | data_path = os.path.join(data_dir, "data_{}.json".format(data_type)) 71 | shared_path = os.path.join(data_dir, "shared_{}.json".format(data_type)) 72 | data = json.load(open(data_path, 'r')) 73 | shared = json.load(open(shared_path, 'r')) 74 | 75 | rows = [] 76 | for i, (idx, yi, ypi) in enumerate(zip(*[eval_[key] for key in ('idxs', 'y', 'yp')])): 77 | id_, q, rx = (data[key][idx] for key in ('ids', 'q', '*x')) 78 | x = shared['x'][rx[0]][rx[1]] 79 | ques = [" ".join(q)] 80 | para = [[word for word in sent] for sent in x] 81 | row = { 82 | 'id': id_, 83 | 'title': "Hello world!", 84 | 'ques': ques, 85 | 'para': para, 86 | 'y': yi, 87 | 'y2': yi, 88 | 'yp': ypi, 89 | 'yp2': ypi, 90 | 'a': "" 91 | } 92 | rows.append(row) 93 | 94 | if i % num_per_page == 0: 95 | html_path = os.path.join(html_dir, "%s.html" % str(i).zfill(8)) 96 | 97 | if (i + 1) % num_per_page == 0 or (i + 1) == len(eval_['y']): 98 | var_dict = {'title': "Accuracy Visualization", 99 | 'rows': rows 100 | } 101 | with open(html_path, "wb") as f: 102 | f.write(template.render(**var_dict).encode('UTF-8')) 103 | rows = [] 104 | 105 | os.chdir(html_dir) 106 | port = args.port 107 | host = args.host 108 | # Overriding to suppress log message 109 | class MyHandler(http.server.SimpleHTTPRequestHandler): 110 | def log_message(self, format, *args): 111 | pass 112 | handler = MyHandler 113 | httpd = socketserver.TCPServer((host, port), handler) 114 | if args.open == 'True': 115 | os.system("open http://%s:%d" % (args.host, args.port)) 116 | print("serving at %s:%d" % (host, port)) 117 | httpd.serve_forever() 118 | 119 | 120 | if __name__ == "__main__": 121 | ARGS = get_args() 122 | accuracy2_visualizer(ARGS) --------------------------------------------------------------------------------