├── eval-script ├── log.py ├── utils.py ├── patterns.py ├── dataset.py ├── mlm.py └── evaluate.py └── README.md /eval-script/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | names = set() 4 | 5 | 6 | def __setup_custom_logger(name: str) -> logging.Logger: 7 | formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(module)s - %(message)s') 8 | 9 | names.add(name) 10 | 11 | handler = logging.StreamHandler() 12 | handler.setFormatter(formatter) 13 | 14 | logger = logging.getLogger(name) 15 | logger.setLevel(logging.INFO) 16 | logger.addHandler(handler) 17 | return logger 18 | 19 | 20 | def get_logger(name: str) -> logging.Logger: 21 | if name in names: 22 | return logging.getLogger(name) 23 | else: 24 | return __setup_custom_logger(name) 25 | -------------------------------------------------------------------------------- /eval-script/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import itertools 4 | import torch 5 | import io 6 | 7 | import log 8 | 9 | logger = log.get_logger('root') 10 | 11 | 12 | def pairwise(iterable): 13 | a, b = itertools.tee(iterable) 14 | next(b, None) 15 | return zip(a, b) 16 | 17 | 18 | def load_embeddings(embd_file: str) -> Dict[str, torch.Tensor]: 19 | logger.info('Loading embeddings from {}'.format(embd_file)) 20 | embds = {} 21 | with io.open(embd_file, 'r', encoding='utf8') as f: 22 | for line in f: 23 | comps = line.split() 24 | word = comps[0] 25 | embd = [float(x) for x in comps[1:]] 26 | embds[word] = torch.tensor(embd) 27 | logger.info('Found {} embeddings'.format(len(embds))) 28 | return embds 29 | -------------------------------------------------------------------------------- /eval-script/patterns.py: -------------------------------------------------------------------------------- 1 | from dataset import ANTONYM, COHYPONYM, HYPERNYM, CORRUPTION, RELATIONS, AnnotatedWord 2 | 3 | # TODO [MASK] -> tokenizer mask token!! 4 | 5 | WORD_TOKEN = '' 6 | MASK_TOKEN = '[MASK]' 7 | 8 | def get_patterns(word: AnnotatedWord, relation: str): 9 | if relation == ANTONYM: 10 | return get_patterns_antonym(word) 11 | if relation == COHYPONYM: 12 | return get_patterns_cohyponym(word) 13 | if relation == HYPERNYM: 14 | return get_patterns_hypernym(word) 15 | if relation == CORRUPTION: 16 | return get_patterns_corruption(word) 17 | raise ValueError("No patterns found for relation {}".format(relation)) 18 | 19 | 20 | def get_patterns_antonym(_): 21 | return [ 22 | ' is the opposite of [MASK]', 23 | ' is not [MASK]', 24 | 'someone who is is not [MASK]', 25 | 'something that is is not [MASK]', 26 | '" " is the opposite of " [MASK] "' 27 | ] 28 | 29 | 30 | def get_patterns_hypernym(word): 31 | article = _get_article(word) 32 | 33 | return [ 34 | ' is a [MASK]', 35 | ' is an [MASK]', 36 | article + ' is a [MASK]', 37 | article + ' is an [MASK]', 38 | '" " refers to a [MASK]', 39 | '" " refers to an [MASK]', 40 | ' is a kind of [MASK]', 41 | article + ' is a kind of [MASK]' 42 | ] 43 | 44 | 45 | def get_patterns_cohyponym(_): 46 | return [ 47 | ' and [MASK]', 48 | '" " and " [MASK] "' 49 | ] 50 | 51 | 52 | def get_patterns_corruption(_): 53 | return [ 54 | '" " is a misspelling of " [MASK] " .', 55 | '" " . did you mean " [MASK] " ?' 56 | ] 57 | 58 | 59 | def _get_article(word): 60 | if word.word[0] in ['a', 'e', 'i', 'o', 'u']: 61 | return 'an' 62 | return 'a' 63 | 64 | 65 | if __name__ == '__main__': 66 | 67 | dummy = AnnotatedWord('dummy', pos='n', freq=1, count=1) 68 | 69 | for rel in RELATIONS: 70 | try: 71 | print('=== {} patterns ({}) ==='.format(rel, len(get_patterns(dummy, rel)))) 72 | for p in get_patterns(dummy, rel): 73 | print(p) 74 | except ValueError: 75 | print('=== no patterns found for relation {} ==='.format(rel)) 76 | print('') 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WordNet Language Model Probing 2 | 3 | This repository contains the **WordNet Language Model Probing** (WNLaMPro) dataset. Each line of the dataset file (`dataset/WNLaMPro.txt`) has the following form (note that all columns are separated by tabs rather than spaces): 4 | 5 | ... 6 | 7 | The columns have the following meaning: 8 | 9 | - ``: A unique identifier for this dataset entry 10 | - ``: Either `test` or `dev`, depending on whether this entry belongs to the development or test subset of PSR 11 | - ``: The key word in the `` format (see below) 12 | - ``: The relation of this entry, either `antonym`, `hypernym`, `cohyponym` or `corruption` 13 | - ``: The `n`-th target word for this dataset entry, in the `` format (see below) 14 | 15 | ### Annotated Words 16 | 17 | Each key and target word of the WNLaMPro dataset is represented as an `` in the following form: 18 | 19 | := (,,) 20 | 21 | The columns have the following meaning: 22 | 23 | - ``: The actual word 24 | - ``: The part-of-speech tag for this word (either `n`oun or `a`djective) 25 | - ``: The estimated Zipf frequency for this word, obtained using [wordfreq](https://pypi.org/project/wordfreq/) 26 | - ``: The number of occurrences of this word in the [Westbury Wikipedia corpus](http://www.psych.ualberta.ca/~westburylab/downloads/westburylab.wikicorp.download.html) 27 | 28 | ## Evaluation Script 29 | 30 | You can evaluate a pretrained language model on WNLaMPro as follows: 31 | ``` 32 | python3 eval-script/evaluate.py --root ROOT --predictions_file PREDICTIONS_FILE --output_file OUTPUT_FILE --model_cls MODEL_CLS --model_name MODEL_NAME (--embeddings EMBEDDINGS) 33 | ``` 34 | where 35 | - `ROOT` is the path to the directory where `WNLaMPro.txt` can be found; 36 | - `PREDICTIONS_FILE` is the name of the file in which predictions are to be stored (relative to `ROOT`); 37 | - `OUTPUT_FILE` is the name of the file in which the model's MRR is to be stored (relative to `ROOT`); 38 | - `MODEL_CLS` is either `bert` or `roberta` (the evaluation script currently does not support other pretrained language models); 39 | - `MODEL_NAME` is either the name of a pretrained model from the [Hugging Face Transformers Library](https://github.com/huggingface/transformers) (e.g., `bert-base-uncased`) or the path to a finetuned model; 40 | - `EMBEDDINGS` (optional) is the path (relative to `ROOT`) of a file that contains embeddings which are used to overwrite the language model's original embeddings. Each line of this file has to be in the format ` `, for example `apple -0.12 3.45 0.23 ... 0.03`. 41 | 42 | For additional parameters, check the content of `eval-script/evaluate.py` or run `python3 eval-script/evaluate.py --help`. 43 | 44 | ## Citation 45 | 46 | If you make use of the WNLaMPro dataset, please cite the following paper: 47 | 48 | ``` 49 | @inproceedings{schick2020rare, 50 | title={Rare words: A major problem for contextualized representation and how to fix it by attentive mimicking}, 51 | author={Schick, Timo and Sch{\"u}tze, Hinrich}, 52 | url="https://arxiv.org/abs/1904.06707", 53 | booktitle={Proceedings of the Thirty-Fourth AAAI Conference on Artificial Intelligence}, 54 | year={2020} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /eval-script/dataset.py: -------------------------------------------------------------------------------- 1 | import statistics 2 | from typing import Tuple, List 3 | import log 4 | 5 | logger = log.get_logger('root') 6 | 7 | SYNONYM = 'synonym' 8 | ANTONYM = 'antonym' 9 | HYPERNYM = 'hypernym' 10 | COHYPONYM = 'cohyponym' 11 | CORRUPTION = 'corruption' 12 | 13 | TEST = 'test' 14 | DEV = 'dev' 15 | 16 | RELATIONS = [SYNONYM, ANTONYM, HYPERNYM, COHYPONYM, CORRUPTION] 17 | 18 | 19 | class AnnotatedWord: 20 | def __init__(self, word, pos=None, freq=None, count=None): 21 | self.word = word 22 | self.pos = pos 23 | self.freq = freq 24 | self.count = count 25 | 26 | def __repr__(self): 27 | return '{} ({},{},{})'.format(self.word, self.pos, self.freq, self.count) 28 | 29 | 30 | class DatasetEntry: 31 | def __init__(self, base_word: AnnotatedWord, relation: str, matching_words: List[AnnotatedWord], 32 | set_type: str = TEST, eid=None): 33 | if relation not in RELATIONS: 34 | raise ValueError('Relation must be one of {}, got {}'.format(RELATIONS, relation)) 35 | self.base_word = base_word 36 | self.relation = relation 37 | self.matching_words = matching_words 38 | self.set_type = set_type 39 | self.id = eid 40 | 41 | def __repr__(self): 42 | return '{}({}): {} <{}> {}'.format(self.id, self.set_type, self.base_word, self.relation, self.matching_words) 43 | 44 | 45 | class Dataset(list): 46 | 47 | def select(self, pos: str = None, relation: str = None, freq: Tuple[int, int] = None, 48 | count: Tuple[int, int] = None, set_type: str = None) -> 'Dataset': 49 | ret = Dataset() 50 | for entry in self: 51 | if pos and entry.base_word.pos != pos: 52 | continue 53 | if relation and entry.relation != relation: 54 | continue 55 | if freq: 56 | if entry.base_word.freq < freq[0]: 57 | continue 58 | if 0 < freq[1] <= entry.base_word.freq: 59 | continue 60 | if count: 61 | if entry.base_word.count < count[0]: 62 | continue 63 | if 0 < count[1] <= entry.base_word.count: 64 | continue 65 | if set_type and entry.set_type != set_type: 66 | continue 67 | ret.append(entry) 68 | 69 | return ret 70 | 71 | def print_statistics(self, relations=None, set_types=None, counts=None) -> None: 72 | if relations is None: 73 | relations = RELATIONS 74 | 75 | if set_types is None: 76 | set_types = [DEV, TEST] 77 | 78 | if counts is None: 79 | counts = [(0, 9), (10, 99), (100, -1)] 80 | 81 | for relation in relations: 82 | for set_type in set_types: 83 | 84 | for (min_count, max_count) in counts: 85 | ds_subset = self.select(relation=relation, set_type=set_type, count=(min_count, max_count)) 86 | 87 | if not ds_subset: 88 | continue 89 | matching_words = [len(x.matching_words) for x in ds_subset] 90 | 91 | logger.info('{} - {} ({},{}): size = {}, mean targets = {}, median targets = {}'.format( 92 | relation, set_type, min_count, max_count, len(ds_subset), 93 | statistics.mean(matching_words), statistics.median(matching_words) 94 | )) 95 | 96 | 97 | def _string_to_entry(estr: str) -> DatasetEntry: 98 | cmps = estr.split('\t') 99 | eid = cmps[0] 100 | set_type = cmps[1] 101 | base_word = _string_to_annotated_word(cmps[2]) 102 | relation = cmps[3] 103 | matching_words = [_string_to_annotated_word(w) for w in cmps[4:]] 104 | return DatasetEntry(base_word, relation, matching_words, eid=eid, set_type=set_type) 105 | 106 | 107 | def _string_to_annotated_word(astr: str) -> AnnotatedWord: 108 | # e.g. disease (n,4.95,64105) 109 | word, meta_info = astr.split() 110 | meta_info = meta_info[1:-1] 111 | pos, freq, count = meta_info.split(',') 112 | freq = float(freq) 113 | count = int(count) 114 | return AnnotatedWord(word, pos=pos, freq=freq, count=count) 115 | 116 | 117 | def file_to_dataset(path: str, keep_frequent_corruptions: bool = False) -> Dataset: 118 | logger.info('Loading dataset from {}'.format(path)) 119 | ret = Dataset() 120 | with open(path, 'r', encoding='utf8') as file: 121 | for line in file: 122 | entry = _string_to_entry(line) 123 | if entry.relation == CORRUPTION and entry.base_word.count >= 10 and not keep_frequent_corruptions: 124 | continue 125 | ret.append(entry) 126 | logger.info('Done loading dataset') 127 | return ret 128 | -------------------------------------------------------------------------------- /eval-script/mlm.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Optional, Dict 3 | 4 | import torch 5 | from torch.nn import Module, Embedding 6 | from transformers import BertTokenizer, RobertaTokenizer, BertForMaskedLM, RobertaForMaskedLM, GPT2Tokenizer 7 | 8 | import log 9 | from patterns import WORD_TOKEN, MASK_TOKEN 10 | 11 | logger = log.get_logger('root') 12 | 13 | 14 | class OverwriteableEmbedding(Module): 15 | 16 | def __init__(self, embedding: Embedding, overwrite_fct=None): 17 | super().__init__() 18 | self.embedding = embedding 19 | self.overwrite_fct = overwrite_fct 20 | 21 | def forward(self, inp: torch.Tensor): 22 | embds = self.embedding(inp) 23 | if self.overwrite_fct is not None: 24 | embds = self.overwrite_fct(embds) 25 | return embds 26 | 27 | 28 | class AbstractMaskedLanguageModel(ABC): 29 | @abstractmethod 30 | def get_predictions(self, pattern: str, base_word: str, num_predictions: int) -> List[str]: 31 | pass 32 | 33 | 34 | class MockMaskedLanguageModel(AbstractMaskedLanguageModel): 35 | def get_predictions(self, pattern: str, base_word: str, num_predictions: int) -> List[str]: 36 | return ['cat', 'dog', 'coffee', 'mouse', 'tree', 'apple', 'orange'] 37 | 38 | 39 | class BertMaskedLanguageModel(AbstractMaskedLanguageModel): 40 | tokenizer_cls = BertTokenizer 41 | model_cls = BertForMaskedLM 42 | model_str = 'bert' 43 | 44 | def __init__(self, model_name: str, embeddings: Optional[Dict[str, torch.Tensor]] = None): 45 | self.tokenizer = type(self).tokenizer_cls.from_pretrained(model_name) 46 | self.model = type(self).model_cls.from_pretrained(model_name) 47 | self.model.eval() 48 | 49 | word_embeddings = getattr(self.model, type(self).model_str).embeddings.word_embeddings 50 | getattr(self.model, type(self).model_str).embeddings.word_embeddings = OverwriteableEmbedding(word_embeddings) 51 | self.embeddings = embeddings 52 | 53 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 54 | self.model.to(self.device) 55 | 56 | if self.embeddings: 57 | for embedding in self.embeddings.values(): 58 | embedding.to(self.device) 59 | 60 | def get_predictions(self, pattern: str, base_word: str, num_predictions: int) -> List[str]: 61 | 62 | replace_base_word = self.embeddings and base_word in self.embeddings 63 | 64 | pattern = pattern.replace(MASK_TOKEN, self.tokenizer.mask_token) 65 | left_context, right_context = pattern.split(WORD_TOKEN) 66 | 67 | if not replace_base_word: 68 | model_input = self._prepare_text(''.join([left_context, base_word, right_context])) 69 | logger.debug('Inferring embedding for {} without replacement'.format(model_input['tokenized_text'])) 70 | 71 | else: 72 | model_input = self._prepare_text(' '.join([left_context, self.tokenizer.unk_token, right_context])) 73 | base_word_idx = model_input['tokenized_text'].index(self.tokenizer.unk_token) 74 | 75 | if model_input['tokenized_text'][base_word_idx] != self.tokenizer.unk_token: 76 | raise ValueError("Got wrong base_word_idx, word at position {} is {} and not [UNK]".format( 77 | base_word_idx, model_input['tokenized_text'][base_word_idx])) 78 | 79 | getattr(self.model, type(self).model_str).embeddings.word_embeddings.overwrite_fct \ 80 | = lambda embeddings: self._overwrite_embeddings(embeddings, base_word_idx, self.embeddings[base_word]) 81 | 82 | logger.debug( 83 | 'Inferring embedding for {} with replacement, base_word_idx = {}'.format(model_input['tokenized_text'], 84 | base_word_idx)) 85 | 86 | if len(model_input['masked_indices']) != 1: 87 | raise ValueError( 88 | 'The pattern must contain exactly one "{}", got "{}" with base word "{}"'.format( 89 | self.tokenizer.mask_token, pattern, base_word) 90 | ) 91 | 92 | with torch.no_grad(): 93 | predictions = self.model( 94 | input_ids=model_input['tokens'].to(self.device), 95 | token_type_ids=model_input['segments'].to(self.device) 96 | )[0] 97 | 98 | getattr(self.model, type(self).model_str).embeddings.word_embeddings.overwrite_fct = None 99 | predicted_tokens = [] 100 | 101 | for masked_index in model_input['masked_indices']: 102 | _, predicted_indices = torch.topk(predictions[0, masked_index], num_predictions) 103 | 104 | for i in range(len(predicted_indices)): 105 | predicted_index = predicted_indices[i] 106 | predicted_token = self.tokenizer.convert_ids_to_tokens([predicted_index.item()])[0] 107 | predicted_tokens.append(predicted_token) 108 | return predicted_tokens 109 | 110 | def _prepare_text(self, text, use_sep=True, use_cls=True, use_full_stop=True): 111 | if use_cls: 112 | text = self.tokenizer.cls_token + ' ' + text 113 | if use_full_stop and not text[len(text) - 1] in ['?', '.', '!']: 114 | text += '.' 115 | if use_sep: 116 | text += ' ' + self.tokenizer.sep_token 117 | 118 | if isinstance(self.tokenizer, GPT2Tokenizer): 119 | tokenized_text = self.tokenizer.tokenize(text, add_prefix_space=True) 120 | else: 121 | tokenized_text = self.tokenizer.tokenize(text) 122 | 123 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text) 124 | segments_ids = [0] * len(indexed_tokens) 125 | 126 | # Convert inputs to PyTorch tensors 127 | tokens_tensor = torch.tensor([indexed_tokens]) 128 | segments_tensors = torch.tensor([segments_ids]) 129 | 130 | # get all masked indices 131 | masked_indices = [i for i, x in enumerate(tokenized_text) if x == self.tokenizer.mask_token] 132 | 133 | return {'tokenized_text': tokenized_text, 134 | 'tokens': tokens_tensor, 135 | 'segments': segments_tensors, 136 | 'masked_indices': masked_indices} 137 | 138 | @staticmethod 139 | def _overwrite_embeddings(embeddings: torch.Tensor, index: int, replacement_embedding: torch.Tensor): 140 | # this function is currently not designed to work with more than one batch 141 | if embeddings.shape[0] != 1: 142 | raise ValueError('expected a batch of size 1 but found ' + str(embeddings.shape[0].item())) 143 | 144 | embeddings[0, index, :] = replacement_embedding 145 | return embeddings 146 | 147 | 148 | class RobertaMaskedLanguageModel(BertMaskedLanguageModel): 149 | tokenizer_cls = RobertaTokenizer 150 | model_cls = RobertaForMaskedLM 151 | model_str = 'roberta' 152 | 153 | def get_predictions(self, pattern: str, base_word: str, num_predictions: int) -> List[str]: 154 | predictions = super().get_predictions(pattern, base_word, num_predictions) 155 | return [w.replace('Ġ', '').lower() for w in predictions] 156 | 157 | 158 | if __name__ == '__main__': 159 | mlm_bert = BertMaskedLanguageModel('bert-base-uncased') 160 | mlm_roberta = RobertaMaskedLanguageModel('roberta-large') 161 | 162 | predictions_bert = mlm_bert.get_predictions(pattern="a is a [MASK]", base_word="lime", num_predictions=10) 163 | predictions_roberta = mlm_roberta.get_predictions(pattern="a is a [MASK]", base_word="lime", num_predictions=10) 164 | print(predictions_bert) 165 | print(predictions_roberta) 166 | -------------------------------------------------------------------------------- /eval-script/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import log 3 | import os 4 | import collections 5 | import time 6 | from collections import defaultdict 7 | from typing import List, Dict 8 | 9 | import jsonpickle 10 | 11 | import utils 12 | from dataset import file_to_dataset, DatasetEntry, SYNONYM, RELATIONS 13 | from patterns import get_patterns 14 | from mlm import BertMaskedLanguageModel, AbstractMaskedLanguageModel, RobertaMaskedLanguageModel 15 | 16 | logger = log.get_logger('root') 17 | 18 | MODELS = { 19 | 'bert': BertMaskedLanguageModel, 20 | 'roberta': RobertaMaskedLanguageModel, 21 | } 22 | 23 | 24 | class EntryResult: 25 | def __init__(self, entry: DatasetEntry, predictions: List[List[str]], rank: int, precision_at: Dict[int, float]): 26 | self.entry = entry 27 | self.predictions = predictions 28 | self.rank = rank 29 | self.precision_at = precision_at 30 | 31 | def to_tsv_str(self, use_rank: bool, precision_k_values: List[int]) -> str: 32 | predictions_str = [','.join(pred) for pred in self.predictions] 33 | predictions_str = ' '.join(predictions_str) 34 | 35 | ret = self.entry.base_word.word + '\t' \ 36 | + self.entry.relation + '\t' \ 37 | + ' '.join([w.word for w in self.entry.matching_words]) 38 | if use_rank: 39 | ret += '\t{}'.format(self.rank) 40 | for k in precision_k_values: 41 | ret += '\t{}'.format(self.precision_at[k]) 42 | ret += '\t' + predictions_str 43 | return ret 44 | 45 | @staticmethod 46 | def headline_tsv(use_rank: bool, precision_k_values: List[int]): 47 | headline = 'keyword\trelation\ttargets' 48 | if use_rank: 49 | headline += '\tMRR' 50 | for k in precision_k_values: 51 | headline += '\tP@{}'.format(k) 52 | headline += '\tpredictions' 53 | return headline 54 | 55 | @staticmethod 56 | def to_file(results: List['EntryResult'], use_rank: bool, precision_k_values: List[int], path: str) -> None: 57 | with open(path, 'w', encoding='utf8') as f: 58 | f.write(EntryResult.headline_tsv(use_rank, precision_k_values) + '\n') 59 | for res in results: 60 | f.write(res.to_tsv_str(use_rank, precision_k_values) + '\n') 61 | 62 | 63 | class Result: 64 | def __init__(self, mrr: float, precision_at: Dict[int, float], ranks: List[float] = None, 65 | precision_vals: Dict[int, List[float]] = None, entry_results: List[EntryResult] = None): 66 | self.mrr = mrr 67 | self.precision_at = precision_at 68 | self.ranks = ranks 69 | self.precision_vals = precision_vals 70 | self.entry_results = entry_results 71 | 72 | def stringify(self, use_mrr, precision_k_values): 73 | ret = '' 74 | if use_mrr: 75 | ret += '{:5.3f} '.format(self.mrr) 76 | for k in precision_k_values: 77 | ret += '{:5.3f} '.format(self.precision_at[k]) 78 | return ret 79 | 80 | @staticmethod 81 | def stringify_results(results: Dict[str, 'Result'], use_mrr=True, precision_k_values=None) -> str: 82 | space_for_name = max(len(key) for key in results.keys()) + 2 83 | ret = ' ' * space_for_name + Result.headline(use_mrr, precision_k_values) + '\n' 84 | 85 | for key in results: 86 | ret += (('{:' + str(space_for_name) + 's}').format(key) + 87 | results[key].stringify(use_mrr, precision_k_values)) + '\n' 88 | return ret 89 | 90 | @staticmethod 91 | def headline(use_mrr, precision_k_values) -> str: 92 | headline = '' 93 | if use_mrr: 94 | headline += 'MRR ' 95 | for k in precision_k_values: 96 | headline += 'P@{:<3d} '.format(k) 97 | return headline 98 | 99 | 100 | def evaluate_from_predictions(dataset: List[DatasetEntry], predictions: Dict[str, List[List[str]]], compute_mrr=True, 101 | precision_at=None, with_raw: bool = False) -> Result: 102 | if precision_at is None: 103 | precision_at = [3, 10, 100] 104 | 105 | reciprocal_ranks = [] 106 | precision_vals = defaultdict(list) 107 | entry_results = [] 108 | 109 | for idx, entry in enumerate(dataset): 110 | 111 | if entry.id not in predictions: 112 | logger.warning('Found no predictions for entry with id "{}"'.format(entry.id)) 113 | continue 114 | 115 | prediction = predictions[entry.id] 116 | actuals = [w.word for w in entry.matching_words] 117 | 118 | entry_rank = -1 119 | entry_precision_at = {} 120 | 121 | if compute_mrr: 122 | reciprocal_rank = get_reciprocal_rank(actuals, prediction) 123 | reciprocal_ranks.append(reciprocal_rank) 124 | entry_rank = 1 / reciprocal_rank if reciprocal_rank > 0 else 1000 125 | 126 | for k in precision_at: 127 | precision_at_k = get_precision_at(k, actuals, prediction) 128 | precision_vals[k].append(precision_at_k) 129 | entry_precision_at[k] = precision_at_k 130 | 131 | entry_result = EntryResult(entry, prediction, entry_rank, entry_precision_at) 132 | entry_results.append(entry_result) 133 | 134 | if idx % 100 == 0: 135 | logger.info('Done processing {} of {} entries'.format(idx + 1, len(dataset))) 136 | 137 | result = Result(0, {}, entry_results=entry_results) 138 | 139 | if compute_mrr: 140 | mrr = avg(reciprocal_ranks) 141 | result.mrr = mrr 142 | for k in precision_at: 143 | p_at_k = avg(precision_vals[k]) 144 | result.precision_at[k] = p_at_k 145 | 146 | if with_raw: 147 | result.ranks = [1 / r if r > 0 else 1000 for r in reciprocal_ranks] 148 | result.precision_vals = precision_vals 149 | return result 150 | 151 | 152 | def avg(l: List[float]): 153 | if len(l) == 0: 154 | logger.warning('Computing average of empty list, returning -1 instead') 155 | return -1 156 | return sum(l) / len(l) 157 | 158 | 159 | def get_reciprocal_rank(actuals: List[str], predictions: List[List[str]]): 160 | rr = 0 161 | for pattern_predictions in predictions: 162 | for idx, predicted_word in enumerate(pattern_predictions): 163 | if predicted_word in actuals: 164 | rr = max(1 / (idx + 1), rr) 165 | return rr 166 | 167 | 168 | def get_precision_at(k: int, actuals: List[str], predictions: List[List[str]]): 169 | matches = get_matches_at(k, actuals, predictions) 170 | return min(1, len(matches) / k) 171 | 172 | 173 | def get_matches_at(k: int, actuals: List[str], predictions: List[List[str]], use_sum_instead_of_max: bool = False): 174 | if use_sum_instead_of_max: 175 | all_predictions = set() 176 | for pattern_predictions in predictions: 177 | all_predictions.update(pattern_predictions[:k]) 178 | return get_matches(all_predictions, actuals) 179 | else: 180 | best_predictions = set() 181 | for pattern_predictions in predictions: 182 | pp_set = set(pattern_predictions[:k]) 183 | if len(get_matches(pp_set, actuals)) > len(get_matches(best_predictions, actuals)): 184 | best_predictions = pp_set 185 | return get_matches(best_predictions, actuals) 186 | 187 | 188 | def predictions_to_file(model: AbstractMaskedLanguageModel, dataset: List[DatasetEntry], num_predictions=100, 189 | out_path: str = None) -> None: 190 | if os.path.isfile(out_path): 191 | raise FileExistsError('File {} already exists'.format(out_path)) 192 | 193 | with open(out_path, 'w', encoding='utf-8') as out_file: 194 | 195 | t0 = time.time() 196 | 197 | for idx, entry in enumerate(dataset): 198 | if entry.relation == SYNONYM: 199 | continue 200 | 201 | entry_predictions = predictions_for_entry(model, entry, num_predictions) 202 | _write_predictions_to_file(out_file, entry.id, entry_predictions) 203 | 204 | if idx % 100 == 0: 205 | total_time = time.time() - t0 206 | time_per_entry = total_time / (idx + 1) 207 | remaining_entries = len(dataset) - (idx + 1) 208 | time_for_remaining_entries = remaining_entries * time_per_entry 209 | logger.info('Done processing {} of {} dataset entries, estimated remaining time: {}s'.format( 210 | idx + 1, len(dataset), time_for_remaining_entries)) 211 | 212 | 213 | def _write_predictions_to_file(file, entry_id, predictions: List[List[str]]) -> None: 214 | file.write(str(entry_id) + '\t' + jsonpickle.dumps(predictions) + '\n') 215 | 216 | 217 | def _load_predictions_from_file(path: str) -> Dict[str, List[List[str]]]: 218 | logger.info("Loading model predictions from {}".format(path)) 219 | predictions = {} 220 | with open(path, 'r', encoding='utf-8') as file: 221 | for line in file: 222 | key, value = line.split('\t', 1) 223 | value_as_list = jsonpickle.decode(value) 224 | predictions[key] = value_as_list 225 | logger.info('Done loading model predictions') 226 | return predictions 227 | 228 | 229 | def predictions_for_entry(model: AbstractMaskedLanguageModel, entry: DatasetEntry, num_predictions=100): 230 | # get the corresponding patterns 231 | patterns = get_patterns(entry.base_word, entry.relation) 232 | predictions = [] 233 | 234 | for pattern in patterns: 235 | pattern_predictions = model.get_predictions(pattern, entry.base_word.word, num_predictions) 236 | predictions.append(pattern_predictions) 237 | 238 | return predictions 239 | 240 | 241 | def get_matches(predictions, actuals): 242 | return predictions.intersection(actuals) 243 | 244 | 245 | if __name__ == '__main__': 246 | 247 | parser = argparse.ArgumentParser() 248 | 249 | # file parameters 250 | parser.add_argument('--root', type=str, required=True) 251 | parser.add_argument('--dataset', type=str, default='WNLaMPro.txt') 252 | parser.add_argument('--predictions_file', type=str, default=None, required=True) 253 | parser.add_argument('--output_file', default=None, type=str) 254 | parser.add_argument('--raw_output_file', default=None, type=str) 255 | 256 | # parameters for computing new predictions 257 | parser.add_argument('--model_cls', choices=['bert', 'roberta'], default='bert') 258 | parser.add_argument('--model_name', type=str, default='bert-base-uncased') 259 | parser.add_argument('--embeddings', type=str, default=None) 260 | parser.add_argument('--num_predictions', type=int, default=100) 261 | 262 | # evaluation parameters 263 | parser.add_argument('--print_statistics', action='store_true', 264 | help='If set, instead of evaluating a model, statistics about the used dataset are printed.') 265 | parser.add_argument('--set_type', choices=['dev', 'test'], default=None) 266 | parser.add_argument('--count_thresholds', '-cs', type=int, nargs='*', default=[10, 100]) 267 | parser.add_argument('--min_subset_size', type=int, default=10) 268 | parser.add_argument('--precision_at', type=int, nargs='*', default=[3, 10, 100]) 269 | parser.add_argument('--keep_frequent_corruptions', action='store_true', 270 | help='If set, corruption entries with key word frequencies at or above 10 are kept') 271 | 272 | args = parser.parse_args() 273 | ds = file_to_dataset(os.path.join(args.root, args.dataset), 274 | keep_frequent_corruptions=args.keep_frequent_corruptions) 275 | 276 | if args.print_statistics: 277 | ds.print_statistics() 278 | 279 | predictions_file = os.path.join(args.root, args.predictions_file) 280 | 281 | if os.path.isfile(predictions_file): 282 | predictions = _load_predictions_from_file(predictions_file) 283 | 284 | else: 285 | logger.info('Found no precomputed predictions at {}'.format(predictions_file)) 286 | embeddings = None 287 | 288 | if args.embeddings: 289 | embeddings = utils.load_embeddings(os.path.join(args.root, args.embeddings)) 290 | 291 | model_cls = MODELS[args.model_cls] 292 | model = model_cls(args.model_name, embeddings) 293 | predictions_to_file(model, ds, args.num_predictions, predictions_file) 294 | predictions = _load_predictions_from_file(predictions_file) 295 | 296 | result = evaluate_from_predictions(ds.select(set_type=args.set_type), predictions) 297 | result_dict = collections.OrderedDict() 298 | result_dict['all_values'] = result 299 | 300 | count_thresholds = [0] + args.count_thresholds + [-1] 301 | 302 | for lower_bound, upper_bound in utils.pairwise(count_thresholds): 303 | ds_restricted = ds.select(count=(lower_bound, upper_bound), set_type=args.set_type) 304 | 305 | if len(ds_restricted) >= args.min_subset_size: 306 | result = evaluate_from_predictions(ds_restricted, predictions) 307 | result_dict['all_values ({},{})'.format(lower_bound, upper_bound)] = result 308 | 309 | if args.raw_output_file: 310 | EntryResult.to_file(result.entry_results, True, args.precision_at, 311 | os.path.join(args.root, 312 | args.raw_output_file + '-{}-{}'.format(lower_bound, upper_bound))) 313 | 314 | for rel in RELATIONS: 315 | for lower_bound, upper_bound in utils.pairwise(count_thresholds): 316 | ds_restricted = ds.select(relation=rel, count=(lower_bound, upper_bound), set_type=args.set_type) 317 | 318 | if len(ds_restricted) >= args.min_subset_size: 319 | result = evaluate_from_predictions(ds_restricted, predictions) 320 | result_dict['{} ({},{})'.format(rel, lower_bound, upper_bound)] = result 321 | 322 | results_str = Result.stringify_results(result_dict, True, args.precision_at) 323 | print(results_str) 324 | if args.output_file: 325 | with open(os.path.join(args.root, args.output_file), 'w', encoding='utf8') as f: 326 | f.write(results_str) 327 | 328 | if args.raw_output_file: 329 | EntryResult.to_file(result_dict['all_values'].entry_results, True, args.precision_at, 330 | os.path.join(args.root, args.raw_output_file)) 331 | --------------------------------------------------------------------------------