├── __init__.py ├── nectar ├── nectar │ ├── base │ │ ├── __init__.py │ │ ├── numbers.py │ │ ├── vecops.py │ │ ├── codalab.py │ │ ├── util.py │ │ ├── sequences.py │ │ ├── vocabulary.py │ │ ├── trie.py │ │ ├── intervals.py │ │ └── graph.py │ ├── fig │ │ ├── __init__.py │ │ └── lisptree.py │ ├── __init__.py │ ├── corenlp │ │ ├── __init__.py │ │ ├── repl.py │ │ ├── server.py │ │ ├── util.py │ │ └── client.py │ └── theanoutil │ │ ├── __init__.py │ │ ├── rnn.py │ │ ├── treelstm.py │ │ ├── util.py │ │ ├── args.py │ │ └── model.py └── .gitignore ├── requirements.txt ├── download_data.sh ├── LICENSE ├── nearest_glove ├── utils.py └── get_nearest.py ├── README.md ├── bridge_entity_rules.py ├── answer_rules.py └── convert_sp_facts.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nectar/nectar/base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nectar/nectar/base/numbers.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nectar/nectar/fig/__init__.py: -------------------------------------------------------------------------------- 1 | import lisptree 2 | -------------------------------------------------------------------------------- /nectar/nectar/__init__.py: -------------------------------------------------------------------------------- 1 | from base.util import * # Put util at top level 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk 2 | tqdm 3 | pattern 4 | bottle 5 | numpy 6 | scipy 7 | scikit-learn 8 | termcolor -------------------------------------------------------------------------------- /nectar/nectar/corenlp/__init__.py: -------------------------------------------------------------------------------- 1 | from client import CoreNLPClient 2 | from server import CoreNLPServer 3 | from util import * 4 | -------------------------------------------------------------------------------- /nectar/nectar/theanoutil/__init__.py: -------------------------------------------------------------------------------- 1 | from args import * 2 | from model import * 3 | import rnn 4 | import treelstm 5 | from util import * 6 | -------------------------------------------------------------------------------- /nectar/.gitignore: -------------------------------------------------------------------------------- 1 | # Downloaded dependencies 2 | lib/ 3 | 4 | #Compiled files 5 | *.jar 6 | *.o 7 | *.pyc 8 | 9 | # LaTeX 10 | *.aux 11 | *.log 12 | *.out 13 | *.pdf 14 | 15 | # Swap files 16 | *.sw* 17 | 18 | # Temporary files 19 | tmp* 20 | *.tmp 21 | -------------------------------------------------------------------------------- /nectar/nectar/corenlp/repl.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from client import CoreNLPClient 4 | from server import CoreNLPServer 5 | 6 | def main(): 7 | """Start a REPL to interact with the server.""" 8 | # TODO: actuall make it a REPL 9 | c = CoreNLPClient() 10 | with CoreNLPServer() as s: 11 | print(json.dumps(c.query_depparse(['Bills on ports and immigration were submitted by Senator Brownback , Republican of Kansas']), indent=2)) 12 | 13 | if __name__ == '__main__': 14 | main() 15 | -------------------------------------------------------------------------------- /nectar/nectar/base/vecops.py: -------------------------------------------------------------------------------- 1 | """Operations on sparse vectors represented as dicts.""" 2 | import collections 3 | 4 | # Mutating a vector 5 | def add(v, other, scale=1): 6 | for k in other: 7 | if k in v: 8 | v[k] += scale * other[k] 9 | else: 10 | v[k] = scale * other[k] 11 | 12 | def scale(v, scale): 13 | for k in v.keys(): 14 | v[k] *= scale 15 | 16 | # Returning a new vector or scalar 17 | def dot(v1, v2): 18 | if len(v1) > len(v2): 19 | return dot(v2, v1) 20 | ans = 0 21 | for k in v1: 22 | if k in v2: 23 | ans += v1[k] * v2[k] 24 | return ans 25 | 26 | def sum(v1, v2): 27 | ans = collections.defaultdict(float) 28 | for k in v1: 29 | ans[k] += v1[k] 30 | for k in v2: 31 | ans[k] += v2[k] 32 | return ans 33 | 34 | def l2norm(v): 35 | return math.sqrt(sum(v[k]**2 for k in v)) 36 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -eu -o pipefail 3 | 4 | mkdir -p data 5 | cd data 6 | 7 | ### GloVe vectors ### 8 | if [ ! -d glove ] 9 | then 10 | mkdir glove 11 | cd glove 12 | wget http://nlp.stanford.edu/data/glove.6B.zip 13 | unzip glove.6B.zip 14 | cd .. 15 | fi 16 | 17 | ### HotpotQA ### 18 | if [ ! -d hotpotqa ] 19 | then 20 | mkdir hotpotqa 21 | cd hotpotqa 22 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_train_v1.1.json 23 | wget http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_distractor_v1.json 24 | cd .. 25 | fi 26 | cd .. 27 | 28 | cd nectar 29 | mkdir -p lib 30 | cd lib 31 | 32 | # CoreNLP 3.6.0 33 | corenlp='stanford-corenlp-full-2015-12-09' 34 | if [ ! -d "${corenlp}" ] 35 | then 36 | wget "http://nlp.stanford.edu/software/${corenlp}.zip" 37 | unzip "${corenlp}.zip" 38 | ln -s "${corenlp}" stanford-corenlp 39 | fi 40 | 41 | cd .. 42 | cd .. 43 | pwd 44 | mkdir resources 45 | mkdir out 46 | ln -s ../nectar/nectar resources 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yichen Jiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /nectar/nectar/base/codalab.py: -------------------------------------------------------------------------------- 1 | """Utilities for interacting with codalab.""" 2 | import collections 3 | import subprocess 4 | import sys 5 | 6 | from .. import log 7 | 8 | DOCKER_IMAGE = 'robinjia/robinjia-codalab:2.1.1' 9 | 10 | def run(cmd, deps, name, description, queue='john', host=None, cpus=1, 11 | docker_image=DOCKER_IMAGE, is_theano=False, omp_num_threads=1, 12 | dry_run=False): 13 | params = collections.OrderedDict() 14 | if host: 15 | params['--request-queue'] = 'host=%s' % host 16 | else: 17 | params['--request-queue'] = queue 18 | params['--request-cpus'] = str(cpus) 19 | params['--request-docker-image'] = docker_image 20 | params['-n'] = name 21 | params['-d'] = description 22 | param_list = [x for k_v in params.iteritems() for x in k_v] 23 | if is_theano: 24 | prefix = 'OMP_NUM_THREADS=%d THEANO_FLAGS=blas.ldflags=-lopenblas' % omp_num_threads 25 | cmd = prefix + ' ' + cmd 26 | call_args = ['cl', 'run'] + deps + [cmd] + param_list 27 | if dry_run: 28 | log('Dry run: %s' % str(call_args)) 29 | else: 30 | subprocess.call(call_args) 31 | 32 | def upload(filename, name=None, description=None, dry_run=False): 33 | call_args = ['cl', 'up', filename] 34 | if name: 35 | call_args.extend(['-n', name]) 36 | if description: 37 | call_args.extend(['-d', description]) 38 | if dry_run: 39 | log('Dry run: %s' % str(call_args)) 40 | else: 41 | subprocess.call(call_args) 42 | -------------------------------------------------------------------------------- /nearest_glove/utils.py: -------------------------------------------------------------------------------- 1 | import codecs, json 2 | import numpy as np 3 | 4 | '''Serializable/Pickleable class to replicate the functionality of collections.defaultdict''' 5 | class autovivify_list(dict): 6 | def __missing__(self, key): 7 | value = self[key] = [] 8 | return value 9 | 10 | def __add__(self, x): 11 | '''Override addition for numeric types when self is empty''' 12 | if not self and isinstance(x, Number): 13 | return x 14 | raise ValueError 15 | 16 | def __sub__(self, x): 17 | '''Also provide subtraction method''' 18 | if not self and isinstance(x, Number): 19 | return -1 * x 20 | raise ValueError 21 | 22 | def build_word_vector_matrix(vector_file, n_words): 23 | '''Read a GloVe array from sys.argv[1] and return its vectors and labels as arrays''' 24 | np_arrays = [] 25 | labels_array = [] 26 | 27 | with codecs.open(vector_file, 'r', 'utf-8') as f: 28 | for i, line in enumerate(f): 29 | sr = line.split() 30 | labels_array.append(sr[0]) 31 | np_arrays.append(np.array([float(j) for j in sr[1:]])) 32 | if i == n_words - 1: 33 | return np.array(np_arrays), labels_array 34 | 35 | def get_cache_filename_from_args(args): 36 | a = (args.vector_dim, args.num_words, args.num_clusters) 37 | return '{}D_{}-words_{}-clusters.json'.format(*a) 38 | 39 | def get_label_dictionaries(labels_array): 40 | id_to_word = dict(zip(range(len(labels_array)), labels_array)) 41 | word_to_id = dict((v,k) for k,v in id_to_word.items()) 42 | return word_to_id, id_to_word 43 | 44 | def save_json(filename, results): 45 | with open(filename, 'w') as f: 46 | json.dump(results, f) 47 | 48 | def load_json(filename): 49 | with open(filename, 'r') as f: 50 | return json.load(f) 51 | -------------------------------------------------------------------------------- /nectar/nectar/fig/lisptree.py: -------------------------------------------------------------------------------- 1 | """Utilities for handling fig LispTree objects.""" 2 | 3 | def tokenize(s): 4 | toks = [] 5 | cur_chars = [] 6 | inside_str = False 7 | inside_escape = False 8 | for c in s: 9 | if inside_str: 10 | if inside_escape: 11 | inside_escape = False 12 | cur_chars.append(c) 13 | else: 14 | if c == '"': 15 | inside_str = False 16 | toks.append(''.join(cur_chars)) 17 | cur_chars = [] 18 | elif c == '\\': 19 | inside_escape = True 20 | else: 21 | cur_chars.append(c) 22 | else: 23 | if inside_escape: 24 | inside_escape = False 25 | cur_chars.append(c) 26 | else: 27 | if c in ('(', ')'): 28 | if cur_chars: 29 | toks.append(''.join(cur_chars)) 30 | cur_chars = [] 31 | toks.append(c) 32 | elif c == ' ': 33 | if cur_chars: 34 | toks.append(''.join(cur_chars)) 35 | cur_chars = [] 36 | elif c == '"': 37 | if cur_chars: 38 | raise ValueError('" character found in middle of token') 39 | inside_str = True 40 | elif c == '\\': 41 | inside_escape = True 42 | else: 43 | cur_chars.append(c) 44 | if cur_chars: 45 | toks.append(''.join(cur_chars)) 46 | return toks 47 | 48 | def from_string(s): 49 | """Parse a Java fig LispTree from a string.""" 50 | toks = tokenize(s) 51 | def recurse(i): 52 | if toks[i] == '(': 53 | subtrees = [] 54 | j = i+1 55 | while True: 56 | subtree, j = recurse(j) 57 | subtrees.append(subtree) 58 | if toks[j] == ')': 59 | return tuple(subtrees), j + 1 60 | else: 61 | return toks[i], i+1 62 | lisp_tree, final_ind = recurse(0) 63 | return lisp_tree 64 | -------------------------------------------------------------------------------- /nectar/nectar/base/util.py: -------------------------------------------------------------------------------- 1 | """General, miscellaneous utilities.""" 2 | from contextlib import contextmanager 3 | import sys 4 | import time 5 | 6 | def flatten(x): 7 | """Flatten a list of lists.""" 8 | return [a for b in x for a in b] 9 | 10 | def log(msg, disappearing=False): 11 | if not sys.stdout.isatty(): 12 | # Only print to stdout if it's being redirected or piped 13 | print(msg) 14 | # if disappearing: 15 | # # Trailing comma suppresses newline 16 | # print >> sys.stderr, msg + '\r', 17 | # else: 18 | # print >> sys.stderr, msg 19 | 20 | def log_dict(d, name): 21 | log('%s {' % name) 22 | for k in d: 23 | log(' %s: %s' % (k, str(d[k]))) 24 | log('}') 25 | 26 | SECS_PER_MIN = 60 27 | SECS_PER_HOUR = SECS_PER_MIN * 60 28 | SECS_PER_DAY = SECS_PER_HOUR * 24 29 | 30 | def secs_to_str(secs): 31 | """Convert a number of seconds to human-readable string.""" 32 | days = int(secs) / SECS_PER_DAY 33 | secs -= days * SECS_PER_DAY 34 | hours = int(secs) / SECS_PER_HOUR 35 | secs -= hours * SECS_PER_HOUR 36 | mins = int(secs) / SECS_PER_MIN 37 | secs -= mins * SECS_PER_MIN 38 | if days > 0: 39 | return '%dd%02dh%02dm' % (days, hours, mins) 40 | elif hours > 0: 41 | return '%dh%02dm%02ds' % (hours, mins, int(secs)) 42 | elif mins > 0: 43 | return '%dm%02ds' % (mins, int(secs)) 44 | elif secs >= 1: 45 | return '%.1fs' % secs 46 | return '%.2fs' % secs 47 | 48 | def timed(func, msg, allow_overwrite=True): 49 | msg1 = '%s...' % msg 50 | log(msg1, disappearing=allow_overwrite) 51 | t0 = time.time() 52 | retval = func() 53 | t1 = time.time() 54 | msg2 = '%s [took %s].' % (msg, secs_to_str(t1 - t0)) 55 | log(msg2) 56 | return retval 57 | 58 | @contextmanager 59 | def timer(msg, allow_overwrite=True): 60 | msg1 = '%s...' % msg 61 | log(msg1, disappearing=allow_overwrite) 62 | t0 = time.time() 63 | yield 64 | t1 = time.time() 65 | msg2 = '%s [took %s].' % (msg, secs_to_str(t1 - t0)) 66 | log(msg2) 67 | -------------------------------------------------------------------------------- /nectar/nectar/base/sequences.py: -------------------------------------------------------------------------------- 1 | """Utilities related to sequences.""" 2 | def edit_distance(x1, x2, dist_func=None, gap_penalty=1): 3 | """Compute edit distance between two sequences. 4 | 5 | Args: 6 | x1: First sequence. 7 | x2: Second sequence. 8 | dist_func: Distance score between two tokens (default = Levenstein). 9 | gap_penalty: Penalty for gaps (default = 1, for Levenstein) 10 | """ 11 | if not dist_func: 12 | dist_func = lambda x, y: int(x != y) 13 | n1, n2 = len(x1), len(x2) 14 | scores = [[i * gap_penalty for i in range(n2+1)]] 15 | ptrs = [[None] + [(0, i-1) for i in range(1, n2+1)]] 16 | for i in range(1, n1 + 1): 17 | cur_scores = [scores[i-1][0] + gap_penalty] 18 | cur_ptrs = [(i-1, 0)] 19 | for j in range(1, n2 + 1): 20 | local_score = dist_func(x1[i-1], x2[j-1]) 21 | poss_scores = [scores[i-1][j-1] + local_score, 22 | scores[i-1][j] + gap_penalty, 23 | cur_scores[j-1] + gap_penalty] 24 | score_ind, cur_score = min(enumerate(poss_scores), key=lambda x: x[1]) 25 | cur_scores.append(cur_score) 26 | if score_ind == 0: 27 | cur_ptr = (i-1, j-1) 28 | elif score_ind == 1: 29 | cur_ptr = (i-1, j) 30 | else: 31 | cur_ptr = (i, j-1) 32 | cur_ptrs.append(cur_ptr) 33 | scores.append(cur_scores) 34 | ptrs.append(cur_ptrs) 35 | dist = scores[n1][n2] 36 | return dist, ptrs 37 | 38 | def get_unaligned_spans(x1, x2, ptrs): 39 | """Get unaligned spans from an edit distance pointer matrix.""" 40 | n1, n2 = len(x1), len(x2) 41 | i1, i2 = n1, n2 42 | spans = [] 43 | end1, end2 = None, None 44 | while i1 != 0 or i2 != 0: 45 | i1_new, i2_new = ptrs[i1][i2] 46 | if i1_new == i1 or i2_new == i2 or x1[i1_new] != x2[i2_new]: # mismatch/gap 47 | if end1 is None: 48 | end1, end2 = i1, i2 49 | else: 50 | if end1 is not None: 51 | spans.append(((i1, end1), (i2, end2))) 52 | end1, end2 = None, None 53 | i1, i2 = i1_new, i2_new 54 | if end1 is not None: 55 | spans.append(((0, end1), (0, end2))) 56 | return spans 57 | -------------------------------------------------------------------------------- /nectar/nectar/base/vocabulary.py: -------------------------------------------------------------------------------- 1 | """A basic vocabulary class.""" 2 | import collections 3 | 4 | UNK_TOKEN = '' 5 | UNK_INDEX = 0 6 | 7 | class Vocabulary(object): 8 | def __init__(self, unk_threshold=0): 9 | """Initialize the vocabulary. 10 | 11 | Args: 12 | unk_threshold: words with <= this many counts will be considered . 13 | """ 14 | self.unk_threshold = unk_threshold 15 | self.counts = collections.Counter() 16 | self.word2index = {UNK_TOKEN: UNK_INDEX} 17 | self.word_list = [UNK_TOKEN] 18 | 19 | def add_word(self, word, count=1): 20 | """Add a word (may still map to UNK if it doesn't pass unk_threshold).""" 21 | self.counts[word] += count 22 | if word not in self.word2index and self.counts[word] > self.unk_threshold: 23 | index = len(self.word_list) 24 | self.word2index[word] = index 25 | self.word_list.append(word) 26 | 27 | def add_words(self, words): 28 | for w in words: 29 | self.add_word(w) 30 | 31 | def add_sentence(self, sentence): 32 | self.add_words(sentence.split(' ')) 33 | 34 | def add_sentences(self, sentences): 35 | for s in sentences: 36 | self.add_sentence(s) 37 | 38 | def add_word_hard(self, word): 39 | """Add word, make sure it is not UNK.""" 40 | self.add_word(word, count=(self.unk_threshold+1)) 41 | 42 | def get_word(self, index): 43 | return self.word_list[index] 44 | 45 | def get_index(self, word): 46 | if word in self.word2index: 47 | return self.word2index[word] 48 | return UNK_INDEX 49 | 50 | def indexify_sentence(self, sentence): 51 | return [self.get_index(w) for w in sentence.split(' ')] 52 | 53 | def indexify_list(self, elems): 54 | return [self.get_index(w) for w in elems] 55 | 56 | def recover_sentence(self, indices): 57 | return ' '.join(self.get_word(i) for i in indices) 58 | 59 | def has_word(self, word): 60 | return word in self.word2index 61 | 62 | def __contains__(self, word): 63 | return self.has_word(word) 64 | 65 | def size(self): 66 | # Report number of words that have been assigned an index 67 | return len(self.word2index) 68 | 69 | def __len__(self): 70 | return self.size() 71 | 72 | def __iter__(self): 73 | return iter(self.word_list) 74 | -------------------------------------------------------------------------------- /nectar/nectar/theanoutil/rnn.py: -------------------------------------------------------------------------------- 1 | """Common RNN-related functions.""" 2 | import theano 3 | from theano import tensor as T 4 | from util import printed 5 | 6 | def lstm_split(c_h): 7 | """Split the joint c_t and h_t of the LSTM state.""" 8 | d = c_h.shape[0]/2 9 | c = c_h[:d] 10 | h = c_h[d:] 11 | return c, h 12 | 13 | def lstm_step(c_h_prev, input_t, W_mat): 14 | """The LSTM recurrence. 15 | 16 | Args: 17 | c_h_prev: A vector of size 2d, concatenation of 18 | memory cell c_prev and hidden state h_prev 19 | input_t: Current input, as a vector of size e 20 | W_mat: transition matrix of size (d+e) x (4d) 21 | """ 22 | c_prev, h_prev = lstm_split(c_h_prev) 23 | d = c_prev.shape[0] 24 | vec_t = T.concatenate([h_prev, input_t]) 25 | prod = T.dot(vec_t, W_mat) 26 | i_t = T.nnet.sigmoid(prod[:d]) 27 | f_t = T.nnet.sigmoid(prod[d:2*d]) 28 | o_t = T.nnet.sigmoid(prod[2*d:3*d]) 29 | c_tilde_t = T.tanh(prod[3*d:]) 30 | c_t = f_t * c_prev + i_t * c_tilde_t 31 | h_t = o_t * T.tanh(c_t) 32 | c_h_t = T.concatenate([c_t, h_t]) 33 | return c_h_t 34 | 35 | def batch_lstm_split(c_h): 36 | """Split the joint c_t and h_t of the LSTM state (batch mode).""" 37 | d = c_h.shape[1]/2 38 | c = c_h[:,:d] 39 | h = c_h[:,d:] 40 | return c, h 41 | 42 | def time_batch_lstm_split(c_h): 43 | """Split the joint c_t and h_t of the LSTM state (time + batch mode).""" 44 | d = c_h.shape[2]/2 45 | c = c_h[:,:,:d] 46 | h = c_h[:,:,d:] 47 | return c, h 48 | 49 | def batch_lstm_step(c_h_prev, input_t, W_mat): 50 | """The LSTM recurrence (batch mode). 51 | 52 | Args: 53 | c_h_prev: matrix of size batch_sz x 2d, concatenation of 54 | memory cell c_prev and hidden state h_prev 55 | input_t: Current input, as matrix of size batch_sz x e 56 | W_mat: transition matrix of size (d+e) x (4d) 57 | """ 58 | c_prev, h_prev = batch_lstm_split(c_h_prev) # batch_sz x d each 59 | d = c_prev.shape[1] 60 | vec_t = T.concatenate([h_prev, input_t], axis=1) # batch_sz x (d+e) 61 | prod = T.dot(vec_t, W_mat) # batch_sz x 4d 62 | i_t = T.nnet.sigmoid(prod[:,:d]) 63 | f_t = T.nnet.sigmoid(prod[:,d:2*d]) 64 | o_t = T.nnet.sigmoid(prod[:,2*d:3*d]) 65 | c_tilde_t = T.tanh(prod[:,3*d:]) 66 | c_t = f_t * c_prev + i_t * c_tilde_t 67 | h_t = o_t * T.tanh(c_t) 68 | c_h_t = T.concatenate([c_t, h_t], axis=1) # batch_sz x 2d 69 | return c_h_t 70 | -------------------------------------------------------------------------------- /nectar/nectar/corenlp/server.py: -------------------------------------------------------------------------------- 1 | """Run a CoreNLP Server.""" 2 | import atexit 3 | import errno 4 | import os 5 | import socket 6 | import subprocess 7 | import sys 8 | import time 9 | 10 | LIB_PATH = os.path.join( 11 | os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), 12 | 'lib/stanford-corenlp/*') 13 | DEVNULL = open(os.devnull, 'wb') 14 | 15 | class CoreNLPServer(object): 16 | """An object that runs the CoreNLP server.""" 17 | def __init__(self, port=7000, lib_path=LIB_PATH, flags=None, logfile=None): 18 | """Create the CoreNLPServer object. 19 | 20 | Args: 21 | port: Port on which to serve requests. 22 | flags: If provided, pass this list of additional flags to the java server. 23 | logfile: If provided, log stderr to this file. 24 | lib_path: The path to the CoreNLP *.jar files. 25 | """ 26 | self.port = port 27 | self.lib_path = LIB_PATH 28 | self.process = None 29 | self.p_stderr = None 30 | if flags: 31 | self.flags = flags 32 | else: 33 | self.flags = [] 34 | if logfile: 35 | self.logfd = open(logfile, 'wb') 36 | else: 37 | self.logfd = DEVNULL 38 | 39 | def start(self, flags=None): 40 | """Start up the server on a separate process.""" 41 | print('Using lib directory %s' % self.lib_path) 42 | if not flags: 43 | flags = self.flags 44 | p = subprocess.Popen( 45 | ['java', '-mx4g', '-cp', self.lib_path, 46 | 'edu.stanford.nlp.pipeline.StanfordCoreNLPServer', 47 | '--port', str(self.port)] + flags, 48 | stderr=self.logfd, stdout=self.logfd) 49 | self.process = p 50 | atexit.register(self.stop) 51 | 52 | # Keep trying to connect until the server is up 53 | s = socket.socket() 54 | while True: 55 | time.sleep(1) 56 | try: 57 | s.connect(('127.0.0.1', self.port)) 58 | except socket.error as e: 59 | if e.errno != errno.ECONNREFUSED: 60 | # Something other than Connection refused means server is running 61 | break 62 | s.close() 63 | 64 | 65 | def stop(self): 66 | """Stop running the server on a separate process.""" 67 | if self.process: 68 | self.process.terminate() 69 | if self.logfd != DEVNULL: 70 | self.logfd.close() 71 | 72 | def __enter__(self): 73 | self.start() 74 | return self 75 | 76 | def __exit__(self, type, value, traceback): 77 | self.stop() 78 | -------------------------------------------------------------------------------- /nectar/nectar/theanoutil/treelstm.py: -------------------------------------------------------------------------------- 1 | """Tree-LSTM encoder. 2 | 3 | Based on Tai et al., 2015, 4 | "Improved Semantic Representations From 5 | Tree-Structured Long Short-Term Memory Networks." 6 | 7 | Actually works on any DAG. 8 | 9 | By convention, the root of the tree has only outgoing edges. 10 | In DAG terminology, this means we start at sink nodes 11 | and end at source nodes. 12 | """ 13 | import theano 14 | from theano import tensor as T 15 | from theano.ifelse import ifelse 16 | 17 | import __init__ as ntu 18 | from .. import log 19 | 20 | def encode_child_sum(x_vecs, topo_order, adj_mat, c0, h0, W, U, Uf): 21 | """Run a child-sum tree-LSTM on a DAG. 22 | 23 | Args: 24 | x_vecs: n x e vector of node embeddings 25 | topo_order: a permutation of range(n) that gives a topological sort 26 | i.e. topo_order[i]'s children are in topo_order[:i] 27 | adj_mat: matrix where adj_mat[i,j] == 1 iff there is an i -> j edge. 28 | c0, h0, W, U, Uf: parameters of sizes 1 x d, 1 x d, e x 4d, d x 3d, d x d, respectively. 29 | """ 30 | def recurrence(j, c_mat, h_mat, n, d, *args): 31 | x_j = x_vecs[j] 32 | children = T.eq(adj_mat[j,], 1).nonzero() # let c(j) be number of children of node j 33 | c_children = c_mat[children] # c(j) x d 34 | h_children = h_mat[children] # c(j) x d 35 | # If this node has no children, use c0 and h0; else use c_mat and h_mat 36 | c_prev = ifelse(T.eq(c_children.shape[0], 0), 37 | c0, c_children) # max(c(j), 1) x d 38 | h_prev = ifelse(T.eq(h_children.shape[0], 0), 39 | h0, h_children) # max(c(j), 1) x d 40 | 41 | h_tilde = T.sum(h_prev, axis=0) # d 42 | w_prod = T.dot(x_j, W) # 4d 43 | u_prod = T.dot(h_tilde, U) # 3d 44 | uf_prod = T.dot(h_prev, Uf) # max(c(j), 1) x d 45 | i_j = T.nnet.sigmoid(w_prod[:d] + u_prod[:d]) # d 46 | f_jk = T.nnet.sigmoid(w_prod[d:2*d] + uf_prod) # c(j) x d 47 | o_j = T.nnet.sigmoid(w_prod[2*d:3*d] + u_prod[d:2*d]) # d 48 | u_j = T.tanh(w_prod[3*d:] + u_prod[2*d:]) # d 49 | fc_j = T.sum(f_jk * c_prev, axis=0) # d 50 | c_j = i_j * u_j + fc_j 51 | h_j = o_j * T.tanh(c_j) 52 | 53 | # Update c_mat and h_mat 54 | new_c_mat = T.set_subtensor(c_mat[j], c_j) 55 | new_h_mat = T.set_subtensor(h_mat[j], h_j) 56 | return new_c_mat, new_h_mat 57 | 58 | n = x_vecs.shape[0] 59 | d = U.shape[0] 60 | (c_list, h_list), _ = theano.scan( 61 | recurrence, sequences=[topo_order], 62 | outputs_info=[T.zeros((n, d)), T.zeros((n, d))], 63 | non_sequences=[n, d, x_vecs, adj_mat, c0, h0, W, U, Uf]) 64 | return c_list[-1], h_list[-1] 65 | -------------------------------------------------------------------------------- /nectar/nectar/theanoutil/util.py: -------------------------------------------------------------------------------- 1 | """Some common theano utilities.""" 2 | import numpy as np 3 | import theano 4 | from theano import tensor as T 5 | from theano.ifelse import ifelse 6 | 7 | def printed(var, name=''): 8 | return theano.printing.Print(name)(var) 9 | 10 | def logsumexp(mat, axis=0): 11 | """Apply a row-wise log-sum-exp, summing along axis.""" 12 | maxes = T.max(mat, axis=axis) 13 | return T.log(T.sum(T.exp(mat - maxes), axis=axis)) + maxes 14 | 15 | def clip_gradients(grads, clip_thresh): 16 | """Clip gradients to some total norm.""" 17 | total_norm = T.sqrt(sum(T.sum(g**2) for g in grads)) 18 | scale = ifelse(T.gt(total_norm, clip_thresh), 19 | clip_thresh / total_norm, 20 | np.dtype(theano.config.floatX).type(1.0)) 21 | clipped_grads = [scale * g for g in grads] 22 | return clipped_grads 23 | 24 | def create_grad_cache(param_list, name='grad_cache'): 25 | """Create a grad cache, for things like momentum or AdaGrad.""" 26 | cache = [theano.shared(name='%s_%s' % (p.name, name), 27 | value=np.zeros_like(p.get_value())) 28 | for p in param_list] 29 | return cache 30 | 31 | def get_vanilla_sgd_updates(param_list, gradients, lr): 32 | """Do SGD updates with vanilla step rule.""" 33 | updates = [] 34 | for p, g in zip(param_list, gradients): 35 | new_p = p - lr * g 36 | has_non_finite = T.any(T.isnan(new_p) + T.isinf(new_p)) 37 | updates.append((p, ifelse(has_non_finite, p, new_p))) 38 | return updates 39 | 40 | def get_nesterov_sgd_updates(param_list, gradients, velocities, lr, mu): 41 | """Do SGD updates with Nesterov momentum.""" 42 | updates = [] 43 | for p, g, v in zip(param_list, gradients, velocities): 44 | new_v = mu * v - lr * g 45 | new_p = p - mu * v + (1 + mu) * new_v 46 | has_non_finite = (T.any(T.isnan(new_p) + T.isinf(new_p)) + 47 | T.any(T.isnan(new_v) + T.isinf(new_v))) 48 | updates.append((p, ifelse(has_non_finite, p, new_p))) 49 | updates.append((v, ifelse(has_non_finite, v, new_v))) 50 | return updates 51 | 52 | def plot_learning_curve(data, outfile=None): 53 | if outfile: 54 | import matplotlib 55 | matplotlib.use('Agg') 56 | import matplotlib.pyplot as plt 57 | else: 58 | import matplotlib.pyplot as plt 59 | 60 | plt.figure(figsize=(12, 5)) 61 | for i, (name, cur_data) in enumerate(data): 62 | plt.subplot(1, len(data), i+1) 63 | plt.plot(cur_data) 64 | plt.xlabel('Epochs') 65 | plt.ylabel(name) 66 | plt.ylim(0, 1.1 * max(cur_data)) 67 | 68 | if outfile: 69 | plt.savefig(outfile) 70 | else: 71 | plt.show() 72 | -------------------------------------------------------------------------------- /nearest_glove/get_nearest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/brannondorsey/GloVe-experiments 3 | """ 4 | 5 | import argparse, sys, readline 6 | from scipy.spatial.distance import cosine 7 | from nearest_glove.utils import build_word_vector_matrix, get_label_dictionaries 8 | from difflib import SequenceMatcher 9 | from nltk.corpus import wordnet 10 | 11 | def word_arithmetic(start_word, minus_words, plus_words, word_to_id, id_to_word, df, num_results=5): 12 | '''Returns a word string that is the result of the vector arithmetic''' 13 | try: 14 | start_vec = df[word_to_id[start_word]] 15 | minus_vecs = [df[word_to_id[minus_word]] for minus_word in minus_words] 16 | plus_vecs = [df[word_to_id[plus_word]] for plus_word in plus_words] 17 | except KeyError as err: 18 | return err, None 19 | 20 | result = start_vec 21 | 22 | if minus_vecs: 23 | for i, vec in enumerate(minus_vecs): 24 | result = result - vec 25 | 26 | if plus_vecs: 27 | for i, vec in enumerate(plus_vecs): 28 | result = result + vec 29 | 30 | # result = start_vec - minus_vec + plus_vec 31 | words = [start_word] + minus_words + plus_words 32 | return None, find_nearest(words, result, id_to_word, df, num_results) 33 | 34 | def find_nearest(words, vec, id_to_word, df, num_results, method='cosine'): 35 | 36 | if method == 'cosine': 37 | minim = [] # min, index 38 | for i, v in enumerate(df): 39 | # skip the base word, its usually the closest 40 | if id_to_word[i] in words: 41 | continue 42 | dist = cosine(vec, v) 43 | minim.append((dist, i)) 44 | minim = sorted(minim, key=lambda v: v[0]) 45 | # return list of (word, cosine distance) tuples 46 | return [(id_to_word[minim[i][1]], minim[i][0]) for i in range(num_results)] 47 | else: 48 | raise Exception('{} is not an excepted method parameter'.format(method)) 49 | 50 | def longest_common_substring(string1, string2): 51 | match_obj = SequenceMatcher(None, string1, string2).find_longest_match(0, len(string1), 0, len(string2)) 52 | return match_obj.size 53 | 54 | def find_nearest_word(start_word, df, word_to_id, id_to_word, num_results=5): 55 | err, results = word_arithmetic (start_word=start_word, 56 | minus_words=[], 57 | plus_words=[], 58 | word_to_id=word_to_id, 59 | id_to_word=id_to_word, 60 | df=df, 61 | num_results=num_results) 62 | if results: 63 | res = [t[0] for t in results if t[0] not in start_word and start_word not in t[0] and longest_common_substring(t[0], start_word) < 4] 64 | if len(res) > 0: 65 | return res 66 | else: 67 | return None 68 | else: 69 | return None 70 | 71 | 72 | def build_glove_matrix(num_words=40000): 73 | vector_file = 'data/glove/glove.6B.100d.txt' 74 | df, labels_array = build_word_vector_matrix(vector_file, num_words) 75 | word_to_id, id_to_word = get_label_dictionaries(labels_array) 76 | return df, word_to_id, id_to_word 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial-MultiHopQA 2 | * Official code for our [ACL 2019 paper](https://arxiv.org/pdf/1906.07132.pdf). 3 | * The initial code was adapted from [Adversarial Squad](https://github.com/robinjia/adversarial-squad). 4 | * We adapt the Corenlp code (nectar) to support python 3. 5 | * The arithmetic on Glove word embeddings are adapted from (https://github.com/brannondorsey/GloVe-experiments). 6 | 7 | ### Dependencies 8 | Run `download_data.sh` to pull HotpotQA data and GloVe vectors. 9 | * We tested our code on TF1.3, TF1.8, TF1.11 and TF1.13. 10 | * See `requirements.txt`. 11 | 12 | ### 1. Preprocess the data using Corenlp 13 | Run: 14 | ``` 15 | python3 convert_sp_facts.py corenlp -d dev 16 | ``` 17 | to store preprocessed data in `data/hotpotqa/dev_corenlp_cache_***.json`. This avoids rerunning Corenlp every time we generate an adversarial data. 18 | If you want to create the adversarial training data, run: 19 | ``` 20 | python3 convert_sp_facts.py corenlp -d train 21 | ``` 22 | Warning: preprocessing both the training set and dev set requires a storage space of ~22G. 23 | 24 | 25 | ### 2. Collect the candidate answer and title set 26 | Run: 27 | ``` 28 | python3 convert_sp_facts.py gen-answer-set -d dev 29 | ``` 30 | and 31 | ``` 32 | python3 convert_sp_facts.py gen-title-set -d dev 33 | ``` 34 | This step collect all answers and Wikipedia article titles in the dev set and classify them based on their NER and POS tag. 35 | 36 | 37 | ### 3. (Optional) Collect all paragraphs appearining in the context 38 | If you want to eliminate the title-balancing bias in the adversarial documents (described in the last paragraph of Sec. 2.2), run: 39 | ``` 40 | python3 convert_sp_facts.py gen-all-docs -d dev 41 | ``` 42 | 43 | ### 4. Generate Adverarial Dev set 44 | To generate the adversarial dev set described in our paper, run: 45 | ``` 46 | python3 convert_sp_facts.py dump-addDoc -d dev -b --rule wordnet_dyn_gen --replace_partial_answer --num_new_doc=4 --dont_replace_full_answer --find_nearest_glove --add_doc_incl_adv_title 47 | ``` 48 | This will create the adversarial training set in `out/hotpot_dev_addDoc.json` 49 | Note: `--add_doc_incl_adv_title` can be set only if Step 3 is done. 50 | 51 | 52 | ### 5. Generate Adverarial Training set 53 | Generating the adversarial training set all at once could take days. Therefore, we divide the training set into 19 batches with the size of 5000, and process each batch in a separate program by running: 54 | ``` 55 | python3 convert_sp_facts.py dumpBatch-addDoc -d train -b --rule wordnet_dyn_gen --replace_partial_answer --num_new_doc=4 --dont_replace_full_answer --find_nearest_glove --add_doc_incl_adv_title --batch_idx=0 56 | ``` 57 | with `batch_idx` set to 0~18. After they finish, run: 58 | ``` 59 | python3 convert_sp_facts.py merge_files -d train 60 | ``` 61 | This will create the adversarial training set in `out/hotpot_train_addDoc.json` 62 | 63 | **In order to recreate the adversarial training data we used in the paper, randomly sample 40% of the adversarial training data generated using this code and combine with the original HotpotQA training set.** 64 | 65 | 66 | # Citation 67 | ``` 68 | @inproceedings{Jiang2019reasoningshortcut, 69 | author={Yichen Jiang and Mohit Bansal}, 70 | booktitle={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics}, 71 | title={Avoiding Reasoning Shortcuts: Adversarial Evaluation, Training, and Model Development for Multi-Hop QA}, 72 | year={2019}, 73 | } 74 | ``` 75 | -------------------------------------------------------------------------------- /nectar/nectar/base/trie.py: -------------------------------------------------------------------------------- 1 | """A basic trie.""" 2 | import argparse 3 | import sys 4 | 5 | class Trie(object): 6 | def __init__(self): 7 | self.root = {} 8 | 9 | def add(self, seq): 10 | node = self.root 11 | for i, x in enumerate(seq): 12 | if x not in node: 13 | node[x] = (False, {}) 14 | if i == len(seq) - 1: 15 | node[x] = (True, node[x][1]) 16 | else: 17 | is_terminal, node = node[x] 18 | 19 | def remove(self, seq): 20 | node = self.root 21 | nodes = [] 22 | for i, x in enumerate(seq): 23 | nodes.append(node) 24 | if x not in node: 25 | raise ValueError('Item not found, cannot be removed') 26 | if i == len(seq) - 1: 27 | # Actually remove 28 | node[x] = (False, node[x][1]) 29 | else: 30 | is_terminal, node = node[x] 31 | # Clean up 32 | for i in range(len(seq) - 1, -1, -1): 33 | # nodes[i] contains seq[i] 34 | node = nodes[i] 35 | x = seq[i] 36 | is_terminal, next_node = node[x] 37 | if not is_terminal and not next_node: 38 | del node[x] 39 | else: 40 | break 41 | 42 | def contains(self, seq): 43 | node = self.root 44 | for x in seq: 45 | if x not in node: 46 | return False 47 | is_terminal, node = node[x] 48 | return is_terminal 49 | 50 | def contains_prefix(self, seq): 51 | node = self.root 52 | for x in seq: 53 | if x not in node: 54 | return False 55 | is_terminal, node = node[x] 56 | return True 57 | 58 | def get_node(self, seq): 59 | node = self.root 60 | for x in seq: 61 | if x not in node: 62 | return None 63 | is_terminal, node = node[x] 64 | return node 65 | 66 | def __iter__(self): 67 | stack = [((), self.root)] 68 | while stack: 69 | prefix, node = stack.pop() 70 | for k in node: 71 | new_prefix = prefix + (k,) 72 | is_terminal, new_node = node[k] 73 | if is_terminal: 74 | yield new_prefix 75 | stack.append((new_prefix, new_node)) 76 | 77 | def main(): 78 | trie = Trie() 79 | print 'Running basic tests...' 80 | trie.add((0,)) 81 | trie.add((1, 2, 3)) 82 | assert trie.contains((0,)) == True 83 | assert trie.contains((1, 2, 3)) == True 84 | assert trie.contains((1,)) == False 85 | assert trie.contains_prefix((1,)) == True 86 | assert trie.contains((1, 2)) == False 87 | assert trie.contains_prefix((1, 2)) == True 88 | assert trie.contains((2,)) == False 89 | trie.add((1, 2)) 90 | trie.add((1, 4)) 91 | trie.add((5, 6)) 92 | assert trie.contains((1, 2, 3)) == True 93 | assert trie.contains((1, 2)) == True 94 | assert trie.contains_prefix((1, 2)) == True 95 | assert trie.contains((2,)) == False 96 | assert trie.contains_prefix((2,)) == False 97 | assert trie.contains((5,)) == False 98 | assert trie.contains((1, 4)) == True 99 | assert trie.contains((5, 6)) == True 100 | assert trie.contains_prefix((5,)) == True 101 | trie.remove((1, 2, 3)) 102 | assert trie.contains((1, 2, 3)) == False 103 | assert trie.contains((1, 2)) == True 104 | assert trie.contains_prefix((1, 2)) == True 105 | trie.add((1, 2, 3)) 106 | trie.remove((1, 2)) 107 | trie.add((1,)) 108 | assert trie.contains((1, 2, 3)) == True 109 | assert trie.contains((1, 2)) == False 110 | assert trie.contains((1,)) == True 111 | assert trie.contains_prefix((1, 2)) == True 112 | assert set(trie) == set([(0,), (1,), (1, 2, 3), (1, 4), (5, 6)]) 113 | print trie.root 114 | print 'All pass!' 115 | 116 | if __name__ == '__main__': 117 | main() 118 | -------------------------------------------------------------------------------- /nectar/nectar/corenlp/util.py: -------------------------------------------------------------------------------- 1 | """CoreNLP-related utilities.""" 2 | def rejoin(tokens, sep=None): 3 | """Rejoin tokens into the original sentence. 4 | 5 | Args: 6 | tokens: a list of dicts containing 'originalText' and 'before' fields. 7 | All other fields will be ignored. 8 | sep: if provided, use the given character as a separator instead of 9 | the 'before' field (e.g. if you want to preserve where tokens are). 10 | Returns: the original sentence that generated this CoreNLP token list. 11 | """ 12 | if sep is None: 13 | return ''.join('%s%s' % (t['before'], t['originalText']) for t in tokens) 14 | else: 15 | # Use the given separator instead 16 | return sep.join(t['originalText'] for t in tokens) 17 | 18 | class ConstituencyParse(object): 19 | """A CoreNLP constituency parse (or a node in a parse tree). 20 | 21 | Word-level constituents have |word| and |index| set and no children. 22 | Phrase-level constituents have no |word| or |index| and have at least one child. 23 | """ 24 | def __init__(self, tag, children=None, word=None, index=None): 25 | self.tag = tag 26 | if children: 27 | self.children = children 28 | else: 29 | self.children = None 30 | self.word = word 31 | self.index = index 32 | 33 | @classmethod 34 | def _recursive_parse_corenlp(cls, tokens, i, j): 35 | orig_i = i 36 | if tokens[i] == '(': 37 | tag = tokens[i + 1] 38 | children = [] 39 | i = i + 2 40 | while True: 41 | child, i, j = cls._recursive_parse_corenlp(tokens, i, j) 42 | if isinstance(child, cls): 43 | children.append(child) 44 | if tokens[i] == ')': 45 | return cls(tag, children), i + 1, j 46 | else: 47 | if tokens[i] != ')': 48 | raise ValueError('Expected ")" following leaf') 49 | return cls(tag, word=child, index=j), i + 1, j + 1 50 | else: 51 | # Only other possibility is it's a word 52 | return tokens[i], i + 1, j 53 | 54 | @classmethod 55 | def from_corenlp(cls, s): 56 | """Parses the "parse" attribute returned by CoreNLP parse annotator.""" 57 | # "parse": "(ROOT\n (SBARQ\n (WHNP (WDT What)\n (NP (NN portion)\n (PP (IN of)\n (NP\n (NP (NNS households))\n (PP (IN in)\n (NP (NNP Jacksonville)))))))\n (SQ\n (VP (VBP have)\n (NP (RB only) (CD one) (NN person))))\n (. ? )))", 58 | s_spaced = s.replace('\n', ' ').replace('(', ' ( ').replace(')', ' ) ') 59 | tokens = [t for t in s_spaced.split(' ') if t] 60 | tree, index, num_words = cls._recursive_parse_corenlp(tokens, 0, 0) 61 | if index != len(tokens): 62 | raise ValueError('Only parsed %d of %d tokens' % (index, len(tokens))) 63 | return tree 64 | 65 | def is_singleton(self): 66 | if self.word: return True 67 | if len(self.children) > 1: return False 68 | return self.children[0].is_singleton() 69 | 70 | def print_tree(self, indent=0): 71 | spaces = ' ' * indent 72 | if self.word: 73 | print ('%s%s: %s (%d)' % (spaces, self.tag, self.word, self.index)).encode('utf-8') 74 | else: 75 | print ('%s%s:' % (spaces, self.tag)) 76 | for c in self.children: 77 | c.print_tree(indent=indent + 1) 78 | 79 | def get_phrase(self): 80 | if self.word: return self.word 81 | toks = [] 82 | for i, c in enumerate(self.children): 83 | p = c.get_phrase() 84 | if i == 0 or p.startswith("'"): 85 | toks.append(p) 86 | else: 87 | toks.append(' ' + p) 88 | return ''.join(toks) 89 | 90 | def get_start_index(self): 91 | if self.index is not None: return self.index 92 | return self.children[0].get_start_index() 93 | 94 | def get_end_index(self): 95 | if self.index is not None: return self.index + 1 96 | return self.children[-1].get_end_index() 97 | 98 | @classmethod 99 | def _recursive_replace_words(cls, tree, new_words, i): 100 | if tree.word: 101 | new_word = new_words[i] 102 | return (cls(tree.tag, word=new_word, index=tree.index), i + 1) 103 | new_children = [] 104 | for c in tree.children: 105 | new_child, i = cls._recursive_replace_words(c, new_words, i) 106 | new_children.append(new_child) 107 | return cls(tree.tag, children=new_children), i 108 | 109 | @classmethod 110 | def replace_words(cls, tree, new_words): 111 | """Return a new tree, with new words replacing old ones.""" 112 | new_tree, i = cls._recursive_replace_words(tree, new_words, 0) 113 | if i != len(new_words): 114 | raise ValueError('len(new_words) == %d != i == %d' % (len(new_words), i)) 115 | return new_tree 116 | -------------------------------------------------------------------------------- /nectar/nectar/corenlp/client.py: -------------------------------------------------------------------------------- 1 | """A client for a CoreNLP Server.""" 2 | import json 3 | import os 4 | import requests 5 | 6 | from server import CoreNLPServer 7 | 8 | class CoreNLPClient(object): 9 | """A client that interacts with the CoreNLPServer.""" 10 | def __init__(self, hostname='http://localhost', port=7000, 11 | start_server=False, server_flags=None, server_log=None, 12 | cache_file=None,): 13 | """Create the client. 14 | 15 | Args: 16 | hostname: hostname of server. 17 | port: port of server. 18 | start_server: start the server on first cache miss. 19 | server_flags: passed to CoreNLPServer.__init__() 20 | server_log: passed to CoreNLPServer.__init__() 21 | cache_file: load and save cache to this file. 22 | """ 23 | self.hostname = hostname 24 | self.port = port 25 | self.start_server = start_server 26 | self.server_flags = server_flags 27 | self.server_log = server_log 28 | self.server = None 29 | self.cache_file = cache_file 30 | self.has_cache_misses = False 31 | if cache_file: 32 | if os.path.exists(cache_file): 33 | with open(cache_file) as f: 34 | self.cache = json.load(f) 35 | else: 36 | self.cache = {} 37 | else: 38 | self.cache = None 39 | 40 | def save_cache(self): 41 | if self.cache_file and self.has_cache_misses: 42 | with open(self.cache_file, 'w') as f: 43 | json.dump(self.cache, f) 44 | self.has_cache_misses = False 45 | 46 | def query(self, sents, properties): 47 | """Most general way to query the server. 48 | 49 | Args: 50 | sents: Either a string or a list of strings. 51 | properties: CoreNLP properties to send as part of the request. 52 | """ 53 | url = '%s:%d' % (self.hostname, self.port) 54 | params = {'properties': str(properties)} 55 | if isinstance(sents, list): 56 | data = '\n'.join(sents) 57 | else: 58 | data = sents 59 | key = '%s\t%s' % (data, str(properties)) 60 | if self.cache and key in self.cache: 61 | return self.cache[key] 62 | self.has_cache_misses = True 63 | if self.start_server and not self.server: 64 | self.server = CoreNLPServer(port=self.port, flags=self.server_flags, 65 | logfile=self.server_log) 66 | self.server.start() 67 | r = requests.post(url, params=params, data=data.encode('utf-8')) 68 | r.encoding = 'utf-8' 69 | json_response = json.loads(r.text, strict=False) 70 | if self.cache is not None: 71 | self.cache[key] = json_response 72 | return json_response 73 | 74 | def __enter__(self): 75 | return self 76 | 77 | def __exit__(self, type, value, traceback): 78 | if self.server: 79 | self.server.stop() 80 | if self.cache_file: 81 | self.save_cache() 82 | 83 | def query_pos(self, sents): 84 | """Standard query for getting POS tags.""" 85 | properties = { 86 | 'ssplit.newlineIsSentenceBreak': 'always', 87 | 'annotators': 'tokenize,ssplit,pos', 88 | 'outputFormat':'json' 89 | } 90 | return self.query(sents, properties) 91 | 92 | def query_ner(self, paragraphs): 93 | """Standard query for getting NERs on raw paragraphs.""" 94 | annotators = 'tokenize,ssplit,pos,ner,entitymentions' 95 | properties = { 96 | 'ssplit.newlineIsSentenceBreak': 'always', 97 | 'annotators': annotators, 98 | 'outputFormat':'json' 99 | } 100 | return self.query(paragraphs, properties) 101 | 102 | def query_depparse_ptb(self, sents, use_sd=False): 103 | """Standard query for getting dependency parses on PTB-tokenized input.""" 104 | annotators = 'tokenize,ssplit,pos,depparse' 105 | properties = { 106 | 'tokenize.whitespace': True, 107 | 'ssplit.eolonly': True, 108 | 'ssplit.newlineIsSentenceBreak': 'always', 109 | 'annotators': annotators, 110 | 'outputFormat':'json' 111 | } 112 | if use_sd: 113 | # Use Stanford Dependencies trained on PTB 114 | # Default is Universal Dependencies 115 | properties['depparse.model'] = 'edu/stanford/nlp/models/parser/nndep/english_SD.gz' 116 | return self.query(sents, properties) 117 | 118 | def query_depparse(self, sents, use_sd=False, add_ner=False): 119 | """Standard query for getting dependency parses on raw sentences.""" 120 | annotators = 'tokenize,ssplit,pos,depparse' 121 | if add_ner: 122 | annotators += ',ner' 123 | properties = { 124 | 'ssplit.eolonly': True, 125 | 'ssplit.newlineIsSentenceBreak': 'always', 126 | 'annotators': annotators, 127 | 'outputFormat':'json' 128 | } 129 | if use_sd: 130 | # Use Stanford Dependencies trained on PTB 131 | # Default is Universal Dependencies 132 | properties['depparse.model'] = 'edu/stanford/nlp/models/parser/nndep/english_SD.gz' 133 | return self.query(sents, properties) 134 | 135 | def query_const_parse(self, sents, add_ner=False): 136 | """Standard query for getting constituency parses on raw sentences.""" 137 | annotators = 'tokenize,ssplit,pos,parse' 138 | if add_ner: 139 | annotators += ',ner' 140 | properties = { 141 | 'ssplit.eolonly': True, 142 | 'ssplit.newlineIsSentenceBreak': 'always', 143 | 'annotators': annotators, 144 | 'outputFormat':'json' 145 | } 146 | return self.query(sents, properties) 147 | -------------------------------------------------------------------------------- /nectar/nectar/base/intervals.py: -------------------------------------------------------------------------------- 1 | """Data structures for dealing with intervals.""" 2 | import argparse 3 | import sys 4 | 5 | OPTS = None 6 | 7 | class Interval(object): 8 | """Represents a half-open interval.""" 9 | def __init__(self, start, end, value=None): 10 | self.start = start 11 | self.end = end 12 | self.value = value 13 | 14 | def contains_pt(self, pt): 15 | return pt >= self.start and pt < self.end 16 | 17 | def contains(self, other): 18 | return self.start <= other.start and self.end >= other.end 19 | 20 | def overlaps(self, other, closed_boundaries=False): 21 | if closed_boundaries: 22 | return not (self.start > other.end or self.end < other.start) 23 | return not (self.start >= other.end or self.end <= other.start) 24 | 25 | def overlap_len(self, other): 26 | start = max(self.start, other.start) 27 | end = min(self.end, other.end) 28 | if start > end: return 0 29 | return end - start 30 | 31 | def length(self): 32 | return self.end - self.start 33 | 34 | def __key(self): 35 | return (self.start, self.end) 36 | 37 | def __eq__(self, other): 38 | return self.__key() == other.__key() 39 | 40 | def __hash__(self): 41 | return hash(self.__key()) 42 | 43 | def __lt__(self, other): 44 | return self.__key() < other.__key() 45 | 46 | def __str__(self): 47 | return str((self.start, self.end)) 48 | 49 | class IntervalSet(object): 50 | """Represents a monotincally growing set of half-open intervals.""" 51 | def __init__(self): 52 | self.intervals = [] 53 | 54 | @classmethod 55 | def from_list(cls, interval_list): 56 | ret = cls() 57 | ret.intervals = sorted(interval_list) 58 | return ret 59 | 60 | def _extend_match(self, index, interval, closed_boundaries=False): 61 | # Linear search forward and backward 62 | start = index 63 | while start-1 >= 0 and self.intervals[start-1].overlaps(interval): 64 | start -= 1 65 | end = index 66 | while end+1 < len(self.intervals) and self.intervals[end+1].overlaps(interval): 67 | end += 1 68 | return (start, end + 1) # Return half-open interval 69 | 70 | def search(self, interval, closed_boundaries=False): 71 | """Search for all overlapping intervals. 72 | 73 | Returns (Found/not Found, start_ind, end_ind) 74 | """ 75 | lo = 0 76 | hi = len(self.intervals) 77 | while hi - lo > 4: 78 | mid = (lo + hi) / 2 79 | x = self.intervals[mid] 80 | if x.overlaps(interval, closed_boundaries=closed_boundaries): 81 | return (True,) + self._extend_match(mid, interval, closed_boundaries=closed_boundaries) 82 | elif x.start >= interval.end: 83 | hi = mid 84 | elif x.end <= interval.start: 85 | lo = mid 86 | for i in range(lo, hi): 87 | x = self.intervals[i] 88 | if x.overlaps(interval): 89 | return (True,) + self._extend_match(i, interval, closed_boundaries=closed_boundaries) 90 | elif x.start >= interval.end: 91 | # We should insert at this index 92 | return (False, i, i) 93 | return (False, hi, hi) 94 | 95 | def contains(self, interval): 96 | found, start_ind, end_ind = self.search(interval) 97 | if not found: return False 98 | if start_ind - end_ind != 1: return False 99 | return self.intervals[start_ind].contains(interval) 100 | 101 | def overlaps(self, interval, closed_boundaries=False): 102 | return self.search(interval, closed_boundaries=closed_boundaries)[0] 103 | 104 | def add(self, interval): 105 | found, start_ind, end_ind = self.search(interval, closed_boundaries=True) 106 | if found: 107 | if start_ind - end_ind == 1 and self.intervals[start_ind].contains(interval): 108 | return 109 | before_elems = self.intervals[:start_ind] 110 | after_elems = self.intervals[end_ind:] 111 | new_start = min((interval.start, self.intervals[start_ind].start)) 112 | new_end = max((interval.end, self.intervals[end_ind - 1].end)) 113 | interval = Interval(new_start, new_end) 114 | self.intervals = before_elems + [interval] + after_elems 115 | else: 116 | self.intervals.insert(start_ind, interval) 117 | # Check if sorted 118 | # if any(self.intervals[i] > self.intervals[i+1] for i in range(len(self.intervals) - 1)): 119 | # for x in self.intervals: 120 | # print x 121 | # raise ValueError('Not sorted!') 122 | 123 | # if self.contains(interval): return 124 | # new_start = interval.start 125 | # new_end = interval.end 126 | # to_remove = set() 127 | # for x in self.intervals: 128 | # if interval.contains(x): 129 | # to_remove.add(x) 130 | # else: 131 | # if x.end >= interval.start and x.start < new_start: 132 | # new_start = x.start 133 | # to_remove.add(x) 134 | # if x.start <= interval.end and x.end > new_end: 135 | # new_end = x.end 136 | # to_remove.add(x) 137 | # new_interval = Interval(new_start, new_end) 138 | # for x in to_remove: 139 | # self.intervals.remove(x) 140 | # self.intervals.append(new_interval) 141 | # self.intervals.sort() # TODO: binary search? 142 | 143 | def complement(self, interval, min_size=0): 144 | """Return the complement of current set within the given interval.""" 145 | cur_start = interval.start 146 | new_intervals = [] 147 | for x in self.intervals: # Relies on self.intervals being sorted 148 | if x.end < interval.start: continue 149 | if x.start > interval.end: continue 150 | if x.start < interval.start: 151 | cur_start = x.end 152 | else: 153 | cur_end = x.start 154 | new_intervals.append((Interval(cur_start, cur_end))) 155 | cur_start = x.end 156 | if cur_start < interval.end: 157 | new_intervals.append(Interval(cur_start, interval.end)) 158 | return IntervalSet.from_list(new_intervals) 159 | -------------------------------------------------------------------------------- /nectar/nectar/theanoutil/args.py: -------------------------------------------------------------------------------- 1 | """Add standard theano-related flags to an argparse.ArgumentParser.""" 2 | import argparse 3 | import sys 4 | import theano 5 | 6 | from .. import log, log_dict 7 | 8 | class NLPArgumentParser(argparse.ArgumentParser): 9 | """An ArgumentParser with some built-in arguments. 10 | 11 | Allows you to not have to retype help messages every time. 12 | """ 13 | def add_flag_helper(self, long_name, short_name, *args, **kwargs): 14 | long_flag = '--%s' % long_name 15 | if 'help' in kwargs: 16 | if 'default' in kwargs: 17 | # Append default value to help message 18 | new_help = '%s (default=%s)' % (kwargs['help'], str(kwargs['default'])) 19 | kwargs['help'] = new_help 20 | # Append period to end, if missing 21 | if not kwargs['help'].endswith('.'): 22 | kwargs['help'] = kwargs['help'] + '.' 23 | if short_name: 24 | short_flag = '-%s' % short_name 25 | self.add_argument(long_flag, short_flag, *args, **kwargs) 26 | else: 27 | self.add_argument(long_flag, *args, **kwargs) 28 | 29 | # Model hyperparameters 30 | def add_hidden_size(self, short_name=None): 31 | self.add_flag_helper('hidden-size', short_name, type=int, 32 | help='Dimension of hidden vectors') 33 | def add_emb_size(self, short_name=None): 34 | self.add_flag_helper('emb-size', short_name, type=int, 35 | help='Dimension of word vectors') 36 | def add_weight_scale(self, short_name=None, default=1e-1): 37 | self.add_flag_helper('weight-scale', short_name, type=float, default=default, 38 | help='Weight scale for initialization') 39 | def add_l2_reg(self, short_name=None, default=0.0): 40 | self.add_flag_helper('l2-reg', short_name, type=float, default=default, 41 | help='L2 Regularization constant.') 42 | def add_unk_cutoff(self, short_name=None): 43 | self.add_flag_helper('unk-cutoff', short_name, type=int, default=0, 44 | help='Treat input words with <= this many occurrences as UNK') 45 | 46 | # Training hyperparameters 47 | def add_num_epochs(self, short_name=None): 48 | self.add_flag_helper( 49 | 'num-epochs', short_name, default=[], type=lambda s: [int(x) for x in s.split(',')], 50 | help=('Number of epochs to train. If comma-separated list, will run for some epochs, halve learning rate, etc.')) 51 | def add_learning_rate(self, short_name=None, default=0.1): 52 | self.add_flag_helper('learning-rate', short_name, type=float, default=default, 53 | help='Initial learning rate.') 54 | def add_clip_thresh(self, short_name=None): 55 | self.add_flag_helper('clip-thresh', short_name, type=float, default=1.0, 56 | help='Total-norm threshold to clip gradients.') 57 | def add_batch_size(self, short_name=None): 58 | self.add_flag_helper('batch-size', short_name, type=int, default=1, 59 | help='Maximum batch size') 60 | # Decoding hyperparameters 61 | def add_beam_size(self, short_name=None): 62 | self.add_flag_helper('beam-size', short_name, type=int, default=0, 63 | help='Use beam search with given beam size, or greedy if 0') 64 | # Data 65 | def add_train_file(self, short_name=None): 66 | self.add_flag_helper('train-file', short_name, help='Path to training data') 67 | def add_dev_file(self, short_name=None): 68 | self.add_flag_helper('dev-file', short_name, help='Path to dev data') 69 | def add_test_file(self, short_name=None): 70 | self.add_flag_helper('test-file', short_name, help='Path to test data') 71 | def add_dev_frac(self, short_name=None): 72 | self.add_flag_helper('dev-frac', short_name, type=float, default=0.0, 73 | help='Take this fraction of train data as dev data') 74 | 75 | # Random seeds 76 | def add_dev_seed(self, short_name=None): 77 | self.add_flag_helper('dev-seed', short_name, type=int, default=0, 78 | help='RNG seed for the train/dev splits') 79 | def add_model_seed(self, short_name=None): 80 | self.add_flag_helper('model-seed', short_name, type=int, default=0, 81 | help="RNG seed for the model") 82 | 83 | # Sasving and loading 84 | def add_save_file(self, short_name=None): 85 | self.add_flag_helper('save-file', short_name, help='Path for saving model') 86 | def add_load_file(self, short_name=None): 87 | self.add_flag_helper('load-file', short_name, help='Path for loading model') 88 | 89 | # Output 90 | def add_stats_file(self, short_name=None): 91 | self.add_flag_helper('stats-file', short_name, help='File to write stats JSON') 92 | def add_html_file(self, short_name=None): 93 | self.add_flag_helper('html-file', short_name, help='File to write output HTML') 94 | 95 | def add_theano_flags(self): 96 | self.add_flag_helper('theano-fast-compile', None, action='store_true', help='Run Theano in fast compile mode') 97 | self.add_flag_helper('theano-profile', None, action='store_true', help='Turn on profiling in Theano') 98 | 99 | def parse_args(self): 100 | """Configure theano and print help on empty arguments.""" 101 | if len(sys.argv) == 1: 102 | self.print_help() 103 | sys.exit(1) 104 | args = super(NLPArgumentParser, self).parse_args() 105 | log_dict(vars(args), 'Command-line Arguments') 106 | configure_theano(args) 107 | return args 108 | 109 | 110 | def configure_theano(opts): 111 | """Configure theano given arguments passed in.""" 112 | if opts.theano_fast_compile: 113 | theano.config.mode='FAST_COMPILE' 114 | theano.config.optimizer = 'None' 115 | theano.config.traceback.limit = 20 116 | else: 117 | theano.config.mode='FAST_RUN' 118 | theano.config.linker='cvm' 119 | if opts.theano_profile: 120 | theano.config.profile = True 121 | -------------------------------------------------------------------------------- /nectar/nectar/theanoutil/model.py: -------------------------------------------------------------------------------- 1 | """Standard utilties for a theano model.""" 2 | import collections 3 | import numbers 4 | import numpy as np 5 | import pickle 6 | import random 7 | import sys 8 | import theano 9 | import time 10 | from Tkinter import TclError 11 | 12 | import __init__ as ntu 13 | from .. import log, secs_to_str 14 | 15 | class TheanoModel(object): 16 | """A generic theano model. 17 | 18 | This class handles some standard boilerplate. 19 | Current features include: 20 | Basic training loop 21 | Saving and reloading of model 22 | 23 | An implementing subclass must override the following methods: 24 | self.__init__(*args, **kwargs) 25 | self.init_params() 26 | self.setup_theano_funcs() 27 | self.get_metrics(example) 28 | self.train_one(example, lr) 29 | self.evaluate(dataset) 30 | 31 | A basic self.__init__() routine is provided here, just as an example. 32 | Most users should override __init__() to perform additional functionality. 33 | 34 | See these methods for more details. 35 | """ 36 | def __init__(self): 37 | """A minimal example of what functions must be called during initialization. 38 | 39 | Implementing subclasses should override this method, 40 | but maintain the basic functionality presented here. 41 | """ 42 | # self.params, self.params_list, and self.param_names are required by self.create_matrix() 43 | self.params = {} 44 | self.param_list = [] 45 | self.param_names = [] 46 | 47 | # Initialize parameters 48 | self.init_params() 49 | 50 | # If desired, set up grad norm caches for momentum, AdaGrad, etc. here. 51 | # It must be after params are initialized but before theano functionss are created. 52 | # self.velocities = nt.create_grad_cache(self.param_list) 53 | 54 | # Set up theano functions 55 | self.theano_funcs = {} 56 | self.setup_theano_funcs() 57 | 58 | def init_params(self): 59 | """Initialize parameters with repeated calls to self.create_matrix().""" 60 | raise NotImplementedError 61 | 62 | def setup_theano_funcs(self): 63 | """Create theano.function objects for this model in self.theano_funcs.""" 64 | raise NotImplementedError 65 | 66 | def get_metrics(self, example): 67 | """Get accuracy metrics on a single example. 68 | 69 | Args: 70 | example: An example (possibly batch) 71 | lr: Current learning rate 72 | Returns: 73 | dictionary mapping metric_name to (value, weight); 74 | |weight| is used to compute weighted average over dataset. 75 | """ 76 | raise NotImplementedError 77 | 78 | def train_one(self, example, lr): 79 | """Run training on a single example. 80 | 81 | Args: 82 | example: An example (possibly batch) 83 | lr: Current learning rate 84 | Returns: 85 | dictionary mapping metric_name to (value, weight); 86 | |weight| is used to compute weighted average over dataset. 87 | """ 88 | raise NotImplementedError 89 | 90 | def create_matrix(self, name, shape, weight_scale, value=None): 91 | """A helper method that creates a parameter matrix.""" 92 | if value: 93 | pass 94 | elif shape: 95 | value = weight_scale * np.random.uniform(-1.0, 1.0, shape).astype( 96 | theano.config.floatX) 97 | else: 98 | # None means it's a scalar 99 | dtype = np.dtype(theano.config.floatX) 100 | value = dtype.type(weight_scale * np.random.uniform(-1.0, 1.0)) 101 | mat = theano.shared(name=name, value=value) 102 | self.params[name] = mat 103 | self.param_list.append(mat) 104 | self.param_names.append(name) 105 | 106 | def train(self, train_data, lr_init, epochs, dev_data=None, rng_seed=0, 107 | plot_metric=None, plot_outfile=None): 108 | """Train the model. 109 | 110 | Args: 111 | train_data: A list of training examples 112 | lr_init: Initial learning rate 113 | epochs: An integer number of epochs to train, or a list of integers, 114 | where we halve the learning rate after each period. 115 | dev_data: A list of dev examples, evaluate loss on this each epoch. 116 | rng_seed: Random seed for shuffling the dataset at each epoch. 117 | plot_metric: If True, plot a learning for the given metric. 118 | plot_outfile: If provided, save learning curve to file. 119 | """ 120 | random.seed(rng_seed) 121 | train_data = list(train_data) 122 | lr = lr_init 123 | if isinstance(epochs, numbers.Number): 124 | lr_changes = [] 125 | num_epochs = epochs 126 | else: 127 | lr_changes = set([sum(epochs[:i]) for i in range(1, len(epochs))]) 128 | num_epochs = sum(epochs) 129 | num_epochs_digits = len(str(num_epochs)) 130 | train_plot_list = [] 131 | dev_plot_list = [] 132 | str_len_dict = collections.defaultdict(int) 133 | len_time = 0 134 | for epoch in range(num_epochs): 135 | t0 = time.time() 136 | random.shuffle(train_data) 137 | if epoch in lr_changes: 138 | lr *= 0.5 139 | train_metric_list = [] 140 | for ex in train_data: 141 | cur_metrics = self.train_one(ex, lr) 142 | train_metric_list.append(cur_metrics) 143 | if dev_data: 144 | dev_metric_list = [self.get_metrics(ex) for ex in dev_data] 145 | else: 146 | dev_metric_list = [] 147 | t1 = time.time() 148 | 149 | # Compute the averaged metrics 150 | train_metrics = aggregate_metrics(train_metric_list) 151 | dev_metrics = aggregate_metrics(dev_metric_list) 152 | if plot_metric: 153 | train_plot_list.append(train_metrics[plot_metric]) 154 | if dev_metrics: 155 | dev_plot_list.append(dev_metrics[plot_metric]) 156 | 157 | # Some formatting to make things align in columns 158 | train_str = format_epoch_str('train', train_metrics, str_len_dict) 159 | dev_str = format_epoch_str('dev', dev_metrics, str_len_dict) 160 | metric_str = ', '.join(x for x in [train_str, dev_str] if x) 161 | time_str = secs_to_str(t1 - t0) 162 | len_time = max(len(time_str), len_time) 163 | log('Epoch %s: %s [lr = %.1e] [took %s]' % ( 164 | str(epoch+1).rjust(num_epochs_digits), metric_str, lr, 165 | time_str.rjust(len_time))) 166 | 167 | if plot_metric: 168 | plot_data = [('%s on train data' % plot_metric, train_plot_list)] 169 | if dev_plot_list: 170 | plot_data.append(('%s on dev data' % plot_metric, dev_plot_list)) 171 | try: 172 | ntu.plot_learning_curve(plot_data, outfile=plot_outfile) 173 | except TclError: 174 | print >> sys.stderr, 'Encoutered error while plotting learning curve' 175 | 176 | 177 | def evaluate(self, dataset): 178 | """Evaluate the model.""" 179 | metrics_list = [self.get_metrics(ex) for ex in dataset] 180 | return aggregate_metrics(metrics_list) 181 | 182 | def save(self, filename): 183 | # Save 184 | tf = self.theano_funcs 185 | params = self.params 186 | param_list = self.param_list 187 | # Don't pickle theano functions 188 | self.theano_funcs = None 189 | # CPU/GPU portability 190 | self.params = {k: v.get_value() for k, v in params.iteritems()} 191 | self.param_list = None 192 | # Any other things to do before saving 193 | saved = self._prepare_save() 194 | with open(filename, 'wb') as f: 195 | pickle.dump(self, f) 196 | # Restore 197 | self.theano_funcs = tf 198 | self.params = params 199 | self.param_list = param_list 200 | self._after_save(saved) 201 | 202 | def _prepare_save(self): 203 | """Any additional things before calling pickle.dump().""" 204 | pass 205 | 206 | def _after_save(self, saved): 207 | """Any additional things after calling pickle.dump().""" 208 | pass 209 | 210 | @classmethod 211 | def load(cls, filename): 212 | with open(filename, 'rb') as f: 213 | model = pickle.load(f) 214 | # Recreate theano shared variables 215 | params = model.params 216 | model.params = {} 217 | model.param_list = [] 218 | for name in model.param_names: 219 | value = params[name] 220 | mat = theano.shared(name=name, value=value) 221 | model.params[name] = mat 222 | model.param_list.append(mat) 223 | model._after_load() 224 | # Recompile theano functions 225 | model.theano_funcs = {} 226 | model.setup_theano_funcs() 227 | return model 228 | 229 | def _after_load(self): 230 | """Any additional things after calling pickle.load().""" 231 | pass 232 | 233 | def aggregate_metrics(metric_list): 234 | metrics = collections.OrderedDict() 235 | if metric_list: 236 | keys = metric_list[0].keys() 237 | for k in keys: 238 | numer = sum(x[k][0] * x[k][1] for x in metric_list) 239 | denom = sum(x[k][1] for x in metric_list) 240 | metrics[k] = float(numer) / denom 241 | return metrics 242 | 243 | def format_epoch_str(name, metrics, str_len_dict): 244 | if not metrics: return '' 245 | toks = [] 246 | for k in metrics: 247 | val_str = '%.4f' % metrics[k] 248 | len_key = '%s:%s' % (name, k) 249 | str_len_dict[len_key] = max(str_len_dict[len_key], len(val_str)) 250 | cur_str = '%s=%s' % (k, val_str.rjust(str_len_dict[len_key])) 251 | toks.append(cur_str) 252 | return '%s(%s)' % (name, ', '.join(toks)) 253 | 254 | -------------------------------------------------------------------------------- /nectar/nectar/base/graph.py: -------------------------------------------------------------------------------- 1 | """A directed graph.""" 2 | import collections 3 | import copy 4 | import numpy as np 5 | import pygraphviz as pgz 6 | 7 | class Graph(object): 8 | """A labeled, unweighted directed graph.""" 9 | def __init__(self): 10 | self.nodes = [] 11 | self.edges = [] # triples (i, j, label) 12 | self.label2index = collections.defaultdict(set) 13 | self.out_edges = collections.defaultdict(set) 14 | self.in_edges = collections.defaultdict(set) 15 | self.edge_to_label = {} 16 | self.conn_comps = [] # Connected components 17 | 18 | @classmethod 19 | def make_chain(cls, nodes): 20 | """Make a chain-structured graph from the list of nodes.""" 21 | g = cls() 22 | for n in nodes: 23 | g.add_node(n) 24 | for i in range(len(nodes) - 1): 25 | g.add_edge(i, i+1) 26 | return g 27 | 28 | @classmethod 29 | def from_string(cls, s): 30 | """Load a Graph from a string generated by make_string()""" 31 | g = cls() 32 | toks = s.split(' ') 33 | nodes = toks[:-1] 34 | edges = [x.split(',') for x in toks[-1].split(';')] 35 | for n in nodes: 36 | g.add_node(n) 37 | for e in edges: 38 | e_new = [int(e[0]), int(e[1])] + e[2:] 39 | g.add_edge(*e_new) 40 | return g 41 | 42 | def make_string(self): 43 | """Serialize the graph as a string.""" 44 | edge_str = ';'.join('%d,%d,%s' % (i, j, lab) for i, j, lab in self.edges) 45 | return '%s %s' % (' '.join(self.nodes), edge_str) 46 | 47 | def get_num_nodes(self): 48 | return len(self.nodes) 49 | 50 | def get_num_edges(self): 51 | return len(self.edges) 52 | 53 | def add_node(self, node_label): 54 | new_index = len(self.nodes) 55 | self.nodes.append(node_label) 56 | self.label2index[node_label].add(new_index) 57 | self.conn_comps.append(set([new_index])) 58 | 59 | def check_index_in_range(self, ind): 60 | if ind < 0 or ind >= len(self.nodes): 61 | raise ValueError('Index %d not in range (len(nodes) == %d)' % ( 62 | ind, len(self.nodes))) 63 | 64 | def add_edge(self, start, end, label='_'): 65 | self.check_index_in_range(start) 66 | self.check_index_in_range(end) 67 | if (start, end) in self.edge_to_label: 68 | raise ValueError('Edge between %d and %d already exists' % (start, end)) 69 | self.edges.append((start, end, label)) 70 | self.out_edges[start].add(end) 71 | self.in_edges[end].add(start) 72 | self.edge_to_label[(start, end)] = label 73 | ind_start = self.find_conn_comp(start) 74 | ind_end = self.find_conn_comp(end) 75 | if ind_start != ind_end: 76 | self.conn_comps[ind_start] |= self.conn_comps[ind_end] 77 | self.conn_comps.pop(ind_end) 78 | 79 | def add_graph(self, other): 80 | base_index = len(self.nodes) 81 | for label in other.nodes: 82 | self.add_node(label) 83 | for i, j, label in other.edges: 84 | self.add_edge(base_index + i, base_index + j, label) 85 | 86 | def find_conn_comp(self, index): 87 | self.check_index_in_range(index) 88 | for i, cc in enumerate(self.conn_comps): 89 | if index in cc: return i 90 | raise ValueError('Connected components missing node index %d' % index) 91 | 92 | def has_edge(self, start, end, label=None): 93 | """Return if there exists an edge from start to end.""" 94 | if end not in self.out_edges[start]: return False 95 | return (not label) or self.edge_to_label[(start, end)] == label 96 | 97 | def has_undirected_edge(self, start, end, label=None): 98 | """Return if there exists an edge from start to end or end to start.""" 99 | return self.has_edge(start, end, label) or self.has_edge(end, start, label) 100 | 101 | def get_adjacency_matrix(self): 102 | """Get a matrix where mat[i,j] == 1 iff there is an i->j edge.""" 103 | n = len(self.nodes) 104 | mat = np.zeros((n, n), dtype=np.int64) 105 | for i, j, label in self.edges: 106 | mat[i,j] = 1 107 | return mat 108 | 109 | def toposort(self, start_at_sink=False): 110 | """Return a topological sort of the nodes. 111 | 112 | In particular, finds a permutation topo_order of range(len(self.nodes)) 113 | such that topo_order[i]'s parents are in topo_order[:i]. 114 | In other words, the topological order starts with source nodes 115 | and ends with sink nodes. 116 | 117 | Args: 118 | start_at_sink: if True, start at sink nodes and end at source nodes. 119 | Returns: 120 | A topological ordering of the nodes, or None if the graph is not a DAG. 121 | """ 122 | topo_order = [] 123 | in_degrees = [len(self.in_edges[i]) for i in range(len(self.nodes))] 124 | source_nodes = [i for i, d in enumerate(in_degrees) if d == 0] 125 | while len(topo_order) < len(self.nodes): 126 | if len(source_nodes) == 0: 127 | return None # graph is not a DAG 128 | i = source_nodes.pop() 129 | topo_order.append(i) 130 | for j in self.out_edges[i]: 131 | in_degrees[j] -= 1 132 | if in_degrees[j] == 0: 133 | source_nodes.append(j) 134 | if start_at_sink: 135 | topo_order = topo_order[::-1] 136 | return topo_order 137 | 138 | def is_connected(self): 139 | return len(self.conn_comps) <= 1 140 | 141 | def __str__(self): 142 | node_str = ','.join(self.nodes) 143 | edge_str = ';'.join('(%s)' % ','.join(str(t) for t in e) for e in self.edges) 144 | return '(%s, nodes=[%s], edges=[%s])' % (self.__class__, node_str, edge_str) 145 | 146 | def to_agraph(self, id_prefix=''): 147 | """Return a pygraphviz AGraph representation of the graph.""" 148 | def make_id(s): 149 | return '%s-%s' % (id_prefix, s) if id_prefix else s 150 | ag = pgz.AGraph(directed=True) 151 | for i, label in enumerate(self.nodes): 152 | ag.add_node(i, label=label, id=make_id('node%d' % i)) 153 | for index, (i, j, label) in enumerate(self.edges): 154 | ag.add_edge(i, j, label=label, id=make_id('edge%d' % index)) 155 | return ag 156 | 157 | def draw_svg(self, id_prefix='', filename=None, horizontal=False): 158 | """Render the graph as SVG, either to a string or to a file.""" 159 | ag = self.to_agraph(id_prefix=id_prefix) 160 | args = '-Grankdir=LR' if horizontal else '' 161 | ag.layout('dot', args=args) 162 | if filename: 163 | # Write to file 164 | svg_str = ag.draw(filename) 165 | else: 166 | # Write to string, return the string 167 | svg_str = ag.draw(format='svg') 168 | start_ind = svg_str.index(' 0 218 | 219 | def can_add_edge(self, start, end, label): 220 | for func in self.funcs: 221 | if self.parent_graph.has_edge(func[start], func[end], label): 222 | return True 223 | return False 224 | 225 | def can_add_graph(self, other): 226 | base_index = len(self.nodes) 227 | for func in self.funcs: 228 | success = True 229 | g = copy.deepcopy(self) # Need to deepcopy since we need to mutate 230 | for label in other.nodes: 231 | if g.can_add_node(label): 232 | g.add_node(label) 233 | else: 234 | success = False 235 | break 236 | if not success: continue 237 | for i, j, label in other.edges: 238 | if g.can_add_edge(base_index + i, base_index + j, label): 239 | g.add_edge(base_index + i, base_index + j, label) 240 | else: 241 | success = False 242 | break 243 | if not success: continue 244 | if success: return True 245 | return False 246 | 247 | def is_finished(self): 248 | return (len(self.nodes) == len(self.parent_graph.nodes) and 249 | len(self.edges) == len(self.parent_graph.edges)) 250 | 251 | def get_valid_new_nodes(self): 252 | """Get a list of all node labels that can be added.""" 253 | return list(x for x in self.counts_left if self.counts_left[x] > 0) 254 | -------------------------------------------------------------------------------- /bridge_entity_rules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import string 4 | import json 5 | import numpy as np 6 | from pattern.en import conjugate 7 | from resources.nectar import corenlp 8 | from nltk.corpus import wordnet 9 | from nltk.corpus import stopwords 10 | 11 | COMMON_WORDS = list(stopwords.words('english')) 12 | PUNCTUATIONS = list(string.punctuation) 13 | 14 | FIRST_NAMES = ['Jason', 'Mary', 'James', 'Jeff', 'Abi', 'Bran', 'Sansa', 'Jon', 'Ned', 'Peter', 'Jaime', \ 15 | 'Marcus', 'Chris', 'Diana', 'Phoebe', 'Leo', 'Phil', 'Nick', 'Steve'] 16 | LAST_NAMES = ['Kid', 'Jordan', 'Harden', 'Dean', 'Stark', 'Parker', 'Morris', 'Wallace', 'Manning', 'Rogers', 'Folt', 'White'] 17 | LOCATIONS = ['Chicago', 'Beijing', 'Tokyo', 'Pittsburg', 'Paris', 'Barcelona', 'Madrid', 'Berlin', 'Europe', 'California'] 18 | ORGANIZATIONS_START = LAST_NAMES + LOCATIONS 19 | ORGANIZATIONS_END = ['Corporations', 'Industries', 'University', 'Association', 'Department'] 20 | NNP_START = ['Central', 'Eastern', 'Western', 'Golden', 'Stony', 'Student', 'Brooks'] 21 | NNP_END = ['Park', 'House', 'Center', 'Palace', 'Place', 'Store'] 22 | NNPS_START = NNP_START 23 | NNPS_END = ['Parks', 'Gardens', 'Bullets', 'Lakers', 'Brothers'] 24 | NNPS = ['Cool Kids', 'Kew Gardens', 'Silver Bullets', 'LA Lakers', 'Brooks Brothers'] 25 | NN = ['hamster', 'composer', 'man', 'statement'] 26 | NNS = ['hamsters', 'composers', 'men', 'statements'] 27 | 28 | 29 | def lookup_title_generate(checker, rule): 30 | ''' 31 | uses cached list of old answers to generate a real answer, return None if it 32 | does not pass the checker function 33 | ''' 34 | 35 | def func(a, tokens, question, title_cache, **kwargs): 36 | fake = checker(a, tokens, question, **kwargs) 37 | if fake is None: 38 | return None 39 | 40 | tok_len = len(tokens) 41 | counter = 0 42 | new_ans = a 43 | while new_ans == a: 44 | if tok_len <= 0: 45 | return None 46 | 47 | counter2 = 0 48 | while True: 49 | if str(tok_len) in title_cache[rule]: 50 | new_ans, new_ans_tok = random.choice(title_cache[rule][str(tok_len)]) 51 | if a.lower().startswith('the ') and (new_ans.lower().startswith('the ') is False) \ 52 | or a.lower().startswith('the ') is False and new_ans.lower().startswith('the '): 53 | counter2 += 1 54 | if counter2 == 40: 55 | break 56 | else: 57 | break 58 | else: 59 | tok_len -= 1 60 | if tok_len <= 0: 61 | return None 62 | 63 | counter += 1 64 | if counter == 40: 65 | tok_len -= 1 66 | counter = 0 67 | 68 | if a.lower().startswith('the ') and (new_ans.lower().startswith('the ') is False): 69 | new_ans = 'The ' + new_ans 70 | new_ans_tok = [{'originalText': 'The', 'pos': 'DT', 'word': 'the'}] + new_ans_tok 71 | assert counter2 == 40 72 | 73 | return new_ans, new_ans_tok 74 | return func 75 | 76 | 77 | MONTHS = ['january', 'february', 'march', 'april', 'may', 'june', 'july', 78 | 'august', 'september', 'october', 'november', 'december'] 79 | 80 | 81 | def bridge_date(a, tokens, q, **kwargs): 82 | out_toks = [] 83 | if not all(t['ner'] == 'DATE' for t in tokens): return None 84 | for t in tokens: 85 | if t['pos'] == 'CD' or t['word'].isdigit(): 86 | try: 87 | value = int(t['word']) 88 | except: 89 | value = 10 # fallback 90 | if value > 50: 91 | rand = np.random.randint(0, 10) 92 | if rand%2 == 0: 93 | new_val = str(value - np.random.randint(10, 25)) # Year 94 | else: 95 | new_val = str(value + np.random.randint(10, 25)) # Year 96 | else: # Day of month 97 | if value > 15: new_val = str(value - np.random.randint(1, 12)) 98 | else: new_val = str(value + np.random.randint(1, 12)) 99 | else: 100 | if t['word'].lower() in MONTHS: 101 | m_ind = MONTHS.index(t['word'].lower()) 102 | new_val = MONTHS[(m_ind + np.random.randint(1, 11)) % 12].title() 103 | else: 104 | # Give up 105 | new_val = t['originalText'] 106 | out_toks.append({'before': t['before'], 'originalText': new_val}) 107 | new_ans = corenlp.rejoin(out_toks).strip() 108 | if new_ans == a: return None 109 | return new_ans, out_toks 110 | 111 | 112 | def bridge_number(a, tokens, q, **kwargs): 113 | """ 114 | Difference with ans_number: 115 | 1. Not changing 'thousand', 'million', etc. 116 | 2. Change trailing digit 117 | """ 118 | out_toks = [] 119 | seen_num = False 120 | for t in tokens: 121 | ner = t['ner'] 122 | pos = t['pos'] 123 | w = t['word'] 124 | out_tok = {'before': t['before']} 125 | 126 | # Split on dashes 127 | leftover = '' 128 | dash_toks = w.split('-') 129 | if len(dash_toks) > 1: 130 | w = dash_toks[0] 131 | leftover = '-'.join(dash_toks[1:]) 132 | 133 | # Try to get a number out 134 | value = None 135 | if w != '%': 136 | # Percent sign should just pass through 137 | try: 138 | value = float(w.replace(',', '')) 139 | except: 140 | try: 141 | norm_ner = t['normalizedNER'] 142 | if norm_ner[0] in ('%', '>', '<'): 143 | norm_ner = norm_ner[1:] 144 | value = float(norm_ner) 145 | except: 146 | pass 147 | if not value and ( 148 | ner == 'NUMBER' or 149 | (ner == 'PERCENT' and pos == 'CD')): 150 | # Force this to be a number anyways 151 | value = 10 152 | if value: 153 | if math.isinf(value) or math.isnan(value): value = 9001 154 | seen_num = True 155 | if w in ('thousand', 'million', 'billion', 'trillion'): 156 | new_val = w 157 | else: 158 | if value < 2500: # This could be years, so don't change too much 159 | rand = np.random.randint(0, 10) 160 | if rand%2 == 0: 161 | new_val = str(value - np.random.randint(1, 11)) 162 | else: 163 | new_val = str(value + np.random.randint(1, 11)) 164 | else: 165 | # Change leading digit 166 | if value == int(value): 167 | val_chars = list('%d' % value) 168 | else: 169 | val_chars = list('%g' % value) 170 | c = val_chars[-1] 171 | for i in range(len(val_chars)): 172 | c = val_chars[len(val_chars)-1-i] 173 | if c >= '0' and c <= '9': 174 | val_chars[len(val_chars)-1-i] = str(max((int(c) + np.random.randint(1, 10)) % 10, 1)) 175 | break 176 | new_val = ''.join(val_chars) 177 | if leftover: 178 | new_val = '%s-%s' % (new_val, leftover) 179 | out_tok['originalText'] = new_val 180 | else: 181 | out_tok['originalText'] = t['originalText'] 182 | 183 | if t['originalText'].endswith('.0') is False and out_tok['originalText'].endswith('.0'): 184 | out_tok['originalText'] = out_tok['originalText'][:-2] 185 | out_toks.append(out_tok) 186 | if seen_num: 187 | return corenlp.rejoin(out_toks).strip(), out_toks 188 | else: 189 | return None 190 | 191 | 192 | def process_token(word, original_tok): 193 | new_word = word 194 | if original_tok['pos'].startswith('V'): 195 | if original_tok['pos'] == 'VB': 196 | new_word = conjugate(word, 'VB') 197 | elif original_tok['pos'] == 'VBD': 198 | new_word = conjugate(word, 'VBD') 199 | elif original_tok['pos'] == 'VBN': 200 | new_word = conjugate(word, 'VBN') 201 | elif original_tok['pos'] == 'VBG': 202 | new_word = conjugate(word, 'VBG') 203 | elif original_tok['pos'] == 'VBZ': 204 | new_word = conjugate(word, 'VBZ') 205 | elif original_tok['pos'] == 'VBP': 206 | new_word = conjugate(word, 'VBP') 207 | return new_word 208 | 209 | 210 | def bridge_wordnet_catch_amap(a, tokens, q, **kwargs): 211 | """Returns a function that yields new_ans if the wordnet can find its antonyms""" 212 | new_anss = [] 213 | for t in tokens: 214 | if t['originalText'].lower() in COMMON_WORDS + PUNCTUATIONS: 215 | new_anss.append(t['originalText']) 216 | continue 217 | antonyms = [], [] 218 | for syn in wordnet.synsets(t['originalText']): 219 | for l in syn.lemmas(): 220 | if l.antonyms(): 221 | antonyms.append(l.antonyms()[0].name()) 222 | 223 | new_word = None 224 | if t['pos'].startswith('VB') or t['pos'].startswith('JJ') or t['pos'].startswith('R'): 225 | for w in antonyms: 226 | if w.lower() != t['originalText'].lower() and t['originalText'] not in w.lower() and '_' not in w: 227 | new_word = process_token(w, t) 228 | break 229 | if new_word: 230 | new_anss.append(new_word) 231 | else: 232 | return None 233 | 234 | if len(new_anss) == 0: 235 | return None 236 | for new_ans in new_anss: 237 | if new_ans.lower() != a.lower(): 238 | return new_ans 239 | return None 240 | 241 | 242 | def bridge_entity_full(ner_tag, new_ans): 243 | """Returns a function that yields new_ans iff every token has |ner_tag|.""" 244 | def func(a, tokens, q, is_end=False, **kwargs): 245 | for t in tokens: 246 | if t['ner'] != ner_tag: return None 247 | if ner_tag == 'PERSON': 248 | if is_end: 249 | return LAST_NAMES[random.randint(0, len(LAST_NAMES)-1)] 250 | else: 251 | return FIRST_NAMES[random.randint(0, len(FIRST_NAMES)-1)] 252 | elif ner_tag == 'LOCATION': 253 | return LOCATIONS[random.randint(0, len(LOCATIONS)-1)] 254 | elif ner_tag == 'ORGANIZATION': 255 | if is_end: 256 | return ORGANIZATIONS_END[random.randint(0, len(ORGANIZATIONS_END)-1)] 257 | else: 258 | return ORGANIZATIONS_START[random.randint(0, len(ORGANIZATIONS_START)-1)] 259 | return new_ans 260 | return func 261 | 262 | 263 | def bridge_abbrev(new_ans): 264 | def func(a, tokens, q, **kwargs): 265 | s = a 266 | if s == s.upper() and s != s.lower(): 267 | return new_ans 268 | return None 269 | return func 270 | 271 | 272 | def bridge_pos(pos, new_ans, end=False, add_dt=False): 273 | """Returns a function that yields new_ans if the first/last token has |pos|.""" 274 | def func(a, tokens, q, is_end=True, **kwargs): 275 | if end: 276 | for it in range(len(tokens)): 277 | t = tokens[-1-it] 278 | if t['originalText'] not in PUNCTUATIONS: 279 | break 280 | else: 281 | t = tokens[0] 282 | if t['pos'] != pos: return None 283 | if pos == 'NN': 284 | return NN[random.randint(0, len(NN)-1)] 285 | if pos == 'NNS': 286 | return NNS[random.randint(0, len(NNS)-1)] 287 | if pos == 'NNP': 288 | if is_end: 289 | return NNP_END[random.randint(0, len(NNP_END)-1)] 290 | else: 291 | return NNP_START[random.randint(0, len(NNP_START)-1)] 292 | if pos == 'NNPS': 293 | if is_end: 294 | return NNPS_END[random.randint(0, len(NNPS_END)-1)] 295 | else: 296 | return NNPS_START[random.randint(0, len(NNPS_START)-1)] 297 | return new_ans 298 | return func 299 | 300 | 301 | def bridge_catch_all(new_ans): 302 | def func(a, tokens, q, **kwargs): 303 | if tokens[0]['originalText'][0].isupper(): 304 | return new_ans[0].upper()+new_ans[1:] 305 | return new_ans 306 | return func -------------------------------------------------------------------------------- /answer_rules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import string 4 | import json 5 | import numpy as np 6 | from pattern.en import conjugate 7 | from resources.nectar import corenlp 8 | from nltk.corpus import wordnet 9 | from nltk.corpus import stopwords 10 | 11 | COMMON_WORDS = list(stopwords.words('english')) 12 | PUNCTUATIONS = list(string.punctuation) 13 | 14 | 15 | def ans_number(a, tokens, q, **kwargs): 16 | out_toks = [] 17 | seen_num = False 18 | for t in tokens: 19 | ner = t['ner'] 20 | pos = t['pos'] 21 | w = t['word'] 22 | out_tok = {'before': t['before']} 23 | 24 | # Split on dashes 25 | leftover = '' 26 | dash_toks = w.split('-') 27 | if len(dash_toks) > 1: 28 | w = dash_toks[0] 29 | leftover = '-'.join(dash_toks[1:]) 30 | 31 | # Try to get a number out 32 | value = None 33 | if w != '%': 34 | # Percent sign should just pass through 35 | try: 36 | value = float(w.replace(',', '')) 37 | except: 38 | try: 39 | norm_ner = t['normalizedNER'] 40 | if norm_ner[0] in ('%', '>', '<'): 41 | norm_ner = norm_ner[1:] 42 | value = float(norm_ner) 43 | except: 44 | pass 45 | if not value and ( 46 | ner == 'NUMBER' or 47 | (ner == 'PERCENT' and pos == 'CD')): 48 | # Force this to be a number anyways 49 | value = 10 50 | if value: 51 | if math.isinf(value) or math.isnan(value): value = 9001 52 | seen_num = True 53 | if w in ('thousand', 'million', 'billion', 'trillion'): 54 | if w == 'thousand': 55 | new_val = 'million' 56 | else: 57 | new_val = 'thousand' 58 | else: 59 | if value < 2500 and value > 1000: 60 | rand = np.random.randint(0, 10) 61 | if rand%2 == 0: 62 | new_val = str(value - np.random.randint(1, 11)) 63 | else: 64 | new_val = str(value + np.random.randint(1, 11)) 65 | else: 66 | # Change leading digit 67 | if value == int(value): 68 | val_chars = list('%d' % value) 69 | else: 70 | val_chars = list('%g' % value) 71 | c = val_chars[0] 72 | for i in range(len(val_chars)): 73 | c = val_chars[i] 74 | if c >= '0' and c <= '9': 75 | val_chars[i] = str(max((int(c) + np.random.randint(1, 10)) % 10, 1)) 76 | break 77 | new_val = ''.join(val_chars) 78 | if leftover: 79 | new_val = '%s-%s' % (new_val, leftover) 80 | out_tok['originalText'] = new_val 81 | else: 82 | out_tok['originalText'] = t['originalText'] 83 | 84 | if t['originalText'].endswith('.0') is False and out_tok['originalText'].endswith('.0'): 85 | out_tok['originalText'] = out_tok['originalText'][:-2] 86 | 87 | out_toks.append(out_tok) 88 | if seen_num: 89 | return corenlp.rejoin(out_toks).strip() 90 | else: 91 | return None 92 | 93 | 94 | MONTHS = ['january', 'february', 'march', 'april', 'may', 'june', 'july', 95 | 'august', 'september', 'october', 'november', 'december'] 96 | 97 | 98 | def ans_date(a, tokens, q, **kwargs): 99 | out_toks = [] 100 | if not all(t['ner'] == 'DATE' for t in tokens): return None 101 | for t in tokens: 102 | if t['pos'] == 'CD' or t['word'].isdigit(): 103 | try: 104 | value = int(t['word']) 105 | except: 106 | value = 10 # fallback 107 | if value > 50: 108 | rand = np.random.randint(0, 10) 109 | if rand%2 == 0: 110 | new_val = str(value - np.random.randint(10, 25)) # Year 111 | else: 112 | new_val = str(value + np.random.randint(10, 25)) # Year 113 | else: # Day of month 114 | if value > 15: new_val = str(value - np.random.randint(1, 12)) 115 | else: new_val = str(value + np.random.randint(1, 12)) 116 | else: 117 | if t['word'].lower() in MONTHS: 118 | m_ind = MONTHS.index(t['word'].lower()) 119 | new_val = MONTHS[(m_ind + np.random.randint(1, 11)) % 12].title() 120 | else: 121 | # Give up 122 | new_val = t['originalText'] 123 | out_toks.append({'before': t['before'], 'originalText': new_val}) 124 | new_ans = corenlp.rejoin(out_toks).strip() 125 | if new_ans == a: return None 126 | return new_ans 127 | 128 | 129 | def lookup_answer_generate(checker, rule): 130 | ''' 131 | uses cached list of old answers to generate a real answer, return None if it 132 | does not pass the checker function 133 | ''' 134 | def func(a, tokens, question, ans_cache, **kwargs): 135 | fake = checker(a, tokens, question, **kwargs) 136 | if fake is None: 137 | return None 138 | 139 | tok_len = len(tokens) 140 | counter = 0 141 | new_ans = a 142 | while new_ans == a: 143 | if tok_len <= 0: 144 | return None 145 | counter2 = 0 146 | while True: 147 | if str(tok_len) in ans_cache[rule]: 148 | new_ans, new_ans_tok = random.choice(ans_cache[rule][str(tok_len)]) 149 | if a.lower().startswith('the ') and (new_ans.lower().startswith('the ') is False) \ 150 | or a.lower().startswith('the ') is False and new_ans.lower().startswith('the '): 151 | counter2 += 1 152 | if counter2 == 40: 153 | break 154 | else: 155 | break 156 | else: 157 | tok_len -= 1 158 | if tok_len <= 0: 159 | return None 160 | 161 | counter += 1 162 | if counter == 40: 163 | tok_len -= 1 164 | counter = 0 165 | 166 | if a.lower().startswith('the ') and (new_ans.lower().startswith('the ') is False): 167 | new_ans = 'The ' + new_ans 168 | new_ans_tok = [{'originalText': 'The', 'pos': 'DT', 'word': 'the'}] + new_ans_tok 169 | assert counter2 == 40 170 | 171 | return new_ans, new_ans_tok 172 | return func 173 | 174 | 175 | FIRST_NAMES = ['Jason', 'Mary', 'James', 'Jeff', 'Abi', 'Bran', 'Sansa', 'Jon', 'Ned', 'Peter', 'Jaime', \ 176 | 'Marcus', 'Chris', 'Diana', 'Phoebe', 'Leo', 'Phil', 'Nick', 'Steve'] 177 | LAST_NAMES = ['Kid', 'Jordan', 'Harden', 'Dean', 'Stark', 'Parker', 'Morris', 'Wallace', 'Manning', 'Rogers', 'Folt', 'White'] 178 | LOCATIONS = ['Chicago', 'New York', 'Beijing', 'Tokyo', 'Pittsburg', 'Los Angeles', 'Paris', 'Barcelona', 'Madrid', 'Berlin'] 179 | ORGANIZATIONS = ['Stark Industries', 'Google Inc', 'Baidu Inc', 'Nike Corp', 'House of Stark', 'University of Southern Texas', \ 180 | 'National Student Association', 'Facebook', 'Department of Education'] 181 | NNP = ['Central Park', 'Student Store', 'White House', 'Pacific Ocean', 'Gourmet Center', 'Golden Palace', 'Stony River', 'Staples Center'] 182 | NNPS = ['Cool Kids', 'Kew Gardens', 'Silver Bullets', 'LA Lakers', 'Brooks Brothers'] 183 | NN = ['hamster', 'composer', 'man', 'statement'] 184 | NNS = ['hamsters', 'composers', 'men', 'statements'] 185 | 186 | 187 | def process_token(word, original_tok): 188 | new_word = word 189 | if original_tok['pos'].startswith('V'): 190 | if original_tok['pos'] == 'VB': 191 | new_word = conjugate(word, 'VB') 192 | elif original_tok['pos'] == 'VBD': 193 | new_word = conjugate(word, 'VBD') 194 | elif original_tok['pos'] == 'VBN': 195 | new_word = conjugate(word, 'VBN') 196 | elif original_tok['pos'] == 'VBG': 197 | new_word = conjugate(word, 'VBG') 198 | elif original_tok['pos'] == 'VBZ': 199 | new_word = conjugate(word, 'VBZ') 200 | elif original_tok['pos'] == 'VBP': 201 | new_word = conjugate(word, 'VBP') 202 | return new_word 203 | 204 | 205 | def ans_wordnet_catch_amap(a, tokens, q, **kwargs): 206 | new_ans, new_ans_tok = [], [] 207 | for t in tokens: 208 | if t['originalText'].lower() in COMMON_WORDS + PUNCTUATIONS: 209 | new_ans.append(t['originalText']) 210 | new_ans_tok.append(t) 211 | continue 212 | synonyms, antonyms = [], [] 213 | for syn in wordnet.synsets(t['originalText']): 214 | for l in syn.lemmas(): 215 | synonyms.append(l.name()) 216 | if l.antonyms(): 217 | antonyms.append(l.antonyms()[0].name()) 218 | 219 | new_word = None 220 | if t['pos'].startswith('VB') or t['pos'].startswith('JJ') or t['pos'].startswith('R'): 221 | for w in antonyms: 222 | if w.lower() != t['originalText'].lower() and t['originalText'] not in w.lower() and '_' not in w: 223 | new_word = process_token(w, t) 224 | break 225 | if new_word is None: 226 | for w in synonyms: 227 | if w.lower() not in t['originalText'].lower() and t['originalText'] not in w.lower() and '_' not in w: 228 | new_word = process_token(w, t) 229 | break 230 | if new_word and new_word not in t['originalText'] and t['originalText'] not in new_word: 231 | new_ans.append(new_word) 232 | new_ans_tok.append({'originalText': new_word, 'pos': t['pos']}) 233 | else: 234 | return None 235 | new_ans = ' '.join(new_ans) 236 | if new_ans.lower() == a.lower(): 237 | return None 238 | return new_ans, new_ans_tok 239 | 240 | 241 | def ans_entity_full(ner_tag, new_ans): 242 | """Returns a function that yields new_ans iff every token has |ner_tag|.""" 243 | def func(a, tokens, q, **kwargs): 244 | for t in tokens: 245 | if t['ner'] != ner_tag: return None 246 | if ner_tag == 'PERSON': 247 | fname = FIRST_NAMES[random.randint(0, len(FIRST_NAMES)-1)] 248 | lname = LAST_NAMES[random.randint(0, len(LAST_NAMES)-1)] 249 | return fname + ' ' + lname 250 | elif ner_tag == 'LOCATIONS': 251 | return LOCATIONS[random.randint(0, len(LOCATIONS)-1)] 252 | elif ner_tag == 'ORGANIZATION': 253 | return ORGANIZATIONS[random.randint(0, len(ORGANIZATIONS)-1)] 254 | return new_ans 255 | return func 256 | 257 | 258 | def ans_abbrev(new_ans): 259 | def func(a, tokens, q, **kwargs): 260 | s = a 261 | if s == s.upper() and s != s.lower(): 262 | return new_ans 263 | return None 264 | return func 265 | 266 | 267 | def ans_match_wh(wh_word, new_ans): 268 | """Returns a function that yields new_ans if the question starts with |wh_word|.""" 269 | def func(a, tokens, q, **kwargs): 270 | if q.lower().startswith(wh_word + ' '): 271 | if wh_word == 'who': 272 | fname = FIRST_NAMES[random.randint(0, len(FIRST_NAMES)-1)] 273 | lname = LAST_NAMES[random.randint(0, len(LAST_NAMES)-1)] 274 | return fname + ' ' + lname 275 | elif wh_word == 'where': 276 | return LOCATIONS[random.randint(0, len(LOCATIONS)-1)] 277 | return new_ans 278 | return None 279 | return func 280 | 281 | 282 | def ans_pos(pos, new_ans, end=False, add_dt=False): 283 | """Returns a function that yields new_ans if the first/last token has |pos|.""" 284 | def func(a, tokens, q, determiner, **kwargs): 285 | if end: 286 | for it in range(len(tokens)): 287 | t = tokens[-1-it] 288 | if t['originalText'] not in PUNCTUATIONS: 289 | break 290 | else: 291 | t = tokens[0] 292 | if t['pos'] != pos: return None 293 | if add_dt and determiner: 294 | if pos == 'NN': 295 | return '%s %s' % (determiner, NN[random.randint(0, len(NN)-1)]) 296 | if pos == 'NNS': 297 | return '%s %s' % (determiner, NNS[random.randint(0, len(NNS)-1)]) 298 | if pos == 'NNP': 299 | return '%s %s' % (determiner, NNP[random.randint(0, len(NNP)-1)]) 300 | if pos == 'NNPS': 301 | return '%s %s' % (determiner, NNPS[random.randint(0, len(NNPS)-1)]) 302 | return '%s %s' % (determiner, new_ans) 303 | if pos == 'NN': 304 | return NN[random.randint(0, len(NN)-1)] 305 | if pos == 'NNS': 306 | return NNS[random.randint(0, len(NNS)-1)] 307 | if pos == 'NNP': 308 | return NNP[random.randint(0, len(NNP)-1)] 309 | if pos == 'NNPS': 310 | return NNPS[random.randint(0, len(NNPS)-1)] 311 | return new_ans 312 | return func 313 | 314 | 315 | def ans_catch_all(new_ans): 316 | def func(a, tokens, q, **kwargs): 317 | if tokens[0]['originalText'][0].isupper(): 318 | return new_ans[0].upper()+new_ans[1:] 319 | return new_ans 320 | return func -------------------------------------------------------------------------------- /convert_sp_facts.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import collections 3 | import json 4 | import math 5 | import os 6 | import string 7 | import re 8 | import numpy as np 9 | import itertools 10 | import random 11 | import nltk 12 | import urllib.parse 13 | from tqdm import tqdm 14 | # sys.path.append('./en/wordnet') 15 | # sys.path.append('./en') 16 | sys.path.append('./resources/nectar/corenlp') 17 | sys.path.append('./resources/nectar/base') 18 | sys.path.append('./resources/nectar') 19 | import argparse 20 | from resources.nectar import corenlp 21 | from answer_rules import ans_date, ans_number, ans_entity_full, ans_abbrev, ans_match_wh, ans_pos, ans_catch_all, ans_wordnet_catch_amap, lookup_answer_generate 22 | from bridge_entity_rules import bridge_entity_full, bridge_abbrev, bridge_wordnet_catch_amap, bridge_pos, bridge_number, lookup_title_generate, bridge_date 23 | from nearest_glove.get_nearest import build_glove_matrix, find_nearest_word 24 | from nltk.corpus import stopwords 25 | COMMON_WORDS = list(stopwords.words('english')) 26 | PUNCTUATIONS = list(string.punctuation) + ['–'] 27 | SHORT_PUNCT = PUNCTUATIONS.copy() + [' '] 28 | PUNCT_COMMON = PUNCTUATIONS + COMMON_WORDS + [' '] 29 | del SHORT_PUNCT[SHORT_PUNCT.index('+')] 30 | OPTS = None 31 | SOURCE_DIR = os.path.join("data", "hotpotqa") 32 | DATASETS = { 33 | 'dev': os.path.join(SOURCE_DIR, 'hotpot_dev_distractor_v1.json'), 34 | 'train': os.path.join(SOURCE_DIR, 'hotpot_train_v1.1.json'), 35 | } 36 | 37 | CORENLP_CACHES = { 38 | 'dev': 'data/hotpotqa/dev_corenlp_cache_', 39 | 'train': 'data/hotpotqa/train_corenlp_cache_', 40 | } 41 | 42 | COMMANDS = ['corenlp', 'dump-addDoc', 'gen-answer-set', 'gen-title-set', 'gen-all-docs'] 43 | CORENLP_PORT = 8765 44 | CORENLP_LOG = 'corenlp.log' 45 | CATCH_ALL_NUM = 0 46 | 47 | def parse_args(): 48 | parser = argparse.ArgumentParser('Generate adversarial support facts for HotpotQA.') 49 | parser.add_argument('command', 50 | help='Command (options: [%s]).' % (', '.join(COMMANDS))) 51 | parser.add_argument('--substitute_bridge_entities', '-b', default=False, action='store_true') 52 | parser.add_argument('--dataset', '-d', default='dev', 53 | help='Which dataset (options: [%s])' % (', '.join(DATASETS))) 54 | parser.add_argument('--prepend', '-p', default=False, 55 | action='store_true', help='Prepend fake answer to the original answer.') 56 | parser.add_argument('--quiet', '-q', default=False, action='store_true') 57 | parser.add_argument('--rule', '-r', default='wordnet_dyn_gen', help='[wordnet | wordnet_dyn_gen]') 58 | parser.add_argument('--seed', '-s', default=-1, type=int, help='Shuffle with RNG seed.') 59 | parser.add_argument('--split', default='0-1000') 60 | parser.add_argument('--replace_partial_answer', default=False, action='store_true') 61 | parser.add_argument('--num_new_doc', default=1, type=int) 62 | parser.add_argument('--dont_replace_full_answer', default=False, action='store_true', help='If true, only a few answer words, instead of the full answer span, will be replaced') 63 | parser.add_argument('--dont_replace_full_title', default=False, action='store_true', help='If true, only a few title words, instead of the full title span, will be replaced') 64 | parser.add_argument('--find_nearest_glove', default=False, action='store_true') 65 | parser.add_argument('--num_glove_words_to_use', default=100000, type=int) 66 | parser.add_argument('--batch_idx', default=0, type=int) 67 | parser.add_argument('--batch_size', default=5000, type=int) 68 | parser.add_argument('--add_doc_incl_adv_title', default=False, action='store_true') 69 | 70 | if len(sys.argv) == 1: 71 | parser.print_help() 72 | sys.exit(1) 73 | return parser.parse_args() 74 | 75 | 76 | def read_data(): 77 | filename = DATASETS[OPTS.dataset] 78 | with open(filename) as f: 79 | return json.load(f) 80 | 81 | 82 | def run_corenlp(dataset, bsz=5000): 83 | with corenlp.CoreNLPServer(port=CORENLP_PORT, logfile=CORENLP_LOG) as server: 84 | client = corenlp.CoreNLPClient(port=CORENLP_PORT) 85 | for ib in range(int(len(dataset)/bsz) + 1): 86 | cache_file = CORENLP_CACHES[OPTS.dataset] + str(ib*bsz) + '-' + str(min((ib+1)*bsz, len(dataset))) + '.json' 87 | print(cache_file) 88 | cache = {} 89 | print('Running NER for paragraphs...') 90 | for ie, e in tqdm(enumerate(dataset[ib*bsz : (ib+1)*bsz])): 91 | context, question, answer, supports = e['context'], e['question'], e['answer'], e['supporting_facts'] 92 | titles, partial_titles, sp_doc_ids = [], [], [] 93 | for si, doc in enumerate(context): 94 | title = doc[0] 95 | titles.append(title) 96 | title_split = re.split("([{}])".format("()"), title) 97 | if title_split[0] != '': 98 | partial_titles.append(title_split[0]) 99 | elif title_split[-1] != '': 100 | partial_titles.append(title_split[-1]) 101 | else: 102 | real_title = title_split[title_split.index('(')+1] 103 | assert real_title != ')' 104 | partial_titles.append(real_title) 105 | for sp_doc_title, sent_id in supports: 106 | sp_doc_id = titles.index(sp_doc_title) 107 | if sp_doc_id not in sp_doc_ids: 108 | sp_doc_ids.append(sp_doc_id) 109 | 110 | response_context, response_title = [], [] 111 | for _id, doc in enumerate(context): 112 | response_doc = [] 113 | response_context.append(response_doc) 114 | response_title.append(client.query_ner(partial_titles[_id])) 115 | if _id not in sp_doc_ids: 116 | continue 117 | for sent in doc[1]: 118 | if sent == '' or sent == ' ': 119 | continue 120 | response = client.query_ner(sent) 121 | response_doc.append(response) 122 | 123 | cache[e['_id']] = [response_context] 124 | response_a = client.query_ner(answer) 125 | cache[e['_id']].append(response_a) 126 | response_q = client.query_ner(question) 127 | cache[e['_id']].append(response_q) 128 | cache[e['_id']].append(response_title) 129 | 130 | print('Dumping caches...') 131 | with open(cache_file, 'w') as f: 132 | json.dump(cache, f, indent=2) 133 | 134 | 135 | def load_cache(start, end): 136 | cache_file = CORENLP_CACHES[OPTS.dataset] + str(start) + '-' + str(end) + '.json' 137 | print(cache_file) 138 | with open(cache_file, 'r') as f: 139 | return json.load(f) 140 | 141 | 142 | def load_ans_title_cache(source): 143 | cache_file = 'data/' + OPTS.dataset + '_' + source + '_set.json' 144 | print(cache_file) 145 | with open(cache_file, 'r') as f: 146 | return json.load(f) 147 | 148 | 149 | def load_inv_index(): 150 | filepath = 'data/inv_index.json' 151 | print("reading inv_index.json") 152 | with open(filepath) as f: 153 | return json.load(f) 154 | 155 | 156 | FIXED_VOCAB_ANSWER_RULES = [ 157 | ('date', ans_date), 158 | ('number', ans_number), 159 | ('ner_person', ans_entity_full('PERSON', 'Jeff Dean')), 160 | ('ner_location', ans_entity_full('LOCATION', 'Chicago')), 161 | ('ner_organization', ans_entity_full('ORGANIZATION', 'Stark Industries')), 162 | ('ner_misc', ans_entity_full('MISC', 'Jupiter')), 163 | ('abbrev', ans_abbrev('LSTM')), 164 | ('wh_who', ans_match_wh('who', 'Jeff Dean')), 165 | ('wh_when', ans_match_wh('when', '1956')), 166 | ('wh_where', ans_match_wh('where', 'Chicago')), 167 | ('wh_how_many', ans_match_wh('how many', '42')), 168 | # Starts with verb 169 | ('pos_begin_vb', ans_pos('VB', 'learn')), 170 | ('pos_end_vbd', ans_pos('VBD', 'learned')), 171 | ('pos_end_vbd', ans_pos('VBN', 'learned')), 172 | ('pos_end_vbg', ans_pos('VBG', 'learning')), 173 | ('pos_end_vbp', ans_pos('VBP', 'learns')), 174 | ('pos_end_vbz', ans_pos('VBZ', 'learns')), 175 | # Ends with some POS tag 176 | ('pos_end_nn', ans_pos('NN', 'hamster', end=True, add_dt=True)), 177 | ('pos_end_nnp', ans_pos('NNP', 'Central Park', end=True, add_dt=True)), 178 | ('pos_end_nns', ans_pos('NNS', 'hamsters', end=True, add_dt=True)), 179 | ('pos_end_nnps', ans_pos('NNPS', 'Kew Gardens', end=True, add_dt=True)), 180 | ('pos_end_jj', ans_pos('JJ', 'deep', end=True)), 181 | ('pos_end_jjr', ans_pos('JJR', 'deeper', end=True)), 182 | ('pos_end_jjs', ans_pos('JJS', 'deepest', end=True)), 183 | ('pos_end_rb', ans_pos('RB', 'silently', end=True)), 184 | ('pos_end_vbg', ans_pos('VBG', 'learning', end=True)), 185 | ('catch_all', ans_catch_all('aliens')), 186 | ] 187 | 188 | 189 | WORDNET_ANSWER_RULES_DYN_GEN = [ 190 | ('date', ans_date), 191 | ('number', ans_number), 192 | ('ner_person', lookup_answer_generate(ans_entity_full('PERSON', 'Jeff Dean'), 193 | 'ner_person')), 194 | ('ner_location', lookup_answer_generate(ans_entity_full('LOCATION', 'Chicago'), 195 | 'ner_location')), 196 | ('ner_organization', lookup_answer_generate(ans_entity_full('ORGANIZATION', 197 | 'Stark Industries'), 'ner_organization')), 198 | ('wordnet_catch', ans_wordnet_catch_amap), 199 | ('ner_misc', lookup_answer_generate(ans_entity_full('MISC', 'Jupiter'), 200 | 'ner_misc')), 201 | ('abbrev', lookup_answer_generate(ans_abbrev('LSTM'), 'abbrev')), 202 | ('wh_who', lookup_answer_generate(ans_match_wh('who', 'Jeff Dean'), 'wh_who')), 203 | ('wh_when', lookup_answer_generate(ans_match_wh('when', '1956'), 'wh_when')), 204 | ('wh_where', lookup_answer_generate(ans_match_wh('where', 'Chicago'), 'wh_where')), 205 | ('wh_how_many', lookup_answer_generate(ans_match_wh('how many', '42'), 206 | 'wh_how_many')), 207 | # Starts with verb 208 | ('pos_begin_vb', lookup_answer_generate(ans_pos('VB', 'learn'), 'pos_begin_vb')), 209 | ('pos_end_vbd', lookup_answer_generate(ans_pos('VBD', 'learned'), 'pos_end_vbd')), 210 | ('pos_end_vbg', lookup_answer_generate(ans_pos('VBG', 'learning'), 'pos_end_vbg')), 211 | ('pos_end_vbp', lookup_answer_generate(ans_pos('VBP', 'learns'), 'pos_end_vbp')), 212 | ('pos_end_vbz', lookup_answer_generate(ans_pos('VBZ', 'learns'), 'pos_end_vbz')), 213 | # Ends with some POS tag 214 | ('pos_end_nn', lookup_answer_generate(ans_pos('NN', 'hamster', end=True, 215 | add_dt=True), 'pos_end_nn')), 216 | ('pos_end_nnp', lookup_answer_generate(ans_pos('NNP', 'Central Park', end=True, 217 | add_dt=True), 'pos_end_nnp')), 218 | ('pos_end_nns', lookup_answer_generate(ans_pos('NNS', 'hamsters', end=True, 219 | add_dt=True), 'pos_end_nns')), 220 | ('pos_end_nnps', lookup_answer_generate(ans_pos('NNPS', 'Kew Gardens', end=True, add_dt=True), 221 | 'pos_end_nnps')), 222 | ('pos_end_jj', lookup_answer_generate(ans_pos('JJ', 'deep', end=True), 223 | 'pos_end_jj')), 224 | ('pos_end_jjr', lookup_answer_generate(ans_pos('JJR', 'deeper', end=True), 225 | 'pos_end_jjr')), 226 | ('pos_end_jjs', lookup_answer_generate(ans_pos('JJS', 'deepest', end=True), 227 | 'pos_end_jjs')), 228 | ('pos_end_rb', lookup_answer_generate(ans_pos('RB', 'silently', end=True), 229 | 'pos_end_rb')), 230 | ('pos_end_vbg', lookup_answer_generate(ans_pos('VBG', 'learning', end=True), 231 | 'pos_end_vbg')), 232 | ('catch_all', lookup_answer_generate(ans_catch_all('aliens'), 233 | 'catch_all')) 234 | ] 235 | 236 | 237 | def get_tokens_from_coreobj(original, obj): 238 | """Get CoreNLP tokens corresponding to a SQuAD answer object.""" 239 | toks = [] 240 | for s in obj['sentences']: 241 | for t in s['tokens']: 242 | toks.append(t) 243 | if corenlp.rejoin(toks).strip() == original.strip(): 244 | # Make sure that the tokens reconstruct the answer 245 | return toks 246 | else: 247 | if len(toks) == 0 and original == ' 🇦🇷': 248 | return toks 249 | assert False, (toks, corenlp.rejoin(toks).strip(), original) 250 | 251 | 252 | def get_determiner_for_answer(answer): 253 | words = answer.split(' ') 254 | if words[0].lower() == 'the': return 'the' 255 | if words[0].lower() in ('a', 'an'): return 'a' 256 | return None 257 | 258 | 259 | def process_sp_facts(e, answer_tok): 260 | """ 261 | Return: 262 | sp_fact_w_answer: A list of at most 2 sublists, each sublist is a list of [Integer: doc_id, List: [Integer: sent_id]]. 263 | Representing the sentence-level supporting facts that contain the answer. 264 | sp_doc_w_answer: A list of at most 2 tuple (Integer: doc_id, String: doc_words). 265 | Representing the document-level supporting facts that contain the answer. 266 | sp_doc_wo_answer: same as above. Representing the document-level supporting facts containing no answer. 267 | sp_doc_ids: A list of two integers. 268 | """ 269 | supports, context, answer = e['supporting_facts'], e['context'], e['answer'] 270 | titles, _context = [], [] 271 | for si, doc in enumerate(context): 272 | _context.append(doc[1]) 273 | titles.append(doc[0]) 274 | sp_doc_w_answer_ids, sp_doc_wo_answer_ids, sp_doc_ids = [], [], [] 275 | sp_fact_w_answer, sp_doc_wo_answer, sp_doc_w_answer = [], [], [] 276 | 277 | sp_doc_sent_id, sp_doc_ids = [], [] 278 | for sp_doc_title, sent_id in supports: 279 | sp_doc_id = titles.index(sp_doc_title) 280 | if sp_doc_id not in sp_doc_ids: 281 | sp_doc_ids.append(sp_doc_id) 282 | sp_doc_sent_id.append([sp_doc_id, [sent_id]]) 283 | else: 284 | sp_doc_sent_id[-1][1].append(sent_id) 285 | 286 | for sp_doc_id, sent_ids in sp_doc_sent_id: 287 | sent_ids = range(len(context[sp_doc_id][1])) 288 | 289 | for sent_id in sent_ids: 290 | if sent_id == 902: 291 | continue 292 | sent = _context[sp_doc_id][sent_id] 293 | a = re.search(r'({})'.format(re.escape(answer.lower())), sent.lower()) 294 | if a: 295 | if sp_doc_id not in sp_doc_w_answer_ids: 296 | sp_fact_w_answer.append([sp_doc_id, [sent_id]]) 297 | sp_doc_w_answer_ids.append(sp_doc_id) 298 | else: 299 | assert sp_doc_id == sp_fact_w_answer[-1][0] 300 | sp_fact_w_answer[-1][1].append(sent_id) 301 | else: 302 | found_tok, not_found_tok = 0, 0 303 | for at in answer_tok: 304 | if at['originalText'] in PUNCT_COMMON: 305 | continue 306 | a = re.search(r'({})'.format('(?= not_found_tok: 312 | if sp_doc_id not in sp_doc_w_answer_ids: 313 | sp_fact_w_answer.append([sp_doc_id, [sent_id]]) 314 | sp_doc_w_answer_ids.append(sp_doc_id) 315 | else: 316 | assert sp_doc_id == sp_fact_w_answer[-1][0] 317 | sp_fact_w_answer[-1][1].append(sent_id) 318 | 319 | for id in sp_doc_ids: 320 | if id in sp_doc_w_answer_ids: 321 | sp_doc_w_answer.append((id,''.join(_context[id]))) 322 | else: 323 | sp_doc_wo_answer.append((id,''.join(_context[id]))) 324 | 325 | return sp_fact_w_answer, sp_doc_w_answer, sp_doc_wo_answer, sp_doc_ids 326 | 327 | 328 | FIXED_VOCAB_BRIDGE_RULES = [ 329 | ('ner_person', bridge_entity_full('PERSON', 'Jeff Dean')), 330 | ('ner_location', bridge_entity_full('LOCATION', 'Chicago')), 331 | ('ner_organization', bridge_entity_full('ORGANIZATION', 'Stark Industries')), 332 | ('ner_misc', bridge_entity_full('MISC', 'Jupiter')), 333 | ('abbrev', bridge_abbrev('LSTM')), 334 | # Starts with verb 335 | ('pos_begin_vb', bridge_pos('VB', 'act')), 336 | ('pos_end_vbd', bridge_pos('VBD', 'acted')), 337 | ('pos_end_vbd', bridge_pos('VBN', 'acted')), 338 | ('pos_end_vbg', bridge_pos('VBG', 'acting')), 339 | ('pos_end_vbp', bridge_pos('VBP', 'acts')), 340 | ('pos_end_vbz', bridge_pos('VBZ', 'acts')), 341 | # Ends with some POS tag 342 | ('pos_end_nn', bridge_pos('NN', 'table', end=True)), 343 | ('pos_end_nnp', bridge_pos('NNP', 'Hyde Park', end=True)), 344 | ('pos_end_nns', bridge_pos('NNS', 'tables', end=True)), 345 | ('pos_end_nnps', bridge_pos('NNPS', 'Trump Towers', end=True)), 346 | ('pos_end_jj', bridge_pos('JJ', 'hard', end=True)), 347 | ('pos_end_jjr', bridge_pos('JJR', 'harder', end=True)), 348 | ('pos_end_jjs', bridge_pos('JJS', 'hardest', end=True)), 349 | ('pos_end_rb', bridge_pos('RB', 'loudly', end=True)), 350 | ('date', ans_date), 351 | ('number', bridge_number), 352 | ('catch_all', ans_catch_all('players')), 353 | ] 354 | 355 | BRIDGE_RULES_DYN_GEN = [ 356 | ('date', bridge_date), 357 | ('number', bridge_number), 358 | ('ner_person', lookup_title_generate(bridge_entity_full('PERSON', 'Jeff Dean'), 'ner_person')), 359 | ('ner_location', lookup_title_generate(bridge_entity_full('LOCATION', 'Chicago'), 'ner_location')), 360 | ('ner_organization', lookup_title_generate(bridge_entity_full('ORGANIZATION', 'Stark Industries'), 'ner_organization')), 361 | ('ner_misc', lookup_title_generate(bridge_entity_full('MISC', 'Jupiter'), 'ner_misc')), 362 | ('abbrev', lookup_title_generate(bridge_abbrev('LSTM'), 'abbrev')), 363 | # Starts with verb 364 | ('pos_begin_vb', lookup_title_generate(bridge_pos('VB', 'learn'), 'pos_begin_vb')), 365 | ('pos_end_vbd', lookup_title_generate(bridge_pos('VBD', 'learned'), 'pos_end_vbd')), 366 | ('pos_end_vbg', lookup_title_generate(bridge_pos('VBG', 'learning'), 'pos_end_vbg')), 367 | ('pos_end_vbp', lookup_title_generate(bridge_pos('VBP', 'learns'), 'pos_end_vbp')), 368 | ('pos_end_vbz', lookup_title_generate(bridge_pos('VBZ', 'learns'), 'pos_end_vbz')), 369 | # Ends with some POS tag 370 | ('pos_end_nn', lookup_title_generate(bridge_pos('NN', 'hamster', end=True), 'pos_end_nn')), 371 | ('pos_end_nnp', lookup_title_generate(bridge_pos('NNP', 'Central Park', end=True), 'pos_end_nnp')), 372 | ('pos_end_nns', lookup_title_generate(bridge_pos('NNS', 'hamsters', end=True), 'pos_end_nns')), 373 | ('pos_end_nnps', lookup_title_generate(bridge_pos('NNPS', 'Kew Gardens', end=True), 'pos_end_nnps')), 374 | ('pos_end_jj', lookup_title_generate(bridge_pos('JJ', 'deep', end=True), 'pos_end_jj')), 375 | ('pos_end_jjr', lookup_title_generate(bridge_pos('JJR', 'deeper', end=True), 'pos_end_jjr')), 376 | ('pos_end_jjs', lookup_title_generate(bridge_pos('JJS', 'deepest', end=True), 'pos_end_jjs')), 377 | ('pos_end_rb', lookup_title_generate(bridge_pos('RB', 'silently', end=True), 'pos_end_rb')), 378 | ('pos_end_vbg', lookup_title_generate(bridge_pos('VBG', 'learning', end=True), 'pos_end_vbg')), 379 | ('catch_all', lookup_title_generate(ans_catch_all('aliens'), 'catch_all')), 380 | ] 381 | 382 | 383 | def create_fake_answer(answer, a_toks, question, determiner, ans_cache=None): 384 | if OPTS.rule == 'fixed_vocab': 385 | rules = FIXED_VOCAB_ANSWER_RULES 386 | elif OPTS.rule == 'wordnet_dyn_gen': 387 | rules = WORDNET_ANSWER_RULES_DYN_GEN 388 | else: 389 | raise NotImplementedError 390 | for rule_name, func in rules: 391 | new_answer = func(answer, a_toks, question, ans_cache=ans_cache, determiner=determiner) 392 | if new_answer and (rule_name == 'date' or rule_name == 'number'): 393 | return new_answer, None 394 | if new_answer: break 395 | else: 396 | raise ValueError('Missing answer') 397 | return new_answer 398 | 399 | 400 | def create_fake_bridge_entity(entity, t_toks, q, title_cache=None, is_end=False): 401 | if OPTS.rule == 'fixed_vocab': 402 | rules = FIXED_VOCAB_BRIDGE_RULES 403 | elif OPTS.rule == 'wordnet_dyn_gen': 404 | rules = BRIDGE_RULES_DYN_GEN 405 | else: 406 | raise NotImplementedError 407 | for rule_name, func in rules: 408 | new_entity = func(entity, t_toks, q, title_cache=title_cache, is_end=is_end) 409 | if new_entity: break 410 | else: 411 | raise ValueError('Missing entity') 412 | return new_entity 413 | 414 | 415 | def find_bridge_entities(tok_doc1, tok_doc2, tok_question, tok_answer, tok_title, find_title_only=False): 416 | """ 417 | Args: 418 | tok_doc1: [num_sent, sent_len] 419 | tok_doc2: [doc_len] 420 | """ 421 | doc1_wordss = [[t['originalText'].lower() for t in s] for s in tok_doc1] 422 | doc1_ners = [[t['ner'] for t in s] for s in tok_doc1] 423 | doc2_words = [t['originalText'].lower() for t in tok_doc2] 424 | doc2_ner = [t['ner'] for t in tok_doc2] 425 | question_words = [t['originalText'].lower() for t in tok_question] 426 | answer_words = [t['originalText'].lower() for t in tok_answer] 427 | title_words = [t['originalText'].lower() for t in tok_title] 428 | title_ner = [t['ner'] for t in tok_title] 429 | entity_list, entity_idx = [], {} 430 | ngram_entity_list, ngram_entity_idx = [], {} 431 | ngram_entity = [] 432 | 433 | for i in range(2): 434 | if i == 0: 435 | words_to_look, ners_to_look = title_words, title_ner 436 | else: 437 | words_to_look, ners_to_look = doc2_words, doc2_ner 438 | 439 | for iss, (doc1_words, doc1_ner) in enumerate(zip(doc1_wordss, doc1_ners)): 440 | end_entity = True 441 | for ie, (entity, ner) in enumerate(zip(doc1_words[::-1], doc1_ner[::-1])): 442 | if entity in words_to_look and entity not in answer_words and entity not in question_words and entity not in PUNCTUATIONS+COMMON_WORDS: 443 | if i == 1 and (ner == 'O' and ners_to_look[words_to_look.index(entity)] == 'O'): 444 | continue 445 | if entity not in entity_list: 446 | entity_list.append(entity) 447 | entity_idx[entity] = [end_entity, [(iss, len(doc1_words)-1-ie)]] 448 | else: 449 | entity_idx[entity][0] = entity_idx[entity][0] | end_entity 450 | entity_idx[entity][1].append([iss, len(doc1_words)-1-ie]) 451 | ngram_entity.append(entity) 452 | end_entity = False 453 | else: 454 | if len(ngram_entity) > 1: 455 | ngram_entity_list.append(ngram_entity[::-1]) 456 | ngram_entity = [] 457 | end_entity = True 458 | if find_title_only: 459 | break 460 | 461 | return entity_list, entity_idx 462 | 463 | 464 | def create_adversarial_exammple(data, adv_strategy='addDoc', start=0, end=5000, ans_cache=None, title_cache=None, glove_tools=None, 465 | all_docs=None): 466 | corenlp_cache = load_cache(start, end) 467 | if OPTS.find_nearest_glove: 468 | assert glove_tools 469 | df, word_to_id, id_to_word = glove_tools[0], glove_tools[1], glove_tools[2] 470 | 471 | unmatched_qas = [] 472 | num_matched = 0 473 | new_sp_facts = {} 474 | 475 | for ie, e in tqdm(enumerate(data[start : end])): 476 | answer = e['answer'] 477 | if answer == 'yes' or answer == 'no': 478 | continue 479 | context, question = e['context'], e['question'] 480 | context_parse, answer_parse, question_parse, title_parse = corenlp_cache[e['_id']] 481 | answer_tok = get_tokens_from_coreobj(answer, answer_parse) 482 | determiner = get_determiner_for_answer(answer) 483 | 484 | sps_w_answer, sp_docs_w_answer, sp_docs_wo_answer, sp_doc_ids = process_sp_facts(e, answer_tok) 485 | assert len(sp_doc_ids) == 2 486 | non_sp_doc_ids = [_i for _i in range(len(e['context'])) if _i not in sp_doc_ids] 487 | 488 | assert len(sps_w_answer) == len(sp_docs_w_answer), (ie, sps_w_answer) 489 | new_docs, new_docs_2 = [], [] 490 | 491 | titles = [] 492 | for si, doc in enumerate(context): 493 | title = doc[0] 494 | title_split = re.split("([{}])".format("()"), title) 495 | if title_split[0] != '': 496 | titles.append(title_split[0]) 497 | elif title_split[-1] != '': 498 | titles.append(title_split[-1]) 499 | else: 500 | real_title = title_split[title_split.index('(')+1] 501 | assert real_title != ')' 502 | titles.append(real_title) 503 | 504 | if OPTS.find_nearest_glove: 505 | new_entities_dict = {} 506 | 507 | for isp, (sp_doc_id, sps_in_one_doc) in enumerate(sps_w_answer): 508 | # All sp_facts in a single doc should get the same fake answer. 509 | new_ans, new_ans_tok = create_fake_answer(answer, answer_tok, question, determiner=determiner, ans_cache=ans_cache) 510 | ans_len_diff = len(new_ans) - len(answer) 511 | if new_ans and isp == 0: 512 | num_matched += 1 513 | if new_ans is None: 514 | unmatched_qas.append((question, answer)) 515 | 516 | for _ind in range(OPTS.num_new_doc): 517 | if _ind > 0: 518 | new_ans, new_ans_tok = create_fake_answer(answer, answer_tok, question, determiner=determiner, ans_cache=ans_cache) 519 | ans_len_diff = len(new_ans) - len(answer) 520 | new_doc = context[sp_doc_id][1].copy() 521 | 522 | ## Step 1: substitute any bridge entities 523 | if OPTS.substitute_bridge_entities and e['type'] == 'bridge': 524 | if len(sps_w_answer) == 1: 525 | doc_id_to_compare, doc_to_compare = sp_docs_wo_answer[0] 526 | else: # There is no bridge entity in this case actually, only need to change the title entities. 527 | doc_id_to_compare, doc_to_compare = sp_docs_w_answer[(isp+1)%2] 528 | assert doc_id_to_compare != sp_doc_id, (doc_id_to_compare, sp_doc_id) 529 | doc_tok_to_compare_list = [get_tokens_from_coreobj(s, s_parse) for (s, s_parse) in zip(context[doc_id_to_compare][1], context_parse[doc_id_to_compare])] 530 | doc_tok_to_compare = list(itertools.chain.from_iterable(doc_tok_to_compare_list)) 531 | question_tok = get_tokens_from_coreobj(question, question_parse) 532 | 533 | new_doc_tok = [get_tokens_from_coreobj(_sent, context_parse[sp_doc_id][iis]) for iis, _sent in enumerate(new_doc) if _sent != ' ' and _sent != ''] 534 | old_title = titles[sp_doc_id] 535 | old_title_compare = titles[doc_id_to_compare] 536 | title_tok = get_tokens_from_coreobj(old_title, title_parse[sp_doc_id]) 537 | title_tok_to_compare = get_tokens_from_coreobj(titles[doc_id_to_compare], title_parse[doc_id_to_compare]) 538 | for i in range(2): # Substitute the title of both sp docs. 539 | if i == 1 and len(sps_w_answer) == 2: 540 | break 541 | if i == 0: 542 | _title, _title_tok = old_title, title_tok 543 | else: 544 | _title, _title_tok = old_title_compare, title_tok_to_compare 545 | 546 | # First, substitute the title. 547 | new_title, new_title_tok = create_fake_bridge_entity(_title, _title_tok, question, title_cache=title_cache) 548 | if OPTS.add_doc_incl_adv_title and len(sps_w_answer) == 1 and i == 0: 549 | for _doc in all_docs: 550 | _doc_title = _doc[0] 551 | if '(' in new_title: 552 | _new_title = new_title[:new_title.index('(')].strip() 553 | else: 554 | _new_title = new_title 555 | if _doc_title != new_title: 556 | if _new_title.lower() in ''.join(_doc[1]).lower(): 557 | new_docs_2.append(_doc) 558 | break 559 | 560 | ent_len_diff = len(new_title) - len(_title) 561 | 562 | # Substitute the entire title. 563 | if new_title_tok is None or OPTS.dont_replace_full_title is False or len(_title_tok) == 1: 564 | foundd = 0 565 | for isent, _sent in enumerate(new_doc): 566 | try: 567 | a = re.finditer(r'({})'.format('(? found.start()+ifo*ent_len_diff 576 | _sent = _sent[:(found.start()+ifo*ent_len_diff)] + new_title + _sent[(found.end()+ifo*ent_len_diff):] 577 | new_doc[isent] = _sent 578 | if i == 0 and new_title_tok is None and foundd == 0: 579 | print(ie) 580 | assert False 581 | 582 | title_tok_wo_common = [t for t in _title_tok if t['originalText'] not in COMMON_WORDS + PUNCTUATIONS] 583 | num_title_tok_to_replace = 0 584 | if OPTS.dont_replace_full_title and len(title_tok_wo_common) > 1: 585 | title_tok_to_replace = [] 586 | for _token in title_tok_wo_common: 587 | a = re.search(r'({})'.format('(? found.start()+ifo*ent_len_diff 625 | _sent = _sent[:(found.start()+ifo*ent_len_diff)] + new_entity + _sent[(found.end()+ifo*ent_len_diff):] 626 | new_doc[isent] = _sent 627 | 628 | new_docs.append(new_doc) 629 | 630 | for ispsent, sp_sent_id in enumerate(sps_in_one_doc): 631 | ## Step 2: replace the original answer in sp with new_ans. 632 | sp = new_doc[sp_sent_id] 633 | new_sp = sp 634 | 635 | if new_ans_tok is None or OPTS.dont_replace_full_answer is False or len(answer_tok) == 1: 636 | a = re.finditer(r'({})'.format(re.escape(answer.lower())), new_sp.lower()) 637 | for ifo, found in enumerate(a): 638 | new_sp = new_sp[:(found.start()+ifo*ans_len_diff)] + new_ans + new_sp[(found.end()+ifo*ans_len_diff):] 639 | 640 | ## Step 3: add the adversarial sp 641 | new_doc[sp_sent_id] = new_sp 642 | 643 | ## Step 4: replace partial answer 644 | if OPTS.replace_partial_answer and len(answer_tok) > 1 and new_ans_tok: 645 | # If the answer already doesn't exi 646 | answer_tok_wo_common = [t for t in answer_tok if t['originalText'] not in COMMON_WORDS + PUNCTUATIONS] 647 | _replace = False 648 | num_replaced = 0 649 | answer_tok_to_replace = [] 650 | for _token in answer_tok_wo_common: 651 | a = re.search(r'({})'.format('(? 1: 663 | if _replace is False: 664 | _replace = True 665 | continue 666 | else: 667 | _replace = False 668 | num_replaced += 1 669 | if OPTS.find_nearest_glove: 670 | if isp == 0 and _ind == 0: 671 | new_entities = find_nearest_word(entity.lower(), df, word_to_id, id_to_word) 672 | new_entities_dict[entity] = new_entities 673 | else: 674 | new_entities = new_entities_dict[entity] 675 | else: 676 | new_entities = None 677 | 678 | if new_entities and len(new_entities) > 0: 679 | new_entity = new_entities[0] 680 | new_entities_dict[entity] = new_entities[1:] 681 | else: 682 | if ittok < len(new_ans_tok): 683 | new_entity = new_ans_tok[ittok]['originalText'] 684 | else: 685 | new_entity = new_ans_tok[-1]['originalText'] 686 | if new_entity in PUNCT_COMMON: 687 | if new_ans_tok[-1]['originalText'] not in PUNCT_COMMON: 688 | new_entity = new_ans_tok[-1]['originalText'] 689 | elif new_ans_tok[0]['originalText'] not in PUNCT_COMMON: 690 | new_entity = new_ans_tok[0]['originalText'] 691 | else: 692 | for _tok in new_ans_tok[1:-1]: 693 | if _tok['originalText'] not in PUNCT_COMMON: 694 | new_entity = _tok['originalText'] 695 | assert new_entity not in PUNCT_COMMON, (new_ans_tok) 696 | 697 | ent_len_diff = len(new_entity) - len(entity) 698 | for isent, _sent in enumerate(new_doc): 699 | a = re.finditer(r'({})'.format('(? found.start()+ifo*ent_len_diff 702 | _sent = _sent[:(found.start()+ifo*ent_len_diff)] + new_entity + _sent[(found.end()+ifo*ent_len_diff):] 703 | new_doc[isent] = _sent 704 | 705 | if new_ans_tok and OPTS.dont_replace_full_answer and len(answer_tok) > 1: # In this case, the replace_full_answer procedure is skipped. 706 | if answer != '""': 707 | assert num_replaced > 0, (answer, new_ans, ie) 708 | 709 | a = re.search(r'({})'.format(re.escape('(?= len(non_sp_doc_ids): 735 | break 736 | new_context.append(context[non_sp_doc_ids[_rm]]) 737 | 738 | if OPTS.prepend is False: 739 | random.shuffle(new_context) 740 | 741 | e['context'] = new_context 742 | return data, num_matched 743 | 744 | 745 | def dump_data(data, adv_strategy='addDoc', bsz=5000): 746 | if OPTS.add_doc_incl_adv_title: 747 | filepath = 'data/all_' + OPTS.dataset + '_docs.json' 748 | all_docs = json.load(open(filepath)) 749 | else: 750 | all_docs = None 751 | 752 | total_num_matched = 0 753 | ans_cache, title_cache = None, None 754 | if adv_strategy == 'addDoc' and 'dyn_gen' in OPTS.rule: 755 | ans_cache = load_ans_title_cache('answer') 756 | title_cache = load_ans_title_cache('title') 757 | 758 | glove_tools = None 759 | if OPTS.find_nearest_glove: 760 | assert OPTS.dont_replace_full_answer 761 | glove_tools = build_glove_matrix(num_words=OPTS.num_glove_words_to_use) 762 | 763 | all_data = [] 764 | for i in range(int(len(data)/bsz) + 1): 765 | data, num_matched = create_adversarial_exammple(data, adv_strategy, start=i*bsz, end=min((i+1)*bsz, len(data)), 766 | ans_cache=ans_cache, title_cache=title_cache, glove_tools=glove_tools, all_docs=all_docs) 767 | total_num_matched += num_matched 768 | all_data.extend(data[i*bsz : min((i+1)*bsz, len(data))]) 769 | # Print stats 770 | print('=== Summary ===') 771 | print('Matched %d/%d = %.2f%% questions' % ( 772 | total_num_matched, len(data), 100.0 * total_num_matched / len(data))) 773 | prefix = '%s-%s' % (OPTS.dataset, adv_strategy) 774 | if OPTS.prepend: 775 | prefix = '%s-pre' % prefix 776 | with open(os.path.join('out', 'hotpot_' + prefix + '.json'), 'w') as f: 777 | json.dump(all_data, f) 778 | 779 | 780 | def dump_data_batch(data, adv_strategy, i, bsz=5000): 781 | if OPTS.add_doc_incl_adv_title: 782 | filepath = 'data/all_' + OPTS.dataset + '_docs.json' 783 | all_docs = json.load(open(filepath)) 784 | else: 785 | all_docs = None 786 | 787 | total_num_matched = 0 788 | ans_cache, title_cache = None, None 789 | if adv_strategy == 'addDoc' and 'dyn_gen' in OPTS.rule: 790 | ans_cache = load_ans_title_cache('answer') 791 | title_cache = load_ans_title_cache('title') 792 | 793 | glove_tools = None 794 | if OPTS.find_nearest_glove: 795 | assert OPTS.dont_replace_full_answer 796 | glove_tools = build_glove_matrix(num_words=OPTS.num_glove_words_to_use) 797 | 798 | data, num_matched = create_adversarial_exammple(data, adv_strategy, start=i*bsz, end=min((i+1)*bsz, len(data)), 799 | ans_cache=ans_cache, title_cache=title_cache, glove_tools=glove_tools, all_docs=all_docs) 800 | data_batch = data[i*bsz : min((i+1)*bsz, len(data))] 801 | total_num_matched += num_matched 802 | # Print stats 803 | print('=== Summary ===') 804 | print('Matched %d/%d = %.2f%% questions' % ( 805 | total_num_matched, len(data), 100.0 * total_num_matched / len(data_batch))) 806 | prefix = '%s-%s' % (OPTS.dataset, adv_strategy) 807 | if OPTS.prepend: 808 | prefix = '%s-pre' % prefix 809 | with open(os.path.join('out', prefix + '-' + str(OPTS.num_new_doc) + '_' + str(i) +'.json'), 'w') as f: 810 | json.dump(data_batch, f) 811 | 812 | 813 | def generate_candidate_set(data, out_path, source='answer', bsz=5000): 814 | ans_set = {} 815 | for i in range(int(len(data)/bsz) + 1): 816 | start = i*bsz 817 | end = min((i+1)*bsz, len(data)) 818 | corenlp_cache = load_cache(start, end) 819 | for ie, e in tqdm(enumerate(data[start : end])): 820 | answer, question, context = e['answer'], e['question'], e['context'] 821 | if answer == 'yes' or answer == 'no': 822 | continue 823 | _, answer_parse, _, title_parse = corenlp_cache[e['_id']] 824 | if source == 'answer': 825 | determiner = get_determiner_for_answer(answer) 826 | entities = [answer] 827 | entity_toks = [get_tokens_from_coreobj(answer, answer_parse)] 828 | else: 829 | entities = [] 830 | for si, doc in enumerate(context): 831 | title = doc[0] 832 | title_split = re.split("([{}])".format("()"), title) 833 | if title_split[0] != '': 834 | entities.append(title_split[0]) 835 | elif title_split[-1] != '': 836 | entities.append(title_split[-1]) 837 | else: 838 | real_title = title_split[title_split.index('(')+1] 839 | assert real_title != ')' 840 | entities.append(real_title) 841 | 842 | entity_toks = [get_tokens_from_coreobj(ent, title_parse[itt]) for itt, ent in enumerate(entities)] 843 | 844 | rules_to_check = FIXED_VOCAB_ANSWER_RULES if source == 'answer' else FIXED_VOCAB_BRIDGE_RULES 845 | 846 | for entity, entity_tok in zip(entities, entity_toks): 847 | for rule_name, func in rules_to_check: 848 | if source == 'answer': 849 | new_entity = func(entity, entity_tok, question, determiner=determiner) 850 | else: 851 | new_entity = func(entity, entity_tok, question) 852 | 853 | tok_len = len(entity_tok) 854 | if new_entity: 855 | if rule_name not in ans_set: 856 | ans_set[rule_name] = {} 857 | ans_set[rule_name][tok_len] = [(entity, entity_tok)] 858 | else: 859 | if tok_len in ans_set[rule_name]: 860 | if (entity, entity_toks) not in ans_set[rule_name][tok_len]: 861 | ans_set[rule_name][tok_len].append((entity, entity_tok)) 862 | else: 863 | ans_set[rule_name][tok_len] = [(entity, entity_tok)] 864 | 865 | for rule in ans_set: 866 | for tok_len in list(ans_set[rule].keys()): 867 | ans_set[rule][tok_len] = list(ans_set[rule][tok_len]) 868 | 869 | with open(out_path, 'w') as out: 870 | json.dump(ans_set, out, indent=2) 871 | 872 | 873 | def generate_all_docs(data, out_path): 874 | all_docs, all_titles = [], [] 875 | for i, e in enumerate(data): 876 | for c in e['context']: 877 | if c[0] not in all_titles: 878 | all_titles.append(c[0]) 879 | all_docs.append(c) 880 | with open(out_path, 'w') as f: 881 | json.dump(all_docs, f) 882 | 883 | 884 | def merge_adv_examples(out_path): 885 | if OPTS.dataset == 'train': 886 | num_files = 19 887 | else: 888 | num_files = 2 889 | 890 | adv_data = [] 891 | for i in range(num_files): 892 | filepath = 'out/' + OPTS.dataset + '-addDoc-4_' + str(i) + '.json' 893 | with open(filepath, 'r') as f: 894 | data = json.load(f) 895 | adv_data += data 896 | 897 | with open(out_path, 'w') as f: 898 | json.dump(adv_data, f) 899 | 900 | 901 | def main(): 902 | dataset = read_data() 903 | if OPTS.command == 'corenlp': 904 | import re 905 | run_corenlp(dataset) 906 | elif OPTS.command == 'dump-addDoc': 907 | dump_data(dataset, 'addDoc') 908 | elif OPTS.command == 'dumpBatch-addDoc': 909 | dump_data_batch(dataset, 'addDoc', OPTS.batch_idx) 910 | elif OPTS.command == 'gen-answer-set': 911 | generate_candidate_set(dataset, 'data/' + OPTS.dataset + '_answer_set.json', source='answer') 912 | elif OPTS.command == 'gen-title-set': 913 | generate_candidate_set(dataset, 'data/' + OPTS.dataset + '_title_set.json', source='title') 914 | elif OPTS.command == 'gen-all-docs': 915 | generate_all_docs(dataset, 'data/' + 'all_' + OPTS.dataset + '_docs.json') 916 | elif OPTS.command == 'merge_files': 917 | merge_adv_examples('out/hotpot_' + OPTS.dataset + '_addDoc.json') 918 | else: 919 | raise ValueError('Unknown command "%s"' % OPTS.command) 920 | 921 | 922 | if __name__ == '__main__': 923 | OPTS = parse_args() 924 | main() 925 | 926 | --------------------------------------------------------------------------------