├── bert_entity ├── __init__.py ├── preprocessing │ ├── __init__.py │ ├── wikiextractor.py │ ├── create_keyword_matcher.py │ ├── create_redirects.py │ ├── download_data.py │ ├── create_disambiguation_dict.py │ ├── create_resolve_to_wiki_dicts.py │ ├── create_integerized_aida_conll_training.py │ ├── postprocess_mention_entity_counts.py │ ├── collect_mention_entity_counts.py │ ├── preprocess_aida_conll_data.py │ └── create_integerized_wiki_training.py ├── vocab.py ├── pipeline_job.py ├── train.py ├── preprocess_all.py ├── metrics.py ├── train_util.py ├── misc.py ├── data_loader_conll.py ├── data_loader_wiki.py └── model.py ├── requirements.txt ├── setup_paths ├── docs ├── Bert-Entity.png └── preprocessing.png ├── downstream_experiments ├── prepare_fairseq.sh ├── run_fairseq_bert_ensemble.sh ├── ensemble_bert_modeling.py └── fairseq_patch_01.patch ├── .gitmodules ├── config ├── dummy__preprocess.yaml ├── conll2019__preprocess.yaml ├── dummy__train_on_aida_conll.yaml ├── conll2019__train_on_aida_conll.yaml ├── dummy__train_on_wiki.yaml └── conll2019__train_on_wiki.yaml ├── LICENSE ├── .gitignore └── README.md /bert_entity/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bert_entity/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flashtext 2 | configargparse 3 | boto3 4 | torch==1.5.0 -------------------------------------------------------------------------------- /setup_paths: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:./bert_entity:./pytorch-pretrained-BERT/:./wikiextractor-wikimentions/ -------------------------------------------------------------------------------- /docs/Bert-Entity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samuelbroscheit/entity_knowledge_in_bert/HEAD/docs/Bert-Entity.png -------------------------------------------------------------------------------- /docs/preprocessing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samuelbroscheit/entity_knowledge_in_bert/HEAD/docs/preprocessing.png -------------------------------------------------------------------------------- /downstream_experiments/prepare_fairseq.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | git clone https://github.com/pytorch/fairseq.git 4 | cd fairseq 5 | git checkout ec6f8ef99a8c6942133e01a610def197e1d6d9dd 6 | git apply ../fairseq_patch_01.patch 7 | git apply ../fairseq_patch_02.patch 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "wikiextractor-wikimentions"] 2 | path = wikiextractor-wikimentions 3 | url = git://github.com/samuelbroscheit/wikiextractor-wikimentions.git 4 | [submodule "pytorch-pretrained-BERT"] 5 | path = pytorch-pretrained-BERT 6 | url = git://github.com/huggingface/pytorch-pretrained-BERT.git 7 | -------------------------------------------------------------------------------- /config/dummy__preprocess.yaml: -------------------------------------------------------------------------------- 1 | debug: False 2 | wiki_lang_version: enwiki 3 | data_version_name: dummy 4 | download_data_only_dummy: True 5 | num_most_freq_entities: 50000 6 | add_missing_conll_entities: True 7 | uncased: True 8 | 9 | collect_mention_entities_num_workers: 10 10 | 11 | wikiextractor_num_workers: 10 12 | 13 | create_training_data_num_workers: 10 14 | create_training_data_num_entities_in_necessary_articles: 10 15 | 16 | create_integerized_training_num_workers: 10 17 | create_integerized_training_loc_file_name: 'data.loc' 18 | create_integerized_training_instance_text_length: 254 19 | create_integerized_training_instance_text_overlap: 20 20 | create_integerized_training_max_entity_per_shard_count: 10 21 | create_integerized_training_valid_size: 1000 22 | create_integerized_training_test_size: 1000 23 | 24 | -------------------------------------------------------------------------------- /config/conll2019__preprocess.yaml: -------------------------------------------------------------------------------- 1 | debug: False 2 | wiki_lang_version: enwiki 3 | data_version_name: conll2019 4 | download_data_only_dummy: False 5 | download_2017_enwiki: True 6 | num_most_freq_entities: 500000 7 | add_missing_conll_entities: True 8 | uncased: True 9 | 10 | collect_mention_entities_num_workers: 10 11 | 12 | wikiextractor_num_workers: 10 13 | 14 | create_training_data_num_workers: 10 15 | create_training_data_num_entities_in_necessary_articles: 10 16 | 17 | create_integerized_training_num_workers: 50 18 | create_integerized_training_loc_file_name: 'data.loc' 19 | create_integerized_training_instance_text_length: 254 20 | create_integerized_training_instance_text_overlap: 20 21 | create_integerized_training_max_entity_per_shard_count: 10 22 | create_integerized_training_valid_size: 1000 23 | create_integerized_training_test_size: 1000 24 | -------------------------------------------------------------------------------- /config/dummy__train_on_aida_conll.yaml: -------------------------------------------------------------------------------- 1 | logdir: data/checkpoints/dummy_aidaconll_00001/ 2 | resume_from_checkpoint: data/checkpoints/dummy_wiki_00001/best_f1-0.pt 3 | resume_reset_epoch: True 4 | data_path_conll: data/benchmarks/aida-yago2-dataset/conll_dataset_dummy_254-20.pickle 5 | data_version_name: dummy 6 | collect_most_popular_labels_steps: 1 7 | checkpoint_eval_steps: 10000 8 | checkpoint_save_steps: 1000 9 | device: 0 10 | eval_device: 0 11 | out_device: 0 12 | dataset: CONLLEDLDataset 13 | model: ConllNet 14 | batch_size: 2 15 | accumulate_batch_gradients: 4 16 | eval_batch_size: 3 17 | label_size: 1024 18 | topk_neg_examples: 20 19 | finetuning: 1 20 | sparse: True 21 | encoder_lr: 5e-5 22 | decoder_lr: 0 23 | n_epochs: 60 24 | segm_decoder_lr: 0 25 | segm_decoder_weight_decay: 1e-10 26 | learn_segmentation: False 27 | bert_dropout: 0.2 28 | maskout_entity_prob: 0 29 | eval_before_training: False 30 | project: False 31 | exclude_parameter_names_regex: 'embeddings|encoder\.layer\.[0-2]\.' 32 | -------------------------------------------------------------------------------- /config/conll2019__train_on_aida_conll.yaml: -------------------------------------------------------------------------------- 1 | logdir: data/checkpoints/conll2019_aidaconll_00001/ 2 | resume_from_checkpoint: data/checkpoints/conll2019_wiki_00001/best_f1-0.pt 3 | resume_reset_epoch: True 4 | data_path_conll: data/benchmarks/aida-yago2-dataset/conll_dataset_dummy_254-20.pickle 5 | data_version_name: dummy 6 | collect_most_popular_labels_steps: 1 7 | checkpoint_eval_steps: 10000 8 | checkpoint_save_steps: 1000 9 | device: 0 10 | eval_device: 0 11 | out_device: 0 12 | dataset: CONLLEDLDataset 13 | model: ConllNet 14 | batch_size: 2 15 | accumulate_batch_gradients: 4 16 | eval_batch_size: 3 17 | label_size: 1024 18 | topk_neg_examples: 20 19 | finetuning: 1 20 | sparse: True 21 | encoder_lr: 5e-5 22 | decoder_lr: 0 23 | n_epochs: 60 24 | segm_decoder_lr: 0 25 | segm_decoder_weight_decay: 1e-10 26 | learn_segmentation: False 27 | bert_dropout: 0.2 28 | maskout_entity_prob: 0 29 | eval_before_training: False 30 | project: False 31 | exclude_parameter_names_regex: 'embeddings|encoder\.layer\.[0-2]\.' 32 | -------------------------------------------------------------------------------- /config/dummy__train_on_wiki.yaml: -------------------------------------------------------------------------------- 1 | logdir: data/checkpoints/dummy_wiki_00001 2 | debug: False 3 | device: 0 4 | eval_device: 0 5 | dataset: EDLDataset 6 | model: Net 7 | data_version_name: dummy 8 | wiki_lang_version: enwiki 9 | eval_on_test_only: False 10 | out_device: 0 11 | batch_size: 16 12 | eval_batch_size: 1 13 | accumulate_batch_gradients: 8 14 | sparse: True 15 | encoder_lr: 5e-05 16 | decoder_lr: 0.1 17 | maskout_entity_prob: 0.0 18 | encoder_weight_decay: 0.0 19 | decoder_weight_decay: 0.0 20 | segm_decoder_weight_decay: 0.0 21 | learn_segmentation: False 22 | label_size: 8192 23 | entity_embedding_size: 768 24 | project: False 25 | n_epochs: 100 26 | collect_most_popular_labels_steps: 1 27 | checkpoint_eval_steps: 1000 28 | checkpoint_save_steps: 100000 29 | finetuning: 3 30 | train_loc_file: train.loc 31 | valid_loc_file: valid.loc 32 | resume_reset_epoch: False 33 | resume_optimizer_from_checkpoint: False 34 | topk_neg_examples: 20 35 | dont_save_checkpoints: False 36 | data_workers: 24 37 | eval_before_training: False 38 | train_data_dir: data 39 | valid_data_dir: data 40 | test_data_dir: data 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 samuelbroscheit 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/conll2019__train_on_wiki.yaml: -------------------------------------------------------------------------------- 1 | logdir: data/checkpoints/conll2019_wiki_00001 2 | debug: False 3 | device: 0 4 | eval_device: 0 5 | dataset: EDLDataset 6 | model: Net 7 | data_version_name: conll2019 8 | wiki_lang_version: enwiki 9 | eval_on_test_only: False 10 | out_device: 0 11 | batch_size: 6 12 | eval_batch_size: 1 13 | accumulate_batch_gradients: 8 14 | sparse: True 15 | encoder_lr: 5e-05 16 | decoder_lr: 0.1 17 | maskout_entity_prob: 0.0 18 | segm_decoder_lr: 0.001 19 | encoder_weight_decay: 0.0 20 | decoder_weight_decay: 0.0 21 | segm_decoder_weight_decay: 0.0 22 | learn_segmentation: False 23 | label_size: 8192 24 | entity_embedding_size: 768 25 | project: False 26 | n_epochs: 20 27 | collect_most_popular_labels_steps: 1 28 | checkpoint_eval_steps: 10000 29 | checkpoint_save_steps: 100000 30 | finetuning: 3 31 | top_rnns: False 32 | train_loc_file: train.loc 33 | valid_loc_file: valid.loc 34 | resume_reset_epoch: False 35 | resume_optimizer_from_checkpoint: False 36 | topk_neg_examples: 20 37 | dont_save_checkpoints: False 38 | data_workers: 24 39 | eval_before_training: False 40 | train_data_dir: data 41 | valid_data_dir: data 42 | test_data_dir: data 43 | -------------------------------------------------------------------------------- /downstream_experiments/run_fairseq_bert_ensemble.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | BERT_ENTITY_CHECKPOINT=$1 4 | 5 | cd "fairseq/" 6 | 7 | python train.py data-bin/bertencoder--wmt14_en_de --optimizer adam --lr 0.00005 --clip-norm 0.1 --dropout 0.2 --max-tokens 1000 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --lr-scheduler fixed --force-anneal 200 --arch bert_transformer_iwslt_en_de --save-dir checkpoints/bert_transformer_wmt14_en_de_ft_20 --fp16 --finetuning 2 --save-interval-updates 10000 --validate-interval 10000 8 | 9 | python train.py data-bin/bertencoder--wmt14_en_de --optimizer adam --lr 0.00005 --clip-norm 0.1 --dropout 0.2 --max-tokens 1000 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --lr-scheduler fixed --force-anneal 200 --arch bert_transformer_iwslt_en_de --save-dir checkpoints/bert_transformer_wmt14_en_de_ft_10-entity --fp16 --finetuning 1 --save-interval-updates 10000 --validate-interval 10000 --load_bert_checkpoint $BERT_ENTITY_CHECKPOINT 10 | 11 | python setup.py build develop 12 | 13 | python generate.py data-bin/bertencoder--wmt14_en_de --beam 5 --remove-bpe --batch-size 128 --path 14 | checkpoints/bert_transformer_wmt14_en_de_ft_10/checkpoint_best.pt | tee checkpoints/bert_transformer_wmt14_en_de_ft_10/gen.out 15 | 16 | python generate.py data-bin/bertencoder--wmt14_en_de --beam 5 --remove-bpe --batch-size 128 --path checkpoints/bert_transformer_wmt14_en_de_ft_10-entity/checkpoint_best.pt | tee checkpoints/bert_transformer_wmt14_en_de_ft_10-entity/gen.out -------------------------------------------------------------------------------- /bert_entity/vocab.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from pytorch_pretrained_bert import BertTokenizer 4 | 5 | 6 | class Vocab: 7 | def __init__(self, args=None): 8 | self.tag2idx = None 9 | self.idx2tag = None 10 | self.OUTSIDE_ID = None 11 | self.PAD_ID = None 12 | self.SPECIAL_TOKENS = None 13 | self.tokenizer = None 14 | if args is not None: 15 | self.load(args) 16 | 17 | def load(self, args, popular_entity_to_id_dict=None): 18 | 19 | if popular_entity_to_id_dict is None: 20 | with open(f"data/versions/{args.data_version_name}/indexes/popular_entity_to_id_dict.pickle", "rb") as f: 21 | popular_entity_to_id_dict = pickle.load(f) 22 | 23 | if args.uncased: 24 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) 25 | else: 26 | tokenizer = BertTokenizer.from_pretrained("bert-base-cased", do_lower_case=False) 27 | 28 | self.tag2idx = popular_entity_to_id_dict 29 | 30 | self.OUTSIDE_ID = len(self.tag2idx) 31 | self.tag2idx["|||O|||"] = self.OUTSIDE_ID 32 | 33 | self.PAD_ID = len(self.tag2idx) 34 | self.tag2idx["|||PAD|||"] = self.PAD_ID 35 | 36 | self.SPECIAL_TOKENS = [self.OUTSIDE_ID, self.PAD_ID] 37 | 38 | self.idx2tag = {v: k for k, v in self.tag2idx.items()} 39 | 40 | self.tokenizer = tokenizer 41 | 42 | args.vocab_size = self.size() 43 | 44 | def size(self): 45 | return len(self.tag2idx) 46 | -------------------------------------------------------------------------------- /bert_entity/preprocessing/wikiextractor.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import glob 3 | import io 4 | import json 5 | import multiprocessing 6 | import os 7 | import pickle 8 | import sys 9 | from collections import Counter 10 | from typing import Dict 11 | 12 | from WikiExtractor import main as wiki_extractor_main 13 | from misc import normalize_wiki_entity 14 | from pipeline_job import PipelineJob 15 | 16 | 17 | class Wikiextractor(PipelineJob): 18 | """ 19 | Run Wikiextractor on the Wikipedia dump and extract all the mentions from it. 20 | """ 21 | def __init__(self, preprocess_jobs: Dict[str, PipelineJob], opts): 22 | super().__init__( 23 | requires=[ 24 | f"data/versions/{opts.data_version_name}/downloads/{opts.wiki_lang_version}/", 25 | ], 26 | provides=[ 27 | f"data/versions/{opts.data_version_name}/wikiextractor_out/{opts.wiki_lang_version}/", 28 | ], 29 | preprocess_jobs=preprocess_jobs, 30 | opts=opts, 31 | ) 32 | 33 | def _run(self): 34 | 35 | self.log("Run WikiExtractor") 36 | 37 | # python wikiextractor-wikimentions/WikiExtractor.py --json --filter_disambig_pages --processes $WIKI_EXTRACTOR_NR_PROCESSES --collect_links $DOWNLOADS_DIR/$WIKI_RAW/$WIKI_FILE -o $WIKI_EXTRACTOR_OUTDIR/$WIKI_FILE 38 | 39 | for input_file in glob.glob( 40 | f"data/versions/{self.opts.data_version_name}/downloads/{self.opts.wiki_lang_version}/*" 41 | ): 42 | self.log(input_file) 43 | sys.argv = [ 44 | "", 45 | "--json", 46 | "--filter_disambig_pages", 47 | "--collect_links", 48 | "--processes", 49 | str(self.opts.wikiextractor_num_workers), 50 | input_file, 51 | "-o", 52 | f"data/versions/{self.opts.data_version_name}/wikiextractor_out/tmp/{os.path.basename(input_file)}", 53 | ] 54 | wiki_extractor_main() 55 | os.rename( 56 | f"data/versions/{self.opts.data_version_name}/wikiextractor_out/tmp/", 57 | f"data/versions/{self.opts.data_version_name}/wikiextractor_out/{self.opts.wiki_lang_version}/", 58 | ) 59 | 60 | self.log("WikiExtractor finished") 61 | 62 | -------------------------------------------------------------------------------- /bert_entity/preprocessing/create_keyword_matcher.py: -------------------------------------------------------------------------------- 1 | from collections import Counter, OrderedDict 2 | from typing import Dict 3 | import tqdm 4 | import pickle 5 | from flashtext import KeywordProcessor 6 | import sys 7 | 8 | sys.setrecursionlimit(10000) 9 | 10 | from pipeline_job import PipelineJob 11 | 12 | 13 | class CreateKeywordProcessor(PipelineJob): 14 | """ 15 | Create a matcher to detect mentions that we found with Wikiextractor in free text. 16 | We use this later to add more annotations to the text. However, as we do not know 17 | the true entity, we'll associate labels for all entities from the with their 18 | p(e|m) prior. 19 | """ 20 | def __init__(self, preprocess_jobs: Dict[str, PipelineJob], opts): 21 | super().__init__( 22 | requires=[ 23 | f"data/versions/{opts.data_version_name}/indexes/mention_entity_counter_popular_entities.pickle", 24 | ], 25 | provides=[ 26 | f"data/versions/{opts.data_version_name}/indexes/keyword_processor.pickle" 27 | ], 28 | preprocess_jobs=preprocess_jobs, 29 | opts=opts, 30 | ) 31 | 32 | def _run(self): 33 | 34 | with open( 35 | f"data/versions/{self.opts.data_version_name}/indexes/mention_entity_counter_popular_entities.pickle", 36 | "rb", 37 | ) as f: 38 | all_mention_entity_counter_most_popular_entities = pickle.load(f) 39 | 40 | 41 | keyword_processor = KeywordProcessor(case_sensitive=False) 42 | 43 | for (k, v_most_common) in tqdm.tqdm( 44 | list(all_mention_entity_counter_most_popular_entities.items()) 45 | ): 46 | if ( 47 | len(v_most_common) == 0 48 | or v_most_common is None 49 | or v_most_common[0] is None 50 | or v_most_common[0][0] is None 51 | ): 52 | continue 53 | if v_most_common[0][0].startswith("List"): 54 | continue 55 | if v_most_common[0][0].startswith("Category:"): 56 | continue 57 | if v_most_common[0][1] < 50: 58 | continue 59 | keyword_processor.add_keyword(k.replace("_", " ")) 60 | keyword_processor.add_keyword(v_most_common[0][0].replace("_", " ")) 61 | 62 | with open( 63 | f"data/versions/{self.opts.data_version_name}/indexes/keyword_processor.pickle", 64 | "wb" 65 | ) as f: 66 | pickle.dump(keyword_processor, f) 67 | -------------------------------------------------------------------------------- /bert_entity/preprocessing/create_redirects.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import io 3 | import os 4 | import pickle 5 | import re 6 | import urllib.request 7 | from typing import Dict 8 | 9 | import tqdm 10 | 11 | from pipeline_job import PipelineJob 12 | 13 | 14 | class CreateRedirects(PipelineJob): 15 | """ 16 | Create a dictionary containing redirects for Wikipedia page names. Here we use 17 | the already extracted mapping from DBPedia that was created from a 2016 dump. 18 | The redirects are used for the Wikipedia mention extractions as well as for 19 | the AIDA-CONLL benchmark. 20 | """ 21 | def __init__(self, preprocess_jobs: Dict[str, PipelineJob], opts): 22 | super().__init__( 23 | requires=[], 24 | provides=[ 25 | "data/indexes/redirects_en.ttl.bz2.dict", 26 | "data/downloads/redirects_en.ttl.bz2", 27 | ], 28 | preprocess_jobs=preprocess_jobs, 29 | opts=opts, 30 | ) 31 | 32 | def _run(self): 33 | self._download( 34 | "http://downloads.dbpedia.org/2016-10/core-i18n/en/redirects_en.ttl.bz2", 35 | "data/downloads/", 36 | ) 37 | 38 | redirects = dict() 39 | redirects_first_sweep = dict() 40 | 41 | redirect_matcher = re.compile( 42 | " ." 43 | ) 44 | 45 | with bz2.BZ2File("data/downloads/redirects_en.ttl.bz2", "rb") as file: 46 | for line in tqdm.tqdm(file.readlines()): 47 | line_decoded = line.decode().strip() 48 | redirect_matcher_match = redirect_matcher.match(line_decoded) 49 | if redirect_matcher_match: 50 | redirects[ 51 | redirect_matcher_match.group(1) 52 | ] = redirect_matcher_match.group(2) 53 | redirects_first_sweep[ 54 | redirect_matcher_match.group(1) 55 | ] = redirect_matcher_match.group(2) 56 | # else: 57 | # print(line_decoded) 58 | 59 | with bz2.BZ2File("data/downloads/redirects_en.ttl.bz2", "rb") as file: 60 | for line in tqdm.tqdm(file.readlines()): 61 | line_decoded = line.decode().strip() 62 | redirect_matcher_match = redirect_matcher.match(line_decoded) 63 | if redirect_matcher_match: 64 | if redirect_matcher_match.group(2) in redirects: 65 | redirects[redirect_matcher_match.group(1)] = redirects[ 66 | redirect_matcher_match.group(2) 67 | ] 68 | 69 | with io.open("data/indexes/redirects_en.ttl.bz2.dict", "wb") as f: 70 | pickle.dump(redirects, f) 71 | -------------------------------------------------------------------------------- /bert_entity/preprocessing/download_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from typing import Dict 4 | 5 | from pipeline_job import PipelineJob 6 | 7 | 8 | class DownloadWikiDump(PipelineJob): 9 | """ 10 | Download the current Wikipedia dump. Either download one file for a dummy / prototyping version 11 | (set download_data_only_dummy to True). Or all download files. 12 | """ 13 | 14 | def __init__(self, preprocess_jobs: Dict[str, PipelineJob], opts): 15 | super().__init__( 16 | requires=[], 17 | provides=[f"data/versions/{opts.data_version_name}/downloads/{opts.wiki_lang_version}/"], 18 | preprocess_jobs=preprocess_jobs, 19 | opts=opts, 20 | ) 21 | 22 | def _run(self): 23 | 24 | self.log(f"Downloading {self.opts.wiki_lang_version}") 25 | if self.opts.download_data_only_dummy: 26 | if self.opts.download_2017_enwiki: 27 | url = "https://archive.org/download/enwiki-20171001/" 28 | accept = "enwiki-20171001-pages-articles1.xml-p10p30302.bz2" 29 | else: 30 | url = f"https://dumps.wikimedia.org/{self.opts.wiki_lang_version}/latest/" 31 | accept = f"{self.opts.wiki_lang_version}-latest-pages-articles1.xml-*.bz2", 32 | 33 | subprocess.check_call( 34 | [ 35 | "wget", 36 | "-r", 37 | "-l1", 38 | "-np", 39 | "-nd", 40 | url, 41 | "-A", 42 | accept, 43 | "-P", 44 | f"data/versions/{self.opts.data_version_name}/downloads/tmp/", 45 | # f"data/versions/{self.opts.data_version_name}/downloads/{self.opts.wiki_lang_version}/", 46 | ] 47 | ) 48 | else: 49 | if self.opts.download_2017_enwiki: 50 | url = "https://archive.org/download/enwiki-20171001/" 51 | else: 52 | url = f"https://dumps.wikimedia.org/{self.opts.wiki_lang_version}/latest/" 53 | subprocess.check_call( 54 | [ 55 | "wget", 56 | "-r", 57 | "-l1", 58 | "-np", 59 | "-nd", 60 | url, 61 | "-A", 62 | f"{self.opts.wiki_lang_version}-latest-pages-articles*.xml-*.bz2", 63 | "-R", 64 | f"{self.opts.wiki_lang_version}-latest-pages-articles-multistream*.xml-*.bz2", 65 | "-P", 66 | f"data/versions/{self.opts.data_version_name}/downloads/tmp/", 67 | ] 68 | ) 69 | 70 | os.rename( 71 | f"data/versions/{self.opts.data_version_name}/downloads/tmp/", 72 | f"data/versions/{self.opts.data_version_name}/downloads/{self.opts.wiki_lang_version}/", 73 | ) 74 | 75 | self.log("Download finished ") 76 | -------------------------------------------------------------------------------- /bert_entity/preprocessing/create_disambiguation_dict.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import pickle 3 | import re 4 | from collections import defaultdict 5 | from typing import Dict 6 | 7 | import tqdm 8 | 9 | from pipeline_job import PipelineJob 10 | 11 | 12 | class CreateDisambiguationDict(PipelineJob): 13 | """ 14 | Create a dictionary containing disambiguations for Wikipedia page names. 15 | Here we use the already extracted mapping from DBPedia that was created from 16 | a 2016 dump. The disambiguations are used to detect entity annotations in 17 | the AIDA-CONLL benchmark that have become incompatble for newer Wikipedia 18 | versions (I was using a Wikipedia dump from 2017. This dictionary might not 19 | be that fitting for the current wiki dump). 20 | """ 21 | def __init__(self, preprocess_jobs: Dict[str, PipelineJob], opts): 22 | super().__init__( 23 | requires=["data/indexes/redirects_en.ttl.bz2.dict"], 24 | provides=[ 25 | "data/indexes/disambiguations_en.ttl.bz2.dict", 26 | "data/downloads/disambiguations_en.ttl.bz2", 27 | ], 28 | preprocess_jobs=preprocess_jobs, 29 | opts=opts, 30 | ) 31 | 32 | def _create_dict( 33 | self, 34 | redirects, 35 | url, 36 | matcher_pattern, 37 | postproc_key=lambda x: x, 38 | postproc_val=lambda x: x, 39 | match_key=1, 40 | match_val=2, 41 | ): 42 | downloaded = self._download(url, "data/downloads/",) 43 | matcher = re.compile(matcher_pattern) 44 | 45 | a_to_b = defaultdict(list) 46 | with bz2.BZ2File(downloaded, "rb") as file: 47 | for line in tqdm.tqdm(file): 48 | line_decoded = line.decode().strip() 49 | matcher_match = matcher.match(line_decoded) 50 | if matcher_match: 51 | if matcher_match.group(match_val) in redirects: 52 | a_to_b[ 53 | postproc_key(matcher_match.group(match_key)) 54 | ].append(redirects[postproc_val(matcher_match.group(match_val))]) 55 | else: 56 | a_to_b[ 57 | postproc_key(matcher_match.group(match_key)) 58 | ].append(postproc_val(matcher_match.group(match_val))) 59 | return a_to_b 60 | 61 | def _run(self): 62 | 63 | with open("data/indexes/redirects_en.ttl.bz2.dict", "rb") as f: 64 | redirects = pickle.load(f) 65 | 66 | fb_to_wikiname_dict = self._create_dict( 67 | redirects=redirects, 68 | url="http://downloads.dbpedia.org/2016-10/core-i18n/en/disambiguations_en.ttl.bz2", 69 | matcher_pattern=" .", 70 | ) 71 | with open("data/indexes/disambiguations_en.ttl.bz2.dict", "wb") as f: 72 | pickle.dump(fb_to_wikiname_dict, f) 73 | 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | # User-specific stuff 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xml 9 | .idea/**/dictionaries 10 | .idea/**/shelf 11 | 12 | # Sensitive or high-churn files 13 | .idea/**/dataSources/ 14 | .idea/**/dataSources.ids 15 | .idea/**/dataSources.local.xml 16 | .idea/**/sqlDataSources.xml 17 | .idea/**/dynamic.xml 18 | .idea/**/uiDesigner.xml 19 | .idea/**/dbnavigator.xml 20 | 21 | # Gradle 22 | .idea/**/gradle.xml 23 | .idea/**/libraries 24 | 25 | # CMake 26 | cmake-build-debug/ 27 | cmake-build-release/ 28 | 29 | # Mongo Explorer plugin 30 | .idea/**/mongoSettings.xml 31 | 32 | # File-based project format 33 | *.iws 34 | 35 | # IntelliJ 36 | out/ 37 | 38 | # mpeltonen/sbt-idea plugin 39 | .idea_modules/ 40 | 41 | # JIRA plugin 42 | atlassian-ide-plugin.xml 43 | 44 | # Cursive Clojure plugin 45 | .idea/replstate.xml 46 | 47 | # Crashlytics plugin (for Android Studio and IntelliJ) 48 | com_crashlytics_export_strings.xml 49 | crashlytics.properties 50 | crashlytics-build.properties 51 | fabric.properties 52 | 53 | # Editor-based Rest Client 54 | .idea/httpRequests 55 | ### Python template 56 | # Byte-compiled / optimized / DLL files 57 | __pycache__/ 58 | *.py[cod] 59 | *$py.class 60 | 61 | # C extensions 62 | *.so 63 | 64 | # Distribution / packaging 65 | .Python 66 | build/ 67 | develop-eggs/ 68 | dist/ 69 | downloads/ 70 | eggs/ 71 | .eggs/ 72 | lib/ 73 | lib64/ 74 | parts/ 75 | sdist/ 76 | var/ 77 | wheels/ 78 | *.egg-info/ 79 | .installed.cfg 80 | *.egg 81 | MANIFEST 82 | 83 | # PyInstaller 84 | # Usually these files are written by a python script from a template 85 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 86 | *.manifest 87 | *.spec 88 | 89 | # Installer logs 90 | pip-log.txt 91 | pip-delete-this-directory.txt 92 | 93 | # Unit test / coverage reports 94 | htmlcov/ 95 | .tox/ 96 | .coverage 97 | .coverage.* 98 | .cache 99 | nosetests.xml 100 | coverage.xml 101 | *.cover 102 | .hypothesis/ 103 | .pytest_cache/ 104 | 105 | # Translations 106 | *.mo 107 | *.pot 108 | 109 | # Django stuff: 110 | *.log 111 | local_settings.py 112 | db.sqlite3 113 | 114 | # Flask stuff: 115 | instance/ 116 | .webassets-cache 117 | 118 | # Scrapy stuff: 119 | .scrapy 120 | 121 | # Sphinx documentation 122 | docs/_build/ 123 | 124 | # PyBuilder 125 | target/ 126 | 127 | # Jupyter Notebook 128 | .ipynb_checkpoints 129 | 130 | # pyenv 131 | .python-version 132 | 133 | # celery beat schedule file 134 | celerybeat-schedule 135 | 136 | # SageMath parsed files 137 | *.sage.py 138 | 139 | # Environments 140 | .env 141 | .venv 142 | env/ 143 | venv/ 144 | ENV/ 145 | env.bak/ 146 | venv.bak/ 147 | 148 | # Spyder project settings 149 | .spyderproject 150 | .spyproject 151 | 152 | # Rope project settings 153 | .ropeproject 154 | 155 | # mkdocs documentation 156 | /site 157 | 158 | # mypy 159 | .mypy_cache/ 160 | 161 | -------------------------------------------------------------------------------- /bert_entity/pipeline_job.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import urllib 4 | import datetime 5 | from argparse import Namespace 6 | from typing import List, Any, Dict 7 | 8 | class Job: 9 | 10 | def run(self, jobs: Dict[str, Any]): 11 | raise NotImplementedError 12 | 13 | def _get_msg(self, msg): 14 | return f"{datetime.datetime.now()} [{self.__class__.__name__}] {msg}" 15 | 16 | def log(self, msg): 17 | logging.info(self._get_msg(msg)) 18 | 19 | def error(self, msg): 20 | logging.error(self._get_msg(msg)) 21 | 22 | def debug(self, msg): 23 | logging.debug(self._get_msg(msg)) 24 | 25 | def _run(self): 26 | raise NotImplementedError 27 | 28 | 29 | class PipelineJob(Job): 30 | def __init__(self, requires, provides, preprocess_jobs, opts: Namespace, rerun_job=False): 31 | 32 | self.requires: List[str] = requires 33 | self.provides: List[str] = provides 34 | self.add_provides(preprocess_jobs) 35 | self.opts = opts 36 | self.rerun_job = rerun_job 37 | 38 | def run(self, pipeline_jobs: Dict[str, Any]): 39 | self.log(f"Checking requirements for {self.__class__.__name__}") 40 | self.check_required_exist(pipeline_jobs) 41 | self.create_out_directories() 42 | if not self.provides_exists(): 43 | self.log(f"Start running {self.__class__.__name__}") 44 | self._run() 45 | self.log(f"Finished running {self.__class__.__name__}") 46 | else: 47 | self.log(f"{self.__class__.__name__} is already finished") 48 | 49 | def _download(self, url, folder): 50 | if not os.path.exists(f"{folder}/{os.path.basename(url)}"): 51 | self.log( 52 | f"Downloading {url}" 53 | ) 54 | urllib.request.urlretrieve( 55 | url, 56 | f"{folder}/{os.path.basename(url)}", 57 | ) 58 | self.log("Download finished ") 59 | return f"{folder}/{os.path.basename(url)}" 60 | 61 | def add_provides(self, preprocess_jobs: Dict[str, Job]): 62 | for file_name in self.provides: 63 | preprocess_jobs[file_name] = self 64 | 65 | def check_required_exist(self, preprocess_jobs: Dict[str, Job]): 66 | for file_name in self.requires: 67 | if not os.path.exists(file_name): 68 | try: 69 | preprocess_jobs[file_name].run(preprocess_jobs) 70 | except: 71 | self.error( 72 | f"Cannot find required {file_name} and there is no preprocess job to create it" 73 | ) 74 | raise Exception 75 | 76 | def provides_exists(self,): 77 | if self.rerun_job: 78 | return False 79 | for file_name in self.provides: 80 | if not os.path.exists(file_name): 81 | return False 82 | return True 83 | 84 | def create_out_directories(self,): 85 | for file_name in self.provides: 86 | if len(file_name) - len(os.path.dirname(file_name)) in [0, 1]: 87 | os.makedirs(os.path.dirname(os.path.dirname(file_name)), exist_ok=True) 88 | else: 89 | os.makedirs(os.path.dirname(file_name), exist_ok=True) 90 | 91 | @staticmethod 92 | def run_jobs(job_classes: List, opts): 93 | jobs_dict = dict() 94 | job_list = list() 95 | for job_class in job_classes: 96 | job_list.append(job_class(jobs_dict, opts)) 97 | for job in job_list: 98 | job.run(jobs_dict) 99 | -------------------------------------------------------------------------------- /bert_entity/preprocessing/create_resolve_to_wiki_dicts.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import pickle 3 | import re 4 | from collections import defaultdict 5 | from typing import Dict 6 | 7 | import tqdm 8 | 9 | from pipeline_job import PipelineJob 10 | 11 | 12 | class CreateResolveToWikiNameDicts(PipelineJob): 13 | """ 14 | Create a dictionary containing mapping Freebase Ids and Wikipedia pages ids 15 | to Wikipedia page names. 16 | Here we use the already extracted mapping from DBPedia that was created from 17 | a 2016 dump. The disambiguations are used to detect entity annotations in 18 | the AIDA-CONLL benchmark that have become incompatble for newer Wikipedia 19 | versions (Please note that in the expermiments for the paper a Wikipedia 20 | dump from 2017 was used. This dictionary might not adequate for the latest 21 | wiki dump). 22 | """ 23 | def __init__(self, preprocess_jobs: Dict[str, PipelineJob], opts): 24 | super().__init__( 25 | requires=["data/indexes/redirects_en.ttl.bz2.dict"], 26 | provides=[ 27 | "data/indexes/freebase_links_en.ttl.bz2.dict", 28 | "data/downloads/freebase_links_en.ttl.bz2", 29 | "data/indexes/page_ids_en.ttl.bz2.dict", 30 | "data/downloads/page_ids_en.ttl.bz2", 31 | ], 32 | preprocess_jobs=preprocess_jobs, 33 | opts=opts, 34 | ) 35 | 36 | def _create_dict( 37 | self, 38 | redirects, 39 | url, 40 | matcher_pattern, 41 | postproc_key=lambda x: x, 42 | postproc_val=lambda x: x, 43 | match_key=2, 44 | match_val=1, 45 | ): 46 | downloaded = self._download(url, "data/downloads/",) 47 | matcher = re.compile(matcher_pattern) 48 | 49 | a_to_b = defaultdict(list) 50 | with bz2.BZ2File(downloaded, "rb") as file: 51 | for line in tqdm.tqdm(file): 52 | line_decoded = line.decode().strip() 53 | matcher_match = matcher.match(line_decoded) 54 | if matcher_match: 55 | if matcher_match.group(match_val) in redirects: 56 | a_to_b[ 57 | postproc_key(matcher_match.group(match_key)) 58 | ] = redirects[postproc_val(matcher_match.group(match_val))] 59 | else: 60 | a_to_b[ 61 | postproc_key(matcher_match.group(match_key)) 62 | ] = postproc_val(matcher_match.group(match_val)) 63 | return a_to_b 64 | 65 | def _run(self): 66 | 67 | with open("data/indexes/redirects_en.ttl.bz2.dict", "rb") as f: 68 | redirects = pickle.load(f) 69 | 70 | fb_to_wikiname_dict = self._create_dict( 71 | redirects=redirects, 72 | url="http://downloads.dbpedia.org/2016-10/core-i18n/en/freebase_links_en.ttl.bz2", 73 | matcher_pattern=" .", 74 | postproc_key=lambda x: "/" + x.replace(".", "/"), 75 | ) 76 | with open("data/indexes/freebase_links_en.ttl.bz2.dict", "wb") as f: 77 | pickle.dump(fb_to_wikiname_dict, f) 78 | 79 | page_id_to_wikiname_dict = self._create_dict( 80 | redirects=redirects, 81 | url="http://downloads.dbpedia.org/2016-10/core-i18n/en/page_ids_en.ttl.bz2", 82 | matcher_pattern=' "(.*)"\^\^ .', 83 | ) 84 | with open("data/indexes/page_ids_en.ttl.bz2.dict", "wb") as f: 85 | pickle.dump(page_id_to_wikiname_dict, f) 86 | 87 | -------------------------------------------------------------------------------- /bert_entity/train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | import time 5 | 6 | import torch.cuda 7 | import torch.nn as nn 8 | 9 | from metrics import Metrics 10 | from data_loader_conll import CONLLEDLDataset 11 | from data_loader_wiki import EDLDataset 12 | from model import Net 13 | from model_conll import ConllNet 14 | from train_util import get_args 15 | from vocab import Vocab 16 | 17 | 18 | class Datasets: 19 | EDLDataset = EDLDataset 20 | CONLLEDLDataset = CONLLEDLDataset 21 | 22 | 23 | class Models: 24 | Net = Net 25 | ConllNet = ConllNet 26 | 27 | 28 | if __name__ == "__main__": 29 | 30 | args = get_args() 31 | 32 | if args.debug: 33 | logging.basicConfig(level=logging.DEBUG) 34 | else: 35 | logging.basicConfig(level=logging.INFO) 36 | 37 | logging.info(str(("Devices", args.device, args.eval_device, args.out_device))) 38 | 39 | # set up the model 40 | vocab = Vocab(args) 41 | model_class = getattr(Models, args.model) 42 | model = model_class(args=args, vocab_size=vocab.size()) 43 | checkpoint = None 44 | if args.resume_from_checkpoint is not None: 45 | checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") 46 | model.load_state_dict(checkpoint["model"], strict=False) 47 | if args.device != "cpu": 48 | torch.cuda.empty_cache() 49 | model.to(args.device, args.out_device) 50 | print(model) 51 | 52 | # set up the optimizers and the loss 53 | optimizers, lr_schedulers = model.get_optimizers(args, checkpoint=checkpoint) 54 | criterion = nn.BCEWithLogitsLoss() 55 | 56 | # set up the datasets and dataloaders 57 | if not args.eval_on_test_only: 58 | train_dataset = getattr(Datasets, args.dataset)( 59 | args, split="train", vocab=vocab, device=args.device, label_size=args.label_size 60 | ) 61 | train_iter = train_dataset.get_data_iter(args=args, batch_size=args.batch_size, vocab=vocab, train=True) 62 | eval_dataset = getattr(Datasets, args.dataset)(args, split="valid", vocab=vocab, device=args.eval_device) 63 | eval_iter = eval_dataset.get_data_iter(args=args, batch_size=args.eval_batch_size, vocab=vocab, train=False) 64 | else: 65 | eval_dataset = getattr(Datasets, args.dataset)(args, split="test", vocab=vocab, device=args.eval_device) 66 | eval_iter = eval_dataset.get_data_iter(args=args, batch_size=args.eval_batch_size, vocab=vocab, train=False) 67 | 68 | start_epoch = 1 69 | if checkpoint and not args.resume_reset_epoch: 70 | start_epoch = checkpoint["epoch"] 71 | 72 | metrics = Metrics() 73 | 74 | if args.eval_before_training or args.eval_on_test_only: 75 | cloned_args = copy.deepcopy(args) 76 | cloned_args.dont_save_checkpoints = True 77 | metrics = model_class.evaluate( 78 | cloned_args, 79 | model, 80 | eval_iter, 81 | optimizers=optimizers, 82 | step=0, 83 | epoch=0, 84 | save_checkpoint=False, 85 | save_csv=args.eval_on_test_only, 86 | vocab=vocab, 87 | metrics=metrics, 88 | ) 89 | 90 | if not args.eval_on_test_only: 91 | for epoch in range(start_epoch, args.n_epochs + 1): 92 | 93 | start = time.time() 94 | 95 | model.finetuning = epoch >= args.finetuning if args.finetuning >= 0 else False 96 | 97 | metrics = model_class.train_one_epoch( 98 | args=args, 99 | model=model, 100 | train_iter=train_iter, 101 | optimizers=optimizers, 102 | criterion=criterion, 103 | vocab=vocab, 104 | eval_iter=eval_iter, 105 | epoch=epoch, 106 | metrics=metrics, 107 | ) 108 | 109 | logging.info(f"Evaluate in epoch {epoch}") 110 | metrics = model_class.evaluate( 111 | args, model, eval_iter, optimizers=optimizers, step=0, epoch=epoch, vocab=vocab, metrics=metrics, 112 | ) 113 | 114 | logging.info(f"{time.time() - start} per epoch") 115 | 116 | if lr_schedulers: 117 | for lr_scheduler in lr_schedulers: 118 | lr_scheduler.step(metrics.get_model_selection_metric()) 119 | -------------------------------------------------------------------------------- /bert_entity/preprocess_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from preprocessing.create_integerized_aida_conll_training import CreateIntegerizedCONLLTrainingData 4 | from preprocessing.preprocess_aida_conll_data import CreateAIDACONLL 5 | from preprocessing.create_integerized_wiki_training import CreateIntegerizedWikiTrainingData 6 | from preprocessing.create_keyword_matcher import CreateKeywordProcessor 7 | from preprocessing.create_disambiguation_dict import CreateDisambiguationDict 8 | from preprocessing.create_resolve_to_wiki_dicts import CreateResolveToWikiNameDicts 9 | from preprocessing.create_wiki_training_data import CreateWikiTrainingData 10 | from preprocessing.postprocess_mention_entity_counts import PostProcessMentionEntityCounts 11 | from pipeline_job import PipelineJob 12 | from preprocessing.collect_mention_entity_counts import CollectMentionEntityCounts 13 | from preprocessing.create_redirects import CreateRedirects 14 | from preprocessing.download_data import DownloadWikiDump 15 | from preprocessing.wikiextractor import Wikiextractor 16 | from misc import argparse_bool_type 17 | import configargparse as argparse 18 | import logging 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("-c", "--config", is_config_file=True, help="config file path") 23 | parser.add_argument("--debug", type=argparse_bool_type, default=False) 24 | parser.add_argument("--wiki_lang_version", type=str, help="wiki language version", default="enwiki") 25 | parser.add_argument("--data_version_name", type=str, help="data identifier/version") 26 | parser.add_argument("--download_data_only_dummy", type=argparse_bool_type, help="only download one wiki file") 27 | parser.add_argument("--download_2017_enwiki", type=argparse_bool_type, help="download the enwiki 2017 dump to reproduce the experiments for the CONLL 2019 paper", default=True) 28 | parser.add_argument("--num_most_freq_entities", type=int, help="") 29 | parser.add_argument("--add_missing_conll_entities", type=argparse_bool_type, help="") 30 | parser.add_argument("--uncased", type=argparse_bool_type, default=True) 31 | 32 | parser.add_argument("--collect_mention_entities_num_workers", type=int, default="10") 33 | 34 | parser.add_argument("--wikiextractor_num_workers", type=int, help="") 35 | 36 | parser.add_argument("--create_training_data_num_workers", type=int, default="10") 37 | parser.add_argument("--create_training_data_num_entities_in_necessary_articles", type=int, help="") 38 | parser.add_argument("--create_training_data_discount_nil_strategy", type=str, help="the discount strategy either 'hacky' or 'prop'", default="prop") 39 | 40 | parser.add_argument("--create_integerized_training_num_workers", type=int, default="10") 41 | parser.add_argument("--create_integerized_training_loc_file_name", type=str, default="data.loc") 42 | parser.add_argument("--create_integerized_training_instance_text_length", type=int, default="254") 43 | parser.add_argument("--create_integerized_training_instance_text_overlap", type=int, default="20") 44 | parser.add_argument("--create_integerized_training_max_entity_per_shard_count", type=int, default="10") 45 | parser.add_argument("--create_integerized_training_valid_size", type=int, default="1000") 46 | parser.add_argument("--create_integerized_training_test_size", type=int, default="1000") 47 | 48 | args = parser.parse_args() 49 | 50 | if args.download_2017_enwiki: 51 | if len(args.wiki_lang_version) > 0 and args.wiki_lang_version != 'enwiki': 52 | raise Exception(f"The configuration was set to 'download_2017_enwiki=True' but wiki_lang_version was set to {args.wiki_lang_version}.") 53 | 54 | if args.debug: 55 | logging.basicConfig(level=logging.DEBUG) 56 | else: 57 | logging.basicConfig(level=logging.INFO) 58 | 59 | for k, v in args.__dict__.items(): 60 | logging.info(f"{k}: {v}") 61 | if v == "None": 62 | args.__dict__[k] = None 63 | 64 | os.makedirs(f"data/versions/{args.data_version_name}/", exist_ok=True) 65 | 66 | with open(f"data/versions/{args.data_version_name}/config.yaml", "w") as f: 67 | f.writelines(["{}: {}\n".format(k, v) for k, v in args.__dict__.items()]) 68 | 69 | PipelineJob.run_jobs([ 70 | CreateRedirects, 71 | CreateResolveToWikiNameDicts, 72 | CreateDisambiguationDict, 73 | DownloadWikiDump, 74 | Wikiextractor, 75 | CollectMentionEntityCounts, 76 | PostProcessMentionEntityCounts, 77 | CreateAIDACONLL, 78 | CreateKeywordProcessor, 79 | CreateWikiTrainingData, 80 | CreateIntegerizedWikiTrainingData, 81 | CreateIntegerizedCONLLTrainingData, 82 | ], args) 83 | -------------------------------------------------------------------------------- /bert_entity/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | import pandas 5 | from itertools import cycle 6 | from operator import gt, lt 7 | 8 | 9 | class Metrics: 10 | 11 | meta = OrderedDict( 12 | [ 13 | ("epoch", {"comp": gt, "type": int, "str": lambda a: a}), 14 | ("step", {"comp": gt, "type": int, "str": lambda a: a}), 15 | ("num_gold", {"comp": gt, "type": int, "str": lambda a: a}), 16 | ("num_correct", {"comp": gt, "type": int, "str": lambda a: a}), 17 | ("num_proposed", {"comp": gt, "type": int, "str": lambda a: a}), 18 | ("f1", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}), 19 | ("f05", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}), 20 | ("precision", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}), 21 | ("recall", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}), 22 | ("span_f1", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}), 23 | ("span_precision", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}), 24 | ("span_recall", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}), 25 | ("lenient_span_f1", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}), 26 | ("lenient_span_precision", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}), 27 | ("lenient_span_recall", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}), 28 | ("precision_gold_mentions", {"comp": gt, "type": float, "str": lambda a: f"{a:.5f}"}), 29 | ("avg_loss", {"comp": lt, "type": float, "str": lambda a: f"{a:.5f}"}), 30 | ] 31 | ) 32 | 33 | def __init__(self, epoch=0, step=0, num_correct=0, num_gold=0, num_proposed=0, model_selection="f1", num_best_checkpoints=4, **kwargs): 34 | 35 | self.epoch = epoch 36 | self.step = step 37 | self.num_correct = num_correct 38 | self.num_gold = num_gold 39 | self.num_proposed = num_proposed 40 | 41 | self.precision = Metrics.compute_precision(num_correct, num_proposed) 42 | self.recall = Metrics.compute_recall(num_correct, num_gold) 43 | self.f1 = Metrics.compute_fmeasure(self.precision, self.recall) 44 | self.f05 = Metrics.compute_fmeasure(self.precision, self.recall, weight=1.5) 45 | 46 | self.avg_loss = float("inf") 47 | 48 | for k,v in kwargs.items(): 49 | if k in self.meta: 50 | self.__dict__[k] = v 51 | 52 | self.model_selection = model_selection 53 | self.checkpoint_cycle = cycle(range(num_best_checkpoints),) 54 | 55 | @staticmethod 56 | def compute_precision(correct, proposed): 57 | try: 58 | precision = correct/proposed 59 | except ZeroDivisionError: 60 | precision = 0.0 61 | return precision 62 | 63 | @staticmethod 64 | def compute_recall(correct, gold): 65 | try: 66 | recall = correct/ gold 67 | except ZeroDivisionError: 68 | recall = 0.0 69 | return recall 70 | 71 | @staticmethod 72 | def compute_fmeasure(precision, recall, weight=2.0): 73 | try: 74 | f = weight * precision * recall / (precision + recall) 75 | except ZeroDivisionError: 76 | f = 0.0 77 | return f 78 | 79 | def was_improved(self, other: "Metrics"): 80 | return Metrics.meta[self.model_selection]["comp"]( 81 | other.get_model_selection_metric(), self.get_model_selection_metric() 82 | ) 83 | 84 | def update(self, other: "Metrics"): 85 | if self.was_improved(other): 86 | for key, val in other.__dict__.items(): 87 | self.__setattr__(key, other.__dict__.get(key)) 88 | 89 | def get_model_selection_metric(self): 90 | return self.__dict__.get(self.model_selection) 91 | 92 | def get_best_checkpoint_filename(self): 93 | return f"best_{self.model_selection}-{next(self.checkpoint_cycle)}" 94 | 95 | def to_csv(self, epoch, step, args): 96 | header = ( 97 | [k for k in list(self.meta.keys()) if k in self.__dict__] 98 | if not os.path.exists("{}/log.csv".format(args.logdir)) 99 | else False 100 | ) 101 | pandas.DataFrame( 102 | [[self.__dict__[k] for k in list(self.meta.keys()) if k in self.__dict__]] 103 | ).to_csv(f"{args.logdir}/log.csv", mode="a", header=header) 104 | 105 | def dict(self): 106 | return self.__dict__ 107 | 108 | def __repr__(self): 109 | return str(self.__dict__) 110 | 111 | def report(self, filter=None): 112 | if not filter: 113 | filter = set(self.meta.keys()) 114 | return [f"{k}: {Metrics.meta[k]['str'](self.__dict__[k])}" for k in list(self.meta.keys()) if k in filter and k in self.__dict__] 115 | -------------------------------------------------------------------------------- /bert_entity/preprocessing/create_integerized_aida_conll_training.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from typing import Dict 3 | 4 | from pipeline_job import PipelineJob 5 | from pytorch_pretrained_bert import BertTokenizer 6 | 7 | 8 | class CreateIntegerizedCONLLTrainingData(PipelineJob): 9 | def __init__(self, preprocess_jobs: Dict[str, PipelineJob], opts): 10 | super().__init__( 11 | requires=[ 12 | "data/benchmarks/aida-yago2-dataset/conll_dataset.pickle", 13 | f"data/versions/{opts.data_version_name}/indexes/popular_entity_to_id_dict.pickle", 14 | ], 15 | provides=[f"data/benchmarks/aida-yago2-dataset/conll_dataset_{opts.data_version_name}_{opts.create_integerized_training_instance_text_length}-{opts.create_integerized_training_instance_text_overlap}.pickle",], 16 | preprocess_jobs=preprocess_jobs, 17 | opts=opts, 18 | ) 19 | 20 | def _run(self): 21 | 22 | with open("data/benchmarks/aida-yago2-dataset/conll_dataset.pickle", "rb") as f: 23 | conll_dataset = pickle.load(f) 24 | 25 | with open(f"data/versions/{self.opts.data_version_name}/indexes/popular_entity_to_id_dict.pickle", "rb") as f: 26 | popular_entity_to_id_dict = pickle.load(f) 27 | 28 | tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) 29 | 30 | # collect train valid test portionas and transform into our format 31 | 32 | train = list() 33 | valid = list() 34 | test = list() 35 | 36 | doc = None 37 | 38 | last_tok = None 39 | ents_in_doc = False 40 | split = None 41 | current_split_name = None 42 | bio_id = {'B': 0, 'I': 1, 'O': 2, } 43 | 44 | for item in conll_dataset: 45 | 46 | if current_split_name in ['train', 'valid', 'test'] and current_split_name != item['split'] and ents_in_doc: 47 | split.append(doc) 48 | doc = list() 49 | 50 | wiki_name = None 51 | bio = item['bio'] 52 | 53 | if item['split'] == 'train': 54 | split = train 55 | current_split_name = 'train' 56 | if item['split'] == 'valid': 57 | split = valid 58 | current_split_name = 'valid' 59 | if item['split'] == 'test': 60 | split = test 61 | current_split_name = 'test' 62 | 63 | if item['sent_nr'] > 0 and item['tok_nr'] == 0 and last_tok != '.': 64 | doc.append(('.', tokenizer.convert_tokens_to_ids(['.'])[0], 'O', bio_id['O'], None, -1, len(doc))) 65 | 66 | if item['doc_start']: 67 | if doc is not None: 68 | split.append(doc) 69 | doc = list() 70 | continue 71 | 72 | if item['wiki_name'] is not None and item['wiki_name'] != '--NME--': 73 | bio = 'O' 74 | for ent, c in item['wiki_name'].items(): 75 | # this statement only would not fire if add_missing_conll_entities was set to False 76 | # during preprocessing 77 | if ent in popular_entity_to_id_dict: 78 | wiki_name = ent 79 | bio = item['bio'] 80 | break 81 | 82 | # take care of all uppercase entity names, i.e. "LONDON" 83 | item_tok = item['tok'] 84 | if item_tok.isupper() and item_tok not in tokenizer.vocab.keys(): 85 | if (item_tok[0] + item_tok[1:].lower()) in tokenizer.vocab.keys(): 86 | item_tok = item_tok[0] + item_tok[1:].lower() 87 | elif item_tok.lower() in tokenizer.vocab.keys(): 88 | item_tok = item_tok.lower() 89 | 90 | last_bio = None 91 | for tok, tok_id in zip(tokenizer.tokenize(item_tok), tokenizer.convert_tokens_to_ids(tokenizer.tokenize(item_tok))): 92 | wiki_id = popular_entity_to_id_dict[wiki_name] if wiki_name else -1 93 | if wiki_id >= 0: 94 | if last_bio == 'B': 95 | bio = 'I' 96 | else: 97 | bio = 'O' 98 | doc.append((tok, tok_id, bio, bio_id[bio], wiki_name, wiki_id, len(doc))) 99 | last_bio = bio 100 | 101 | last_tok = item['tok'] 102 | 103 | def create_overlapping_chunks(a_list, n, overlap): 104 | for i in range(0, len(a_list), n - overlap): 105 | yield a_list[i:i + n] 106 | 107 | train_overlapped = list() 108 | valid_overlapped = list() 109 | test_overlapped = list() 110 | 111 | for doc in train: 112 | train_overlapped.extend( 113 | create_overlapping_chunks(doc, 114 | self.opts.create_integerized_training_instance_text_length, 115 | self.opts.create_integerized_training_instance_text_overlap)) 116 | 117 | for doc in valid: 118 | valid_overlapped.extend( 119 | create_overlapping_chunks(doc, 120 | self.opts.create_integerized_training_instance_text_length, 121 | self.opts.create_integerized_training_instance_text_overlap)) 122 | 123 | for doc in test: 124 | test_overlapped.extend( 125 | create_overlapping_chunks(doc, 126 | self.opts.create_integerized_training_instance_text_length, 127 | self.opts.create_integerized_training_instance_text_overlap)) 128 | 129 | with open(f"data/benchmarks/aida-yago2-dataset/conll_dataset_{self.opts.data_version_name}" 130 | f"_{self.opts.create_integerized_training_instance_text_length}" 131 | f"-{self.opts.create_integerized_training_instance_text_overlap}" 132 | f".pickle", "wb") as f: 133 | pickle.dump((train_overlapped, valid_overlapped, test_overlapped), f) -------------------------------------------------------------------------------- /bert_entity/train_util.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | 4 | import configargparse as argparse 5 | import torch.cuda 6 | import yaml 7 | 8 | from misc import argparse_bool_type 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("-c", "--config", is_config_file=True, help="config file path") 12 | parser.add_argument("--debug", type=argparse_bool_type, default=False) 13 | parser.add_argument("--device", default=0) 14 | parser.add_argument("--eval_device", default=None) 15 | parser.add_argument("--dataset", default="EDLDataset") 16 | parser.add_argument("--model", default="Net") 17 | parser.add_argument("--data_version_name") 18 | parser.add_argument("--wiki_lang_version") 19 | parser.add_argument("--eval_on_test_only", type=argparse_bool_type, default=False) 20 | parser.add_argument("--out_device", default=None) 21 | parser.add_argument("--batch_size", type=int, default=128) 22 | parser.add_argument("--eval_batch_size", type=int, default=128) 23 | parser.add_argument("--accumulate_batch_gradients", type=int, default=1) 24 | parser.add_argument("--sparse", dest="sparse", type=argparse_bool_type) 25 | parser.add_argument("--encoder_lr", type=float, default=5e-5) 26 | parser.add_argument("--decoder_lr", type=float, default=1e-3) 27 | parser.add_argument("--maskout_entity_prob", type=float, default=0) 28 | parser.add_argument("--segm_decoder_lr", type=float, default=1e-3) 29 | parser.add_argument("--encoder_weight_decay", type=float, default=0) 30 | parser.add_argument("--decoder_weight_decay", type=float, default=0) 31 | parser.add_argument("--segm_decoder_weight_decay", type=float, default=0) 32 | parser.add_argument("--learn_segmentation", type=argparse_bool_type, default=False) 33 | parser.add_argument("--label_size", type=int) 34 | parser.add_argument("--vocab_size", type=int) 35 | parser.add_argument("--entity_embedding_size", type=int, default=768) 36 | parser.add_argument("--project", type=argparse_bool_type, default=False) 37 | parser.add_argument("--n_epochs", type=int, default=1000) 38 | parser.add_argument("--collect_most_popular_labels_steps", type=int, default=100) 39 | parser.add_argument("--checkpoint_eval_steps", type=int, default=1000) 40 | parser.add_argument("--checkpoint_save_steps", type=int, default=50000) 41 | parser.add_argument("--finetuning", dest="finetuning", type=int, default=9999999999) 42 | parser.add_argument("--top_rnns", dest="top_rnns", type=argparse_bool_type) 43 | parser.add_argument("--logdir", type=str) 44 | parser.add_argument("--train_loc_file", type=str, default="train.loc") 45 | parser.add_argument("--valid_loc_file", type=str, default="valid.loc") 46 | parser.add_argument("--test_loc_file", type=str, default="test.loc") 47 | parser.add_argument("--resume_from_checkpoint", type=str) 48 | parser.add_argument("--resume_reset_epoch", type=argparse_bool_type, default=False) 49 | parser.add_argument("--resume_optimizer_from_checkpoint", type=argparse_bool_type, default=False) 50 | parser.add_argument("--topk_neg_examples", type=int, default=3) 51 | parser.add_argument("--dont_save_checkpoints", type=argparse_bool_type, default=False) 52 | parser.add_argument("--data_workers", type=int, default=8) 53 | parser.add_argument("--bert_dropout", type=float, default=None) 54 | parser.add_argument("--encoder_lr_scheduler", type=str, default=None) 55 | parser.add_argument("--encoder_lr_scheduler_config", default=None) 56 | parser.add_argument("--decoder_lr_scheduler", type=str, default=None) 57 | parser.add_argument("--decoder_lr_scheduler_config", default=None) 58 | parser.add_argument("--segm_decoder_lr_scheduler", type=str, default=None) 59 | parser.add_argument("--segm_decoder_lr_scheduler_config", default=None) 60 | parser.add_argument("--eval_before_training", type=argparse_bool_type, default=False) 61 | parser.add_argument("--data_path_conll", type=str,) 62 | parser.add_argument("--train_data_dir", type=str, default="data") 63 | parser.add_argument("--valid_data_dir", type=str, default="data") 64 | parser.add_argument("--test_data_dir", type=str, default="data") 65 | parser.add_argument("--exclude_parameter_names_regex", type=str) 66 | 67 | 68 | def get_args(): 69 | 70 | args = parser.parse_args() 71 | 72 | for k, v in args.__dict__.items(): 73 | print(k, ":", v) 74 | if v == "None": 75 | args.__dict__[k] = None 76 | 77 | args.device = ( 78 | int(args.device) if args.device is not None and args.device != "cpu" and torch.cuda.is_available() else "cpu" 79 | ) 80 | if args.eval_device is not None: 81 | if args.eval_device != "cpu": 82 | args.eval_device = int(args.eval_device) 83 | else: 84 | args.eval_device = "cpu" 85 | else: 86 | args.eval_device = args.device 87 | if args.out_device is not None: 88 | if args.out_device != "cpu": 89 | args.out_device = int(args.out_device) 90 | else: 91 | args.out_device = "cpu" 92 | else: 93 | args.out_device = args.device 94 | 95 | if args.encoder_lr_scheduler_config: 96 | args.encoder_lr_scheduler_config = ast.literal_eval(args.encoder_lr_scheduler_config) 97 | if args.decoder_lr_scheduler_config: 98 | args.decoder_lr_scheduler_config = ast.literal_eval(args.decoder_lr_scheduler_config) 99 | if args.segm_decoder_lr_scheduler_config: 100 | args.segm_decoder_lr_scheduler_config = ast.literal_eval(args.segm_decoder_lr_scheduler_config) 101 | 102 | args.eval_batch_size = args.eval_batch_size if args.eval_batch_size else args.batch_size 103 | 104 | if not args.logdir: 105 | raise Exception("set args.logdir") 106 | 107 | if not os.path.exists(args.logdir): 108 | os.makedirs(args.logdir) 109 | 110 | if not args.eval_on_test_only: 111 | config_fname = os.path.join(args.logdir, "config") 112 | with open(f"{config_fname}.yaml", "w") as f: 113 | f.writelines( 114 | [ 115 | "{}: {}\n".format(k, v) 116 | for k, v in args.__dict__.items() 117 | if isinstance(v, str) and len(v.strip()) > 0 or not isinstance(v, str) and v is not None 118 | ] 119 | ) 120 | 121 | with open(f"data/versions/{args.data_version_name}/config.yaml") as f: 122 | dataset = yaml.load(f, Loader=yaml.UnsafeLoader) 123 | 124 | for k, v in dataset.items(): 125 | if k != "debug": 126 | args.__setattr__(k, v) 127 | 128 | return args 129 | -------------------------------------------------------------------------------- /bert_entity/preprocessing/postprocess_mention_entity_counts.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import sys 3 | from collections import Counter, OrderedDict 4 | from typing import Dict 5 | 6 | import tqdm 7 | 8 | sys.setrecursionlimit(10000) 9 | 10 | from pipeline_job import PipelineJob 11 | 12 | 13 | class PostProcessMentionEntityCounts(PipelineJob): 14 | """ 15 | Create entity indexes that will later be used in the creation of the Wikipedia 16 | training data. 17 | First, based on the configuration key "num_most_freq_entities" the top k most 18 | popular entities are selected. Based on those, other mappings are created to only 19 | contain counts and priors concerning the top k popular entities. Later the top k 20 | popular entities will also restrict the training data to only contain instances 21 | that contain popular entities. 22 | Also, if "add_missing_conll_entities" is set, the entity ids that are missing 23 | in the top k popular entities we'll add the entities that are missing in the AIDA-CONLL 24 | benchmark to ensure that the evaluation measures are comparable to prior work. 25 | """ 26 | def __init__(self, preprocess_jobs: Dict[str, PipelineJob], opts): 27 | super().__init__( 28 | requires=[ 29 | f"data/versions/{opts.data_version_name}/indexes/mention_entity_counter.pickle", 30 | f"data/versions/{opts.data_version_name}/indexes/entity_counter.pickle", 31 | f"data/versions/{opts.data_version_name}/indexes/linked_mention_counter.pickle", 32 | f"data/versions/{opts.data_version_name}/indexes/found_conll_entities.pickle", 33 | ], 34 | provides=[ 35 | f"data/versions/{opts.data_version_name}/indexes/mention_entity_counter_popular_entities.pickle", 36 | f"data/versions/{opts.data_version_name}/indexes/popular_entity_counter_dict.pickle", 37 | f"data/versions/{opts.data_version_name}/indexes/popular_entity_to_id_dict.pickle", 38 | f"data/versions/{opts.data_version_name}/indexes/mention_to_popular_entity_id_probabilies_dicts_dict.pickle", 39 | ], 40 | preprocess_jobs=preprocess_jobs, 41 | opts=opts, 42 | ) 43 | 44 | def _run(self): 45 | 46 | with open( 47 | f"data/versions/{self.opts.data_version_name}/indexes/entity_counter.pickle", 48 | "rb", 49 | ) as f: 50 | all_entity_counter = pickle.load(f) 51 | 52 | with open( 53 | f"data/versions/{self.opts.data_version_name}/indexes/mention_entity_counter.pickle", 54 | "rb", 55 | ) as f: 56 | all_mention_entity_counter = pickle.load(f) 57 | 58 | with open( 59 | f"data/versions/{self.opts.data_version_name}/indexes/found_conll_entities.pickle", 60 | "rb", 61 | ) as f: 62 | all_found_conll_entities = pickle.load(f) 63 | 64 | # Create the index over the most popular entities (configured by num_most_freq_entities) 65 | # Then create the mention entity index based on that 66 | 67 | # Create the index over the most popular entities 68 | popular_entity_counter_dict = dict( 69 | all_entity_counter.most_common()[: self.opts.num_most_freq_entities] 70 | ) 71 | if self.opts.add_missing_conll_entities: 72 | # add entities required for the Aida-CoNLL benchmark dataset 73 | count = 0 74 | for ent in all_found_conll_entities: 75 | if ent in all_entity_counter: 76 | popular_entity_counter_dict[ent] = all_entity_counter[ent] 77 | count += 1 78 | self.log(f"Added {count} entities from the conll data back to the most popular entities vocabulary.") 79 | 80 | # Create the mention entity index based on that 81 | # TODO: filter rare entities for mentions / hacky heuristic to have cleaner data, can be improved 82 | mention_entity_counter_popular_entities = dict() 83 | for mention, entities in tqdm.tqdm(all_mention_entity_counter.items()): 84 | mention_entity_counter_popular_entities[mention] = Counter( 85 | { 86 | k: v 87 | for k, v in filter( 88 | lambda t: t[0] in popular_entity_counter_dict and t[1] > 9, 89 | entities.items(), 90 | ) 91 | } 92 | ).most_common() 93 | 94 | # Create a mapping from entities to ids 95 | popular_entity_to_id_dict = OrderedDict( 96 | [ 97 | (k, eid) 98 | for eid, (k, v) in enumerate( 99 | Counter(popular_entity_counter_dict).most_common() 100 | ) 101 | ] 102 | ) 103 | 104 | # Create a dictionary for the prior probablities p(e|m) of mentions to 105 | # ids of popular entities 106 | mention_to_popular_entity_id_probabilies_dicts_dict = { 107 | m: { 108 | popular_entity_to_id_dict[ename]: count 109 | / sum([val for key, val in entities]) 110 | for ename, count in entities 111 | if ename in popular_entity_to_id_dict 112 | } 113 | for m, entities in mention_entity_counter_popular_entities.items() 114 | } 115 | 116 | with open( 117 | f"data/versions/{self.opts.data_version_name}/indexes/mention_entity_counter_popular_entities.pickle", 118 | "wb", 119 | ) as f: 120 | pickle.dump(mention_entity_counter_popular_entities, f) 121 | 122 | with open( 123 | f"data/versions/{self.opts.data_version_name}/indexes/popular_entity_counter_dict.pickle", 124 | "wb", 125 | ) as f: 126 | pickle.dump(popular_entity_counter_dict, f) 127 | 128 | with open( 129 | f"data/versions/{self.opts.data_version_name}/indexes/popular_entity_to_id_dict.pickle", 130 | "wb", 131 | ) as f: 132 | pickle.dump(popular_entity_to_id_dict, f) 133 | 134 | with open( 135 | f"data/versions/{self.opts.data_version_name}/indexes/mention_to_popular_entity_id_probabilies_dicts_dict.pickle", 136 | "wb", 137 | ) as f: 138 | pickle.dump(mention_to_popular_entity_id_probabilies_dicts_dict, f) 139 | -------------------------------------------------------------------------------- /bert_entity/preprocessing/collect_mention_entity_counts.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import glob 3 | import json 4 | import io 5 | import pickle 6 | import sys 7 | import multiprocessing 8 | from collections import Counter 9 | from typing import Dict 10 | 11 | import pandas 12 | 13 | from tqdm import tqdm 14 | from misc import normalize_wiki_entity 15 | from pipeline_job import PipelineJob 16 | 17 | 18 | class CollectMentionEntityCounts(PipelineJob): 19 | """ 20 | Collect mention entity counts from the Wikiextractor files. 21 | """ 22 | def __init__(self, preprocess_jobs: Dict[str, PipelineJob], opts): 23 | super().__init__( 24 | requires=[ 25 | "data/indexes/redirects_en.ttl.bz2.dict", 26 | f"data/versions/{opts.data_version_name}/wikiextractor_out/{opts.wiki_lang_version}/", 27 | ], 28 | provides=[ 29 | f"data/versions/{opts.data_version_name}/indexes/mention_entity_counter.pickle", 30 | f"data/versions/{opts.data_version_name}/indexes/entity_counter.pickle", 31 | f"data/versions/{opts.data_version_name}/indexes/linked_mention_counter.pickle", 32 | ], 33 | preprocess_jobs=preprocess_jobs, 34 | opts=opts, 35 | ) 36 | 37 | def _run(self): 38 | 39 | pandas.set_option("display.max_rows", 5000) 40 | sys.setrecursionlimit(10000) 41 | 42 | self.log("Load ./data/indexes/redirects_en.ttl.bz2.dict") 43 | with open(f"data/indexes/redirects_en.ttl.bz2.dict", "rb") as f: 44 | redirects_en = pickle.load(f) 45 | 46 | in_queue = multiprocessing.Queue() 47 | out_queue = multiprocessing.Queue() 48 | 49 | workers = list() 50 | 51 | list_dir_string = f"data/versions/{self.opts.data_version_name}/wikiextractor_out/{self.opts.wiki_lang_version}/{self.opts.wiki_lang_version}*pages-articles*/*/wiki_*" 52 | 53 | # 54 | # start the workers in individual processes 55 | # 56 | for id in range(self.opts.collect_mention_entities_num_workers): 57 | worker = Worker(in_queue, out_queue, redirects_en) 58 | worker.start() 59 | workers.append(worker) 60 | 61 | self.log("Fill queue") 62 | # fill the queue 63 | for file_nr, extracted_wiki_file in enumerate(tqdm(glob.glob(list_dir_string))): 64 | in_queue.put(extracted_wiki_file) 65 | self.debug("put {} in queue".format(extracted_wiki_file)) 66 | 67 | all_linked_mention_counter = Counter() 68 | all_entity_counter = Counter() 69 | all_mention_entity_counter = dict() 70 | 71 | self.log("Collect the output") 72 | # collect the output 73 | for file_nr, extracted_wiki_file in enumerate(tqdm(glob.glob(list_dir_string))): 74 | ( 75 | ( 76 | local_linked_mention_counter, 77 | local_entity_counter, 78 | local_mention_entity_counter, 79 | ), 80 | in_file_name, 81 | ) = out_queue.get() 82 | all_linked_mention_counter.update(local_linked_mention_counter) 83 | all_entity_counter.update(local_entity_counter) 84 | for k, v in local_mention_entity_counter.items(): 85 | if k not in all_mention_entity_counter: 86 | all_mention_entity_counter[k] = Counter() 87 | all_mention_entity_counter[k].update(v) 88 | 89 | # for file_name in outputs: 90 | # print(file_name) 91 | # pass 92 | 93 | # put the None into the queue so the loop in the run() function of the worker stops 94 | for worker in workers: 95 | in_queue.put(None) 96 | out_queue.put(None) 97 | 98 | # terminate the process 99 | for worker in workers: 100 | worker.join() 101 | 102 | with open(f"data/versions/{self.opts.data_version_name}/indexes/linked_mention_counter.pickle", "wb") as f: 103 | pickle.dump(all_linked_mention_counter, f) 104 | 105 | with open(f"data/versions/{self.opts.data_version_name}/indexes/entity_counter.pickle", "wb") as f: 106 | pickle.dump(all_entity_counter, f) 107 | 108 | with open(f"data/versions/{self.opts.data_version_name}/indexes/mention_entity_counter.pickle", "wb") as f: 109 | pickle.dump(all_mention_entity_counter, f) 110 | 111 | 112 | class Worker(multiprocessing.Process): 113 | def __init__(self, in_queue, out_queue, redirects_en): 114 | super().__init__() 115 | self.in_queue = in_queue 116 | self.out_queue = out_queue 117 | self.redirects_en = redirects_en 118 | 119 | def run(self): 120 | # this loop will run until it receives None form the in_queue, if the queue is empty 121 | # the loop will wait until it gets something 122 | for next_item in iter(self.in_queue.get, None): 123 | file_name = next_item 124 | self.out_queue.put((self.extract_data(next_item), file_name)) 125 | 126 | def extract_data(self, file_name): 127 | 128 | local_entity_counter = Counter() 129 | local_linked_mention_counter = Counter() 130 | local_mention_entity_counter = Counter() 131 | 132 | with open(file_name) as f: 133 | 134 | for i, wiki_article in enumerate(f.readlines()): 135 | 136 | wiki_article = json.loads(wiki_article) 137 | 138 | start_offset_dict = dict() 139 | 140 | for ((start, end), (mention, wiki_page_name)) in pickle.loads( 141 | base64.b64decode(wiki_article["internal_links"].encode("utf-8")) 142 | ).items(): 143 | start_offset_dict[start] = (end, (mention, wiki_page_name)) 144 | if mention.startswith("Category:"): 145 | continue 146 | normalized_wiki_entity = normalize_wiki_entity( 147 | wiki_page_name, replace_ws=True 148 | ) 149 | entity = self.redirects_en.get( 150 | normalized_wiki_entity, normalized_wiki_entity 151 | ) 152 | local_linked_mention_counter[mention] += 1 153 | local_entity_counter[entity] += 1 154 | if mention not in local_mention_entity_counter: 155 | local_mention_entity_counter[mention] = Counter() 156 | local_mention_entity_counter[mention][entity] += 1 157 | 158 | return ( 159 | local_linked_mention_counter, 160 | local_entity_counter, 161 | local_mention_entity_counter, 162 | ) 163 | -------------------------------------------------------------------------------- /bert_entity/misc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | from collections import Counter 4 | 5 | import torch.optim 6 | from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau 7 | 8 | 9 | def capitalize(text: str) -> str: 10 | return text[0].upper() + text[1:] 11 | 12 | 13 | def snip(string, search, keep, keep_search): 14 | pos = string.find(search) 15 | if pos != -1: 16 | if keep == "left": 17 | if keep_search: 18 | pos += len(search) 19 | string = string[:pos] 20 | if keep == "right": 21 | if not keep_search: 22 | pos += len(search) 23 | string = string[pos:] 24 | return string 25 | 26 | 27 | def snip_anchor(text: str) -> str: 28 | return snip(text, "#", keep="left", keep_search=False) 29 | 30 | 31 | def normalize_wiki_entity(i, replace_ws=False): 32 | i = snip_anchor(i) 33 | if len(i) == 0: 34 | return None 35 | i = capitalize(i) 36 | if replace_ws: 37 | return i.replace(" ", "_") 38 | return i 39 | 40 | 41 | # most frequent English words from English Wikipedia 42 | stopwords = { 43 | "a", 44 | "also", 45 | "an", 46 | "are", 47 | "as", 48 | "at", 49 | "be", 50 | "by", 51 | "city", 52 | "company", 53 | "film", 54 | "first", 55 | "for", 56 | "from", 57 | "had", 58 | "has", 59 | "her", 60 | "his", 61 | "in", 62 | "is", 63 | "its", 64 | "john", 65 | "national", 66 | "new", 67 | "of", 68 | "on", 69 | "one", 70 | "people", 71 | "school", 72 | "state", 73 | "the", 74 | "their", 75 | "these", 76 | "this", 77 | "time", 78 | "to", 79 | "two", 80 | "university", 81 | "was", 82 | "were", 83 | "with", 84 | "world", 85 | } 86 | 87 | 88 | def get_stopwordless_token_set(s): 89 | result = set(s.lower().split(" ")) 90 | result_minus_stopwords = result.difference(stopwords) 91 | if len(result_minus_stopwords) == 0: 92 | return result 93 | else: 94 | return result_minus_stopwords 95 | 96 | 97 | def argparse_bool_type(v): 98 | "Type for argparse that correctly treats Boolean values" 99 | if isinstance(v, bool): 100 | return v 101 | if v.lower() in ("yes", "true", "t", "y", "1"): 102 | return True 103 | elif v.lower() in ("no", "false", "f", "n", "0"): 104 | return False 105 | else: 106 | raise argparse.ArgumentTypeError("Boolean value expected.") 107 | 108 | 109 | def get_gpu_memory_map(): 110 | """Get the current gpu usage. 111 | 112 | Returns 113 | ------- 114 | usage: dict 115 | Keys are device ids as integers. 116 | Values are memory usage as integers in MB. 117 | """ 118 | result = subprocess.check_output( 119 | ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"], encoding="utf-8" 120 | ) 121 | # Convert lines into a dictionary 122 | gpu_memory = [int(x) for x in result.strip().split("\n")] 123 | gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory)) 124 | return gpu_memory_map 125 | 126 | 127 | def create_chunks(a_list, n): 128 | for i in range(0, len(a_list), n): 129 | yield a_list[i : i + n] 130 | 131 | 132 | def unescape(s): 133 | if s.startswith('"'): 134 | s = s[1:-1] 135 | return s.replace('""""', '"').replace('""', '"') 136 | 137 | 138 | def create_overlapping_chunks(a_list, n, overlap): 139 | for i in range(0, len(a_list), n - overlap): 140 | yield a_list[i : i + n] 141 | 142 | 143 | def running_mean(new, old=None, momentum=0.9): 144 | if old is None: 145 | return new 146 | else: 147 | return momentum * old + (1 - momentum) * new 148 | 149 | 150 | def get_topk_ids_aggregated_from_seq_prediction(logits, topk_per_token, topk_from_batch): 151 | topk_logit_per_token, topk_eids_per_token = logits.topk(topk_per_token, sorted=False, dim=-1) 152 | 153 | i = torch.cat( 154 | [ 155 | topk_eids_per_token.view(1, -1), 156 | torch.zeros(topk_eids_per_token.view(-1).size(), dtype=torch.long, device=topk_eids_per_token.device).view( 157 | 1, -1 158 | ), 159 | ], 160 | dim=0, 161 | ) 162 | v = topk_logit_per_token.view(-1) 163 | st = torch.sparse.FloatTensor(i, v) 164 | stc = st.coalesce() 165 | topk_indices = stc._values().sort(descending=True)[1][:topk_from_batch] 166 | result = stc._indices()[0, topk_indices] 167 | 168 | return result.cpu().tolist() 169 | 170 | 171 | def get_entity_annotations(t, outside_id): 172 | annos = list() 173 | begin = -1 174 | in_entity = -1 175 | for i, j in enumerate(t): 176 | if j < outside_id and begin == -1: 177 | begin = i 178 | in_entity = j.item() 179 | elif j < outside_id and j != in_entity: 180 | annos.append((tuple(range(begin, i)), in_entity)) 181 | begin = i 182 | in_entity = j.item() 183 | elif j == outside_id and begin != -1: 184 | annos.append((tuple(range(begin, i)), in_entity)) 185 | begin = -1 186 | return annos 187 | 188 | 189 | def get_entity_annotations_with_gold_spans(t, t_gold, outside_id): 190 | annos = list() 191 | begin = -1 192 | in_gold_entity = -1 193 | collected_entities_in_span = Counter() 194 | for i, (j, j_gold) in enumerate(zip(t, t_gold)): 195 | if j_gold < outside_id and begin == -1: 196 | begin = i 197 | in_gold_entity = j_gold.item() 198 | collected_entities_in_span[j.item()] += 1 199 | elif j_gold != in_gold_entity and begin != -1: 200 | in_entity = collected_entities_in_span.most_common()[0][0] 201 | annos.append((tuple(range(begin, i)), in_entity)) 202 | collected_entities_in_span = Counter() 203 | begin = i 204 | in_gold_entity = j_gold.item() 205 | collected_entities_in_span[j.item()] += 1 206 | elif j_gold == outside_id and begin != -1: 207 | in_entity = collected_entities_in_span.most_common()[0][0] 208 | annos.append((tuple(range(begin, i)), in_entity)) 209 | collected_entities_in_span = Counter() 210 | begin = -1 211 | return annos 212 | 213 | 214 | class DummyOptimizer(torch.optim.Optimizer): 215 | def step(self, closure=None): 216 | pass 217 | 218 | 219 | class LRMilestones(_LRScheduler): 220 | """Set the learning rate of each parameter group to the initial lr decayed 221 | by gamma once the number of epoch reaches one of the milestones. When 222 | last_epoch=-1, sets initial lr as lr. 223 | 224 | Args: 225 | optimizer (Optimizer): Wrapped optimizer. 226 | milestones (list): List of epoch indices. Must be increasing. 227 | gamma (float): Multiplicative factor of learning rate decay. 228 | Default: 0.1. 229 | last_epoch (int): The index of last epoch. Default: -1. 230 | 231 | Example: 232 | >>> # Assuming optimizer uses lr = 0.05 for all groups 233 | >>> # lr = 0.05 if epoch < 30 234 | >>> # lr = 0.005 if 30 <= epoch < 80 235 | >>> # lr = 0.0005 if epoch >= 80 236 | >>> scheduler = LRMilestones(optimizer, milestones=[(30, 0.1), (80, 0.2), ]) 237 | >>> for epoch in range(100): 238 | >>> scheduler.step() 239 | >>> train(...) 240 | >>> validate(...) 241 | """ 242 | 243 | def __init__(self, optimizer, milestones, last_epoch=-1): 244 | super().__init__(optimizer) 245 | if not list(milestones) == sorted(milestones): 246 | raise ValueError("Milestones should be a list of" " increasing integers. Got {}", milestones) 247 | self.milestones = milestones 248 | super(LRMilestones, self).__init__(optimizer, last_epoch) 249 | 250 | def get_lr(self): 251 | for ep, lr in self.milestones: 252 | if self.last_epoch >= ep: 253 | print("Set lr to {} in epoch {}".format(lr, ep)) 254 | return lr 255 | 256 | 257 | def pad_to(arr, max_len, pad_id, cls_id, sep_id): 258 | return [cls_id] + arr + [sep_id] + [pad_id] * (max_len - len(arr) - 2) 259 | 260 | 261 | def set_out_id(t, repl, dummy=-1): 262 | t[(t == dummy)] = repl 263 | return t 264 | 265 | 266 | class LRSchedulers: 267 | ReduceLROnPlateau = ReduceLROnPlateau 268 | LRMilestones = LRMilestones 269 | -------------------------------------------------------------------------------- /bert_entity/data_loader_conll.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import OrderedDict 3 | 4 | import numpy 5 | import torch 6 | from torch.utils import data 7 | 8 | from misc import pad_to, set_out_id 9 | from vocab import Vocab 10 | 11 | 12 | 13 | class CONLLEDLDataset(data.Dataset): 14 | def __init__(self, args, split, vocab, device, label_size=None): 15 | 16 | if split == "train": 17 | train_valid_test_int = 0 18 | if split == "small_valid" or split == "valid": 19 | train_valid_test_int = 1 20 | if split == "test": 21 | train_valid_test_int = 2 22 | 23 | chunk_len = args.create_integerized_training_instance_text_length 24 | chunk_overlap = args.create_integerized_training_instance_text_overlap 25 | 26 | self.item_locs = None 27 | self.device = device 28 | with open(args.data_path_conll, "rb") as f: 29 | train_valid_test = pickle.load(f) 30 | self.conll_docs = torch.LongTensor( 31 | [ 32 | [ 33 | pad_to( 34 | [tok_id for _, tok_id, _, _, _, _, _ in doc], 35 | max_len=chunk_len + 2, 36 | pad_id=0, 37 | cls_id=101, 38 | sep_id=102, 39 | ) 40 | for doc in train_valid_test[train_valid_test_int] 41 | ], 42 | [ 43 | pad_to( 44 | [bio_id for _, _, _, bio_id, _, _, _ in doc], 45 | max_len=chunk_len + 2, 46 | pad_id=2, 47 | cls_id=2, 48 | sep_id=2, 49 | ) 50 | for doc in train_valid_test[train_valid_test_int] 51 | ], 52 | [ 53 | pad_to( 54 | [wiki_id for _, _, _, _, _, wiki_id, _ in doc], 55 | max_len=chunk_len + 2, 56 | pad_id=vocab.PAD_ID, 57 | cls_id=vocab.PAD_ID, 58 | sep_id=vocab.PAD_ID, 59 | ) 60 | for doc in train_valid_test[train_valid_test_int] 61 | ], 62 | [ 63 | pad_to( 64 | [doc_id for _, _, _, _, _, _, doc_id in doc], 65 | max_len=chunk_len + 2, 66 | pad_id=0, 67 | cls_id=0, 68 | sep_id=0, 69 | ) 70 | for doc in train_valid_test[train_valid_test_int] 71 | ], 72 | ] 73 | ).permute(1, 0, 2) 74 | self.conll_docs[:, 2] = set_out_id(self.conll_docs[:, 2], vocab.OUTSIDE_ID) 75 | 76 | self.pad_token_id = vocab.PAD_ID 77 | self.label_size = label_size 78 | self.labels = None 79 | self.train_valid_test_int = train_valid_test_int 80 | 81 | def get_data_iter( 82 | self, args, batch_size, vocab, train, 83 | ): 84 | return data.DataLoader( 85 | dataset=self.conll_docs, 86 | batch_size=batch_size, 87 | shuffle=train, 88 | num_workers=args.data_workers, 89 | collate_fn=self.collate_func( 90 | args, 91 | return_labels=args.collect_most_popular_labels_steps is not None 92 | and args.collect_most_popular_labels_steps > 0 93 | if train 94 | else True, 95 | vocab=vocab, 96 | ), 97 | ) 98 | 99 | def collate_func(self, args, vocab, return_labels): 100 | def collate(batch): 101 | return CONLLEDLDataset_collate_func( 102 | batch=batch, 103 | labels_with_high_model_score=self.labels, 104 | args=args, 105 | return_labels=return_labels, 106 | vocab=vocab, 107 | is_training=self.train_valid_test_int == 0, 108 | ) 109 | 110 | return collate 111 | 112 | 113 | def CONLLEDLDataset_collate_func( 114 | batch, labels_with_high_model_score, args, return_labels, vocab: Vocab, is_training=False, 115 | ): 116 | drop_entity_mentions_prob = args.maskout_entity_prob 117 | # print([b[0] for b in batch]) 118 | label_size = args.label_size 119 | batch_token_ids = torch.LongTensor([b[0].tolist() for b in batch]) 120 | batch_bio_ids = [b[1].tolist() for b in batch] 121 | batch_entity_ids = [b[2].tolist() for b in batch] 122 | batch_doc_ids = [b[3, 0].item() for b in batch] 123 | 124 | if return_labels: 125 | 126 | all_batch_entity_ids = OrderedDict() 127 | 128 | for batch_offset, one_item_entity_ids in enumerate(batch_entity_ids): 129 | for tok_id, eid in enumerate(one_item_entity_ids): 130 | if eid not in all_batch_entity_ids: 131 | all_batch_entity_ids[eid] = len(all_batch_entity_ids) 132 | 133 | if label_size is not None: 134 | 135 | batch_shared_label_ids = all_batch_entity_ids.keys() 136 | negative_samples = set() 137 | if labels_with_high_model_score is not None: 138 | # print(labels_with_high_model_score) 139 | negative_samples = set(labels_with_high_model_score) 140 | # else: 141 | # negative_samples = set(numpy.random.choice(vocab.OUTSIDE_ID, label_size, replace=False)) 142 | if len(negative_samples) < label_size: 143 | random_negative_samples = set(numpy.random.choice(vocab.OUTSIDE_ID, label_size, replace=False)) 144 | negative_samples = negative_samples.union(random_negative_samples) 145 | 146 | negative_samples.difference_update(batch_shared_label_ids) 147 | 148 | if len(batch_shared_label_ids) + len(negative_samples) < label_size: 149 | negative_samples.difference_update( 150 | set(numpy.random.choice(vocab.OUTSIDE_ID, label_size, replace=False)) 151 | ) 152 | 153 | batch_shared_label_ids = (list(batch_shared_label_ids) + list(negative_samples))[:label_size] 154 | label_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), len(batch_shared_label_ids)) 155 | bio_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), 3) 156 | 157 | else: 158 | 159 | batch_shared_label_ids = list(all_batch_entity_ids.keys()) 160 | label_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), args.vocab_size) 161 | bio_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), 3) 162 | 163 | drop_probs = None 164 | if drop_entity_mentions_prob > 0 and is_training: 165 | drop_probs = torch.rand((batch_token_ids.size(0), batch_token_ids.size(1)),) < drop_entity_mentions_prob 166 | 167 | for batch_offset, (one_item_entity_ids, one_item_bio_ids) in enumerate(zip(batch_entity_ids, batch_bio_ids)): 168 | for tok_id, one_entity_ids in enumerate(one_item_entity_ids): 169 | 170 | if ( 171 | is_training 172 | and vocab.OUTSIDE_ID != one_entity_ids 173 | and drop_entity_mentions_prob > 0 174 | and drop_probs[batch_offset][tok_id].item() == 1 175 | ): 176 | batch_token_ids[batch_offset][tok_id] = vocab.tokenizer.vocab["[MASK]"] 177 | 178 | if label_size is not None: 179 | label_probs[batch_offset][tok_id][torch.LongTensor([all_batch_entity_ids[one_entity_ids]])] = 1.0 180 | else: 181 | label_probs[batch_offset][tok_id][torch.LongTensor(one_entity_ids)] = 1.0 182 | bio_probs[batch_offset][tok_id][torch.LongTensor(one_item_bio_ids)] = 1.0 183 | 184 | label_ids = torch.LongTensor(batch_shared_label_ids) 185 | 186 | return ( 187 | batch_token_ids, 188 | label_ids, 189 | torch.LongTensor(batch_bio_ids), 190 | torch.FloatTensor(label_probs), 191 | bio_probs, 192 | None, 193 | {v: k for k, v in all_batch_entity_ids.items()}, 194 | batch_entity_ids, 195 | batch_doc_ids, 196 | batch, 197 | ) 198 | 199 | else: 200 | 201 | return batch_token_ids, None, None, None, None, None, None, None, None, batch 202 | 203 | -------------------------------------------------------------------------------- /bert_entity/data_loader_wiki.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import OrderedDict 4 | from operator import itemgetter 5 | 6 | import numpy 7 | import torch 8 | from torch.utils import data 9 | from tqdm import tqdm 10 | 11 | from vocab import Vocab 12 | 13 | 14 | class EDLDataset(data.Dataset): 15 | def __init__(self, args, split, vocab, device, label_size=None): 16 | 17 | if split == "train": 18 | loc_file_name = args.train_loc_file 19 | self.data_dir = args.train_data_dir 20 | elif split == "valid": 21 | loc_file_name = args.valid_loc_file 22 | self.data_dir = args.valid_data_dir 23 | elif split == "test": 24 | loc_file_name = args.test_loc_file 25 | self.data_dir = args.test_data_dir 26 | 27 | 28 | self.data_path = f"data/versions/{args.data_version_name}/wiki_training/integerized/{args.wiki_lang_version}/" 29 | self.item_locs = None 30 | self.device = device 31 | if os.path.exists("{}.pickle".format(self.data_path + loc_file_name)): 32 | with open("{}.pickle".format(self.data_path + loc_file_name), "rb") as f: 33 | self.item_locs = pickle.load(f) 34 | else: 35 | with open(self.data_path + loc_file_name) as f: 36 | self.item_locs = list(map(lambda x: list(map(int, x.strip().split())), tqdm(f.readlines()))) 37 | with open("{}.pickle".format(self.data_path + loc_file_name), "wb") as f: 38 | pickle.dump(self.item_locs, f) 39 | self.pad_token_id = vocab.PAD_ID 40 | self.label_size = label_size 41 | self.is_training = split == "train" 42 | 43 | def get_data_iter( 44 | self, args, batch_size, vocab, train, 45 | ): 46 | return data.DataLoader( 47 | dataset=self.item_locs, 48 | batch_size=batch_size, 49 | shuffle=train, 50 | num_workers=args.data_workers, 51 | collate_fn=self.collate_func( 52 | args=args, 53 | vocab=vocab, 54 | return_labels=args.collect_most_popular_labels_steps is not None 55 | and args.collect_most_popular_labels_steps > 0 56 | if train 57 | else True, 58 | ), 59 | ) 60 | 61 | # def collate_func(self, args, vocab, return_labels, shards, shards_locks): 62 | def collate_func( 63 | self, args, vocab, return_labels, in_queue=None, out_queue=None, 64 | ): 65 | def collate(batch): 66 | return EDLDataset_collate_func( 67 | batch=batch, 68 | labels_with_high_model_score=None, 69 | args=args, 70 | return_labels=return_labels, 71 | data_path=self.data_path, 72 | vocab=vocab, 73 | is_training=self.is_training, 74 | ) 75 | 76 | return collate 77 | 78 | 79 | def EDLDataset_collate_func( 80 | batch, 81 | labels_with_high_model_score, 82 | args, 83 | return_labels, 84 | vocab: Vocab, 85 | data_path=None, 86 | is_training=True, 87 | drop_entity_mentions_prob=0.0, 88 | loaded_batch=None, 89 | ): 90 | if loaded_batch is None: 91 | batch_dict_list = list() 92 | for shard, offset in batch: 93 | # print('{}/{}.dat'.format(data_path, shard), offset) 94 | with open("{}/{}.dat".format(data_path, shard), "rb") as f: 95 | f.seek(offset) 96 | ( 97 | token_ids_chunk, 98 | mention_entity_ids_chunk, 99 | mention_entity_probs_chunk, 100 | mention_probs_chunk, 101 | ) = pickle.load(f) 102 | try: 103 | eval_mask = list(map(is_a_wikilink_or_keyword, mention_probs_chunk)) 104 | mention_entity_ids_chunk = list(map(itemgetter(0), mention_entity_ids_chunk)) 105 | mention_entity_probs_chunk = list(map(itemgetter(0), mention_entity_probs_chunk)) 106 | batch_dict_list.append( 107 | { 108 | "token_ids": token_ids_chunk, 109 | "entity_ids": mention_entity_ids_chunk, 110 | "entity_probs": mention_entity_probs_chunk, 111 | "eval_mask": eval_mask, 112 | } 113 | ) 114 | except Exception as e: 115 | print(f"pickle.load(shards[shard]) failed {e}") 116 | print(mention_entity_ids_chunk) 117 | print(mention_entity_probs_chunk) 118 | raise e 119 | 120 | f = lambda x: [sample[x] for sample in batch_dict_list] 121 | # print(batch) 122 | batch_token_ids = f("token_ids") 123 | batch_entity_ids = f("entity_ids") 124 | batch_entity_probs = f("entity_probs") 125 | eval_mask = f("eval_mask") 126 | maxlen = max([len(chunk) for chunk in batch_token_ids]) 127 | 128 | eval_mask = torch.LongTensor([sample + [0] * (maxlen - len(sample)) for sample in eval_mask]) 129 | 130 | # create dictionary mapping the vocabulary entity id to a batch label id 131 | # 132 | # e.g. 133 | # all_batch_entity_ids[324] = 0 134 | # all_batch_entity_ids[24] = 1 135 | # all_batch_entity_ids[2] = 2 136 | # all_batch_entity_ids[987] = 3 137 | # 138 | all_batch_entity_ids = OrderedDict() 139 | 140 | for batch_offset, (batch_item_token_item_entity_ids, batch_item_token_entity_probs) in enumerate( 141 | zip(batch_entity_ids, batch_entity_probs) 142 | ): 143 | for tok_id, (token_entity_ids, token_entity_probs) in enumerate( 144 | zip(batch_item_token_item_entity_ids, batch_item_token_entity_probs) 145 | ): 146 | for eid in token_entity_ids: 147 | if eid not in all_batch_entity_ids: 148 | all_batch_entity_ids[eid] = len(all_batch_entity_ids) 149 | 150 | loaded_batch = ( 151 | batch_token_ids, 152 | batch_entity_ids, 153 | batch_entity_probs, 154 | eval_mask, 155 | all_batch_entity_ids, 156 | maxlen, 157 | ) 158 | 159 | else: 160 | (batch_token_ids, batch_entity_ids, batch_entity_probs, eval_mask, all_batch_entity_ids, maxlen,) = loaded_batch 161 | 162 | batch_token_ids = torch.LongTensor([sample + [0] * (maxlen - len(sample)) for sample in batch_token_ids]) 163 | 164 | if return_labels: 165 | 166 | # if labels for each token should be over 167 | # a. the whole entity vocabulary 168 | # b. a reduced set of entities composed of: 169 | # set of batch's true entities, entities 170 | # set of entities with the largest logits 171 | # set of negative samples 172 | 173 | if args.label_size is None: 174 | 175 | batch_shared_label_ids = list(all_batch_entity_ids.keys()) 176 | label_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), args.vocab_size) 177 | 178 | else: 179 | 180 | # batch_shared_label_ids are constructing by incrementally concatenating 181 | # set of batch's true entities, entities 182 | # set of entities with the largest logits 183 | # set of negative samples 184 | 185 | batch_shared_label_ids = list(all_batch_entity_ids.keys()) 186 | 187 | if len(batch_shared_label_ids) < args.label_size and labels_with_high_model_score is not None: 188 | # print(labels_with_high_model_score) 189 | negative_examples = set(labels_with_high_model_score) 190 | negative_examples.difference_update(batch_shared_label_ids) 191 | batch_shared_label_ids += list(negative_examples) 192 | 193 | if len(batch_shared_label_ids) < args.label_size: 194 | negative_samples = set(numpy.random.choice(vocab.OUTSIDE_ID, args.label_size, replace=False)) 195 | negative_samples.difference_update(batch_shared_label_ids) 196 | batch_shared_label_ids += list(negative_samples) 197 | 198 | batch_shared_label_ids = batch_shared_label_ids[: args.label_size] 199 | 200 | label_probs = torch.zeros(batch_token_ids.size(0), batch_token_ids.size(1), len(batch_shared_label_ids)) 201 | 202 | drop_probs = None 203 | if drop_entity_mentions_prob > 0 and is_training: 204 | drop_probs = torch.rand((batch_token_ids.size(0), batch_token_ids.size(1)),) < drop_entity_mentions_prob 205 | 206 | # loop through the batch x tokens x (label_ids, label_probs) 207 | for batch_offset, (batch_item_token_item_entity_ids, batch_item_token_entity_probs) in enumerate( 208 | zip(batch_entity_ids, batch_entity_probs) 209 | ): 210 | # loop through tokens x (label_ids, label_probs) 211 | for tok_id, (token_entity_ids, token_entity_probs) in enumerate( 212 | zip(batch_item_token_item_entity_ids, batch_item_token_entity_probs) 213 | ): 214 | if drop_entity_mentions_prob > 0 and is_training and drop_probs[batch_offset][tok_id].item() == 1: 215 | batch_token_ids[batch_offset][tok_id] = vocab.tokenizer.vocab["[MASK]"] 216 | 217 | if args.label_size is None: 218 | label_probs[batch_offset][tok_id][torch.LongTensor(token_entity_ids)] = torch.Tensor( 219 | batch_item_token_item_entity_ids 220 | ) 221 | else: 222 | label_probs[batch_offset][tok_id][ 223 | torch.LongTensor(list(map(all_batch_entity_ids.__getitem__, token_entity_ids))) 224 | ] = torch.Tensor(token_entity_probs) 225 | 226 | label_ids = torch.LongTensor(batch_shared_label_ids) 227 | 228 | return ( 229 | batch_token_ids, 230 | label_ids, 231 | label_probs, 232 | torch.LongTensor(eval_mask), 233 | {v: k for k, v in all_batch_entity_ids.items()}, 234 | batch_entity_ids, 235 | batch, 236 | loaded_batch, 237 | ) 238 | 239 | else: 240 | 241 | return batch_token_ids, None, None, None, None, None, batch, loaded_batch 242 | 243 | # hack to detect if an entity annotation was a 244 | # wikilink (== only one entity label) or a 245 | # keyword matcher annotation (== multiple entity labels) 246 | def is_a_wikilink_or_keyword(item): 247 | if len(item) == 1: 248 | return 1 249 | else: 250 | return 0 251 | -------------------------------------------------------------------------------- /bert_entity/preprocessing/preprocess_aida_conll_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import subprocess 4 | from collections import Counter 5 | from typing import Dict 6 | 7 | import tqdm 8 | 9 | from pipeline_job import PipelineJob 10 | 11 | 12 | class CreateAIDACONLL(PipelineJob): 13 | def __init__(self, preprocess_jobs: Dict[str, PipelineJob], opts): 14 | super().__init__( 15 | requires=[ 16 | "data/indexes/redirects_en.ttl.bz2.dict", 17 | "data/indexes/freebase_links_en.ttl.bz2.dict", 18 | "data/indexes/page_ids_en.ttl.bz2.dict", 19 | "data/indexes/disambiguations_en.ttl.bz2.dict", 20 | "data/benchmarks/aida-yago2-dataset/AIDA-YAGO2-dataset.tsv", 21 | ], 22 | provides=[ 23 | "data/benchmarks/aida-yago2-dataset/conll_dataset.pickle", 24 | f"data/versions/{opts.data_version_name}/indexes/found_conll_entities.pickle", 25 | f"data/versions/{opts.data_version_name}/indexes/not_found_conll_entities.pickle", 26 | ], 27 | preprocess_jobs=preprocess_jobs, 28 | opts=opts, 29 | ) 30 | 31 | def _run(self): 32 | 33 | with open("data/indexes/redirects_en.ttl.bz2.dict", "rb") as f: 34 | redirects_en = pickle.load(f) 35 | 36 | redirects_en_values = set(redirects_en.values()) 37 | 38 | with open("data/indexes/freebase_links_en.ttl.bz2.dict", "rb") as f: 39 | fb_to_wikiname_dict = pickle.load(f) 40 | 41 | with open("data/indexes/disambiguations_en.ttl.bz2.dict", "rb") as f: 42 | disambiguations_dict = pickle.load(f) 43 | 44 | with open("data/indexes/page_ids_en.ttl.bz2.dict", "rb") as f: 45 | page_id_to_wikiname_dict = pickle.load(f) 46 | 47 | conll2003_ner_en = self._download( 48 | url="https://www.clips.uantwerpen.be/conll2003/ner.tgz", 49 | folder="data/downloads", 50 | ) 51 | 52 | subprocess.check_call( 53 | [ 54 | "tar", 55 | "xzf", 56 | conll2003_ner_en, 57 | "-C", 58 | "data/benchmarks/aida-yago2-dataset/", 59 | ] 60 | ) 61 | 62 | try: 63 | subprocess.call( 64 | [ 65 | "cat data/benchmarks/aida-yago2-dataset/ner/etc/tags.eng" 66 | " data/benchmarks/aida-yago2-dataset/ner/etc/tags.eng.testb" 67 | " >" 68 | " data/benchmarks/aida-yago2-dataset/ner/eng.all" 69 | ], 70 | shell=True) 71 | except subprocess.CalledProcessError as e: 72 | print(e.output) 73 | 74 | # merge AIDA-YAGO2-dataset.tsv and the CONLL2003 NER dataset eng.all to 75 | # also have BIO-NER tags 76 | with open("data/benchmarks/aida-yago2-dataset/AIDA-YAGO2-dataset.tsv") as f1: 77 | with open("data/benchmarks/aida-yago2-dataset/ner/eng.all") as f2: 78 | merged_el = list() 79 | el = f1.readlines() 80 | ner = f2.readlines() 81 | ner_i = 0 82 | last_entity = None 83 | for el_i, el_line in enumerate(el): 84 | # if len(merged_el) > 1 and len(merged_el) < 100: print(merged_el[-1]) 85 | ner_line = ner[ner_i] 86 | if len(el_line.strip()) == 0: 87 | merged_el.append([""]) 88 | continue 89 | while len(el_line.strip()) > 0 and len(ner_line.strip()) == 0: 90 | ner_i += 1 91 | ner_line = ner[ner_i] 92 | 93 | if el_line.startswith("-DOCSTART-"): 94 | merged_el.append([el_line.strip()]) 95 | ner_i += 1 96 | continue 97 | 98 | el_fields = el_line.strip().split("\t") 99 | ner_fields = ner_line.strip().split() 100 | 101 | if len(el_fields) == 1: 102 | bio_ner = "O" 103 | rest = [] 104 | else: 105 | bio, etype = ner_fields[2].split("-") 106 | if last_entity != el_fields[3]: 107 | bio = "B" 108 | rest = el_fields[2:] 109 | last_entity = el_fields[3] 110 | bio_ner = f"{bio}-{etype}" 111 | 112 | merged_el.append([el_fields[0], bio_ner] + rest) 113 | ner_i += 1 114 | 115 | conll_dataset = list() 116 | mentions = list() 117 | sentence_nr = 0 118 | tok_nr = 0 119 | split = "train" 120 | 121 | print(merged_el[:10]) 122 | 123 | for line_nr, line_items in enumerate(tqdm.tqdm(merged_el)): 124 | 125 | if ( 126 | len(line_items) > 0 127 | and line_items[0].startswith("-DOCSTART-") 128 | and "testa" in line_items[0] 129 | ): 130 | split = "valid" 131 | 132 | if ( 133 | len(line_items) > 0 134 | and line_items[0].startswith("-DOCSTART-") 135 | and "testb" in line_items[0] 136 | ): 137 | split = "test" 138 | 139 | if len(line_items) == 1 and line_items[0] == '': 140 | sentence_nr += 1 141 | tok_nr = 0 142 | elif len(line_items) > 0 and line_items[0].startswith("-DOCSTART-"): 143 | sentence_nr = 0 144 | conll_dataset.append( 145 | { 146 | "tok": " ".join(line_items), 147 | "bio-tag": None, 148 | "bio": None, 149 | "tag": None, 150 | "mention": None, 151 | "yago_name": None, 152 | "wiki_name": None, 153 | "wiki_id": None, 154 | "fb_id": None, 155 | "doc_start": True, 156 | "is_nil": None, 157 | "sent_nr": sentence_nr, 158 | "tok_nr": tok_nr, 159 | "split": split, 160 | } 161 | ) 162 | elif len(line_items) == 2: 163 | tok_nr += 1 164 | conll_dataset.append( 165 | { 166 | "tok": line_items[0], 167 | "bio-tag": "O", 168 | "bio": "O", 169 | "tag": None, 170 | "mention": None, 171 | "yago_name": None, 172 | "wiki_name": None, 173 | "wiki_id": None, 174 | "fb_id": None, 175 | "doc_start": False, 176 | "is_nil": None, 177 | "sent_nr": sentence_nr, 178 | "tok_nr": tok_nr, 179 | "split": split, 180 | } 181 | ) 182 | elif len(line_items) == 4: 183 | tok_nr += 1 184 | conll_dataset.append( 185 | { 186 | "tok": line_items[0], 187 | "bio-tag": line_items[1], 188 | "bio": line_items[1].split("-")[0], 189 | "tag": line_items[1].split("-")[1], 190 | "mention": line_items[2], 191 | "yago_name": line_items[3], 192 | "wiki_name": line_items[3], 193 | "wiki_id": line_items[3], 194 | "fb_id": line_items[3], 195 | "doc_start": False, 196 | "is_nil": True, 197 | "sent_nr": sentence_nr, 198 | "tok_nr": tok_nr, 199 | "split": split, 200 | } 201 | ) 202 | elif len(line_items) == 6 or len(line_items) == 7: 203 | tok_nr += 1 204 | conll_dataset.append( 205 | { 206 | "tok": line_items[0], 207 | "bio-tag": line_items[1], 208 | "bio": line_items[1].split("-")[0], 209 | "tag": line_items[1].split("-")[1], 210 | "mention": line_items[2], 211 | "yago_name": line_items[3], 212 | "wiki_name": Counter({line_items[4].split("/")[-1]: 1}), 213 | "wiki_id": line_items[5], 214 | "fb_id": line_items[6] if len(line_items) == 7 else None, 215 | "doc_start": False, 216 | "is_nil": False, 217 | "sent_nr": sentence_nr, 218 | "tok_nr": tok_nr, 219 | "split": split, 220 | } 221 | ) 222 | 223 | if conll_dataset[-1]["fb_id"] in fb_to_wikiname_dict: 224 | key = fb_to_wikiname_dict[conll_dataset[-1]["fb_id"]] 225 | if key in redirects_en: 226 | key = redirects_en[key] 227 | conll_dataset[-1]["wiki_name"][key] += 1 228 | 229 | if conll_dataset[-1]["wiki_id"] in page_id_to_wikiname_dict: 230 | key = page_id_to_wikiname_dict[conll_dataset[-1]["wiki_id"]] 231 | if key in redirects_en: 232 | key = redirects_en[key] 233 | conll_dataset[-1]["wiki_name"][key] += 1 234 | 235 | for wn in set( 236 | map(redirects_en.get, conll_dataset[-1]["wiki_name"].keys()) 237 | ): 238 | if wn not in conll_dataset[-1]["wiki_name"]: 239 | conll_dataset[-1]["wiki_name"][wn] += 1 240 | 241 | if conll_dataset[-1]["mention"].replace(" ", "_") in redirects_en: 242 | conll_dataset[-1]["wiki_name"][ 243 | redirects_en[conll_dataset[-1]["mention"].replace(" ", "_")] 244 | ] += 1 245 | 246 | for wn in set(conll_dataset[-1]["wiki_name"].keys()): 247 | if wn in disambiguations_dict: 248 | conll_dataset[-1]["wiki_name"][wn] = 0 249 | 250 | for wn in set(conll_dataset[-1]["wiki_name"].keys()): 251 | if wn in redirects_en_values: 252 | conll_dataset[-1]["wiki_name"][wn] += 5 253 | 254 | else: 255 | raise Exception("Error {}".format(line_items)) 256 | 257 | if ( 258 | len(conll_dataset) > 0 259 | and conll_dataset[-1]["bio"] == "B" 260 | and not conll_dataset[-1]["is_nil"] 261 | ): 262 | mentions.append(conll_dataset[-1]) 263 | 264 | # if len(conll_dataset) > 0: 265 | # print(conll_dataset[-1]) 266 | # if line_nr > 1000: 267 | # break 268 | 269 | with open("data/benchmarks/aida-yago2-dataset/conll_dataset.pickle", "wb") as f: 270 | pickle.dump(conll_dataset, f) 271 | 272 | with open( 273 | f"data/versions/{self.opts.data_version_name}/indexes/entity_counter.pickle", 274 | "rb", 275 | ) as f: 276 | all_entity_counter = pickle.load(f) 277 | 278 | all_found_conll_entities = set() 279 | all_conll_entities = set() 280 | all_not_found_conll_entities = set() 281 | 282 | for item in conll_dataset: 283 | if not item["is_nil"] and item["bio"] == "B": 284 | name, count = item["wiki_name"].most_common()[0] 285 | if name in all_entity_counter: 286 | all_found_conll_entities.add(name) 287 | else: 288 | all_not_found_conll_entities.add(name) 289 | all_conll_entities.add(name) 290 | 291 | with open( 292 | f"data/versions/{self.opts.data_version_name}/indexes/found_conll_entities.pickle", 293 | "wb", 294 | ) as f: 295 | pickle.dump(all_found_conll_entities, f) 296 | 297 | with open( 298 | f"data/versions/{self.opts.data_version_name}/indexes/not_found_conll_entities.pickle", 299 | "wb", 300 | ) as f: 301 | pickle.dump(all_not_found_conll_entities, f) 302 | 303 | self.log(f"Found {len(all_found_conll_entities)} and not found {len(all_not_found_conll_entities)} of AIDA-CoNLL entities in the entities dictionary.") 304 | 305 | -------------------------------------------------------------------------------- /downstream_experiments/ensemble_bert_modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.nn import CrossEntropyLoss 5 | 6 | from pytorch_pretrained_bert import BertForQuestionAnswering, BertForSequenceClassification, BertForMultipleChoice 7 | 8 | 9 | class EnsembleBertForSequenceClassification(nn.Module): 10 | """BERT model for classification. 11 | This module is composed of the BERT model with a linear layer on top of 12 | the pooled output. 13 | 14 | Params: 15 | `config`: a BertConfig class instance with the configuration to build a new model. 16 | `num_labels`: the number of classes for the classifier. Default = 2. 17 | 18 | Inputs: 19 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 20 | with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts 21 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 22 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 23 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 24 | a `sentence B` token (see BERT paper for more details). 25 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 26 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 27 | input sequence length in the current batch. It's the mask that we typically use for attention when 28 | a batch has varying length sentences. 29 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 30 | with indices selected in [0, ..., num_labels]. 31 | 32 | Outputs: 33 | if `labels` is not `None`: 34 | Outputs the CrossEntropy classification loss of the output with the labels. 35 | if `labels` is `None`: 36 | Outputs the classification logits of shape [batch_size, num_labels]. 37 | 38 | Example usage: 39 | ```python 40 | # Already been converted into WordPiece token ids 41 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 42 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 43 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 44 | 45 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 46 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 47 | 48 | num_labels = 2 49 | 50 | model = BertForSequenceClassification(config, num_labels) 51 | logits = model(input_ids, token_type_ids, input_mask) 52 | ``` 53 | """ 54 | def __init__(self, num_labels, model_1:BertForSequenceClassification, model_2:BertForSequenceClassification): 55 | super().__init__() 56 | self.num_labels = num_labels 57 | self.model_1 = model_1 58 | self.model_2 = model_2 59 | self.config = self.model_1.config 60 | self.dropout = nn.Dropout(self.model_1.dropout.p) 61 | self.classifier = nn.Linear(self.model_1.classifier.weight.size(1)*2, num_labels) 62 | 63 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 64 | _, pooled_output_model_1 = self.model_1.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 65 | _, pooled_output_model_2 = self.model_2.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 66 | # print(pooled_output_model_1.size(), pooled_output_model_2.size(), self.classifier.weight.size()) 67 | pooled_output = self.dropout(torch.cat([pooled_output_model_1, pooled_output_model_2], dim=-1)) 68 | logits = self.classifier(pooled_output) 69 | 70 | if labels is not None: 71 | loss_fct = CrossEntropyLoss() 72 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 73 | return loss 74 | else: 75 | return logits 76 | 77 | 78 | class EnsembleBertForQuestionAnswering(nn.Module): 79 | """BERT model for Question Answering (span extraction). 80 | This module is composed of the BERT model with a linear layer on top of 81 | the sequence output that computes start_logits and end_logits 82 | 83 | Params: 84 | `config`: a BertConfig class instance with the configuration to build a new model. 85 | 86 | Inputs: 87 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 88 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 89 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 90 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 91 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 92 | a `sentence B` token (see BERT paper for more details). 93 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 94 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 95 | input sequence length in the current batch. It's the mask that we typically use for attention when 96 | a batch has varying length sentences. 97 | `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. 98 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 99 | into account for computing the loss. 100 | `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. 101 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 102 | into account for computing the loss. 103 | 104 | Outputs: 105 | if `start_positions` and `end_positions` are not `None`: 106 | Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. 107 | if `start_positions` or `end_positions` is `None`: 108 | Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end 109 | position tokens of shape [batch_size, sequence_length]. 110 | 111 | Example usage: 112 | ```python 113 | # Already been converted into WordPiece token ids 114 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 115 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 116 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 117 | 118 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 119 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 120 | 121 | model = BertForQuestionAnswering(config) 122 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 123 | ``` 124 | """ 125 | def __init__(self, model_1:BertForQuestionAnswering, model_2:BertForQuestionAnswering): 126 | super().__init__() 127 | self.model_1 = model_1 128 | self.model_2 = model_2 129 | self.config = self.model_1.config 130 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 131 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 132 | self.qa_outputs = nn.Linear(self.model_1.qa_outputs.weight.size(1)*2, 2) 133 | 134 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): 135 | sequence_output_model_1, _ = self.model_1.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 136 | sequence_output_model_2, _ = self.model_2.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 137 | logits = self.qa_outputs(torch.cat([sequence_output_model_1, sequence_output_model_2], dim=-1)) 138 | start_logits, end_logits = logits.split(1, dim=-1) 139 | start_logits = start_logits.squeeze(-1) 140 | end_logits = end_logits.squeeze(-1) 141 | 142 | if start_positions is not None and end_positions is not None: 143 | # If we are on multi-GPU, split add a dimension 144 | if len(start_positions.size()) > 1: 145 | start_positions = start_positions.squeeze(-1) 146 | if len(end_positions.size()) > 1: 147 | end_positions = end_positions.squeeze(-1) 148 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 149 | ignored_index = start_logits.size(1) 150 | start_positions.clamp_(0, ignored_index) 151 | end_positions.clamp_(0, ignored_index) 152 | 153 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 154 | start_loss = loss_fct(start_logits, start_positions) 155 | end_loss = loss_fct(end_logits, end_positions) 156 | total_loss = (start_loss + end_loss) / 2 157 | return total_loss 158 | else: 159 | return start_logits, end_logits 160 | 161 | 162 | class EnsembleBertForMultipleChoice(nn.Module): 163 | """BERT model for multiple choice tasks. 164 | This module is composed of the BERT model with a linear layer on top of 165 | the pooled output. 166 | 167 | Params: 168 | `config`: a BertConfig class instance with the configuration to build a new model. 169 | `num_choices`: the number of classes for the classifier. Default = 2. 170 | 171 | Inputs: 172 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] 173 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 174 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 175 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] 176 | with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` 177 | and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 178 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices 179 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 180 | input sequence length in the current batch. It's the mask that we typically use for attention when 181 | a batch has varying length sentences. 182 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 183 | with indices selected in [0, ..., num_choices]. 184 | 185 | Outputs: 186 | if `labels` is not `None`: 187 | Outputs the CrossEntropy classification loss of the output with the labels. 188 | if `labels` is `None`: 189 | Outputs the classification logits of shape [batch_size, num_labels]. 190 | 191 | Example usage: 192 | ```python 193 | # Already been converted into WordPiece token ids 194 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) 195 | input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) 196 | token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) 197 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 198 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 199 | 200 | num_choices = 2 201 | 202 | model = BertForMultipleChoice(config, num_choices) 203 | logits = model(input_ids, token_type_ids, input_mask) 204 | ``` 205 | """ 206 | def __init__(self, model_1:BertForMultipleChoice, model_2:BertForMultipleChoice): 207 | super().__init__() 208 | self.num_choices = model_1.num_choices 209 | self.dropout = nn.Dropout(model_1.config.hidden_dropout_prob) 210 | self.classifier = nn.Linear(model_1.config.hidden_size*2, 1) 211 | self.model_1 = model_1 212 | self.model_2 = model_2 213 | self.config = self.model_1.config 214 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 215 | 216 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): 217 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) 218 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 219 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 220 | pooled_output = torch.cat([ 221 | self.model_1.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)[1], 222 | self.model_2.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)[1] 223 | ], dim=-1) 224 | pooled_output = self.dropout(pooled_output) 225 | logits = self.classifier(pooled_output) 226 | reshaped_logits = logits.view(-1, self.num_choices) 227 | 228 | if labels is not None: 229 | loss_fct = CrossEntropyLoss() 230 | loss = loss_fct(reshaped_logits, labels) 231 | return loss 232 | else: 233 | return reshaped_logits 234 | -------------------------------------------------------------------------------- /bert_entity/preprocessing/create_integerized_wiki_training.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | import pickle 4 | import random 5 | from collections import Counter 6 | from itertools import cycle 7 | from typing import Dict 8 | 9 | from tqdm import tqdm 10 | 11 | from misc import unescape, create_chunks, create_overlapping_chunks 12 | from pipeline_job import PipelineJob 13 | from vocab import Vocab 14 | 15 | 16 | class CreateIntegerizedWikiTrainingData(PipelineJob): 17 | """ 18 | Create overlapping chunks of the Wikipedia articles. Outputs are stored as 19 | Python lists with integer ids. Configured by "create_integerized_training_instance_text_length" 20 | and "create_integerized_training_instance_text_overlap". 21 | 22 | Each worker creates his own shard, i.e., the number of shards is determined by 23 | "create_integerized_training_num_workers". 24 | 25 | Only save a training instance (a chunk of a Wikipedia article) if at least one entity in that 26 | chunk has not been seen more than "create_integerized_training_max_entity_per_shard_count" times. 27 | This downsamples highly frequent entities. Has to be set in relation to "create_integerized_training_num_workers" 28 | and "num_most_freq_entities". For the CONLL 2019 paper experiments the setting was 29 | 30 | create_integerized_training_max_entity_per_shard_count = 10 31 | create_integerized_training_num_workers = 40 32 | num_most_freq_entities = 500000 33 | """ 34 | def __init__(self, preprocess_jobs: Dict[str, PipelineJob], opts): 35 | super().__init__( 36 | requires=[ 37 | "data/indexes/redirects_en.ttl.bz2.dict", 38 | f"data/versions/{opts.data_version_name}/indexes/keyword_processor.pickle", 39 | f"data/versions/{opts.data_version_name}/indexes/popular_entity_counter_dict.pickle", 40 | f"data/versions/{opts.data_version_name}/indexes/mention_entity_counter_popular_entities.pickle", 41 | f"data/versions/{opts.data_version_name}/indexes/popular_entity_to_id_dict.pickle", 42 | f"data/versions/{opts.data_version_name}/indexes/mention_entity_discounted_probs.pickle", 43 | f"data/versions/{opts.data_version_name}/indexes/necessary_articles.pickle", 44 | ], 45 | provides=[f"data/versions/{opts.data_version_name}/wiki_training/integerized/{opts.wiki_lang_version}/"], 46 | preprocess_jobs=preprocess_jobs, 47 | opts=opts, 48 | ) 49 | 50 | def _run(self): 51 | 52 | with open("data/indexes/redirects_en.ttl.bz2.dict", "rb") as f: 53 | redirects_en = pickle.load(f) 54 | 55 | with open(f"data/versions/{self.opts.data_version_name}/indexes/keyword_processor.pickle", "rb",) as f: 56 | keyword_processor = pickle.load(f) 57 | 58 | with open( 59 | f"data/versions/{self.opts.data_version_name}/indexes/popular_entity_counter_dict.pickle", "rb", 60 | ) as f: 61 | most_popular_entity_counter_dict = pickle.load(f) 62 | 63 | with open( 64 | f"data/versions/{self.opts.data_version_name}/indexes/mention_entity_counter_popular_entities.pickle", "rb", 65 | ) as f: 66 | mention_entity_counter_popular_entities = pickle.load(f) 67 | 68 | with open(f"data/versions/{self.opts.data_version_name}/indexes/popular_entity_to_id_dict.pickle", "rb") as f: 69 | popular_entity_to_id_dict = pickle.load(f) 70 | 71 | with open( 72 | f"data/versions/{self.opts.data_version_name}/indexes/mention_entity_discounted_probs.pickle", "rb" 73 | ) as f: 74 | me_pop_p = pickle.load(f) 75 | 76 | with open(f"data/versions/{self.opts.data_version_name}/indexes/necessary_articles.pickle", "rb") as f: 77 | necessary_articles = pickle.load(f) 78 | 79 | out_dir = f"data/versions/{self.opts.data_version_name}/wiki_training/integerized/tmp/" 80 | 81 | vocab = Vocab() 82 | vocab.load(self.opts, popular_entity_to_id_dict=popular_entity_to_id_dict) 83 | 84 | in_queue = multiprocessing.Queue() 85 | out_queue = multiprocessing.Queue() 86 | 87 | workers = list() 88 | 89 | # 90 | # start the workers in individual processes 91 | # 92 | 93 | shards = cycle(range(self.opts.create_integerized_training_num_workers)) 94 | 95 | for id in range(self.opts.create_integerized_training_num_workers): 96 | worker = Worker( 97 | in_queue, 98 | out_queue, 99 | shard=next(shards), 100 | opts=self.opts, 101 | out_dir=out_dir, 102 | redirects_en=redirects_en, 103 | keyword_processor=keyword_processor, 104 | popular_entity_counter_dict=most_popular_entity_counter_dict, 105 | mention_entity_counter_popular_entities=mention_entity_counter_popular_entities, 106 | me_pop_p=me_pop_p, 107 | vocab=vocab, 108 | ) 109 | worker.start() 110 | workers.append(worker) 111 | 112 | self.log("Fill queue") 113 | 114 | submitted_jobs = 0 115 | 116 | for file_nr, extracted_wiki_file in enumerate(tqdm(necessary_articles)): 117 | submitted_jobs += 1 118 | in_queue.put((extracted_wiki_file)) 119 | 120 | self.log("Collect the output") 121 | 122 | joined_data_loc_list = list() 123 | 124 | for _ in tqdm(range(submitted_jobs)): 125 | (local_joined_data_loc_list), in_file_name = out_queue.get() 126 | if ( 127 | local_joined_data_loc_list is not None 128 | and len(local_joined_data_loc_list) > 0 129 | and local_joined_data_loc_list[0] is not None 130 | ): 131 | joined_data_loc_list.extend(local_joined_data_loc_list) 132 | 133 | with open(f"{out_dir}/data.loc", "w") as f_loc: 134 | f_loc.writelines(joined_data_loc_list) 135 | 136 | random.shuffle(joined_data_loc_list) 137 | 138 | with open(f"{out_dir}/train.loc", "w") as f_loc: 139 | f_loc.writelines( 140 | joined_data_loc_list[ 141 | : -( 142 | self.opts.create_integerized_training_valid_size 143 | + self.opts.create_integerized_training_test_size 144 | ) 145 | ] 146 | ) 147 | 148 | with open(f"{out_dir}/valid.loc", "w") as f_loc: 149 | f_loc.writelines( 150 | joined_data_loc_list[ 151 | -( 152 | self.opts.create_integerized_training_valid_size 153 | + self.opts.create_integerized_training_test_size 154 | ) : -self.opts.create_integerized_training_test_size 155 | ] 156 | ) 157 | 158 | with open(f"{out_dir}/test.loc", "w") as f_loc: 159 | f_loc.writelines(joined_data_loc_list[-self.opts.create_integerized_training_test_size :]) 160 | 161 | # put the None into the queue so the loop in the run() function of the worker stops 162 | for worker in workers: 163 | in_queue.put(None) 164 | out_queue.put(None) 165 | 166 | # terminate the process 167 | for worker in workers: 168 | worker.join() 169 | 170 | os.rename( 171 | f"data/versions/{self.opts.data_version_name}/wiki_training/integerized/tmp/", 172 | f"data/versions/{self.opts.data_version_name}/wiki_training/integerized/{self.opts.wiki_lang_version}/", 173 | ) 174 | 175 | 176 | class Worker(multiprocessing.Process): 177 | def __init__( 178 | self, 179 | in_queue, 180 | out_queue, 181 | shard, 182 | opts, 183 | out_dir, 184 | redirects_en, 185 | keyword_processor, 186 | popular_entity_counter_dict, 187 | mention_entity_counter_popular_entities, 188 | me_pop_p, 189 | vocab, 190 | ): 191 | super().__init__() 192 | self.in_queue = in_queue 193 | self.out_queue = out_queue 194 | self.shard = shard 195 | self.redirects_en = redirects_en 196 | self.keyword_processor = keyword_processor 197 | self.popular_entity_counter_dict = popular_entity_counter_dict 198 | self.mention_entity_counter_popular_entities = mention_entity_counter_popular_entities 199 | self.me_pop_p = me_pop_p 200 | self.vocab = vocab 201 | self.opts = opts 202 | out_file = f"{out_dir}/{self.shard}.dat" 203 | os.makedirs(out_dir, exist_ok=True) 204 | self.pickle_file = open(out_file, "wb") 205 | self.entity_counts = Counter() 206 | 207 | def run(self): 208 | # this loop will run until it receives None form the in_queue, if the queue is empty 209 | for next_item in iter(self.in_queue.get, None): 210 | file_name = next_item 211 | self.out_queue.put((self.extract_data(next_item), file_name)) 212 | self.pickle_file.close() 213 | 214 | def extract_data(self, file_name): 215 | 216 | instance_text_length = self.opts.create_integerized_training_instance_text_length 217 | instance_text_overlap = self.opts.create_integerized_training_instance_text_overlap 218 | max_entity_per_shard_count = self.opts.create_integerized_training_max_entity_per_shard_count 219 | 220 | local_joined_data_loc_list = list() 221 | 222 | if os.path.getsize(file_name) == 0: 223 | return None, None, None, None 224 | 225 | def map_func(line): 226 | items = line.strip().split("\t") 227 | if len(items) != 4: 228 | return "[UNK]", "O", "-" 229 | else: 230 | _, tok, ent, ment = items 231 | tok = unescape(tok) 232 | ent = unescape(ent) 233 | return tok if tok else "[UNK]", ent if ent else "O", ment if ment else "-" 234 | 235 | with open(file_name) as f: 236 | toks, ents, ments = zip(*map(map_func, f.readlines())) 237 | 238 | token_ids = list() 239 | for chunk in create_chunks(toks, 512): 240 | token_ids.extend(self.vocab.tokenizer.convert_tokens_to_ids(chunk)) 241 | 242 | mention_entity_ids, mention_entity_probs, mention_probs, is_entity = list(), list(), list(), list() 243 | 244 | # if df.isnull().values.any(): 245 | # print(file_name, '\n', df) 246 | # return None, None, None, None 247 | 248 | for i, (entity, mention) in enumerate(zip(ents, ments)): 249 | if ( 250 | entity == "O" 251 | or entity == "UNK" 252 | and ( 253 | mention not in self.mention_entity_counter_popular_entities 254 | or len(self.mention_entity_counter_popular_entities[mention]) == 0 255 | ) 256 | ): 257 | mention_entity_ids.append(([self.vocab.OUTSIDE_ID],)) # 258 | mention_entity_probs.append(([1.0],)) 259 | mention_probs.append((1.0,)) 260 | is_entity.append(1000000000) 261 | elif entity == "UNK" and mention in self.me_pop_p: 262 | this_mention_entity_ids, this_mention_entity_probs = zip(*list(self.me_pop_p[mention])) 263 | mention_entity_ids.append((list(map(self.vocab.tag2idx.__getitem__, this_mention_entity_ids)),)) 264 | mention_entity_probs.append((this_mention_entity_probs,)) 265 | mention_probs.append((1.0, 1.0,)) 266 | is_entity.append(1000000000) 267 | else: 268 | mention_entity_ids.append(([self.vocab.tag2idx[entity]],)) 269 | mention_entity_probs.append(([1.0],)) 270 | mention_probs.append((1.0,)) 271 | is_entity.append(self.vocab.tag2idx[entity]) 272 | 273 | for ( 274 | token_ids_chunk, 275 | mention_entity_ids_chunk, 276 | mention_entity_probs_chunk, 277 | mention_probs_chunk, 278 | is_entity_chunk, 279 | ) in zip( 280 | create_overlapping_chunks(token_ids, instance_text_length, instance_text_overlap), 281 | create_overlapping_chunks(mention_entity_ids, instance_text_length, instance_text_overlap), 282 | create_overlapping_chunks(mention_entity_probs, instance_text_length, instance_text_overlap), 283 | create_overlapping_chunks(mention_probs, instance_text_length, instance_text_overlap), 284 | create_overlapping_chunks(is_entity, instance_text_length, instance_text_overlap), 285 | ): 286 | # Only save a training instance (a chunk of a Wikipedia article) if at least one entity in that 287 | # chunk has not been seen more than "max_entity_per_shard_count" times. This downsamples highly 288 | # frequent entities. 289 | if ( 290 | sum(map(lambda i: 1 if self.entity_counts[i] < (max_entity_per_shard_count + 1) else 0, is_entity_chunk,)) 291 | > 0 292 | ): 293 | self.entity_counts.update(is_entity_chunk) 294 | local_joined_data_loc_list.append( 295 | str("{}\t{}\n".format(self.shard, self.pickle_file.tell())) 296 | ) # remember row byte offset 297 | pickle.dump( 298 | (token_ids_chunk, mention_entity_ids_chunk, mention_entity_probs_chunk, mention_probs_chunk), 299 | self.pickle_file, 300 | ) # write new row 301 | self.pickle_file.flush() 302 | return local_joined_data_loc_list 303 | 304 | 305 | -------------------------------------------------------------------------------- /bert_entity/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from itertools import chain 4 | 5 | import numpy 6 | import torch 7 | import torch.nn as nn 8 | import tqdm 9 | from torch import optim 10 | from tqdm import trange 11 | 12 | from metrics import Metrics 13 | from data_loader_wiki import EDLDataset_collate_func 14 | from misc import running_mean, get_topk_ids_aggregated_from_seq_prediction, DummyOptimizer, LRSchedulers 15 | from pytorch_pretrained_bert import BertModel 16 | 17 | 18 | class Net(nn.Module): 19 | def __init__( 20 | self, args, vocab_size=None, 21 | ): 22 | super().__init__() 23 | if args.uncased: 24 | self.bert = BertModel.from_pretrained("bert-base-uncased") 25 | else: 26 | self.bert = BertModel.from_pretrained("bert-base-cased") 27 | 28 | self.top_rnns = args.top_rnns 29 | if args.top_rnns: 30 | self.rnn = nn.LSTM(bidirectional=True, num_layers=2, input_size=768, hidden_size=768 // 2, batch_first=True) 31 | self.fc = None 32 | if args.project: 33 | self.fc = nn.Linear(768, args.entity_embedding_size) 34 | self.out = nn.Embedding(num_embeddings=vocab_size, embedding_dim=args.entity_embedding_size, sparse=args.sparse) 35 | # torch.nn.init.normal_(self.out, std=0.1) 36 | 37 | self.device = args.device 38 | self.out_device = args.out_device 39 | self.finetuning = args.finetuning == 0 40 | self.vocab_size = vocab_size 41 | 42 | def to(self, device, out_device): 43 | self.bert.to(device) 44 | if self.fc: 45 | self.fc.to(device) 46 | self.out.to(out_device) 47 | self.device = device 48 | self.out_device = out_device 49 | 50 | def forward(self, x, y=None, probs=None, enc=None): 51 | """ 52 | x: (N, T). int64 53 | y: (N, T). int64 54 | 55 | Returns 56 | enc: (N, T, VOCAB) 57 | """ 58 | if y is not None: 59 | y = y.to(self.out_device) 60 | if probs is not None: 61 | probs = probs.to(self.out_device) 62 | 63 | # fake_y = torch.Tensor(range(10)).long().to(self.device) 64 | 65 | if enc is None: 66 | x = x.to(self.device) 67 | if self.training: 68 | if self.finetuning: 69 | # print("->bert.train()") 70 | self.bert.train() 71 | encoded_layers, _ = self.bert(x) 72 | enc = encoded_layers[-1] 73 | else: 74 | self.bert.eval() 75 | with torch.no_grad(): 76 | encoded_layers, _ = self.bert(x) 77 | enc = encoded_layers[-1] 78 | else: 79 | encoded_layers, _ = self.bert(x) 80 | enc = encoded_layers[-1] 81 | 82 | if self.top_rnns: 83 | enc, _ = self.rnn(enc) 84 | 85 | if self.fc: 86 | enc = self.fc(enc) 87 | 88 | enc = enc.to(self.out_device) 89 | 90 | if y is not None: 91 | out = self.out(y) 92 | logits = enc.matmul(out.transpose(0, 1)) 93 | y_hat = logits.argmax(-1) 94 | return logits, y, y_hat, probs, out, enc 95 | else: 96 | with torch.no_grad(): 97 | out = self.out.weight 98 | logits = enc.matmul(out.transpose(0, 1)) 99 | y_hat = logits.argmax(-1) 100 | return logits, None, y_hat, None, None, enc 101 | 102 | @staticmethod 103 | def train_one_epoch( 104 | args, 105 | model, 106 | train_iter, 107 | optimizers, 108 | criterion, 109 | eval_iter, 110 | vocab, 111 | epoch, 112 | metrics=Metrics(), 113 | loss_aggr=None, 114 | ): 115 | labels_with_high_model_score = None 116 | 117 | with trange(len(train_iter)) as t: 118 | for iter, batch in enumerate(train_iter): 119 | 120 | model.to( 121 | args.device, args.out_device, 122 | ) 123 | model.train() 124 | 125 | batch_token_ids, label_ids, label_probs, eval_mask, _, _, orig_batch, loaded_batch = batch 126 | 127 | enc = None 128 | 129 | if ( 130 | args.collect_most_popular_labels_steps is not None 131 | and args.collect_most_popular_labels_steps > 0 132 | and iter > 0 133 | and iter % args.collect_most_popular_labels_steps == 0 134 | ): 135 | model.to(args.device, args.eval_device) 136 | with torch.no_grad(): 137 | logits_, _, _, _, _, enc = model( 138 | batch_token_ids, None, None, 139 | ) # logits: (N, T, VOCAB), y: (N, T) 140 | labels_with_high_model_score = get_topk_ids_aggregated_from_seq_prediction( 141 | logits_, topk_from_batch=args.label_size, topk_per_token=args.topk_neg_examples 142 | ) 143 | batch_token_ids, label_ids, label_probs, eval_mask, _, _, _, _ = EDLDataset_collate_func( 144 | args=args, 145 | labels_with_high_model_score=labels_with_high_model_score, 146 | batch=orig_batch, 147 | return_labels=True, 148 | vocab=vocab, 149 | is_training=False, 150 | loaded_batch=loaded_batch, 151 | ) 152 | 153 | # if args.label_size is not None: 154 | logits, y, y_hat, label_probs, sparse_params, _ = model( 155 | batch_token_ids, label_ids, label_probs, enc=enc 156 | ) # logits: (N, T, VOCAB), y: (N, T) 157 | logits = logits.view(-1) # (N*T, VOCAB) 158 | label_probs = label_probs.view(-1) # (N*T,) 159 | 160 | loss = criterion(logits, label_probs) 161 | 162 | loss.backward() 163 | 164 | if (iter + 1) % args.accumulate_batch_gradients == 0: 165 | for optimizer in optimizers: 166 | optimizer.step() 167 | optimizer.zero_grad() 168 | 169 | if iter == 0: 170 | logging.debug(f"Sanity check") 171 | logging.debug("x:", batch_token_ids.cpu().numpy()[0]) 172 | logging.debug("tokens:", vocab.tokenizer.convert_ids_to_tokens(batch_token_ids.cpu().numpy()[0])) 173 | logging.debug("y:", label_probs.cpu().numpy()[0]) 174 | 175 | loss_aggr = running_mean(loss.detach().item(), loss_aggr) 176 | 177 | if iter > 0 and iter % args.checkpoint_eval_steps == 0: 178 | metrics = Net.evaluate( 179 | args=args, 180 | model=model, 181 | iterator=eval_iter, 182 | optimizers=optimizers, 183 | step=iter, 184 | epoch=epoch, 185 | save_checkpoint=iter % args.checkpoint_save_steps == 0, 186 | sampled_evaluation=False, 187 | metrics=metrics, 188 | vocab=vocab, 189 | ) 190 | 191 | t.set_postfix( 192 | loss=loss_aggr, 193 | nr_labels=len(label_ids), 194 | aggr_labels=len(labels_with_high_model_score) if labels_with_high_model_score else 0, 195 | last_eval=metrics.report(filter={"f1", "num_proposed", "epoch", "step"}), 196 | ) 197 | t.update() 198 | 199 | for optimizer in optimizers: 200 | optimizer.step() 201 | optimizer.zero_grad() 202 | 203 | return metrics 204 | 205 | @staticmethod 206 | def evaluate( 207 | args, 208 | model, 209 | iterator, 210 | vocab, 211 | optimizers, 212 | step=0, 213 | epoch=0, 214 | save_checkpoint=True, 215 | save_predictions=True, 216 | save_csv=True, 217 | sampled_evaluation=False, 218 | metrics=Metrics(), 219 | ): 220 | 221 | print() 222 | logging.info(f"Start evaluation on split {'test' if args.eval_on_test_only else 'valid'}") 223 | 224 | model.eval() 225 | model.to(args.device, args.eval_device) 226 | 227 | all_words, all_tags, all_y, all_y_hat, all_predicted, all_token_ids = [], [], [], [], [], [] 228 | with torch.no_grad(): 229 | for iter, batch in enumerate(tqdm.tqdm(iterator)): 230 | ( 231 | batch_token_ids, 232 | label_ids, 233 | label_probs, 234 | eval_mask, 235 | label_id_to_entity_id_dict, 236 | batch_entity_ids, 237 | orig_batch, 238 | _, 239 | ) = batch 240 | 241 | logits, y, y_hat, probs, _, _ = model(batch_token_ids, None, None) # logits: (N, T, VOCAB), y: (N, T) 242 | 243 | tags = list() 244 | predtags = list() 245 | y_resolved_list = list() 246 | y_hat_resolved_list = list() 247 | token_list = list() 248 | 249 | chunk_len = args.create_integerized_training_instance_text_length 250 | chunk_overlap = args.create_integerized_training_instance_text_overlap 251 | 252 | for batch_id, seq in enumerate(label_probs.max(-1)[1]): 253 | for tok_id, label_id in enumerate(seq[chunk_overlap : -chunk_overlap]): 254 | y_resolved = ( 255 | vocab.PAD_ID 256 | if eval_mask[batch_id][tok_id + chunk_overlap] == 0 257 | else label_ids[label_id].item() 258 | ) 259 | y_resolved_list.append(y_resolved) 260 | tags.append(vocab.idx2tag[y_resolved]) 261 | if sampled_evaluation: 262 | y_hat_resolved = ( 263 | vocab.PAD_ID 264 | if eval_mask[batch_id][tok_id + chunk_overlap] == 0 265 | else label_ids[y_hat[batch_id][tok_id + chunk_overlap]].item() 266 | ) 267 | else: 268 | y_hat_resolved = y_hat[batch_id][tok_id + chunk_overlap].item() 269 | y_hat_resolved_list.append(y_hat_resolved) 270 | predtags.append(vocab.idx2tag[y_hat_resolved]) 271 | token_list.append(batch_token_ids[batch_id][tok_id + chunk_overlap].item()) 272 | 273 | all_y.append(y_resolved_list) 274 | all_y_hat.append(y_hat_resolved_list) 275 | all_tags.append(tags) 276 | all_predicted.append(predtags) 277 | all_words.append(vocab.tokenizer.convert_ids_to_tokens(token_list)) 278 | all_token_ids.append(token_list) 279 | 280 | ## calc metric 281 | y_true = numpy.array(list(chain(*all_y))) 282 | y_pred = numpy.array(list(chain(*all_y_hat))) 283 | all_token_ids = numpy.array(list(chain(*all_token_ids))) 284 | 285 | num_proposed = len(y_pred[(vocab.OUTSIDE_ID > y_pred) & (all_token_ids > 0)]) 286 | num_correct = (((y_true == y_pred) & (vocab.OUTSIDE_ID > y_true) & (all_token_ids > 0))).astype(numpy.int).sum() 287 | num_gold = len(y_true[(vocab.OUTSIDE_ID > y_true) & (all_token_ids > 0)]) 288 | 289 | new_metrics = Metrics( 290 | epoch=epoch, step=step, num_correct=num_correct, num_proposed=num_proposed, num_gold=num_gold, 291 | ) 292 | 293 | if save_predictions: 294 | final = args.logdir + "/%s.P%.2f_R%.2f_F%.2f" % ( 295 | "{}-{}".format(str(epoch), str(step)), 296 | new_metrics.precision, 297 | new_metrics.recall, 298 | new_metrics.f1, 299 | ) 300 | with open(final, "w") as fout: 301 | 302 | for words, tags, y_hat, preds in zip(all_words, all_tags, all_y_hat, all_predicted): 303 | assert len(preds) == len(words) == len(tags) 304 | for w, t, p in zip(words, tags, preds): 305 | fout.write(f"{w}\t{t}\t{p}\n") 306 | fout.write("\n") 307 | 308 | fout.write(f"num_proposed:{num_proposed}\n") 309 | fout.write(f"num_correct:{num_correct}\n") 310 | fout.write(f"num_gold:{num_gold}\n") 311 | fout.write(f"precision={new_metrics.precision}\n") 312 | fout.write(f"recall={new_metrics.recall}\n") 313 | fout.write(f"f1={new_metrics.f1}\n") 314 | 315 | if not args.dont_save_checkpoints: 316 | 317 | if save_checkpoint and metrics.was_improved(new_metrics): 318 | config = { 319 | "args": args, 320 | "optimizer_dense": optimizers[0].state_dict(), 321 | "optimizer_sparse": optimizers[1].state_dict(), 322 | "model": model.state_dict(), 323 | "epoch": epoch, 324 | "step": step, 325 | "performance": new_metrics.dict(), 326 | } 327 | fname = os.path.join(args.logdir, "{}-{}".format(str(epoch), str(step))) 328 | torch.save(config, f"{fname}.pt") 329 | fname = os.path.join(args.logdir, new_metrics.get_best_checkpoint_filename()) 330 | torch.save(config, f"{fname}.pt") 331 | logging.info(f"weights were saved to {fname}.pt") 332 | 333 | if save_csv: 334 | new_metrics.to_csv(epoch=epoch, step=step, args=args) 335 | 336 | if metrics.was_improved(new_metrics): 337 | metrics.update(new_metrics) 338 | 339 | logging.info("Finished evaluation") 340 | 341 | return metrics 342 | 343 | def get_optimizers(self, args, checkpoint): 344 | 345 | optimizers = list() 346 | 347 | if args.encoder_lr > 0: 348 | optimizer_encoder = optim.Adam( 349 | list(self.bert.parameters()) + list(self.fc.parameters() if args.project else list()), 350 | lr=args.encoder_lr, 351 | ) 352 | if args.resume_from_checkpoint is not None: 353 | optimizer_encoder.load_state_dict(checkpoint["optimizer_dense"]) 354 | optimizer_encoder.param_groups[0]["lr"] = args.encoder_lr 355 | optimizer_encoder.param_groups[0]["weight_decay"] = args.encoder_weight_decay 356 | optimizers.append(optimizer_encoder) 357 | else: 358 | optimizers.append(DummyOptimizer(self.out.parameters(), defaults={})) 359 | 360 | if args.decoder_lr > 0: 361 | if args.sparse: 362 | optimizer_decoder = optim.SparseAdam(self.out.parameters(), lr=args.decoder_lr) 363 | else: 364 | optimizer_decoder = optim.Adam(self.out.parameters(), lr=args.decoder_lr) 365 | if args.resume_from_checkpoint is not None: 366 | optimizer_decoder.load_state_dict(checkpoint["optimizer_sparse"]) 367 | if "weight_decay" not in optimizer_decoder.param_groups[0]: 368 | optimizer_decoder.param_groups[0]["weight_decay"] = 0 369 | optimizer_decoder.param_groups[0]["lr"] = args.decoder_lr 370 | if not args.sparse: 371 | optimizer_decoder.param_groups[0]["weight_decay"] = args.decoder_weight_decay 372 | optimizers.append(optimizer_decoder) 373 | else: 374 | optimizers.append(DummyOptimizer(self.out.parameters(), defaults={})) 375 | 376 | lr_schedulers = [ 377 | getattr(LRSchedulers, lr_scheduler)(optimizer=optimizer, **lr_scheduler_config) 378 | for optimizer, (lr_scheduler, lr_scheduler_config) in zip( 379 | optimizers, 380 | [ 381 | (args.encoder_lr_scheduler, args.encoder_lr_scheduler_config), 382 | (args.decoder_lr_scheduler, args.decoder_lr_scheduler_config), 383 | ], 384 | ) 385 | if lr_scheduler is not None # and not isinstance(optimizer, DummyOptimizer) 386 | ] 387 | 388 | return tuple(optimizers), tuple(lr_schedulers) 389 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # "Investigating Entity Knowledge in BERT with Simple Neural End-To-End Entity Linking" 2 | 3 | This repository contains the code for the CONLL 2019 paper [**"Investigating Entity Knowledge in BERT with Simple Neural End-To-End Entity Linking"**](https://arxiv.org/abs/2003.05473). The code is provided as a documentation for the paper and also for follow-up research. 4 | 5 | #

