├── .gitignore ├── LICENSE ├── README.md ├── bin ├── cloze_task_example.py ├── combine_wordnet_embeddings.py ├── create_pretraining_data_for_bert.py ├── evaluate_mrr.py ├── evaluate_perplexity.py ├── evaluate_wiki_linking.py ├── evaluate_wsd_official.py ├── extract_wordnet.py ├── preprocess_wsd.py ├── run_hyperparameter_seeds.sh ├── semeval2010_task8_scorer-v1.2.pl ├── tacred_scorer.py ├── write_semeval2010_task8_for_official_eval.py ├── write_tacred_for_official_scorer.py └── write_wic_for_codalab.py ├── kb ├── __init__.py ├── bert_pretraining_reader.py ├── bert_tokenizer_and_candidate_generator.py ├── bert_utils.py ├── common.py ├── dict_field.py ├── entity_linking.py ├── evaluation │ ├── __init__.py │ ├── classification_model.py │ ├── exponential_average_metric.py │ ├── fbeta_measure.py │ ├── semeval2010_task8.py │ ├── tacred_dataset_reader.py │ ├── tacred_predictor.py │ ├── ultra_fine_reader.py │ ├── weighted_average.py │ └── wic_dataset_reader.py ├── include_all.py ├── kg_embedding.py ├── kg_probe_reader.py ├── knowbert.py ├── knowbert_utils.py ├── metrics.py ├── multitask.py ├── self_attn_bucket_iterator.py ├── span_attention_layer.py ├── testing.py ├── wiki_linking_reader.py ├── wiki_linking_util.py └── wordnet.py ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── evaluation │ ├── test_semeval2010_task8.py │ ├── test_simple_classifier.py │ ├── test_tacred_reader.py │ ├── test_ultra_fine_reader.py │ └── test_wic_reader.py ├── fixtures │ ├── bert │ │ ├── bert_config.json │ │ ├── bert_test_fixture.tar.gz │ │ ├── vocab.txt │ │ └── vocab_dir_with_entities_for_tokenizer_and_generator │ │ │ ├── entity.txt │ │ │ ├── non_padded_namespaces.txt │ │ │ └── tokens.txt │ ├── bert_pretraining │ │ └── shard1.txt │ ├── evaluation │ │ ├── semeval2010_task8 │ │ │ ├── semeval2010_task8.json │ │ │ └── vocab_entity_markers.txt │ │ ├── ultra_fine │ │ │ ├── train.json │ │ │ └── vocab.txt │ │ └── wic │ │ │ ├── train.data.txt │ │ │ ├── train.gold.txt │ │ │ └── vocab_entity_markers.txt │ ├── kg_embeddings │ │ ├── tucker_wordnet │ │ │ ├── model.tar.gz │ │ │ └── vocabulary │ │ │ │ ├── entity.txt │ │ │ │ ├── non_padded_namespaces.txt │ │ │ │ └── relation.txt │ │ ├── wn18rr_dev.txt │ │ └── wn18rr_train.txt │ ├── kg_probe │ │ └── file1.txt │ ├── linking │ │ ├── aida.txt │ │ ├── entities_universe.txt │ │ └── priors.txt │ ├── multitask │ │ ├── ccgbank.txt │ │ └── conll2003.txt │ ├── tacred │ │ ├── LDC2018T24.json │ │ └── vocab.txt │ ├── wordnet │ │ ├── cat_hat_mask_null_embedding.hdf5 │ │ ├── cat_hat_synset_mask_null_vocab.txt │ │ ├── cat_hat_vocabdir │ │ │ ├── entity.txt │ │ │ ├── non_padded_namespaces.txt │ │ │ └── tokens.txt │ │ ├── entities_cat_hat.jsonl │ │ ├── entities_fixture.jsonl │ │ └── wsd_dataset.json │ └── wordnet_wiki_vocab │ │ ├── entity_wiki.txt │ │ ├── entity_wordnet.txt │ │ └── non_padded_namespaces.txt ├── test_bert_pretraining_reader.py ├── test_bert_tokenizer_and_candidate_generator.py ├── test_common.py ├── test_dict_field.py ├── test_entity_linking.py ├── test_kg_embedding.py ├── test_kg_probe_reader.py ├── test_knowbert.py ├── test_metrics.py ├── test_multitask.py ├── test_self_attn_iterator.py ├── test_span_attention_layer.py ├── test_wiki_reader.py └── test_wordnet.py └── training_config ├── downstream ├── entity_typing.jsonnet ├── semeval2010_task8.jsonnet ├── tacred.jsonnet └── wic.jsonnet └── pretraining ├── knowbert_wiki.jsonnet ├── knowbert_wiki_linker.jsonnet ├── knowbert_wordnet.jsonnet ├── knowbert_wordnet_linker.jsonnet ├── knowbert_wordnet_wiki.jsonnet ├── knowbert_wordnet_wiki_linker.jsonnet └── wordnet_tucker.json /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache/ 2 | .vscode/ 3 | log/ 4 | *.pyc 5 | build 6 | kb.egg-info 7 | dist 8 | -------------------------------------------------------------------------------- /bin/cloze_task_example.py: -------------------------------------------------------------------------------- 1 | 2 | from kb.include_all import ModelArchiveFromParams 3 | from kb.knowbert_utils import KnowBertBatchifier 4 | from allennlp.common import Params 5 | 6 | import torch 7 | 8 | if __name__ == '__main__': 9 | archive_file = 'https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_wordnet_model.tar.gz' 10 | params = Params({"archive_file": archive_file}) 11 | 12 | # load model and batcher 13 | model = ModelArchiveFromParams.from_params(params=params) 14 | batcher = KnowBertBatchifier(archive_file, masking_strategy='full_mask') 15 | 16 | sentences = ["Paris is located in [MASK].", "La Mauricie National Park is located in [MASK]."] 17 | 18 | mask_id = batcher.tokenizer_and_candidate_generator.bert_tokenizer.vocab['[MASK]'] 19 | for batch in batcher.iter_batches(sentences): 20 | model_output = model(**batch) 21 | token_mask = batch['tokens']['tokens'] == mask_id 22 | 23 | # (batch_size, timesteps, vocab_size) 24 | prediction_scores, _ = model.pretraining_heads( 25 | model_output['contextual_embeddings'], model_output['pooled_output'] 26 | ) 27 | 28 | mask_token_probabilities = prediction_scores.masked_select(token_mask.unsqueeze(-1)).view(-1, prediction_scores.shape[-1]) # (num_masked_tokens, vocab_size) 29 | 30 | predicted_token_ids = mask_token_probabilities.argmax(dim=-1) 31 | 32 | predicted_tokens = [batcher.tokenizer_and_candidate_generator.bert_tokenizer.ids_to_tokens[int(i)] 33 | for i in predicted_token_ids] 34 | 35 | print(predicted_tokens) 36 | 37 | -------------------------------------------------------------------------------- /bin/combine_wordnet_embeddings.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | import h5py 5 | 6 | from allennlp.models.archival import load_archive 7 | from kb.common import JsonFile 8 | 9 | 10 | # includes @@PADDING@@, @@UNKNOWN@@, @@MASK@@, @@NULL@@ 11 | NUM_EMBEDDINGS = 117663 12 | 13 | def generate_wordnet_synset_vocab(entity_file, vocab_file): 14 | vocab = ['@@UNKNOWN@@'] 15 | 16 | with JsonFile(entity_file, 'r') as fin: 17 | for node in fin: 18 | if node['type'] == 'synset': 19 | vocab.append(node['id']) 20 | 21 | vocab.append('@@MASK@@') 22 | vocab.append('@@NULL@@') 23 | 24 | with open(vocab_file, 'w') as fout: 25 | fout.write('\n'.join(vocab)) 26 | 27 | 28 | def extract_tucker_embeddings(tucker_archive, vocab_file, tucker_hdf5): 29 | archive = load_archive(tucker_archive) 30 | 31 | with open(vocab_file, 'r') as fin: 32 | vocab_list = fin.read().strip().split('\n') 33 | 34 | # get embeddings 35 | embed = archive.model.kg_tuple_predictor.entities.weight.detach().numpy() 36 | out_embeddings = np.zeros((NUM_EMBEDDINGS, embed.shape[1])) 37 | 38 | vocab = archive.model.vocab 39 | 40 | for k, entity in enumerate(vocab_list): 41 | embed_id = vocab.get_token_index(entity, 'entity') 42 | if entity in ('@@MASK@@', '@@NULL@@'): 43 | # these aren't in the tucker vocab -> random init 44 | out_embeddings[k + 1, :] = np.random.randn(1, embed.shape[1]) * 0.004 45 | elif entity != '@@UNKNOWN@@': 46 | assert embed_id != 1 47 | # k = 0 is @@UNKNOWN@@, and want it at index 1 in output 48 | out_embeddings[k + 1, :] = embed[embed_id, :] 49 | 50 | # write out to file 51 | with h5py.File(tucker_hdf5, 'w') as fout: 52 | ds = fout.create_dataset('tucker', data=out_embeddings) 53 | 54 | 55 | def get_gensen_synset_definitions(entity_file, vocab_file, gensen_file): 56 | from gensen import GenSen, GenSenSingle 57 | 58 | gensen_1 = GenSenSingle( 59 | model_folder='./data/models', 60 | filename_prefix='nli_large_bothskip', 61 | pretrained_emb='./data/embedding/glove.840B.300d.h5' 62 | ) 63 | gensen_1.eval() 64 | 65 | definitions = {} 66 | with open(entity_file, 'r') as fin: 67 | for line in fin: 68 | node = json.loads(line) 69 | if node['type'] == 'synset': 70 | definitions[node['id']] = node['definition'] 71 | 72 | with open(vocab_file, 'r') as fin: 73 | vocab_list = fin.read().strip().split('\n') 74 | 75 | # get the descriptions 76 | sentences = [''] * NUM_EMBEDDINGS 77 | for k, entity in enumerate(vocab_list): 78 | definition = definitions.get(entity) 79 | if definition is None: 80 | assert entity in ('@@UNKNOWN@@', '@@MASK@@', '@@NULL@@') 81 | else: 82 | sentences[k + 1] = definition 83 | 84 | embeddings = np.zeros((NUM_EMBEDDINGS, 2048), dtype=np.float32) 85 | for k in range(0, NUM_EMBEDDINGS, 32): 86 | sents = sentences[k:(k+32)] 87 | reps_h, reps_h_t = gensen_1.get_representation( 88 | sents, pool='last', return_numpy=True, tokenize=True 89 | ) 90 | embeddings[k:(k+32), :] = reps_h_t 91 | print(k) 92 | 93 | with h5py.File(gensen_file, 'w') as fout: 94 | ds = fout.create_dataset('gensen', data=embeddings) 95 | 96 | 97 | def combine_tucker_gensen(tucker_hdf5, gensen_hdf5, all_file): 98 | with h5py.File(tucker_hdf5, 'r') as fin: 99 | tucker = fin['tucker'][...] 100 | 101 | with h5py.File(gensen_hdf5, 'r') as fin: 102 | gensen = fin['gensen'][...] 103 | 104 | all_embeds = np.concatenate([tucker, gensen], axis=1) 105 | all_e = all_embeds.astype(np.float32) 106 | 107 | with h5py.File(all_file, 'w') as fout: 108 | ds = fout.create_dataset('tucker_gensen', data=all_e) 109 | 110 | 111 | if __name__ == '__main__': 112 | import argparse 113 | 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('--generate_wordnet_synset_vocab', default=False, action="store_true") 116 | parser.add_argument('--entity_file', type=str) 117 | parser.add_argument('--vocab_file', type=str) 118 | 119 | parser.add_argument('--generate_gensen_embeddings', default=False, action="store_true") 120 | parser.add_argument('--gensen_file', type=str) 121 | 122 | parser.add_argument('--extract_tucker', default=False, action="store_true") 123 | parser.add_argument('--tucker_archive_file', type=str) 124 | parser.add_argument('--tucker_hdf5_file', type=str) 125 | 126 | parser.add_argument('--combine_tucker_gensen', default=False, action="store_true") 127 | parser.add_argument('--all_embeddings_file', type=str) 128 | 129 | args = parser.parse_args() 130 | 131 | 132 | if args.generate_wordnet_synset_vocab: 133 | generate_wordnet_synset_vocab(args.entity_file, args.vocab_file) 134 | elif args.generate_gensen_embeddings: 135 | get_gensen_synset_definitions(args.entity_file, args.vocab_file, args.gensen_file) 136 | elif args.extract_tucker: 137 | extract_tucker_embeddings(args.tucker_archive_file, args.vocab_file, args.tucker_hdf5_file) 138 | elif args.combine_tucker_gensen: 139 | combine_tucker_gensen(args.tucker_hdf5_file, args.gensen_file, args.all_embeddings_file) 140 | else: 141 | raise ValueError 142 | 143 | -------------------------------------------------------------------------------- /bin/create_pretraining_data_for_bert.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import codecs 4 | 5 | import glob 6 | import time 7 | 8 | CACHE_SIZE = int(1e8) 9 | 10 | 11 | def read_files(file_glob): 12 | # Our memory: read self.cache_size tokens, from which we will generate BERT instances. 13 | memory = [[]] 14 | n_tokens = 0 15 | 16 | t1 = time.time() 17 | all_file_names = glob.glob(file_glob) 18 | print("Found {} total files".format(len(all_file_names))) 19 | 20 | file_number = 0 21 | for fname in glob.glob(file_glob): 22 | file_number += 1 23 | print("reading file number {}, {}, total memory {}, total time {}".format(file_number, fname, n_tokens, time.time() - t1)) 24 | with codecs.open(fname, 'r', encoding='utf8') as open_file: 25 | for sentence in open_file: 26 | words = sentence.strip().split() 27 | 28 | # Empty lines are used as document delimiters 29 | if len(words) == 0: 30 | memory.append([]) 31 | else: 32 | memory[-1].append(words) 33 | n_tokens += len(words) 34 | 35 | if n_tokens > CACHE_SIZE: 36 | for document_index in range(len(memory)): 37 | for instance in gen_bert_instances(memory, document_index): 38 | yield instance 39 | 40 | n_tokens = 0 41 | memory = [[]] 42 | 43 | for document_index in range(len(memory)): 44 | for instance in gen_bert_instances(memory, document_index): 45 | yield instance 46 | 47 | 48 | def gen_bert_instances(all_documents, document_index): 49 | """ 50 | Create bert instances from a given document 51 | """ 52 | word_to_wordpiece_ratio = 1.33 53 | document = all_documents[document_index] 54 | 55 | # First, defining target sequence length 56 | # [[0.9, 128], [0.05, 256], [0.04, 384], [0.01, 512]] 57 | randnum = np.random.random() 58 | if randnum < 0.9: 59 | if np.random.random() < 0.1: 60 | target_seq_length = np.random.randint(2, 128) 61 | else: 62 | target_seq_length = 128 63 | elif randnum < 0.95: 64 | target_seq_length = 256 65 | elif randnum < 0.99: 66 | target_seq_length = 384 67 | else: 68 | target_seq_length = 512 69 | 70 | word_target_seq_length = target_seq_length / word_to_wordpiece_ratio 71 | 72 | # We DON'T just concatenate all of the tokens from a document into a long 73 | # sequence and choose an arbitrary split point because this would make the 74 | # next sentence prediction task too easy. Instead, we split the input into 75 | # segments "A" and "B" based on the actual "sentences" provided by the user 76 | # input. 77 | current_chunk = [] 78 | current_length = 0 79 | i = 0 80 | while i < len(document): 81 | segment = document[i] 82 | current_chunk.append(segment) 83 | 84 | current_length += len(segment) 85 | if i == len(document) - 1 or current_length >= word_target_seq_length: 86 | if current_chunk: 87 | # `a_end` is how many segments from `current_chunk` go into the `A` 88 | # (first) sentence. 89 | a_end = 1 90 | if len(current_chunk) >= 2: 91 | a_end = np.random.randint(1, len(current_chunk)) 92 | 93 | tokens_a = [] 94 | for j in range(a_end): 95 | tokens_a.extend(current_chunk[j]) 96 | 97 | tokens_b = [] 98 | # Random next 99 | next_sentence_label = 0 100 | if len(current_chunk) == 1 or np.random.random() < 0.5: 101 | next_sentence_label = 1 102 | target_b_length = word_target_seq_length - len(tokens_a) 103 | 104 | # This should rarely go for more than one iteration for large 105 | # corpora. However, just to be careful, we try to make sure that 106 | # the random document is not the same as the document 107 | # we're processing. 108 | for _ in range(10): 109 | random_document_index = np.random.randint(0, len(all_documents)) 110 | if random_document_index != document_index and len( 111 | all_documents[random_document_index] 112 | ) > 0: 113 | break 114 | 115 | random_document = all_documents[random_document_index] 116 | random_start = np.random.randint(0, len(random_document)) 117 | for j in range(random_start, len(random_document)): 118 | tokens_b.extend(random_document[j]) 119 | if len(tokens_b) >= target_b_length: 120 | break 121 | # We didn't actually use these segments so we "put them back" so 122 | # they don't go to waste. 123 | num_unused_segments = len(current_chunk) - a_end 124 | i -= num_unused_segments 125 | # Actual next 126 | else: 127 | for j in range(a_end, len(current_chunk)): 128 | tokens_b.extend(current_chunk[j]) 129 | 130 | yield tokens_a, tokens_b, next_sentence_label 131 | 132 | current_chunk = [] 133 | current_length = 0 134 | i += 1 135 | 136 | def read_file_sample_nsp_write_shards(file_glob, output_prefix, num_output_files): 137 | output_files = [] 138 | for k in range(num_output_files): 139 | fname = output_prefix + str(k) + '.txt' 140 | output_files.append(open(fname, 'w')) 141 | 142 | for tokens_a, tokens_b, label in read_files(file_glob): 143 | file_index = np.random.randint(0, num_output_files) 144 | line = "{}\t{}\t{}\n".format(label, ' '.join(tokens_a), ' '.join(tokens_b)) 145 | output_files[file_index].write(line) 146 | 147 | for k in range(num_output_files): 148 | output_files[k].close() 149 | 150 | -------------------------------------------------------------------------------- /bin/evaluate_mrr.py: -------------------------------------------------------------------------------- 1 | 2 | from allennlp.commands.evaluate import * 3 | from kb.include_all import * 4 | 5 | import glob 6 | 7 | def go(archive_file, cuda_device, datadir, all_only=True): 8 | archive = load_archive(archive_file, cuda_device) 9 | 10 | config = archive.config 11 | prepare_environment(config) 12 | model = archive.model 13 | model.eval() 14 | 15 | reader_params = config.pop('dataset_reader') 16 | if reader_params['type'] == 'multitask_reader': 17 | reader_params = reader_params['dataset_readers']['language_modeling'] 18 | 19 | validation_reader_params = { 20 | "type": "kg_probe", 21 | "tokenizer_and_candidate_generator": reader_params['base_reader']['tokenizer_and_candidate_generator'].as_dict() 22 | } 23 | dataset_reader = DatasetReader.from_params(Params(validation_reader_params)) 24 | 25 | iterator = DataIterator.from_params(Params( 26 | {"type": "basic", "batch_size": 32} 27 | )) 28 | iterator.index_with(model.vocab) 29 | 30 | if all_only: 31 | fnames = [datadir + '/all.text'] 32 | else: 33 | fnames = glob.glob(datadir + '/*.text') 34 | 35 | for fname in sorted(fnames): 36 | instances = dataset_reader.read(fname) 37 | metrics = evaluate(model, instances, iterator, cuda_device, "") 38 | print("================") 39 | print(fname) 40 | print(metrics) 41 | print("================") 42 | 43 | 44 | if __name__ == '__main__': 45 | import argparse 46 | 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--model_archive', type=str) 49 | parser.add_argument('--datadir', type=str) 50 | parser.add_argument('--cuda_device', type=int, default=-1) 51 | 52 | args = parser.parse_args() 53 | 54 | go(args.model_archive, args.cuda_device, args.datadir) 55 | 56 | -------------------------------------------------------------------------------- /bin/evaluate_perplexity.py: -------------------------------------------------------------------------------- 1 | 2 | # compute the heldout perplexity, next sentence prediction accuracy and loss 3 | 4 | import tqdm 5 | 6 | from allennlp.models.archival import load_archive 7 | from allennlp.data import DatasetReader, DataIterator 8 | from allennlp.common import Params 9 | from allennlp.nn.util import move_to_device 10 | 11 | from kb.include_all import BertPretrainedMaskedLM, KnowBert 12 | 13 | 14 | def run_evaluation(evaluation_file, model_archive, 15 | random_candidates=False): 16 | 17 | archive = load_archive(model_archive) 18 | model = archive.model 19 | vocab = model.vocab 20 | params = archive.config 21 | 22 | model.multitask = False 23 | model.multitask_kg = False 24 | model.cuda() 25 | model.eval() 26 | for p in model.parameters(): 27 | p.requires_grad_(False) 28 | 29 | reader_params = params.pop('dataset_reader') 30 | if reader_params['type'] == 'multitask_reader': 31 | reader_params = reader_params['dataset_readers']['language_modeling'] 32 | 33 | if random_candidates: 34 | for k, v in reader_params['base_reader']['tokenizer_and_candidate_generator']['entity_candidate_generators'].items(): 35 | v['random_candidates'] = True 36 | 37 | reader = DatasetReader.from_params(Params(reader_params)) 38 | 39 | iterator = DataIterator.from_params(Params({ 40 | "type": "self_attn_bucket", 41 | "batch_size_schedule": "base-11gb-fp32", 42 | "iterator":{ 43 | "type": "bucket", 44 | "batch_size": 32, 45 | "sorting_keys": [["tokens", "num_tokens"]], 46 | "max_instances_in_memory": 2500, 47 | } 48 | })) 49 | iterator.index_with(vocab) 50 | instances = reader.read(evaluation_file) 51 | 52 | for batch_no, batch in enumerate(tqdm.tqdm(iterator(instances, num_epochs=1))): 53 | b = move_to_device(batch, 0) 54 | loss = model(**b) 55 | if batch_no % 100 == 0: 56 | print(model.get_metrics()) 57 | 58 | print(model.get_metrics()) 59 | 60 | 61 | if __name__ == '__main__': 62 | import argparse, os 63 | 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('-e', '--evaluation_file', type=str) 66 | parser.add_argument('-m', '--model_archive', type=str) 67 | 68 | args = parser.parse_args() 69 | 70 | run_evaluation(args.evaluation_file, 71 | model_archive=args.model_archive, 72 | random_candidates=False) 73 | 74 | -------------------------------------------------------------------------------- /bin/evaluate_wiki_linking.py: -------------------------------------------------------------------------------- 1 | 2 | # the wiki entity linking model 3 | 4 | from kb.knowbert import BertPretrainedMaskedLM, KnowBert 5 | from kb.bert_tokenizer_and_candidate_generator import BertTokenizerAndCandidateGenerator 6 | from kb.bert_pretraining_reader import BertPreTrainingReader 7 | from kb.include_all import ModelArchiveFromParams 8 | from kb.include_all import TokenizerAndCandidateGenerator 9 | 10 | from allennlp.data import DatasetReader, Vocabulary, DataIterator 11 | from allennlp.models import Model 12 | from allennlp.common import Params 13 | import tqdm 14 | 15 | from allennlp.nn.util import move_to_device 16 | 17 | import torch 18 | import copy 19 | 20 | from allennlp.models.archival import load_archive 21 | 22 | from allennlp.data import Instance 23 | from allennlp.data.dataset import Batch 24 | 25 | 26 | 27 | def run_evaluation(evaluation_file, 28 | model_archive_file, 29 | is_wordnet_and_wiki=False): 30 | archive = load_archive(model_archive_file) 31 | 32 | params = archive.config 33 | vocab = Vocabulary.from_params(params.pop('vocabulary')) 34 | 35 | model = archive.model 36 | model.cuda() 37 | model.eval() 38 | 39 | if is_wordnet_and_wiki: 40 | reader_params = Params({ 41 | "type": "aida_wiki_linking", 42 | "entity_disambiguation_only": False, 43 | "entity_indexer": { 44 | "type": "characters_tokenizer", 45 | "namespace": "entity_wiki", 46 | "tokenizer": { 47 | "type": "word", 48 | "word_splitter": { 49 | "type": "just_spaces" 50 | } 51 | } 52 | }, 53 | "extra_candidate_generators": { 54 | "wordnet": { 55 | "type": "wordnet_mention_generator", 56 | "entity_file": "s3://allennlp/knowbert/wordnet/entities.jsonl" 57 | } 58 | }, 59 | "should_remap_span_indices": True, 60 | "token_indexers": { 61 | "tokens": { 62 | "type": "bert-pretrained", 63 | "do_lowercase": True, 64 | "max_pieces": 512, 65 | "pretrained_model": "bert-base-uncased", 66 | "use_starting_offsets": True, 67 | } 68 | } 69 | }) 70 | else: 71 | reader_params = Params({ 72 | "type": "aida_wiki_linking", 73 | "entity_disambiguation_only": False, 74 | "token_indexers": { 75 | "tokens": { 76 | "type": "bert-pretrained", 77 | "pretrained_model": "bert-base-uncased", 78 | "do_lowercase": True, 79 | "use_starting_offsets": True, 80 | "max_pieces": 512, 81 | }, 82 | }, 83 | "entity_indexer": { 84 | "type": "characters_tokenizer", 85 | "tokenizer": { 86 | "type": "word", 87 | "word_splitter": {"type": "just_spaces"}, 88 | }, 89 | "namespace": "entity", 90 | }, 91 | "should_remap_span_indices": True, 92 | }) 93 | 94 | if is_wordnet_and_wiki: 95 | cg_params = Params({ 96 | "type": "bert_tokenizer_and_candidate_generator", 97 | "bert_model_type": "bert-base-uncased", 98 | "do_lower_case": True, 99 | "entity_candidate_generators": { 100 | "wordnet": { 101 | "type": "wordnet_mention_generator", 102 | "entity_file": "s3://allennlp/knowbert/wordnet/entities.jsonl" 103 | } 104 | }, 105 | "entity_indexers": { 106 | "wordnet": { 107 | "type": "characters_tokenizer", 108 | "namespace": "entity_wordnet", 109 | "tokenizer": { 110 | "type": "word", 111 | "word_splitter": { 112 | "type": "just_spaces" 113 | } 114 | } 115 | } 116 | } 117 | }) 118 | candidate_generator = TokenizerAndCandidateGenerator.from_params(cg_params) 119 | 120 | reader = DatasetReader.from_params(Params(reader_params)) 121 | 122 | iterator = DataIterator.from_params(Params({"type": "basic", "batch_size": 16})) 123 | iterator.index_with(vocab) 124 | 125 | instances = reader.read(evaluation_file) 126 | 127 | for batch_no, batch in enumerate(iterator(instances, shuffle=False, num_epochs=1)): 128 | b = move_to_device(batch, 0) 129 | 130 | b['candidates'] = {'wiki': { 131 | 'candidate_entities': b.pop('candidate_entities'), 132 | 'candidate_entity_priors': b.pop('candidate_entity_prior'), 133 | 'candidate_segment_ids': b.pop('candidate_segment_ids'), 134 | 'candidate_spans': b.pop('candidate_spans')}} 135 | gold_entities = b.pop('gold_entities') 136 | b['gold_entities'] = {'wiki': gold_entities} 137 | 138 | if is_wordnet_and_wiki: 139 | extra_candidates = b.pop('extra_candidates') 140 | seq_len = b['tokens']['tokens'].shape[1] 141 | bbb = [] 142 | for e in extra_candidates: 143 | for k in e.keys(): 144 | e[k]['candidate_segment_ids'] = [0] * len(e[k]['candidate_spans']) 145 | ee = {'tokens': ['[CLS]'] * seq_len, 'segment_ids': [0] * seq_len, 146 | 'candidates': e} 147 | ee_fields = candidate_generator.convert_tokens_candidates_to_fields(ee) 148 | bbb.append(Instance(ee_fields)) 149 | eb = Batch(bbb) 150 | eb.index_instances(vocab) 151 | padding_lengths = eb.get_padding_lengths() 152 | tensor_dict = eb.as_tensor_dict(padding_lengths) 153 | b['candidates'].update(tensor_dict['candidates']) 154 | bb = move_to_device(b, 0) 155 | else: 156 | bb = b 157 | 158 | loss = model(**bb) 159 | if batch_no % 100 == 0: 160 | print(model.get_metrics()) 161 | 162 | print(model.get_metrics()) 163 | 164 | 165 | if __name__ == '__main__': 166 | import argparse 167 | 168 | parser = argparse.ArgumentParser() 169 | parser.add_argument('-e', '--evaluation_file', type=str) 170 | parser.add_argument('-a', '--model_archive', type=str) 171 | parser.add_argument('--wiki_and_wordnet', action='store_true') 172 | 173 | args = parser.parse_args() 174 | 175 | run_evaluation(args.evaluation_file, args.model_archive, is_wordnet_and_wiki=args.wiki_and_wordnet) 176 | 177 | -------------------------------------------------------------------------------- /bin/preprocess_wsd.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Original data from: http://lcl.uniroma1.it/wsdeval/ 3 | 4 | Data converted from XML to jsonl 5 | ''' 6 | 7 | import os 8 | import h5py 9 | import json 10 | 11 | import numpy as np 12 | 13 | from kb.common import JsonFile 14 | 15 | def read_gold_data(fname): 16 | gold_data = {} 17 | with open(fname, 'r') as fin: 18 | for line in fin: 19 | ls = line.strip().split() 20 | lid = ls[0] 21 | if lid not in gold_data: 22 | gold_data[lid] = set() 23 | for sense in ls[1:]: 24 | gold_data[lid].add(sense) 25 | return gold_data 26 | 27 | 28 | def read_wsd_data(fname, fname_gold): 29 | from lxml import etree 30 | 31 | gold_data = read_gold_data(fname_gold) 32 | 33 | with open(fname, 'r') as fin: 34 | data = fin.read() 35 | 36 | corpus = etree.fromstring(data.encode('utf-8')) 37 | 38 | sentences = [] 39 | n_sentences = 0 40 | for node in corpus.iterdescendants(): 41 | if node.tag == 'sentence': 42 | sentence = [] 43 | for token_node in node.iterdescendants(): 44 | token = { 45 | 'token': token_node.text, 46 | 'lemma': token_node.attrib['lemma'], 47 | 'pos': token_node.attrib['pos'] 48 | } 49 | if token_node.tag == 'instance': 50 | token['id'] = token_node.attrib['id'] 51 | token['senses'] = [] 52 | for sense in gold_data[token['id']]: 53 | lemma, _, ss = sense.partition('%') 54 | assert lemma == token['lemma'] 55 | token['senses'].append(ss) 56 | 57 | sentence.append(token) 58 | 59 | sentences.append(sentence) 60 | 61 | return sentences 62 | 63 | 64 | def get_dataset_metadata(wsd_framework_root): 65 | return [ 66 | [ 67 | 'semcor', 68 | os.path.join( 69 | wsd_framework_root, 'Training_Corpora', 'SemCor', 'semcor' 70 | ) 71 | ], [ 72 | 'senseval2', 73 | os.path.join( 74 | wsd_framework_root, 'Evaluation_Datasets', 'senseval2', 75 | 'senseval2' 76 | ) 77 | ], [ 78 | 'senseval3', 79 | os.path.join( 80 | wsd_framework_root, 'Evaluation_Datasets', 'senseval3', 81 | 'senseval3' 82 | ) 83 | ], [ 84 | 'semeval2015', 85 | os.path.join( 86 | wsd_framework_root, 'Evaluation_Datasets', 'semeval2015', 87 | 'semeval2015' 88 | ) 89 | ], [ 90 | 'semeval2013', 91 | os.path.join( 92 | wsd_framework_root, 'Evaluation_Datasets', 'semeval2013', 93 | 'semeval2013' 94 | ) 95 | ], [ 96 | 'semeval2007', 97 | os.path.join( 98 | wsd_framework_root, 'Evaluation_Datasets', 'semeval2007', 99 | 'semeval2007' 100 | ) 101 | ] 102 | ] 103 | 104 | 105 | def convert_all_wsd_datasets(outdir, wsd_framework_root): 106 | datasets = get_dataset_metadata(wsd_framework_root) 107 | 108 | for ds in datasets: 109 | ds_name, ds_root = ds 110 | data = read_wsd_data(ds_root + '.data.xml', ds_root + '.gold.key.txt') 111 | with JsonFile(os.path.join(outdir, ds_name + '.json'), 'w') as fout: 112 | for line in data: 113 | fout.write(line) 114 | 115 | 116 | if __name__ == '__main__': 117 | import argparse 118 | 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument('--wsd_framework_root', type=str) 121 | parser.add_argument('--outdir', type=str) 122 | 123 | args = parser.parse_args() 124 | 125 | convert_all_wsd_datasets(args.outdir, args.wsd_framework_root) 126 | 127 | -------------------------------------------------------------------------------- /bin/run_hyperparameter_seeds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ./bin/run_hyperparameter_seeds.sh training_config output_prefix 4 | # 5 | 6 | 7 | training_config=$1 8 | shift; 9 | output_prefix=$1 10 | shift; 11 | key=$1 12 | 13 | random_seeds=(1989894904 2294922467 2002866410 1004506748 4076792239) 14 | numpy_seeds=(1053248695 2739105195 1071118652 755056791 3842727116) 15 | pytorch_seeds=(81406405 807621944 3166916287 3467634827 1189731539) 16 | 17 | i=0 18 | while [ $i -lt 5 ]; do 19 | rs=${random_seeds[$i]} 20 | ns=${numpy_seeds[$i]} 21 | ps=${pytorch_seeds[$i]} 22 | 23 | for LR in 2e-5 3e-5 5e-5 24 | do 25 | for NUM_EPOCHS in 3 4 26 | do 27 | 28 | echo "$i $LR $NUM_EPOCHS" 29 | 30 | overrides="{\"random_seed\": $rs, \"numpy_seed\": $ns, \"pytorch_seed\": $ps, \"trainer\": {\"num_epochs\": $NUM_EPOCHS, \"learning_rate_scheduler\": {\"num_epochs\": $NUM_EPOCHS}, \"optimizer\": {\"lr\": $LR}}}" 31 | 32 | outdir=${output_prefix}_lr${LR}_${NUM_EPOCHS}epochs_SEED_$i 33 | allennlp train --file-friendly-logging --include-package kb.include_all $training_config -s $outdir --overrides "$overrides" 34 | 35 | done 36 | 37 | done 38 | 39 | let i=i+1 40 | done 41 | 42 | 43 | -------------------------------------------------------------------------------- /bin/tacred_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Score the predictions with gold labels, using precision, recall and F1 metrics. 5 | """ 6 | 7 | import argparse 8 | import sys 9 | from collections import Counter 10 | 11 | NO_RELATION = "no_relation" 12 | 13 | def parse_arguments(): 14 | parser = argparse.ArgumentParser(description='Score a prediction file using the gold labels.') 15 | parser.add_argument('gold_file', help='The gold relation file; one relation per line') 16 | parser.add_argument('pred_file', help='A prediction file; one relation per line, in the same order as the gold file.') 17 | args = parser.parse_args() 18 | return args 19 | 20 | def score(key, prediction, verbose=False): 21 | correct_by_relation = Counter() 22 | guessed_by_relation = Counter() 23 | gold_by_relation = Counter() 24 | 25 | # Loop over the data to compute a score 26 | for row in range(len(key)): 27 | gold = key[row] 28 | guess = prediction[row] 29 | 30 | if gold == NO_RELATION and guess == NO_RELATION: 31 | pass 32 | elif gold == NO_RELATION and guess != NO_RELATION: 33 | guessed_by_relation[guess] += 1 34 | elif gold != NO_RELATION and guess == NO_RELATION: 35 | gold_by_relation[gold] += 1 36 | elif gold != NO_RELATION and guess != NO_RELATION: 37 | guessed_by_relation[guess] += 1 38 | gold_by_relation[gold] += 1 39 | if gold == guess: 40 | correct_by_relation[guess] += 1 41 | 42 | # Print verbose information 43 | if verbose: 44 | print("Per-relation statistics:") 45 | relations = gold_by_relation.keys() 46 | longest_relation = 0 47 | for relation in sorted(relations): 48 | longest_relation = max(len(relation), longest_relation) 49 | for relation in sorted(relations): 50 | # (compute the score) 51 | correct = correct_by_relation[relation] 52 | guessed = guessed_by_relation[relation] 53 | gold = gold_by_relation[relation] 54 | prec = 1.0 55 | if guessed > 0: 56 | prec = float(correct) / float(guessed) 57 | recall = 0.0 58 | if gold > 0: 59 | recall = float(correct) / float(gold) 60 | f1 = 0.0 61 | if prec + recall > 0: 62 | f1 = 2.0 * prec * recall / (prec + recall) 63 | # (print the score) 64 | sys.stdout.write(("{:<" + str(longest_relation) + "}").format(relation)) 65 | sys.stdout.write(" P: ") 66 | if prec < 0.1: sys.stdout.write(' ') 67 | if prec < 1.0: sys.stdout.write(' ') 68 | sys.stdout.write("{:.2%}".format(prec)) 69 | sys.stdout.write(" R: ") 70 | if recall < 0.1: sys.stdout.write(' ') 71 | if recall < 1.0: sys.stdout.write(' ') 72 | sys.stdout.write("{:.2%}".format(recall)) 73 | sys.stdout.write(" F1: ") 74 | if f1 < 0.1: sys.stdout.write(' ') 75 | if f1 < 1.0: sys.stdout.write(' ') 76 | sys.stdout.write("{:.2%}".format(f1)) 77 | sys.stdout.write(" #: %d" % gold) 78 | sys.stdout.write("\n") 79 | print("") 80 | 81 | # Print the aggregate score 82 | if verbose: 83 | print("Final Score:") 84 | prec_micro = 1.0 85 | if sum(guessed_by_relation.values()) > 0: 86 | prec_micro = float(sum(correct_by_relation.values())) / float(sum(guessed_by_relation.values())) 87 | recall_micro = 0.0 88 | if sum(gold_by_relation.values()) > 0: 89 | recall_micro = float(sum(correct_by_relation.values())) / float(sum(gold_by_relation.values())) 90 | f1_micro = 0.0 91 | if prec_micro + recall_micro > 0.0: 92 | f1_micro = 2.0 * prec_micro * recall_micro / (prec_micro + recall_micro) 93 | print( "Precision (micro): {:.3%}".format(prec_micro) ) 94 | print( " Recall (micro): {:.3%}".format(recall_micro) ) 95 | print( " F1 (micro): {:.3%}".format(f1_micro) ) 96 | return prec_micro, recall_micro, f1_micro 97 | 98 | if __name__ == "__main__": 99 | # Parse the arguments from stdin 100 | args = parse_arguments() 101 | key = [str(line).rstrip('\n') for line in open(str(args.gold_file))] 102 | prediction = [str(line).rstrip('\n') for line in open(str(args.pred_file))] 103 | 104 | # Check that the lengths match 105 | if len(prediction) != len(key): 106 | print("Gold and prediction file must have same number of elements: %d in gold vs %d in prediction" % (len(key), len(prediction))) 107 | exit(1) 108 | 109 | # Score the predictions 110 | score(key, prediction, verbose=True) 111 | 112 | -------------------------------------------------------------------------------- /bin/write_semeval2010_task8_for_official_eval.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from allennlp.models.archival import load_archive 4 | from allennlp.data import DatasetReader, Vocabulary, DataIterator 5 | from allennlp.nn.util import move_to_device 6 | from allennlp.common import Params 7 | 8 | import numpy as np 9 | 10 | from kb.include_all import * 11 | 12 | from kb.evaluation.semeval2010_task8 import LABEL_MAP 13 | 14 | 15 | def write_for_official_eval(model_archive_file, test_file, output_file, 16 | label_ids_to_label): 17 | archive = load_archive(model_archive_file) 18 | model = archive.model 19 | 20 | reader = DatasetReader.from_params(archive.config['dataset_reader']) 21 | 22 | iterator = DataIterator.from_params(Params({"type": "basic", "batch_size": 4})) 23 | vocab = Vocabulary.from_params(archive.config['vocabulary']) 24 | iterator.index_with(vocab) 25 | 26 | model.cuda() 27 | model.eval() 28 | 29 | instances = reader.read(test_file) 30 | predictions = [] 31 | for batch in iterator(instances, num_epochs=1, shuffle=False): 32 | batch = move_to_device(batch, cuda_device=0) 33 | output = model(**batch) 34 | 35 | batch_labels = [ 36 | label_ids_to_label[i] 37 | for i in output['predictions'].cpu().numpy().tolist() 38 | ] 39 | 40 | predictions.extend(batch_labels) 41 | 42 | to_write = ''.join(["{}\t{}\n".format(i + 8001, e) for i, e in enumerate(model.metrics[0].pred)]) 43 | with open(output_file, 'w') as fout: 44 | fout.write(to_write) 45 | 46 | 47 | if __name__ == '__main__': 48 | import argparse 49 | 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('--model_archive', type=str) 52 | parser.add_argument('--evaluation_file', type=str) 53 | parser.add_argument('--output_file', type=str) 54 | 55 | args = parser.parse_args() 56 | 57 | # int -> str 58 | label_ids_to_label = {v:k for k, v in LABEL_MAP.items()} 59 | 60 | write_for_official_eval(args.model_archive, 61 | args.evaluation_file, 62 | args.output_file, 63 | label_ids_to_label) 64 | -------------------------------------------------------------------------------- /bin/write_tacred_for_official_scorer.py: -------------------------------------------------------------------------------- 1 | 2 | from allennlp.models.archival import load_archive 3 | from allennlp.data import DatasetReader, Vocabulary, DataIterator 4 | from allennlp.nn.util import move_to_device 5 | from allennlp.common import Params 6 | 7 | import numpy as np 8 | 9 | from kb.include_all import * 10 | 11 | 12 | def write_for_official_eval(model_archive_file, test_file, output_file, 13 | label_ids_to_label): 14 | archive = load_archive(model_archive_file) 15 | model = archive.model 16 | 17 | reader = DatasetReader.from_params(archive.config['dataset_reader']) 18 | 19 | iterator = DataIterator.from_params(Params({"type": "basic", "batch_size": 4})) 20 | vocab = Vocabulary.from_params(archive.config['vocabulary']) 21 | iterator.index_with(vocab) 22 | 23 | model.cuda() 24 | model.eval() 25 | 26 | instances = reader.read(test_file) 27 | predictions = [] 28 | for batch in iterator(instances, num_epochs=1, shuffle=False): 29 | batch = move_to_device(batch, cuda_device=0) 30 | output = model(**batch) 31 | 32 | batch_labels = [ 33 | label_ids_to_label[i] 34 | for i in output['predictions'].cpu().numpy().tolist() 35 | ] 36 | 37 | predictions.extend(batch_labels) 38 | 39 | 40 | with open(output_file, 'w') as fout: 41 | for p in predictions: 42 | fout.write("{}\n".format(p)) 43 | 44 | if __name__ == '__main__': 45 | import argparse 46 | from kb.evaluation.tacred_dataset_reader import LABEL_MAP 47 | 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--model_archive', type=str) 50 | parser.add_argument('--evaluation_file', type=str) 51 | parser.add_argument('--output_file', type=str) 52 | 53 | args = parser.parse_args() 54 | 55 | # int -> str 56 | label_ids_to_label = {v:k for k, v in LABEL_MAP.items()} 57 | 58 | write_for_official_eval(args.model_archive, 59 | args.evaluation_file, 60 | args.output_file, 61 | label_ids_to_label) 62 | 63 | -------------------------------------------------------------------------------- /bin/write_wic_for_codalab.py: -------------------------------------------------------------------------------- 1 | 2 | from allennlp.models.archival import load_archive 3 | from allennlp.data import DatasetReader, Vocabulary, DataIterator 4 | from allennlp.nn.util import move_to_device 5 | from allennlp.common import Params 6 | 7 | import numpy as np 8 | 9 | from kb.include_all import * 10 | 11 | 12 | def write_for_official_eval(model_archive_file, test_file, output_file): 13 | archive = load_archive(model_archive_file) 14 | model = archive.model 15 | 16 | reader = DatasetReader.from_params(archive.config['dataset_reader']) 17 | 18 | iterator = DataIterator.from_params(Params({"type": "basic", "batch_size": 32})) 19 | vocab = Vocabulary.from_params(archive.config['vocabulary']) 20 | iterator.index_with(vocab) 21 | 22 | model.cuda() 23 | model.eval() 24 | 25 | label_ids_to_label = {0: 'F', 1: 'T'} 26 | 27 | instances = reader.read(test_file) 28 | predictions = [] 29 | for batch in iterator(instances, num_epochs=1, shuffle=False): 30 | batch = move_to_device(batch, cuda_device=0) 31 | output = model(**batch) 32 | 33 | batch_labels = [ 34 | label_ids_to_label[i] 35 | for i in output['predictions'].cpu().numpy().tolist() 36 | ] 37 | 38 | predictions.extend(batch_labels) 39 | 40 | assert len(predictions) == 1400 41 | 42 | with open(output_file, 'w') as fout: 43 | for p in predictions: 44 | fout.write("{}\n".format(p)) 45 | 46 | -------------------------------------------------------------------------------- /kb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/kb/4c37bbccd6c871828aceb24b342841825d86df14/kb/__init__.py -------------------------------------------------------------------------------- /kb/bert_utils.py: -------------------------------------------------------------------------------- 1 | from allennlp.nn import Activation 2 | import torch 3 | import math 4 | 5 | def truncate_seq_pair(tokens_a, tokens_b, max_length): 6 | """Truncates a sequence pair in place to the maximum length.""" 7 | # Copied from pytorch_pretrained_bert/examples/run_classifier.py 8 | # This is a simple heuristic which will always truncate the longer sequence 9 | # one token at a time. This makes more sense than truncating an equal percent 10 | # of tokens from each, since if one sequence is very short then each token 11 | # that's truncated likely contains more information than a longer sequence. 12 | while True: 13 | total_length = len(tokens_a) + len(tokens_b) 14 | if total_length <= max_length: 15 | break 16 | if len(tokens_a) > len(tokens_b): 17 | tokens_a.pop() 18 | else: 19 | tokens_b.pop() 20 | 21 | 22 | @Activation.register("gelu") 23 | class GeLu(Activation): 24 | def __call__(self, x): 25 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 26 | -------------------------------------------------------------------------------- /kb/common.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | 4 | import torch 5 | 6 | from pytorch_pretrained_bert.modeling import \ 7 | BertLayer, BertAttention, BertSelfAttention, BertSelfOutput, \ 8 | BertOutput, BertIntermediate, BertEncoder, BertLayerNorm, BertConfig 9 | 10 | from allennlp.common.registrable import Registrable 11 | from allennlp.training.metrics.metric import Metric 12 | 13 | import spacy 14 | from spacy.tokens import Doc 15 | 16 | 17 | 18 | class MentionGenerator(Registrable): 19 | pass 20 | 21 | 22 | class EntityEmbedder(Registrable): 23 | pass 24 | 25 | 26 | def get_empty_candidates(): 27 | """ 28 | The mention generators always return at least one candidate, but signal 29 | it with this special candidate 30 | """ 31 | return { 32 | "candidate_spans": [[-1, -1]], 33 | "candidate_entities": [["@@PADDING@@"]], 34 | "candidate_entity_priors": [[1.0]] 35 | } 36 | 37 | 38 | # from https://spacy.io/usage/linguistic-features#custom-tokenizer-example 39 | class WhitespaceTokenizer(object): 40 | def __init__(self, vocab): 41 | self.vocab = vocab 42 | 43 | def __call__(self, text): 44 | words = text.split(' ') 45 | # All tokens 'own' a subsequent space character in this tokenizer 46 | spaces = [True] * len(words) 47 | return Doc(self.vocab, words=words, spaces=spaces) 48 | 49 | 50 | @Metric.register("f1_set") 51 | class F1Metric(Metric): 52 | """ 53 | A generic set based F1 metric. 54 | Takes two lists of predicted and gold elements and computes F1. 55 | Only requirements are that the elements are hashable. 56 | """ 57 | def __init__(self, filter_func=None): 58 | self.reset() 59 | if filter_func is None: 60 | filter_func = lambda x: True 61 | self.filter_func = filter_func 62 | 63 | def reset(self): 64 | self._true_positives = 0.0 65 | self._false_positives = 0.0 66 | self._false_negatives = 0.0 67 | 68 | def get_metric(self, reset: bool = False): 69 | """ 70 | Returns 71 | ------- 72 | A tuple of the following metrics based on the accumulated count statistics: 73 | precision : float 74 | recall : float 75 | f1-measure : float 76 | """ 77 | precision = float(self._true_positives) / float(self._true_positives + self._false_positives + 1e-13) 78 | recall = float(self._true_positives) / float(self._true_positives + self._false_negatives + 1e-13) 79 | f1_measure = 2. * ((precision * recall) / (precision + recall + 1e-13)) 80 | if reset: 81 | self.reset() 82 | return precision, recall, f1_measure 83 | 84 | def __call__(self, predictions, gold_labels): 85 | """ 86 | predictions = batch of predictions that can be compared 87 | gold labels = list of gold labels 88 | 89 | e.g. 90 | predictions = [ 91 | [('ORG', (0, 1)), ('PER', (5, 8))], 92 | [('MISC', (9, 13))] 93 | ] 94 | gold_labels = [ 95 | [('ORG', (0, 1))], 96 | [] 97 | ] 98 | 99 | elements must be hashable 100 | """ 101 | assert len(predictions) == len(gold_labels) 102 | 103 | for pred, gold in zip(predictions, gold_labels): 104 | s_gold = set(g for g in gold if self.filter_func(g)) 105 | s_pred = set(p for p in pred if self.filter_func(p)) 106 | 107 | for p in s_pred: 108 | if p in s_gold: 109 | self._true_positives += 1 110 | else: 111 | self._false_positives += 1 112 | 113 | for p in s_gold: 114 | if p not in s_pred: 115 | self._false_negatives += 1 116 | 117 | 118 | def get_dtype_for_module(module): 119 | # gets dtype for module parameters, for fp16 support when casting 120 | # we unfortunately can't set this during module construction as module 121 | # will be moved to GPU or cast to half after construction. 122 | return next(module.parameters()).dtype 123 | 124 | def set_requires_grad(module, requires_grad): 125 | for param in module.parameters(): 126 | param.requires_grad_(requires_grad) 127 | 128 | def extend_attention_mask_for_bert(mask, dtype): 129 | # mask = (batch_size, timesteps) 130 | # returns an attention_mask useable with BERT 131 | # see: https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L696 132 | extended_attention_mask = mask.unsqueeze(1).unsqueeze(2) 133 | extended_attention_mask = extended_attention_mask.to(dtype=dtype) 134 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 135 | return extended_attention_mask 136 | 137 | 138 | def init_bert_weights(module, initializer_range, extra_modules_without_weights=()): 139 | # these modules don't have any weights, other then ones in submodules, 140 | # so don't have to worry about init 141 | modules_without_weights = ( 142 | BertEncoder, torch.nn.ModuleList, torch.nn.Dropout, BertLayer, 143 | BertAttention, BertSelfAttention, BertSelfOutput, 144 | BertOutput, BertIntermediate 145 | ) + extra_modules_without_weights 146 | 147 | 148 | # modified from pytorch_pretrained_bert 149 | def _do_init(m): 150 | if isinstance(m, (torch.nn.Linear, torch.nn.Embedding)): 151 | # Slightly different from the TF version which uses truncated_normal for initialization 152 | # cf https://github.com/pytorch/pytorch/pull/5617 153 | m.weight.data.normal_(mean=0.0, std=initializer_range) 154 | elif isinstance(m, BertLayerNorm): 155 | m.bias.data.zero_() 156 | m.weight.data.fill_(1.0) 157 | elif isinstance(m, modules_without_weights): 158 | pass 159 | else: 160 | raise ValueError(str(m)) 161 | 162 | if isinstance(m, torch.nn.Linear) and m.bias is not None: 163 | m.bias.data.zero_() 164 | 165 | for mm in module.modules(): 166 | _do_init(mm) 167 | 168 | def get_linear_layer_init_identity(dim): 169 | ret = torch.nn.Linear(dim, dim) 170 | ret.weight.data.copy_(torch.eye(dim)) 171 | ret.bias.data.fill_(0.0) 172 | return ret 173 | 174 | 175 | 176 | class JsonFile: 177 | ''' 178 | A flat text file where each line is one json object 179 | 180 | # to read though a file line by line 181 | with JsonFile('file.json', 'r') as fin: 182 | for line in fin: 183 | # line is the deserialized json object 184 | pass 185 | 186 | 187 | # to write a file object by object 188 | with JsonFile('file.json', 'w') as fout: 189 | fout.write({'key1': 5, 'key2': 'token'}) 190 | fout.write({'key1': 0, 'key2': 'the'}) 191 | ''' 192 | 193 | def __init__(self, *args, **kwargs): 194 | self._args = args 195 | self._kwargs = kwargs 196 | 197 | def __iter__(self): 198 | for line in self._file: 199 | yield json.loads(line) 200 | 201 | def write(self, item): 202 | item_as_json = json.dumps(item, ensure_ascii=False) 203 | encoded = '{0}\n'.format(item_as_json) 204 | self._file.write(encoded) 205 | 206 | def __enter__(self): 207 | self._file = open(*self._args, **self._kwargs) 208 | self._file.__enter__() 209 | return self 210 | 211 | def __exit__(self, exc_type, exc_val, exc_tb): 212 | self._file.__exit__(exc_type, exc_val, exc_tb) 213 | 214 | -------------------------------------------------------------------------------- /kb/dict_field.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Dict, List, Iterator 3 | 4 | from overrides import overrides 5 | 6 | from allennlp.data.fields.field import DataArray, Field 7 | from allennlp.data.vocabulary import Vocabulary 8 | from allennlp.data.fields.sequence_field import SequenceField 9 | from allennlp.common.util import pad_sequence_to_length 10 | 11 | SEPERATOR = '*' 12 | 13 | 14 | class DictField(Field): 15 | """ 16 | dict with values as fields 17 | """ 18 | def __init__(self, field_dict: Dict[str, Field]) -> None: 19 | self.field_dict = field_dict 20 | 21 | @overrides 22 | def count_vocab_items(self, counter: Dict[str, Dict[str, int]]): 23 | for field in self.field_dict.values(): 24 | field.count_vocab_items(counter) 25 | 26 | @overrides 27 | def index(self, vocab: Vocabulary): 28 | for field in self.field_dict.values(): 29 | field.index(vocab) 30 | 31 | @overrides 32 | def get_padding_lengths(self) -> Dict[str, int]: 33 | padding_lengths = {} 34 | for key, field in self.field_dict.items(): 35 | for sub_key, val in field.get_padding_lengths().items(): 36 | padding_lengths[key + SEPERATOR + sub_key] = val 37 | return padding_lengths 38 | 39 | @overrides 40 | def as_tensor(self, padding_lengths: Dict[str, int]) -> DataArray: 41 | # padding_lengths is flattened from the nested structure -- unflatten 42 | pl = {} 43 | for full_key, val in padding_lengths.items(): 44 | key, _, sub_key = full_key.partition(SEPERATOR) 45 | if key not in pl: 46 | pl[key] = {} 47 | pl[key][sub_key] = val 48 | 49 | ret = {} 50 | for key, field in self.field_dict.items(): 51 | ret[key] = field.as_tensor(pl[key]) 52 | 53 | return ret 54 | 55 | @overrides 56 | def empty_field(self): 57 | return DictField({key: field.empty_field() for key, field in self.field_dict.items()}) 58 | 59 | @overrides 60 | def batch_tensors(self, tensor_list_of_dict): 61 | ret = {} 62 | for key, field in self.field_dict.items(): 63 | ret[key] = field.batch_tensors([t[key] for t in tensor_list_of_dict]) 64 | return ret 65 | 66 | def __str__(self) -> str: 67 | return "" 68 | -------------------------------------------------------------------------------- /kb/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/kb/4c37bbccd6c871828aceb24b342841825d86df14/kb/evaluation/__init__.py -------------------------------------------------------------------------------- /kb/evaluation/exponential_average_metric.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | 3 | from allennlp.training.metrics.metric import Metric 4 | 5 | 6 | @Metric.register("ema") 7 | class ExponentialMovingAverage(Metric): 8 | """ 9 | Keep an exponentially weighted moving average. 10 | alpha is the decay constant. Alpha = 1 means just keep the most recent value. 11 | alpha = 0.5 will have almost no contribution from 10 time steps ago. 12 | """ 13 | def __init__(self, alpha:float = 0.5) -> None: 14 | self.alpha = alpha 15 | self.reset() 16 | 17 | @overrides 18 | def __call__(self, value): 19 | """ 20 | Parameters 21 | ---------- 22 | value : ``float`` 23 | The value to average. 24 | """ 25 | if self._ema is None: 26 | # first observation 27 | self._ema = value 28 | else: 29 | self._ema = self.alpha * value + (1.0 - self.alpha) * self._ema 30 | 31 | @overrides 32 | def get_metric(self, reset: bool = False): 33 | """ 34 | Returns 35 | ------- 36 | The average of all values that were passed to ``__call__``. 37 | """ 38 | if self._ema is None: 39 | ret = 0.0 40 | else: 41 | ret = self._ema 42 | 43 | if reset: 44 | self.reset() 45 | 46 | return ret 47 | 48 | @overrides 49 | def reset(self): 50 | self._ema = None 51 | -------------------------------------------------------------------------------- /kb/evaluation/tacred_predictor.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | 3 | from allennlp.common.util import JsonDict 4 | from allennlp.data import Instance 5 | from allennlp.predictors import Predictor 6 | from kb.evaluation.tacred_dataset_reader import LABEL_MAP 7 | 8 | 9 | REVERSE_LABEL_MAP = {y: x for x, y in LABEL_MAP.items()} 10 | 11 | 12 | @Predictor.register('tacred') 13 | class TacredPredictor(Predictor): 14 | @overrides 15 | def dump_line(self, outputs: JsonDict) -> str: 16 | return REVERSE_LABEL_MAP[outputs['predictions']] + '\n' 17 | 18 | -------------------------------------------------------------------------------- /kb/evaluation/ultra_fine_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import Iterable 4 | 5 | from allennlp.common.file_utils import cached_path 6 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 7 | from allennlp.data.fields import ArrayField, LabelField 8 | from allennlp.data.instance import Instance 9 | import numpy as np 10 | 11 | from kb.bert_tokenizer_and_candidate_generator import TokenizerAndCandidateGenerator 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | LABEL_MAP = { 17 | 'entity': 0, 18 | 'event': 1, 19 | 'group': 2, 20 | 'location': 3, 21 | 'object': 4, 22 | 'organization': 5, 23 | 'person': 6, 24 | 'place': 7, 25 | 'time': 8 26 | } 27 | 28 | 29 | @DatasetReader.register('ultra_fine') 30 | class UltraFineReader(DatasetReader): 31 | """ 32 | Reads coarse grained entity typing data from "Ultra-Fine Entity Typing", 33 | Choi et al, ACL 2018. 34 | 35 | Reads data format from https://github.com/thunlp/ERNIE 36 | 37 | Encodes as: 38 | 39 | entity_masking = 'entity': 40 | [CLS] The left context [ENTITY] right context . [SEP] entity name [SEP] 41 | use [unused0] as the [ENTITY] token 42 | 43 | entity_masking = 'entity_markers': 44 | [CLS] The left context [e1start] entity name [e1end] right context . [SEP] 45 | """ 46 | def __init__(self, 47 | tokenizer_and_candidate_generator: TokenizerAndCandidateGenerator, 48 | entity_masking: str = 'entity', 49 | lazy: bool = False) -> None: 50 | super().__init__(lazy=lazy) 51 | self.tokenizer_and_candidate_generator = tokenizer_and_candidate_generator 52 | self.tokenizer_and_candidate_generator.whitespace_tokenize = True 53 | assert entity_masking in ('entity', 'entity_markers') 54 | self.entity_masking = entity_masking 55 | 56 | 57 | def _read(self, file_path: str) -> Iterable[Instance]: 58 | with open(cached_path(file_path), 'r') as f: 59 | data = json.load(f) 60 | 61 | for example in data: 62 | # whitespace separated 63 | tokens = example['sent'] 64 | 65 | left = tokens[:example['start']] 66 | span = tokens[example['start']:example['end']] 67 | right = tokens[example['end']:] 68 | 69 | if self.entity_masking == 'entity': 70 | sentence = left.strip() + ' [unused0] ' + right.strip() 71 | span = span.strip() 72 | index_entity_start = None 73 | elif self.entity_masking == 'entity_markers': 74 | sentence = left.strip() + ' [e1start] ' + span.strip() + ' [e1end] ' + right.strip() 75 | span = None 76 | index_entity_start = sentence.split().index('[e1start]') 77 | 78 | # get the labels 79 | labels = [0] * len(LABEL_MAP) 80 | for label in example['labels']: 81 | labels[LABEL_MAP[label]] = 1 82 | 83 | yield self.text_to_instance(sentence, span, labels, index_entity_start) 84 | 85 | def text_to_instance(self, sentence, span, labels, index_entity_start): 86 | token_candidates = self.tokenizer_and_candidate_generator.tokenize_and_generate_candidates(sentence, span) 87 | fields = self.tokenizer_and_candidate_generator.convert_tokens_candidates_to_fields(token_candidates) 88 | fields['label_ids'] = ArrayField(np.array(labels), dtype=np.int) 89 | 90 | # index of entity start 91 | if index_entity_start is not None: 92 | offsets = [1] + token_candidates['offsets_a'][:-1] 93 | idx1_offset = offsets[index_entity_start] 94 | fields['index_a'] = LabelField(idx1_offset, skip_indexing=True) 95 | 96 | return Instance(fields) 97 | 98 | -------------------------------------------------------------------------------- /kb/evaluation/weighted_average.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | 3 | from allennlp.training.metrics.metric import Metric 4 | 5 | 6 | @Metric.register("weighted_average") 7 | class WeightedAverage(Metric): 8 | """ 9 | This :class:`Metric` breaks with the typical ``Metric`` API and just stores values that were 10 | computed in some fashion outside of a ``Metric``. If you have some external code that computes 11 | the metric for you, for instance, you can use this to report the average result using our 12 | ``Metric`` API. 13 | """ 14 | def __init__(self) -> None: 15 | self._total_value = 0.0 16 | self._count = 0 17 | 18 | @overrides 19 | def __call__(self, value, count=1): 20 | """ 21 | Parameters 22 | ---------- 23 | value : ``float`` 24 | The value to average. 25 | """ 26 | self._total_value += (list(self.unwrap_to_tensors(value))[0] * count) 27 | self._count += count 28 | 29 | @overrides 30 | def get_metric(self, reset: bool = False): 31 | """ 32 | Returns 33 | ------- 34 | The average of all values that were passed to ``__call__``. 35 | """ 36 | average_value = self._total_value / self._count if self._count > 0 else 0 37 | if reset: 38 | self.reset() 39 | return average_value 40 | 41 | @overrides 42 | def reset(self): 43 | self._total_value = 0.0 44 | self._count = 0 45 | -------------------------------------------------------------------------------- /kb/evaluation/wic_dataset_reader.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 3 | from allennlp.data.fields import LabelField 4 | from allennlp.data.instance import Instance 5 | from kb.bert_tokenizer_and_candidate_generator import TokenizerAndCandidateGenerator 6 | from allennlp.common.file_utils import cached_path 7 | 8 | 9 | @DatasetReader.register("wic") 10 | class WicDatasetReader(DatasetReader): 11 | def __init__(self, 12 | tokenizer_and_candidate_generator: TokenizerAndCandidateGenerator, 13 | entity_markers: bool = False): 14 | super().__init__() 15 | self.label_to_index = {'T': 1, 'F': 0} 16 | self.tokenizer = tokenizer_and_candidate_generator 17 | self.tokenizer.whitespace_tokenize = True 18 | self.entity_markers = entity_markers 19 | 20 | def text_to_instance(self, line) -> Instance: 21 | raise NotImplementedError 22 | 23 | def _read(self, file_path: str) -> Iterable[Instance]: 24 | """Creates examples for the training and dev sets.""" 25 | 26 | with open(cached_path(file_path + '.gold.txt'), 'r') as f: 27 | labels = f.read().split() 28 | 29 | with open(cached_path(file_path + '.data.txt'), 'r') as f: 30 | sentences = f.read().splitlines() 31 | assert len(labels) == len(sentences), f'The length of the labels and sentences must match. ' \ 32 | f'Got {len(labels)} and {len(sentences)}.' 33 | 34 | for line, label in zip(sentences, labels): 35 | tokens = line.split('\t') 36 | assert len(tokens) == 5, tokens 37 | 38 | text_a = tokens[3] 39 | text_b = tokens[4] 40 | if self.entity_markers: 41 | # insert entity markers 42 | idx1, idx2 = [int(ind) for ind in tokens[2].split('-')] 43 | tokens_a = text_a.strip().split() 44 | tokens_b = text_b.strip().split() 45 | tokens_a.insert(idx1, '[e1start]') 46 | tokens_a.insert(idx1 + 2, '[e1end]') 47 | tokens_b.insert(idx2, '[e2start]') 48 | tokens_b.insert(idx2 + 2, '[e2end]') 49 | text_a = ' '.join(tokens_a) 50 | text_b = ' '.join(tokens_b) 51 | 52 | token_candidates = self.tokenizer.tokenize_and_generate_candidates(text_a, text_b) 53 | fields = self.tokenizer.convert_tokens_candidates_to_fields(token_candidates) 54 | fields['label_ids'] = LabelField(self.label_to_index[label], skip_indexing=True) 55 | 56 | # get the indices of the marked words 57 | # index in the original tokens 58 | idx1, idx2 = [int(ind) for ind in tokens[2].split('-')] 59 | offsets_a = [1] + token_candidates['offsets_a'][:-1] 60 | idx1_offset = offsets_a[idx1] 61 | offsets_b = [token_candidates['offsets_a'][-1] + 1] + token_candidates['offsets_b'][:-1] 62 | idx2_offset = offsets_b[idx2] 63 | 64 | fields['index_a'] = LabelField(idx1_offset, skip_indexing=True) 65 | fields['index_b'] = LabelField(idx2_offset, skip_indexing=True) 66 | 67 | instance = Instance(fields) 68 | 69 | yield instance 70 | -------------------------------------------------------------------------------- /kb/include_all.py: -------------------------------------------------------------------------------- 1 | 2 | from kb.kg_embedding import KGTupleReader, KGTupleModel 3 | from kb.entity_linking import TokenCharactersIndexerTokenizer 4 | from kb.entity_linking import CrossSentenceLinking 5 | from kb.wordnet import WordNetFineGrainedSenseDisambiguationReader 6 | from kb.wordnet import WordNetAllEmbedding 7 | from kb.multitask import MultitaskDatasetReader, MultiTaskDataIterator 8 | from kb.bert_pretraining_reader import BertPreTrainingReader 9 | from kb.bert_tokenizer_and_candidate_generator import BertTokenizerAndCandidateGenerator, TokenizerAndCandidateGenerator 10 | from kb.self_attn_bucket_iterator import SelfAttnBucketIterator 11 | from kb.knowbert import KnowBert, BertPretrainedMaskedLM 12 | from kb.bert_utils import GeLu 13 | from kb.wiki_linking_reader import LinkingReader 14 | from kb.kg_probe_reader import KgProbeReader 15 | 16 | from kb.evaluation.classification_model import SimpleClassifier 17 | from kb.evaluation.tacred_dataset_reader import TacredDatasetReader 18 | from kb.evaluation.wic_dataset_reader import WicDatasetReader 19 | from kb.evaluation.semeval2010_task8 import SemEval2010Task8Reader, SemEval2010Task8Metric 20 | from kb.evaluation.fbeta_measure import FBetaMeasure 21 | from kb.evaluation.ultra_fine_reader import UltraFineReader 22 | 23 | from kb.common import F1Metric 24 | 25 | from allennlp.models.archival import load_archive 26 | from allennlp.models import Model 27 | 28 | import json 29 | 30 | 31 | @Model.register("from_archive") 32 | class ModelArchiveFromParams(Model): 33 | """ 34 | Loads a model from an archive 35 | """ 36 | @classmethod 37 | def from_params(cls, vocab=None, params=None): 38 | """ 39 | {"type": "from_archive", "archive_file": "path to archive", 40 | "overrides:" .... } 41 | 42 | "overrides" omits the "model" key 43 | """ 44 | archive_file = params.pop("archive_file") 45 | overrides = params.pop("overrides", None) 46 | params.assert_empty("ModelArchiveFromParams") 47 | if overrides is not None: 48 | archive = load_archive(archive_file, overrides=json.dumps({'model': overrides.as_dict()})) 49 | else: 50 | archive = load_archive(archive_file) 51 | return archive.model 52 | 53 | -------------------------------------------------------------------------------- /kb/kg_probe_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import codecs 5 | from typing import Dict, List, Iterable, Tuple 6 | 7 | from overrides import overrides 8 | 9 | from allennlp.common.file_utils import cached_path 10 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 11 | from allennlp.data.fields import Field, TextField, SequenceLabelField, LabelField, ListField, ArrayField 12 | from allennlp.data.instance import Instance 13 | from allennlp.data.tokenizers import Token 14 | from allennlp.data.token_indexers.wordpiece_indexer import PretrainedBertIndexer 15 | 16 | from kb.bert_tokenizer_and_candidate_generator import TokenizerAndCandidateGenerator, start_token, sep_token 17 | 18 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 19 | 20 | 21 | @DatasetReader.register('kg_probe') 22 | class KgProbeReader(DatasetReader): 23 | """ 24 | This DatasetReader is designed to read in sentences that render information contained in 25 | knowledge graph triples. Similar to the BertPreTrainingReader, but leverages provided entity 26 | spans to ensure that entity-related tokens are properly masked out. 27 | 28 | It returns a dataset of instances with the following fields: 29 | 30 | tokens : ``TextField`` 31 | The WordPiece tokens in the sentence. 32 | segment_ids : ``SequenceLabelField`` 33 | The labels of each of the tokens (0 - tokens from the first sentence, 34 | 1 - tokens from the second sentence). 35 | lm_label_ids : ``SequenceLabelField`` 36 | For each masked position, what is the correct label. 37 | next_sentence_label : ``LabelField`` 38 | Next sentence label: is the second sentence the next sentence following the 39 | first one, or is it a randomly selected sentence. 40 | candidates: ``DictField`` 41 | """ 42 | def __init__(self, 43 | tokenizer_and_candidate_generator: TokenizerAndCandidateGenerator, 44 | lazy: bool = False) -> None: 45 | 46 | super().__init__(lazy) 47 | 48 | self._tokenizer_and_candidate_generator = tokenizer_and_candidate_generator 49 | self._label_indexer = { 50 | "lm_labels": tokenizer_and_candidate_generator._bert_single_id_indexer["tokens"] 51 | } 52 | 53 | def _read(self, file_path: str): 54 | with open(cached_path(file_path), 'r') as f: 55 | for line in f: 56 | span_text, sentence = line.strip().split('\t') 57 | span = tuple(int(x) for x in span_text.split()) 58 | yield self.text_to_instance(sentence, span) 59 | 60 | def text_to_instance(self, sentence: str, span: Tuple[int, ...]): 61 | token_candidates = self._tokenizer_and_candidate_generator.tokenize_and_generate_candidates(sentence) 62 | 63 | # NOTE: Skipping the padding here since sentences are all quite short. 64 | vocab = self._tokenizer_and_candidate_generator.bert_tokenizer.vocab 65 | lm_label_ids = TextField( 66 | [Token(t, text_id=vocab[t]) for t in token_candidates['tokens']], 67 | token_indexers=self._label_indexer 68 | ) 69 | 70 | # We need to offset the start and end of the span so that it aligns with word pieces. 71 | if span[0] == 0: 72 | start = 1 # Since 0'th elt. is 73 | else: 74 | start = token_candidates['offsets_a'][span[0] - 1] 75 | end = token_candidates['offsets_a'][span[1]] 76 | 77 | masked_tokens: List[str] = token_candidates['tokens'].copy() 78 | mask_indicator = np.zeros(len(masked_tokens), dtype=np.uint8) 79 | for i in range(start, end): 80 | masked_tokens[i] = '[MASK]' 81 | mask_indicator[i] = 1 82 | 83 | token_candidates['tokens'] = masked_tokens 84 | 85 | # mask out the entity candidates 86 | candidates = token_candidates['candidates'] 87 | for candidate_key in candidates.keys(): 88 | indices_to_mask = [] 89 | for k, candidate_span in enumerate(candidates[candidate_key]['candidate_spans']): 90 | # (end-1) as candidate spans are exclusive (e.g. candidate_span = (0, 0) has start=0, end=1) 91 | if (candidate_span[0] >= start and candidate_span[0] <= end-1) or ( 92 | candidate_span[1] >= start and candidate_span[1] <= end-1): 93 | indices_to_mask.append(k) 94 | for ind in indices_to_mask: 95 | candidates[candidate_key]['candidate_entities'][ind] = ['@@MASK@@'] 96 | candidates[candidate_key]['candidate_entity_priors'][ind] = [1.0] 97 | 98 | fields = self._tokenizer_and_candidate_generator. \ 99 | convert_tokens_candidates_to_fields(token_candidates) 100 | 101 | fields['lm_label_ids'] = lm_label_ids 102 | fields['mask_indicator'] = ArrayField(mask_indicator, dtype=np.uint8) 103 | 104 | return Instance(fields) 105 | -------------------------------------------------------------------------------- /kb/knowbert_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Union, List 3 | 4 | from allennlp.common import Params 5 | from allennlp.data import Instance, DataIterator, Vocabulary 6 | from allennlp.common.file_utils import cached_path 7 | 8 | 9 | from kb.include_all import TokenizerAndCandidateGenerator 10 | from kb.bert_pretraining_reader import replace_candidates_with_mask_entity 11 | 12 | import json 13 | 14 | 15 | def _extract_config_from_archive(model_archive): 16 | import tarfile 17 | import tempfile 18 | import os 19 | with tempfile.TemporaryDirectory() as tmp: 20 | with tarfile.open(model_archive, 'r:gz') as archive: 21 | archive.extract('config.json', path=tmp) 22 | config = Params.from_file(os.path.join(tmp, 'config.json')) 23 | return config 24 | 25 | 26 | def _find_key(d, key): 27 | val = None 28 | stack = [d.items()] 29 | while len(stack) > 0 and val is None: 30 | s = stack.pop() 31 | for k, v in s: 32 | if k == key: 33 | val = v 34 | break 35 | elif isinstance(v, dict): 36 | stack.append(v.items()) 37 | return val 38 | 39 | 40 | class KnowBertBatchifier: 41 | """ 42 | Takes a list of sentence strings and returns a tensor dict usable with 43 | a KnowBert model 44 | """ 45 | def __init__(self, model_archive, batch_size=32, 46 | masking_strategy=None, 47 | wordnet_entity_file=None, vocab_dir=None): 48 | 49 | # get bert_tokenizer_and_candidate_generator 50 | config = _extract_config_from_archive(cached_path(model_archive)) 51 | 52 | # look for the bert_tokenizers and candidate_generator 53 | candidate_generator_params = _find_key( 54 | config['dataset_reader'].as_dict(), 'tokenizer_and_candidate_generator' 55 | ) 56 | 57 | if wordnet_entity_file is not None: 58 | candidate_generator_params['entity_candidate_generators']['wordnet']['entity_file'] = wordnet_entity_file 59 | 60 | self.tokenizer_and_candidate_generator = TokenizerAndCandidateGenerator.\ 61 | from_params(Params(candidate_generator_params)) 62 | self.tokenizer_and_candidate_generator.whitespace_tokenize = False 63 | 64 | assert masking_strategy is None or masking_strategy == 'full_mask' 65 | self.masking_strategy = masking_strategy 66 | 67 | # need bert_tokenizer_and_candidate_generator 68 | if vocab_dir is not None: 69 | vocab_params = Params({"directory_path": vocab_dir}) 70 | else: 71 | vocab_params = config['vocabulary'] 72 | self.vocab = Vocabulary.from_params(vocab_params) 73 | 74 | self.iterator = DataIterator.from_params( 75 | Params({"type": "basic", "batch_size": batch_size}) 76 | ) 77 | self.iterator.index_with(self.vocab) 78 | 79 | def _replace_mask(self, s): 80 | return s.replace('[MASK]', ' [MASK] ') 81 | 82 | def iter_batches(self, sentences_or_sentence_pairs: Union[List[str], List[List[str]]], verbose=True): 83 | # create instances 84 | instances = [] 85 | for sentence_or_sentence_pair in sentences_or_sentence_pairs: 86 | if isinstance(sentence_or_sentence_pair, list): 87 | assert len(sentence_or_sentence_pair) == 2 88 | tokens_candidates = self.tokenizer_and_candidate_generator.\ 89 | tokenize_and_generate_candidates( 90 | self._replace_mask(sentence_or_sentence_pair[0]), 91 | self._replace_mask(sentence_or_sentence_pair[1])) 92 | else: 93 | tokens_candidates = self.tokenizer_and_candidate_generator.\ 94 | tokenize_and_generate_candidates(self._replace_mask(sentence_or_sentence_pair)) 95 | 96 | if verbose: 97 | print(self._replace_mask(sentence_or_sentence_pair)) 98 | print(tokens_candidates['tokens']) 99 | 100 | # now modify the masking if needed 101 | if self.masking_strategy == 'full_mask': 102 | # replace the mask span with a @@mask@@ span 103 | masked_indices = [index for index, token in enumerate(tokens_candidates['tokens']) 104 | if token == '[MASK]'] 105 | 106 | spans_to_mask = set([(i, i) for i in masked_indices]) 107 | replace_candidates_with_mask_entity( 108 | tokens_candidates['candidates'], spans_to_mask 109 | ) 110 | 111 | # now make sure the spans are actually masked 112 | for key in tokens_candidates['candidates'].keys(): 113 | for span_to_mask in spans_to_mask: 114 | found = False 115 | for span in tokens_candidates['candidates'][key]['candidate_spans']: 116 | if tuple(span) == tuple(span_to_mask): 117 | found = True 118 | if not found: 119 | tokens_candidates['candidates'][key]['candidate_spans'].append(list(span_to_mask)) 120 | tokens_candidates['candidates'][key]['candidate_entities'].append(['@@MASK@@']) 121 | tokens_candidates['candidates'][key]['candidate_entity_priors'].append([1.0]) 122 | tokens_candidates['candidates'][key]['candidate_segment_ids'].append(0) 123 | # hack, assume only one sentence 124 | assert not isinstance(sentence_or_sentence_pair, list) 125 | 126 | 127 | fields = self.tokenizer_and_candidate_generator.\ 128 | convert_tokens_candidates_to_fields(tokens_candidates) 129 | 130 | instances.append(Instance(fields)) 131 | 132 | 133 | for batch in self.iterator(instances, num_epochs=1, shuffle=False): 134 | yield batch 135 | 136 | -------------------------------------------------------------------------------- /kb/multitask.py: -------------------------------------------------------------------------------- 1 | """ 2 | More fully featured dataset readers and iterators for multitask training 3 | then allennlp. 4 | 5 | Differences: 6 | - randomly sample batches from each dataset according to dataset size, 7 | so gradient steps are spread out over each throught the course 8 | of training. 9 | - allows to use any generic iterators for each dataset 10 | - allows to remove some datasets for vocab creation 11 | 12 | Implementation in allennlp: 13 | 14 | Interface for dataset and iterator in allennlp trainer: 15 | train_generator = self._iterator(self._train_data, 16 | num_epochs=1, 17 | shuffle=shuffle) 18 | num_training_batches = self._iterator.get_num_batches(self._train_data) 19 | 20 | Interface for dataset and iterator in train command: 21 | instances_for_vocab = [] 22 | for instance in dataset: 23 | instances_for_vocab.append(instance) 24 | --> then pass into Vocabulary.from_params(...) 25 | 26 | So dataset needs to implement __iter__, except it is only called 27 | to construct the Vocabulary, if we also pair this dataset with a 28 | special iterator that doesn't call __iter__. 29 | """ 30 | 31 | from typing import Dict, List, Iterable 32 | 33 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 34 | from allennlp.data import Instance 35 | from allennlp.data.iterators import DataIterator 36 | from allennlp.data import Vocabulary 37 | 38 | import numpy as np 39 | 40 | import torch 41 | 42 | 43 | class MultitaskDataset: 44 | def __init__(self, datasets: Dict[str, Iterable[Instance]], 45 | datasets_for_vocab_creation: List[str]): 46 | self.datasets = datasets 47 | self.datasets_for_vocab_creation = datasets_for_vocab_creation 48 | 49 | def __iter__(self): 50 | # with our iterator, this is only called for vocab creation 51 | for key in self.datasets_for_vocab_creation: 52 | for instance in self.datasets[key]: 53 | yield instance 54 | 55 | 56 | @DatasetReader.register("multitask_reader") 57 | class MultitaskDatasetReader(DatasetReader): 58 | def __init__(self, 59 | dataset_readers: Dict[str, DatasetReader], 60 | datasets_for_vocab_creation: List[str]) -> None: 61 | super().__init__(False) 62 | self.dataset_readers = dataset_readers 63 | self.datasets_for_vocab_creation = datasets_for_vocab_creation 64 | 65 | def read(self, file_path: Dict[str, str]): 66 | """ 67 | read returns an iterable of instances that is directly 68 | iterated over when constructing vocab, and in the iterators. 69 | Since we will also pair this reader with a special iterator, 70 | we only have to worry about the case where the return value from 71 | this call is used to iterate for vocab creation. 72 | 73 | In addition, it is the return value from this that is passed 74 | into Trainer as the dataset (and then into the iterator) 75 | """ 76 | datasets = {key: self.dataset_readers[key].read(fpath) 77 | for key, fpath in file_path.items()} 78 | return MultitaskDataset(datasets, self.datasets_for_vocab_creation) 79 | 80 | 81 | @DataIterator.register("multitask_iterator") 82 | class MultiTaskDataIterator(DataIterator): 83 | def __init__(self, 84 | iterators: Dict[str, DataIterator], 85 | names_to_index: List[str], 86 | iterate_forever: bool = False, 87 | sampling_rates: List[float] = None) -> None: 88 | self.iterators = iterators 89 | self.names_to_index = names_to_index 90 | self.sampling_rates = sampling_rates 91 | self.iterate_forever = iterate_forever 92 | 93 | def __call__(self, 94 | multitask_dataset: MultitaskDataset, 95 | num_epochs: int = None, 96 | shuffle: bool = True): 97 | 98 | # get the number of batches in each of the sub-iterators for 99 | # the sampling rate 100 | num_batches_per_iterator = [] 101 | for name in self.names_to_index: 102 | dataset = multitask_dataset.datasets[name] 103 | num_batches_per_iterator.append( 104 | self.iterators[name].get_num_batches(dataset) 105 | ) 106 | 107 | total_batches_per_epoch = sum(num_batches_per_iterator) 108 | 109 | # make the sampling rates -- 110 | p = np.array(num_batches_per_iterator, dtype=np.float) \ 111 | / total_batches_per_epoch 112 | 113 | if self.iterate_forever: 114 | total_batches_per_epoch = 1000000000 115 | if self.sampling_rates is not None: 116 | p = np.array(self.sampling_rates, dtype=np.float) 117 | 118 | for epoch in range(num_epochs): 119 | generators = [] 120 | for name in self.names_to_index: 121 | dataset = multitask_dataset.datasets[name] 122 | generators.append( 123 | self.iterators[name]( 124 | dataset, 125 | num_epochs=1, 126 | shuffle=shuffle, 127 | ) 128 | ) 129 | 130 | n_batches_this_epoch = 0 131 | all_indices = np.arange(len(generators)).tolist() 132 | while n_batches_this_epoch < total_batches_per_epoch: 133 | index = np.random.choice(len(generators), p=p) 134 | try: 135 | batch = next(generators[index]) 136 | except StopIteration: 137 | # remove this generator from the pile! 138 | del generators[index] 139 | if len(generators) == 0: 140 | # something went wrong 141 | raise ValueError 142 | del all_indices[index] 143 | newp = np.concatenate([p[:index], p[index+1:]]) 144 | newp /= newp.sum() 145 | p = newp 146 | continue 147 | 148 | # add the iterator id 149 | batch['dataset_index'] = torch.tensor(all_indices[index]) 150 | yield batch 151 | 152 | n_batches_this_epoch += 1 153 | 154 | def _take_instances(self, *args, **kwargs): 155 | raise NotImplementedError 156 | 157 | def _memory_sized_lists(self, *args, **kwargs): 158 | raise NotImplementedError 159 | 160 | def _ensure_batch_is_sufficiently_small(self, *args, **kwargs): 161 | raise NotImplementedError 162 | 163 | def get_num_batches(self, multitask_dataset: MultitaskDataset) -> int: 164 | num_batches = 0 165 | for name, dataset in multitask_dataset.datasets.items(): 166 | num_batches += self.iterators[name].get_num_batches(dataset) 167 | return num_batches 168 | 169 | def index_with(self, vocab: Vocabulary): 170 | for iterator in self.iterators.values(): 171 | iterator.index_with(vocab) 172 | 173 | -------------------------------------------------------------------------------- /kb/self_attn_bucket_iterator.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import random 4 | from collections import deque 5 | from typing import List, Tuple, Iterable, cast, Dict, Deque 6 | 7 | from overrides import overrides 8 | 9 | from allennlp.common.checks import ConfigurationError 10 | from allennlp.common.util import lazy_groups_of, add_noise_to_dict_values 11 | from allennlp.data.dataset import Batch 12 | from allennlp.data.instance import Instance 13 | from allennlp.data.iterators.data_iterator import DataIterator 14 | from allennlp.data.vocabulary import Vocabulary 15 | 16 | import bisect 17 | 18 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 19 | 20 | from allennlp.data.iterators.bucket_iterator import sort_by_padding 21 | 22 | 23 | SCHEDULES = { 24 | "base-24gb-bs64_fp32": [ 25 | [64, 115], 26 | [32, 220], 27 | [16, 380], 28 | [8, 512] 29 | ], 30 | "base-12gb-fp32": [ 31 | [32, 90], 32 | [16, 170], 33 | [8, 300], 34 | [4, 400], 35 | [2, 512] 36 | ], 37 | "base-11gb-fp32": [ 38 | [32, 80], 39 | [16, 150], 40 | [8, 270], 41 | [4, 370], 42 | [2, 512] 43 | ], 44 | "base-24gb-fp32": [ 45 | [32, 140], 46 | [16, 280], 47 | [8, 400], 48 | [4, 512], 49 | ], 50 | } 51 | 52 | 53 | @DataIterator.register("self_attn_bucket") 54 | class SelfAttnBucketIterator(DataIterator): 55 | """ 56 | Like a bucket iterator, but with a quadratic relationship between 57 | sequence length and batch size instead of linear. 58 | 59 | Has a fixed schedule of batch size vs sequence length. 60 | """ 61 | def __init__(self, 62 | batch_size_schedule: str, 63 | iterator: DataIterator): 64 | 65 | if isinstance(batch_size_schedule, str): 66 | schedule = SCHEDULES[batch_size_schedule] 67 | else: 68 | # user is providing a dict directly 69 | schedule = batch_size_schedule 70 | 71 | # set batch size to max value in schedule 72 | batch_size = schedule[0][0] 73 | 74 | super().__init__( 75 | batch_size=batch_size, 76 | instances_per_epoch=iterator._instances_per_epoch, 77 | max_instances_in_memory=iterator._max_instances_in_memory, 78 | cache_instances=iterator._cache_instances, 79 | track_epoch=iterator._track_epoch, 80 | maximum_samples_per_batch=iterator._maximum_samples_per_batch 81 | ) 82 | 83 | self.iterator = iterator 84 | 85 | # process the schedule 86 | self._schedule_batch_sizes = [ele[0] for ele in schedule] 87 | self._schedule_lengths = [ele[1] for ele in schedule] 88 | 89 | def index_with(self, vocab: Vocabulary): 90 | self.vocab = vocab 91 | self.iterator.index_with(vocab) 92 | 93 | @overrides 94 | def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]: 95 | for batch in self.iterator._create_batches(instances, shuffle): 96 | # split after shuffling so smaller batches are kept together 97 | batch_instances = batch.instances 98 | 99 | # split if needed 100 | batch_length = -1 101 | for instance in batch_instances: 102 | instance.index_fields(self.vocab) 103 | field_lengths = instance.get_padding_lengths() 104 | batch_length = max(batch_length, field_lengths['tokens']['num_tokens']) 105 | 106 | # get the required batch size 107 | index = bisect.bisect_left(self._schedule_lengths, batch_length) 108 | if index == len(self._schedule_lengths): 109 | # this batch exceeds the maximum allowed, just skip it 110 | continue 111 | batch_size = self._schedule_batch_sizes[index] 112 | start = 0 113 | while start < len(batch_instances): 114 | end = start + batch_size 115 | yield Batch(batch_instances[start:end]) 116 | start = end 117 | 118 | -------------------------------------------------------------------------------- /kb/span_attention_layer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import math 4 | 5 | from pytorch_pretrained_bert.modeling import BertIntermediate, BertOutput, BertLayer, BertSelfOutput 6 | 7 | from kb.common import get_dtype_for_module, extend_attention_mask_for_bert, get_linear_layer_init_identity, init_bert_weights 8 | 9 | 10 | class SpanWordAttention(torch.nn.Module): 11 | def __init__(self, config): 12 | super(SpanWordAttention, self).__init__() 13 | if config.hidden_size % config.num_attention_heads != 0: 14 | raise ValueError( 15 | "The hidden size (%d) is not a multiple of the number of attention " 16 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 17 | self.num_attention_heads = config.num_attention_heads 18 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 19 | self.all_head_size = self.num_attention_heads * self.attention_head_size 20 | 21 | #self.query = get_linear_layer_init_identity(config.hidden_size) 22 | #self.key = get_linear_layer_init_identity(config.hidden_size) 23 | #self.value = get_linear_layer_init_identity(config.hidden_size) 24 | 25 | self.query = torch.nn.Linear(config.hidden_size, self.all_head_size) 26 | self.key = torch.nn.Linear(config.hidden_size, self.all_head_size) 27 | self.value = torch.nn.Linear(config.hidden_size, self.all_head_size) 28 | 29 | self.dropout = torch.nn.Dropout(config.attention_probs_dropout_prob) 30 | 31 | def transpose_for_scores(self, x): 32 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 33 | x = x.view(*new_x_shape) 34 | return x.permute(0, 2, 1, 3) 35 | 36 | def forward(self, hidden_states, entity_embeddings, entity_mask): 37 | """ 38 | hidden_states = (batch_size, timesteps, dim) 39 | entity_embeddings = (batch_size, num_entities, dim) 40 | entity_mask = (batch_size, num_entities) with 0/1 41 | """ 42 | mixed_query_layer = self.query(hidden_states) 43 | mixed_key_layer = self.key(entity_embeddings) 44 | mixed_value_layer = self.value(entity_embeddings) 45 | 46 | # (batch_size, num_heads, timesteps, head_size) 47 | query_layer = self.transpose_for_scores(mixed_query_layer) 48 | # (batch_size, num_heads, num_entity_embeddings, head_size) 49 | key_layer = self.transpose_for_scores(mixed_key_layer) 50 | value_layer = self.transpose_for_scores(mixed_value_layer) 51 | 52 | # Take the dot product between "query" and "key" to get the raw attention scores. 53 | # (batch_size, num_heads, timesteps, num_entity_embeddings) 54 | # gives the attention from timestep i to embedding j 55 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 56 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 57 | 58 | # apply the attention mask. 59 | # the attention_mask masks out thing to attend TO so we extend 60 | # the entity mask 61 | attention_mask = extend_attention_mask_for_bert(entity_mask, get_dtype_for_module(self)) 62 | attention_scores = attention_scores + attention_mask 63 | 64 | # Normalize the attention scores to probabilities. 65 | attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) 66 | 67 | # This is actually dropping out entire entities to attend to, which might 68 | # seem a bit unusual, but is similar to the original Transformer paper. 69 | attention_probs = self.dropout(attention_probs) 70 | 71 | # (batch_size, num_heads, timesteps, head_size) 72 | context_layer = torch.matmul(attention_probs, value_layer) 73 | # (batch_size, timesteps, num_heads, head_size) 74 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 75 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 76 | # (batch_size, timesteps, hidden_dim) 77 | context_layer = context_layer.view(*new_context_layer_shape) 78 | return context_layer, attention_probs 79 | 80 | 81 | class SpanAttention(torch.nn.Module): 82 | def __init__(self, config): 83 | super(SpanAttention, self).__init__() 84 | self.attention = SpanWordAttention(config) 85 | init_bert_weights(self.attention, config.initializer_range, (SpanWordAttention, )) 86 | self.output = BertSelfOutput(config) 87 | init_bert_weights(self.output, config.initializer_range) 88 | 89 | def forward(self, input_tensor, entity_embeddings, entity_mask): 90 | span_output, attention_probs = self.attention(input_tensor, entity_embeddings, entity_mask) 91 | attention_output = self.output(span_output, input_tensor) 92 | return attention_output, attention_probs 93 | 94 | 95 | class SpanAttentionLayer(torch.nn.Module): 96 | # WARNING: does it's own init, so don't re-init 97 | def __init__(self, config): 98 | super(SpanAttentionLayer, self).__init__() 99 | self.attention = SpanAttention(config) 100 | self.intermediate = BertIntermediate(config) 101 | self.output = BertOutput(config) 102 | init_bert_weights(self.intermediate, config.initializer_range) 103 | init_bert_weights(self.output, config.initializer_range) 104 | 105 | def forward(self, hidden_states, entity_embeddings, entity_mask): 106 | attention_output, attention_probs = self.attention(hidden_states, entity_embeddings, entity_mask) 107 | intermediate_output = self.intermediate(attention_output) 108 | layer_output = self.output(intermediate_output, attention_output) 109 | return {"output": layer_output, "attention_probs": attention_probs} 110 | 111 | -------------------------------------------------------------------------------- /kb/testing.py: -------------------------------------------------------------------------------- 1 | 2 | from allennlp.data import DatasetReader, Vocabulary, DataIterator, TokenIndexer 3 | from allennlp.common import Params 4 | from allennlp.modules import TokenEmbedder 5 | from allennlp.models import Model 6 | 7 | import torch 8 | 9 | from kb.include_all import * 10 | 11 | 12 | def get_bert_test_fixture(): 13 | embedder_params = { 14 | "type": "bert-pretrained", 15 | "pretrained_model": "tests/fixtures/bert/bert_test_fixture.tar.gz", 16 | "requires_grad": True, 17 | "top_layer_only": True, 18 | } 19 | embedder_params_copy = dict(embedder_params) 20 | embedder = TokenEmbedder.from_params(Params(embedder_params)) 21 | 22 | 23 | indexer_params = { 24 | "type": "bert-pretrained", 25 | "pretrained_model": "tests/fixtures/bert/vocab.txt", 26 | "do_lowercase": True, 27 | "use_starting_offsets": True, 28 | "max_pieces": 512, 29 | } 30 | indexer_params_copy = dict(indexer_params) 31 | indexer = TokenIndexer.from_params(Params(indexer_params)) 32 | 33 | return {'embedder': embedder, 'embedder_params': embedder_params_copy, 34 | 'indexer': indexer, 'indexer_params': indexer_params_copy} 35 | 36 | 37 | def get_wsd_reader(is_training, use_bert_indexer=False, wordnet_entity_file=None): 38 | if wordnet_entity_file is None: 39 | wordnet_entity_file = "tests/fixtures/wordnet/entities_cat_hat.jsonl" 40 | 41 | if use_bert_indexer: 42 | bert_fixtures = get_bert_test_fixture() 43 | indexer_params = bert_fixtures["indexer_params"] 44 | else: 45 | indexer_params = {"type": "single_id", "lowercase_tokens": True} 46 | 47 | reader_params = { 48 | "type": "wordnet_fine_grained", 49 | "wordnet_entity_file": wordnet_entity_file, 50 | "token_indexers": { 51 | "tokens": indexer_params, 52 | }, 53 | "entity_indexer": { 54 | "type": "characters_tokenizer", 55 | "tokenizer": { 56 | "type": "word", 57 | "word_splitter": {"type": "just_spaces"}, 58 | }, 59 | "namespace": "entity" 60 | }, 61 | "is_training": is_training, 62 | "use_surface_form": False 63 | } 64 | reader = DatasetReader.from_params(Params(reader_params)) 65 | 66 | vocab_params = { 67 | "directory_path": "tests/fixtures/wordnet/cat_hat_vocabdir" 68 | } 69 | vocab = Vocabulary.from_params(Params(vocab_params)) 70 | 71 | iterator = DataIterator.from_params(Params({"type": "basic"})) 72 | iterator.index_with(vocab) 73 | 74 | return reader, vocab, iterator 75 | 76 | 77 | def get_wsd_fixture_batch(is_training, use_bert_indexer=False): 78 | wsd_file = 'tests/fixtures/wordnet/wsd_dataset.json' 79 | reader, vocab, iterator = get_wsd_reader(is_training, use_bert_indexer=use_bert_indexer) 80 | instances = reader.read(wsd_file) 81 | 82 | for batch in iterator(instances, shuffle=False, num_epochs=1): 83 | break 84 | return batch 85 | 86 | 87 | def get_bert_pretraining_reader_with_kg( 88 | mask_candidate_strategy='none', masked_lm_prob=0.15, include_wiki=False 89 | ): 90 | params = { 91 | "type": "bert_pre_training", 92 | "tokenizer_and_candidate_generator": { 93 | "type": "bert_tokenizer_and_candidate_generator", 94 | "entity_candidate_generators": { 95 | "wordnet": {"type": "wordnet_mention_generator", 96 | "entity_file": "tests/fixtures/wordnet/entities_fixture.jsonl"} 97 | }, 98 | "entity_indexers": { 99 | "wordnet": { 100 | "type": "characters_tokenizer", 101 | "tokenizer": { 102 | "type": "word", 103 | "word_splitter": {"type": "just_spaces"}, 104 | }, 105 | "namespace": "entity" 106 | } 107 | }, 108 | "bert_model_type": "tests/fixtures/bert/vocab.txt", 109 | "do_lower_case": True, 110 | }, 111 | "mask_candidate_strategy": mask_candidate_strategy, 112 | "masked_lm_prob": masked_lm_prob 113 | } 114 | 115 | if include_wiki: 116 | params["tokenizer_and_candidate_generator"]["entity_candidate_generators"]["wiki"] = { 117 | "type": "wiki", 118 | "candidates_file": "tests/fixtures/linking/priors.txt", 119 | } 120 | params["tokenizer_and_candidate_generator"]["entity_indexers"]["wiki"] = { 121 | "type": "characters_tokenizer", 122 | "tokenizer": { 123 | "type": "word", 124 | "word_splitter": {"type": "just_spaces"}, 125 | }, 126 | "namespace": "entity_wiki" 127 | } 128 | params["tokenizer_and_candidate_generator"]["entity_indexers"]["wordnet"]["namespace"] = "entity_wordnet" 129 | 130 | return DatasetReader.from_params(Params(params)) 131 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # this is the fp16_e_s3 branch 2 | git+git://github.com/matt-peters/allennlp.git@2d7ba1cb108428aaffe2dce875648253b44cb5ba 3 | pytest 4 | nltk 5 | ipython 6 | spacy 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from distutils.core import setup 3 | 4 | 5 | setup( 6 | name='kb', 7 | version='0.1', 8 | packages=['kb', 'kb.evaluation'], 9 | entry_points='', 10 | package_dir={'kb': 'kb'}, 11 | ) 12 | 13 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/kb/4c37bbccd6c871828aceb24b342841825d86df14/tests/__init__.py -------------------------------------------------------------------------------- /tests/evaluation/test_semeval2010_task8.py: -------------------------------------------------------------------------------- 1 | 2 | import unittest 3 | from kb.include_all import SemEval2010Task8Reader, SemEval2010Task8Metric 4 | from allennlp.common import Params 5 | from allennlp.data import DatasetReader, DataIterator, Vocabulary 6 | import torch 7 | 8 | 9 | class TestSemEval2010Task8Metric(unittest.TestCase): 10 | def test_semeval2010_metric(self): 11 | predicted_ids = [ 12 | torch.tensor([0, 15, 3]), 13 | torch.tensor([7, 0]) 14 | ] 15 | gold_ids = [ 16 | torch.tensor([0, 3, 3]), 17 | torch.tensor([0, 7]) 18 | ] 19 | 20 | expected_f1 = 33.33 21 | 22 | metric = SemEval2010Task8Metric() 23 | for p, g in zip(predicted_ids, gold_ids): 24 | metric(p, g) 25 | 26 | f1 = metric.get_metric() 27 | 28 | self.assertAlmostEqual(expected_f1, f1) 29 | 30 | 31 | class TestSemEval2010Task8Reader(unittest.TestCase): 32 | def test_semeval2010_task8_reader(self): 33 | reader_params = Params({ 34 | "type": "semeval2010_task8", 35 | "tokenizer_and_candidate_generator": { 36 | "type": "bert_tokenizer_and_candidate_generator", 37 | "entity_candidate_generators": { 38 | "wordnet": {"type": "wordnet_mention_generator", 39 | "entity_file": "tests/fixtures/wordnet/entities_fixture.jsonl"} 40 | }, 41 | "entity_indexers": { 42 | "wordnet": { 43 | "type": "characters_tokenizer", 44 | "tokenizer": { 45 | "type": "word", 46 | "word_splitter": {"type": "just_spaces"}, 47 | }, 48 | "namespace": "entity" 49 | } 50 | }, 51 | "bert_model_type": "tests/fixtures/bert/vocab.txt", 52 | "do_lower_case": True, 53 | }, 54 | }) 55 | 56 | reader = DatasetReader.from_params(reader_params) 57 | train_file = 'tests/fixtures/evaluation/semeval2010_task8/semeval2010_task8.json' 58 | 59 | instances = reader.read(train_file) 60 | 61 | # check that the offsets are right! 62 | segment_ids = instances[0]['segment_ids'].array.tolist() 63 | tokens = [t.text for t in instances[0]['tokens'].tokens] 64 | 65 | tokens_and_segments = list(zip(tokens, segment_ids)) 66 | 67 | expected_tokens_and_segments = [ 68 | ('[CLS]', 0), 69 | ('the', 0), 70 | ('big', 1), 71 | ('cat', 1), 72 | ('##s', 1), 73 | ('jumped', 0), 74 | ('[UNK]', 0), 75 | ('the', 0), 76 | ('la', 2), 77 | ('##zie', 2), 78 | ('##st', 2), 79 | ('brown', 2), 80 | ('dog', 2), 81 | ('##s', 2), 82 | ('.', 0), 83 | ('[SEP]', 0) 84 | ] 85 | 86 | self.assertEqual( 87 | tokens_and_segments, 88 | expected_tokens_and_segments 89 | ) 90 | 91 | def test_semeval2010_task8_reader_with_entity_markers(self): 92 | reader_params = Params({ 93 | "type": "semeval2010_task8", 94 | "entity_masking": "entity_markers", 95 | "tokenizer_and_candidate_generator": { 96 | "type": "bert_tokenizer_and_candidate_generator", 97 | "entity_candidate_generators": { 98 | "wordnet": {"type": "wordnet_mention_generator", 99 | "entity_file": "tests/fixtures/wordnet/entities_fixture.jsonl"} 100 | }, 101 | "entity_indexers": { 102 | "wordnet": { 103 | "type": "characters_tokenizer", 104 | "tokenizer": { 105 | "type": "word", 106 | "word_splitter": {"type": "just_spaces"}, 107 | }, 108 | "namespace": "entity" 109 | } 110 | }, 111 | "bert_model_type": "tests/fixtures/evaluation/semeval2010_task8/vocab_entity_markers.txt", 112 | "do_lower_case": True, 113 | }, 114 | }) 115 | 116 | reader = DatasetReader.from_params(reader_params) 117 | train_file = 'tests/fixtures/evaluation/semeval2010_task8/semeval2010_task8.json' 118 | 119 | instances = reader.read(train_file) 120 | 121 | # check that the offsets are right! 122 | segment_ids = instances[0]['segment_ids'].array.tolist() 123 | tokens = [t.text for t in instances[0]['tokens'].tokens] 124 | 125 | tokens_and_segments = list(zip(tokens, segment_ids)) 126 | 127 | expected_tokens_and_segments = [ 128 | ('[CLS]', 0), 129 | ('the', 0), 130 | ('[e1start]', 0), 131 | ('big', 0), 132 | ('cat', 0), 133 | ('##s', 0), 134 | ('[e1end]', 0), 135 | ('jumped', 0), 136 | ('[UNK]', 0), 137 | ('the', 0), 138 | ('[e2start]', 0), 139 | ('la', 0), 140 | ('##zie', 0), 141 | ('##st', 0), 142 | ('brown', 0), 143 | ('dog', 0), 144 | ('##s', 0), 145 | ('[e2end]', 0), 146 | ('.', 0), 147 | ('[SEP]', 0)] 148 | 149 | self.assertEqual( 150 | tokens_and_segments, 151 | expected_tokens_and_segments 152 | ) 153 | 154 | tokens_1 = [t.text for t in instances[1]['tokens'].tokens] 155 | expected_tokens_1 = ['[CLS]', 156 | 'the', 157 | '[e2start]', 158 | 'big', 159 | 'cat', 160 | '##s', 161 | '[e2end]', 162 | 'jumped', 163 | '[e1start]', 164 | '[UNK]', 165 | 'the', 166 | 'la', 167 | '##zie', 168 | '##st', 169 | 'brown', 170 | 'dog', 171 | '##s', 172 | '[e1end]', 173 | '.', 174 | '[SEP]'] 175 | 176 | self.assertEqual( 177 | tokens_1, 178 | expected_tokens_1 179 | ) 180 | 181 | self.assertEqual( 182 | instances[0].fields['label_ids'].label, 0 183 | ) 184 | self.assertEqual( 185 | instances[1].fields['label_ids'].label, 8 186 | ) 187 | 188 | all_tokens = [[t.text for t in instances[k]['tokens'].tokens] for k in range(2)] 189 | 190 | for k in range(2): 191 | self.assertEqual(all_tokens[k][instances[k].fields['index_a'].label], '[e1start]') 192 | self.assertEqual(all_tokens[k][instances[k].fields['index_b'].label], '[e2start]') 193 | 194 | 195 | if __name__ == '__main__': 196 | unittest.main() 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /tests/evaluation/test_simple_classifier.py: -------------------------------------------------------------------------------- 1 | 2 | import unittest 3 | from kb.include_all import SimpleClassifier, F1Metric 4 | from allennlp.common import Params 5 | from allennlp.models import Model 6 | from allennlp.data import DatasetReader, DataIterator, Vocabulary 7 | from allennlp.training.metrics import CategoricalAccuracy 8 | 9 | def get_wic_batch(): 10 | fixtures = 'tests/fixtures/evaluation/wic' 11 | 12 | reader_params = Params({ 13 | "type": "wic", 14 | "tokenizer_and_candidate_generator": { 15 | "type": "bert_tokenizer_and_candidate_generator", 16 | "entity_candidate_generators": { 17 | "wordnet": {"type": "wordnet_mention_generator", 18 | "entity_file": "tests/fixtures/wordnet/entities_fixture.jsonl"} 19 | }, 20 | "entity_indexers": { 21 | "wordnet": { 22 | "type": "characters_tokenizer", 23 | "tokenizer": { 24 | "type": "word", 25 | "word_splitter": {"type": "just_spaces"}, 26 | }, 27 | "namespace": "entity" 28 | } 29 | }, 30 | "bert_model_type": "tests/fixtures/bert/vocab.txt", 31 | "do_lower_case": True, 32 | }, 33 | }) 34 | 35 | reader = DatasetReader.from_params(reader_params) 36 | instances = reader.read(fixtures + '/train') 37 | iterator = DataIterator.from_params(Params({"type": "basic"})) 38 | iterator.index_with(Vocabulary()) 39 | 40 | for batch in iterator(instances, num_epochs=1, shuffle=False): 41 | break 42 | 43 | return batch 44 | 45 | 46 | def get_ultra_fine_batch(): 47 | from kb.include_all import UltraFineReader 48 | 49 | params = { 50 | "type": "ultra_fine", 51 | "tokenizer_and_candidate_generator": { 52 | "type": "bert_tokenizer_and_candidate_generator", 53 | "entity_candidate_generators": { 54 | "wordnet": {"type": "wordnet_mention_generator", 55 | "entity_file": "tests/fixtures/wordnet/entities_fixture.jsonl"} 56 | }, 57 | "entity_indexers": { 58 | "wordnet": { 59 | "type": "characters_tokenizer", 60 | "tokenizer": { 61 | "type": "word", 62 | "word_splitter": {"type": "just_spaces"}, 63 | }, 64 | "namespace": "entity" 65 | } 66 | }, 67 | "bert_model_type": "tests/fixtures/bert/vocab.txt", 68 | "do_lower_case": True, 69 | } 70 | } 71 | 72 | reader = DatasetReader.from_params(Params(params)) 73 | instances = reader.read('tests/fixtures/evaluation/ultra_fine/train.json') 74 | iterator = DataIterator.from_params(Params({"type": "basic"})) 75 | iterator.index_with(Vocabulary()) 76 | 77 | for batch in iterator(instances, num_epochs=1, shuffle=False): 78 | break 79 | 80 | return batch 81 | 82 | 83 | def get_knowbert_model(): 84 | vocab = Vocabulary.from_params(Params({ 85 | "directory_path": "tests/fixtures/kg_embeddings/tucker_wordnet/vocabulary", 86 | })) 87 | 88 | params = Params({ 89 | "type": "knowbert", 90 | "soldered_kgs": { 91 | "wordnet": { 92 | "type": "soldered_kg", 93 | "entity_linker": { 94 | "type": "entity_linking_with_candidate_mentions", 95 | "kg_model": { 96 | "type": "from_archive", 97 | "archive_file": "tests/fixtures/kg_embeddings/tucker_wordnet/model.tar.gz", 98 | }, 99 | "contextual_embedding_dim": 12, 100 | "max_sequence_length": 64, 101 | "span_encoder_config": { 102 | "hidden_size": 24, 103 | "num_hidden_layers": 1, 104 | "num_attention_heads": 3, 105 | "intermediate_size": 37 106 | }, 107 | }, 108 | "span_attention_config": { 109 | "hidden_size": 24, 110 | "num_hidden_layers": 2, 111 | "num_attention_heads": 4, 112 | "intermediate_size": 55 113 | } 114 | }, 115 | }, 116 | "soldered_layers": {"wordnet": 1}, 117 | "bert_model_name": "tests/fixtures/bert/bert_test_fixture.tar.gz", 118 | }) 119 | 120 | model = Model.from_params(params, vocab=vocab) 121 | return model, vocab 122 | 123 | 124 | class TestSimpleClassifier(unittest.TestCase): 125 | def test_simple_classifier(self): 126 | batch = get_wic_batch() 127 | knowbert_model, vocab = get_knowbert_model() 128 | 129 | model = SimpleClassifier( 130 | vocab, 131 | knowbert_model, 132 | 'classification', 133 | 2, 134 | 12, 135 | CategoricalAccuracy() 136 | ) 137 | output = model(**batch) 138 | output['loss'].backward() 139 | 140 | self.assertTrue(True) 141 | 142 | def test_simple_classifier_with_concat_a_b(self): 143 | batch = get_wic_batch() 144 | knowbert_model, vocab = get_knowbert_model() 145 | 146 | model = SimpleClassifier( 147 | vocab, 148 | knowbert_model, 149 | 'classification', 150 | 2, 151 | 12, 152 | CategoricalAccuracy(), 153 | concat_word_a_b=True 154 | ) 155 | 156 | output = model(**batch) 157 | output['loss'].backward() 158 | 159 | self.assertTrue(True) 160 | 161 | def test_simple_classifier_bce_loss(self): 162 | batch = get_ultra_fine_batch() 163 | knowbert_model, vocab = get_knowbert_model() 164 | 165 | model = SimpleClassifier( 166 | vocab, 167 | knowbert_model, 168 | 'classification', 169 | 9, # 9 labels 170 | 12, 171 | F1Metric(), 172 | use_bce_loss=True 173 | ) 174 | 175 | output = model(**batch) 176 | output['loss'].backward() 177 | 178 | metrics = model.get_metrics() 179 | self.assertTrue('f1' in metrics) 180 | 181 | 182 | if __name__ == '__main__': 183 | unittest.main() 184 | 185 | 186 | 187 | -------------------------------------------------------------------------------- /tests/evaluation/test_ultra_fine_reader.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from allennlp.common import Params 4 | from allennlp.common.util import ensure_list 5 | from allennlp.data import DatasetReader, DataIterator, Vocabulary 6 | 7 | from kb.include_all import UltraFineReader 8 | 9 | 10 | def get_reader(entity_masking): 11 | params = { 12 | "type": "ultra_fine", 13 | "entity_masking": entity_masking, 14 | "tokenizer_and_candidate_generator": { 15 | "type": "bert_tokenizer_and_candidate_generator", 16 | "entity_candidate_generators": { 17 | "wordnet": {"type": "wordnet_mention_generator", 18 | "entity_file": "tests/fixtures/wordnet/entities_fixture.jsonl"} 19 | }, 20 | "entity_indexers": { 21 | "wordnet": { 22 | "type": "characters_tokenizer", 23 | "tokenizer": { 24 | "type": "word", 25 | "word_splitter": {"type": "just_spaces"}, 26 | }, 27 | "namespace": "entity" 28 | } 29 | }, 30 | "bert_model_type": "tests/fixtures/evaluation/ultra_fine/vocab.txt", 31 | "do_lower_case": True, 32 | } 33 | } 34 | return DatasetReader.from_params(Params(params)) 35 | 36 | 37 | class TestUltraFineReader(unittest.TestCase): 38 | def test_ultra_fine_reader(self): 39 | reader = get_reader("entity") 40 | instances = ensure_list(reader.read('tests/fixtures/evaluation/ultra_fine/train.json')) 41 | 42 | # Check number of instances is correct 43 | self.assertEqual(len(instances), 2) 44 | 45 | # Check that first instance's tokens are correct 46 | tokens_0 = [x.text for x in instances[0]['tokens']] 47 | segments_0 = list(instances[0]['segment_ids'].array) 48 | actual = list(zip(tokens_0, segments_0)) 49 | expected = [('[CLS]', 0), 50 | ('the', 0), 51 | ('british', 0), 52 | ('information', 0), 53 | ('commissioner', 0), 54 | ("'s", 0), 55 | ('office', 0), 56 | ('invites', 0), 57 | ('[unused0]', 0), 58 | ('to', 0), 59 | ('locate', 0), 60 | ('its', 0), 61 | ('add', 0), 62 | ('##ress', 0), 63 | ('using', 0), 64 | ('google', 0), 65 | ('[UNK]', 0), 66 | ('.', 0), 67 | ('[SEP]', 0), 68 | ('web', 1), 69 | ('users', 1), 70 | ('[SEP]', 1)] 71 | self.assertListEqual(actual, expected) 72 | 73 | iterator = DataIterator.from_params(Params({"type": "basic"})) 74 | iterator.index_with(Vocabulary()) 75 | 76 | for batch in iterator(instances, num_epochs=1, shuffle=False): 77 | break 78 | 79 | expected_labels = [[0, 0, 0, 0, 0, 0, 1, 0, 0], 80 | [1, 0, 0, 0, 0, 0, 0, 0, 0]] 81 | self.assertEqual(batch['label_ids'].numpy().tolist(), expected_labels) 82 | 83 | def test_ultra_fine_reader_entity_markers(self): 84 | reader = get_reader("entity_markers") 85 | instances = ensure_list(reader.read('tests/fixtures/evaluation/ultra_fine/train.json')) 86 | 87 | # Check number of instances is correct 88 | self.assertEqual(len(instances), 2) 89 | 90 | # Check that first instance's tokens are correct 91 | tokens_0 = [x.text for x in instances[0]['tokens']] 92 | segments_0 = list(instances[0]['segment_ids'].array) 93 | actual = list(zip(tokens_0, segments_0)) 94 | expected = [('[CLS]', 0), 95 | ('the', 0), 96 | ('british', 0), 97 | ('information', 0), 98 | ('commissioner', 0), 99 | ("'s", 0), 100 | ('office', 0), 101 | ('invites', 0), 102 | ('[e1start]', 0), 103 | ('web', 0), 104 | ('users', 0), 105 | ('[e1end]', 0), 106 | ('to', 0), 107 | ('locate', 0), 108 | ('its', 0), 109 | ('add', 0), 110 | ('##ress', 0), 111 | ('using', 0), 112 | ('google', 0), 113 | ('[UNK]', 0), 114 | ('.', 0), 115 | ('[SEP]', 0)] 116 | self.assertListEqual(actual, expected) 117 | 118 | self.assertEqual(actual[instances[0]['index_a'].label], ('[e1start]', 0)) 119 | 120 | 121 | if __name__ == '__main__': 122 | unittest.main() 123 | -------------------------------------------------------------------------------- /tests/evaluation/test_wic_reader.py: -------------------------------------------------------------------------------- 1 | 2 | import unittest 3 | from kb.include_all import WicDatasetReader 4 | from allennlp.common import Params 5 | from allennlp.data import DatasetReader, DataIterator, Vocabulary 6 | 7 | 8 | FIXTURES = 'tests/fixtures/evaluation/wic' 9 | 10 | 11 | class TestWicReader(unittest.TestCase): 12 | def test_wic_reader(self): 13 | reader_params = Params({ 14 | "type": "wic", 15 | "tokenizer_and_candidate_generator": { 16 | "type": "bert_tokenizer_and_candidate_generator", 17 | "entity_candidate_generators": { 18 | "wordnet": {"type": "wordnet_mention_generator", 19 | "entity_file": "tests/fixtures/wordnet/entities_fixture.jsonl"} 20 | }, 21 | "entity_indexers": { 22 | "wordnet": { 23 | "type": "characters_tokenizer", 24 | "tokenizer": { 25 | "type": "word", 26 | "word_splitter": {"type": "just_spaces"}, 27 | }, 28 | "namespace": "entity" 29 | } 30 | }, 31 | "bert_model_type": "tests/fixtures/bert/vocab.txt", 32 | "do_lower_case": True, 33 | }, 34 | }) 35 | 36 | reader = DatasetReader.from_params(reader_params) 37 | instances = reader.read(FIXTURES + '/train') 38 | iterator = DataIterator.from_params(Params({"type": "basic"})) 39 | iterator.index_with(Vocabulary()) 40 | 41 | for batch in iterator(instances, num_epochs=1, shuffle=False): 42 | break 43 | 44 | self.assertTrue(len(batch['label_ids']) == 5) 45 | 46 | self.assertEqual(batch['index_a'][0].item(), 3) 47 | self.assertEqual(batch['index_b'][0].item(), 10) 48 | 49 | def test_wic_reader_entity_markers(self): 50 | reader_params = Params({ 51 | "type": "wic", 52 | "entity_markers": True, 53 | "tokenizer_and_candidate_generator": { 54 | "type": "bert_tokenizer_and_candidate_generator", 55 | "entity_candidate_generators": { 56 | "wordnet": {"type": "wordnet_mention_generator", 57 | "entity_file": "tests/fixtures/wordnet/entities_fixture.jsonl"} 58 | }, 59 | "entity_indexers": { 60 | "wordnet": { 61 | "type": "characters_tokenizer", 62 | "tokenizer": { 63 | "type": "word", 64 | "word_splitter": {"type": "just_spaces"}, 65 | }, 66 | "namespace": "entity" 67 | } 68 | }, 69 | "bert_model_type": "tests/fixtures/evaluation/wic/vocab_entity_markers.txt", 70 | "do_lower_case": True, 71 | }, 72 | }) 73 | 74 | reader = DatasetReader.from_params(reader_params) 75 | instances = reader.read(FIXTURES + '/train') 76 | iterator = DataIterator.from_params(Params({"type": "basic"})) 77 | iterator.index_with(Vocabulary()) 78 | 79 | for batch in iterator(instances, num_epochs=1, shuffle=False): 80 | break 81 | 82 | self.assertTrue(len(batch['label_ids']) == 5) 83 | 84 | self.assertEqual(batch['index_a'][0].item(), 3) 85 | self.assertEqual(batch['index_b'][0].item(), 12) 86 | 87 | instance_0_text = [token.text for token in instances[0].fields['tokens'].tokens] 88 | expected_instance_0_text = ['[CLS]', '[UNK]', '[UNK]', '[e1start]', '[UNK]', 89 | '[e1end]', '[UNK]', '[UNK]', '[UNK]', '.', '[SEP]', '[UNK]', '[e2start]', 90 | '[UNK]', '[e2end]', '[UNK]', 'over', '[UNK]', '.', '[SEP]' 91 | ] 92 | self.assertEqual(instance_0_text, expected_instance_0_text) 93 | self.assertEqual(instance_0_text[3], '[e1start]') 94 | self.assertEqual(instance_0_text[12], '[e2start]') 95 | 96 | 97 | if __name__ == '__main__': 98 | unittest.main() 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /tests/fixtures/bert/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "vocab_size": 18, 3 | "hidden_size": 12, 4 | "num_hidden_layers": 2, 5 | "num_attention_heads": 3, 6 | "intermediate_size": 6, 7 | "hidden_act": "gelu", 8 | "hidden_dropout_prob": 0.1, 9 | "attention_probs_dropout_prob": 0.1, 10 | "max_position_embeddings": 64, 11 | "type_vocab_size": 2, 12 | "initializer_range": 0.02 13 | } 14 | -------------------------------------------------------------------------------- /tests/fixtures/bert/bert_test_fixture.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/kb/4c37bbccd6c871828aceb24b342841825d86df14/tests/fixtures/bert/bert_test_fixture.tar.gz -------------------------------------------------------------------------------- /tests/fixtures/bert/vocab.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [UNK] 3 | the 4 | quick 5 | ##est 6 | brown 7 | fox 8 | ##iest 9 | jumped 10 | over 11 | ##zie 12 | ##st 13 | dog 14 | . 15 | lazy 16 | la 17 | [CLS] 18 | [SEP] 19 | big 20 | cat 21 | ##s 22 | [MASK] 23 | -------------------------------------------------------------------------------- /tests/fixtures/bert/vocab_dir_with_entities_for_tokenizer_and_generator/entity.txt: -------------------------------------------------------------------------------- 1 | @@UNKNOWN@@ 2 | cat%1:04:00:: 3 | cat%1:05:00:: 4 | cat%1:05:02:: 5 | cat%1:06:01:: 6 | cat%1:06:00:: 7 | cat%1:06:02:: 8 | cat%1:18:00:: 9 | cat%1:18:01:: 10 | cat%2:29:00:: 11 | cat%2:35:00:: 12 | hat%1:04:00:: 13 | hat%1:06:00:: 14 | hat%2:29:00:: 15 | hat%2:40:00:: 16 | hats%1:00:01:: 17 | person%1:03:00:: 18 | hat-trick%1:05:55:: 19 | location%1:03:00:: 20 | person%1:03:00:: 21 | group%1:03:00:: 22 | person.n.03 23 | hat-trick.n.55 24 | hat-trick.n.01 25 | hat.n.01 26 | hat.n.02 27 | hat.v.01 28 | hat.v.02 29 | cat.n.01 30 | cat.n.04 31 | computerized_tomography.n.01 32 | @@MASK@@ 33 | @@NULL@@ 34 | -------------------------------------------------------------------------------- /tests/fixtures/bert/vocab_dir_with_entities_for_tokenizer_and_generator/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | *labels 2 | *tags 3 | -------------------------------------------------------------------------------- /tests/fixtures/bert/vocab_dir_with_entities_for_tokenizer_and_generator/tokens.txt: -------------------------------------------------------------------------------- 1 | @@UNKNOWN@@ 2 | bob 3 | scored 4 | a 5 | hat 6 | trick 7 | the 8 | -------------------------------------------------------------------------------- /tests/fixtures/bert_pretraining/shard1.txt: -------------------------------------------------------------------------------- 1 | 0 Big cats are cats . The quickest dog . 2 | 1 Second . The hat is . 3 | -------------------------------------------------------------------------------- /tests/fixtures/evaluation/semeval2010_task8/semeval2010_task8.json: -------------------------------------------------------------------------------- 1 | {"sentence": "The big cats jumped near the laziest brown dogs.", "label": "Other", "sent_id": 3088} 2 | {"sentence": "The big cats jumped near the laziest brown dogs.", "label": "Entity-Destination(e2,e1)", "sent_id": 3089} 3 | -------------------------------------------------------------------------------- /tests/fixtures/evaluation/semeval2010_task8/vocab_entity_markers.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [UNK] 3 | the 4 | quick 5 | ##est 6 | brown 7 | fox 8 | ##iest 9 | jumped 10 | over 11 | ##zie 12 | ##st 13 | dog 14 | . 15 | lazy 16 | la 17 | [CLS] 18 | [SEP] 19 | big 20 | cat 21 | ##s 22 | [MASK] 23 | [e1start] 24 | [e1end] 25 | [e2start] 26 | [e2end] 27 | -------------------------------------------------------------------------------- /tests/fixtures/evaluation/ultra_fine/train.json: -------------------------------------------------------------------------------- 1 | [{"sent": "The British Information Commissioner 's Office invites Web users to locate its address using Google Maps .", "start": 55, "labels": ["person"], "end": 64, "ents": [["Q849763", 4, 11, 0.120101936], ["Q1662473", 12, 46, 0.6343167], ["Q466", 55, 58, 0.19205835], ["Q12013", 93, 104, 0.30670428]]}, {"sent": "Tushar Gandhi said the Australian - born tycoon would be arrested if he visited Bombay or New Delhi again .", "start": 69, "labels": ["entity"], "end": 71, "ents": [["Q5625256", 0, 13, 0.5617667], ["Q142555", 23, 33, 0.1070978], ["Q1156", 80, 86, 0.34545702], ["Q987", 90, 99, 0.38672918]]}] 2 | -------------------------------------------------------------------------------- /tests/fixtures/evaluation/ultra_fine/vocab.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [UNK] 3 | the 4 | quick 5 | ##est 6 | brown 7 | fox 8 | ##iest 9 | jumped 10 | over 11 | ##zie 12 | ##st 13 | dog 14 | . 15 | lazy 16 | la 17 | [CLS] 18 | [SEP] 19 | big 20 | cat 21 | ##s 22 | [MASK] 23 | [unused0] 24 | british 25 | information 26 | commissioner 27 | 's 28 | office 29 | invites 30 | web 31 | users 32 | to 33 | locate 34 | its 35 | add 36 | ##ress 37 | using 38 | google 39 | [e1start] 40 | [e1end] 41 | -------------------------------------------------------------------------------- /tests/fixtures/evaluation/wic/train.data.txt: -------------------------------------------------------------------------------- 1 | carry V 2-1 You must carry your camping gear . Sound carries well over water . 2 | go V 2-6 Messages must go through diplomatic channels . Do you think the sofa will go through the door ? 3 | break V 0-2 Break an alibi . The wholesaler broke the container loads into palettes and boxes for local retailers . 4 | cup N 8-4 He wore a jock strap with a metal cup . Bees filled the waxen cups with honey . 5 | academy N 1-2 The Academy of Music . The French Academy . 6 | -------------------------------------------------------------------------------- /tests/fixtures/evaluation/wic/train.gold.txt: -------------------------------------------------------------------------------- 1 | F 2 | F 3 | F 4 | T 5 | F 6 | -------------------------------------------------------------------------------- /tests/fixtures/evaluation/wic/vocab_entity_markers.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [UNK] 3 | the 4 | quick 5 | ##est 6 | brown 7 | fox 8 | ##iest 9 | jumped 10 | over 11 | ##zie 12 | ##st 13 | dog 14 | . 15 | lazy 16 | la 17 | [CLS] 18 | [SEP] 19 | big 20 | cat 21 | ##s 22 | [MASK] 23 | [e1start] 24 | [e1end] 25 | [e2start] 26 | [e2end] 27 | -------------------------------------------------------------------------------- /tests/fixtures/kg_embeddings/tucker_wordnet/model.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/kb/4c37bbccd6c871828aceb24b342841825d86df14/tests/fixtures/kg_embeddings/tucker_wordnet/model.tar.gz -------------------------------------------------------------------------------- /tests/fixtures/kg_embeddings/tucker_wordnet/vocabulary/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | *labels 2 | *tags 3 | -------------------------------------------------------------------------------- /tests/fixtures/kg_embeddings/tucker_wordnet/vocabulary/relation.txt: -------------------------------------------------------------------------------- 1 | @@UNKNOWN@@ 2 | synset_lemma_reverse 3 | synset_lemma 4 | synset_hypernyms 5 | synset_hypernyms_reverse 6 | synset_attributes 7 | synset_attributes_reverse 8 | synset_similar_tos 9 | synset_similar_tos_reverse 10 | synset_member_holonyms 11 | synset_member_holonyms_reverse 12 | synset_causes 13 | synset_causes_reverse 14 | synset_verb_groups 15 | synset_verb_groups_reverse 16 | synset_part_holonyms 17 | synset_part_holonyms_reverse 18 | synset_substance_holonyms 19 | synset_substance_holonyms_reverse 20 | -------------------------------------------------------------------------------- /tests/fixtures/kg_embeddings/wn18rr_dev.txt: -------------------------------------------------------------------------------- 1 | 1 _hypernym 5 2 | -------------------------------------------------------------------------------- /tests/fixtures/kg_embeddings/wn18rr_train.txt: -------------------------------------------------------------------------------- 1 | 1 _hypernym 2 2 | 1 _derivationally_related_form 3 3 | 1 _hypernym 3 4 | 4 _hypernym 1 5 | 4 _hypernym 5 6 | -------------------------------------------------------------------------------- /tests/fixtures/kg_probe/file1.txt: -------------------------------------------------------------------------------- 1 | 0 0 Cats are quickest . 2 | 6 7 The brown fox jumped over the laziest dog . 3 | -------------------------------------------------------------------------------- /tests/fixtures/linking/aida.txt: -------------------------------------------------------------------------------- 1 | DOCSTART_1_EU 2 | EU 3 | rejects 4 | MMSTART_11867 Germany 5 | German 6 | MMEND 7 | call 8 | to 9 | boycott 10 | MMSTART_31717 United_Kingdom 11 | British 12 | MMEND 13 | lamb 14 | . 15 | *NL* 16 | *NL* 17 | The 18 | MMSTART_9974 European_Commission 19 | European 20 | Commission 21 | MMEND 22 | said 23 | on 24 | Thursday 25 | it 26 | disagreed 27 | with 28 | MMSTART_11867 Germany 29 | German 30 | MMEND 31 | advice 32 | to 33 | consumers 34 | to 35 | shun 36 | MMSTART_31717 United_Kingdom 37 | British 38 | MMEND 39 | lamb 40 | until 41 | scientists 42 | determine 43 | whether 44 | it 45 | is 46 | dangerous 47 | -------------------------------------------------------------------------------- /tests/fixtures/linking/entities_universe.txt: -------------------------------------------------------------------------------- 1 | 2344068 A_Forest 2 | 269070 Operation_Granby 3 | 5426016 Alfred_Bryan_Bonds 4 | 15421821 Killer_(1998_film) 5 | 1122869 1984_Brazilian_Grand_Prix 6 | 14560419 MRGPRX1 7 | 20568301 Forbidden_Corner 8 | 21482087 Consulate_of_the_United_States,_Liverpool 9 | 3424957 Clan_Cameron 10 | 320021 Bioneers 11 | -------------------------------------------------------------------------------- /tests/fixtures/multitask/ccgbank.txt: -------------------------------------------------------------------------------- 1 | ID=wsj_0001.1 PARSER=GOLD NUMPARSE=1 2 | ( ( ( ( ( ( ( () () ) ) () ) ( ( ( ( () () ) ) () ) ) ) () ) ( () ( ( ( () ( () () ) ) ( () ( () ( () () ) ) ) ) ( () () ) ) ) ) () ) 3 | ID=wsj_0001.2 PARSER=GOLD NUMPARSE=1 4 | ( ( ( ( () () ) ) ( () ( ( () ) ( () ( ( ( () () ) ) ( () ( () ( () ( () () ) ) ) ) ) ) ) ) ) () ) 5 | -------------------------------------------------------------------------------- /tests/fixtures/multitask/conll2003.txt: -------------------------------------------------------------------------------- 1 | -DOCSTART- -X- -X- O 2 | 3 | U.N. NNP I-NP I-ORG 4 | official NN I-NP O 5 | Ekeus NNP I-NP I-PER 6 | heads VBZ I-VP O 7 | for IN I-PP O 8 | Baghdad NNP I-NP I-LOC 9 | . . O O 10 | 11 | -DOCSTART- -X- -X- O 12 | 13 | AI2 NNP I-NP I-ORG 14 | engineer NN I-NP O 15 | Joel NNP I-NP I-PER 16 | lives VBZ I-VP O 17 | in IN I-PP O 18 | Seattle NNP I-NP I-LOC 19 | . . O O 20 | -------------------------------------------------------------------------------- /tests/fixtures/tacred/LDC2018T24.json: -------------------------------------------------------------------------------- 1 | [{"id": "e7798fb926b9403cfcd2", "docid": "APW_ENG_20101103.0539", "relation": "per:title", "token": ["At", "the", "same", "time", ",", "Chief", "Financial", "Officer", "Douglas", "Flint", "will", "become", "chairman", ",", "succeeding", "Stephen", "Green", "who", "is", "leaving", "to", "take", "a", "government", "job", "."], "subj_start": 8, "subj_end": 9, "obj_start": 12, "obj_end": 12, "subj_type": "PERSON", "obj_type": "TITLE", "stanford_pos": ["IN", "DT", "JJ", "NN", ",", "NNP", "NNP", "NNP", "NNP", "NNP", "MD", "VB", "NN", ",", "VBG", "NNP", "NNP", "WP", "VBZ", "VBG", "TO", "VB", "DT", "NN", "NN", "."], "stanford_ner": ["O", "O", "O", "O", "O", "O", "O", "O", "PERSON", "PERSON", "O", "O", "O", "O", "O", "PERSON", "PERSON", "O", "O", "O", "O", "O", "O", "O", "O", "O"], "stanford_head": [4, 4, 4, 12, 12, 10, 10, 10, 10, 12, 12, 0, 12, 12, 12, 17, 15, 20, 20, 17, 22, 20, 25, 25, 22, 12], "stanford_deprel": ["case", "det", "amod", "nmod", "punct", "compound", "compound", "compound", "compound", "nsubj", "aux", "ROOT", "xcomp", "punct", "xcomp", "compound", "dobj", "nsubj", "aux", "acl:relcl", "mark", "xcomp", "det", "compound", "dobj", "punct"]}, {"id": "e779865fb96bbbcc4ca4", "docid": "APW_ENG_20080229.1401.LDC2009T13", "relation": "no_relation", "token": ["U.S.", "District", "Court", "Judge", "Jeffrey", "White", "in", "mid-February", "issued", "an", "injunction", "against", "Wikileaks", "after", "the", "Zurich-based", "Bank", "Julius", "Baer", "accused", "the", "site", "of", "posting", "sensitive", "account", "information", "stolen", "by", "a", "disgruntled", "former", "employee", "."], "subj_start": 17, "subj_end": 18, "obj_start": 4, "obj_end": 5, "subj_type": "PERSON", "obj_type": "PERSON", "stanford_pos": ["NNP", "NNP", "NNP", "NNP", "NNP", "NNP", "IN", "NNP", "VBD", "DT", "NN", "IN", "NNP", "IN", "DT", "JJ", "NNP", "NNP", "NNP", "VBD", "DT", "NN", "IN", "VBG", "JJ", "NN", "NN", "VBN", "IN", "DT", "JJ", "JJ", "NN", "."], "stanford_ner": ["LOCATION", "O", "O", "O", "PERSON", "PERSON", "O", "O", "O", "O", "O", "O", "ORGANIZATION", "O", "O", "MISC", "O", "PERSON", "PERSON", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"], "stanford_head": [6, 6, 6, 6, 6, 9, 8, 6, 0, 11, 9, 13, 11, 20, 19, 19, 19, 19, 20, 9, 22, 20, 24, 20, 27, 27, 24, 27, 33, 33, 33, 33, 28, 9], "stanford_deprel": ["compound", "compound", "compound", "compound", "compound", "nsubj", "case", "nmod", "ROOT", "det", "dobj", "case", "nmod", "mark", "det", "amod", "compound", "compound", "nsubj", "advcl", "det", "dobj", "mark", "advcl", "amod", "compound", "dobj", "acl", "case", "det", "amod", "amod", "nmod", "punct"]}, {"id": "e7798ae9c0adbcdc81e7", "docid": "APW_ENG_20090707.0488", "relation": "per:city_of_death", "token": ["PARIS", "2009-07-07", "11:07:32", "UTC", "French", "media", "earlier", "reported", "that", "Montcourt", ",", "ranked", "119", ",", "was", "found", "dead", "by", "his", "girlfriend", "in", "the", "stairwell", "of", "his", "Paris", "apartment", "."], "subj_start": 9, "subj_end": 9, "obj_start": 0, "obj_end": 0, "subj_type": "PERSON", "obj_type": "CITY", "stanford_pos": ["NNP", "CD", "CD", "NNP", "NNP", "NNS", "RBR", "VBD", "IN", "NNP", ",", "VBD", "CD", ",", "VBD", "VBN", "JJ", "IN", "PRP$", "NN", "IN", "DT", "NN", "IN", "PRP$", "NNP", "NN", "."], "stanford_ner": ["LOCATION", "TIME", "TIME", "TIME", "MISC", "O", "O", "O", "O", "PERSON", "O", "O", "NUMBER", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "O", "LOCATION", "O", "O"], "stanford_head": [6, 6, 6, 6, 6, 8, 8, 0, 16, 16, 10, 10, 12, 10, 16, 8, 16, 20, 20, 17, 23, 23, 16, 27, 27, 27, 23, 8], "stanford_deprel": ["compound", "nummod", "nummod", "compound", "compound", "nsubj", "advmod", "ROOT", "mark", "nsubjpass", "punct", "acl", "dobj", "punct", "auxpass", "ccomp", "xcomp", "case", "nmod:poss", "nmod", "case", "det", "nmod", "case", "nmod:poss", "compound", "nmod", "punct"]}] -------------------------------------------------------------------------------- /tests/fixtures/tacred/vocab.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [UNK] 3 | [CLS] 4 | [SEP] 5 | [MASK] 6 | [s-person] 7 | [o-title] 8 | [e1start] 9 | [e1end] 10 | [e2start] 11 | [e2end] 12 | ##ed 13 | ##ing 14 | ##ment 15 | at 16 | the 17 | same 18 | time 19 | chief 20 | financial 21 | officier 22 | douglas 23 | flint 24 | will 25 | become 26 | chairman 27 | , 28 | succeed 29 | who 30 | is 31 | leav 32 | to 33 | take 34 | a 35 | govern 36 | . 37 | paris 38 | french 39 | media 40 | earlier 41 | report 42 | that 43 | montcourt 44 | ranked 45 | was 46 | found 47 | dead 48 | by 49 | his 50 | girlfriend 51 | in 52 | the 53 | stairwell 54 | of 55 | apartment 56 | -------------------------------------------------------------------------------- /tests/fixtures/wordnet/cat_hat_mask_null_embedding.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/kb/4c37bbccd6c871828aceb24b342841825d86df14/tests/fixtures/wordnet/cat_hat_mask_null_embedding.hdf5 -------------------------------------------------------------------------------- /tests/fixtures/wordnet/cat_hat_synset_mask_null_vocab.txt: -------------------------------------------------------------------------------- 1 | @@UNKNOWN@@ 2 | person.n.03 3 | hat-trick.n.5 4 | cat.v.01 5 | cat.n.03 6 | cat.n.01 7 | hat.n.01 8 | hat.n.02 9 | hat.v.01 10 | hat.v.02 11 | @@MASK@@ 12 | @@NULL@@ 13 | -------------------------------------------------------------------------------- /tests/fixtures/wordnet/cat_hat_vocabdir/entity.txt: -------------------------------------------------------------------------------- 1 | @@UNKNOWN@@ 2 | cat%1:04:00:: 3 | cat%1:05:00:: 4 | cat%1:05:02:: 5 | cat%1:06:01:: 6 | cat%1:06:00:: 7 | cat%1:06:02:: 8 | cat%1:18:00:: 9 | cat%1:18:01:: 10 | cat%2:29:00:: 11 | cat%2:35:00:: 12 | hat%1:04:00:: 13 | hat%1:06:00:: 14 | hat%2:29:00:: 15 | hat%2:40:00:: 16 | hats%1:00:01:: 17 | person%1:03:00:: 18 | hat-trick%1:05:55:: 19 | location%1:03:00:: 20 | person%1:03:00:: 21 | group%1:03:00:: 22 | person.n.03 23 | hat-trick.n.55 24 | hat-trick.n.01 25 | hat.n.01 26 | hat.n.02 27 | hat.v.01 28 | hat.v.02 29 | -------------------------------------------------------------------------------- /tests/fixtures/wordnet/cat_hat_vocabdir/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | *labels 2 | *tags 3 | -------------------------------------------------------------------------------- /tests/fixtures/wordnet/cat_hat_vocabdir/tokens.txt: -------------------------------------------------------------------------------- 1 | @@UNKNOWN@@ 2 | bob 3 | scored 4 | a 5 | hat 6 | trick 7 | the 8 | -------------------------------------------------------------------------------- /tests/fixtures/wordnet/entities_cat_hat.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "cat%1:04:00::", "pos": "n", "synset": "computerized_tomography.n.01", "type": "lemma", "count": 5} 2 | {"id": "cat%1:05:00::", "pos": "n", "synset": "cat.n.01", "type": "lemma", "count": 0} 3 | {"id": "cat%1:05:02::", "pos": "n", "synset": "big_cat.n.01", "type": "lemma", "count": 0} 4 | {"id": "cat%1:06:01::", "pos": "n", "synset": "caterpillar.n.02", "type": "lemma", "count": 0} 5 | {"id": "cat%1:06:00::", "pos": "n", "synset": "cat-o'-nine-tails.n.01", "type": "lemma", "count": 0} 6 | {"id": "cat%1:06:02::", "pos": "n", "synset": "kat.n.01", "type": "lemma", "count": 0} 7 | {"id": "cat%1:18:00::", "pos": "n", "synset": "cat.n.03", "type": "lemma", "count": 2} 8 | {"id": "cat%1:18:01::", "pos": "n", "synset": "guy.n.01", "type": "lemma", "count": 0} 9 | {"id": "cat%2:29:00::", "pos": "v", "synset": "vomit.v.01", "type": "lemma", "count": 0} 10 | {"id": "cat%2:35:00::", "pos": "v", "synset": "cat.v.01", "type": "lemma", "count": 0} 11 | {"id": "cat.n.01", "pos": "n", "lemmas": ["cat%1:05:00::", "true_cat%1:05:00::"], "examples": [], "definition": "feline mammal usually having thick soft fur and no ability to roar: domestic cats; wildcats", "type": "synset"} 12 | {"id": "cat.n.03", "pos": "n", "lemmas": ["cat%1:18:00::"], "examples": ["what a cat she is!"], "definition": "a spiteful woman gossip", "type": "synset"} 13 | {"id": "cat.v.01", "pos": "v", "lemmas": ["cat%2:35:00::"], "examples": [], "definition": "beat with a cat-o'-nine-tails", "type": "synset"} 14 | {"id": "hat.n.02", "pos": "n", "lemmas": ["hat%1:04:00::"], "examples": ["he took off his politician's hat and talked frankly"], "definition": "an informal term for a person's role", "type": "synset"} 15 | {"id": "hat.n.01", "pos": "n", "lemmas": ["hat%1:06:00::", "chapeau%1:06:00::", "lid%1:06:01::"], "examples": [], "definition": "headdress that protects the head from bad weather; has shaped crown and usually a brim", "type": "synset"} 16 | {"id": "hat.v.01", "pos": "v", "lemmas": ["hat%2:29:00::"], "examples": ["He was unsuitably hatted"], "definition": "put on or wear a hat", "type": "synset"} 17 | {"id": "hat.v.02", "pos": "v", "lemmas": ["hat%2:40:00::"], "examples": [], "definition": "furnish with a hat", "type": "synset"} 18 | {"id": "hat%1:04:00::", "pos": "n", "synset": "hat.n.02", "type": "lemma", "count": 1} 19 | {"id": "hat%1:06:00::", "pos": "n", "synset": "hat.n.01", "type": "lemma", "count": 2} 20 | {"id": "hat%2:29:00::", "pos": "v", "synset": "hat.v.01", "type": "lemma", "count": 3} 21 | {"id": "hat%2:40:00::", "pos": "v", "synset": "hat.v.02", "type": "lemma", "count": 4} 22 | {"id": "hats%1:00:01::", "pos": "n", "synset": "hats.n.01", "type": "lemma", "count": 0} 23 | {"id": "person%1:03:00::", "pos": "n", "synset": "person.n.03", "type": "lemma", "count": 0} 24 | {"id": "person.n.03", "pos": "n", "lemmas": ["person%1:03:00::"], "examples": [], "definition": "Bob", "type": "synset"} 25 | {"id": "hat-trick%1:05:55::", "pos": "n", "synset": "hat-trick.n.55", "type": "lemma", "count": 0} 26 | {"id": "hat-trick.n.5", "pos": "n", "lemmas": ["hat-trick%1:05:55::"], "examples": [], "definition": "something", "type": "synset"} 27 | {"id": "hat-trick%1:01:03::", "pos": "n", "synset": "hat-trick.n.01", "type": "lemma", "count": 2} 28 | -------------------------------------------------------------------------------- /tests/fixtures/wordnet/entities_fixture.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "cat%1:04:00::", "pos": "n", "synset": "computerized_tomography.n.01", "type": "lemma", "count": 5} 2 | {"id": "big_cat%1:05:00::", "pos": "n", "synset": "cat.n.01", "type": "lemma", "count": 0} 3 | {"id": "big_cat%1:04:00::", "pos": "n", "synset": "cat.n.04", "type": "lemma", "count": 4} 4 | {"id": "person%1:03:00::", "pos": "n", "synset": "person.n.01", "type": "lemma", "count": 5} 5 | {"id": "person%1:01:55::", "pos": "n", "synset": "person.n.02", "type": "lemma", "count": 8} 6 | {"id": "person%1:01:56::", "pos": "n", "synset": "person.n.03", "type": "lemma", "count": 10} 7 | {"id": "see_a%2:05:25::", "pos": "v", "synset": "see_a.v.01", "type": "lemma", "count": 10} 8 | {"id": "see_a%3:01:45::", "pos": "a", "synset": "see_a.a.02", "type": "lemma", "count": 10} 9 | {"id": "half-baked%3:01:22::", "pos": "a", "synset": "half-baked.a.01", "type": "lemma", "count": 10} 10 | {"id": "hot_dog%1:01:02::", "pos": "n", "synset": "hot_dog.n.01", "type": "lemma", "count": 15} 11 | {"id": "washingtonian%1:01:02::", "pos": "a", "synset": "washingtonian.a.01", "type": "lemma", "count": 2} 12 | -------------------------------------------------------------------------------- /tests/fixtures/wordnet/wsd_dataset.json: -------------------------------------------------------------------------------- 1 | [{"token": "Bob", "lemma": "person", "pos": "NOUN", "senses": ["1:03:00::"], "id": "d000.s000.t000"}, {"token": "scored", "lemma": "score", "pos": "VERB"}, {"token": "a", "lemma": "a", "pos": "DET"}, {"token": "hat trick", "lemma": "hat-trick", "pos": "NOUN", "senses": ["1:05:55::"], "id": "d000.s000.t001"}] 2 | [{"token": "The", "lemma": "the", "pos": "DET"}, {"token": "hat", "lemma": "hat", "pos": "NOUN", "senses": ["1:06:00::"], "id": "d000.s001.t000"}] 3 | -------------------------------------------------------------------------------- /tests/fixtures/wordnet_wiki_vocab/entity_wiki.txt: -------------------------------------------------------------------------------- 1 | @@UNKNOWN@@ 2 | A_Forest 3 | Operation_Granby 4 | Alfred_Bryan_Bonds 5 | Killer_(1998_film) 6 | 1984_Brazilian_Grand_Prix 7 | MRGPRX1 8 | Forbidden_Corner 9 | Consulate_of_the_United_States,_Liverpool 10 | Clan_Cameron 11 | Bioneers 12 | @@MASK@@ 13 | @@NULL@@ 14 | -------------------------------------------------------------------------------- /tests/fixtures/wordnet_wiki_vocab/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | *labels 2 | *tags 3 | -------------------------------------------------------------------------------- /tests/test_bert_pretraining_reader.py: -------------------------------------------------------------------------------- 1 | 2 | from kb.bert_pretraining_reader import BertPreTrainingReader, \ 3 | replace_candidates_with_mask_entity 4 | from kb.wordnet import WordNetCandidateMentionGenerator 5 | from kb.wiki_linking_util import WikiCandidateMentionGenerator 6 | 7 | from kb.testing import get_bert_pretraining_reader_with_kg as get_reader 8 | 9 | from kb.bert_tokenizer_and_candidate_generator import BertTokenizerAndCandidateGenerator 10 | 11 | import unittest 12 | 13 | from allennlp.common import Params 14 | from allennlp.data import DatasetReader, Vocabulary, DataIterator 15 | 16 | import numpy as np 17 | 18 | import torch 19 | 20 | 21 | 22 | class TestReplaceCandidatesWithMaskEntity(unittest.TestCase): 23 | def test_replace_candidates_with_mask_entity(self): 24 | spans_to_mask = set([(0, 0), (1, 2), (3, 3)]) 25 | candidates = { 26 | 'wordnet': {'candidate_spans': [[0, 0], [1, 1], [1, 2]], 27 | 'candidate_entities': [["a"], ["b", "c"], ["d"]], 28 | 'candidate_entity_priors': [[1.0], [0.2, 0.8], [1.0]]}, 29 | 'wiki': {'candidate_spans': [[3, 3]], 30 | 'candidate_entities': [["d"]], 31 | 'candidate_entity_priors': [[1.0]]}, 32 | } 33 | replace_candidates_with_mask_entity(candidates, spans_to_mask) 34 | 35 | expected_candidates = { 36 | 'wordnet': {'candidate_spans': [[0, 0], [1, 1], [1, 2]], 37 | 'candidate_entities': [["@@MASK@@"], ["b", "c"], ["@@MASK@@"]], 38 | 'candidate_entity_priors': [[1.0], [0.2, 0.8], [1.0]]}, 39 | 'wiki': {'candidate_spans': [[3, 3]], 40 | 'candidate_entities': [["@@MASK@@"]], 41 | 'candidate_entity_priors': [[1.0]]}, 42 | } 43 | 44 | for key in ['wordnet', 'wiki']: 45 | for key2 in ['candidate_spans', 'candidate_entities']: 46 | self.assertListEqual( 47 | candidates[key][key2], expected_candidates[key][key2] 48 | ) 49 | 50 | 51 | class TestBertPretrainingReader(unittest.TestCase): 52 | def test_create_masked_lm_predictions(self): 53 | reader = get_reader(masked_lm_prob=0.5) 54 | np.random.seed(5) 55 | 56 | tokens, lm_labels = reader._tokenizer_masker.create_masked_lm_predictions( 57 | "The original tokens in the sentence .".split() 58 | ) 59 | 60 | expected_tokens = ['The', '[MASK]', '[MASK]', 'in', '[MASK]', 'sentence', '[MASK]'] 61 | expected_lm_labels = ['[PAD]', 'original', 'tokens', '[PAD]', 'the', '[PAD]', '.'] 62 | 63 | self.assertEqual(expected_tokens, tokens) 64 | self.assertEqual(expected_lm_labels, lm_labels) 65 | 66 | def test_reader_can_run_with_full_mask_strategy(self): 67 | reader = get_reader('full_mask', masked_lm_prob=0.5) 68 | instances = reader.read("tests/fixtures/bert_pretraining/shard1.txt") 69 | self.assertEqual(len(instances), 2) 70 | 71 | def test_reader_can_run_with_wordnet_and_wiki(self): 72 | reader = get_reader('full_mask', masked_lm_prob=0.5, include_wiki=True) 73 | instances = reader.read("tests/fixtures/bert_pretraining/shard1.txt") 74 | self.assertEqual(len(instances), 2) 75 | 76 | def test_reader(self): 77 | reader = get_reader(masked_lm_prob=0.15) 78 | 79 | np.random.seed(5) 80 | instances = reader.read("tests/fixtures/bert_pretraining/shard1.txt") 81 | 82 | vocab = Vocabulary.from_params(Params({ 83 | "directory_path": "tests/fixtures/bert/vocab_dir_with_entities_for_tokenizer_and_generator" 84 | })) 85 | iterator = DataIterator.from_params(Params({"type": "basic"})) 86 | iterator.index_with(vocab) 87 | 88 | for batch in iterator(instances, num_epochs=1, shuffle=False): 89 | break 90 | 91 | actual_tokens_ids = batch['tokens']['tokens'] 92 | expected_tokens_ids = torch.tensor( 93 | [[16, 18, 19, 20, 1, 19, 21, 13, 17, 21, 3, 4, 12, 13, 17], 94 | [16, 1, 13, 17, 21, 1, 1, 13, 17, 0, 0, 0, 0, 0, 0]]) 95 | 96 | self.assertEqual(actual_tokens_ids.tolist(), expected_tokens_ids.tolist()) 97 | 98 | actual_entities = batch['candidates']['wordnet']['candidate_entities']['ids'] 99 | expected_entities = torch.tensor( 100 | [[[29, 30], 101 | [31, 0], 102 | [31, 0]], 103 | 104 | [[ 0, 0], 105 | [ 0, 0], 106 | [ 0, 0]]]) 107 | self.assertEqual(actual_entities.tolist(), expected_entities.tolist()) 108 | 109 | expected_spans = torch.tensor( 110 | [[[ 1, 3], 111 | [ 2, 3], 112 | [ 5, 6]], 113 | 114 | [[-1, -1], 115 | [-1, -1], 116 | [-1, -1]]]) 117 | actual_spans = batch['candidates']['wordnet']['candidate_spans'] 118 | self.assertEqual(actual_spans.tolist(), expected_spans.tolist()) 119 | 120 | expected_lm_labels = torch.tensor( 121 | [[ 0, 0, 0, 0, 0, 0, 20, 0, 0, 2, 0, 0, 0, 0, 0], 122 | [ 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) 123 | actual_lm_labels = batch['lm_label_ids']['lm_labels'] 124 | self.assertEqual(actual_lm_labels.tolist(), expected_lm_labels.tolist()) 125 | 126 | expected_segment_ids = torch.tensor( 127 | [[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], 128 | [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]]) 129 | self.assertEqual(batch['segment_ids'].tolist(), expected_segment_ids.tolist()) 130 | self.assertTrue(batch['segment_ids'].dtype == torch.long) 131 | 132 | 133 | if __name__ == '__main__': 134 | unittest.main() 135 | 136 | 137 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | 2 | import unittest 3 | 4 | import torch 5 | 6 | from kb.common import F1Metric 7 | 8 | 9 | class TestF1Metric(unittest.TestCase): 10 | def test_f1(self): 11 | 12 | f1 = F1Metric() 13 | 14 | predicted = [ 15 | ['a', 'b', 'c'], 16 | ['d'], 17 | [] 18 | ] 19 | gold = [ 20 | ['b', 'd'], 21 | ['d', 'e'], 22 | ['f'] 23 | ] 24 | 25 | f1(predicted, gold) 26 | predicted2 = [[6, 10]] 27 | gold2 = [[6, 15]] 28 | f1(predicted2, gold2) 29 | 30 | metrics = f1.get_metric() 31 | 32 | precision = 3 / 6 33 | recall = 3 / 7 34 | f1 = 2 * precision * recall / (precision + recall) 35 | 36 | expected_metrics = [precision, recall, f1] 37 | 38 | for m1, m2 in zip(metrics, expected_metrics): 39 | self.assertAlmostEqual(m1, m2) 40 | 41 | 42 | if __name__ == '__main__': 43 | unittest.main() 44 | 45 | -------------------------------------------------------------------------------- /tests/test_dict_field.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | from allennlp.data import Token, Vocabulary 5 | from allennlp.data.fields import TextField, ListField, ArrayField, SpanField 6 | from allennlp.data.token_indexers import SingleIdTokenIndexer 7 | from allennlp.data.tokenizers import WordTokenizer 8 | from allennlp.data.tokenizers.word_splitter import JustSpacesWordSplitter 9 | 10 | 11 | from kb.dict_field import DictField 12 | from kb.entity_linking import TokenCharactersIndexerTokenizer 13 | 14 | import unittest 15 | import torch 16 | 17 | 18 | class TestDictField(unittest.TestCase): 19 | def setUp(self): 20 | super(TestDictField, self).setUp() 21 | 22 | entity_tokenizer = WordTokenizer(word_splitter=JustSpacesWordSplitter()) 23 | 24 | self.vocab = Vocabulary() 25 | self.vocab.add_token_to_namespace("entity1", "entity") 26 | self.vocab.add_token_to_namespace("entity2", "entity") 27 | self.vocab.add_token_to_namespace("entity3", "entity") 28 | self.entity_indexer = {"entity": TokenCharactersIndexerTokenizer( 29 | "entity", character_tokenizer=entity_tokenizer) 30 | } 31 | 32 | tokens1 = "The sentence .".split() 33 | tokens_field = TextField( 34 | [Token(t) for t in tokens1], 35 | token_indexers={'tokens': SingleIdTokenIndexer()} 36 | ) 37 | 38 | self.instance1_fields = { 39 | "candidate_entities": TextField( 40 | [Token("entity1 entity2"), Token("entity_unk")], 41 | token_indexers=self.entity_indexer), 42 | "candidate_entity_prior": ArrayField(np.array([[0.5, 0.5], [1.0, 0.0]])), 43 | "candidate_spans": ListField( 44 | [SpanField(0, 0, tokens_field), 45 | SpanField(1, 2, tokens_field)] 46 | ) 47 | } 48 | 49 | tokens2 = "The sentence".split() 50 | tokens2_field = TextField( 51 | [Token(t) for t in tokens2], 52 | token_indexers={'tokens': SingleIdTokenIndexer()} 53 | ) 54 | 55 | self.instance2_fields = { 56 | "candidate_entities": TextField( 57 | [Token("entity1")], 58 | token_indexers=self.entity_indexer), 59 | "candidate_entity_prior": ArrayField(np.array([[1.0]])), 60 | "candidate_spans": ListField( 61 | [SpanField(1, 1, tokens2_field)], 62 | ) 63 | } 64 | 65 | def test_get_padding_lengths(self): 66 | field = DictField(self.instance1_fields) 67 | field.index(self.vocab) 68 | lengths = field.get_padding_lengths() 69 | self.assertDictEqual( 70 | lengths, 71 | {'candidate_entities*entity_length': 2, 72 | 'candidate_entities*num_token_characters': 2, 73 | 'candidate_entities*num_tokens': 2, 74 | 'candidate_entity_prior*dimension_0': 2, 75 | 'candidate_entity_prior*dimension_1': 2, 76 | 'candidate_spans*num_fields': 2} 77 | ) 78 | 79 | def test_dict_field_can_handle_empty(self): 80 | field = DictField(self.instance1_fields) 81 | empty = field.empty_field() 82 | self.assertTrue(True) 83 | 84 | def _check_tensors(self, tensor, expected): 85 | self.assertListEqual( 86 | sorted(list(tensor.keys())), sorted(list(expected.keys())) 87 | ) 88 | for key in tensor.keys(): 89 | if key == 'candidate_entities': 90 | a = tensor[key]['entity'] 91 | b = expected[key]['entity'] 92 | else: 93 | a = tensor[key] 94 | b = expected[key] 95 | self.assertTrue(np.allclose(a.numpy(), b.numpy())) 96 | 97 | 98 | def test_dict_field_as_tensor(self): 99 | field = DictField(self.instance1_fields) 100 | field.index(self.vocab) 101 | tensor = field.as_tensor(field.get_padding_lengths()) 102 | 103 | expected = {'candidate_entities': {'entity': torch.tensor([[2, 3], 104 | [1, 0]])}, 'candidate_entity_prior': torch.tensor([[0.5000, 0.5000], 105 | [1.0000, 0.0000]]), 'candidate_spans': torch.tensor([[0, 0], 106 | [1, 2]])} 107 | 108 | self._check_tensors(tensor, expected) 109 | 110 | def test_dict_field_can_iterator(self): 111 | from allennlp.data import Instance 112 | from allennlp.data.iterators import BasicIterator 113 | 114 | iterator = BasicIterator() 115 | iterator.index_with(self.vocab) 116 | 117 | instances = [ 118 | Instance({"candidates": DictField(self.instance1_fields)}), 119 | Instance({"candidates": DictField(self.instance2_fields)}) 120 | ] 121 | 122 | for batch in iterator(instances, num_epochs=1, shuffle=False): 123 | break 124 | 125 | expected_batch = {'candidates': { 126 | 'candidate_entities': {'entity': torch.tensor([[[2, 3], 127 | [1, 0]], 128 | 129 | [[2, 0], 130 | [0, 0]]])}, 131 | 'candidate_entity_prior': torch.tensor([[[0.5000, 0.5000], 132 | [1.0000, 0.0000]], 133 | 134 | [[1.0000, 0.0000], 135 | [0.0000, 0.0000]]]), 136 | 'candidate_spans': torch.tensor([[[ 0, 0], 137 | [ 1, 2]], 138 | 139 | [[ 1, 1], 140 | [-1, -1]]])} 141 | } 142 | 143 | self._check_tensors(batch['candidates'], expected_batch['candidates']) 144 | 145 | def test_list_field_of_dict_field(self): 146 | from allennlp.data import Instance 147 | from allennlp.data.iterators import BasicIterator 148 | 149 | tokens3 = "The long sentence .".split() 150 | tokens3_field = TextField( 151 | [Token(t) for t in tokens3], 152 | token_indexers={'tokens': SingleIdTokenIndexer()} 153 | ) 154 | 155 | instance3_fields = { 156 | "candidate_entities": TextField( 157 | [Token("entity1 entity2 entity3"), Token("entity_unk"), Token("entity2 entity3")], 158 | token_indexers=self.entity_indexer), 159 | "candidate_entity_prior": ArrayField(np.array([[0.1, 0.1, 0.8], 160 | [1.0, 0.0, 0.0], 161 | [0.33, 0.67, 0.0]])), 162 | "candidate_spans": ListField( 163 | [SpanField(1, 1, tokens3_field), SpanField(1, 2, tokens3_field), SpanField(1, 3, tokens3_field)], 164 | ) 165 | } 166 | 167 | iterator = BasicIterator() 168 | iterator.index_with(self.vocab) 169 | 170 | instances = [Instance({"candidates": ListField([ 171 | DictField(self.instance1_fields), 172 | DictField(self.instance2_fields)])}), 173 | Instance({"candidates": ListField([ 174 | DictField(self.instance1_fields), 175 | DictField(instance3_fields)])}) 176 | ] 177 | 178 | for batch in iterator(instances, num_epochs=1, shuffle=False): 179 | pass 180 | 181 | self.assertTrue(batch['candidates']['candidate_entities']['entity'].shape == batch['candidates']['candidate_entity_prior'].shape) 182 | 183 | 184 | if __name__ == '__main__': 185 | unittest.main() 186 | 187 | -------------------------------------------------------------------------------- /tests/test_entity_linking.py: -------------------------------------------------------------------------------- 1 | 2 | import unittest 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from allennlp.data import TokenIndexer, Vocabulary, Token 8 | from allennlp.common import Params 9 | from allennlp.models import Model 10 | 11 | from kb.entity_linking import TokenCharactersIndexerTokenizer 12 | from kb.entity_linking import remap_span_indices_after_subword_tokenization 13 | from kb.testing import get_bert_test_fixture 14 | from kb.include_all import ModelArchiveFromParams 15 | 16 | 17 | class TestRemapAfterWordpiece(unittest.TestCase): 18 | def test_remap(self): 19 | bert_fixture = get_bert_test_fixture() 20 | indexer = bert_fixture['indexer'] 21 | 22 | tokens = [Token(t) for t in 'The words dog overst .'.split()] 23 | vocab = Vocabulary() 24 | indexed = indexer.tokens_to_indices(tokens, vocab, 'wordpiece') 25 | 26 | original_span_indices = [ 27 | [0, 0], [0, 1], [2, 3], [3, 3], [2, 4] 28 | ] 29 | offsets = indexed['wordpiece-offsets'] 30 | 31 | expected_remapped = [ 32 | [1, 1], 33 | [1, 2], 34 | [3, 5], 35 | [4, 5], 36 | [3, 6] 37 | ] 38 | 39 | remapped = remap_span_indices_after_subword_tokenization( 40 | original_span_indices, offsets, len(indexed['wordpiece']) 41 | ) 42 | 43 | self.assertEqual(expected_remapped, remapped) 44 | 45 | 46 | class TestTokenCharactersIndexerTokenizer(unittest.TestCase): 47 | 48 | def test_token_characters_indexer_tokenizer(self): 49 | params = Params({ 50 | "type": "characters_tokenizer", 51 | "tokenizer": { 52 | "type": "word", 53 | "word_splitter": {"type": "just_spaces"}, 54 | }, 55 | "namespace": "tok" 56 | }) 57 | 58 | indexer = TokenIndexer.from_params(params) 59 | 60 | vocab = Vocabulary() 61 | vocab.add_token_to_namespace("the", namespace="tok") 62 | vocab.add_token_to_namespace("2", namespace="tok") 63 | 64 | indices = indexer.tokens_to_indices( 65 | [Token(t) for t in "the 2 .".split()], vocab, 'a' 66 | ) 67 | 68 | self.assertListEqual(indices['a'], [[2], [3], [1]]) 69 | 70 | 71 | 72 | if __name__ == '__main__': 73 | unittest.main() 74 | 75 | -------------------------------------------------------------------------------- /tests/test_kg_embedding.py: -------------------------------------------------------------------------------- 1 | 2 | import unittest 3 | import torch 4 | 5 | from kb.kg_embedding import KGTupleReader, get_labels_tensor_from_indices, \ 6 | RankingAndHitsMetric 7 | 8 | from allennlp.data import Vocabulary 9 | from allennlp.common import Params 10 | from allennlp.data.iterators import BasicIterator 11 | 12 | 13 | class TestRankingAndHitsMetric(unittest.TestCase): 14 | def test_ranking_and_hits(self): 15 | batch_size = 2 16 | num_entities = 5 17 | 18 | predicted = torch.rand(batch_size, num_entities) 19 | all_entity2 = torch.LongTensor([[2, 3, 4], [1, 0, 0]]) 20 | entity2 = torch.LongTensor([3, 2]) 21 | 22 | metric = RankingAndHitsMetric() 23 | metric(predicted, all_entity2, entity2) 24 | metrics = metric.get_metric() 25 | 26 | self.assertTrue(True) 27 | 28 | 29 | class TestLabelFromIndices(unittest.TestCase): 30 | def test_get_labels_tensor_from_indices(self): 31 | batch_size = 2 32 | num_embeddings = 7 33 | entity_ids = torch.tensor([[5, 6, 1, 0, 0, 0], [1, 2, 3, 4, 5, 6]]) 34 | labels = get_labels_tensor_from_indices( 35 | batch_size, num_embeddings, entity_ids 36 | ) 37 | 38 | expected_labels = torch.tensor( 39 | [[0., 1., 0., 0., 0., 1., 1.], 40 | [0., 1., 1., 1., 1., 1., 1.]] 41 | ) 42 | 43 | self.assertTrue(torch.abs(labels - expected_labels).max() < 1e-6) 44 | 45 | def test_get_labels_tensor_from_indices_with_smoothing(self): 46 | batch_size = 2 47 | num_embeddings = 7 48 | entity_ids = torch.tensor([[5, 6, 1, 0, 0, 0], [1, 2, 3, 4, 5, 6]]) 49 | labels = get_labels_tensor_from_indices( 50 | batch_size, num_embeddings, entity_ids, label_smoothing=0.1 51 | ) 52 | 53 | expected_labels = torch.tensor( 54 | [[0., 1., 0., 0., 0., 1., 1.], 55 | [0., 1., 1., 1., 1., 1., 1.]] 56 | ) 57 | smoothed_labels = (1.0 - 0.1) * expected_labels + 0.1 / 7 * torch.tensor([[3.0], [6.0]]) 58 | 59 | self.assertTrue(torch.abs(labels - smoothed_labels).max() < 1e-6) 60 | 61 | 62 | class TestKGTupleReader(unittest.TestCase): 63 | def _check_batch(self, batch, vocab, expected_entity, expected_relation, expected_entity2): 64 | expected_entity_ids = [vocab.get_token_index(str(e), 'entity') 65 | for e in expected_entity] 66 | self.assertListEqual(batch['entity']['entity'].flatten().tolist(), 67 | expected_entity_ids) 68 | 69 | expected_relation_ids = [vocab.get_token_index(r, 'relation') 70 | for r in expected_relation] 71 | self.assertListEqual(batch['relation']['relation'].flatten().tolist(), 72 | expected_relation_ids) 73 | 74 | # check the entity2 75 | expected_entity2_ids = [ 76 | [vocab.get_token_index(str(e), 'entity') for e in ee] 77 | for ee in expected_entity2 78 | ] 79 | self.assertEqual(len(expected_entity2), batch['entity2']['entity'].shape[0]) 80 | for k in range(len(expected_entity2)): 81 | self.assertListEqual( 82 | sorted(expected_entity2_ids[k]), 83 | sorted([e for e in batch['entity2']['entity'][k].tolist() if e != 0]) 84 | ) 85 | 86 | def test_no_eval(self): 87 | reader = KGTupleReader() 88 | instances = reader.read('tests/fixtures/kg_embeddings/wn18rr_train.txt') 89 | 90 | self.assertTrue(len(instances) == 8) 91 | 92 | # create the vocab and index to make sure things look good 93 | vocab = Vocabulary.from_params(Params({}), instances) 94 | # (+2 for @@PADDING@@ and @@UNKNOWN@@ 95 | self.assertEqual(vocab.get_vocab_size("entity"), 5 + 2) 96 | self.assertEqual(vocab.get_vocab_size("relation"), 4 + 2) 97 | 98 | # now get a batch 99 | iterator = BasicIterator(batch_size=32) 100 | iterator.index_with(vocab) 101 | for batch in iterator(instances, num_epochs=1, shuffle=False): 102 | pass 103 | 104 | # check it! 105 | expected_entity = [1, 2, 1, 3, 3, 4, 1, 5] 106 | expected_relation = ['_hypernym', '_hypernym_reverse', 107 | '_derivationally_related_form', '_derivationally_related_form_reverse', 108 | '_hypernym_reverse', '_hypernym', '_hypernym_reverse', 109 | '_hypernym_reverse'] 110 | expected_entity2 = [[2, 3], [1], [3], [1], [1], [1, 5], [4], [4]] 111 | 112 | self._check_batch(batch, vocab, 113 | expected_entity, expected_relation, expected_entity2) 114 | 115 | def test_kg_reader_with_eval(self): 116 | train_file = 'tests/fixtures/kg_embeddings/wn18rr_train.txt' 117 | dev_file = 'tests/fixtures/kg_embeddings/wn18rr_dev.txt' 118 | 119 | train_instances = KGTupleReader().read(train_file) 120 | 121 | reader = KGTupleReader(extra_files_for_gold_pairs=[train_file]) 122 | instances = reader.read(dev_file) 123 | self.assertEqual(len(instances), 2) 124 | 125 | vocab = Vocabulary.from_params(Params({}), train_instances + instances) 126 | iterator = BasicIterator(batch_size=32) 127 | iterator.index_with(vocab) 128 | for batch in iterator(instances, num_epochs=1, shuffle=False): 129 | pass 130 | 131 | expected_entity = [1, 5] 132 | expected_relation = ['_hypernym', '_hypernym_reverse'] 133 | expected_entity2 = [[5, 2, 3], [1, 4]] 134 | self._check_batch(batch, vocab, 135 | expected_entity, expected_relation, expected_entity2) 136 | 137 | 138 | if __name__ == '__main__': 139 | unittest.main() 140 | 141 | -------------------------------------------------------------------------------- /tests/test_kg_probe_reader.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from allennlp.common import Params 4 | from allennlp.common.util import ensure_list 5 | from allennlp.data import DatasetReader 6 | import numpy as np 7 | 8 | from kb.bert_tokenizer_and_candidate_generator import TokenizerAndCandidateGenerator 9 | from kb.kg_probe_reader import KgProbeReader 10 | from kb.wordnet import WordNetCandidateMentionGenerator 11 | 12 | 13 | def get_reader(): 14 | params = { 15 | "type": "kg_probe", 16 | "tokenizer_and_candidate_generator": { 17 | "type": "bert_tokenizer_and_candidate_generator", 18 | "entity_candidate_generators": { 19 | "wordnet": {"type": "wordnet_mention_generator", 20 | "entity_file": "tests/fixtures/wordnet/entities_fixture.jsonl"} 21 | }, 22 | "entity_indexers": { 23 | "wordnet": { 24 | "type": "characters_tokenizer", 25 | "tokenizer": { 26 | "type": "word", 27 | "word_splitter": {"type": "just_spaces"}, 28 | }, 29 | "namespace": "entity" 30 | } 31 | }, 32 | "bert_model_type": "tests/fixtures/bert/vocab.txt", 33 | "do_lower_case": True, 34 | }, 35 | } 36 | 37 | return DatasetReader.from_params(Params(params)) 38 | 39 | 40 | class TestKgProbeReader(unittest.TestCase): 41 | def test_kg_probe_reader(self): 42 | reader = get_reader() 43 | instances = ensure_list(reader.read('tests/fixtures/kg_probe/file1.txt')) 44 | 45 | # Check instances are correct length 46 | self.assertEqual(len(instances), 2) 47 | 48 | # Check masking is performed properly 49 | expected_tokens_0 = ['[CLS]', '[MASK]', '[MASK]', '[UNK]', 'quick', 50 | '##est', '.', '[SEP]'] 51 | tokens_0 = [x.text for x in instances[0]['tokens'].tokens] 52 | self.assertListEqual(expected_tokens_0, tokens_0) 53 | 54 | expected_mask_indicator_0 = np.array([0,1,1,0,0,0,0,0], dtype=np.uint8) 55 | mask_indicator_0 = instances[0]['mask_indicator'].array 56 | assert np.allclose(expected_mask_indicator_0, mask_indicator_0) 57 | 58 | expected_tokens_1 = ['[CLS]', 'the', 'brown', 'fox', 'jumped', 'over', 59 | 'the', '[MASK]', '[MASK]', '[MASK]', '[MASK]', 60 | '.', '[SEP]'] 61 | tokens_1 = [x.text for x in instances[1]['tokens'].tokens] 62 | self.assertListEqual(expected_tokens_1, tokens_1) 63 | 64 | expected_mask_indicator_1 = np.array([0,0,0,0,0,0,0,1,1,1,1,0,0], dtype=np.uint8) 65 | mask_indicator_1 = instances[1]['mask_indicator'].array 66 | assert np.allclose(expected_mask_indicator_1, mask_indicator_1) 67 | 68 | 69 | if __name__ == '__main__': 70 | unittest.main() 71 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from kb.metrics import MeanReciprocalRank, MicroF1 6 | 7 | 8 | class TestMeanReciprocalRank(unittest.TestCase): 9 | def test_mrr(self): 10 | labels = torch.tensor([[1,2,3], [1,2,0]]) 11 | predictions = torch.tensor([[ 12 | [-4.00, 1.00, 0.00, -1.00], 13 | [-4.00, 1.00, 0.00, -1.00], 14 | [-4.00, 0.50, 0.00, 1.00] 15 | ], [ 16 | [-4.00, 1.00, 0.00, -1.00], 17 | [-4.00, 1.00, 0.00, -1.00], 18 | [ 0.00, 0.00, 0.00, 0.00] 19 | ]]) 20 | mask = torch.tensor([[1,1,1],[1,1,0]], dtype=torch.uint8) 21 | 22 | metric = MeanReciprocalRank() 23 | metric(predictions, labels, mask) 24 | 25 | expected = 0.8 26 | output = metric.get_metric(reset=True) 27 | self.assertAlmostEqual(expected, output) 28 | 29 | expected = 0.0 30 | output = metric.get_metric() 31 | self.assertAlmostEqual(expected, output) 32 | 33 | 34 | class TestMicroF1(unittest.TestCase): 35 | def test_micro_f1(self): 36 | labels = torch.tensor([0, 1, 1, 1, 0, 0], dtype=torch.int32) 37 | predictions = torch.tensor([1, 1, 0, 1, 1, 0], dtype=torch.int32) 38 | 39 | metric = MicroF1(negative_label=0) 40 | metric(predictions, labels) 41 | 42 | precision, recall, f1 = metric.get_metric(reset=True) 43 | self.assertAlmostEqual(precision, 1/2) 44 | self.assertAlmostEqual(recall, 2/3) 45 | self.assertAlmostEqual(f1, 2 * 1/2 * 2/3 / (1/2 + 2/3)) 46 | 47 | 48 | precision, recall, f1 = metric.get_metric(reset=True) 49 | self.assertAlmostEqual(precision, 0) 50 | self.assertAlmostEqual(recall, 0) 51 | self.assertAlmostEqual(f1, 0) 52 | -------------------------------------------------------------------------------- /tests/test_multitask.py: -------------------------------------------------------------------------------- 1 | 2 | import unittest 3 | 4 | from allennlp.common.params import Params 5 | from allennlp.data.dataset_readers import CcgBankDatasetReader 6 | from allennlp.data.dataset_readers import DatasetReader 7 | from allennlp.data.dataset_readers import Conll2003DatasetReader 8 | from allennlp.data.iterators import DataIterator 9 | from allennlp.data import Vocabulary 10 | 11 | from kb.multitask import MultitaskDatasetReader, MultiTaskDataIterator 12 | 13 | 14 | FIXTURES_ROOT = 'tests/fixtures/multitask' 15 | 16 | 17 | def get_dataset_params_paths(datasets_for_vocab_creation): 18 | params = Params({ 19 | "type": "multitask_reader", 20 | "dataset_readers": { 21 | "ner": { 22 | "type": "conll2003", 23 | "tag_label": "ner", 24 | "token_indexers": { 25 | "tokens": { 26 | "type": "single_id", 27 | }, 28 | } 29 | }, 30 | "ccg": { 31 | "type": "ccgbank", 32 | "token_indexers": { 33 | "tokens": { 34 | "type": "single_id", 35 | }, 36 | }, 37 | "feature_labels": ["original_pos"], 38 | } 39 | }, 40 | "datasets_for_vocab_creation": datasets_for_vocab_creation 41 | }) 42 | 43 | file_paths = { 44 | 'ner': FIXTURES_ROOT + '/conll2003.txt', 45 | 'ccg': FIXTURES_ROOT + '/ccgbank.txt' 46 | } 47 | 48 | return params, file_paths 49 | 50 | 51 | class TestMultiTaskDatasetReader(unittest.TestCase): 52 | def test_read(self): 53 | params, file_paths = get_dataset_params_paths(["ner"]) 54 | 55 | multitask_reader = DatasetReader.from_params(params) 56 | 57 | dataset = multitask_reader.read(file_paths) 58 | 59 | # get all the instances -- only should have "original_pos_tags" 60 | # for NER 61 | for name, instances in dataset.datasets.items(): 62 | self.assertTrue(name in ('ner', 'ccg')) 63 | for instance in instances: 64 | if name == 'ner': 65 | self.assertTrue("original_pos_tags" not in instance.fields) 66 | else: 67 | self.assertTrue("original_pos_tags" in instance.fields) 68 | 69 | # when iterating directly, only get 'ner' 70 | for instance in dataset: 71 | self.assertTrue("original_pos_tags" not in instance.fields) 72 | 73 | 74 | class TestMultiTaskDataIterator(unittest.TestCase): 75 | def test_multi_iterator(self): 76 | params, file_paths = get_dataset_params_paths(['ner', 'ccg']) 77 | 78 | multitask_reader = DatasetReader.from_params(params) 79 | dataset = multitask_reader.read(file_paths) 80 | 81 | iterator_params = Params({ 82 | "type": "multitask_iterator", 83 | "iterators": { 84 | "ner": {"type": "bucket", 85 | "sorting_keys": [["tokens", "num_tokens"]], 86 | "padding_noise": 0.0, 87 | "batch_size" : 2}, 88 | "ccg": {"type": "basic", 89 | "batch_size" : 1} 90 | }, 91 | "names_to_index": ["ner", "ccg"], 92 | }) 93 | 94 | multi_iterator = DataIterator.from_params(iterator_params) 95 | 96 | # make the vocab 97 | vocab = Vocabulary.from_params(Params({}), 98 | (instance for instance in dataset)) 99 | multi_iterator.index_with(vocab) 100 | 101 | all_batches = [] 102 | for epoch in range(2): 103 | all_batches.append([]) 104 | for batch in multi_iterator(dataset, shuffle=True, 105 | num_epochs=1): 106 | all_batches[-1].append(batch) 107 | 108 | # 3 batches per epoch - 109 | self.assertEqual([len(b) for b in all_batches], [3, 3]) 110 | 111 | ner_batches = [] 112 | ccg_batches = [] 113 | for epoch_batches in all_batches: 114 | ner_batches.append(0) 115 | ccg_batches.append(0) 116 | for batch in epoch_batches: 117 | if 'original_pos_tags' not in batch: 118 | ner_batches[-1] += 1 119 | if 'original_pos_tags' in batch: 120 | ccg_batches[-1] += 1 121 | 122 | # 1 NER batch per epoch, 2 CCG per epoch 123 | self.assertEqual(ner_batches, [1, 1]) 124 | self.assertEqual(ccg_batches, [2, 2]) 125 | 126 | 127 | if __name__ == '__main__': 128 | unittest.main() 129 | 130 | -------------------------------------------------------------------------------- /tests/test_self_attn_iterator.py: -------------------------------------------------------------------------------- 1 | 2 | import unittest 3 | 4 | import torch 5 | 6 | from kb.self_attn_bucket_iterator import SelfAttnBucketIterator 7 | from allennlp.data.fields import TextField 8 | from allennlp.data import Token, Vocabulary, Instance 9 | from allennlp.data.iterators import BucketIterator 10 | from allennlp.data.token_indexers import SingleIdTokenIndexer 11 | 12 | 13 | class TestSelfAttnBucketIterator(unittest.TestCase): 14 | def test_self_attn_iterator(self): 15 | indexer = {'tokens': SingleIdTokenIndexer()} 16 | 17 | # make some instances 18 | instances = [] 19 | for k in range(100): 20 | l = max(int(torch.rand(1).item() * 500), 1) 21 | instances.append(Instance( 22 | {'tokens': TextField( 23 | [Token('a') for i in range(l)], token_indexers=indexer)}) 24 | ) 25 | 26 | schedule = [[16, 128], [8, 256], [4, 512]] 27 | 28 | sub_iterator = BucketIterator( 29 | batch_size=16, 30 | sorting_keys=[['tokens', 'num_tokens']], 31 | padding_noise=0.0 32 | ) 33 | 34 | it = SelfAttnBucketIterator(schedule, sub_iterator) 35 | it.index_with(Vocabulary()) 36 | 37 | batches = [batch for batch in it(instances, num_epochs=1)] 38 | 39 | n_instances = 0 40 | for batch in batches: 41 | batch_size = batch['tokens']['tokens'].shape[0] 42 | n_instances += batch_size 43 | timesteps = batch['tokens']['tokens'].shape[1] 44 | if timesteps <= 128: 45 | expected_batch_size = 16 46 | elif timesteps <= 256: 47 | expected_batch_size = 8 48 | else: 49 | expected_batch_size = 4 50 | # batch might be smaller then expected if we split a larger batch 51 | # and the sequence length for the shorter segment falls into a lower 52 | # bucket 53 | self.assertTrue(batch_size <= expected_batch_size) 54 | 55 | self.assertEqual(n_instances, 100) 56 | 57 | if __name__ == '__main__': 58 | unittest.main() 59 | 60 | -------------------------------------------------------------------------------- /tests/test_span_attention_layer.py: -------------------------------------------------------------------------------- 1 | 2 | from pytorch_pretrained_bert.modeling import BertConfig 3 | import json 4 | 5 | from kb.span_attention_layer import SpanAttentionLayer, SpanWordAttention 6 | 7 | import unittest 8 | 9 | import torch 10 | 11 | class TestSpanAttentionLayer(unittest.TestCase): 12 | def test_span_word_attention(self): 13 | config_file = 'tests/fixtures/bert/bert_config.json' 14 | with open(config_file) as fin: 15 | json_config = json.load(fin) 16 | 17 | vocab_size = json_config.pop("vocab_size") 18 | config = BertConfig(vocab_size, **json_config) 19 | 20 | span_attn = SpanWordAttention(config) 21 | 22 | batch_size = 7 23 | timesteps = 29 24 | hidden_states = torch.rand(batch_size, timesteps, config.hidden_size) 25 | 26 | num_entity_embeddings = 11 27 | entity_embeddings = torch.rand(batch_size, num_entity_embeddings, config.hidden_size) 28 | entity_mask = entity_embeddings[:, :, 0] > 0.5 29 | 30 | span_attn, attention_probs = span_attn(hidden_states, entity_embeddings, entity_mask) 31 | self.assertEqual(list(span_attn.shape), [batch_size, timesteps, config.hidden_size]) 32 | 33 | def test_span_attention_layer(self): 34 | config_file = 'tests/fixtures/bert/bert_config.json' 35 | with open(config_file) as fin: 36 | json_config = json.load(fin) 37 | 38 | vocab_size = json_config.pop("vocab_size") 39 | config = BertConfig(vocab_size, **json_config) 40 | 41 | batch_size = 7 42 | timesteps = 29 43 | hidden_states = torch.rand(batch_size, timesteps, config.hidden_size) 44 | 45 | num_entity_embeddings = 11 46 | entity_embeddings = torch.rand(batch_size, num_entity_embeddings, config.hidden_size) 47 | entity_mask = entity_embeddings[:, :, 0] > 0.5 48 | 49 | span_attention_layer = SpanAttentionLayer(config) 50 | 51 | output = span_attention_layer(hidden_states, entity_embeddings, entity_mask) 52 | 53 | self.assertEqual(list(output["output"].shape), [batch_size, timesteps, config.hidden_size]) 54 | 55 | 56 | if __name__ == '__main__': 57 | unittest.main() 58 | 59 | -------------------------------------------------------------------------------- /tests/test_wiki_reader.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy 3 | 4 | from allennlp.common.testing.test_case import AllenNlpTestCase 5 | 6 | from kb.wiki_linking_reader import LinkingReader 7 | from kb.wiki_linking_util import WikiCandidateMentionGenerator 8 | 9 | from kb.bert_tokenizer_and_candidate_generator import BertTokenizerAndCandidateGenerator 10 | from kb.wordnet import WordNetCandidateMentionGenerator 11 | 12 | from allennlp.common import Params 13 | from allennlp.data import TokenIndexer, Vocabulary, DataIterator, Instance 14 | 15 | 16 | class MentionGeneratorTest(AllenNlpTestCase): 17 | 18 | def test_read(self): 19 | 20 | candidate_generator = WikiCandidateMentionGenerator("tests/fixtures/linking/priors.txt") 21 | assert len(candidate_generator.p_e_m) == 50 22 | 23 | assert set(candidate_generator.p_e_m.keys()) == { 24 | 'United States', 'Information', 'Wiki', 'France', 'English', 'Germany', 25 | 'World War II', '2007', 'England', 'American', 'Canada', 'Australia', 26 | 'Japan', '2008', 'India', '2006', 'Area Info', 'London', 'German', 27 | 'About Company', 'French', 'United Kingdom', 'Italy', 'en', 'California', 28 | 'China', '2005', 'New York', 'Spain', 'Europe', 'British', '2004', 29 | 'New York City', 'Russia', 'public domain', '2000', 'Brazil', 'Poland', 30 | 'micro-blogging', 'Greek', 'New Zealand', '2003', 'Mexico', 'Italian', 31 | 'Ireland', 'Wiki Image', 'Paris', 'USA', '[1]', 'Iran' 32 | } 33 | 34 | 35 | lower = candidate_generator.process("united states") 36 | string_list = candidate_generator.process(["united", "states"]) 37 | upper = candidate_generator.process(["United", "States"]) 38 | assert lower == upper == string_list 39 | 40 | 41 | class WikiReaderTest(AllenNlpTestCase): 42 | def test_wiki_linking_reader_with_wordnet(self): 43 | def _get_indexer(namespace): 44 | return TokenIndexer.from_params(Params({ 45 | "type": "characters_tokenizer", 46 | "tokenizer": { 47 | "type": "word", 48 | "word_splitter": {"type": "just_spaces"}, 49 | }, 50 | "namespace": namespace 51 | })) 52 | 53 | extra_generator = { 54 | 'wordnet': WordNetCandidateMentionGenerator( 55 | 'tests/fixtures/wordnet/entities_fixture.jsonl') 56 | } 57 | 58 | fake_entity_world = {"Germany":"11867", "United_Kingdom": "31717", "European_Commission": "42336"} 59 | candidate_generator = WikiCandidateMentionGenerator('tests/fixtures/linking/priors.txt', 60 | entity_world_path=fake_entity_world) 61 | train_file = 'tests/fixtures/linking/aida.txt' 62 | 63 | reader = LinkingReader(mention_generator=candidate_generator, 64 | entity_indexer=_get_indexer("entity_wiki"), 65 | extra_candidate_generators=extra_generator) 66 | instances = reader.read(train_file) 67 | 68 | assert len(instances) == 2 69 | 70 | 71 | def test_wiki_linking_reader(self): 72 | 73 | fake_entity_world = {"Germany":"11867", "United_Kingdom": "31717", "European_Commission": "42336"} 74 | candidate_generator = WikiCandidateMentionGenerator('tests/fixtures/linking/priors.txt', 75 | entity_world_path=fake_entity_world) 76 | train_file = 'tests/fixtures/linking/aida.txt' 77 | 78 | reader = LinkingReader(mention_generator=candidate_generator) 79 | instances = reader.read(train_file) 80 | 81 | instances = list(instances) 82 | 83 | fields = instances[0].fields 84 | 85 | text = [x.text for x in fields["tokens"].tokens] 86 | assert text == ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'] 87 | 88 | spans = fields["candidate_spans"].field_list 89 | span_starts, span_ends = zip(*[(field.span_start, field.span_end) for field in spans]) 90 | assert span_starts == (6, 2) 91 | assert span_ends == (6, 2) 92 | gold_ids = [x.text for x in fields["gold_entities"].tokens] 93 | assert gold_ids == ['United_Kingdom', 'Germany'] 94 | 95 | candidate_token_list = [x.text for x in fields["candidate_entities"].tokens] 96 | candidate_tokens = [] 97 | for x in candidate_token_list: 98 | candidate_tokens.extend(x.split(" ")) 99 | 100 | assert candidate_tokens == ['United_Kingdom', 'Germany'] 101 | 102 | numpy.testing.assert_array_almost_equal(fields["candidate_entity_prior"].array, numpy.array([[1.], [1.]])) 103 | fields = instances[1].fields 104 | text = [x.text for x in fields["tokens"].tokens] 105 | assert text ==['The', 'European', 'Commission', 'said', 'on', 'Thursday', 'it', 'disagreed', 106 | 'with', 'German', 'advice', 'to', 'consumers', 'to', 'shun', 'British', 'lamb', 107 | 'until', 'scientists', 'determine', 'whether', 'it', 'is', 'dangerous'] 108 | 109 | spans = fields["candidate_spans"].field_list 110 | span_starts, span_ends = zip(*[(field.span_start, field.span_end) for field in spans]) 111 | assert span_starts == (15, 9) 112 | assert span_ends == (15, 9) 113 | gold_ids = [x.text for x in fields["gold_entities"].tokens] 114 | # id not inside our mini world, should be ignored 115 | assert "European_Commission" not in gold_ids 116 | assert gold_ids == ['United_Kingdom', 'Germany'] 117 | candidate_token_list = [x.text for x in fields["candidate_entities"].tokens] 118 | candidate_tokens = [] 119 | for x in candidate_token_list: 120 | candidate_tokens.extend(x.split(" ")) 121 | assert candidate_tokens == ['United_Kingdom', 'Germany'] 122 | 123 | numpy.testing.assert_array_almost_equal(fields["candidate_entity_prior"].array, numpy.array([[1.], [1.]])) 124 | 125 | class TestWikiCandidateMentionGenerator(AllenNlpTestCase): 126 | def test_wiki_candidate_generator_no_candidates(self): 127 | fake_entity_world = {"Germany":"11867", "United_Kingdom": "31717", "European_Commission": "42336"} 128 | 129 | candidate_generator = WikiCandidateMentionGenerator( 130 | 'tests/fixtures/linking/priors.txt', 131 | entity_world_path=fake_entity_world 132 | ) 133 | 134 | candidates = candidate_generator.get_mentions_raw_text(".") 135 | assert candidates['candidate_entities'] == [['@@PADDING@@']] 136 | 137 | def test_wiki_candidate_generator_simple(self): 138 | candidate_generator = WikiCandidateMentionGenerator( 139 | 'tests/fixtures/linking/priors.txt', 140 | ) 141 | s = "Mexico is bordered to the north by the United States." 142 | 143 | # first candidate in each list 144 | candidates = candidate_generator.get_mentions_raw_text(s) 145 | first_prior = [span_candidates[0] for span_candidates in candidates['candidate_entities']] 146 | assert first_prior == ['Mexico', 'United_States'] 147 | 148 | # now do it randomly 149 | candidate_generator.random_candidates = True 150 | candidate_generator.p_e_m_keys_for_sampling = list(candidate_generator.p_e_m.keys()) 151 | candidates = candidate_generator.get_mentions_raw_text(s) 152 | first_prior = [span_candidates[0] for span_candidates in candidates['candidate_entities']] 153 | assert first_prior != ['Mexico', 'United_States'] 154 | -------------------------------------------------------------------------------- /training_config/downstream/entity_typing.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "ultra_fine", 4 | "entity_masking": "entity_markers", 5 | "tokenizer_and_candidate_generator": { 6 | "type": "bert_tokenizer_and_candidate_generator", 7 | "bert_model_type": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/bert-base-uncased-tacred-entity-markers-vocab.txt", 8 | "do_lower_case": true, 9 | "entity_candidate_generators": { 10 | "wiki": { 11 | "type": "wiki", 12 | }, 13 | "wordnet": { 14 | "type": "wordnet_mention_generator", 15 | "entity_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/entities.jsonl" 16 | } 17 | }, 18 | "entity_indexers": { 19 | "wiki": { 20 | "type": "characters_tokenizer", 21 | "namespace": "entity_wiki", 22 | "tokenizer": { 23 | "type": "word", 24 | "word_splitter": { 25 | "type": "just_spaces" 26 | } 27 | } 28 | }, 29 | "wordnet": { 30 | "type": "characters_tokenizer", 31 | "namespace": "entity_wordnet", 32 | "tokenizer": { 33 | "type": "word", 34 | "word_splitter": { 35 | "type": "just_spaces" 36 | } 37 | } 38 | } 39 | } 40 | } 41 | }, 42 | "iterator": { 43 | "iterator": { 44 | "type": "basic", 45 | "batch_size": 32 46 | }, 47 | "type": "self_attn_bucket", 48 | "batch_size_schedule": "base-12gb-fp32" 49 | }, 50 | "model": { 51 | "model": { 52 | "type": "from_archive", 53 | "archive_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_wordnet_model.tar.gz", 54 | }, 55 | "type": "simple-classifier", 56 | "bert_dim": 768, 57 | "concat_word_a": true, 58 | "include_cls": false, 59 | "metric_a": { 60 | "type": "f1_set" 61 | }, 62 | "num_labels": 9, 63 | "task": "classification", 64 | "use_bce_loss": true 65 | }, 66 | "train_data_path": "/home/matthewp/data/thunlp/OpenEntity/train.json", 67 | "validation_data_path": "/home/matthewp/data/thunlp/OpenEntity/dev.json", 68 | "trainer": { 69 | "cuda_device": 0, 70 | "gradient_accumulation_batch_size": 32, 71 | "learning_rate_scheduler": { 72 | "type": "slanted_triangular", 73 | "num_epochs": 10, 74 | "num_steps_per_epoch": 62.5 75 | }, 76 | "num_epochs": 10, 77 | "num_serialized_models_to_keep": 1, 78 | "optimizer": { 79 | "type": "bert_adam", 80 | "b2": 0.98, 81 | "lr": 3e-05, 82 | "max_grad_norm": 1, 83 | "parameter_groups": [ 84 | [ 85 | [ 86 | "bias", 87 | "LayerNorm.bias", 88 | "LayerNorm.weight", 89 | "layer_norm.weight" 90 | ], 91 | { 92 | "weight_decay": 0 93 | } 94 | ] 95 | ], 96 | "t_total": -1, 97 | "weight_decay": 0.01 98 | }, 99 | "should_log_learning_rate": true, 100 | "validation_metric": "+f1" 101 | }, 102 | "vocabulary": { 103 | "directory_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/vocabulary_wordnet_wiki.tar.gz" 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /training_config/downstream/semeval2010_task8.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "semeval2010_task8", 4 | "entity_masking": "entity_markers", 5 | "tokenizer_and_candidate_generator": { 6 | "type": "bert_tokenizer_and_candidate_generator", 7 | "bert_model_type": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/bert-base-uncased-tacred-entity-markers-vocab.txt", 8 | "do_lower_case": true, 9 | "entity_candidate_generators": { 10 | "wiki": { 11 | "type": "wiki", 12 | }, 13 | "wordnet": { 14 | "type": "wordnet_mention_generator", 15 | "entity_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/entities.jsonl" 16 | } 17 | }, 18 | "entity_indexers": { 19 | "wiki": { 20 | "type": "characters_tokenizer", 21 | "namespace": "entity_wiki", 22 | "tokenizer": { 23 | "type": "word", 24 | "word_splitter": { 25 | "type": "just_spaces" 26 | } 27 | } 28 | }, 29 | "wordnet": { 30 | "type": "characters_tokenizer", 31 | "namespace": "entity_wordnet", 32 | "tokenizer": { 33 | "type": "word", 34 | "word_splitter": { 35 | "type": "just_spaces" 36 | } 37 | } 38 | } 39 | } 40 | } 41 | }, 42 | "iterator": { 43 | "iterator": { 44 | "type": "basic", 45 | "batch_size": 32 46 | }, 47 | "type": "self_attn_bucket", 48 | "batch_size_schedule": "base-12gb-fp32" 49 | }, 50 | "model": { 51 | "model": { 52 | "type": "from_archive", 53 | "archive_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_wordnet_model.tar.gz", 54 | }, 55 | "type": "simple-classifier", 56 | "bert_dim": 768, 57 | "concat_word_a_b": true, 58 | "include_cls": false, 59 | "metric_a": { 60 | "type": "semeval2010_task8_metric" 61 | }, 62 | "num_labels": 19, 63 | "task": "classification" 64 | }, 65 | "train_data_path": "/home/matthewp/data/semeval2010_task8/train.json", 66 | "validation_data_path": "/home/matthewp/data/semeval2010_task8/dev.json", 67 | "trainer": { 68 | "cuda_device": 0, 69 | "gradient_accumulation_batch_size": 32, 70 | "learning_rate_scheduler": { 71 | "type": "slanted_triangular", 72 | "num_epochs": 3, 73 | "num_steps_per_epoch": 234.375 74 | }, 75 | "num_epochs": 3, 76 | "num_serialized_models_to_keep": 1, 77 | "optimizer": { 78 | "type": "bert_adam", 79 | "b2": 0.98, 80 | "lr": 5e-05, 81 | "max_grad_norm": 1, 82 | "parameter_groups": [ 83 | [ 84 | [ 85 | "bias", 86 | "LayerNorm.bias", 87 | "LayerNorm.weight", 88 | "layer_norm.weight" 89 | ], 90 | { 91 | "weight_decay": 0 92 | } 93 | ] 94 | ], 95 | "t_total": -1, 96 | "weight_decay": 0.01 97 | }, 98 | "should_log_learning_rate": true, 99 | "validation_metric": "+f1" 100 | }, 101 | "vocabulary": { 102 | "directory_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/vocabulary_wordnet_wiki.tar.gz" 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /training_config/downstream/tacred.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "tacred", 4 | "entity_masking": "entity_markers/type", 5 | "tokenizer_and_candidate_generator": { 6 | "type": "bert_tokenizer_and_candidate_generator", 7 | "bert_model_type": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/bert-base-uncased-tacred-entity-markers-vocab.txt", 8 | "do_lower_case": true, 9 | "entity_candidate_generators": { 10 | "wiki": { 11 | "type": "wiki" 12 | }, 13 | "wordnet": { 14 | "type": "wordnet_mention_generator", 15 | "entity_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/entities.jsonl" 16 | } 17 | }, 18 | "entity_indexers": { 19 | "wiki": { 20 | "type": "characters_tokenizer", 21 | "namespace": "entity_wiki", 22 | "tokenizer": { 23 | "type": "word", 24 | "word_splitter": { 25 | "type": "just_spaces" 26 | } 27 | } 28 | }, 29 | "wordnet": { 30 | "type": "characters_tokenizer", 31 | "namespace": "entity_wordnet", 32 | "tokenizer": { 33 | "type": "word", 34 | "word_splitter": { 35 | "type": "just_spaces" 36 | } 37 | } 38 | } 39 | } 40 | } 41 | }, 42 | "iterator": { 43 | "iterator": { 44 | "type": "basic", 45 | "batch_size": 32 46 | }, 47 | "type": "self_attn_bucket", 48 | "batch_size_schedule": "base-12gb-fp32" 49 | }, 50 | "model": { 51 | "model": { 52 | "type": "from_archive", 53 | "archive_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_wordnet_model.tar.gz", 54 | }, 55 | "type": "simple-classifier", 56 | "bert_dim": 768, 57 | "concat_word_a_b": true, 58 | "include_cls": false, 59 | "metric_a": { 60 | "type": "microf1", 61 | "negative_label": 0 62 | }, 63 | "num_labels": 42, 64 | "task": "classification" 65 | }, 66 | "train_data_path": "/home/matthewp/data/tacred/train.json", 67 | "validation_data_path": "/home/matthewp/data/tacred/dev.json", 68 | "trainer": { 69 | "cuda_device": 0, 70 | "gradient_accumulation_batch_size": 32, 71 | "learning_rate_scheduler": { 72 | "type": "slanted_triangular", 73 | "num_epochs": 3, 74 | "num_steps_per_epoch": 2128.875 75 | }, 76 | "num_epochs": 3, 77 | "num_serialized_models_to_keep": 1, 78 | "optimizer": { 79 | "type": "bert_adam", 80 | "b2": 0.98, 81 | "lr": 3e-05, 82 | "max_grad_norm": 1, 83 | "parameter_groups": [ 84 | [ 85 | [ 86 | "bias", 87 | "LayerNorm.bias", 88 | "LayerNorm.weight", 89 | "layer_norm.weight" 90 | ], 91 | { 92 | "weight_decay": 0 93 | } 94 | ] 95 | ], 96 | "t_total": -1, 97 | "weight_decay": 0.01 98 | }, 99 | "should_log_learning_rate": true, 100 | "validation_metric": "+micro_f1" 101 | }, 102 | "vocabulary": { 103 | "directory_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/vocabulary_wordnet_wiki.tar.gz" 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /training_config/downstream/wic.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "wic", 4 | "tokenizer_and_candidate_generator": { 5 | "type": "bert_tokenizer_and_candidate_generator", 6 | "bert_model_type": "bert-base-uncased", 7 | "do_lower_case": true, 8 | "entity_candidate_generators": { 9 | "wiki": { 10 | "type": "wiki" 11 | }, 12 | "wordnet": { 13 | "type": "wordnet_mention_generator", 14 | "entity_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/entities.jsonl" 15 | } 16 | }, 17 | "entity_indexers": { 18 | "wiki": { 19 | "type": "characters_tokenizer", 20 | "namespace": "entity_wiki", 21 | "tokenizer": { 22 | "type": "word", 23 | "word_splitter": { 24 | "type": "just_spaces" 25 | } 26 | } 27 | }, 28 | "wordnet": { 29 | "type": "characters_tokenizer", 30 | "namespace": "entity_wordnet", 31 | "tokenizer": { 32 | "type": "word", 33 | "word_splitter": { 34 | "type": "just_spaces" 35 | } 36 | } 37 | } 38 | } 39 | } 40 | }, 41 | "iterator": { 42 | "iterator": { 43 | "type": "basic", 44 | "batch_size": 32 45 | }, 46 | "type": "self_attn_bucket", 47 | "batch_size_schedule": "base-12gb-fp32" 48 | }, 49 | "model": { 50 | "model": { 51 | "type": "from_archive", 52 | "archive_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_wordnet_model.tar.gz", 53 | }, 54 | "type": "simple-classifier", 55 | "bert_dim": 768, 56 | "metric_a": { 57 | "type": "categorical_accuracy" 58 | }, 59 | "num_labels": 2, 60 | "task": "classification" 61 | }, 62 | "train_data_path": "/home/matthewp/data/wic/train", 63 | "validation_data_path": "/home/matthewp/data/wic/dev", 64 | "trainer": { 65 | "cuda_device": 0, 66 | "gradient_accumulation_batch_size": 32, 67 | "learning_rate_scheduler": { 68 | "type": "slanted_triangular", 69 | "num_epochs": 5, 70 | "num_steps_per_epoch": 169.75 71 | }, 72 | "moving_average": { 73 | "decay": 0.95 74 | }, 75 | "num_epochs": 5, 76 | "num_serialized_models_to_keep": 1, 77 | "optimizer": { 78 | "type": "bert_adam", 79 | "lr": 1e-05, 80 | "max_grad_norm": 1, 81 | "parameter_groups": [ 82 | [ 83 | [ 84 | "bias", 85 | "LayerNorm.bias", 86 | "LayerNorm.weight", 87 | "layer_norm.weight" 88 | ], 89 | { 90 | "weight_decay": 0 91 | } 92 | ] 93 | ], 94 | "t_total": -1, 95 | "weight_decay": 0.01 96 | }, 97 | "should_log_learning_rate": true, 98 | "validation_metric": "+accuracy" 99 | }, 100 | "vocabulary": { 101 | "directory_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/vocabulary_wordnet_wiki.tar.gz" 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /training_config/pretraining/knowbert_wiki_linker.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "vocabulary": { 3 | "directory_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/vocabulary_wiki.tar.gz" 4 | }, 5 | 6 | "dataset_reader": { 7 | "type": "aida_wiki_linking", 8 | "entity_disambiguation_only": false, 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "bert-pretrained", 12 | "pretrained_model": "bert-base-uncased", 13 | "do_lowercase": true, 14 | "use_starting_offsets": true, 15 | "max_pieces": 512, 16 | }, 17 | }, 18 | "entity_indexer": { 19 | "type": "characters_tokenizer", 20 | "tokenizer": { 21 | "type": "word", 22 | "word_splitter": {"type": "just_spaces"}, 23 | }, 24 | "namespace": "entity", 25 | }, 26 | "should_remap_span_indices": false, 27 | }, 28 | 29 | "iterator": { 30 | "type": "self_attn_bucket", 31 | "batch_size_schedule": "base-12gb-fp32", 32 | "iterator": { 33 | "type": "cross_sentence_linking", 34 | "batch_size": 32, 35 | "entity_indexer": { 36 | "type": "characters_tokenizer", 37 | "tokenizer": { 38 | "type": "word", 39 | "word_splitter": {"type": "just_spaces"}, 40 | }, 41 | "namespace": "entity" 42 | }, 43 | "bert_model_type": "bert-base-uncased", 44 | "do_lower_case": true, 45 | // this is ignored 46 | "mask_candidate_strategy": "none", 47 | "max_predictions_per_seq": 0, 48 | "id_type": "wiki", 49 | "use_nsp_label": false, 50 | }, 51 | }, 52 | 53 | "train_data_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wiki_entity_linking/aida_train.txt", 54 | "validation_data_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wiki_entity_linking/aida_dev.txt", 55 | 56 | "model": { 57 | "type": "knowbert", 58 | "bert_model_name": "bert-base-uncased", 59 | "mode": "entity_linking", 60 | "soldered_layers": {"wiki": 9}, 61 | "soldered_kgs": { 62 | "wiki": { 63 | "type": "soldered_kg", 64 | "entity_linker": { 65 | "type": "entity_linking_with_candidate_mentions", 66 | "entity_embedding": { 67 | "vocab_namespace": "entity", 68 | "embedding_dim": 300, 69 | "pretrained_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wiki_entity_linking/entities_glove_format.gz", 70 | "trainable": false, 71 | "sparse": false 72 | }, 73 | "contextual_embedding_dim": 768, 74 | "span_encoder_config": { 75 | "hidden_size": 300, 76 | "num_hidden_layers": 1, 77 | "num_attention_heads": 4, 78 | "intermediate_size": 1024 79 | }, 80 | }, 81 | "span_attention_config": { 82 | "hidden_size": 300, 83 | "num_hidden_layers": 1, 84 | "num_attention_heads": 4, 85 | "intermediate_size": 1024 86 | }, 87 | }, 88 | }, 89 | }, 90 | 91 | "trainer": { 92 | "optimizer": { 93 | "type": "bert_adam", 94 | "lr": 1e-3, 95 | "t_total": -1, 96 | "max_grad_norm": 1.0, 97 | "weight_decay": 0.01, 98 | "parameter_groups": [ 99 | [["bias", "LayerNorm.bias", "LayerNorm.weight", "layer_norm.weight"], {"weight_decay": 0.0}], 100 | ], 101 | }, 102 | "gradient_accumulation_batch_size": 32, 103 | "num_epochs": 10, 104 | 105 | "learning_rate_scheduler": { 106 | "type": "slanted_triangular", 107 | "num_epochs": 10, 108 | "num_steps_per_epoch": 434, 109 | }, 110 | "num_serialized_models_to_keep": 2, 111 | "should_log_learning_rate": true, 112 | "cuda_device": 0, 113 | "validation_metric": "+wiki_el_f1", 114 | } 115 | 116 | } 117 | -------------------------------------------------------------------------------- /training_config/pretraining/knowbert_wordnet_linker.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "vocabulary": { 3 | "directory_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/vocabulary_wordnet.tar.gz", 4 | }, 5 | 6 | "dataset_reader": { 7 | "type": "wordnet_fine_grained", 8 | "wordnet_entity_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/entities.jsonl", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "bert-pretrained", 12 | "pretrained_model": "bert-base-uncased", 13 | "do_lowercase": true, 14 | "use_starting_offsets": true, 15 | "max_pieces": 512, 16 | }, 17 | }, 18 | "entity_indexer": { 19 | "type": "characters_tokenizer", 20 | "tokenizer": { 21 | "type": "word", 22 | "word_splitter": {"type": "just_spaces"}, 23 | }, 24 | "namespace": "entity" 25 | }, 26 | "is_training": true, 27 | "should_remap_span_indices": false 28 | }, 29 | 30 | "train_data_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/semcor_and_wordnet_examples.json", 31 | 32 | "iterator": { 33 | "type": "cross_sentence_linking", 34 | "batch_size": 32, 35 | "entity_indexer": { 36 | "type": "characters_tokenizer", 37 | "tokenizer": { 38 | "type": "word", 39 | "word_splitter": {"type": "just_spaces"}, 40 | }, 41 | "namespace": "entity" 42 | }, 43 | "bert_model_type": "bert-base-uncased", 44 | "do_lower_case": true, 45 | // this is ignored 46 | "mask_candidate_strategy": "none", 47 | "max_predictions_per_seq": 0, 48 | "id_type": "wordnet", 49 | "use_nsp_label": false, 50 | }, 51 | 52 | "model": { 53 | "type": "knowbert", 54 | "bert_model_name": "bert-base-uncased", 55 | "mode": "entity_linking", 56 | "soldered_layers": {"wordnet": 9}, 57 | "soldered_kgs": { 58 | "wordnet": { 59 | "type": "soldered_kg", 60 | "entity_linker": { 61 | "type": "entity_linking_with_candidate_mentions", 62 | "loss_type": "softmax", 63 | "concat_entity_embedder": { 64 | "type": "wordnet_all_embeddings", 65 | "entity_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/entities.jsonl", 66 | "embedding_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/wordnet_synsets_mask_null_vocab_embeddings_tucker_gensen.hdf5", 67 | "vocab_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/wordnet_synsets_mask_null_vocab.txt", 68 | "entity_dim": 200, 69 | "entity_h5_key": "tucker_gensen", 70 | }, 71 | "contextual_embedding_dim": 768, 72 | "span_encoder_config": { 73 | "hidden_size": 200, 74 | "num_hidden_layers": 1, 75 | "num_attention_heads": 4, 76 | "intermediate_size": 1024 77 | }, 78 | }, 79 | "span_attention_config": { 80 | "hidden_size": 200, 81 | "num_hidden_layers": 1, 82 | "num_attention_heads": 4, 83 | "intermediate_size": 1024 84 | }, 85 | }, 86 | }, 87 | }, 88 | 89 | "trainer": { 90 | "optimizer": { 91 | "type": "bert_adam", 92 | "lr": 1e-3, 93 | "t_total": -1, 94 | "max_grad_norm": 1.0, 95 | "weight_decay": 0.01, 96 | "parameter_groups": [ 97 | [["bias", "LayerNorm.bias", "LayerNorm.weight", "layer_norm.weight"], {"weight_decay": 0.0}], 98 | ], 99 | }, 100 | "num_epochs": 5, 101 | 102 | "learning_rate_scheduler": { 103 | "type": "slanted_triangular", 104 | "num_epochs": 5, 105 | // semcor + examples batch size=32 106 | "num_steps_per_epoch": 2470, 107 | }, 108 | "num_serialized_models_to_keep": 1, 109 | "should_log_learning_rate": true, 110 | "cuda_device": 0, 111 | } 112 | 113 | } 114 | -------------------------------------------------------------------------------- /training_config/pretraining/knowbert_wordnet_wiki_linker.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "vocabulary": { 3 | "directory_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/vocabulary_wordnet_wiki.tar.gz" 4 | }, 5 | 6 | "dataset_reader": { 7 | "type": "wordnet_fine_grained", 8 | "wordnet_entity_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/entities.jsonl", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "bert-pretrained", 12 | "pretrained_model": "bert-base-uncased", 13 | "do_lowercase": true, 14 | "use_starting_offsets": true, 15 | "max_pieces": 512, 16 | }, 17 | }, 18 | "entity_indexer": { 19 | "type": "characters_tokenizer", 20 | "tokenizer": { 21 | "type": "word", 22 | "word_splitter": {"type": "just_spaces"}, 23 | }, 24 | "namespace": "entity" 25 | }, 26 | "is_training": true, 27 | "should_remap_span_indices": false, 28 | "extra_candidate_generators": { 29 | "wiki": {"type": "wiki"}, 30 | }, 31 | }, 32 | 33 | "train_data_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/semcor_and_wordnet_examples.json", 34 | 35 | "iterator": { 36 | "type": "self_attn_bucket", 37 | "batch_size_schedule": "base-24gb-fp32", 38 | "iterator": { 39 | "type": "cross_sentence_linking", 40 | "batch_size": 32, 41 | "entity_indexer": { 42 | "type": "characters_tokenizer", 43 | "tokenizer": { 44 | "type": "word", 45 | "word_splitter": {"type": "just_spaces"}, 46 | }, 47 | "namespace": "entity_wordnet" 48 | }, 49 | "bert_model_type": "bert-base-uncased", 50 | "do_lower_case": true, 51 | // this is ignored 52 | "mask_candidate_strategy": "none", 53 | "max_predictions_per_seq": 0, 54 | "id_type": "wordnet", 55 | "use_nsp_label": false, 56 | "extra_id_type": "wiki", 57 | "extra_entity_indexer": { 58 | "type": "characters_tokenizer", 59 | "tokenizer": { 60 | "type": "word", 61 | "word_splitter": {"type": "just_spaces"}, 62 | }, 63 | "namespace": "entity_wiki" 64 | } 65 | }, 66 | }, 67 | 68 | "model": { 69 | "type": "knowbert", 70 | "bert_model_name": "bert-base-uncased", 71 | "model_archive": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/models/knowbert_wiki_model.tar.gz", 72 | "strict_load_archive": false, 73 | "mode": "entity_linking", 74 | "soldered_layers": {"wordnet": 10, "wiki": 9}, 75 | "soldered_kgs": { 76 | "wordnet": { 77 | "type": "soldered_kg", 78 | "entity_linker": { 79 | "type": "entity_linking_with_candidate_mentions", 80 | "namespace": "entity_wordnet", 81 | "loss_type": "softmax", 82 | "concat_entity_embedder": { 83 | "type": "wordnet_all_embeddings", 84 | "entity_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/entities.jsonl", 85 | "embedding_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/wordnet_synsets_mask_null_vocab_embeddings_tucker_gensen.hdf5", 86 | "vocab_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/wordnet_synsets_mask_null_vocab.txt", 87 | "entity_dim": 200, 88 | "entity_h5_key": "tucker_gensen", 89 | }, 90 | "contextual_embedding_dim": 768, 91 | "span_encoder_config": { 92 | "hidden_size": 200, 93 | "num_hidden_layers": 1, 94 | "num_attention_heads": 4, 95 | "intermediate_size": 1024 96 | }, 97 | }, 98 | "span_attention_config": { 99 | "hidden_size": 200, 100 | "num_hidden_layers": 1, 101 | "num_attention_heads": 4, 102 | "intermediate_size": 1024 103 | }, 104 | }, 105 | "wiki": { 106 | "type": "soldered_kg", 107 | // wiki component is pretrained so we freeze 108 | "should_init_kg_to_bert_inverse": false, 109 | "freeze": true, 110 | "entity_linker": { 111 | "type": "entity_linking_with_candidate_mentions", 112 | "namespace": "entity_wiki", 113 | "entity_embedding": { 114 | "vocab_namespace": "entity_wiki", 115 | "embedding_dim": 300, 116 | "pretrained_file": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wiki_entity_linking/entities_glove_format.gz", 117 | "trainable": false, 118 | "sparse": false 119 | }, 120 | "contextual_embedding_dim": 768, 121 | "span_encoder_config": { 122 | "hidden_size": 300, 123 | "num_hidden_layers": 1, 124 | "num_attention_heads": 4, 125 | "intermediate_size": 1024 126 | }, 127 | }, 128 | "span_attention_config": { 129 | "hidden_size": 300, 130 | "num_hidden_layers": 1, 131 | "num_attention_heads": 4, 132 | "intermediate_size": 1024 133 | }, 134 | }, 135 | }, 136 | }, 137 | 138 | "trainer": { 139 | "optimizer": { 140 | "type": "bert_adam", 141 | "lr": 1e-3, 142 | "t_total": -1, 143 | "max_grad_norm": 1.0, 144 | "weight_decay": 0.01, 145 | "parameter_groups": [ 146 | [["bias", "LayerNorm.bias", "LayerNorm.weight", "layer_norm.weight"], {"weight_decay": 0.0}], 147 | ], 148 | }, 149 | "gradient_accumulation_batch_size": 32, 150 | "num_epochs": 5, 151 | 152 | "learning_rate_scheduler": { 153 | "type": "slanted_triangular", 154 | "num_epochs": 5, 155 | "num_steps_per_epoch": 2470, 156 | }, 157 | "num_serialized_models_to_keep": 2, 158 | "model_save_interval": 600, 159 | "should_log_learning_rate": true, 160 | "cuda_device": 0, 161 | } 162 | 163 | } 164 | -------------------------------------------------------------------------------- /training_config/pretraining/wordnet_tucker.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "kg_tuple", 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "kg_tuple", 7 | "extra_files_for_gold_pairs": [ 8 | "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/relations_train99.txt", 9 | "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/relations_dev01.txt", 10 | ] 11 | }, 12 | 13 | "train_data_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/relations_train99.txt", 14 | "validation_data_path": "https://allennlp.s3-us-west-2.amazonaws.com/knowbert/wordnet/relations_dev01.txt", 15 | 16 | "model": { 17 | "type": "kg_tuple", 18 | "kg_tuple_predictor": { 19 | "type": "tucker", 20 | "num_entities": 324637 + 2, 21 | // *2 for _reverse, +2 for PADDING, UNKNOWN 22 | // +1 for synset_lemma 23 | "num_relations": 11 * 2 + 17 * 2 + 1 + 2, 24 | "entity_dim": 200, 25 | "relation_dim": 30 26 | } 27 | }, 28 | "iterator": { 29 | "type": "basic", 30 | "batch_size": 128 31 | }, 32 | 33 | "trainer": { 34 | "optimizer": { 35 | "type": "adam", 36 | "lr": 0.01, 37 | }, 38 | "validation_metric": "+mean_reciprocal_rank", 39 | "num_serialized_models_to_keep": 2, 40 | "num_epochs": 500, 41 | "patience": 10, 42 | "cuda_device": 0, 43 | } 44 | 45 | } 46 | --------------------------------------------------------------------------------