├── experiments ├── mnli_vocabulary │ ├── non_padded_namespaces.txt │ └── labels.txt ├── squad_vocabulary │ └── non_padded_namespaces.txt ├── combo_snli_mnli_vocab │ ├── non_padded_namespaces.txt │ └── labels.txt ├── pair2vec_train.json ├── mnli_mism.json ├── pair2vec_nli.json └── pair2vec_squad2.json ├── embeddings ├── pair_to_index.pkl ├── cmd.txt ├── cooccurance.py ├── metrics.py ├── representation.py ├── util.py ├── preprocess.py ├── matrix_data.py ├── bats_analysis.py ├── model.py ├── train.py ├── vocab.py └── indexed_field.py ├── .gitignore ├── download_pair2vec.sh ├── endtasks ├── modules.py ├── squad_predictor.py ├── util.py ├── squad2_eval.py ├── squad2_reader.py ├── esim_pair2vec.py └── bidaf_pair2vec.py ├── download_corpus.sh ├── requirements.txt ├── README.md └── LICENSE /experiments/mnli_vocabulary/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | *tags 2 | *labels 3 | -------------------------------------------------------------------------------- /experiments/squad_vocabulary/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | *tags 2 | *labels 3 | -------------------------------------------------------------------------------- /experiments/combo_snli_mnli_vocab/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | *tags 2 | *labels 3 | -------------------------------------------------------------------------------- /experiments/mnli_vocabulary/labels.txt: -------------------------------------------------------------------------------- 1 | entailment 2 | contradiction 3 | neutral 4 | -------------------------------------------------------------------------------- /experiments/combo_snli_mnli_vocab/labels.txt: -------------------------------------------------------------------------------- 1 | entailment 2 | contradiction 3 | neutral 4 | -------------------------------------------------------------------------------- /embeddings/pair_to_index.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mandarjoshi90/pair2vec/HEAD/embeddings/pair_to_index.pkl -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .cache/ 3 | .coverage 4 | .ipynb_checkpoints 5 | *__pycache__* 6 | *.idea* 7 | *.history 8 | scratch/* 9 | ./data 10 | ./models 11 | .mypy_cache 12 | .vector_cache 13 | -------------------------------------------------------------------------------- /embeddings/cmd.txt: -------------------------------------------------------------------------------- 1 | python preprocess.py /sdb/data/wikipedia-sentences/shuf_sentences.txt /sdb/data/models/temp/ /sdb/data/wikipedia-sentences/counts.txt /sdb/data/wikipedia-sentences/sorted_coor_counts.txt 2 | -------------------------------------------------------------------------------- /download_pair2vec.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | curl -o experiments/pair2vec_pretrained.tar,gz http://nlp.cs.washington.edu/pair2vec/pair2vec_pretrained.tar.gz 3 | (cd experiments && tar xvfz pair2vec_pretrained.tar,gz) 4 | rm experiments/pair2vec_pretrained.tar,gz 5 | -------------------------------------------------------------------------------- /endtasks/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | class VariationalDropout(torch.nn.Dropout): 5 | def forward(self, input): 6 | """ 7 | input is shape (batch_size, timesteps, embedding_dim) 8 | Samples one mask of size (batch_size, embedding_dim) and applies it to every time step. 9 | """ 10 | ones = Variable(input.data.new(input.shape[0], input.shape[-1]).fill_(1)) 11 | dropout_mask = torch.nn.functional.dropout(ones, self.p, self.training, inplace=False) 12 | if self.inplace: 13 | input *= dropout_mask.unsqueeze(1) 14 | return None 15 | else: 16 | return dropout_mask.unsqueeze(1) * input 17 | -------------------------------------------------------------------------------- /download_corpus.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # download preprocessed corpus. Alternatively download raw data from, http://nlp.cs.washington.edu/pair2vec/wikipedia.tar.gz unzip and run python -m embeddings.preprocess 4 | data_dir=data 5 | mkdir $data_dir 6 | echo "Downloading preprocessed corpus" 7 | curl -o $data_dir/preprocessed.tar,gz http://nlp.cs.washington.edu/pair2vec/preprocessed.tar.gz 8 | (cd $data_dir && tar xvfz preprocessed.tar,gz) 9 | rm $data_dir/preprocessed.tar,gz 10 | # fasttext 11 | echo "Downlaoding fastText" 12 | mkdir $data_dir/fasttext 13 | curl -o $data_dir/fasttext/wiki-news-300d-1M-subword.vec.zip https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki-news-300d-1M-subword.vec.zip 14 | unzip $data_dir/fasttext/wiki-news-300d-1M-subword.vec.zip -d $data_dir/fasttext/ 15 | rm $data_dir/fasttext/wiki-news-300d-1M-subword.vec.zip 16 | ln -s $data_dir/fasttext/wiki-news-300d-1M-subword.vec $data_dir/fasttext/wiki.en.vec 17 | -------------------------------------------------------------------------------- /experiments/pair2vec_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "compositional_rels": true, 3 | "compositional_args": false, 4 | "positional_rels": false, 5 | "triplet_dir": "data/softsample", 6 | "vocab_file": "experiments/vocab_pair2vec.txt", 7 | "train_batch_size": 600, 8 | "dev_batch_size": 1000, 9 | "normalize_args": true, 10 | "grad_norm": 10, 11 | "d_args": 300, 12 | "d_rels": 300, 13 | "d_pos": 300, 14 | "d_embed": 300, 15 | "d_lstm_input": 300, 16 | "d_lstm_hidden": 100, 17 | "n_lstm_layers": 1, 18 | "num_neg_samples": 2, 19 | "num_sampled_relations": 2, 20 | "negative_rel_loss": 1.0, 21 | "negative_subject_loss": 1.0, 22 | "negative_object_loss": 1.0, 23 | "type_subject_loss": 1.0, 24 | "type_object_loss": 1.0, 25 | "dropout": 0.0, 26 | "relation_predictor": "mlp", 27 | "lr": 0.01, 28 | "epochs": 12, 29 | "threshold": 0.5, 30 | "type_scores_file": "data/softsample/topk_scores.npy", 31 | "type_indices_file": "data/softsample/topk_indxs.npy", 32 | "save_every": 30000, 33 | "dev_every": 30000, 34 | "log_every": 3000, 35 | "separate_mlr": false, 36 | "sample_arguments": true, 37 | "n_args": 100004, 38 | "n_rels": 100004 39 | } 40 | -------------------------------------------------------------------------------- /endtasks/squad_predictor.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | 3 | from allennlp.common.util import JsonDict 4 | from allennlp.data import Instance 5 | from allennlp.predictors.predictor import Predictor 6 | 7 | @Predictor.register('squad2-predictor') 8 | class Squad2Predictor(Predictor): 9 | """ 10 | Predictor for the :class:`~allennlp.models.bidaf.BidirectionalAttentionFlow` model. 11 | """ 12 | 13 | def predict(self, question: str, passage: str, question_id: str) -> JsonDict: 14 | """ 15 | Make a machine comprehension prediction on the supplied input. 16 | See https://rajpurkar.github.io/SQuAD-explorer/ for more information about the machine comprehension task. 17 | 18 | Parameters 19 | ---------- 20 | question : ``str`` 21 | A question about the content in the supplied paragraph. The question must be answerable by a 22 | span in the paragraph. 23 | passage : ``str`` 24 | A paragraph of information relevant to the question. 25 | 26 | Returns 27 | ------- 28 | A dictionary that represents the prediction made by the system. The answer string will be under the 29 | "best_span_str" key. 30 | """ 31 | return self.predict_json({"passage" : passage, "question" : question, "question_id": question_id}) 32 | 33 | @overrides 34 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 35 | """ 36 | Expects JSON that looks like ``{"question": "...", "passage": "..."}``. 37 | """ 38 | question_text = json_dict["question"] 39 | passage_text = json_dict["passage"] 40 | question_id = json_dict["question_id"] 41 | return self._dataset_reader.text_to_instance(question_text, passage_text, question_id), {} 42 | -------------------------------------------------------------------------------- /embeddings/cooccurance.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import sys 3 | from tqdm import tqdm 4 | from embeddings.vocab import Vocab 5 | 6 | def read_vocab_from_file(vocab_path, specials): 7 | tokens = None 8 | with open(vocab_path) as f: 9 | text = f.read() 10 | tokens = text.rstrip().split('\n') 11 | tokens = tokens[3:] 12 | vocab = Vocab(tokens, specials=specials) 13 | print('Loaded vocab with {} tokens'.format(len(tokens))) 14 | return vocab 15 | 16 | def get_cooccurance(fname, vocab_file, outf): 17 | vocab = read_vocab_from_file(vocab_file, specials=['', '', '', '']) 18 | counts = defaultdict(int) 19 | win = 5 20 | with open(fname, encoding='utf-8') as f: 21 | for i_line, line in tqdm(enumerate(f)): 22 | tokens = line.strip().lower().split() 23 | token_ids = [vocab.stoi[t] for t in tokens if vocab.stoi[t] != 0] 24 | len_tokens = len(token_ids) 25 | for ix, x in enumerate(token_ids): 26 | y_iter = [iy for iy in range(ix + 1, ix + 2 + win) if iy < len_tokens and token_ids[iy] != 0] 27 | for iy in y_iter: 28 | pair = (token_ids[ix], token_ids[iy]) if token_ids[ix] < token_ids[iy] else (token_ids[iy], token_ids[ix]) 29 | counts[pair] += 1 30 | counts = sorted(counts.items(), key=lambda x : x[1], reverse=True) 31 | with open(outf, mode='w', encoding='utf-8') as f: 32 | for pair, count in counts: 33 | f.write(vocab.itos[pair[0]] + '\t' + vocab.itos[pair[1]] + '\t' + str(count) + '\n') 34 | 35 | corpus_file = sys.argv[1] # input corput 36 | pair_counts_file = sys.argv[2] # output 37 | vocab_file = sys.argv[3] if len(sys.argv) > 3 else 'vocabulary/pair2vec_tokens.txt' # optional vocaulary file from the repo 38 | get_cooccurance(corpus_file, vocab_file, pair_counts_file) 39 | #get_cooccurance('/sdb/data/wikipedia-sentences/shuf_sentences.txt', '/sdb/data/wikipedia-sentences/triplet_contexts/vocab.txt', 40 | # '/sdb/data/wikipedia-sentences/coor_counts.txt') 41 | -------------------------------------------------------------------------------- /embeddings/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | 5 | 6 | def positive_predictions_for(predicted_probs, threshold=0.5): 7 | #return sum(torch.gt(predicted_probs.data, threshold).cpu().numpy().tolist()) 8 | return (torch.gt(predicted_probs.data, threshold).float().sum()) 9 | 10 | def mrr(predictions, gold_labels, all_true, candidates=None): 11 | reciprocal_ranks = [] 12 | candidate_mask = get_mask(all_true, candidates, gold_labels, predictions.size(1)) 13 | predictions = torch.sigmoid(predictions) 14 | predictions = predictions * candidate_mask 15 | max_values, argsort = torch.sort(predictions, 1, descending=True) 16 | argsort = argsort.data.cpu().numpy() 17 | gold_labels = gold_labels.data.cpu().numpy() 18 | for i in range(predictions.size(0)): 19 | rank = np.where(argsort[i] == gold_labels[i])[0][0] 20 | reciprocal_ranks.append(rank + 1) 21 | return reciprocal_ranks 22 | 23 | 24 | def masked_index_fill(tensor, index, index_mask, value): 25 | num_indices = index_mask.long().sum() 26 | valid_indices = index[: num_indices] 27 | tensor.index_fill_(0, valid_indices, value) 28 | 29 | 30 | def get_mask(all_true_objects, candidates, gold_labels, num_labels): 31 | batch_size = gold_labels.size(0) 32 | all_true_objects_mask = (1 - torch.eq(all_true_objects, -1).float()).byte() 33 | if candidates is None: 34 | candidates_mask = torch.ones((all_true_objects.size(0), num_labels), out=all_true_objects.data.new()) 35 | else: 36 | candidates_mask = torch.zeros((candidates.size(0), num_labels), out=all_true_objects.data.new()) 37 | cand_index_mask = (1 - torch.eq(candidates, -1).float()) 38 | for i in range(batch_size): 39 | if candidates is not None: 40 | masked_index_fill(candidates_mask[i], candidates[i].data, cand_index_mask[i].data, 1) 41 | masked_index_fill(candidates_mask[i], all_true_objects[i].data, all_true_objects_mask[i].data, 0) 42 | # candidates_mask.scatter_(1, all_true_objects.data, 0) 43 | candidates_mask.scatter_(1, gold_labels.unsqueeze(1).data, 1) 44 | return Variable(candidates_mask.float(), requires_grad=True) 45 | 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.6.1 2 | alabaster==0.7.12 3 | allennlp==0.5.0 4 | astor==0.7.1 5 | atomicwrites==1.2.1 6 | attrs==18.2.0 7 | awscli==1.16.61 8 | Babel==2.6.0 9 | backcall==0.1.0 10 | bleach==1.5.0 11 | boto3==1.9.51 12 | botocore==1.12.51 13 | certifi==2018.10.15 14 | cffi==1.11.2 15 | chardet==3.0.4 16 | Click==7.0 17 | colorama==0.3.9 18 | conllu==1.2.1 19 | cymem==2.0.2 20 | cytoolz==0.9.0.1 21 | dask==1.0.0 22 | decorator==4.3.0 23 | dill==0.2.8.2 24 | docutils==0.14 25 | editdistance==0.5.2 26 | en-core-web-sm==2.0.0 27 | flaky==3.4.0 28 | Flask==0.12.1 29 | Flask-Cors==3.0.3 30 | gast==0.2.0 31 | gevent==1.2.2 32 | greenlet==0.4.15 33 | grpcio==1.17.0 34 | h5py==2.8.0 35 | html5lib==0.9999999 36 | idna==2.7 37 | imagesize==1.1.0 38 | ipdb==0.11 39 | ipython==7.1.1 40 | ipython-genutils==0.2.0 41 | itsdangerous==1.1.0 42 | jedi==0.13.1 43 | Jinja2==2.10 44 | jmespath==0.9.3 45 | jsonnet==0.11.2 46 | MarkupSafe==1.1.0 47 | more-itertools==4.3.0 48 | msgpack==0.5.6 49 | msgpack-numpy==0.4.3.2 50 | murmurhash==1.0.1 51 | nltk==3.4.5 52 | numpy==1.15.4 53 | numpydoc==0.8.0 54 | overrides==1.9 55 | packaging==18.0 56 | parsimonious==0.8.1 57 | parso==0.3.1 58 | pexpect==4.6.0 59 | pickleshare==0.7.5 60 | plac==0.9.6 61 | pluggy==0.8.0 62 | preshed==2.0.1 63 | prompt-toolkit==2.0.7 64 | protobuf==3.6.1 65 | psycopg2==2.7.6.1 66 | ptyprocess==0.6.0 67 | py==1.7.0 68 | pyasn1==0.4.4 69 | pycparser==2.19 70 | Pygments==2.3.0 71 | pyhocon==0.3.48 72 | pyparsing==2.3.0 73 | pytest==4.0.1 74 | python-dateutil==2.7.5 75 | pytz==2017.3 76 | PyYAML==5.1 77 | regex==2017.11.9 78 | requests==2.20.1 79 | responses==0.10.4 80 | rsa==3.4.2 81 | s3transfer==0.1.13 82 | scikit-learn==0.20.1 83 | scipy==1.1.0 84 | sentencepiece==0.1.6 85 | singledispatch==3.4.0.3 86 | six==1.11.0 87 | sklearn==0.0 88 | snowballstemmer==1.2.1 89 | spacy==2.0.17 90 | Sphinx==1.8.2 91 | sphinxcontrib-websupport==1.1.0 92 | tensorboard==1.7.0 93 | tensorboardX==1.2 94 | tensorflow-gpu==1.7.0 95 | tensorflow-hub==0.2.0 96 | termcolor==1.1.0 97 | thinc==6.12.0 98 | toolz==0.9.0 99 | torch==0.4.1 100 | torchtext==0.3.1 101 | tqdm==4.28.1 102 | traitlets==4.3.2 103 | typing==3.6.6 104 | ujson==1.35 105 | Unidecode==1.0.23 106 | urllib3==1.24.1 107 | wcwidth==0.1.7 108 | Werkzeug==0.14.1 109 | wrapt==1.10.11 110 | -------------------------------------------------------------------------------- /endtasks/util.py: -------------------------------------------------------------------------------- 1 | from embeddings.vocab import Vocab 2 | from embeddings.matrix_data import create_vocab 3 | from embeddings.indexed_field import Field 4 | from embeddings.util import load_model, get_config 5 | from embeddings.model import Pair2Vec 6 | 7 | # Get input to the encoder by concatenating representations (ELMo, charCNN etc.) specified in keys 8 | def get_encoder_input(text_field_embedder, text_field_input, keys): 9 | token_vectors = None 10 | for key in keys: 11 | tensor = text_field_input[key] 12 | embedder = getattr(text_field_embedder, 'token_embedder_{}'.format(key)) if key != 'pair2vec_tokens' else get_pair2vec_word_embeddings 13 | embedding = embedder(tensor) 14 | token_vectors = embedding if token_vectors is None else torch.cat((token_vectors, embedding), -1) 15 | return token_vectors 16 | 17 | # Initialize pair2vec, load from the pretrained model file, and freeze parameters 18 | def get_pair2vec(pair2vec_config_file, pair2vec_model_file): 19 | pair2vec_config = get_config(pair2vec_config_file) 20 | field = Field(batch_first=True) 21 | create_vocab(pair2vec_config, field) 22 | pair2vec_config.n_args = len(field.vocab) 23 | pair2vec = Pair2Vec(pair2vec_config, field.vocab, field.vocab) 24 | load_model(pair2vec_model_file, pair2vec) 25 | # freeze pair2vec 26 | for param in pair2vec.parameters(): 27 | param.requires_grad = False 28 | del pair2vec.represent_relations 29 | return pair2vec 30 | 31 | # Get cross-sequence pair embeddings given two sequences 32 | def get_pair_embeddings(pair2vec, seq1, seq2): 33 | (batch_size, sl1, dim), (_, sl2, _) = seq1.size(),seq2.size() 34 | seq1 = seq1.unsqueeze(2).expand(batch_size, sl1, sl2, dim).contiguous().view(-1, dim) 35 | seq2 = seq2.unsqueeze(1).expand(batch_size, sl1, sl2, dim).contiguous().view(-1, dim) 36 | pair_embeddings = pair2vec.predict_relations(seq1, seq2).contiguous().view(batch_size, sl1, sl2, dim) 37 | return pair_embeddings 38 | 39 | # Get word/argument embeddings 40 | def get_pair2vec_word_embeddings(pair2vec, tokens): 41 | batch_size, seq_len = tokens.size() 42 | argument_embedding = pair2vec.represent_arguments(tokens.view(-1, 1)).view(batch_size, seq_len, -1) 43 | return argument_embedding 44 | 45 | def get_mask(text_field_tensors, key): 46 | if text_field_tensors[key].dim() == 2: 47 | return text_field_tensors[key] > 0 48 | elif text_field_tensors[key].dim() == 3: 49 | return ((text_field_tensors[key] > 0).long().sum(dim=-1) > 0).long() 50 | else: 51 | raise NotImplementedError() 52 | -------------------------------------------------------------------------------- /embeddings/representation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import Module, Linear, Dropout, Sequential, LSTM, Embedding, GRU, ReLU, Parameter 4 | from embeddings.util import masked_softmax 5 | from torch.autograd import Variable 6 | from torch.nn.init import xavier_normal, constant 7 | from embeddings.util import pretrained_embeddings_or_xavier 8 | from embeddings.vocab import Vocab, Vectors 9 | 10 | class SpanRepresentation(Module): 11 | def __init__(self, config, d_output, vocab): 12 | super(SpanRepresentation, self).__init__() 13 | self.config = config 14 | self.vocab = vocab 15 | n_input = len(vocab) 16 | self.embedding = Embedding(n_input, config.d_embed) 17 | self.normalize_pretrained = getattr(config, 'normalize_pretrained', False) 18 | 19 | 20 | self.contextualizer = LSTMContextualizer(config) if config.n_lstm_layers > 0 else lambda x : x 21 | self.dropout = Dropout(p=config.dropout) 22 | self.head_attention = Sequential(self.dropout, Linear(2 * config.d_lstm_hidden, 1)) 23 | self.head_transform = Sequential(self.dropout, Linear(2 * config.d_lstm_hidden, d_output)) 24 | self.init() 25 | 26 | def init(self): 27 | [xavier_normal(p) for p in self.parameters() if len(p.size()) > 1] 28 | if self.vocab.vectors is not None: 29 | pretrained = normalize(self.vocab.vectors, dim=-1) if self.normalize_pretrained else self.vocab.vectors 30 | self.embedding.weight.data.copy_(pretrained) 31 | print('Copied pretrained vectors into relation span representation') 32 | else: 33 | #xavier_normal(self.embedding.weight.data) 34 | self.embedding.reset_parameters() 35 | 36 | def forward(self, inputs): 37 | text, mask = inputs 38 | text = self.dropout(self.embedding(text)) 39 | text = self.contextualizer(text) 40 | weights = masked_softmax(self.head_attention(text).squeeze(-1), mask.float()) 41 | representation = (weights.unsqueeze(2) * self.head_transform(text)).sum(dim=1) 42 | return representation 43 | 44 | 45 | class LSTMContextualizer(Module): 46 | def __init__(self, config): 47 | super(LSTMContextualizer, self).__init__() 48 | self.config = config 49 | bidirectional = getattr(config, 'bidirectional', True) 50 | self.rnn = LSTM(input_size=config.d_lstm_input, hidden_size=config.d_lstm_hidden, num_layers=config.n_lstm_layers, dropout=config.dropout, bidirectional=bidirectional) 51 | 52 | def forward(self, inputs): 53 | inputs = inputs.permute(1, 0, 2) 54 | outputs, _ = self.rnn(inputs ) # outputs: [seq_len, batch, hidden * 2] 55 | return outputs.permute(1, 0, 2) 56 | 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pair2vec: Compositional Word-Pair Embeddings for Cross-Sentence Inference 2 | ## Introduction 3 | This repository contains the code for replicating results from 4 | 5 | * [pair2vec: Compositional Word-Pair Embeddings for Cross-Sentence Inference](https://arxiv.org/abs/1810.08854) 6 | * [Mandar Joshi](https://homes.cs.washington.edu/~mandar90/), [Eunsol Choi](https://homes.cs.washington.edu/~eunsol), [Omer Levy](https://levyomer.wordpress.com/), [Dan Weld](https://www.cs.washington.edu/people/faculty/weld), and [Luke Zettlemoyer](https://www.cs.washington.edu/people/faculty/lsz) 7 | 8 | ## Getting Started 9 | * Install python3 requirements: `pip install -r requirements.txt` 10 | 11 | ## Using pretrained pair2vec embeddings 12 | * Download pretrained pair2vec: `./download_pair2vec.sh` 13 | * If you want to reproduce results from the paper on QA/NLI, please use the following: 14 | * Download and extract the pretrained models [tar file](http://nlp.cs.washington.edu/pair2vec/pretrained_models.tar.gz) 15 | * Run evaluation: 16 | ``` 17 | python -m allennlp.run evaluate [--output-file OUTPUT_FILE] 18 | --cuda-device 0 19 | --include-package endtasks 20 | ARCHIVE_FILE INPUT_FILE 21 | ``` 22 | * If you want to train your own QA/NLI model: 23 | ``` 24 | python -m allennlp.run train -s --include-package endtasks 25 | ``` 26 | See the `experiments` directory for relevant config files. 27 | 28 | ## Training your own embeddings 29 | * Download the preprocessed corpus if you want to train pair2vec from scratch: `./download_corpus.sh` 30 | * Training: This starts the training process which typically takes 7-10 days. It takes in a config file and a directory to save checkpoints. 31 | ``` 32 | python -m embeddings.train --config experiments/pair2vec_train.json --save_path 33 | ``` 34 | 35 | ## Miscellaneous 36 | * If you use the code, please cite the following paper 37 | ``` 38 | @inproceedings{joshi-etal-2019-pair2vec, 39 | title = "pair2vec: Compositional Word-Pair Embeddings for Cross-Sentence Inference", 40 | author = "Joshi, Mandar and 41 | Choi, Eunsol and 42 | Levy, Omer and 43 | Weld, Daniel and 44 | Zettlemoyer, Luke", 45 | booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)", 46 | month = jun, 47 | year = "2019", 48 | address = "Minneapolis, Minnesota", 49 | publisher = "Association for Computational Linguistics", 50 | url = "https://www.aclweb.org/anthology/N19-1362", 51 | pages = "3597--3608" 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /experiments/mnli_mism.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "snli", 4 | "token_indexers": { 5 | "pair2vec_tokens": { 6 | "type": "single_id", 7 | "lowercase_tokens": true, 8 | "namespace": "pair2vec_tokens" 9 | }, 10 | "elmo": { 11 | "type": "elmo_characters" 12 | } 13 | } 14 | }, 15 | // "train_data_path": "data/mnli/multinli_1.0_train.jsonl", 16 | // "validation_data_path": "data/mnli/multinli_1.0_dev_mismatched.jsonl", 17 | "train_data_path": "data/mnli/snli.jsonl", 18 | "validation_data_path": "data/mnli/snli.jsonl", 19 | "vocabulary": { 20 | "directory_path": "experiments/mnli_vocabulary" 21 | }, 22 | "model": { 23 | "type": "esim-pair2vec", 24 | "dropout": 0.5, 25 | "pair2vec_dropout": 0.15, 26 | "mask_key": "elmo", 27 | "pair2vec_config_file": "models/typed_2-2_normalized/saved_config.json", 28 | "pair2vec_model_file": "models/typed_2-2_normalized/best.pt", 29 | "encoder_keys": [ 30 | "elmo" 31 | ], 32 | "text_field_embedder": { 33 | "elmo": { 34 | "type": "elmo_token_embedder", 35 | "options_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json", 36 | "weight_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5", 37 | "do_layer_norm": false, 38 | "dropout": 0.0 39 | } 40 | }, 41 | "encoder": { 42 | "type": "lstm", 43 | "input_size": 1024, 44 | "hidden_size": 300, 45 | "num_layers": 1, 46 | "bidirectional": true 47 | }, 48 | "similarity_function": { 49 | "type": "dot_product" 50 | }, 51 | "projection_feedforward": { 52 | "input_dim": 3000, 53 | "hidden_dims": 300, 54 | "num_layers": 1, 55 | "activations": "relu" 56 | }, 57 | "inference_encoder": { 58 | "type": "lstm", 59 | "input_size": 300, 60 | "hidden_size": 300, 61 | "num_layers": 1, 62 | "bidirectional": true 63 | }, 64 | "output_feedforward": { 65 | "input_dim": 2400, 66 | "num_layers": 1, 67 | "hidden_dims": 300, 68 | "activations": "relu", 69 | "dropout": 0.5 70 | }, 71 | "output_logit": { 72 | "input_dim": 300, 73 | "num_layers": 1, 74 | "hidden_dims": 3, 75 | "activations": "linear" 76 | }, 77 | "initializer": [ 78 | [ 79 | ".*linear_layers.*weight", 80 | { 81 | "type": "xavier_uniform" 82 | } 83 | ], 84 | [ 85 | ".*linear_layers.*bias", 86 | { 87 | "type": "zero" 88 | } 89 | ], 90 | [ 91 | ".*weight_ih.*", 92 | { 93 | "type": "xavier_uniform" 94 | } 95 | ], 96 | [ 97 | ".*weight_hh.*", 98 | { 99 | "type": "orthogonal" 100 | } 101 | ] 102 | ] 103 | }, 104 | "iterator": { 105 | "type": "bucket", 106 | "sorting_keys": [ 107 | [ 108 | "premise", 109 | "num_tokens" 110 | ], 111 | [ 112 | "hypothesis", 113 | "num_tokens" 114 | ] 115 | ], 116 | "batch_size": 32 117 | }, 118 | "trainer": { 119 | "optimizer": { 120 | "type": "adam", 121 | "lr": 0.0004 122 | }, 123 | "validation_metric": "+accuracy", 124 | "num_serialized_models_to_keep": 2, 125 | "num_epochs": 75, 126 | "grad_norm": 10.0, 127 | "patience": 5, 128 | "cuda_device": 0, 129 | "learning_rate_scheduler": { 130 | "type": "reduce_on_plateau", 131 | "factor": 0.5, 132 | "mode": "max", 133 | "patience": 0 134 | } 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /experiments/pair2vec_nli.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "snli", 4 | "token_indexers": { 5 | "pair2vec_tokens": { 6 | "type": "single_id", 7 | "lowercase_tokens": true, 8 | "namespace": "pair2vec_tokens" 9 | }, 10 | "elmo": { 11 | "type": "elmo_characters" 12 | } 13 | } 14 | }, 15 | "train_data_path": "data/mnli/multinli_1.0_train.jsonl", 16 | "validation_data_path": "data/mnli/multinli_1.0_dev_mismatched.jsonl", 17 | // "train_data_path": "data/mnli/snli.jsonl", 18 | // "validation_data_path": "data/mnli/snli.jsonl", 19 | "vocabulary": { 20 | "directory_path": "experiments/mnli_vocabulary" 21 | }, 22 | "model": { 23 | "type": "esim-pair2vec", 24 | "dropout": 0.5, 25 | "pair2vec_dropout": 0.15, 26 | "mask_key": "elmo", 27 | "pair2vec_config_file": "models/typed_2-2_normalized/saved_config.json", 28 | "pair2vec_model_file": "models/typed_2-2_normalized/best.pt", 29 | "encoder_keys": [ 30 | "elmo" 31 | ], 32 | "text_field_embedder": { 33 | "elmo": { 34 | "type": "elmo_token_embedder", 35 | "options_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json", 36 | "weight_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5", 37 | "do_layer_norm": false, 38 | "dropout": 0.0 39 | } 40 | }, 41 | "encoder": { 42 | "type": "lstm", 43 | "input_size": 1024, 44 | "hidden_size": 300, 45 | "num_layers": 1, 46 | "bidirectional": true 47 | }, 48 | "similarity_function": { 49 | "type": "dot_product" 50 | }, 51 | "projection_feedforward": { 52 | "input_dim": 3000, 53 | "hidden_dims": 300, 54 | "num_layers": 1, 55 | "activations": "relu" 56 | }, 57 | "inference_encoder": { 58 | "type": "lstm", 59 | "input_size": 300, 60 | "hidden_size": 300, 61 | "num_layers": 1, 62 | "bidirectional": true 63 | }, 64 | "output_feedforward": { 65 | "input_dim": 2400, 66 | "num_layers": 1, 67 | "hidden_dims": 300, 68 | "activations": "relu", 69 | "dropout": 0.5 70 | }, 71 | "output_logit": { 72 | "input_dim": 300, 73 | "num_layers": 1, 74 | "hidden_dims": 3, 75 | "activations": "linear" 76 | }, 77 | "initializer": [ 78 | [ 79 | ".*linear_layers.*weight", 80 | { 81 | "type": "xavier_uniform" 82 | } 83 | ], 84 | [ 85 | ".*linear_layers.*bias", 86 | { 87 | "type": "zero" 88 | } 89 | ], 90 | [ 91 | ".*weight_ih.*", 92 | { 93 | "type": "xavier_uniform" 94 | } 95 | ], 96 | [ 97 | ".*weight_hh.*", 98 | { 99 | "type": "orthogonal" 100 | } 101 | ] 102 | ] 103 | }, 104 | "iterator": { 105 | "type": "bucket", 106 | "sorting_keys": [ 107 | [ 108 | "premise", 109 | "num_tokens" 110 | ], 111 | [ 112 | "hypothesis", 113 | "num_tokens" 114 | ] 115 | ], 116 | "batch_size": 32 117 | }, 118 | "trainer": { 119 | "optimizer": { 120 | "type": "adam", 121 | "lr": 0.0004 122 | }, 123 | "validation_metric": "+accuracy", 124 | "num_serialized_models_to_keep": 2, 125 | "num_epochs": 75, 126 | "grad_norm": 10.0, 127 | "patience": 5, 128 | "cuda_device": 0, 129 | "learning_rate_scheduler": { 130 | "type": "reduce_on_plateau", 131 | "factor": 0.5, 132 | "mode": "max", 133 | "patience": 0 134 | } 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /experiments/pair2vec_squad2.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "no_answer_squad2", 4 | "token_indexers": { 5 | "elmo": { 6 | "type": "elmo_characters" 7 | }, 8 | "pair2vec_tokens": { 9 | "type": "single_id", 10 | "lowercase_tokens": true, 11 | "namespace": "pair2vec_tokens" 12 | }, 13 | "token_characters": { 14 | "type": "characters", 15 | "character_tokenizer": { 16 | "byte_encoding": "utf-8", 17 | "end_tokens": [ 18 | 260 19 | ], 20 | "start_tokens": [ 21 | 259 22 | ] 23 | } 24 | } 25 | } 26 | }, 27 | "iterator": { 28 | "type": "bucket", 29 | "batch_size": 30, 30 | "sorting_keys": [ 31 | [ 32 | "passage", 33 | "num_tokens" 34 | ] 35 | ] 36 | }, 37 | "model": { 38 | "type": "bidaf-pair2vec", 39 | "dropout": 0.2, 40 | "initializer": [], 41 | "max_span_length": 17, 42 | "phrase_layer": { 43 | "type": "gru", 44 | "bidirectional": true, 45 | "hidden_size": 100, 46 | "input_size": 1124, 47 | "num_layers": 1 48 | }, 49 | "pair2vec_config_file": "models/typed_2-2_normalized/saved_config.json", 50 | "pair2vec_model_file": "models/typed_2-2_normalized/best.pt", 51 | "pair2vec_dropout": 0.15, 52 | "residual_encoder": { 53 | "type": "gru", 54 | "bidirectional": true, 55 | "hidden_size": 100, 56 | "input_size": 200, 57 | "num_layers": 1 58 | }, 59 | "span_end_encoder": { 60 | "type": "gru", 61 | "bidirectional": true, 62 | "hidden_size": 100, 63 | "input_size": 400, 64 | "num_layers": 1 65 | }, 66 | "span_start_encoder": { 67 | "type": "gru", 68 | "bidirectional": true, 69 | "hidden_size": 100, 70 | "input_size": 200, 71 | "num_layers": 1 72 | }, 73 | "text_field_embedder": { 74 | "token_embedders": { 75 | "elmo": { 76 | "type": "elmo_token_embedder", 77 | "do_layer_norm": false, 78 | "dropout": 0.2, 79 | "options_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json", 80 | "weight_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5" 81 | }, 82 | "token_characters": { 83 | "type": "character_encoding", 84 | "dropout": 0.2, 85 | "embedding": { 86 | "embedding_dim": 20, 87 | "num_embeddings": 262 88 | }, 89 | "encoder": { 90 | "type": "cnn", 91 | "embedding_dim": 20, 92 | "ngram_filter_sizes": [ 93 | 5 94 | ], 95 | "num_filters": 100 96 | } 97 | } 98 | } 99 | } 100 | }, 101 | "train_data_path": "data/squad/squad.json", 102 | "validation_data_path": "data/squad/squad.json", 103 | "trainer": { 104 | "cuda_device": 0, 105 | "learning_rate_scheduler": { 106 | "type": "reduce_on_plateau", 107 | "factor": 0.5, 108 | "mode": "max", 109 | "patience": 3 110 | }, 111 | "num_epochs": 30, 112 | "num_serialized_models_to_keep": 2, 113 | "optimizer": { 114 | "type": "sgd", 115 | "lr": 0.01, 116 | "momentum": 0.9 117 | }, 118 | "patience": 10, 119 | "validation_metric": "+f1" 120 | }, 121 | "vocabulary": { 122 | "directory_path": "experiments/squad_vocabulary" 123 | }, 124 | "validation_iterator": { 125 | "type": "bucket", 126 | "batch_size": 30, 127 | "sorting_keys": [ 128 | [ 129 | "passage", 130 | "num_tokens" 131 | ] 132 | ] 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /embeddings/util.py: -------------------------------------------------------------------------------- 1 | import pyhocon 2 | from argparse import ArgumentParser 3 | from torch.nn.init import xavier_normal 4 | import torch 5 | import os 6 | import logging 7 | import json 8 | import glob 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | # From AllenNLP 13 | def masked_softmax(vector, mask): 14 | """ 15 | ``torch.nn.functional.softmax(vector)`` does not work if some elements of ``vector`` should be 16 | masked. This performs a softmax on just the non-masked portions of ``vector``. Passing 17 | ``None`` in for the mask is also acceptable; you'll just get a regular softmax. 18 | We assume that both ``vector`` and ``mask`` (if given) have shape ``(batch_size, vector_dim)``. 19 | In the case that the input vector is completely masked, this function returns an array 20 | of ``0.0``. This behavior may cause ``NaN`` if this is used as the last layer of a model 21 | that uses categorical cross-entropy loss. 22 | """ 23 | if mask is None: 24 | result = torch.nn.functional.softmax(vector, dim=-1) 25 | else: 26 | # To limit numerical errors from large vector elements outside the mask, we zero these out. 27 | result = torch.nn.functional.softmax(vector * mask, dim=-1) 28 | result = result * mask 29 | result = result / (result.sum(dim=1, keepdim=True) + 1e-13) 30 | return result 31 | 32 | def load_model(resume_snapshot, model): 33 | if os.path.isfile(resume_snapshot): 34 | checkpoint = torch.load(resume_snapshot) 35 | print("Loaded checkpoint '{}' (epoch {} iter: {} train_loss: {}, dev_loss: {}, train_pos:{}, train_neg: {}, dev_pos: {}, dev_neg: {})" 36 | .format(resume_snapshot, checkpoint['epoch'], checkpoint['iterations'], checkpoint['train_loss'], checkpoint['dev_loss'], checkpoint['train_pos'], checkpoint['train_neg'], checkpoint['dev_pos'], checkpoint['dev_neg'])) 37 | model.load_state_dict(checkpoint['state_dict'], strict=True) 38 | else: 39 | # logger.info("No checkpoint found at '{}'".format(resume_snapshot)) 40 | raise ValueError("No checkpoint found at {}".format(resume_snapshot)) 41 | 42 | def resume_from(resume_snapshot, model, optimizer): 43 | if os.path.isfile(resume_snapshot): 44 | logger.info("Loading checkpoint '{}'".format(resume_snapshot)) 45 | checkpoint = torch.load(resume_snapshot) 46 | model.load_state_dict(checkpoint['state_dict']) 47 | if optimizer is not None: 48 | optimizer.load_state_dict(checkpoint['optimizer']) 49 | logger.info("Loaded checkpoint '{}' (epoch {} iter: {} train_loss: {}, dev_loss: {}, train_pos:{}, train_neg: {}, dev_pos: {}, dev_neg: {})" 50 | .format(resume_snapshot, checkpoint['epoch'], checkpoint['iterations'], checkpoint['train_loss'], checkpoint['dev_loss'], checkpoint['train_pos'], checkpoint['train_neg'], checkpoint['dev_pos'], checkpoint['dev_neg'])) 51 | return checkpoint 52 | else: 53 | logger.info("No checkpoint found at '{}'".format(resume_snapshot)) 54 | return None 55 | 56 | def save_checkpoint(config, model, optimizer, epoch, iterations, train_eval_stats, dev_eval_stats, name, remove=True): 57 | # save config 58 | config.dump_to_file(os.path.join(config.save_path, "saved_config.json")) 59 | 60 | train_loss, train_pos, train_neg = train_eval_stats.average() 61 | dev_loss, dev_pos, dev_neg = dev_eval_stats.average() if dev_eval_stats is not None else (-1.0, -1.0, -1.0) 62 | 63 | snapshot_prefix = os.path.join(config.save_path, name) 64 | snapshot_path = snapshot_prefix + '_loss_{:.6f}_iter_{}_pos_{}_neg_{}_model.pt'.format(train_loss, iterations, 65 | train_pos, train_neg) 66 | 67 | state = { 68 | 'epoch': epoch, 69 | 'iterations': iterations + 1, 70 | 'state_dict': model.state_dict(), 71 | 'train_loss': train_loss, 72 | 'dev_loss': dev_loss, 73 | 'train_pos': train_pos, 74 | 'train_neg': train_neg, 75 | 'dev_pos': dev_pos, 76 | 'dev_neg': dev_neg, 77 | 'optimizer' : optimizer.state_dict(), 78 | } 79 | torch.save(state, snapshot_path) 80 | if remove: 81 | for f in glob.glob(snapshot_prefix + '*'): 82 | if f != snapshot_path: 83 | os.remove(f) 84 | 85 | 86 | def pretrained_embeddings_or_xavier(config, embedding, vocab, namespace): 87 | pretrained_file = config.pretrained_file if hasattr(config, "pretrained_file") else None 88 | if pretrained_file is not None: 89 | pretrained_embeddings(pretrained_file, embedding, 90 | vocab, namespace) 91 | else: 92 | xavier_normal(embedding.weight.data) 93 | 94 | def pretrained_embeddings(pretrained_file, embedding, vocab, namespace): 95 | weight = _read_pretrained_embedding_file(pretrained_file, embedding.embedding_dim, 96 | vocab, namespace) 97 | embedding.weight.data.copy_(weight) 98 | 99 | 100 | def makedirs(name): 101 | """helper function for python 2 and 3 to call os.makedirs() 102 | avoiding an error if the directory to be created already exists""" 103 | import os, errno 104 | try: 105 | os.makedirs(name) 106 | except OSError as ex: 107 | if ex.errno == errno.EEXIST and os.path.isdir(name): 108 | # ignore existing directory 109 | pass 110 | else: 111 | # a different error happened 112 | raise 113 | 114 | 115 | class Config: 116 | def __init__(self, **entries): 117 | self.__dict__.update(entries) 118 | 119 | def __str__(self): 120 | string = '' 121 | for key, value in self.__dict__.items(): 122 | string += key + ': ' + str(value) + '\n' 123 | return string 124 | 125 | def dump_to_file(self, path): 126 | with open(path, 'w') as f: 127 | json.dump(self.__dict__, f, indent=4) 128 | 129 | def get_config(filename, exp_name=None, save_path=None): 130 | config_dict = pyhocon.ConfigFactory.parse_file(filename) 131 | if exp_name is not None and exp_name in config_dict: 132 | config_dict = config_dict[exp_name] 133 | config = Config(**config_dict) 134 | if save_path is not None: 135 | config.save_path = save_path 136 | return config 137 | 138 | 139 | def print_config(config): 140 | print (pyhocon.HOCONConverter.convert(config, "hocon")) 141 | 142 | 143 | def get_args(): 144 | parser = ArgumentParser(description='Relation Embeddings') 145 | parser.add_argument('--config', type=str, default="experiments.conf") 146 | parser.add_argument('--save_path', type=str) 147 | parser.add_argument('--gpu', type=int, default=0) 148 | parser.add_argument('--seed', type=int, default=45) 149 | parser.add_argument('--exp', type=str, default='multiplication') 150 | parser.add_argument('--resume_snapshot', type=str, default='') 151 | args = parser.parse_args() 152 | return args 153 | -------------------------------------------------------------------------------- /embeddings/preprocess.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from math import sqrt 3 | from random import Random 4 | from docopt import docopt 5 | import numpy as np 6 | import os 7 | from tqdm import tqdm 8 | from embeddings.vocab import Vocab 9 | import pickle 10 | from collections import defaultdict 11 | 12 | stop_words = set(['the', 'of', ',', 'in', 'and', 'to', '"', '(', ')', 'a', 'is', 'was', 'for', '.', '-', 'as', 'by', 'at', 'an', 'with', 'from', 'that', 'which', 'also', 'be', 'were', 'are', 'but', 'this', 'had', 'can', 'into', 'could', 'would', 'should', 'then', 'do', 'does', 'above', 'after', 'again', 'same', 'any', 'been']) 13 | def read_filtered_pairs(fname, vocab, thr=None, sorted_file=False): 14 | pairs_count = {} 15 | count = 1 16 | with open(fname, encoding='utf-8') as f: 17 | for line in tqdm(f): 18 | w1, w2, count = line.strip().split('\t') 19 | if thr is None or float(count) > thr: 20 | this_pair = (vocab.stoi[w1],vocab.stoi[w2]) if vocab.stoi[w1] < vocab.stoi[w2] else (vocab.stoi[w2],vocab.stoi[w1]) 21 | pairs_count[this_pair] =float(count) 22 | elif sorted_file: 23 | break 24 | total = float(sum(pairs_count.values())) 25 | for k, count in pairs_count.items(): 26 | pairs_count[k] /= total 27 | return pairs_count 28 | 29 | 30 | def read_counts(fname, vocab, thr=10): 31 | count_dict = defaultdict(int) 32 | with open(fname, encoding='utf-8') as f: 33 | for line in tqdm(f): 34 | w1, count = line.strip().split('\t') 35 | if int(count) > thr and w1 in vocab.stoi: 36 | count_dict[vocab.stoi[w1]] =float(count) 37 | total = float(sum(count_dict.values())) 38 | for k, count in count_dict.items(): 39 | count_dict[k] /= total 40 | # print('total {}, min {}'.format(total, thr / total)) 41 | return count_dict 42 | 43 | 44 | def main(): 45 | args = docopt(""" 46 | Usage: 47 | preprocess.py [options] 48 | 49 | Options: 50 | --chunk NUM The number of lines to read before dumping each matrix [default: 1000000] 51 | --win NUM Maximal number of tokens between X and Y [default: 4] 52 | --left NUM Left window size [default: 1] 53 | --right NUM Right window size [default: 1] 54 | --word_thr NUM Right window size [default: 10] 55 | --pair_thr NUM Right window size [default: 50] 56 | """) 57 | print(args) 58 | corpus_file = args[''] 59 | triplets_dir = args[''] 60 | word_count_file = args[''] 61 | pair_count_file = args[''] 62 | vocab_file = os.path.join(triplets_dir, 'vocab.txt') 63 | word_thr = int(args['--word_thr']) 64 | pair_thr = int(args['--pair_thr']) 65 | chunk = int(args['--chunk']) 66 | win = int(args['--win']) 67 | left = int(args['--left']) 68 | right = int(args['--right']) 69 | unk, pad, x_placeholder, y_placeholder = '', '', '', '' 70 | print('reading vocab from {}'.format(vocab_file)) 71 | specials = [unk, pad, x_placeholder, y_placeholder] 72 | vocab = get_vocab(vocab_file, corpus_file, specials) 73 | print('Vocab Size:', len(vocab)) 74 | chunk_i = 1 75 | matrix = [] 76 | pair_filter = read_filtered_pairs(pair_count_file, vocab, pair_thr, sorted_file=True) 77 | stop_word_ids = set([vocab.stoi[w] for w in stop_words]) 78 | keep_wordpair = keep_wordpair_by_mult 79 | word_unigram_dict = read_counts(word_count_file, vocab, word_thr) 80 | 81 | with open(corpus_file, 'r', encoding='utf-8') as f: 82 | for i_line, line in tqdm(enumerate(f)): 83 | tokens = line.strip().lower().split() 84 | token_ids = [vocab.stoi[t] for t in tokens if vocab.stoi[t] != 0] 85 | len_tokens = len(token_ids) 86 | for ix, x in enumerate(token_ids): 87 | # use ix+1 to start from adjacent word 88 | y_start = (ix + 1) 89 | y_iter = [iy for iy in range(y_start, ix + 2 + win) if iy < len_tokens and token_ids[iy] != 0] 90 | for iy in y_iter: 91 | ordered_pair = (token_ids[ix], token_ids[iy]) 92 | this_pair = (token_ids[ix], token_ids[iy]) if (token_ids[ix] < token_ids[iy]) else (token_ids[iy], token_ids[ix]) 93 | 94 | if this_pair in pair_filter and keep_wordpair(word_unigram_dict, this_pair, vocab, stop_words=stop_word_ids): 95 | contexts = token_ids[max(0, ix - left): ix] + [vocab.stoi[x_placeholder]] + token_ids[ix+1: iy] + [vocab.stoi[y_placeholder]] + token_ids[iy+1:iy+right+1] 96 | contexts += [vocab.stoi[pad]] * (left + right + win + 2 - len(contexts)) 97 | matrix += [[token_ids[ix], token_ids[iy]] + contexts] 98 | 99 | if (i_line + 1) % chunk == 0: 100 | size = len(matrix) 101 | save(matrix, triplets_dir, chunk_i) 102 | print('chunk {} len {}'.format(chunk_i, len(matrix))) 103 | matrix = [] 104 | chunk_i += 1 105 | if len(matrix) > 0: 106 | save(matrix, triplets_dir, chunk_i) 107 | 108 | 109 | def keep_wordpair_by_mult(count_dict, word_pair, vocab, thr=5e-5, stop_words=None): 110 | x, y = word_pair 111 | clamp = lambda x : x if x < 1.0 else 1.0 112 | keep_x = clamp(sqrt(thr / count_dict[x])) if x in count_dict else 0.0 113 | keep_y = clamp(sqrt(thr / count_dict[y])) if y in count_dict else 0.0 114 | random_prob = np.random.uniform() 115 | return random_prob < keep_x * keep_y 116 | 117 | def get_vocab(vocab_path, corpus_file, specials): 118 | if os.path.isfile(vocab_path): 119 | vocab = read_vocab_from_file(vocab_path, specials) 120 | else: 121 | selected = read_vocab(corpus_file) 122 | vocab = Vocab(selected, specials) 123 | save_vocab(selected, vocab_path) 124 | return vocab 125 | 126 | def read_vocab_from_file(vocab_path, specials): 127 | tokens = None 128 | with open(vocab_path) as f: 129 | text = f.read() 130 | tokens = text.rstrip().split('\n') 131 | vocab = Vocab(tokens, specials=specials) 132 | return vocab 133 | 134 | def read_vocab(corpus_file, thr=100, max_size=150000): 135 | counter = Counter() 136 | with open(corpus_file, mode='r', encoding='utf-8') as f: 137 | for i_line, line in enumerate(f): 138 | counter.update(Counter(line.strip().lower().split())) 139 | if i_line % 1000000 == 0: 140 | print(i_line) 141 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) 142 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) 143 | selected = [] 144 | for word, freq in words_and_frequencies: 145 | if freq < thr or len(selected) == max_size: 146 | break 147 | selected.append(word) 148 | return selected 149 | 150 | def save_vocab(selected, path): 151 | with open(path, 'w', encoding='utf-8') as fout: 152 | fout.write('\n'.join(selected)) 153 | 154 | 155 | def save(matrix, triplets_dir, chunk_i): 156 | np.save(triplets_dir + '/triplets_' + str(chunk_i) + '.npy', np.array(tuple(matrix), dtype=np.int32)) 157 | 158 | 159 | if __name__ == '__main__': 160 | main() 161 | 162 | -------------------------------------------------------------------------------- /embeddings/matrix_data.py: -------------------------------------------------------------------------------- 1 | from embeddings.vocab import Vocab 2 | from embeddings.indexed_field import Field 3 | from torch.autograd import Variable 4 | import torch 5 | import numpy as np 6 | 7 | from typing import Optional, Dict, Union, Sequence, Iterable, Iterator, TypeVar, List 8 | from tqdm import tqdm 9 | from collections import defaultdict, Counter, OrderedDict 10 | from embeddings import util 11 | import os 12 | import random 13 | 14 | import logging 15 | logger = logging.getLogger(__name__) 16 | 17 | # From AllenNLP 18 | class _LazyInstances(Iterable): 19 | """ 20 | An ``Iterable`` that just wraps a thunk for generating instances and calls it for 21 | each call to ``__iter__``. 22 | """ 23 | def __init__(self, instance_generator) -> None: 24 | super().__init__() 25 | self.instance_generator = instance_generator 26 | 27 | def __iter__(self): 28 | instances = self.instance_generator() 29 | return instances 30 | 31 | 32 | def smoothed_sampling(instances, alpha=None, num_neg_samples=1): 33 | unique, counts = np.unique(instances, return_counts=True, axis=0) 34 | unique_idxs = np.arange(0, unique.shape[0]) 35 | if alpha is not None: 36 | counts = np.power(counts, alpha) 37 | probs = counts.astype('float') / counts.sum() 38 | sample_idxs = np.random.choice(unique_idxs, size=instances.shape[0]*num_neg_samples, replace=True, p=probs) 39 | sample = np.take(unique, sample_idxs, axis=0) 40 | return sample 41 | 42 | def uniform_type_sampling(instances, scores_matrix, indxs_matrix): 43 | # (num_ins, topk) 44 | batch_indxs = np.take(indxs_matrix, instances, axis=0) 45 | # (num_ins, 1)) 46 | sample_idx_idxs = np.random.randint(0, indxs_matrix.shape[1], indxs_matrix.shape[0]) 47 | # import ipdb 48 | # ipdb.set_trace() 49 | # (num_ins, 1)) 50 | sample_idxs = np.take(batch_indxs, sample_idx_idxs, axis=1) 51 | return sample_idxs 52 | def batched_unigram_type_sampling(instances, scores_matrix, indxs_matrix): 53 | # (num_ins, topk) 54 | batch_indxs = np.take(indxs_matrix, instances, axis=0) 55 | # (num_ins, topk) 56 | batch_scores = np.take(scores_matrix, instances, axis=0) 57 | # (num_ins, 1)) 58 | sample_idx_idxs = torch.multinomial(torch.from_numpy(batch_scores), 1, replacement=True) 59 | # import ipdb 60 | # ipdb.set_trace() 61 | # (num_ins, 1)) 62 | sample_idxs = np.take(batch_indxs, sample_idx_idxs.cpu().numpy(), axis=1) 63 | return sample_idxs 64 | 65 | def unigram_type_sampling(instances, scores_matrix, indxs_matrix, batch_size=10000): 66 | samples = [] 67 | for i in range(0, instances.shape[0], batch_size): 68 | samples.append(batched_unigram_type_sampling(instances[i: i + batch_size], scores_matrix, indxs_matrix)) 69 | return np.concatenate(samples) 70 | 71 | 72 | def shuffled_sampling(instances): 73 | return np.random.permutation(instances) 74 | 75 | def sample_compositional(instances, alpha=None, compositional_rels=True, type_scores=None, type_indices=None, num_neg_samples=1, num_sampled_relations=1, model_type='sampling'): 76 | np.random.shuffle(instances) 77 | subjects, objects, relations = instances[:, 0], instances[:, 1], instances[:, 2:] 78 | relations = relations if compositional_rels or relations.shape[1] > 1 else relations.reshape(relations.shape[0]) 79 | sample_fn, kwargs = (smoothed_sampling, {'alpha': alpha, 'num_neg_samples': num_sampled_relations}) if alpha is not None else (shuffled_sampling, {}) 80 | sampled_relations = sample_fn(relations, **kwargs) 81 | sampled_relations = sampled_relations.reshape((relations.shape[0], relations.shape[1], num_sampled_relations)) 82 | sample_fn, kwargs = (smoothed_sampling, {'alpha': alpha, 'num_neg_samples': num_neg_samples}) if alpha is not None else (shuffled_sampling, {}) 83 | sampled_subjects, sampled_objects = sample_fn(subjects, **kwargs).reshape((instances.shape[0], num_neg_samples)), sample_fn(objects, **kwargs).reshape((instances.shape[0], num_neg_samples)) 84 | return subjects, objects, relations, sampled_relations, sampled_subjects, sampled_objects #, type_sampled_subjects, type_sampled_objects 85 | 86 | 87 | 88 | class TripletIterator(): 89 | def __init__(self, batch_size, fields, return_nl=False, limit=None, compositional_rels=True, type_scores_file=None, type_indices_file=None, num_neg_samples=1, 90 | alpha=0.75, num_sampled_relations=1, model_type='sampling'): 91 | self.batch_size = batch_size 92 | self.fields = fields 93 | self.return_nl = return_nl 94 | self.limit = limit 95 | self.alpha = alpha 96 | self.compositional_rels = compositional_rels 97 | self.num_neg_samples = num_neg_samples 98 | self.num_sampled_relations = num_sampled_relations 99 | self.model_type = model_type 100 | self.type_scores = None if type_scores_file is None else np.load(type_scores_file) 101 | self.type_indices = None if type_indices_file is None else np.load(type_indices_file) 102 | 103 | def __call__(self, data, device=-1, train=True): 104 | batches = self._create_batches(data, device, train) 105 | for batch in batches: 106 | yield batch 107 | 108 | 109 | def _create_batches(self, instance_gen, device=-1, train=True): 110 | for instances in instance_gen: 111 | start = 0 112 | sample = sample_compositional 113 | inputs = instances if (not train) else sample(instances, self.alpha, self.compositional_rels, self.type_scores, self.type_indices, self.num_neg_samples, self.num_sampled_relations, model_type=self.model_type) 114 | for num, batch_start in enumerate(range(0, inputs[0].shape[0], self.batch_size)): 115 | tensors = tuple(Variable(torch.LongTensor(x[batch_start: batch_start + self.batch_size]), requires_grad=False) for x in inputs) 116 | if device == None: 117 | tensors = tuple([t.cuda() if t is not None else None for t in tensors]) 118 | if self.return_nl: 119 | relation_nl = [] 120 | rel_index = 2 121 | for rel in inputs[rel_index][batch_start: batch_start + self.batch_size]: 122 | relation_nl += [' '.join([self.fields[rel_index].vocab.itos[j] for j in rel])] 123 | yield tensors, (relation_nl) 124 | else: 125 | yield tensors 126 | 127 | def create_vocab(config, field): 128 | vocab_path = getattr(config, 'vocab_file', os.path.join(config.triplet_dir, "vocab.txt")) 129 | tokens = None 130 | with open(vocab_path) as f: 131 | text = f.read() 132 | tokens = text.rstrip().split('\n') 133 | specials = ['', '', '', ''] if config.compositional_rels else ['', ''] 134 | init_with_pretrained = getattr(config, 'init_with_pretrained', True) 135 | vectors, vectors_cache = (None, None) if not init_with_pretrained else (getattr(config, 'word_vecs', 'fasttext.en.300d'), getattr(config, 'word_vecs_cache', 'data/fasttext')) 136 | vocab = Vocab(tokens, specials=specials, vectors=vectors, vectors_cache=vectors_cache) 137 | field.vocab = vocab 138 | 139 | def read(filenames): 140 | for fname in filenames: 141 | if os.path.isfile(fname): 142 | instances = np.load(fname) 143 | logger.info('Loading {} instances from {}'.format(instances.shape[0], fname)) 144 | yield instances 145 | 146 | def read_dev(fname, limit=None, compositional_rels=True, type_scores_file=None, type_indices_file=None, num_neg_samples=1, num_sampled_relations=1, model_type='sampling'): 147 | instances = np.load(fname) 148 | instances = instances[:limit] if limit is not None else instances 149 | logger.info('Loading {} instances from {}'.format(instances.shape[0], fname)) 150 | type_scores = None if type_scores_file is None else np.load(type_scores_file) 151 | type_indices = None if type_indices_file is None else np.load(type_indices_file) 152 | sample = sample_compositional 153 | return sample(instances, alpha=.75, compositional_rels=compositional_rels, type_scores=type_scores, type_indices=type_indices, num_neg_samples=num_neg_samples, num_sampled_relations=num_sampled_relations, model_type=model_type) 154 | 155 | def dev_data(sample): 156 | yield sample 157 | 158 | def create_dataset(config, triplet_dir=None): 159 | triplet_dir = config.triplet_dir if triplet_dir is None else triplet_dir 160 | #files = [os.path.join(config.triplet_dir, fname) for fname in os.listdir(config.triplet_dir) if fname.endswith('.npy')] 161 | files = [os.path.join(triplet_dir, 'triplets_' + str(i) + '.npy') for i in range(1, 1000)] 162 | train_data = _LazyInstances(lambda : iter(read(files[1:]))) 163 | type_scores_file = config.type_scores_file if hasattr(config, 'type_scores_file') else None 164 | type_indices_file = config.type_indices_file if hasattr(config, 'type_indices_file') else None 165 | model_type = getattr(config, 'model_type', 'sampling') 166 | validation_sample = read_dev(files[0], 500000, config.compositional_rels, type_scores_file, type_indices_file, config.num_neg_samples, config.num_sampled_relations, model_type) 167 | validation_data = _LazyInstances(lambda : iter (dev_data(validation_sample))) 168 | return train_data, validation_data 169 | 170 | def read_data(config, return_nl=False, preindex=True): 171 | args = Field(lower=True, batch_first=True) 172 | rels = Field(lower=True, batch_first=True) if config.compositional_rels else Field(batch_first=True) 173 | fields = [args, args, rels] 174 | train, dev = create_dataset(config) 175 | create_vocab(config, args) 176 | rels.vocab = args.vocab 177 | config.n_args = len(args.vocab) 178 | config.n_rels = len(rels.vocab) 179 | sample_arguments = getattr(config, "sample_arguments", True) 180 | fields = fields + [rels, args, args] if sample_arguments else fields + [rels] 181 | type_scores_file = config.type_scores_file if hasattr(config, 'type_scores_file') else None 182 | type_indices_file = config.type_indices_file if hasattr(config, 'type_indices_file') else None 183 | model_type = getattr(config, 'model_type', 'sampling') 184 | 185 | train_iterator = TripletIterator(config.train_batch_size, fields , return_nl=return_nl, 186 | compositional_rels=config.compositional_rels, type_scores_file=type_scores_file, type_indices_file=type_indices_file, num_neg_samples=config.num_neg_samples, 187 | alpha=getattr(config, 'alpha', 0.75), num_sampled_relations=getattr(config, 'num_sampled_relations', 1), model_type=model_type) 188 | dev_iterator = TripletIterator(config.dev_batch_size, fields, return_nl=return_nl, compositional_rels=config.compositional_rels, num_neg_samples=config.num_neg_samples, 189 | alpha=getattr(config, 'alpha', 0.75), num_sampled_relations=getattr(config, 'num_sampled_relations', 1), model_type=model_type) 190 | 191 | return train, dev, train_iterator, dev_iterator, args, rels 192 | -------------------------------------------------------------------------------- /embeddings/bats_analysis.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import numpy as np 3 | from torch.autograd import Variable 4 | from embeddings.model import Pair2Vec 5 | from embeddings.matrix_data import create_vocab 6 | from embeddings.indexed_field import Field 7 | from endtasks.util import get_pair2vec 8 | import torch 9 | import os 10 | import sys 11 | import fnmatch 12 | from tqdm import tqdm 13 | from random import shuffle 14 | import random 15 | from torch.nn.functional import softmax, normalize 16 | from embeddings.vocab import Vectors 17 | 18 | class DistributionalModel(torch.nn.Module): 19 | def __init__(self, vocab, dim, name='wikipedia-jan-18-model-300.vec', cache='/fasttext'): 20 | super(DistributionalModel, self).__init__() 21 | self.arg_vocab = vocab 22 | self.represent_arguments = torch.nn.Embedding(len(vocab), dim) 23 | self.represent_arguments.weight.requires_grad = False 24 | self.arg_vocab.load_vectors(Vectors(name=name, cache=cache)) 25 | pretrained = self.arg_vocab.vectors 26 | #pretrained = normalize(pretrained) 27 | self.represent_arguments.weight.data.copy_(pretrained) 28 | 29 | def forward(self): 30 | pass 31 | 32 | 33 | def predict_relations(self, subjects, objects): 34 | return subjects - objects 35 | 36 | def read_pairs(fname, vocab): 37 | pairs, idxs = [], [] 38 | oov, total = 0, 0 39 | with open(fname, encoding='utf-8') as f: 40 | id_line = 0 41 | for id_line, line in enumerate(f): 42 | try: 43 | if "\t" in line: 44 | left,right = line.lower().split("\t") 45 | else: 46 | left,right = line.lower().split() 47 | right = right.strip() 48 | if "/" in right: 49 | right=[i.strip() for i in right.split("/")] 50 | else: 51 | right=[i.strip() for i in right.split(",")] 52 | left_idx = vocab.stoi[left] 53 | right = [r for r in right if vocab.stoi[r] != 0] 54 | right_idxs = [vocab.stoi[r] for r in right] 55 | total += 1 + len(right_idxs) 56 | oov += int(left_idx == 0) + sum(int(r == 0) for r in right_idxs) 57 | if left_idx != 0 and len(right_idxs) > 0: 58 | pairs.append([left,right]) 59 | idxs.append([left_idx, right_idxs]) 60 | except: 61 | print ("error reading pairs") 62 | print ("in file", fname) 63 | print ("in line",id_line,line) 64 | exit(-1) 65 | print('oov', oov * 100.0 / total) 66 | return pairs, idxs 67 | 68 | 69 | 70 | def create_dataset(bats_dir): 71 | pairs = [] 72 | for root, dirnames, filenames in os.walk(bats_dir): 73 | for filename in fnmatch.filter(sorted(filenames), '*'): 74 | pairs += read_pairs(os.path.join(root,filename)) 75 | return pairs 76 | 77 | 78 | def predict_relations(pair, model): 79 | word1, word2 = pair 80 | word1_embedding = model.represent_arguments(word1) 81 | word2_embedding = model.represent_arguments(word2) 82 | mlp_output = model.predict_relations(word1_embedding, word2_embedding) 83 | mlp_output = normalize(mlp_output, dim=-1) 84 | return mlp_output 85 | 86 | 87 | def vocab_pair_embeddings(model, word1): 88 | # (bs, dim) 89 | word1_embedding = model.represent_arguments(word1) 90 | # (V, dim) 91 | vocab_embedding = Variable(model.represent_arguments.weight.data, requires_grad=False) 92 | bs, dim = word1_embedding.size() 93 | vocab_size, _ = vocab_embedding.size() 94 | rep_word1_embedding = word1_embedding.unsqueeze(1).expand(bs, vocab_size, dim) 95 | rep_vocab_embedding = vocab_embedding.unsqueeze(0).expand(bs, vocab_size, dim) 96 | vocab_pair_fwd = model.predict_relations(rep_word1_embedding.contiguous().view(-1, dim), rep_vocab_embedding.contiguous().view(-1, dim)).contiguous().view(bs, vocab_size, dim) 97 | vocab_pair_bwd = model.predict_relations(rep_vocab_embedding.contiguous().view(-1, dim), rep_word1_embedding.contiguous().view(-1, dim)).contiguous().view(bs, vocab_size, dim) 98 | vocab_pair = normalize(vocab_pair_fwd, dim=-1), normalize(vocab_pair_bwd, dim=-1) 99 | return vocab_pair 100 | 101 | def pairs_to_analogies(pairs): 102 | tups = [] 103 | for pair1 in pairs: 104 | for pair2 in pairs: 105 | if pair1 != pair2: 106 | tups.append((pair1[0], pair1[1][0], pair2[0], pair2[1])) 107 | shuffle(tups) 108 | get = lambda i : [x[i] for x in tups] 109 | w1, w2, w3, w4 = get(0), get(1), get(2), get(3) 110 | return Variable(torch.LongTensor(w1), requires_grad=False).cuda(), Variable(torch.LongTensor(w2).cuda(), requires_grad=False).cuda(), Variable(torch.LongTensor(w3), requires_grad=False).cuda(), w4 111 | 112 | 113 | def get_accuracy(org_scores, w4, vocab, w1, w2, w3, mask, batch_num, preds, filename): 114 | if mask is not None: 115 | mask = Variable(torch.from_numpy(mask).cuda(), requires_grad=False).float() 116 | scores = (org_scores - (org_scores.min(-1, keepdim=True)[0])) * mask 117 | else: 118 | scores = org_scores 119 | sorted_scores, indices = torch.sort(scores, descending=True, dim=-1) 120 | w1, w2, w3 = w1.data.cpu().numpy().tolist(), w2.data.cpu().numpy().tolist(), w3.data.cpu().numpy().tolist() 121 | predictions = indices[:, 0].cpu().data.numpy().tolist() 122 | acc = 0 123 | for i, (pred, gold) in enumerate(zip(predictions, w4)): 124 | ranks = indices[i].cpu().data.numpy().tolist() 125 | topk = indices[i, :10].cpu().data.numpy().tolist() 126 | topk_scores = sorted_scores[i, :10].cpu().data.numpy().tolist() 127 | topk = [vocab.itos[w] for w in topk] 128 | 129 | gold_ranks = [ranks.index(g) for g in gold] 130 | preds += [(filename, vocab.itos[w1[i]], vocab.itos[w2[i]], vocab.itos[w3[i]], '\t'.join(topk), '\t'.join([vocab.itos[g] for g in gold]))] 131 | if pred in gold: 132 | acc += 1 133 | # if batch_num < 15: 134 | # print(vocab.itos[w1[i]], ':', vocab.itos[w2[i]], '::', vocab.itos[w3[i]], ':', vocab.itos[gold[0]], min(gold_ranks), topk) 135 | topk = indices[i, :10].cpu().data.numpy().tolist() 136 | 137 | return acc 138 | 139 | def mask_out_analogy_words(file_mask, w1_batch, w2_batch, w3_batch, model): 140 | mask = np.tile(file_mask.copy(), (w1_batch.shape[0], 1)) 141 | for i, (w1, w2, w3) in enumerate(zip(w1_batch, w2_batch, w3_batch)): 142 | mask[i, w1] = 0 143 | mask[i, w2] = 0 144 | mask[i, w3] = 0 145 | return mask 146 | 147 | 148 | def get_scores(model, w1, w2, w3, batch, method='3CosAdd'): 149 | vocab_size, dim = model.represent_arguments.weight.data.size() 150 | # (bs, V, dim)) 151 | if method == '3CosAdd': 152 | vocab_emb = Variable(normalize(model.represent_arguments.weight.data, dim=-1), requires_grad=False).unsqueeze(0).expand(batch, vocab_size, dim) 153 | p1_relemb = normalize(model.represent_arguments(w3) - model.represent_arguments(w1) + model.represent_arguments(w2), dim=-1) 154 | scores = torch.bmm(vocab_emb, p1_relemb.unsqueeze(2)).squeeze(2) 155 | else: 156 | vocab_pair = vocab_pair_embeddings(model, w3) 157 | p1_fwd, p1_bwd = predict_relations((w1, w2), model), predict_relations((w2, w1), model) 158 | vocab_pair_fwd, vocab_pair_bwd = vocab_pair 159 | scores_fwd = (torch.bmm(vocab_pair_fwd, p1_fwd.unsqueeze(2)).squeeze(2)) 160 | scores_bwd = (torch.bmm(vocab_pair_bwd, p1_bwd.unsqueeze(2)).squeeze(2)) 161 | scores = (scores_fwd + scores_bwd) / 2 162 | return scores 163 | 164 | def eval_on_bats_interpolate(bats_dir, model_file, config_file, pred_file, batch=1): 165 | random.seed(10) 166 | pair2vec = get_pair2vec(config_file, model_file) 167 | vocab = pair2vec.arg_vocab 168 | distrib_model = DistributionalModel(vocab, 300) 169 | pair2vec.cuda() 170 | pair2vec.eval() 171 | distrib_model.cuda() 172 | distrib_model.eval() 173 | 174 | file_mask = np.ones(len(vocab)) 175 | correct, total = 0, 0 176 | per_cat_acc, preds = [], [] 177 | all_alpha_acc = [] 178 | for root, dirnames, filenames in os.walk(bats_dir): 179 | for filename in fnmatch.filter(sorted(filenames), '*.txt'): 180 | pairs, idxs = read_pairs(os.path.join(root,filename), pair2vec.arg_vocab) 181 | print(filename, len(idxs)) 182 | best_correct, best_alpha = 0, 0 183 | for alpha in np.linspace(0,1,11): 184 | alpha = float(alpha) 185 | print('alpha', alpha) 186 | file_correct, file_total = 0, 0 187 | all_w1, all_w2, all_w3, all_w4 = pairs_to_analogies(idxs) 188 | 189 | bs = len(all_w1) 190 | print(bs) 191 | for i in tqdm(range(0, len(all_w1), batch)): 192 | w1, w2, w3, w4 = all_w1[i:i+batch], all_w2[i:i+batch], all_w3[i:i+batch], all_w4[i:i+batch] 193 | distrib_scores = get_scores(distrib_model, w1, w2, w3, batch, method='3CosAdd') 194 | scores = get_scores(pair2vec, w1, w2, w3, batch, method='pair2vec') 195 | scores = alpha * distrib_scores + (1- alpha) * scores 196 | mask = mask_out_analogy_words(file_mask, w1.data.cpu().numpy(), w2.data.cpu().numpy(), w3.data.cpu().numpy(), None) 197 | file_correct += get_accuracy(scores, w4, pair2vec.arg_vocab, w1, w2, w3, mask, i, preds, filename) 198 | file_total += len(w4) 199 | print(filename, file_correct * 100.0 / file_total, file_correct, file_total, alpha) 200 | all_alpha_acc.append((filename, file_correct, file_total, file_correct * 100.0 / file_total, alpha)) 201 | if file_correct > best_correct: 202 | best_correct, best_alpha = file_correct, alpha 203 | file_correct = best_correct 204 | correct += file_correct 205 | total += file_total 206 | print(filename, file_correct * 100.0 / file_total, file_correct, file_total, best_alpha) 207 | print('cumulative', correct * 100.0 / total) 208 | per_cat_acc += [(filename, file_correct, file_total, best_alpha)] 209 | print('Summary') 210 | group_correct = defaultdict(int) 211 | group_total = defaultdict(int) 212 | for cat, cat_correct, cat_total, best_alpha in per_cat_acc: 213 | group_correct[cat[0]] += cat_correct 214 | group_total[cat[0]] += cat_total 215 | print(cat, cat_correct * 100.0 / cat_total, best_alpha) 216 | print('Final', correct * 100 / total) 217 | for group in group_correct.keys(): 218 | acc = group_correct[group] * 100.0 / group_total[group] 219 | print(group, acc) 220 | with open(pred_file, encoding='utf-8', mode='w') as f: 221 | for info in all_alpha_acc: 222 | f.write('\t'.join([str(x) for x in info]) + '\n') 223 | 224 | if __name__ == '__main__': 225 | bats_dir = sys.argv[1] 226 | model_dir = sys.argv[2] 227 | output_file = sys.argv[3] 228 | eval_on_bats_interpolate(bats_dir, os.path.join(model_dir, 'best.pt'), os.path.join(model_dir, 'saved_config.json'), output_file) 229 | -------------------------------------------------------------------------------- /embeddings/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from typing import Dict 4 | from torch.nn import Module, Linear, Dropout, Sequential, Embedding, LogSigmoid, ReLU 5 | from torch.nn.functional import sigmoid, logsigmoid, softmax, normalize, log_softmax 6 | from embeddings.representation import SpanRepresentation 7 | from torch.nn.init import xavier_normal 8 | from embeddings.util import pretrained_embeddings_or_xavier 9 | import numpy as np 10 | from torch.nn.functional import cosine_similarity 11 | 12 | def get_type_file(filename, vocab, indxs=False): 13 | data = np.load(filename) 14 | if len(vocab) - data.shape[0] > 0: 15 | if indxs: 16 | data = data + (len(vocab) - data.shape[0]) 17 | data = np.concatenate((np.ones((len(vocab) - data.shape[0], data.shape[1]), dtype=data.dtype), data)) 18 | return torch.from_numpy(data) 19 | 20 | class Pair2Vec(Module): 21 | def __init__(self, config, arg_vocab, rel_vocab): 22 | super(Pair2Vec, self).__init__() 23 | self.config = config 24 | self.arg_vocab = arg_vocab 25 | self.rel_vocab = rel_vocab 26 | self.compositional_rels = config.compositional_rels 27 | self.normalize_pretrained = getattr(config, 'normalize_pretrained', False) 28 | self.separate_mlr = getattr(config, 'separate_mlr', False) 29 | self.positional_rels = getattr(config, 'positional_rels', False) 30 | self.type_scores = get_type_file(config.type_scores_file, arg_vocab).cuda() if hasattr(config, 'type_scores_file') else None 31 | self.type_indices = get_type_file(config.type_indices_file, arg_vocab, indxs=True).cuda() if hasattr(config, 'type_indices_file') else None 32 | self.pad = arg_vocab.stoi[''] 33 | score_fn_str = getattr(config, 'score_function', 'dot_product') 34 | if score_fn_str == 'dot_product': 35 | self.score = (lambda predicted, observed : (predicted * observed).sum(-1)) 36 | elif score_fn_str == 'cosine': 37 | self.score = (lambda predicted, observed : cosine_similarity(predicted, observed, dim=1, eps=1e-8)) 38 | else: 39 | raise NotImplementedError() 40 | self.num_neg_samples = getattr(config, 'num_neg_samples', 1) 41 | self.num_sampled_relations = getattr(config, 'num_sampled_relations', 1) 42 | self.subword_vocab_file = getattr(config, 'subword_vocab_file', None) 43 | self.loss_weights = [('positive_loss', getattr(config, 'positive_loss', 1.0)), 44 | ('negative_rel_loss', getattr(config, 'negative_rel_loss', 1.0)), 45 | ('negative_subject_loss', getattr(config, 'negative_subject_loss', 1.0)), 46 | ('negative_object_loss', getattr(config, 'negative_object_loss', 1.0))] 47 | if self.type_scores is not None: 48 | self.loss_weights += [('type_subject_loss', getattr(config, 'type_subject_loss', 0.3)), ('type_object_loss', getattr(config, 'type_object_loss', 0.3))] 49 | self.shared_arg_embeddings = getattr(config, 'shared_arg_embeddings', True) 50 | self.represent_arguments = Embedding(config.n_args, config.d_embed) 51 | self.represent_left_argument = lambda x : self.represent_arguments(x) 52 | self.represent_right_argument = (lambda x : self.represent_arguments(x)) if self.shared_arg_embeddings else Embedding(config.n_args, config.d_embed) 53 | if config.compositional_rels: 54 | self.represent_relations = SpanRepresentation(config, config.d_rels, rel_vocab) 55 | else: 56 | raise NotImplementedError() 57 | if config.relation_predictor == 'multiplication': 58 | self.predict_relations = lambda x, y: x * y 59 | elif config.relation_predictor == 'mlp': 60 | self.predict_relations = MLP(config) 61 | else: 62 | raise Exception('Unknown relation predictor: ' + config.relation_predictor) 63 | self.init() 64 | 65 | def to_tensors(self, fields): 66 | return ((field, 1.0 - torch.eq(field, self.pad).float()) if (len(field.size()) > 1 and (self.compositional_rels)) else field for field in fields) 67 | 68 | def init(self): 69 | for arg_matrix in [self.represent_arguments, self.represent_right_argument]: 70 | if isinstance(arg_matrix, Embedding): 71 | if self.arg_vocab.vectors is not None: 72 | pretrained = normalize(self.arg_vocab.vectors, dim=-1) if self.normalize_pretrained else self.arg_vocab.vectors 73 | arg_matrix.weight.data[:, :pretrained.size(1)].copy_(pretrained) 74 | print('Copied pretrained vecs for argument matrix') 75 | else: 76 | arg_matrix.reset_parameters() 77 | 78 | def forward(self, batch): 79 | if len(batch) == 4: 80 | batch = batch + (None, None) 81 | subjects, objects, observed_relations, sampled_relations, sampled_subjects, sampled_objects = batch 82 | sampled_relations = sampled_relations.view(-1, observed_relations.size(1), 1).squeeze(-1) 83 | subjects, objects = self.to_tensors((subjects, objects)) 84 | embedded_subjects = self.represent_left_argument(subjects) 85 | embedded_objects = self.represent_right_argument(objects) 86 | predicted_relations = self.predict_relations(embedded_subjects, embedded_objects) 87 | 88 | observed_relations, sampled_relations = self.to_tensors((observed_relations, sampled_relations)) 89 | observed_relations = self.represent_relations(observed_relations) 90 | sampled_relations = self.represent_relations(sampled_relations) 91 | # score = lambda predicted, observed : (predicted * observed).sum(-1) 92 | rep_observed_relations = observed_relations.repeat(self.num_sampled_relations, 1) 93 | rep_predicted_relations = predicted_relations.repeat(self.num_sampled_relations, 1) 94 | pos_rel_scores, neg_rel_scores = self.score(predicted_relations, observed_relations), self.score(rep_predicted_relations, sampled_relations) 95 | 96 | output_dict = {} 97 | output_dict['positive_loss'] = -logsigmoid(pos_rel_scores).sum() 98 | output_dict['negative_rel_loss'] = -logsigmoid(-neg_rel_scores).sum() 99 | # fake pair loss 100 | if sampled_subjects is not None and sampled_objects is not None: 101 | # sampled_subjects, sampled_objects = self.to_tensors((sampled_subjects, sampled_objects)) 102 | sampled_subjects, sampled_objects = sampled_subjects.view(-1, 1).squeeze(-1), sampled_objects.view(-1, 1).squeeze(-1) 103 | sampled_subjects, sampled_objects = self.represent_left_argument(sampled_subjects), self.represent_right_argument(sampled_objects) 104 | rep_embedded_objects, rep_embedded_subjects = embedded_objects.repeat(self.num_neg_samples, 1), embedded_subjects.repeat(self.num_neg_samples, 1) 105 | pred_relations_for_sampled_sub = self.predict_relations(sampled_subjects, rep_embedded_objects) 106 | pred_relations_for_sampled_obj = self.predict_relations(rep_embedded_subjects, sampled_objects) 107 | rep_observed_relations = observed_relations.repeat(self.num_neg_samples, 1) 108 | output_dict['negative_subject_loss'] = -logsigmoid(-self.score(pred_relations_for_sampled_sub, rep_observed_relations)).sum() #/ self.num_neg_samples 109 | output_dict['negative_object_loss'] = -logsigmoid(-self.score(pred_relations_for_sampled_obj, rep_observed_relations)).sum() #/ self.num_neg_samples 110 | if self.type_scores is not None: 111 | # loss_weights += [('type_subject_loss', 0.3), ('type_object_loss', 0.3)] 112 | method = 'uniform' 113 | type_sampled_subjects, type_sampled_objects = self.get_type_sampled_arguments(subjects, method), self.get_type_sampled_arguments(objects, method) 114 | type_sampled_subjects, type_sampled_objects = self.represent_left_argument(type_sampled_subjects), self.represent_right_argument(type_sampled_objects) 115 | pred_relations_for_type_sampled_sub = self.predict_relations(type_sampled_subjects, embedded_objects) 116 | pred_relations_for_type_sampled_obj = self.predict_relations(embedded_subjects, type_sampled_objects) 117 | output_dict['type_subject_loss'] = -logsigmoid(-self.score(pred_relations_for_type_sampled_sub, observed_relations)).sum() 118 | output_dict['type_object_loss'] = -logsigmoid(-self.score(pred_relations_for_type_sampled_obj, observed_relations)).sum() 119 | loss = 0.0 120 | for loss_name, weight in self.loss_weights: 121 | loss += weight * output_dict[loss_name] 122 | output_dict['observed_probabilities'] = sigmoid(pos_rel_scores) 123 | output_dict['sampled_probabilities'] = sigmoid(neg_rel_scores) 124 | return predicted_relations, loss, output_dict 125 | 126 | def get_type_sampled_arguments(self, arguments, method='uniform'): 127 | argument_indices = torch.index_select(self.type_indices, 0, arguments.data) 128 | if method == 'unigram': 129 | argument_scores = torch.index_select(self.type_scores, 0, arguments.data) 130 | sampled_idx_idxs = torch.multinomial(argument_scores, 1, replacement=True).squeeze(1).cuda() 131 | sampled_idxs = torch.gather(argument_indices, 1, sampled_idx_idxs.unsqueeze(1)).squeeze(1) 132 | else: 133 | # sampled_idx_idxs = torch.randint(0, self.type_scores.size(1), size=arguments.size(0), replacement=True) 134 | sampled_idx_idxs = torch.LongTensor(arguments.size(0)).random_(0, self.type_scores.size(1)).cuda() 135 | sampled_idxs = torch.gather(argument_indices, 1, sampled_idx_idxs.unsqueeze(1)).squeeze(1) 136 | return Variable(sampled_idxs, requires_grad=False) 137 | 138 | def score(self, predicted, observed): 139 | return torch.bmm(predicted.unsqueeze(1), observed.unsqueeze(2)).squeeze(-1).squeeze(-1) 140 | 141 | 142 | 143 | class MLP(Module): 144 | def __init__(self, config): 145 | super(MLP, self).__init__() 146 | self.dropout = Dropout(p=config.dropout) 147 | self.nonlinearity = ReLU() 148 | self.normalize = normalize if getattr(config, 'normalize_args', False) else (lambda x : x) 149 | layers = getattr(config, "mlp_layers", 4) 150 | if layers == 2: 151 | self.mlp = Sequential(self.dropout, Linear(3 * config.d_args, config.d_args), self.nonlinearity, self.dropout, Linear(config.d_args, config.d_rels)) 152 | elif layers == 3: 153 | self.mlp = Sequential(self.dropout, Linear(3 * config.d_args, config.d_args), self.nonlinearity, self.dropout, Linear(config.d_args, config.d_args), self.nonlinearity, self.dropout, Linear(config.d_args, config.d_rels)) 154 | elif layers == 4: 155 | self.mlp = Sequential(self.dropout, Linear(3 * config.d_args, config.d_args), self.nonlinearity, self.dropout, Linear(config.d_args, config.d_args), self.nonlinearity, self.dropout, Linear(config.d_args, config.d_args), self.nonlinearity, self.dropout, Linear(config.d_args, config.d_rels)) 156 | else: 157 | raise NotImplementedError() 158 | 159 | def forward(self, subjects, objects): 160 | subjects = self.normalize(subjects) 161 | objects = self.normalize(objects) 162 | return self.mlp(torch.cat([subjects, objects, subjects * objects], dim=-1)) 163 | 164 | -------------------------------------------------------------------------------- /endtasks/squad2_eval.py: -------------------------------------------------------------------------------- 1 | """Official evaluation script for SQuAD version 2.0. 2 | 3 | In addition to basic functionality, we also compute additional statistics and 4 | plot precision-recall curves if an additional na_prob.json file is provided. 5 | This file is expected to map question ID's to the model's predicted probability 6 | that a question is unanswerable. 7 | """ 8 | import argparse 9 | import collections 10 | import json 11 | import numpy as np 12 | import os 13 | import re 14 | import string 15 | import sys 16 | 17 | OPTS = None 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.') 21 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.') 22 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.') 23 | parser.add_argument('--out-file', '-o', metavar='eval.json', 24 | help='Write accuracy metrics to file (default is stdout).') 25 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json', 26 | help='Model estimates of probability of no answer.') 27 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0, 28 | help='Predict "" if no-answer probability exceeds this (default = 1.0).') 29 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None, 30 | help='Save precision-recall curves to directory.') 31 | parser.add_argument('--verbose', '-v', action='store_true') 32 | if len(sys.argv) == 1: 33 | parser.print_help() 34 | sys.exit(1) 35 | return parser.parse_args() 36 | 37 | def make_qid_to_has_ans(dataset): 38 | qid_to_has_ans = {} 39 | for article in dataset: 40 | for p in article['paragraphs']: 41 | for qa in p['qas']: 42 | qid_to_has_ans[qa['id']] = bool(qa['answers']) 43 | return qid_to_has_ans 44 | 45 | def normalize_answer(s): 46 | """Lower text and remove punctuation, articles and extra whitespace.""" 47 | def remove_articles(text): 48 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 49 | return re.sub(regex, ' ', text) 50 | def white_space_fix(text): 51 | return ' '.join(text.split()) 52 | def remove_punc(text): 53 | exclude = set(string.punctuation) 54 | return ''.join(ch for ch in text if ch not in exclude) 55 | def lower(text): 56 | return text.lower() 57 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 58 | 59 | def get_tokens(s): 60 | if not s: return [] 61 | return normalize_answer(s).split() 62 | 63 | def compute_exact(a_gold, a_pred): 64 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 65 | 66 | def compute_f1(a_gold, a_pred): 67 | gold_toks = get_tokens(a_gold) 68 | pred_toks = get_tokens(a_pred) 69 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 70 | num_same = sum(common.values()) 71 | if len(gold_toks) == 0 or len(pred_toks) == 0: 72 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 73 | return int(gold_toks == pred_toks) 74 | if num_same == 0: 75 | return 0 76 | precision = 1.0 * num_same / len(pred_toks) 77 | recall = 1.0 * num_same / len(gold_toks) 78 | f1 = (2 * precision * recall) / (precision + recall) 79 | return f1 80 | 81 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 82 | return max(metric_fn(ground_truth, prediction) for ground_truth in ground_truths) 83 | 84 | def get_raw_scores(dataset, preds): 85 | exact_scores = {} 86 | f1_scores = {} 87 | for article in dataset: 88 | for p in article['paragraphs']: 89 | for qa in p['qas']: 90 | qid = qa['id'] 91 | gold_answers = [a['text'] for a in qa['answers'] 92 | if normalize_answer(a['text'])] 93 | if not gold_answers: 94 | # For unanswerable questions, only correct answer is empty string 95 | gold_answers = [''] 96 | if qid not in preds: 97 | print('Missing prediction for %s' % qid) 98 | continue 99 | a_pred = preds[qid] 100 | # Take max over all gold answers 101 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) 102 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) 103 | return exact_scores, f1_scores 104 | 105 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 106 | new_scores = {} 107 | for qid, s in scores.items(): 108 | pred_na = na_probs[qid] > na_prob_thresh 109 | if pred_na: 110 | new_scores[qid] = float(not qid_to_has_ans[qid]) 111 | else: 112 | new_scores[qid] = s 113 | return new_scores 114 | 115 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 116 | if not qid_list: 117 | total = len(exact_scores) 118 | return collections.OrderedDict([ 119 | ('exact', 100.0 * sum(exact_scores.values()) / total), 120 | ('f1', 100.0 * sum(f1_scores.values()) / total), 121 | ('total', total), 122 | ]) 123 | else: 124 | total = len(qid_list) 125 | return collections.OrderedDict([ 126 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 127 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 128 | ('total', total), 129 | ]) 130 | 131 | def merge_eval(main_eval, new_eval, prefix): 132 | for k in new_eval: 133 | main_eval['%s_%s' % (prefix, k)] = new_eval[k] 134 | 135 | def plot_pr_curve(precisions, recalls, out_image, title): 136 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post') 137 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b') 138 | plt.xlabel('Recall') 139 | plt.ylabel('Precision') 140 | plt.xlim([0.0, 1.05]) 141 | plt.ylim([0.0, 1.05]) 142 | plt.title(title) 143 | plt.savefig(out_image) 144 | plt.clf() 145 | 146 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, 147 | out_image=None, title=None): 148 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 149 | true_pos = 0.0 150 | cur_p = 1.0 151 | cur_r = 0.0 152 | precisions = [1.0] 153 | recalls = [0.0] 154 | avg_prec = 0.0 155 | for i, qid in enumerate(qid_list): 156 | if qid_to_has_ans[qid]: 157 | true_pos += scores[qid] 158 | cur_p = true_pos / float(i+1) 159 | cur_r = true_pos / float(num_true_pos) 160 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]: 161 | # i.e., if we can put a threshold after this point 162 | avg_prec += cur_p * (cur_r - recalls[-1]) 163 | precisions.append(cur_p) 164 | recalls.append(cur_r) 165 | if out_image: 166 | plot_pr_curve(precisions, recalls, out_image, title) 167 | return {'ap': 100.0 * avg_prec} 168 | 169 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, 170 | qid_to_has_ans, out_image_dir): 171 | if out_image_dir and not os.path.exists(out_image_dir): 172 | os.makedirs(out_image_dir) 173 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 174 | if num_true_pos == 0: 175 | return 176 | pr_exact = make_precision_recall_eval( 177 | exact_raw, na_probs, num_true_pos, qid_to_has_ans, 178 | out_image=os.path.join(out_image_dir, 'pr_exact.png'), 179 | title='Precision-Recall curve for Exact Match score') 180 | pr_f1 = make_precision_recall_eval( 181 | f1_raw, na_probs, num_true_pos, qid_to_has_ans, 182 | out_image=os.path.join(out_image_dir, 'pr_f1.png'), 183 | title='Precision-Recall curve for F1 score') 184 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 185 | pr_oracle = make_precision_recall_eval( 186 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans, 187 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'), 188 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)') 189 | merge_eval(main_eval, pr_exact, 'pr_exact') 190 | merge_eval(main_eval, pr_f1, 'pr_f1') 191 | merge_eval(main_eval, pr_oracle, 'pr_oracle') 192 | 193 | def histogram_na_prob(na_probs, qid_list, image_dir, name): 194 | if not qid_list: 195 | return 196 | x = [na_probs[k] for k in qid_list] 197 | weights = np.ones_like(x) / float(len(x)) 198 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) 199 | plt.xlabel('Model probability of no-answer') 200 | plt.ylabel('Proportion of dataset') 201 | plt.title('Histogram of no-answer probability: %s' % name) 202 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name)) 203 | plt.clf() 204 | 205 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 206 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 207 | cur_score = num_no_ans 208 | best_score = cur_score 209 | best_thresh = 0.0 210 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 211 | for i, qid in enumerate(qid_list): 212 | if qid not in scores: continue 213 | if qid_to_has_ans[qid]: 214 | diff = scores[qid] 215 | else: 216 | if preds[qid]: 217 | diff = -1 218 | else: 219 | diff = 0 220 | cur_score += diff 221 | if cur_score > best_score: 222 | best_score = cur_score 223 | best_thresh = na_probs[qid] 224 | return 100.0 * best_score / len(scores), best_thresh 225 | 226 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 227 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 228 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 229 | main_eval['best_exact'] = best_exact 230 | main_eval['best_exact_thresh'] = exact_thresh 231 | main_eval['best_f1'] = best_f1 232 | main_eval['best_f1_thresh'] = f1_thresh 233 | 234 | def main(): 235 | with open(OPTS.data_file) as f: 236 | dataset_json = json.load(f) 237 | dataset = dataset_json['data'] 238 | with open(OPTS.pred_file) as f: 239 | preds = json.load(f) 240 | if OPTS.na_prob_file: 241 | with open(OPTS.na_prob_file) as f: 242 | na_probs = json.load(f) 243 | else: 244 | na_probs = {k: 0.0 for k in preds} 245 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 246 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 247 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 248 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 249 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, 250 | OPTS.na_prob_thresh) 251 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, 252 | OPTS.na_prob_thresh) 253 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 254 | if has_ans_qids: 255 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 256 | merge_eval(out_eval, has_ans_eval, 'HasAns') 257 | if no_ans_qids: 258 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 259 | merge_eval(out_eval, no_ans_eval, 'NoAns') 260 | if OPTS.na_prob_file: 261 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) 262 | if OPTS.na_prob_file and OPTS.out_image_dir: 263 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, 264 | qid_to_has_ans, OPTS.out_image_dir) 265 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns') 266 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns') 267 | if OPTS.out_file: 268 | with open(OPTS.out_file, 'w') as f: 269 | json.dump(out_eval, f) 270 | else: 271 | print(json.dumps(out_eval, indent=2)) 272 | 273 | if __name__ == '__main__': 274 | OPTS = parse_args() 275 | if OPTS.out_image_dir: 276 | import matplotlib 277 | matplotlib.use('Agg') 278 | import matplotlib.pyplot as plt 279 | main() 280 | 281 | 282 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2017 Kenton Lee 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /embeddings/train.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.optim as optim 7 | from torch.nn.utils import clip_grad_norm 8 | from tensorboardX import SummaryWriter 9 | 10 | from embeddings.model import Pair2Vec 11 | from embeddings.matrix_data import read_data 12 | from embeddings.util import get_args, get_config, makedirs 13 | from embeddings import metrics, util 14 | import numpy 15 | 16 | import logging 17 | from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau 18 | 19 | format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 20 | logging.basicConfig(format=format, level=logging.INFO) 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def prepare_env(args, config): 25 | # logging 26 | mode = 'a' if args.resume_snapshot else 'w' 27 | fh = logging.FileHandler(os.path.join(config.save_path, 'stdout.log'), mode=mode) 28 | fh.setFormatter(logging.Formatter(format)) 29 | logger.addHandler(fh) 30 | 31 | # add seeds 32 | seed = args.seed 33 | numpy.random.seed(seed) 34 | torch.manual_seed(seed) 35 | # Seed all GPUs with the same seed if available. 36 | if torch.cuda.is_available(): 37 | torch.cuda.manual_seed_all(seed) 38 | 39 | def main(args, config): 40 | prepare_env(args, config) 41 | train_data, dev_data, train_iterator, dev_iterator, args_field, rels_field = read_data(config, preindex=True) 42 | 43 | model_type = getattr(config, 'model_type', 'sampling') 44 | if model_type == 'sampling': 45 | model = Pair2Vec(config, args_field.vocab, rels_field.vocab) 46 | else: 47 | raise NotImplementedError() 48 | 49 | model.cuda() 50 | params = filter(lambda p: p.requires_grad, model.parameters()) 51 | opt = optim.SGD(params, lr=config.lr) 52 | 53 | checkpoint = None 54 | if args.resume_snapshot: 55 | checkpoint = util.resume_from(args.resume_snapshot, model, opt) 56 | 57 | writer = SummaryWriter(comment="_" + args.exp) 58 | 59 | train(train_data, dev_data, train_iterator, dev_iterator, model, config, opt, writer, checkpoint) 60 | 61 | writer.export_scalars_to_json("./all_scalars.json") 62 | writer.close() 63 | 64 | def get_lr(optimizer): 65 | lr=[] 66 | for param_group in optimizer.param_groups: 67 | lr +=[ param_group['lr'] ] 68 | return lr 69 | 70 | def train(train_data, dev_data, train_iterator, dev_iterator, model, config, opt, writer, checkpoint=None): 71 | 72 | logger.info( model) 73 | start = time.time() 74 | best_dev_loss, best_train_loss = 1000, 1000 75 | 76 | makedirs(config.save_path) 77 | stats_logger = StatsLogger(writer, start, 0) 78 | 79 | iterations = 0 if checkpoint is None else checkpoint['iterations'] 80 | start_epoch = 0 if checkpoint is None else checkpoint['epoch'] 81 | #scheduler = StepLR(opt, step_size=1, gamma=0.9) 82 | scheduler = ReduceLROnPlateau(opt, mode='min', factor=0.9, patience=10, verbose=True, threshold=0.001) 83 | 84 | logger.info('LR: {}'.format(get_lr(opt))) 85 | logger.info(' Time Epoch Iteration Progress Loss Dev_Loss Train_Pos Train_Neg Dev_Pos Dev_Neg') 86 | 87 | dev_eval_stats = None 88 | #import ipdb 89 | #ipdb.set_trace() 90 | for epoch in range(start_epoch, config.epochs): 91 | # train_iter.init_epoch() 92 | train_eval_stats = EvaluationStatistics(config) 93 | 94 | for batch_index, batch in enumerate(train_iterator(train_data, device=None, train=True)): 95 | # Switch model to training mode, clear gradient accumulators 96 | model.train() 97 | opt.zero_grad() 98 | iterations += 1 99 | 100 | # forward pass 101 | answer, loss, output_dict = model(batch) 102 | 103 | # backpropagate and update optimizer learning rate 104 | loss.backward() 105 | 106 | # grad clipping 107 | rescale_gradients(model, config.grad_norm) 108 | opt.step() 109 | 110 | # aggregate training error 111 | train_eval_stats.update(loss, output_dict) 112 | 113 | 114 | # evaluate performance on validation set periodically 115 | if iterations % config.dev_every == 0: 116 | model.eval() 117 | dev_eval_stats = EvaluationStatistics(config) 118 | for dev_batch_index, dev_batch in (enumerate(dev_iterator(dev_data, device=None, train=False))): 119 | answer, loss, dev_output_dict = model(dev_batch) 120 | dev_eval_stats.update(loss, dev_output_dict) 121 | 122 | scheduler.step(train_eval_stats.average()[0]) 123 | stats_logger.log( epoch, iterations, batch_index, train_eval_stats, dev_eval_stats) 124 | stats_logger.epoch_log(epoch, iterations, train_eval_stats, dev_eval_stats) 125 | 126 | # update best validation set accuracy 127 | train_loss = train_eval_stats.average()[0] 128 | if train_loss < best_train_loss: 129 | best_train_loss = train_loss 130 | util.save_checkpoint(config, model, opt, epoch, iterations, train_eval_stats, dev_eval_stats, 'best_train_snapshot') 131 | 132 | # reset train stats 133 | train_eval_stats = EvaluationStatistics(config) 134 | logger.info('LR: {}'.format(get_lr(opt))) 135 | 136 | elif iterations % config.log_every == 0: 137 | stats_logger.log( epoch, iterations, batch_index, train_eval_stats, dev_eval_stats) 138 | train_loss = train_eval_stats.average()[0] 139 | util.save_checkpoint(config, model, opt, epoch, iterations, train_eval_stats, dev_eval_stats, 'epoch_train_snapshot', remove=False) 140 | 141 | 142 | def rescale_gradients(model, grad_norm): 143 | parameters_to_clip = [p for p in model.parameters() if p.grad is not None] 144 | clip_grad_norm(parameters_to_clip, grad_norm) 145 | 146 | 147 | def save(config, model, loss, iterations, name): 148 | snapshot_prefix = os.path.join(config.save_path, name) 149 | snapshot_path = snapshot_prefix + '_loss_{:.6f}_iter_{}_model.pt'.format(loss.data[0], iterations) 150 | torch.save(model.state_dict(), snapshot_path) 151 | for f in glob.glob(snapshot_prefix + '*'): 152 | if f != snapshot_path: 153 | os.remove(f) 154 | 155 | 156 | class EvaluationStatistics: 157 | 158 | def __init__(self, config): 159 | self.n_examples = 0 160 | self.loss = 0.0 161 | self.pos_from_observed = 0.0 162 | self.pos_from_sampled = 0.0 163 | self.threshold = config.threshold 164 | self.pos_pred = 0.0 165 | self.neg_pred = 0.0 166 | self.positive_loss = 0 167 | self.neg_sub_loss = 0 168 | self.neg_obj_loss = 0 169 | self.neg_rel_loss = 0 170 | self.type_obj_loss = 0 171 | self.type_sub_loss = 0 172 | self.num_neg_samples = config.num_neg_samples 173 | self.num_sampled_relations = config.num_sampled_relations 174 | self.config = config 175 | 176 | def update(self, loss, output_dict): 177 | self.loss += loss.item() 178 | self.positive_loss += output_dict['positive_loss'].item() 179 | self.neg_sub_loss += output_dict['negative_subject_loss'].item() if 'negative_subject_loss' in output_dict else self.neg_sub_loss 180 | self.neg_obj_loss += output_dict['negative_object_loss'].item() if 'negative_object_loss' in output_dict else self.neg_obj_loss 181 | 182 | self.type_sub_loss += output_dict['type_subject_loss'].item() if 'type_subject_loss' in output_dict else self.type_sub_loss 183 | self.type_obj_loss += output_dict['type_object_loss'].item() if 'type_object_loss' in output_dict else self.type_obj_loss 184 | self.neg_rel_loss += output_dict['negative_rel_loss'].item() if 'negative_rel_loss' in output_dict else self.neg_rel_loss 185 | if 'observed_probabilities' in output_dict: 186 | observed_probabilities = output_dict['observed_probabilities'] 187 | self.n_examples += observed_probabilities.size()[0] 188 | sampled_probabilities = output_dict['sampled_probabilities'] 189 | self.pos_pred += metrics.positive_predictions_for(observed_probabilities, self.threshold) 190 | self.neg_pred += metrics.positive_predictions_for(sampled_probabilities, self.threshold) 191 | else: 192 | self.n_examples += 1 193 | 194 | def average(self): 195 | return self.loss / self.n_examples, self.pos_pred / self.n_examples, (self.neg_pred / self.n_examples) / self.num_sampled_relations 196 | 197 | def average_loss(self): 198 | return self.positive_loss / self.n_examples, self.neg_sub_loss / self.n_examples, self.neg_obj_loss / self.n_examples, self.neg_rel_loss / self.n_examples, self.type_sub_loss / self.n_examples, self.type_obj_loss / self.n_examples 199 | 200 | 201 | class StatsLogger: 202 | 203 | def __init__(self, writer, start, batches_per_epoch): 204 | self.log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f},{:>8.6f},{:8.6f},{:12.4f},{:12.4f},{:12.4f},{:12.4f}'.split(',')) 205 | self.writer = writer 206 | self.start = start 207 | self.batches_per_epoch = batches_per_epoch 208 | 209 | def log(self, epoch, iterations, batch_index, train_eval_stats, dev_eval_stats=None): 210 | train_loss, train_pos, train_neg = train_eval_stats.average() 211 | dev_loss, dev_pos, dev_neg = dev_eval_stats.average() if dev_eval_stats is not None else (-1.0, -1.0, -1.0) 212 | logger.info(self.log_template.format( 213 | time.time() - self.start, 214 | epoch, 215 | iterations, 216 | batch_index + 1, 217 | self.batches_per_epoch, 218 | train_loss, 219 | dev_loss, 220 | train_pos, 221 | train_neg, 222 | dev_pos, 223 | dev_neg)) 224 | 225 | self.writer.add_scalar('Train_Loss', train_loss, iterations) 226 | self.writer.add_scalar('Dev_Loss', dev_loss, iterations) 227 | self.writer.add_scalar('Train_Pos.', train_pos, iterations) 228 | self.writer.add_scalar('Train_Neg.', train_neg, iterations) 229 | self.writer.add_scalar('Dev_Pos.', dev_pos, iterations) 230 | self.writer.add_scalar('Dev_Neg.', dev_neg, iterations) 231 | # pos_loss, neg_sub_loss, neg_obj_loss, neg_rel_loss = train_eval_stats.average_loss() 232 | # logger.info('pos_loss {:.3f}, neg_sub_loss {:.3f}, neg_obj_loss {:.3f}, neg_rel_loss {:.3f}'.format(pos_loss, neg_sub_loss, neg_obj_loss, neg_rel_loss)) 233 | 234 | def epoch_log(self, epoch, iterations, train_eval_stats, dev_eval_stats): 235 | train_loss, train_pos, train_neg = train_eval_stats.average() 236 | dev_loss, dev_pos, dev_neg = dev_eval_stats.average() 237 | pos_loss, neg_sub_loss, neg_obj_loss, neg_rel_loss, type_sub_loss, type_obj_loss = train_eval_stats.average_loss() 238 | 239 | logger.info("In epoch {}".format(epoch)) 240 | logger.info("Epoch:{}, iter:{}, train loss: {:.6f}, dev loss:{:.6f}, train pos:{:.4f}, train neg:{:.4f}, dev pos: {:.4f} dev neg: {:.4f}".format(epoch, iterations, train_loss, dev_loss, train_pos, train_neg, dev_pos, dev_neg)) 241 | logger.info('pos_loss {:.3f}, neg_sub_loss {:.3f}, neg_obj_loss {:.3f}, neg_rel_loss {:.3f}, type_sub_loss {:.3f}, type_obj_loss {:.3f}'.format(pos_loss, neg_sub_loss, neg_obj_loss, neg_rel_loss, type_sub_loss, type_obj_loss)) 242 | 243 | 244 | if __name__ == "__main__": 245 | args = get_args() 246 | print("Running experiment:", args.exp) 247 | arg_save_path = args.save_path if hasattr(args, "save_path") else None 248 | config = get_config(args.config, args.exp, arg_save_path) 249 | print(config) 250 | torch.cuda.set_device(args.gpu) 251 | main(args, config) 252 | -------------------------------------------------------------------------------- /endtasks/squad2_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import Dict, List, Tuple, Any 4 | from allennlp.data.fields import Field, TextField, IndexField, MetadataField, ListField, SpanField 5 | from collections import Counter 6 | from overrides import overrides 7 | 8 | from allennlp.common import Params 9 | from allennlp.common.file_utils import cached_path 10 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 11 | from allennlp.data.instance import Instance 12 | from allennlp.data.dataset_readers.reading_comprehension import util 13 | from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer 14 | from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer 15 | 16 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 17 | 18 | 19 | def make_reading_comprehension_instance(question_tokens: List[Token], 20 | passage_tokens: List[Token], 21 | token_indexers: Dict[str, TokenIndexer], 22 | passage_text: str, 23 | question_id: str, 24 | token_spans: List[Tuple[int, int]] = None, 25 | answer_texts: List[str] = None, 26 | additional_metadata: Dict[str, Any] = None) -> Instance: 27 | """ 28 | Converts a question, a passage, and an optional answer (or answers) to an ``Instance`` for use 29 | in a reading comprehension model. 30 | 31 | Creates an ``Instance`` with at least these fields: ``question`` and ``passage``, both 32 | ``TextFields``; and ``metadata``, a ``MetadataField``. Additionally, if both ``answer_texts`` 33 | and ``char_span_starts`` are given, the ``Instance`` has ``span_start`` and ``span_end`` 34 | fields, which are both ``IndexFields``. 35 | 36 | Parameters 37 | ---------- 38 | question_tokens : ``List[Token]`` 39 | An already-tokenized question. 40 | passage_tokens : ``List[Token]`` 41 | An already-tokenized passage that contains the answer to the given question. 42 | token_indexers : ``Dict[str, TokenIndexer]`` 43 | Determines how the question and passage ``TextFields`` will be converted into tensors that 44 | get input to a model. See :class:`TokenIndexer`. 45 | passage_text : ``str`` 46 | The original passage text. We need this so that we can recover the actual span from the 47 | original passage that the model predicts as the answer to the question. This is used in 48 | official evaluation scripts. 49 | token_spans : ``List[Tuple[int, int]]``, optional 50 | Indices into ``passage_tokens`` to use as the answer to the question for training. This is 51 | a list because there might be several possible correct answer spans in the passage. 52 | Currently, we just select the most frequent span in this list (i.e., SQuAD has multiple 53 | annotations on the dev set; this will select the span that the most annotators gave as 54 | correct). 55 | answer_texts : ``List[str]``, optional 56 | All valid answer strings for the given question. In SQuAD, e.g., the training set has 57 | exactly one answer per question, but the dev and test sets have several. TriviaQA has many 58 | possible answers, which are the aliases for the known correct entity. This is put into the 59 | metadata for use with official evaluation scripts, but not used anywhere else. 60 | additional_metadata : ``Dict[str, Any]``, optional 61 | The constructed ``metadata`` field will by default contain ``original_passage``, 62 | ``token_offsets``, ``question_tokens``, ``passage_tokens``, and ``answer_texts`` keys. If 63 | you want any other metadata to be associated with each instance, you can pass that in here. 64 | This dictionary will get added to the ``metadata`` dictionary we already construct. 65 | """ 66 | additional_metadata = additional_metadata or {} 67 | fields: Dict[str, Field] = {} 68 | passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens] 69 | 70 | # This is separate so we can reference it later with a known type. 71 | passage_field = TextField(passage_tokens, token_indexers) 72 | fields['passage'] = passage_field 73 | fields['question'] = TextField(question_tokens, token_indexers) 74 | metadata = { 75 | 'original_passage': passage_text, 76 | 'token_offsets': passage_offsets, 77 | 'question_id': question_id, 78 | 'question_tokens': [token.text for token in question_tokens], 79 | 'passage_tokens': [token.text for token in passage_tokens], 80 | } 81 | if answer_texts is None or len(answer_texts) > 0: 82 | metadata['answer_texts'] = answer_texts 83 | else: 84 | metadata['answer_texts'] = [''] 85 | 86 | 87 | if token_spans: 88 | # There may be multiple answer annotations, so we pick the one that occurs the most. This 89 | # only matters on the SQuAD dev set, and it means our computed metrics ("start_acc", 90 | # "end_acc", and "span_acc") aren't quite the same as the official metrics, which look at 91 | # all of the annotations. This is why we have a separate official SQuAD metric calculation 92 | # (the "em" and "f1" metrics use the official script). 93 | candidate_answers: Counter = Counter() 94 | token_spans = list(set(token_spans)) 95 | span_fields = [] 96 | # print(answer_texts, passage_tokens[token_spans[0][0]: token_spans[0][1] + 1]) 97 | span_fields = ListField([SpanField(start, end, passage_field) 98 | for start, end in token_spans]) 99 | else: 100 | span_fields = ListField([SpanField(len(passage_tokens) - 1, len(passage_tokens) - 1, passage_field)]) 101 | 102 | fields['spans'] = span_fields 103 | metadata.update(additional_metadata) 104 | fields['metadata'] = MetadataField(metadata) 105 | return Instance(fields) 106 | 107 | 108 | @DatasetReader.register("no_answer_squad2") 109 | class NoAnswerSquad2Reader(DatasetReader): 110 | """ 111 | Reads a JSON-formatted SQuAD file and returns a ``Dataset`` where the ``Instances`` have four 112 | fields: ``question``, a ``TextField``, ``passage``, another ``TextField``, and ``span_start`` 113 | and ``span_end``, both ``IndexFields`` into the ``passage`` ``TextField``. We also add a 114 | ``MetadataField`` that stores the instance's ID, the original passage text, gold answer strings, 115 | and token offsets into the original passage, accessible as ``metadata['id']``, 116 | ``metadata['original_passage']``, ``metadata['answer_texts']`` and 117 | ``metadata['token_offsets']``. This is so that we can more easily use the official SQuAD 118 | evaluation script to get metrics. 119 | 120 | Parameters 121 | ---------- 122 | multiparagraph : ``bool``, optional (default=``False``) 123 | If ``True``, uses ``util.make_multi_paragraph_reading_comprehension_instance`` to create 124 | a "multi-paragraph" instance (but with only one paragraph) with ``"paragraphs"`` being a 125 | ``ListField[TextField]``. Otherwise creates an instance with ``"passage"`` being a ``TextField``. 126 | tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``) 127 | We use this ``Tokenizer`` for both the question and the passage. See :class:`Tokenizer`. 128 | Default is ```WordTokenizer()``. 129 | token_indexers : ``Dict[str, TokenIndexer]``, optional 130 | We similarly use this for both the question and the passage. See :class:`TokenIndexer`. 131 | Default is ``{"tokens": SingleIdTokenIndexer()}``. 132 | """ 133 | def __init__(self, 134 | tokenizer: Tokenizer = None, 135 | token_indexers: Dict[str, TokenIndexer] = None, 136 | lazy: bool = False) -> None: 137 | super().__init__(lazy) 138 | self._tokenizer = tokenizer or WordTokenizer() 139 | self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} 140 | 141 | @overrides 142 | def _read(self, file_path: str): 143 | # if `file_path` is a URL, redirect to the cache 144 | no_answer_token = 'CANNOTANSWER' 145 | file_path = cached_path(file_path) 146 | 147 | logger.info("Reading file at %s", file_path) 148 | with open(file_path) as dataset_file: 149 | dataset_json = json.load(dataset_file) 150 | dataset = dataset_json['data'] 151 | logger.info("Reading the dataset") 152 | for article in dataset: 153 | for paragraph_json in article['paragraphs']: 154 | paragraph = paragraph_json["context"] + ' ' + no_answer_token 155 | tokenized_paragraph = self._tokenizer.tokenize(paragraph) # + [''] 156 | 157 | for question_answer in paragraph_json['qas']: 158 | question_text = question_answer["question"].strip().replace("\n", "") 159 | answers = [{'text': no_answer_token, 'answer_start': len(paragraph) - len(no_answer_token)}] if len(question_answer['answers']) == 0 else question_answer['answers'] 160 | answer_texts = [answer['text'] for answer in answers] 161 | span_starts = [answer['answer_start'] for answer in answers] 162 | span_ends = [start + len(answer) for start, answer in zip(span_starts, answer_texts)] 163 | # print(answer_texts, span_starts, span_ends) 164 | instance = self.text_to_instance(question_text, 165 | paragraph, 166 | question_answer['id'], 167 | zip(span_starts, span_ends), 168 | answer_texts, 169 | tokenized_paragraph) 170 | yield instance 171 | 172 | @overrides 173 | def text_to_instance(self, # type: ignore 174 | question_text: str, 175 | passage_text: str, 176 | question_id: str, 177 | char_spans: List[Tuple[int, int]] = None, 178 | answer_texts: List[str] = None, 179 | passage_tokens: List[Token] = None) -> Instance: 180 | # pylint: disable=arguments-differ 181 | if not passage_tokens: 182 | passage_tokens = self._tokenizer.tokenize(passage_text) 183 | char_spans = char_spans or [] 184 | 185 | # We need to convert character indices in `passage_text` to token indices in 186 | # `passage_tokens`, as the latter is what we'll actually use for supervision. 187 | token_spans: List[Tuple[int, int]] = [] 188 | passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens] 189 | for char_span_start, char_span_end in char_spans: 190 | (span_start, span_end), error = util.char_span_to_token_span(passage_offsets, 191 | (char_span_start, char_span_end)) 192 | if error: 193 | logger.debug("Passage: %s", passage_text) 194 | logger.debug("Passage tokens: %s", passage_tokens) 195 | logger.debug("Question text: %s", question_text) 196 | logger.debug("Answer span: (%d, %d)", char_span_start, char_span_end) 197 | logger.debug("Token span: (%d, %d)", span_start, span_end) 198 | logger.debug("Tokens in answer: %s", passage_tokens[span_start:span_end + 1]) 199 | logger.debug("Answer: %s", passage_text[char_span_start:char_span_end]) 200 | token_spans.append((span_start, span_end)) 201 | 202 | return make_reading_comprehension_instance( 203 | self._tokenizer.tokenize(question_text), 204 | passage_tokens, 205 | self._token_indexers, 206 | passage_text, 207 | question_id, 208 | token_spans, 209 | answer_texts) 210 | 211 | # @classmethod 212 | # def from_params(cls, params: Params) -> 'Squad2Reader': 213 | # tokenizer = Tokenizer.from_params(params.pop('tokenizer', {})) 214 | # token_indexers = TokenIndexer.dict_from_params(params.pop('token_indexers', {})) 215 | # lazy = params.pop('lazy', False) 216 | # params.assert_empty(cls.__name__) 217 | # return cls(tokenizer=tokenizer, token_indexers=token_indexers, lazy=lazy) 218 | -------------------------------------------------------------------------------- /endtasks/esim_pair2vec.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, List 2 | from torch.nn import Linear 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.nn.functional import normalize 6 | from allennlp.common import Params 7 | from allennlp.common.checks import check_dimensions_match 8 | from allennlp.data import Vocabulary 9 | from allennlp.models.model import Model 10 | from allennlp.modules import FeedForward 11 | from allennlp.modules import Seq2SeqEncoder, SimilarityFunction, TimeDistributed, TextFieldEmbedder 12 | from allennlp.nn import InitializerApplicator, RegularizerApplicator 13 | from allennlp.nn.util import get_text_field_mask, last_dim_softmax, weighted_sum, replace_masked_values 14 | from allennlp.training.metrics import CategoricalAccuracy 15 | from endtasks import util 16 | from endtasks.modules import VariationalDropout 17 | 18 | 19 | @Model.register("esim-pair2vec") 20 | class ESIMPair2Vec(Model): 21 | """ 22 | This ``Model`` implements the ESIM sequence model described in `"Enhanced LSTM for Natural Language Inference" 23 | `_ 24 | by Chen et al., 2017. 25 | 26 | Parameters 27 | ---------- 28 | vocab : ``Vocabulary`` 29 | text_field_embedder : ``TextFieldEmbedder`` 30 | Used to embed the ``premise`` and ``hypothesis`` ``TextFields`` we get as input to the 31 | model. 32 | attend_feedforward : ``FeedForward`` 33 | This feedforward network is applied to the encoded sentence representations before the 34 | similarity matrix is computed between words in the premise and words in the hypothesis. 35 | similarity_function : ``SimilarityFunction`` 36 | This is the similarity function used when computing the similarity matrix between words in 37 | the premise and words in the hypothesis. 38 | compare_feedforward : ``FeedForward`` 39 | This feedforward network is applied to the aligned premise and hypothesis representations, 40 | individually. 41 | aggregate_feedforward : ``FeedForward`` 42 | This final feedforward network is applied to the concatenated, summed result of the 43 | ``compare_feedforward`` network, and its output is used as the entailment class logits. 44 | premise_encoder : ``Seq2SeqEncoder``, optional (default=``None``) 45 | After embedding the premise, we can optionally apply an encoder. If this is ``None``, we 46 | will do nothing. 47 | hypothesis_encoder : ``Seq2SeqEncoder``, optional (default=``None``) 48 | After embedding the hypothesis, we can optionally apply an encoder. If this is ``None``, 49 | we will use the ``premise_encoder`` for the encoding (doing nothing if ``premise_encoder`` 50 | is also ``None``). 51 | initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) 52 | Used to initialize the model parameters. 53 | regularizer : ``RegularizerApplicator``, optional (default=``None``) 54 | If provided, will be used to calculate the regularization penalty during training. 55 | """ 56 | def __init__(self, vocab: Vocabulary, 57 | encoder_keys: List[str], 58 | mask_key: str, 59 | pair2vec_config_file: str, 60 | pair2vec_model_file: str, 61 | text_field_embedder: TextFieldEmbedder, 62 | encoder: Seq2SeqEncoder, 63 | similarity_function: SimilarityFunction, 64 | projection_feedforward: FeedForward, 65 | inference_encoder: Seq2SeqEncoder, 66 | output_feedforward: FeedForward, 67 | output_logit: FeedForward, 68 | initializer: InitializerApplicator = InitializerApplicator(), 69 | dropout: float = 0.5, 70 | pair2vec_dropout: float = 0.0, 71 | bidirectional_pair2vec: bool = True, 72 | regularizer: Optional[RegularizerApplicator] = None) -> None: 73 | super().__init__(vocab, regularizer) 74 | self._vocab = vocab 75 | self.pair2vec = util.get_pair2vec(pair2vec_config_file, pair2vec_model_file) 76 | self._encoder_keys = encoder_keys 77 | self._mask_key = mask_key 78 | self._text_field_embedder = text_field_embedder 79 | self._projection_feedforward = projection_feedforward 80 | self._encoder = encoder 81 | from allennlp.modules.matrix_attention import DotProductMatrixAttention 82 | 83 | self._matrix_attention = DotProductMatrixAttention() 84 | 85 | 86 | self._inference_encoder = inference_encoder 87 | self._pair2vec_dropout = torch.nn.Dropout(pair2vec_dropout) 88 | self._bidirectional_pair2vec = bidirectional_pair2vec 89 | 90 | if dropout: 91 | self.dropout = torch.nn.Dropout(dropout) 92 | self.rnn_input_dropout = VariationalDropout(dropout) 93 | else: 94 | self.dropout = None 95 | self.rnn_input_dropout = None 96 | 97 | self._output_feedforward = output_feedforward 98 | self._output_logit = output_logit 99 | 100 | self._num_labels = vocab.get_vocab_size(namespace="labels") 101 | 102 | 103 | self._accuracy = CategoricalAccuracy() 104 | self._loss = torch.nn.CrossEntropyLoss() 105 | 106 | initializer(self) 107 | 108 | def forward(self, # type: ignore 109 | premise: Dict[str, torch.LongTensor], 110 | hypothesis: Dict[str, torch.LongTensor], 111 | label: torch.IntTensor = None, 112 | metadata: Dict = None) -> Dict[str, torch.Tensor]: 113 | # pylint: disable=arguments-differ 114 | """ 115 | Parameters 116 | ---------- 117 | premise : Dict[str, torch.LongTensor] 118 | From a ``TextField`` 119 | hypothesis : Dict[str, torch.LongTensor] 120 | From a ``TextField`` 121 | label : torch.IntTensor, optional (default = None) 122 | From a ``LabelField`` 123 | 124 | Returns 125 | ------- 126 | An output dictionary consisting of: 127 | 128 | label_logits : torch.FloatTensor 129 | A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log 130 | probabilities of the entailment label. 131 | label_probs : torch.FloatTensor 132 | A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the 133 | entailment label. 134 | loss : torch.FloatTensor, optional 135 | A scalar loss to be optimised. 136 | """ 137 | embedded_premise = util.get_encoder_input(self._text_field_embedder, premise, self._encoder_keys) 138 | embedded_hypothesis = util.get_encoder_input(self._text_field_embedder, hypothesis, self._encoder_keys) 139 | premise_as_args = util.get_pair2vec_word_embeddings(self.pair2vec, premise['pair2vec_tokens']) 140 | hypothesis_as_args = util.get_pair2vec_word_embeddings(self.pair2vec, hypothesis['pair2vec_tokens']) 141 | 142 | premise_mask = util.get_mask(premise, self._mask_key).float() 143 | hypothesis_mask = util.get_mask(hypothesis, self._mask_key).float() 144 | 145 | # apply dropout for LSTM 146 | if self.rnn_input_dropout: 147 | embedded_premise = self.rnn_input_dropout(embedded_premise) 148 | embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis) 149 | 150 | # encode premise and hypothesis 151 | encoded_premise = self._encoder(embedded_premise, premise_mask) 152 | encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask) 153 | 154 | 155 | # Shape: (batch_size, premise_length, hypothesis_length) 156 | similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis) 157 | 158 | # Shape: (batch_size, premise_length, hypothesis_length) 159 | p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask) 160 | # Shape: (batch_size, premise_length, embedding_dim) 161 | attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention) 162 | 163 | # Shape: (batch_size, hypothesis_length, premise_length) 164 | h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask) 165 | # Shape: (batch_size, hypothesis_length, embedding_dim) 166 | attended_premise = weighted_sum(encoded_premise, h2p_attention) 167 | 168 | # cross sequence embeddings 169 | ph_pair_embeddings = normalize(util.get_pair_embeddings(self.pair2vec, premise_as_args, hypothesis_as_args), dim=-1) 170 | hp_pair_embeddings = normalize(util.get_pair_embeddings(self.pair2vec, hypothesis_as_args, premise_as_args), dim=-1) 171 | if self._bidirectional_pair2vec: 172 | temp = torch.cat((ph_pair_embeddings, hp_pair_embeddings.transpose(1,2)), dim=-1) 173 | hp_pair_embeddings = torch.cat((hp_pair_embeddings, ph_pair_embeddings.transpose(1,2)), dim=-1) 174 | ph_pair_embeddings = temp 175 | # pair_embeddings = torch.cat((ph_pair_embeddings, hp_pair_embeddings.transpose(1,2)), dim=-1) 176 | # pair2vec masks 177 | pair2vec_premise_mask = 1 - (torch.eq(premise['pair2vec_tokens'], 0).long() + torch.eq(premise['pair2vec_tokens'], 1).long()) 178 | pair2vec_hypothesis_mask = 1 - (torch.eq(hypothesis['pair2vec_tokens'], 0).long() + torch.eq(hypothesis['pair2vec_tokens'], 1).long()) 179 | # re-normalize attention using pair2vec masks 180 | h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), pair2vec_premise_mask) 181 | p2h_attention = last_dim_softmax(similarity_matrix, pair2vec_hypothesis_mask) 182 | 183 | attended_hypothesis_pairs = self._pair2vec_dropout(weighted_sum(ph_pair_embeddings, p2h_attention)) * pair2vec_premise_mask.float().unsqueeze(-1) 184 | attended_premise_pairs = self._pair2vec_dropout(weighted_sum(hp_pair_embeddings, h2p_attention)) * pair2vec_hypothesis_mask.float().unsqueeze(-1) 185 | 186 | 187 | # the "enhancement" layer 188 | premise_enhanced = torch.cat( 189 | [encoded_premise, attended_hypothesis, 190 | encoded_premise - attended_hypothesis, 191 | encoded_premise * attended_hypothesis, 192 | attended_hypothesis_pairs], 193 | dim=-1 194 | ) 195 | hypothesis_enhanced = torch.cat( 196 | [encoded_hypothesis, attended_premise, 197 | encoded_hypothesis - attended_premise, 198 | encoded_hypothesis * attended_premise, 199 | attended_premise_pairs], 200 | dim=-1 201 | ) 202 | 203 | projected_enhanced_premise = self._projection_feedforward(premise_enhanced) 204 | projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced) 205 | 206 | # Run the inference layer 207 | if self.rnn_input_dropout: 208 | projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise) 209 | projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis) 210 | v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask) 211 | v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask) 212 | 213 | # The pooling layer -- max and avg pooling. 214 | # (batch_size, model_dim) 215 | v_a_max, _ = replace_masked_values( 216 | v_ai, premise_mask.unsqueeze(-1), -1e7 217 | ).max(dim=1) 218 | v_b_max, _ = replace_masked_values( 219 | v_bi, hypothesis_mask.unsqueeze(-1), -1e7 220 | ).max(dim=1) 221 | 222 | v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(premise_mask, 1, keepdim=True) 223 | v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(hypothesis_mask, 1, keepdim=True) 224 | 225 | # Now concat 226 | # (batch_size, model_dim * 2 * 4) 227 | v = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) 228 | 229 | # the final MLP -- apply dropout to input, and MLP applies to output & hidden 230 | if self.dropout: 231 | v = self.dropout(v) 232 | 233 | output_hidden = self._output_feedforward(v) 234 | label_logits = self._output_logit(output_hidden) 235 | label_probs = torch.nn.functional.softmax(label_logits, dim=-1) 236 | 237 | output_dict = {"label_logits": label_logits, "label_probs": label_probs} 238 | 239 | if label is not None: 240 | loss = self._loss(label_logits, label.long().view(-1)) 241 | self._accuracy(label_logits, label) 242 | output_dict["loss"] = loss 243 | 244 | return output_dict 245 | 246 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 247 | return { 248 | 'accuracy': self._accuracy.get_metric(reset), 249 | } 250 | 251 | -------------------------------------------------------------------------------- /embeddings/vocab.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | import array 3 | from collections import defaultdict 4 | from functools import partial 5 | import io 6 | import logging 7 | import os 8 | import zipfile 9 | 10 | import six 11 | from six.moves.urllib.request import urlretrieve 12 | import torch 13 | from tqdm import tqdm 14 | import tarfile 15 | 16 | from torchtext.utils import reporthook 17 | 18 | logger = logging.getLogger(__name__) 19 | # Monkey-patch because I trained with a newer version. 20 | # This can be removed once PyTorch 0.4.x is out. 21 | # See https://discuss.pytorch.org/t/question-about-rebuild-tensor-v2/14560 22 | import torch._utils 23 | try: 24 | torch._utils._rebuild_tensor_v2 25 | except AttributeError: 26 | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): 27 | tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) 28 | tensor.requires_grad = requires_grad 29 | tensor._backward_hooks = backward_hooks 30 | return tensor 31 | torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 32 | 33 | class Vocab(object): 34 | """Defines a vocabulary object that will be used to numericalize a field. 35 | 36 | Attributes: 37 | freqs: A collections.Counter object holding the frequencies of tokens 38 | in the data used to build the Vocab. 39 | stoi: A collections.defaultdict instance mapping token strings to 40 | numerical identifiers. 41 | itos: A list of token strings indexed by their numerical identifiers. 42 | """ 43 | def __init__(self, word_list, max_size=None, min_freq=1, specials=[''], 44 | vectors=None, unk_init=None, vectors_cache=None): 45 | """Create a Vocab object from a collections.Counter. 46 | 47 | Arguments: 48 | counter: collections.Counter object holding the frequencies of 49 | each value found in the data. 50 | max_size: The maximum size of the vocabulary, or None for no 51 | maximum. Default: None. 52 | min_freq: The minimum frequency needed to include a token in the 53 | vocabulary. Values less than 1 will be set to 1. Default: 1. 54 | specials: The list of special tokens (e.g., padding or eos) that 55 | will be prepended to the vocabulary in addition to an 56 | token. Default: [''] 57 | vectors: One of either the available pretrained vectors 58 | or custom pretrained vectors (see Vocab.load_vectors); 59 | or a list of aforementioned vectors 60 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 61 | to zero vectors; can be any function that takes in a Tensor and 62 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 63 | vectors_cache: directory for cached vectors. Default: '.vector_cache' 64 | """ 65 | 66 | self.itos = list(specials) + word_list 67 | self.specials = specials 68 | self.stoi = defaultdict(_default_unk_index) 69 | # stoi is simply a reverse dict for itos 70 | self.stoi.update({tok: i for i, tok in enumerate(self.itos)}) 71 | 72 | self.vectors = None 73 | if vectors is not None: 74 | self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) 75 | print('Loaded from {} {}'.format(vectors, vectors_cache)) 76 | else: 77 | assert unk_init is None and vectors_cache is None 78 | 79 | def __eq__(self, other): 80 | if self.stoi != other.stoi: 81 | return False 82 | if self.itos != other.itos: 83 | return False 84 | if self.vectors != other.vectors: 85 | return False 86 | return True 87 | 88 | def __len__(self): 89 | return len(self.itos) 90 | 91 | def extend(self, v, sort=False): 92 | words = sorted(v.itos) if sort else v.itos 93 | for w in words: 94 | if w not in self.stoi: 95 | self.itos.append(w) 96 | self.stoi[w] = len(self.itos) - 1 97 | 98 | def load_vectors(self, vectors, **kwargs): 99 | """ 100 | Arguments: 101 | vectors: one of or a list containing instantiations of the 102 | GloVe, CharNGram, or Vectors classes. Alternatively, one 103 | of or a list of available pretrained vectors: 104 | charngram.100d 105 | fasttext.en.300d 106 | fasttext.simple.300d 107 | glove.42B.300d 108 | glove.840B.300d 109 | glove.twitter.27B.25d 110 | glove.twitter.27B.50d 111 | glove.twitter.27B.100d 112 | glove.twitter.27B.200d 113 | glove.6B.50d 114 | glove.6B.100d 115 | glove.6B.200d 116 | glove.6B.300d 117 | Remaining keyword arguments: Passed to the constructor of Vectors classes. 118 | """ 119 | if not isinstance(vectors, list): 120 | vectors = [vectors] 121 | for idx, vector in enumerate(vectors): 122 | if six.PY2 and isinstance(vector, str): 123 | vector = six.text_type(vector) 124 | if isinstance(vector, six.string_types): 125 | # Convert the string pretrained vector identifier 126 | # to a Vectors object 127 | if vector not in pretrained_aliases: 128 | raise ValueError( 129 | "Got string input vector {}, but allowed pretrained " 130 | "vectors are {}".format( 131 | vector, list(pretrained_aliases.keys()))) 132 | vectors[idx] = pretrained_aliases[vector](**kwargs) 133 | elif not isinstance(vector, Vectors): 134 | raise ValueError( 135 | "Got input vectors of type {}, expected str or " 136 | "Vectors object".format(type(vector))) 137 | 138 | tot_dim = sum(v.dim for v in vectors) 139 | self.vectors = torch.Tensor(len(self), tot_dim) 140 | for i, token in enumerate(self.itos): 141 | start_dim = 0 142 | for v in vectors: 143 | end_dim = start_dim + v.dim 144 | self.vectors[i][start_dim:end_dim] = v[token.strip()] 145 | start_dim = end_dim 146 | assert(start_dim == tot_dim) 147 | 148 | def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_): 149 | """ 150 | Set the vectors for the Vocab instance from a collection of Tensors. 151 | 152 | Arguments: 153 | stoi: A dictionary of string to the index of the associated vector 154 | in the `vectors` input argument. 155 | vectors: An indexed iterable (or other structure supporting __getitem__) that 156 | given an input index, returns a FloatTensor representing the vector 157 | for the token associated with the index. For example, 158 | vector[stoi["string"]] should return the vector for "string". 159 | dim: The dimensionality of the vectors. 160 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 161 | to zero vectors; can be any function that takes in a Tensor and 162 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 163 | """ 164 | self.vectors = torch.Tensor(len(self), dim) 165 | for i, token in enumerate(self.itos): 166 | wv_index = stoi.get(token, None) 167 | if wv_index is not None: 168 | self.vectors[i] = vectors[wv_index] 169 | else: 170 | self.vectors[i] = unk_init(self.vectors[i]) 171 | 172 | 173 | class SubwordVocab(Vocab): 174 | 175 | def __init__(self, counter, max_size=None, specials=[''], 176 | vectors=None, unk_init=torch.Tensor.zero_): 177 | """Create a revtok subword vocabulary from a collections.Counter. 178 | 179 | Arguments: 180 | counter: collections.Counter object holding the frequencies of 181 | each word found in the data. 182 | max_size: The maximum size of the subword vocabulary, or None for no 183 | maximum. Default: None. 184 | specials: The list of special tokens (e.g., padding or eos) that 185 | will be prepended to the vocabulary in addition to an 186 | token. 187 | """ 188 | try: 189 | import revtok 190 | except ImportError: 191 | print("Please install revtok.") 192 | raise 193 | 194 | self.stoi = defaultdict(_default_unk_index) 195 | self.stoi.update({tok: i for i, tok in enumerate(specials)}) 196 | self.itos = specials 197 | 198 | self.segment = revtok.SubwordSegmenter(counter, max_size) 199 | 200 | max_size = None if max_size is None else max_size + len(self.itos) 201 | 202 | # sort by frequency/entropy, then alphabetically 203 | toks = sorted(self.segment.vocab.items(), 204 | key=lambda tup: (len(tup[0]) != 1, -tup[1], tup[0])) 205 | 206 | for tok, _ in toks: 207 | self.itos.append(tok) 208 | self.stoi[tok] = len(self.itos) - 1 209 | 210 | if vectors is not None: 211 | self.load_vectors(vectors, unk_init=unk_init) 212 | 213 | 214 | class Vectors(object): 215 | 216 | def __init__(self, name, cache=None, 217 | url=None, unk_init=None): 218 | """ 219 | Arguments: 220 | name: name of the file that contains the vectors 221 | cache: directory for cached vectors 222 | url: url for download if vectors not found in cache 223 | unk_init (callback): by default, initalize out-of-vocabulary word vectors 224 | to zero vectors; can be any function that takes in a Tensor and 225 | returns a Tensor of the same size 226 | """ 227 | cache = '.vector_cache' if cache is None else cache 228 | self.unk_init = torch.Tensor.zero_ if unk_init is None else unk_init 229 | self.cache(name, cache, url=url) 230 | 231 | def __getitem__(self, token): 232 | if token in self.stoi: 233 | return self.vectors[self.stoi[token]] 234 | else: 235 | return self.unk_init(torch.Tensor(1, self.dim)) 236 | 237 | def cache(self, name, cache, url=None): 238 | if os.path.isfile(name): 239 | path = name 240 | path_pt = os.path.join(cache, os.path.basename(name)) + '.pt' 241 | else: 242 | path = os.path.join(cache, name) 243 | path_pt = path + '.pt' 244 | 245 | if not os.path.isfile(path_pt): 246 | if not os.path.isfile(path) and url: 247 | logger.info('Downloading vectors from {}'.format(url)) 248 | if not os.path.exists(cache): 249 | os.makedirs(cache) 250 | dest = os.path.join(cache, os.path.basename(url)) 251 | if not os.path.isfile(dest): 252 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=dest) as t: 253 | try: 254 | urlretrieve(url, dest, reporthook=reporthook(t)) 255 | except KeyboardInterrupt as e: # remove the partial zip file 256 | os.remove(dest) 257 | raise e 258 | logger.info('Extracting vectors into {}'.format(cache)) 259 | ext = os.path.splitext(dest)[1][1:] 260 | if ext == 'zip': 261 | with zipfile.ZipFile(dest, "r") as zf: 262 | zf.extractall(cache) 263 | elif ext == 'gz': 264 | with tarfile.open(dest, 'r:gz') as tar: 265 | tar.extractall(path=cache) 266 | if not os.path.isfile(path): 267 | raise RuntimeError('no vectors found at {}'.format(path)) 268 | 269 | # str call is necessary for Python 2/3 compatibility, since 270 | # argument must be Python 2 str (Python 3 bytes) or 271 | # Python 3 str (Python 2 unicode) 272 | itos, vectors, dim = [], array.array(str('d')), None 273 | 274 | # Try to read the whole file with utf-8 encoding. 275 | binary_lines = False 276 | try: 277 | with io.open(path, encoding="utf8") as f: 278 | lines = [line for line in f] 279 | # If there are malformed lines, read in binary mode 280 | # and manually decode each word from utf-8 281 | except: 282 | logger.warning("Could not read {} as UTF8 file, " 283 | "reading file as bytes and skipping " 284 | "words with malformed UTF8.".format(path)) 285 | with open(path, 'rb') as f: 286 | lines = [line for line in f] 287 | binary_lines = True 288 | 289 | logger.info("Loading vectors from {}".format(path)) 290 | for iline, line in tqdm(enumerate(lines), total=len(lines)): 291 | # Explicitly splitting on " " is important, so we don't 292 | # get rid of Unicode non-breaking spaces in the vectors. 293 | entries = line.rstrip().split(b" " if binary_lines else " ") 294 | 295 | word, entries = entries[0], entries[1:] 296 | if dim is None and len(entries) > 1: 297 | dim = len(entries) 298 | elif len(entries) == 1: 299 | logger.warning("Skipping token {} with 1-dimensional " 300 | "vector {}; likely a header".format(word, entries)) 301 | continue 302 | elif dim != len(entries): 303 | raise RuntimeError( 304 | "Vector for token {} at line {} has {} dimensions, but previously " 305 | "read vectors have {} dimensions. All vectors must have " 306 | "the same number of dimensions.".format(word, iline, len(entries), dim)) 307 | 308 | if binary_lines: 309 | try: 310 | if isinstance(word, six.binary_type): 311 | word = word.decode('utf-8') 312 | except: 313 | logger.info("Skipping non-UTF8 token {}".format(repr(word))) 314 | continue 315 | vectors.extend(float(x) for x in entries) 316 | itos.append(word) 317 | 318 | self.itos = itos 319 | self.stoi = {word: i for i, word in enumerate(itos)} 320 | self.vectors = torch.Tensor(vectors).view(-1, dim) 321 | self.dim = dim 322 | logger.info('Saving vectors to {}'.format(path_pt)) 323 | torch.save((self.itos, self.stoi, self.vectors, self.dim), path_pt) 324 | else: 325 | logger.info('Loading vectors from {}'.format(path_pt)) 326 | self.itos, self.stoi, self.vectors, self.dim = torch.load(path_pt) 327 | 328 | 329 | class GloVe(Vectors): 330 | url = { 331 | '42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', 332 | '840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', 333 | 'twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip', 334 | '6B': 'http://nlp.stanford.edu/data/glove.6B.zip', 335 | } 336 | 337 | def __init__(self, name='840B', dim=300, **kwargs): 338 | url = self.url[name] 339 | name = 'glove.{}.{}d.txt'.format(name, str(dim)) 340 | super(GloVe, self).__init__(name, url=url, **kwargs) 341 | 342 | 343 | class FastText(Vectors): 344 | 345 | url_base = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.{}.vec' 346 | 347 | def __init__(self, language="en", **kwargs): 348 | url = self.url_base.format(language) 349 | name = os.path.basename(url) 350 | super(FastText, self).__init__(name, url=url, **kwargs) 351 | 352 | 353 | class CharNGram(Vectors): 354 | 355 | name = 'charNgram.txt' 356 | url = ('http://www.logos.t.u-tokyo.ac.jp/~hassy/publications/arxiv2016jmt/' 357 | 'jmt_pre-trained_embeddings.tar.gz') 358 | 359 | def __init__(self, **kwargs): 360 | super(CharNGram, self).__init__(self.name, url=self.url, **kwargs) 361 | 362 | def __getitem__(self, token): 363 | vector = torch.Tensor(1, self.dim).zero_() 364 | if token == "": 365 | return self.unk_init(vector) 366 | # These literals need to be coerced to unicode for Python 2 compatibility 367 | # when we try to join them with read ngrams from the files. 368 | chars = ['#BEGIN#'] + list(token) + ['#END#'] 369 | num_vectors = 0 370 | for n in [2, 3, 4]: 371 | end = len(chars) - n + 1 372 | grams = [chars[i:(i + n)] for i in range(end)] 373 | for gram in grams: 374 | gram_key = '{}gram-{}'.format(n, ''.join(gram)) 375 | if gram_key in self.stoi: 376 | vector += self.vectors[self.stoi[gram_key]] 377 | num_vectors += 1 378 | if num_vectors > 0: 379 | vector /= num_vectors 380 | else: 381 | vector = self.unk_init(vector) 382 | return vector 383 | 384 | 385 | def _default_unk_index(): 386 | return 0 387 | 388 | 389 | pretrained_aliases = { 390 | "charngram.100d": partial(CharNGram), 391 | "fasttext.en.300d": partial(FastText, language="en"), 392 | "fasttext.simple.300d": partial(FastText, language="simple"), 393 | "glove.42B.300d": partial(GloVe, name="42B", dim="300"), 394 | "glove.840B.300d": partial(GloVe, name="840B", dim="300"), 395 | "glove.twitter.27B.25d": partial(GloVe, name="twitter.27B", dim="25"), 396 | "glove.twitter.27B.50d": partial(GloVe, name="twitter.27B", dim="50"), 397 | "glove.twitter.27B.100d": partial(GloVe, name="twitter.27B", dim="100"), 398 | "glove.twitter.27B.200d": partial(GloVe, name="twitter.27B", dim="200"), 399 | "glove.6B.50d": partial(GloVe, name="6B", dim="50"), 400 | "glove.6B.100d": partial(GloVe, name="6B", dim="100"), 401 | "glove.6B.200d": partial(GloVe, name="6B", dim="200"), 402 | "glove.6B.300d": partial(GloVe, name="6B", dim="300") 403 | } 404 | """Mapping from string name to factory function""" 405 | -------------------------------------------------------------------------------- /embeddings/indexed_field.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | from collections import Counter, OrderedDict 3 | import six 4 | import torch 5 | from torch.autograd import Variable 6 | from tqdm import tqdm 7 | 8 | from torchtext.data.dataset import Dataset 9 | from torchtext.data.pipeline import Pipeline 10 | from torchtext.data.utils import get_tokenizer 11 | from torchtext.vocab import Vocab, SubwordVocab 12 | 13 | 14 | class RawField(object): 15 | """ Defines a general datatype. 16 | 17 | Every dataset consists of one or more types of data. For instance, a text 18 | classification dataset contains sentences and their classes, while a 19 | machine translation dataset contains paired examples of text in two 20 | languages. Each of these types of data is represented by an RawField object. 21 | An RawField object does not assume any property of the data type and 22 | it holds parameters relating to how a datatype should be processed. 23 | 24 | Attributes: 25 | preprocessing: The Pipeline that will be applied to examples 26 | using this field before creating an example. 27 | Default: None. 28 | postprocessing: A Pipeline that will be applied to a list of examples 29 | using this field before assigning to a batch. 30 | Function signature: (batch(list)) -> object 31 | Default: None. 32 | """ 33 | 34 | def __init__(self, preprocessing=None, postprocessing=None): 35 | self.preprocessing = preprocessing 36 | self.postprocessing = postprocessing 37 | 38 | def preprocess(self, x): 39 | """ Preprocess an example if the `preprocessing` Pipeline is provided. """ 40 | if self.preprocessing is not None: 41 | return self.preprocessing(x) 42 | else: 43 | return x 44 | 45 | def process(self, batch, *args, **kargs): 46 | """ Process a list of examples to create a batch. 47 | 48 | Postprocess the batch with user-provided Pipeline. 49 | 50 | Args: 51 | batch (list(object)): A list of object from a batch of examples. 52 | Returns: 53 | object: Processed object given the input and custom 54 | postprocessing Pipeline. 55 | """ 56 | if self.postprocessing is not None: 57 | batch = self.postprocessing(batch) 58 | return batch 59 | 60 | 61 | class Field(RawField): 62 | """Defines a datatype together with instructions for converting to Tensor. 63 | 64 | Field class models common text processing datatypes that can be represented 65 | by tensors. It holds a Vocab object that defines the set of possible values 66 | for elements of the field and their corresponding numerical representations. 67 | The Field object also holds other parameters relating to how a datatype 68 | should be numericalized, such as a tokenization method and the kind of 69 | Tensor that should be produced. 70 | 71 | If a Field is shared between two columns in a dataset (e.g., question and 72 | answer in a QA dataset), then they will have a shared vocabulary. 73 | 74 | Attributes: 75 | sequential: Whether the datatype represents sequential data. If False, 76 | no tokenization is applied. Default: True. 77 | use_vocab: Whether to use a Vocab object. If False, the data in this 78 | field should already be numerical. Default: True. 79 | init_token: A token that will be prepended to every example using this 80 | field, or None for no initial token. Default: None. 81 | eos_token: A token that will be appended to every example using this 82 | field, or None for no end-of-sentence token. Default: None. 83 | fix_length: A fixed length that all examples using this field will be 84 | padded to, or None for flexible sequence lengths. Default: None. 85 | tensor_type: The torch.Tensor class that represents a batch of examples 86 | of this kind of data. Default: torch.LongTensor. 87 | preprocessing: The Pipeline that will be applied to examples 88 | using this field after tokenizing but before numericalizing. Many 89 | Datasets replace this attribute with a custom preprocessor. 90 | Default: None. 91 | postprocessing: A Pipeline that will be applied to examples using 92 | this field after numericalizing but before the numbers are turned 93 | into a Tensor. The pipeline function takes the batch as a list, 94 | the field's Vocab, and train (a bool). 95 | Default: None. 96 | lower: Whether to lowercase the text in this field. Default: False. 97 | tokenize: The function used to tokenize strings using this field into 98 | sequential examples. If "spacy", the SpaCy English tokenizer is 99 | used. Default: str.split. 100 | include_lengths: Whether to return a tuple of a padded minibatch and 101 | a list containing the lengths of each examples, or just a padded 102 | minibatch. Default: False. 103 | batch_first: Whether to produce tensors with the batch dimension first. 104 | Default: False. 105 | pad_token: The string token used as padding. Default: "". 106 | unk_token: The string token used to represent OOV words. Default: "". 107 | pad_first: Do the padding of the sequence at the beginning. Default: False. 108 | truncate_first: Do the truncating of the sequence at the beginning. Defaulf: False 109 | """ 110 | 111 | vocab_cls = Vocab 112 | # Dictionary mapping PyTorch tensor types to the appropriate Python 113 | # numeric type. 114 | tensor_types = { 115 | torch.FloatTensor: float, 116 | torch.cuda.FloatTensor: float, 117 | torch.DoubleTensor: float, 118 | torch.cuda.DoubleTensor: float, 119 | torch.HalfTensor: float, 120 | torch.cuda.HalfTensor: float, 121 | 122 | torch.ByteTensor: int, 123 | torch.cuda.ByteTensor: int, 124 | torch.CharTensor: int, 125 | torch.cuda.CharTensor: int, 126 | torch.ShortTensor: int, 127 | torch.cuda.ShortTensor: int, 128 | torch.IntTensor: int, 129 | torch.cuda.IntTensor: int, 130 | torch.LongTensor: int, 131 | torch.cuda.LongTensor: int 132 | } 133 | 134 | def __init__(self, sequential=True, use_vocab=True, init_token=None, 135 | eos_token=None, fix_length=None, tensor_type=torch.LongTensor, 136 | preprocessing=None, postprocessing=None, lower=False, 137 | tokenize=(lambda s: s.split()), include_lengths=False, 138 | batch_first=False, pad_token="", unk_token="", 139 | pad_first=False, truncate_first=False): 140 | self.sequential = sequential 141 | self.use_vocab = use_vocab 142 | self.init_token = init_token 143 | self.eos_token = eos_token 144 | self.unk_token = unk_token 145 | self.fix_length = fix_length 146 | self.tensor_type = tensor_type 147 | self.preprocessing = preprocessing 148 | self.postprocessing = postprocessing 149 | self.lower = lower 150 | self.tokenize = get_tokenizer(tokenize) 151 | self.include_lengths = include_lengths 152 | self.batch_first = batch_first 153 | self.pad_token = pad_token if self.sequential else None 154 | self.pad_first = pad_first 155 | self.truncate_first = truncate_first 156 | 157 | def preprocess(self, x): 158 | """Load a single example using this field, tokenizing if necessary. 159 | 160 | If the input is a Python 2 `str`, it will be converted to Unicode 161 | first. If `sequential=True`, it will be tokenized. Then the input 162 | will be optionally lowercased and passed to the user-provided 163 | `preprocessing` Pipeline.""" 164 | if (six.PY2 and isinstance(x, six.string_types) and 165 | not isinstance(x, six.text_type)): 166 | x = Pipeline(lambda s: six.text_type(s, encoding='utf-8'))(x) 167 | if self.sequential and isinstance(x, six.text_type): 168 | x = self.tokenize(x.rstrip('\n')) 169 | if self.lower: 170 | x = Pipeline(six.text_type.lower)(x) 171 | if self.preprocessing is not None: 172 | return self.preprocessing(x) 173 | else: 174 | return x 175 | 176 | def process(self, batch, device, train, indexed=False): 177 | """ Process a list of examples to create a torch.Tensor. 178 | 179 | Pad, numericalize, and postprocess a batch and create a tensor. 180 | 181 | Args: 182 | batch (list(object)): A list of object from a batch of examples. 183 | Returns: 184 | torch.autograd.Variable: Processed object given the input 185 | and custom postprocessing Pipeline. 186 | """ 187 | if not indexed: 188 | padded = self.pad(batch) 189 | tensor = self.numericalize(padded, device=device, train=train) 190 | else: 191 | padded = self.pad_indexed(batch) 192 | tensor = self.numericalize_indexed(padded, device=device, train=train) 193 | return tensor 194 | 195 | def pad(self, minibatch): 196 | """Pad a batch of examples using this field. 197 | 198 | Pads to self.fix_length if provided, otherwise pads to the length of 199 | the longest example in the batch. Prepends self.init_token and appends 200 | self.eos_token if those attributes are not None. Returns a tuple of the 201 | padded list and a list containing lengths of each example if 202 | `self.include_lengths` is `True` and `self.sequential` is `True`, else just 203 | returns the padded list. If `self.sequential` is `False`, no padding is applied. 204 | """ 205 | minibatch = list(minibatch) 206 | if not self.sequential: 207 | return minibatch 208 | if self.fix_length is None: 209 | max_len = max(len(x) for x in minibatch) 210 | else: 211 | max_len = self.fix_length + ( 212 | self.init_token, self.eos_token).count(None) - 2 213 | padded, lengths = [], [] 214 | for x in minibatch: 215 | if self.pad_first: 216 | padded.append( 217 | [self.pad_token] * max(0, max_len - len(x)) + 218 | ([] if self.init_token is None else [self.init_token]) + 219 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 220 | ([] if self.eos_token is None else [self.eos_token])) 221 | else: 222 | padded.append( 223 | ([] if self.init_token is None else [self.init_token]) + 224 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 225 | ([] if self.eos_token is None else [self.eos_token]) + 226 | [self.pad_token] * max(0, max_len - len(x))) 227 | lengths.append(len(padded[-1]) - max(0, max_len - len(x))) 228 | if self.include_lengths: 229 | return (padded, lengths) 230 | return padded 231 | 232 | def pad_indexed(self, minibatch): 233 | """Pad a batch of pre-indexed examples using this field. 234 | 235 | Pads to self.fix_length if provided, otherwise pads to the length of 236 | the longest example in the batch. Prepends self.init_token and appends 237 | self.eos_token if those attributes are not None. Returns a tuple of the 238 | padded list and a list containing lengths of each example if 239 | `self.include_lengths` is `True` and `self.sequential` is `True`, else just 240 | returns the padded list. If `self.sequential` is `False`, no padding is applied. 241 | """ 242 | minibatch = list(minibatch) 243 | if not self.sequential: 244 | return minibatch 245 | if self.fix_length is None: 246 | max_len = max(len(x) for x in minibatch) 247 | else: 248 | max_len = self.fix_length + ( 249 | self.vocab.stoi[self.init_token], self.eos_token).count(None) - 2 250 | padded, lengths = [], [] 251 | for x in minibatch: 252 | if self.pad_first: 253 | padded.append( 254 | [pad_id] * max(0, max_len - len(x)) + 255 | ([] if self.init_token is None else [self.vocab.stoi[self.init_token]]) + 256 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 257 | ([] if self.eos_token is None else [self.vocab.stoi[eos_token]])) 258 | else: 259 | padded.append( 260 | ([] if self.init_token is None else [self.vocab.stoi[self.init_token]]) + 261 | list(x[-max_len:] if self.truncate_first else x[:max_len]) + 262 | ([] if self.eos_token is None else [self.vocab.stoi[self.eos_token]]) + 263 | [self.vocab.stoi[self.pad_token]] * max(0, max_len - len(x))) 264 | lengths.append(len(padded[-1]) - max(0, max_len - len(x))) 265 | if self.include_lengths: 266 | return (padded, lengths) 267 | return padded 268 | 269 | def build_vocab(self, *args, **kwargs): 270 | """Construct the Vocab object for this field from one or more datasets. 271 | 272 | Arguments: 273 | Positional arguments: Dataset objects or other iterable data 274 | sources from which to construct the Vocab object that 275 | represents the set of possible values for this field. If 276 | a Dataset object is provided, all columns corresponding 277 | to this field are used; individual columns can also be 278 | provided directly. 279 | Remaining keyword arguments: Passed to the constructor of Vocab. 280 | """ 281 | counter = Counter() 282 | sources = [] 283 | for arg in args: 284 | if isinstance(arg, Dataset): 285 | sources += [getattr(arg, name) for name, field in 286 | arg.fields.items() if field is self] 287 | else: 288 | sources.append(arg) 289 | for data in sources: 290 | for x in data: 291 | if not self.sequential: 292 | x = [x] 293 | counter.update(x) 294 | specials = list(OrderedDict.fromkeys( 295 | tok for tok in [self.unk_token, self.pad_token, self.init_token, 296 | self.eos_token] 297 | if tok is not None)) 298 | self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) 299 | 300 | def index(self, arr): 301 | """Turn a batch of examples that use this field into a indexes. 302 | 303 | If the field has include_lengths=True, a tensor of lengths will be 304 | included in the return value. 305 | 306 | Arguments: 307 | arr (List[List[str]], or tuple of (List[List[str]], List[int])): 308 | List of tokenized and padded examples, or tuple of List of 309 | tokenized and padded examples and List of lengths of each 310 | example if self.include_lengths is True. 311 | device (-1 or None): Device to create the Variable's Tensor on. 312 | Use -1 for CPU and None for the currently active GPU device. 313 | Default: None. 314 | train (boolean): Whether the batch is for a training set. 315 | If False, the Variable will be created with volatile=True. 316 | Default: True. 317 | """ 318 | if self.include_lengths and not isinstance(arr, tuple): 319 | raise ValueError("Field has include_lengths set to True, but " 320 | "input data is not a tuple of " 321 | "(data batch, batch lengths).") 322 | if isinstance(arr, tuple): 323 | arr, lengths = arr 324 | 325 | if self.use_vocab: 326 | if self.sequential: 327 | arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] 328 | else: 329 | arr = [self.vocab.stoi[x] for x in arr] 330 | 331 | if self.postprocessing is not None: 332 | arr = self.postprocessing(arr, self.vocab, train) 333 | else: 334 | if self.tensor_type not in self.tensor_types: 335 | raise ValueError( 336 | "Specified Field tensor_type {} can not be used with " 337 | "use_vocab=False because we do not know how to numericalize it. " 338 | "Please raise an issue at " 339 | "https://github.com/pytorch/text/issues".format(self.tensor_type)) 340 | numericalization_func = self.tensor_types[self.tensor_type] 341 | # It doesn't make sense to explictly coerce to a numeric type if 342 | # the data is sequential, since it's unclear how to coerce padding tokens 343 | # to a numeric type. 344 | if not self.sequential: 345 | arr = [numericalization_func(x) if isinstance(x, six.string_types) 346 | else x for x in arr] 347 | if self.postprocessing is not None: 348 | arr = self.postprocessing(arr, None, train) 349 | 350 | if self.include_lengths: 351 | return Variable(arr, volatile=not train), lengths 352 | return arr 353 | 354 | def numericalize_indexed(self, arr, device=None, train=True): 355 | """Turn a batch of examples that use this field into a Variable. 356 | 357 | If the field has include_lengths=True, a tensor of lengths will be 358 | included in the return value. 359 | 360 | Arguments: 361 | arr (List[List[str]], or tuple of (List[List[str]], List[int])): 362 | List of tokenized and padded examples, or tuple of List of 363 | tokenized and padded examples and List of lengths of each 364 | example if self.include_lengths is True. 365 | device (-1 or None): Device to create the Variable's Tensor on. 366 | Use -1 for CPU and None for the currently active GPU device. 367 | Default: None. 368 | train (boolean): Whether the batch is for a training set. 369 | If False, the Variable will be created with volatile=True. 370 | Default: True. 371 | """ 372 | if self.include_lengths and not isinstance(arr, tuple): 373 | raise ValueError("Field has include_lengths set to True, but " 374 | "input data is not a tuple of " 375 | "(data batch, batch lengths).") 376 | if isinstance(arr, tuple): 377 | arr, lengths = arr 378 | lengths = torch.LongTensor(lengths) 379 | arr = self.tensor_type(arr) 380 | if self.sequential and not self.batch_first: 381 | arr.t_() 382 | if device == -1: 383 | if self.sequential: 384 | arr = arr.contiguous() 385 | else: 386 | arr = arr.cuda(device) 387 | if self.include_lengths: 388 | lengths = lengths.cuda(device) 389 | if self.include_lengths: 390 | return Variable(arr, volatile=not train), lengths 391 | return Variable(arr, volatile=not train) 392 | 393 | def numericalize(self, arr, device=None, train=True): 394 | """Turn a batch of examples that use this field into a Variable. 395 | 396 | If the field has include_lengths=True, a tensor of lengths will be 397 | included in the return value. 398 | 399 | Arguments: 400 | arr (List[List[str]], or tuple of (List[List[str]], List[int])): 401 | List of tokenized and padded examples, or tuple of List of 402 | tokenized and padded examples and List of lengths of each 403 | example if self.include_lengths is True. 404 | device (-1 or None): Device to create the Variable's Tensor on. 405 | Use -1 for CPU and None for the currently active GPU device. 406 | Default: None. 407 | train (boolean): Whether the batch is for a training set. 408 | If False, the Variable will be created with volatile=True. 409 | Default: True. 410 | """ 411 | if self.include_lengths and not isinstance(arr, tuple): 412 | raise ValueError("Field has include_lengths set to True, but " 413 | "input data is not a tuple of " 414 | "(data batch, batch lengths).") 415 | if isinstance(arr, tuple): 416 | arr, lengths = arr 417 | lengths = torch.LongTensor(lengths) 418 | 419 | if self.use_vocab: 420 | if self.sequential: 421 | arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] 422 | else: 423 | arr = [self.vocab.stoi[x] for x in arr] 424 | 425 | if self.postprocessing is not None: 426 | arr = self.postprocessing(arr, self.vocab, train) 427 | else: 428 | if self.tensor_type not in self.tensor_types: 429 | raise ValueError( 430 | "Specified Field tensor_type {} can not be used with " 431 | "use_vocab=False because we do not know how to numericalize it. " 432 | "Please raise an issue at " 433 | "https://github.com/pytorch/text/issues".format(self.tensor_type)) 434 | numericalization_func = self.tensor_types[self.tensor_type] 435 | # It doesn't make sense to explictly coerce to a numeric type if 436 | # the data is sequential, since it's unclear how to coerce padding tokens 437 | # to a numeric type. 438 | if not self.sequential: 439 | arr = [numericalization_func(x) if isinstance(x, six.string_types) 440 | else x for x in arr] 441 | if self.postprocessing is not None: 442 | arr = self.postprocessing(arr, None, train) 443 | 444 | arr = self.tensor_type(arr) 445 | if self.sequential and not self.batch_first: 446 | arr.t_() 447 | if device == -1: 448 | if self.sequential: 449 | arr = arr.contiguous() 450 | else: 451 | arr = arr.cuda(device) 452 | if self.include_lengths: 453 | lengths = lengths.cuda(device) 454 | if self.include_lengths: 455 | return Variable(arr, volatile=not train), lengths 456 | return Variable(arr, volatile=not train) 457 | 458 | 459 | -------------------------------------------------------------------------------- /endtasks/bidaf_pair2vec.py: -------------------------------------------------------------------------------- 1 | import logging, os 2 | from typing import Any, Dict, List 3 | from overrides import overrides 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.nn.functional import nll_loss 7 | from torch.nn import Module, Linear, Sequential, ReLU 8 | from allennlp.data import Vocabulary 9 | from allennlp.models.model import Model 10 | from allennlp.modules import Seq2SeqEncoder, TimeDistributed, TextFieldEmbedder 11 | from allennlp.modules.matrix_attention.linear_matrix_attention import LinearMatrixAttention 12 | from allennlp.nn import InitializerApplicator, util 13 | from allennlp.training.metrics import Average, BooleanAccuracy, CategoricalAccuracy, SquadEmAndF1 14 | from torch.nn.functional import normalize 15 | from endtasks import util as pair2vec_util 16 | from endtasks import squad2_eval 17 | from endtasks.modules import VariationalDropout as InputVariationalDropout 18 | 19 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 20 | 21 | 22 | @Model.register("bidaf-pair2vec") 23 | class BidafPair2Vec(Model): 24 | """ 25 | This class implements modified version of BiDAF 26 | (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in 27 | Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf]. 28 | 29 | In this set-up, a single instance is a dialog, list of question answer pairs. 30 | 31 | Parameters 32 | ---------- 33 | vocab : ``Vocabulary`` 34 | text_field_embedder : ``TextFieldEmbedder`` 35 | Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. 36 | phrase_layer : ``Seq2SeqEncoder`` 37 | The encoder (with its own internal stacking) that we will use in between embedding tokens 38 | and doing the bidirectional attention. 39 | span_start_encoder : ``Seq2SeqEncoder`` 40 | The encoder that we will use to incorporate span start predictions into the passage state 41 | before predicting span end. 42 | span_end_encoder : ``Seq2SeqEncoder`` 43 | The encoder that we will use to incorporate span end predictions into the passage state. 44 | dropout : ``float``, optional (default=0.2) 45 | If greater than 0, we will apply dropout with this probability after all encoders (pytorch 46 | LSTMs do not apply dropout to their last layer). 47 | num_context_answers : ``int``, optional (default=0) 48 | If greater than 0, the model will consider previous question answering context. 49 | max_span_length: ``int``, optional (default=0) 50 | Maximum token length of the output span. 51 | """ 52 | 53 | def __init__(self, vocab: Vocabulary, 54 | text_field_embedder: TextFieldEmbedder, 55 | phrase_layer: Seq2SeqEncoder, 56 | residual_encoder: Seq2SeqEncoder, 57 | span_start_encoder: Seq2SeqEncoder, 58 | span_end_encoder: Seq2SeqEncoder, 59 | initializer: InitializerApplicator, 60 | dropout: float = 0.2, 61 | pair2vec_dropout: float = 0.15, 62 | max_span_length: int = 30, 63 | pair2vec_model_file: str = None, 64 | pair2vec_config_file: str = None 65 | ) -> None: 66 | super().__init__(vocab) 67 | self._max_span_length = max_span_length 68 | self._text_field_embedder = text_field_embedder 69 | self._phrase_layer = phrase_layer 70 | self._encoding_dim = phrase_layer.get_output_dim() 71 | 72 | self.pair2vec = pair2vec_util.get_pair2vec(pair2vec_config_file, pair2vec_model_file) 73 | self._pair2vec_dropout = torch.nn.Dropout(pair2vec_dropout) 74 | 75 | self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') 76 | 77 | # atten_dim = self._encoding_dim * 4 + 600 if ablation_type == 'attn_over_rels' else self._encoding_dim * 4 78 | atten_dim = self._encoding_dim * 4 + 600 79 | self._merge_atten = TimeDistributed(torch.nn.Linear(atten_dim, self._encoding_dim)) 80 | 81 | self._residual_encoder = residual_encoder 82 | 83 | self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y') 84 | 85 | self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3, 86 | self._encoding_dim)) 87 | 88 | self._span_start_encoder = span_start_encoder 89 | self._span_end_encoder = span_end_encoder 90 | 91 | self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1)) 92 | self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1)) 93 | self._squad_metrics = SquadEmAndF1() 94 | initializer(self) 95 | 96 | self._span_start_accuracy = CategoricalAccuracy() 97 | self._span_end_accuracy = CategoricalAccuracy() 98 | self._official_em = Average() 99 | self._official_f1 = Average() 100 | 101 | self._span_accuracy = BooleanAccuracy() 102 | self._variational_dropout = InputVariationalDropout(dropout) 103 | 104 | 105 | 106 | 107 | def forward(self, # type: ignore 108 | question: Dict[str, torch.LongTensor], 109 | passage: Dict[str, torch.LongTensor], 110 | spans: torch.IntTensor = None, 111 | metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: 112 | # pylint: disable=arguments-differ 113 | """ 114 | Parameters 115 | ---------- 116 | question : Dict[str, torch.LongTensor] 117 | From a ``TextField``. 118 | passage : Dict[str, torch.LongTensor] 119 | From a ``TextField``. The model assumes that this passage contains the answer to the 120 | question, and predicts the beginning and ending positions of the answer within the 121 | passage. 122 | span_start : ``torch.IntTensor``, optional 123 | From an ``IndexField``. This is one of the things we are trying to predict - the 124 | beginning position of the answer with the passage. This is an `inclusive` token index. 125 | If this is given, we will compute a loss that gets included in the output dictionary. 126 | span_end : ``torch.IntTensor``, optional 127 | From an ``IndexField``. This is one of the things we are trying to predict - the 128 | ending position of the answer with the passage. This is an `inclusive` token index. 129 | If this is given, we will compute a loss that gets included in the output dictionary. 130 | metadata : ``List[Dict[str, Any]]``, optional 131 | If present, this should contain the question ID, original passage text, and token 132 | offsets into the passage for each instance in the batch. We use this for computing 133 | official metrics using the official SQuAD evaluation script. The length of this list 134 | should be the batch size, and each dictionary should have the keys ``id``, 135 | ``original_passage``, and ``token_offsets``. If you only want the best span string and 136 | don't care about official metrics, you can omit the ``id`` key. 137 | 138 | Returns 139 | ------- 140 | An output dictionary consisting of the followings. 141 | Each of the followings is a nested list because first iterates over dialog, then questions in dialog. 142 | 143 | qid : List[List[str]] 144 | A list of list, consisting of question ids. 145 | best_span_str : List[List[str]] 146 | If sufficient metadata was provided for the instances in the batch, we also return the 147 | string from the original passage that the model thinks is the best answer to the 148 | question. 149 | loss : torch.FloatTensor, optional 150 | A scalar loss to be optimised. 151 | """ 152 | span_start = None if spans is None else spans[:, 0, 0] 153 | span_end = None if spans is None else spans[:, 0, 1] 154 | pair2vec_question_tokens = question['pair2vec_tokens'] 155 | pair2vec_passage_tokens = passage['pair2vec_tokens'] 156 | del question['pair2vec_tokens'] 157 | del passage['pair2vec_tokens'] 158 | embedded_question = self._variational_dropout(self._text_field_embedder(question)) 159 | embedded_passage = self._variational_dropout(self._text_field_embedder(passage)) 160 | 161 | # Extended batch size takes into account batch size * num paragraphs 162 | extended_batch_size = embedded_question.size(0) 163 | passage_length = embedded_passage.size(1) 164 | #question_mask = util.get_text_field_mask(question).float() 165 | #passage_mask = util.get_text_field_mask(passage).float() 166 | question_mask = pair2vec_util.get_mask(question, 'elmo').float() 167 | passage_mask = pair2vec_util.get_mask(passage, 'elmo').float() 168 | 169 | # Phrase layer is the shared Bi-GRU in the paper 170 | # (extended_batch_size, sequence_length, input_dim) -> (extended_batch_size, sequence_length, encoding_dim) 171 | encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask)) 172 | encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask)) 173 | batch_size, num_tokens, _ = encoded_passage.size() 174 | encoding_dim = encoded_question.size(-1) 175 | 176 | # Shape: (batch_size * max_qa_count, passage_length, question_length) 177 | passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) 178 | # Shape: (batch_size * max_qa_count, passage_length, question_length) 179 | passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) 180 | # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) 181 | passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) 182 | 183 | # We replace masked values with something really negative here, so they don't affect the 184 | # max below. 185 | masked_similarity = util.replace_masked_values(passage_question_similarity, 186 | question_mask.unsqueeze(1), 187 | -1e7) 188 | 189 | question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) 190 | question_passage_attention = util.masked_softmax(question_passage_similarity, 191 | passage_mask) 192 | passage_as_args = pair2vec_util.get_pair2vec_word_embeddings(self.pair2vec, pair2vec_passage_tokens) 193 | question_as_args = pair2vec_util.get_pair2vec_word_embeddings(self.pair2vec, pair2vec_question_tokens) 194 | # get mask for padding and unknowns 195 | pair2vec_passage_mask = 1 - (torch.eq(pair2vec_passage_tokens, 0).long() + torch.eq(pair2vec_passage_tokens, 1).long()) 196 | pair2vec_question_mask = 1 - (torch.eq(pair2vec_question_tokens, 0).long() + torch.eq(pair2vec_question_tokens, 1).long()) 197 | # normalize with masked softmask 198 | pair2vec_attention = util.last_dim_softmax(passage_question_similarity, pair2vec_question_mask) 199 | # get relation embedding 200 | p2q_pairs = normalize(pair2vec_util.get_pair_embeddings(self.pair2vec, passage_as_args, question_as_args), dim=-1) 201 | q2p_pairs = normalize(pair2vec_util.get_pair_embeddings(self.pair2vec, question_as_args, passage_as_args), dim=-1) 202 | # attention over pair2vec 203 | attended_question_relations = self._pair2vec_dropout(util.weighted_sum(p2q_pairs, pair2vec_attention)) 204 | attended_passage_relations = self._pair2vec_dropout(util.weighted_sum(q2p_pairs.transpose(1,2), pair2vec_attention)) 205 | # mask out stuff 206 | attended_question_pairs = attended_question_relations * pair2vec_passage_mask.float().unsqueeze(-1) 207 | attended_passage_pairs = attended_passage_relations * pair2vec_passage_mask.float().unsqueeze(-1) 208 | attended_pairs = torch.cat((attended_question_pairs, attended_passage_pairs), dim=-1) 209 | # Shape: (batch_size * max_qa_count, encoding_dim) 210 | question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) 211 | tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(extended_batch_size, 212 | passage_length, 213 | encoding_dim) 214 | 215 | 216 | # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) 217 | final_merged_passage = torch.cat([encoded_passage, 218 | passage_question_vectors, 219 | encoded_passage * passage_question_vectors, 220 | encoded_passage * tiled_question_passage_vector, 221 | attended_pairs], 222 | dim=-1) 223 | 224 | final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) 225 | 226 | residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage, 227 | passage_mask)) 228 | self_attention_matrix = self._self_attention(residual_layer, residual_layer) 229 | # Expand mask for self-attention 230 | mask = (passage_mask.resize(extended_batch_size, passage_length, 1) * 231 | passage_mask.resize(extended_batch_size, 1, passage_length)) 232 | 233 | self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device) 234 | self_mask = self_mask.resize(1, passage_length, passage_length) 235 | mask = mask * (1 - self_mask) 236 | 237 | self_attention_probs = util.masked_softmax(self_attention_matrix, mask) 238 | 239 | # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) 240 | self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) 241 | self_attention_vecs = torch.cat([self_attention_vecs, residual_layer, 242 | residual_layer * self_attention_vecs], 243 | dim=-1) 244 | residual_layer = F.relu(self._merge_self_attention(self_attention_vecs)) 245 | 246 | final_merged_passage = final_merged_passage + residual_layer 247 | # batch_size * maxqa_pair_len * max_passage_len * 200 248 | final_merged_passage = self._variational_dropout(final_merged_passage) 249 | start_rep = self._span_start_encoder(final_merged_passage, passage_mask) 250 | span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) 251 | 252 | end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1), 253 | passage_mask) 254 | span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) 255 | 256 | span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) 257 | # batch_size * maxqa_len_pair, max_document_len 258 | span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) 259 | 260 | best_span = self._get_best_span(span_start_logits, span_end_logits, self._max_span_length) 261 | 262 | output_dict: Dict[str, Any] = {} 263 | 264 | # Compute the loss. 265 | if span_start is not None: 266 | loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.view(-1), 267 | ignore_index=-1) 268 | loss += nll_loss(util.masked_log_softmax(span_end_logits, 269 | passage_mask), span_end.view(-1), ignore_index=-1) 270 | # add a select for the right span to compute loss 271 | output_dict["loss"] = loss 272 | 273 | # Compute F1 and preparing the output dictionary. 274 | output_dict['best_span_str'] = [] 275 | output_dict['question_id'] = [] 276 | best_span_cpu = best_span.detach().cpu().numpy() 277 | for i in range(batch_size): 278 | passage_str = metadata[i]['original_passage'] 279 | offsets = metadata[i]['token_offsets'] 280 | predicted_span = tuple(best_span[i].cpu().numpy()) 281 | start_offset = offsets[predicted_span[0]][0] 282 | end_offset = offsets[predicted_span[1]][1] 283 | best_span_string = passage_str[start_offset:end_offset] 284 | # if best_span_string == 'noanswertoken': 285 | # best_span_string = '' 286 | # print(predicted_span, best_span_string) 287 | output_dict['best_span_str'].append(best_span_string) 288 | output_dict['question_id'].append(metadata[i]['question_id']) 289 | 290 | answer_texts = metadata[i].get('answer_texts', []) 291 | exact_match = f1_score = 0 292 | if answer_texts: 293 | exact_match = squad2_eval.metric_max_over_ground_truths( 294 | squad2_eval.compute_exact, 295 | best_span_string, 296 | answer_texts) 297 | f1_score = squad2_eval.metric_max_over_ground_truths( 298 | squad2_eval.compute_f1, 299 | best_span_string, 300 | answer_texts) 301 | self._official_em(100 * exact_match) 302 | self._official_f1(100 * f1_score) 303 | return output_dict 304 | 305 | @overrides 306 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: 307 | return output_dict 308 | 309 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 310 | return {'em': self._official_em.get_metric(reset), 311 | 'f1': self._official_f1.get_metric(reset)} 312 | 313 | @staticmethod 314 | def _get_best_span(span_start_logits: torch.Tensor, 315 | span_end_logits: torch.Tensor, 316 | max_span_length: int) -> torch.Tensor: 317 | # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as 318 | # yesno prediction bit and followup prediction bit from the predicted span end token. 319 | if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: 320 | raise ValueError("Input shapes must be (batch_size, passage_length)") 321 | batch_size, passage_length = span_start_logits.size() 322 | max_span_log_prob = [-1e20] * batch_size 323 | span_start_argmax = [0] * batch_size 324 | 325 | best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) 326 | 327 | span_start_logits = span_start_logits.data.cpu().numpy() 328 | span_end_logits = span_end_logits.data.cpu().numpy() 329 | for b_i in range(batch_size): # pylint: disable=invalid-name 330 | for j in range(passage_length): 331 | val1 = span_start_logits[b_i, span_start_argmax[b_i]] 332 | if val1 < span_start_logits[b_i, j]: 333 | span_start_argmax[b_i] = j 334 | val1 = span_start_logits[b_i, j] 335 | val2 = span_end_logits[b_i, j] 336 | if val1 + val2 > max_span_log_prob[b_i]: 337 | if j - span_start_argmax[b_i] > max_span_length: 338 | continue 339 | best_word_span[b_i, 0] = span_start_argmax[b_i] 340 | best_word_span[b_i, 1] = j 341 | max_span_log_prob[b_i] = val1 + val2 342 | return best_word_span 343 | 344 | --------------------------------------------------------------------------------