Bert-Entity

6 | 7 | The content of this page covers the following topics: 8 | 9 | 1. [Quick start](#quick-start) 10 | 2. [Preparation and Installation](#preparation-and-installation) 11 | 3. [Preprocessing of Wikipedia and the AIDA-CONLL entity linking benchmark into a sequence tagging format](#preprocessing-data) 12 | 4. [Finetuning/Training a BERT-Entity model on Wikipedia](#training) 13 | 5. [Finetuning a BERT-Entity model on the AIDA-CONLL entity linking benchmark](#training) 14 | 6. [Using a BERT-Entity model in a downstream task](#evalation-on-downstream-tasks) 15 | 7. [Issues and possible improvements](#issues-and-possible-improvements) 16 | 17 | ## Quick start 18 | 19 | Here are all the steps until the finetuning and evaluation on the AIDA-CoNLL benchmark in a prototyping setting (i.e. a smaller model pretrained on reduced Wikipedia data): 20 | 21 | - The project is installed as follows: 22 | 23 | ``` 24 | git clone https://github.com/samuelbroscheit/entity_knowledge_in_bert.git 25 | cd entity_knowledge_in_bert 26 | pip install -r requirements.txt 27 | git submodule update --init 28 | ``` 29 | 30 | - Add paths to environment 31 | 32 | ``` 33 | source setup_paths 34 | ``` 35 | 36 | - Create directory 37 | 38 | ``` 39 | mkdir -p data/benchmarks/ 40 | ``` 41 | 42 | The AIDA-CoNLL benchmark file should be located under `data/benchmarks/aida-yago2-dataset/AIDA-YAGO2-dataset.tsv`. Get it from https://www.mpi-inf.mpg.de/departments/databases-and-information-systems/research/yago-naga/aida/downloads/ . If you get it from somewhere else, then please make sure that you have the correct file with 6 columns: Token, Mention, Yago Name, Wiki Name, Wiki Id, Freebase Id. 43 | 44 | - Run preprocessing 45 | 46 | ``` 47 | python bert_entity/preprocess_all.py -c config/dummy__preprocess.yaml 48 | ``` 49 | 50 | - Run pretraining on Wikipedia 51 | 52 | ``` 53 | python bert_entity/train.py -c config/dummy__train_on_wiki.yaml 54 | ``` 55 | 56 | - Finetune on AIDA-CoNLL benchmark 57 | 58 | ``` 59 | python bert_entity/train.py -c config/dummy__train_on_aida_conll.yaml 60 | ``` 61 | 62 | - Evaluate the best model on the AIDA-CoNLL benchmark 63 | 64 | ``` 65 | python bert_entity/train.py -c config/dummy__train_on_aida_conll.yaml --eval_on_test_only True --resume_from_checkpoint data/checkpoints/dummy_aidaconll_00001/best_f1-0.pt 66 | ``` 67 | 68 | 69 | 70 | ## Preparation and Installation 71 | 72 | For downloading and processing the full data and for storing checkpoints you should have at least 500GB of free space in the respective filesystem. If you just want to prototype there are also prepared configurations that need less space (~100GB). 73 | 74 | ### Installation 75 | 76 | To install run the following commands: 77 | 78 | ``` 79 | git clone https://github.com/samuelbroscheit/entity_knowledge_in_bert.git 80 | cd entity_knowledge_in_bert 81 | pip install -r requirements.txt 82 | git submodule update --init 83 | ``` 84 | ### Setup Paths 85 | 86 | **Every time** you run the code you have to setup up the paths for python with 87 | 88 | ``` 89 | source setup_paths 90 | ``` 91 | 92 | ### Prepare AIDA CoNLL-YAGO2 benchmark data 93 | 94 | First create the directory 95 | 96 | ``` 97 | mkdir -p data/benchmarks/ 98 | ``` 99 | 100 | and then retrieve the AIDA CoNLL-YAGO2 benchmark from https://www.mpi-inf.mpg.de/departments/databases-and-information-systems/research/yago-naga/aida/downloads/ (the benchmark is referred to as AIDA Conll throughout the code). The resulting file should be located under `data/benchmarks/aida-yago2-dataset/AIDA-YAGO2-dataset.tsv`. If you get it from somewhere else, please make sure that you have the correct file with 6 columns: Token, Mention, Yago Name, Wiki Name, Wiki Id, Freebase Id. 101 | 102 | 103 | ## Preprocessing data 104 | 105 | 1. [Run preprocessing](#run-preprocessing) 106 | 2. [Prepared configurations](#prepared-configurations) 107 | 3. [Available options](#available-options) 108 | 4. [Preprocessing tasks](#preprocessing-tasks) 109 | 110 | ### Run preprocessing 111 | 112 | The preprocessing pipeline will take care of all downloads and processing of the data. You run the preprocessing with: 113 | 114 | ``` 115 | python bert_entity/preprocess_all.py -c PREPROC_CONFIG_FILE_NAME 116 | ``` 117 | 118 | PREPROC_CONFIG_FILE_NAME is a yaml file, but all options can also be given on the command line. 119 | 120 | ### Prepared configurations 121 | 122 | In the config folder you will find two configurations: 123 | 124 | - [config/dummy__preprocess.yaml](config/dummy__preprocess.yaml) is a setting for prototyping and testing the preprocessing pipeline and for prototyping the BERT-Entity training. 125 | 126 | - [config/conll2019__preprocess.yaml](config/conll2019__preprocess.yaml) is the setting "Setting 2" that was used in the the CoNLL 2019 paper for the BERT-Entity model with ~500K entities. 127 | 128 | 129 | ### Available options 130 | 131 | The PREPROC_CONFIG_FILE_NAME supports the following configurations: 132 | 133 | 134 | ``` 135 | 136 | --debug Print debug messages 137 | 138 | # General settings 139 | 140 | --wiki_lang_version Wiki language version, e.g. enwiki 141 | 142 | --data_version_name Data identifier/version; e.g. if you experiment with different 143 | preprocessing options you should use different names here to create 144 | new directories. 145 | 146 | --num_most_freq_entities Number of most frequent entities that should be collected from the 147 | entity set collected from the Wikipedia dump 148 | 149 | --add_missing_conll_entities Whether entities for the AIDA CONLL benchmark that are missing from 150 | the most frequent entities collected from the Wikipedia dump should 151 | be added to the entity vocabulary 152 | 153 | --uncased Should the input token dictionary be uncased 154 | 155 | 156 | # DownloadWikiDump 157 | 158 | --download_data_only_dummy Only download one wiki file 159 | 160 | --download_2017_enwiki Download the enwiki 2017 dump to reproduce the experiments for the 161 | CONLL 2019 paper 162 | 163 | 164 | # CollectMentionEntityCounts 165 | 166 | --collect_mention_entities_num_workers 167 | Number of worker for parallel processing of the Wikipedia dump to 168 | collect mention entities. 169 | 170 | 171 | # WikiExtractor 172 | 173 | --wikiextractor_num_workers Number of worker for parallel processing of the Wikipedia dump 174 | 175 | 176 | # CreateWikiTrainingData 177 | 178 | --create_training_data_num_workers 179 | Number of worker for parallel processing to create the sequence 180 | tagging training data 181 | 182 | --create_training_data_num_entities_in_necessary_articles 183 | Threshold on the #entities in necessary articles (i.e. articles that 184 | contain entities in the most frequent entity vocabulary) that should 185 | be considered for training data 186 | 187 | 188 | # CreateIntegerizedWikiTrainingData 189 | 190 | --create_integerized_training_num_workers 191 | Number of worker for parallel processing to create the integerized 192 | sequence tagging training data, also determines the number of created 193 | shards. 194 | 195 | --create_integerized_training_instance_text_length 196 | Text length of the integerized training instances 197 | 198 | --create_integerized_training_instance_text_overlap 199 | Overlap between integerized training instances 200 | 201 | --create_integerized_training_max_entity_per_shard_count 202 | Max count per entity in each shard. For each 203 | 204 | --create_integerized_training_valid_size 205 | Sample size for validation data. 206 | 207 | --create_integerized_training_test_size 208 | Sample size for test data. 209 | 210 | 211 | ``` 212 | 213 | ### Preprocessing tasks 214 | 215 | Preprocessing consists of the following tasks (the respective code is in `bert_entity/preprocessing`): 216 | 217 | - CreateRedirects 218 | - Create a dictionary containing redirects for Wikipedia page names [(*)](#footnote). The redirects are used for the Wikipedia mention extractions as well as for the AIDA-CONLL benchmark. 219 | 220 | ``` 221 | "AccessibleComputing": "Computer_accessibility" 222 | ``` 223 | 224 | - CreateResolveToWikiNameDicts 225 | - Create a dictionary that map Freebase Ids and Wikipedia pages ids to Wikipedia page names [(*)](#footnote). The disambiguations are used to detect entity annotations in the AIDA-CONLL benchmark that have become incompatible for newer Wikipedia versions. 226 | 227 | ``` 228 | "/m/01009ly3": "Dysdera_ancora" 229 | ``` 230 | 231 | ``` 232 | "10": "Computer_accessibility" 233 | ``` 234 | 235 | - CreateDisambiguationDict 236 | - Create a dictionary containing disambiguations for Wikipedia page names [(*)](#footnote). The disambiguations are used to detect entity annotations in 237 | the AIDA-CONLL benchmark that have become incompatble for newer Wikipedia 238 | versions. 239 | 240 | ``` 241 | "Alien": ["Alien_(law)", "Alien_(software)", ... ] 242 | ``` 243 | 244 | - DownloadWikiDump 245 | - Download the current Wikipedia dump. Either download one file for a dummy / prototyping version. Set `download_data_only_dummy` to True for just one file, ootherwise download all files. Set `download_2017_enwiki` to True if not the latest dump should be retrieved but a 2017 dump like in the paper. 246 | 247 | - Wikiextractor 248 | - Run Wikiextractor on the Wikipedia dump and extract all the mentions from it. 249 | 250 | - CollectMentionEntityCounts 251 | - Collect mention entity counts from the Wikiextractor files. 252 | 253 | - PostProcessMentionEntityCounts 254 | - Create entity indexes that will later be used in the creation of the Wikipedia training data. First, based on the configuration key `num_most_freq_entities` the **top k most popular entities** are selected. Based on those, other mappings are created to only 255 | contain counts and priors concerning the top k popular entities. Later the top k popular entities will also restrict the training data to only contain instances that contain popular entities. 256 | Also, if `add_missing_conll_entities` is set, the entity ids necessary for the AIDA-CONLL benchmark that are missing in the top k popular entities are added. This is to ensure that the evaluation measures are comparable to prior work. 257 | 258 | - CreateAIDACONLL 259 | - Read the AIDA-CONLL benchmark dataset in and merge it with the NER annotations. Requires you to provide `data/benchmarks/aida-yago2-dataset/AIDA-YAGO2-dataset.tsv`. Please make sure that you have the correct file with 6 columns: Token, Mention, Yago Name, Wiki Name, Wiki Id, Freebase Id. 260 | 261 | - CreateKeywordProcessor 262 | - Create a tri-based matcher to detect possible mentions of our known entities. We use this later to add autmatic annotations to the text. However, as we do not know the true entity for those mentions, they will have multiple labels, i.e. all entities from the p(e|m) prior. 263 | 264 | - CreateWikiTrainingData 265 | - Create sequence labelling data. Tokenization is done with BertTokenizer. Tokens are either have a label when they have an associated Wikipedia link, or when they are in spans detected by the keyword matcher. Subsequently, we count the mentions in this data and create a discounted prior p(e|m) and the set of necessary Wikpedia articles, i.e. all the articles that contain links to the top k popular entities. 266 | 267 | - CreateIntegerizedWikiTrainingData 268 | - Create overlapping chunks of the Wikipedia articles. Outputs are stored as Python lists with integer ids. Configured by `create_integerized_training_instance_text_length` 269 | and `create_integerized_training_instance_text_overlap`. 270 | 271 | Each worker creates his own shard, i.e., the number of shards is determined by `create_integerized_training_num_workers`. 272 | 273 | Only saves a training instance (a chunk of a Wikipedia article) if at least one entity in that chunk has not been seen more than `create_integerized_training_max_entity_per_shard_count` times. This downsamples highly frequent entities. Has to be set in relation to `create_integerized_training_num_workers` 274 | and `num_most_freq_entities`. For the CONLL 2019 paper experiments the setting was 275 | 276 | ``` 277 | create_integerized_training_max_entity_per_shard_count = 10 278 | create_integerized_training_num_workers = 50 279 | num_most_freq_entities = 500000 280 | ``` 281 | 282 | - CreateIntegerizedCONLLTrainingData 283 | - Create overlapping chunks of the benchmark articles. Outputs are stored as Python lists with integer ids. Configured by `create_integerized_training_instance_text_length` 284 | and `create_integerized_training_instance_text_overlap`. 285 | 286 | ###### Footnote 287 | _Here we use an already extracted mapping provided by DBPedia that was created from a 2016 dump. Please note that in the experiments for the paper a Wikipedia dump from 2017 was used. The DbPedia dictionaries might not be adequate for the latest wiki dumps._ 288 | 289 | 290 | 291 | ## Training 292 | 293 | Once you have alle the preprocessing done, you can run the training on Wikipidia to learn a BERT-Entity model. When you have learned a BERT-Entity model on Wikipedia you can resume it to finetune it on the AIDA-CONLL benchmark. 294 | 295 | 1. [Run training](#run-training) 296 | 2. [Prepared configurations](#prepared-configurations) 297 | 3. [Available options](#available-options) 298 | 299 | ### Run training 300 | 301 | Run the training with: 302 | 303 | ``` 304 | python bert_entity/train.py -c TRAIN_CONFIG_FILE_NAME 305 | ``` 306 | 307 | TRAIN_CONFIG_FILE_NAME is a yaml file, but all options can also be given on the command line. 308 | 309 | 310 | ### Run evaluation 311 | 312 | Run evaluation with: 313 | 314 | ``` 315 | python bert_entity/train.py -c TRAIN_CONFIG_FILE_NAME --eval_on_test_only True --resume_from_checkpoint LOGDIR/best_f1-0.pt 316 | ``` 317 | 318 | LOGDIR is the path that was set in TRAIN_CONFIG_FILE_NAME with key `logdir` . 319 | 320 | ### Prepared configurations 321 | 322 | In the config folder you will find the following configurations: 323 | 324 | - [config/dummy__train_on_wiki.yaml](config/dummy__train_on_wiki.yaml) is a setting for prototyping the BERT-Entity training on Wikipieda. 325 | 326 | - [config/dummy__train_on_aida_conll.yaml](config/dummy__train_on_aida_conll.yaml) is a setting to finetune the best found model from `dummy__train_on_wiki.yaml` on the the aida-yago2 dataset. 327 | 328 | - [config/conll2019__train_on_wiki.yaml](config/dummy__train_on_wiki.yaml) is a setting for reproducing the BERT-Entity training on Wikipieda from the CoNLL 2019 paper. 329 | 330 | - [config/conll2019__train_on_aida_conll.yaml](config/dummy__train_on_aida_conll.yaml) is a setting to finetune the best found model from `conll2019__train_on_wiki.yaml` on the the aida-yago2 dataset to reproduce the CoNLL 2019 paper. 331 | 332 | 333 | 334 | ### Available options 335 | 336 | The TRAIN_CONFIG_FILE_NAME supports the following configurations: 337 | 338 | ``` 339 | --debug DEBUG 340 | 341 | --logdir LOGDIR; the output dir where checkpints and logfiles are stored 342 | 343 | --data_workers number of data workers to prepare training instances 344 | 345 | --data_version_name use the same identifier that was used for the same key in 346 | preprocessing 347 | 348 | 349 | --device GPU device used for training 350 | --eval_device GPU device used for evaluation 351 | --out_device GPU device used to collect the most probable entities in the batch 352 | 353 | --dataset 'EDLDataset' for training on Wikipedia or 'CONLLEDLDataset' 354 | for training on AIDA-CONLL 355 | 356 | --model Either 'Net' for training on Wikipedia or 'ConllNet' for 357 | training on AIDA-CONLL 358 | 359 | --eval_on_test_only only run evaluation on test (requires --resume_from_checkpoint) 360 | 361 | --batch_size batch size in training 362 | --eval_batch_size batch size in evluation 363 | --accumulate_batch_gradients accumulate gradients over this many batches 364 | 365 | --n_epochs max number of epochs 366 | --finetuning start finetuning after this many epochs 367 | --checkpoint_eval_steps evaluate every this many epochs 368 | --checkpoint_save_steps save a checkpoint every this many epochs 369 | --dont_save_checkpoints dont_save_checkpoints 370 | 371 | --sparse Use a sparse embedding layer 372 | 373 | --encoder_lr encoder learning rate 374 | --decoder_lr decoder learning rate 375 | --encoder_weight_decay encoder weight decay 376 | --decoder_weight_decay decoder weight decay 377 | --bert_dropout BERT_DROPOUT 378 | 379 | --label_size nr of entities considered in the label vector for each instance 380 | --topk_neg_examples TOPK_NEG_EXAMPLES 381 | --entity_embedding_size entity_embedding_size 382 | --project project entity embedding 383 | 384 | --resume_from_checkpoint path of checkpoint to resume from 385 | --resume_reset_epoch reset the epochs from the checkpoint for resuming (f.ex. 386 | training on AIDA-CONLL) 387 | --resume_optimizer_from_checkpoint resume optimizer from checkpoint 388 | 389 | --eval_before_training evluate once before training 390 | --data_path_conll path to the conll file that was create in data_version_name 391 | --exclude_parameter_names_regex regex to exclude params from training, i.e. freeze them 392 | ``` 393 | 394 | 395 | ## Evalation on downstream tasks 396 | 397 | See files in `downstream_tasks`. Documentation is a TODO. 398 | 399 | ## Issues and possible improvements 400 | 401 | - The code is currently poorly documented and not always nice to read. 402 | - Currently all checkpoints are kept which requires a lot of disk space, should be configurable to only keep the K most current checkpoints. 403 | - Training is slow because of accumulating the most probable entities per batch. This could be sped up with adaptive softmax like change, i.e. first decide if token is an entity. 404 | - Resuming currently only works on the epoch level. To enable resuming in between epochs, the shuffled indexes for an epoch have to be stored in the LOGDIR or checkpoint. 405 | - When AIDA CoNLL entities are missing from the top k popular entities and are added to the vocabulary, should make sure that confounders with high BPE token overlap are added as well. 406 | - Configuration to create integerized Wiki training data is hard to grasp, i.e. which training instances are included depends on `create_integerized_training_max_entity_per_shard_count`, `create_integerized_training_num_workers` and `num_most_freq_entities`. Their influence is difficult to describe, should be more straightforward. 407 | - It would be much better to produce an annotated document and evaluate it with https://github.com/wikilinks/neleval . neleval is widely used and has all the necessary metrics implemented. 408 | - Investigate shrinking the model size and improve prediction speed. 409 | 410 | ## Citation 411 | 412 | if you find this code useful for your research please cite 413 | 414 | ``` 415 | @inproceedings{broscheit-2019-investigating, 416 | title = "Investigating Entity Knowledge in {BERT} with Simple Neural End-To-End Entity Linking", 417 | author = "Broscheit, Samuel", 418 | booktitle = "Proceedings of the 23rd Conference on Computational Natural Language Learning (CoNLL)", 419 | month = nov, 420 | year = "2019", 421 | address = "Hong Kong, China", 422 | publisher = "Association for Computational Linguistics", 423 | url = "https://www.aclweb.org/anthology/K19-1063", 424 | doi = "10.18653/v1/K19-1063", 425 | pages = "677--685", 426 | } 427 | ``` 428 | -------------------------------------------------------------------------------- /downstream_experiments/fairseq_patch_01.patch: -------------------------------------------------------------------------------- 1 | Index: fairseq/models/transformer.py 2 | IDEA additional info: 3 | Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP 4 | <+>UTF-8 5 | =================================================================== 6 | --- fairseq/models/transformer.py (revision ec6f8ef99a8c6942133e01a610def197e1d6d9dd) 7 | +++ fairseq/models/transformer.py (revision 09236665549d8a98aea0a367d70bf3c902950ee7) 8 | @@ -14,6 +14,8 @@ 9 | from fairseq import options 10 | from fairseq import utils 11 | 12 | +from pytorch_pretrained_bert import BertModel 13 | + 14 | from fairseq.modules import ( 15 | AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LearnedPositionalEmbedding, MultiheadAttention, 16 | SinusoidalPositionalEmbedding 17 | @@ -56,6 +58,8 @@ 18 | help='dropout probability for attention weights') 19 | parser.add_argument('--relu-dropout', type=float, metavar='D', 20 | help='dropout probability after ReLU in FFN') 21 | + parser.add_argument('--use-bert-encoder', type=str, metavar='STR', 22 | + help='use bert encoder') 23 | parser.add_argument('--encoder-embed-path', type=str, metavar='STR', 24 | help='path to pre-trained encoder embedding') 25 | parser.add_argument('--encoder-embed-dim', type=int, metavar='N', 26 | @@ -143,8 +147,10 @@ 27 | decoder_embed_tokens = build_embedding( 28 | tgt_dict, args.decoder_embed_dim, args.decoder_embed_path 29 | ) 30 | - 31 | - encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens) 32 | + if args.use_bert_encoder: 33 | + encoder = BertTransformerEncoder(args) 34 | + else: 35 | + encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens) 36 | decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens) 37 | return TransformerModel(encoder, decoder) 38 | 39 | @@ -287,7 +293,7 @@ 40 | 41 | self.layers = nn.ModuleList([]) 42 | self.layers.extend([ 43 | - TransformerEncoderLayer(args) 44 | + TransformerEncoder(args) 45 | for i in range(args.encoder_layers) 46 | ]) 47 | self.register_buffer('version', torch.Tensor([2])) 48 | @@ -377,6 +383,100 @@ 49 | return state_dict 50 | 51 | 52 | +class BertTransformerEncoder(FairseqEncoder): 53 | + """ 54 | + Transformer encoder consisting of *args.encoder_layers* layers. Each layer 55 | + is a :class:`TransformerEncoderLayer`. 56 | + 57 | + Args: 58 | + args (argparse.Namespace): parsed command-line arguments 59 | + dictionary (~fairseq.data.Dictionary): encoding dictionary 60 | + embed_tokens (torch.nn.Embedding): input embedding 61 | + left_pad (bool, optional): whether the input is left-padded 62 | + (default: True). 63 | + """ 64 | + 65 | + def __init__(self, args): 66 | + # def __init__(self, args, dictionary, embed_tokens, left_pad=True): 67 | + super().__init__(None) 68 | + self.bert = BertModel.from_pretrained('bert-base-cased') 69 | + self.padding_idx = 0 70 | + self.max_source_positions = args.max_source_positions 71 | + 72 | + def forward(self, src_tokens, src_lengths): 73 | + """ 74 | + Args: 75 | + src_tokens (LongTensor): tokens in the source language of shape 76 | + `(batch, src_len)` 77 | + src_lengths (torch.LongTensor): lengths of each source sentence of 78 | + shape `(batch)` 79 | + 80 | + Returns: 81 | + dict: 82 | + - **encoder_out** (Tensor): the last encoder layer's output of 83 | + shape `(src_len, batch, embed_dim)` 84 | + - **encoder_padding_mask** (ByteTensor): the positions of 85 | + padding elements of shape `(batch, src_len)` 86 | + """ 87 | + # embed tokens and positions 88 | + 89 | + # B x T x C -> T x B x C 90 | + # x = x.transpose(0, 1) 91 | + 92 | + # compute padding mask 93 | + encoder_padding_mask = src_tokens.eq(self.padding_idx) 94 | + if not encoder_padding_mask.any(): 95 | + encoder_padding_mask = None 96 | + 97 | + # encoder layers 98 | + 99 | + encoded_layers, _ = self.bert(src_tokens) 100 | + x = encoded_layers[-1] 101 | + 102 | + return { 103 | + 'encoder_out': x, # T x B x C 104 | + 'encoder_padding_mask': encoder_padding_mask, # B x T 105 | + } 106 | + 107 | + def reorder_encoder_out(self, encoder_out, new_order): 108 | + """ 109 | + Reorder encoder output according to *new_order*. 110 | + 111 | + Args: 112 | + encoder_out: output from the ``forward()`` method 113 | + new_order (LongTensor): desired order 114 | + 115 | + Returns: 116 | + *encoder_out* rearranged according to *new_order* 117 | + """ 118 | + if encoder_out['encoder_out'] is not None: 119 | + encoder_out['encoder_out'] = \ 120 | + encoder_out['encoder_out'].index_select(1, new_order) 121 | + if encoder_out['encoder_padding_mask'] is not None: 122 | + encoder_out['encoder_padding_mask'] = \ 123 | + encoder_out['encoder_padding_mask'].index_select(0, new_order) 124 | + return encoder_out 125 | + 126 | + def max_positions(self): 127 | + """Maximum input length supported by the encoder.""" 128 | + return self.max_source_positions 129 | + 130 | + def upgrade_state_dict_named(self, state_dict, name): 131 | + raise Exception('upgrade_state_dict_named not implemented') 132 | + # """Upgrade a (possibly old) state dict for new versions of fairseq.""" 133 | + # if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): 134 | + # weights_key = '{}.embed_positions.weights'.format(name) 135 | + # if weights_key in state_dict: 136 | + # del state_dict[weights_key] 137 | + # state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) 138 | + # version_key = '{}.version'.format(name) 139 | + # if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: 140 | + # # earlier checkpoints did not normalize after the stack of layers 141 | + # self.layer_norm = None 142 | + # self.normalize = False 143 | + # state_dict[version_key] = torch.Tensor([1]) 144 | + # return state_dict 145 | + 146 | class TransformerDecoder(FairseqIncrementalDecoder): 147 | """ 148 | Transformer decoder consisting of *args.decoder_layers* layers. Each layer 149 | @@ -880,6 +980,16 @@ 150 | base_architecture(args) 151 | 152 | 153 | +@register_model_architecture('transformer', 'bert_transformer_iwslt_en_de') 154 | +def transformer_iwslt_de_en(args): 155 | + args.encoder_embed_dim = getattr(args, 'decoder_embed_dim', 768) 156 | + args.use_bert_encoder = getattr(args, 'use_bert_encoder', True) 157 | + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 768) 158 | + args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024) 159 | + args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4) 160 | + args.decoder_layers = getattr(args, 'decoder_layers', 6) 161 | + base_architecture(args) 162 | + 163 | @register_model_architecture('transformer', 'transformer_wmt_en_de') 164 | def transformer_wmt_en_de(args): 165 | base_architecture(args) 166 | Index: preprocess.py 167 | IDEA additional info: 168 | Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP 169 | <+>UTF-8 170 | =================================================================== 171 | --- preprocess.py (revision ec6f8ef99a8c6942133e01a610def197e1d6d9dd) 172 | +++ preprocess.py (revision 09236665549d8a98aea0a367d70bf3c902950ee7) 173 | @@ -72,7 +72,7 @@ 174 | 175 | print(args) 176 | os.makedirs(args.destdir, exist_ok=True) 177 | - target = not args.only_source 178 | + not_only_source = not args.only_source 179 | 180 | def train_path(lang): 181 | return "{}{}".format(args.trainpref, ("." + lang) if lang else "") 182 | @@ -105,7 +105,7 @@ 183 | args.trainpref 184 | ), "--trainpref must be set if --srcdict is not specified" 185 | src_dict = build_dictionary([train_path(args.source_lang)], args.workers) 186 | - if target: 187 | + if not_only_source: 188 | if args.tgtdict: 189 | tgt_dict = dictionary.Dictionary.load(args.tgtdict) 190 | else: 191 | @@ -122,7 +122,7 @@ 192 | padding_factor=args.padding_factor, 193 | ) 194 | src_dict.save(dict_path(args.source_lang)) 195 | - if target: 196 | + if not_only_source: 197 | if not args.joined_dictionary: 198 | tgt_dict.finalize( 199 | threshold=args.thresholdtgt, 200 | @@ -220,7 +220,7 @@ 201 | make_dataset(testpref, outprefix, lang) 202 | 203 | make_all(args.source_lang) 204 | - if target: 205 | + if not_only_source: 206 | make_all(args.target_lang) 207 | 208 | print("| Wrote preprocessed data to {}".format(args.destdir)) 209 | Index: preprocess_4_bert_encoder.py 210 | IDEA additional info: 211 | Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP 212 | <+>UTF-8 213 | =================================================================== 214 | --- preprocess_4_bert_encoder.py (revision 09236665549d8a98aea0a367d70bf3c902950ee7) 215 | +++ preprocess_4_bert_encoder.py (revision 09236665549d8a98aea0a367d70bf3c902950ee7) 216 | @@ -0,0 +1,352 @@ 217 | +#!/usr/bin/env python3 218 | +# Copyright (c) 2017-present, Facebook, Inc. 219 | +# All rights reserved. 220 | +# 221 | +# This source code is licensed under the license found in the LICENSE file in 222 | +# the root directory of this source tree. An additional grant of patent rights 223 | +# can be found in the PATENTS file in the same directory. 224 | +""" 225 | +Data pre-processing: build vocabularies and binarize training data. 226 | +""" 227 | + 228 | +import argparse 229 | +from collections import Counter 230 | +from itertools import zip_longest 231 | +import os 232 | +import shutil 233 | + 234 | +from fairseq.data import indexed_dataset, dictionary 235 | +from fairseq.tokenizer import Tokenizer, tokenize_line 236 | +from multiprocessing import Pool 237 | +from pytorch_pretrained_bert import BertModel 238 | + 239 | +from fairseq.utils import import_user_module 240 | + 241 | + 242 | +def get_parser(): 243 | + parser = argparse.ArgumentParser() 244 | + # fmt: off 245 | + parser.add_argument("-s", "--source-lang", default=None, metavar="SRC", 246 | + help="source language") 247 | + parser.add_argument("-t", "--target-lang", default=None, metavar="TARGET", 248 | + help="target language") 249 | + parser.add_argument("--trainpref", metavar="FP", default=None, 250 | + help="train file prefix") 251 | + parser.add_argument("--validpref", metavar="FP", default=None, 252 | + help="comma separated, valid file prefixes") 253 | + parser.add_argument("--testpref", metavar="FP", default=None, 254 | + help="comma separated, test file prefixes") 255 | + parser.add_argument("--destdir", metavar="DIR", default="data-bin", 256 | + help="destination dir") 257 | + parser.add_argument("--thresholdtgt", metavar="N", default=0, type=int, 258 | + help="map words appearing less than threshold times to unknown") 259 | + parser.add_argument("--thresholdsrc", metavar="N", default=0, type=int, 260 | + help="map words appearing less than threshold times to unknown") 261 | + parser.add_argument("--tgtdict", metavar="FP", 262 | + help="reuse given target dictionary") 263 | + parser.add_argument("--srcdict", metavar="FP", 264 | + help="reuse given source dictionary") 265 | + parser.add_argument("--nwordstgt", metavar="N", default=-1, type=int, 266 | + help="number of target words to retain") 267 | + parser.add_argument("--nwordssrc", metavar="N", default=-1, type=int, 268 | + help="number of source words to retain") 269 | + parser.add_argument("--alignfile", metavar="ALIGN", default=None, 270 | + help="an alignment file (optional)") 271 | + parser.add_argument("--output-format", metavar="FORMAT", default="binary", 272 | + choices=["binary", "raw"], 273 | + help="output format (optional)") 274 | + parser.add_argument("--joined-dictionary", action="store_true", 275 | + help="Generate joined dictionary") 276 | + parser.add_argument("--only-source", action="store_true", 277 | + help="Only process the source language") 278 | + parser.add_argument("--padding-factor", metavar="N", default=8, type=int, 279 | + help="Pad dictionary size to be multiple of N") 280 | + parser.add_argument("--workers", metavar="N", default=1, type=int, 281 | + help="number of parallel workers") 282 | + # fmt: on 283 | + return parser 284 | + 285 | + 286 | +def main(args): 287 | + import_user_module(args) 288 | + 289 | + print(args) 290 | + os.makedirs(args.destdir, exist_ok=True) 291 | + not_only_source = not args.only_source 292 | + 293 | + def train_path(lang): 294 | + return "{}{}".format(args.trainpref, ("." + lang) if lang else "") 295 | + 296 | + def file_name(prefix, lang): 297 | + fname = prefix 298 | + if lang is not None: 299 | + fname += ".{lang}".format(lang=lang) 300 | + return fname 301 | + 302 | + def dest_path(prefix, lang): 303 | + return os.path.join(args.destdir, file_name(prefix, lang)) 304 | + 305 | + def dict_path(lang): 306 | + return dest_path("dict", lang) + ".txt" 307 | + 308 | + if args.joined_dictionary: 309 | + # assert not args.srcdict, "cannot combine --srcdict and --joined-dictionary" 310 | + # assert not args.tgtdict, "cannot combine --tgtdict and --joined-dictionary" 311 | + # src_dict = build_dictionary( 312 | + # {train_path(lang) for lang in [args.source_lang, args.target_lang]}, 313 | + # args.workers, 314 | + # ) 315 | + # tgt_dict = src_dict 316 | + raise Exception('joined_dictionary not implemented') 317 | + else: 318 | + if not_only_source: 319 | + if args.tgtdict: 320 | + tgt_dict = dictionary.Dictionary.load(args.tgtdict) 321 | + else: 322 | + assert ( 323 | + args.trainpref 324 | + ), "--trainpref must be set if --tgtdict is not specified" 325 | + tgt_dict = build_dictionary( 326 | + [train_path(args.target_lang)], args.workers 327 | + ) 328 | + 329 | + from pytorch_pretrained_bert import BertTokenizer 330 | + tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False) 331 | + 332 | + def save (f): 333 | + if isinstance(f, str): 334 | + os.makedirs(os.path.dirname(f), exist_ok=True) 335 | + with open(f, 'w', encoding='utf-8') as fd: 336 | + return save(fd) 337 | + for symbol, index in tokenizer.vocab.items(): 338 | + print('{} {}'.format(symbol, len(tokenizer.vocab)-index), file=f) 339 | + 340 | + save(dict_path(args.source_lang)) 341 | + 342 | + if not_only_source: 343 | + if not args.joined_dictionary: 344 | + tgt_dict.finalize( 345 | + threshold=args.thresholdtgt, 346 | + nwords=args.nwordstgt, 347 | + padding_factor=args.padding_factor, 348 | + ) 349 | + tgt_dict.save(dict_path(args.target_lang)) 350 | + 351 | + def make_binary_dataset(input_prefix, output_prefix, lang, num_workers): 352 | + dict = dictionary.Dictionary.load(dict_path(lang)) 353 | + print("| [{}] Dictionary: {} types".format(lang, len(dict) - 1)) 354 | + n_seq_tok = [0, 0] 355 | + replaced = Counter() 356 | + 357 | + def merge_result(worker_result): 358 | + replaced.update(worker_result["replaced"]) 359 | + n_seq_tok[0] += worker_result["nseq"] 360 | + n_seq_tok[1] += worker_result["ntok"] 361 | + 362 | + input_file = "{}{}".format( 363 | + input_prefix, ("." + lang) if lang is not None else "" 364 | + ) 365 | + offsets = Tokenizer.find_offsets(input_file, num_workers) 366 | + pool = None 367 | + if num_workers > 1: 368 | + pool = Pool(processes=num_workers - 1) 369 | + for worker_id in range(1, num_workers): 370 | + prefix = "{}{}".format(output_prefix, worker_id) 371 | + pool.apply_async( 372 | + binarize, 373 | + ( 374 | + args, 375 | + input_file, 376 | + dict, 377 | + prefix, 378 | + lang, 379 | + offsets[worker_id], 380 | + offsets[worker_id + 1], 381 | + ), 382 | + callback=merge_result, 383 | + ) 384 | + pool.close() 385 | + 386 | + ds = indexed_dataset.IndexedDatasetBuilder( 387 | + dataset_dest_file(args, output_prefix, lang, "bin") 388 | + ) 389 | + merge_result( 390 | + Tokenizer.binarize( 391 | + input_file, dict, lambda t: ds.add_item(t), offset=0, end=offsets[1] 392 | + ) 393 | + ) 394 | + if num_workers > 1: 395 | + pool.join() 396 | + for worker_id in range(1, num_workers): 397 | + prefix = "{}{}".format(output_prefix, worker_id) 398 | + temp_file_path = dataset_dest_prefix(args, prefix, lang) 399 | + ds.merge_file_(temp_file_path) 400 | + os.remove(indexed_dataset.data_file_path(temp_file_path)) 401 | + os.remove(indexed_dataset.index_file_path(temp_file_path)) 402 | + 403 | + ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) 404 | + 405 | + print( 406 | + "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format( 407 | + lang, 408 | + input_file, 409 | + n_seq_tok[0], 410 | + n_seq_tok[1], 411 | + 100 * sum(replaced.values()) / n_seq_tok[1], 412 | + dict.unk_word, 413 | + ) 414 | + ) 415 | + 416 | + def make_dataset(input_prefix, output_prefix, lang, num_workers=1): 417 | + if args.output_format == "binary": 418 | + make_binary_dataset(input_prefix, output_prefix, lang, num_workers) 419 | + elif args.output_format == "raw": 420 | + # Copy original text file to destination folder 421 | + output_text_file = dest_path( 422 | + output_prefix + ".{}-{}".format(args.source_lang, args.target_lang), 423 | + lang, 424 | + ) 425 | + shutil.copyfile(file_name(input_prefix, lang), output_text_file) 426 | + 427 | + def make_all(lang): 428 | + if args.trainpref: 429 | + make_dataset(args.trainpref, "train", lang, num_workers=args.workers) 430 | + if args.validpref: 431 | + for k, validpref in enumerate(args.validpref.split(",")): 432 | + outprefix = "valid{}".format(k) if k > 0 else "valid" 433 | + make_dataset(validpref, outprefix, lang) 434 | + if args.testpref: 435 | + for k, testpref in enumerate(args.testpref.split(",")): 436 | + outprefix = "test{}".format(k) if k > 0 else "test" 437 | + make_dataset(testpref, outprefix, lang) 438 | + 439 | + make_all(args.source_lang) 440 | + if not_only_source: 441 | + make_all(args.target_lang) 442 | + 443 | + print("| Wrote preprocessed data to {}".format(args.destdir)) 444 | + 445 | + if args.alignfile: 446 | + assert args.trainpref, "--trainpref must be set if --alignfile is specified" 447 | + src_file_name = train_path(args.source_lang) 448 | + tgt_file_name = train_path(args.target_lang) 449 | + src_dict = dictionary.Dictionary.load(dict_path(args.source_lang)) 450 | + tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang)) 451 | + freq_map = {} 452 | + with open(args.alignfile, "r", encoding='utf-8') as align_file: 453 | + with open(src_file_name, "r", encoding='utf-8') as src_file: 454 | + with open(tgt_file_name, "r", encoding='utf-8') as tgt_file: 455 | + for a, s, t in zip_longest(align_file, src_file, tgt_file): 456 | + si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False) 457 | + ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False) 458 | + ai = list(map(lambda x: tuple(x.split("-")), a.split())) 459 | + for sai, tai in ai: 460 | + srcidx = si[int(sai)] 461 | + tgtidx = ti[int(tai)] 462 | + if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk(): 463 | + assert srcidx != src_dict.pad() 464 | + assert srcidx != src_dict.eos() 465 | + assert tgtidx != tgt_dict.pad() 466 | + assert tgtidx != tgt_dict.eos() 467 | + 468 | + if srcidx not in freq_map: 469 | + freq_map[srcidx] = {} 470 | + if tgtidx not in freq_map[srcidx]: 471 | + freq_map[srcidx][tgtidx] = 1 472 | + else: 473 | + freq_map[srcidx][tgtidx] += 1 474 | + 475 | + align_dict = {} 476 | + for srcidx in freq_map.keys(): 477 | + align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get) 478 | + 479 | + with open( 480 | + os.path.join( 481 | + args.destdir, 482 | + "alignment.{}-{}.txt".format(args.source_lang, args.target_lang), 483 | + ), 484 | + "w", encoding='utf-8' 485 | + ) as f: 486 | + for k, v in align_dict.items(): 487 | + print("{} {}".format(src_dict[k], tgt_dict[v]), file=f) 488 | + 489 | + 490 | +def build_and_save_dictionary( 491 | + train_path, output_path, num_workers, freq_threshold, max_words, dict_cls=dictionary.Dictionary, 492 | +): 493 | + dict = build_dictionary([train_path], num_workers, dict_cls) 494 | + dict.finalize(threshold=freq_threshold, nwords=max_words) 495 | + dict_path = os.path.join(output_path, "dict.txt") 496 | + dict.save(dict_path) 497 | + return dict_path 498 | + 499 | + 500 | +def build_dictionary( 501 | + filenames, 502 | + workers, 503 | + dict_cls=dictionary.Dictionary, 504 | +): 505 | + d = dict_cls() 506 | + for filename in filenames: 507 | + Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, workers) 508 | + return d 509 | + 510 | + 511 | +def binarize(args, filename, dict, output_prefix, lang, offset, end): 512 | + ds = indexed_dataset.IndexedDatasetBuilder( 513 | + dataset_dest_file(args, output_prefix, lang, "bin") 514 | + ) 515 | + 516 | + def consumer(tensor): 517 | + ds.add_item(tensor) 518 | + 519 | + res = Tokenizer.binarize(filename, dict, consumer, offset=offset, end=end) 520 | + ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) 521 | + return res 522 | + 523 | + 524 | +def binarize_with_load( 525 | + args, 526 | + filename, 527 | + dict_path, 528 | + output_prefix, 529 | + lang, 530 | + offset, 531 | + end, 532 | + dict_cls=dictionary.Dictionary, 533 | +): 534 | + dict = dict_cls.load(dict_path) 535 | + binarize(args, filename, dict, output_prefix, lang, offset, end) 536 | + return dataset_dest_prefix(args, output_prefix, lang) 537 | + 538 | + 539 | +def dataset_dest_prefix(args, output_prefix, lang): 540 | + base = "{}/{}".format(args.destdir, output_prefix) 541 | + lang_part = ( 542 | + ".{}-{}.{}".format(args.source_lang, args.target_lang, lang) if lang is not None else "" 543 | + ) 544 | + return "{}{}".format(base, lang_part) 545 | + 546 | + 547 | +def dataset_dest_file(args, output_prefix, lang, extension): 548 | + base = dataset_dest_prefix(args, output_prefix, lang) 549 | + return "{}.{}".format(base, extension) 550 | + 551 | + 552 | +def get_offsets(input_file, num_workers): 553 | + return Tokenizer.find_offsets(input_file, num_workers) 554 | + 555 | + 556 | +def merge_files(files, outpath): 557 | + ds = indexed_dataset.IndexedDatasetBuilder("{}.bin".format(outpath)) 558 | + for file in files: 559 | + ds.merge_file_(file) 560 | + os.remove(indexed_dataset.data_file_path(file)) 561 | + os.remove(indexed_dataset.index_file_path(file)) 562 | + ds.finalize("{}.idx".format(outpath)) 563 | + 564 | + 565 | +if __name__ == "__main__": 566 | + parser = get_parser() 567 | + args = parser.parse_args() 568 | + main(args) 569 | --------------------------------------------------------------------------------