├── config.py ├── config.yaml ├── scripts ├── download_wikinews.sh ├── download_wiki.sh ├── download_wiki.py ├── preprocess_wikinews.py └── preprocess_wiki.py ├── README.md ├── generate_entity_embeddings.py ├── main.py └── train.py /config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | def read_config(config_file): 4 | with open(config_file, 'r') as f: 5 | config = yaml.safe_load(f) 6 | return config 7 | 8 | config_file = 'config.yaml' 9 | config = read_config(config_file) -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | file_path: wikipedia/wikipedia_data.json 3 | 4 | model: 5 | text_model_name: distilbert-base-uncased 6 | entity_model_name: distilbert-base-uncased 7 | text_encoder_path: bi_encoder_output/text_encoder 8 | entity_encoder_path: bi_encoder_output/entity_encoder 9 | tokenizer_path: bi_encoder_output/tokenizer 10 | index_path: entity_embeddings_hnsw.faiss -------------------------------------------------------------------------------- /scripts/download_wikinews.sh: -------------------------------------------------------------------------------- 1 | mkdir wikinews 2 | cd wikinews 3 | 4 | for LANG in ar bg bs ca cs de el en eo es fa fi fr he hu it ja ko nl no pl pt ro ru sd sq sr sv ta th tr uk zh 5 | do 6 | wget http://wikipedia.c3sl.ufpr.br/${LANG}wikinews/20191001/${LANG}wikinews-20191001-pages-articles-multistream.xml.bz2 7 | done 8 | 9 | for LANG in ar bg bs ca cs de el en eo es fa fi fr he hu it ja ko nl no pl pt ro ru sd sq sr sv ta th tr uk zh 10 | do 11 | wikiextractor ${LANG}wikinews-20191001-pages-articles-multistream.xml.bz2 -o ${LANG} --links --section_hierarchy --lists --sections 12 | done -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bi-encoder entity linking model 2 | 3 | Simple, but efficient model to retrieve entities in the given piece of text. The model uses BERT encoder for text and entities and fast HNSW index to retrieve nearest entities. 4 | 5 | ## Usage: 6 | 7 | ``` 8 | python main.py "Nobel Prize-winning physicist who developed the theory of general relativity." 9 | ``` 10 | 11 | ## Train the model: 12 | 13 | ### Download wikipedia dump 14 | 15 | First we need to download and wikipedia data: 16 | ``` 17 | scripts/download_wiki.sh 18 | ``` 19 | 20 | ``` 21 | python scripts/preprocess_wiki.py 22 | ``` 23 | 24 | ### Train the model 25 | 26 | ``` 27 | python train.py 28 | ``` 29 | 30 | ### Generate entities embeddings index 31 | 32 | ``` 33 | python generate_entity_embeddings.py 34 | ``` -------------------------------------------------------------------------------- /scripts/download_wiki.sh: -------------------------------------------------------------------------------- 1 | mkdir wikipedia 2 | cd wikipedia 3 | 4 | for LANG in af am ar as az be bg bm bn br bs ca cs cy da de el en eo es et eu fa ff fi fr fy ga gd gl gn gu ha he hi hr ht hu hy id ig is it ja jv ka kg kk km kn ko ku ky la lg ln lo lt lv mg mk ml mn mr ms my ne nl no om or pa pl ps pt qu ro ru sa sd si sk sl so sq sr ss su sv sw ta te th ti tl tn tr uk ur uz vi wo xh yo zh 5 | do 6 | wget http://wikipedia.c3sl.ufpr.br/${LANG}wiki/20191001/${LANG}wiki-20191001-pages-articles-multistream.xml.bz2 7 | done 8 | 9 | for LANG in af am ar as az be bg bm bn br bs ca cs cy da de el en eo es et eu fa ff fi fr fy ga gd gl gn gu ha he hi hr ht hu hy id ig is it ja jv ka kg kk km kn ko ku ky la lg ln lo lt lv mg mk ml mn mr ms my ne nl no om or pa pl ps pt qu ro ru sa sd si sk sl so sq sr ss su sv sw ta te th ti tl tn tr uk ur uz vi wo xh yo zh 10 | do 11 | wikiextractor ${LANG}wikinews-20191001-pages-articles-multistream.xml.bz2 -o ${LANG} --links --lists --sections 12 | done 13 | -------------------------------------------------------------------------------- /generate_entity_embeddings.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import faiss 4 | from transformers import AutoTokenizer, AutoModel 5 | 6 | from .config import config 7 | 8 | # Load the JSON data 9 | def load_json_data(file_path): 10 | with open(file_path, 'r') as f: 11 | data = json.load(f) 12 | return data 13 | 14 | file_path = config['data']['file_path'] 15 | data = load_json_data(file_path) 16 | 17 | # Initialize the tokenizer and entity encoder 18 | entity_model_name = config['model']['entity_model_name'] 19 | tokenizer = AutoTokenizer.from_pretrained(entity_model_name) 20 | entity_encoder = AutoModel.from_pretrained(config['model']['entity_encoder_path']) 21 | 22 | def generate_entity_embeddings(data, tokenizer, entity_encoder): 23 | entity_embeddings = [] 24 | 25 | for page in data: 26 | entity_text = page['text'] 27 | encoded_input = tokenizer(entity_text, truncation=True, padding=True, return_tensors="pt") 28 | 29 | with torch.no_grad(): 30 | embeddings = entity_encoder(**encoded_input).last_hidden_state[:, 0, :] 31 | 32 | entity_embeddings.append(embeddings) 33 | 34 | return torch.cat(entity_embeddings, dim=0) 35 | 36 | entity_embeddings = generate_entity_embeddings(data, tokenizer, entity_encoder) 37 | 38 | # Save the generated embeddings to a torch file 39 | torch.save(entity_embeddings, 'entity_embeddings.pt') 40 | 41 | # Initialize the Faiss index 42 | embedding_size = entity_embeddings.shape[1] 43 | index = faiss.IndexHNSWFlat(embedding_size, 32) 44 | index.hnsw.efConstruction = 40 45 | index.verbose = True 46 | 47 | # Add the embeddings to the Faiss index 48 | faiss.normalize_L2(entity_embeddings.numpy()) 49 | index.add(entity_embeddings.numpy()) 50 | 51 | # Save the Faiss index 52 | faiss.write_index(index, 'entity_embeddings_hnsw.faiss') -------------------------------------------------------------------------------- /scripts/download_wiki.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | 4 | # Define the URL for the Wikipedia API 5 | WIKIPEDIA_API_URL = "https://en.wikipedia.org/w/api.php" 6 | LIMIT=None 7 | 8 | # Define the parameters for the API query 9 | params = { 10 | "action": "query", 11 | "generator": "allpages", 12 | "prop": "extracts", 13 | "exintro": True, 14 | "explaintext": True, 15 | "format": "json", 16 | "gaplimit": 100, 17 | "gapfrom": "A" # Starting letter of Wikipedia pages to download 18 | } 19 | 20 | pages_count = 0 21 | 22 | # Open the output file in JSON Lines format 23 | with open("wikipedia_articles.jsonl", "w") as f: 24 | 25 | # Loop through the API results in batches of 10 pages at a time 26 | while True: 27 | 28 | # Send the API request and get the response 29 | response = requests.get(WIKIPEDIA_API_URL, params=params) 30 | data = response.json() 31 | 32 | # Loop through the pages in the response and write each one to the output file 33 | for page_id, page_data in data["query"]["pages"].items(): 34 | # Extract relevant fields 35 | title = page_data["title"] 36 | body = page_data["extract"] 37 | wiki_id = page_data["pageid"] 38 | 39 | # Write the fields to the output file in JSON Lines format 40 | f.write(json.dumps({"title": title, "body": body, "wikipedia_id": wiki_id}) + "\n") 41 | pages_count += 1 42 | 43 | if pages_count % 100 == 0: 44 | print(f"Downloaded {pages_count} pages") 45 | 46 | if LIMIT and pages_count == LIMIT: 47 | break 48 | 49 | # Check if there are more pages to download 50 | if "continue" in data: 51 | params["gapcontinue"] = data["continue"]["excontinue"] 52 | else: 53 | break -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import faiss 4 | import argparse 5 | from transformers import AutoTokenizer, AutoModel 6 | 7 | from .config import config 8 | 9 | # Load the JSON data 10 | def load_json_data(file_path): 11 | with open(file_path, 'r') as f: 12 | data = json.load(f) 13 | return data 14 | 15 | file_path = config['data']['file_path'] 16 | data = load_json_data(file_path) 17 | 18 | # Load the trained model, tokenizer, and Faiss index 19 | text_model_name = config['model']['text_model_name'] 20 | tokenizer = AutoTokenizer.from_pretrained(text_model_name) 21 | text_encoder = AutoModel.from_pretrained(config['model']['text_encoder_path']) 22 | 23 | index = faiss.read_index(config['model']['index_path']) 24 | 25 | # Define a function to preprocess and embed the input text 26 | def embed_text(text, tokenizer, text_encoder): 27 | encoded_input = tokenizer(text, truncation=True, padding=True, return_tensors="pt") 28 | 29 | with torch.no_grad(): 30 | embeddings = text_encoder(**encoded_input).last_hidden_state[:, 0, :] 31 | 32 | return embeddings 33 | 34 | # Run the inference and search for the nearest entities 35 | def search_nearest_entities(text, tokenizer, text_encoder, index, data, k=5): 36 | text_embeddings = embed_text(text, tokenizer, text_encoder) 37 | faiss.normalize_L2(text_embeddings.numpy()) 38 | scores, indices = index.search(text_embeddings.numpy(), k) 39 | 40 | nearest_entities = [(data[i]['entity'], score) for i, score in zip(indices[0], scores[0])] 41 | 42 | return nearest_entities 43 | 44 | # Read text from the command line 45 | parser = argparse.ArgumentParser(description='Retrieve nearest entities for a given text.') 46 | parser.add_argument('text', type=str, help='Text to retrieve nearest entities for') 47 | args = parser.parse_args() 48 | input_text = args.text 49 | 50 | nearest_entities = search_nearest_entities(input_text, tokenizer, text_encoder, index, data) 51 | 52 | # Print the nearest entity names and scores 53 | print("Nearest Entities:") 54 | for entity, score in nearest_entities: 55 | print(f"- {entity} (score: {score:.4f})") -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | from sklearn.model_selection import train_test_split 3 | from transformers import AutoTokenizer, AutoModel, TrainingArguments, DataCollatorWithPadding, Trainer 4 | from datasets import Dataset 5 | 6 | 7 | from .config import config 8 | 9 | # Load the JSON data 10 | def load_json_data(file_path): 11 | with open(file_path, 'r') as f: 12 | data = json.load(f) 13 | return data 14 | 15 | file_path = config['data']['file_path'] 16 | data = load_json_data(file_path) 17 | 18 | # Preprocess the data 19 | def preprocess_data(data): 20 | processed_data = [] 21 | for page in data: 22 | for mention in page['mentions']: 23 | processed_data.append({ 24 | 'input_text': mention['text'], 25 | 'entity': mention['entity'], 26 | 'entity_text': page['text'] 27 | }) 28 | return processed_data 29 | 30 | processed_data = preprocess_data(data) 31 | train_data, test_data = train_test_split(processed_data, test_size=0.2, random_state=42) 32 | 33 | train_dataset = Dataset.from_dict(train_data) 34 | test_dataset = Dataset.from_dict(test_data) 35 | 36 | # Train the bi-encoder model 37 | text_model_name = config['model']['text_model_name'] 38 | entity_model_name = config['model']['entity_model_name'] 39 | tokenizer = AutoTokenizer.from_pretrained(text_model_name) 40 | text_encoder = AutoModel.from_pretrained(text_model_name) 41 | entity_encoder = AutoModel.from_pretrained(entity_model_name) 42 | 43 | 44 | def tokenize_data(example): 45 | input_text = example['input_text'] 46 | entity_text = example['entity_text'] 47 | input_encoding = tokenizer(input_text, truncation=True, padding=False) 48 | entity_encoding = tokenizer(entity_text, truncation=True, padding=False) 49 | return {**input_encoding, 'entity_ids': entity_encoding['input_ids'], 'entity_attention_mask': entity_encoding['attention_mask']} 50 | 51 | train_dataset = train_dataset.map(tokenize_data, batched=True) 52 | test_dataset = test_dataset.map(tokenize_data, batched=True) 53 | 54 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 55 | 56 | training_args = TrainingArguments( 57 | output_dir="bi_encoder_output", 58 | evaluation_strategy="epoch", 59 | save_strategy="epoch", 60 | per_device_train_batch_size=8, 61 | per_device_eval_batch_size=8, 62 | num_train_epochs=3, 63 | seed=42, 64 | learning_rate=5e-5, 65 | ) 66 | 67 | def compute_loss(text_encoder, entity_encoder, inputs, return_outputs=False): 68 | input_ids = inputs.pop("input_ids") 69 | attention_mask = inputs.pop("attention_mask") 70 | entity_ids = inputs.pop("entity_ids") 71 | entity_attention_mask = inputs.pop("entity_attention_mask") 72 | 73 | input_outputs = text_encoder(input_ids, attention_mask=attention_mask) 74 | entity_outputs = entity_encoder(entity_ids, attention_mask=entity_attention_mask) 75 | 76 | similarities = (input_outputs[0] * entity_outputs[0]).sum(dim=1) 77 | loss = -similarities.mean() 78 | 79 | if return_outputs: 80 | return loss, input_outputs 81 | return loss 82 | 83 | trainer = Trainer( 84 | model=(text_encoder, entity_encoder), 85 | args=training_args, 86 | train_dataset=train_dataset, 87 | eval_dataset=test_dataset, 88 | data_collator=data_collator, 89 | compute_loss=compute_loss, 90 | ) 91 | 92 | trainer.train() 93 | 94 | # Save and test the model 95 | text_encoder.save_pretrained(config['model']['text_encoder_path']) 96 | entity_encoder.save_pretrained(config['model']['entity_encoder_path']) 97 | tokenizer.save_pretrained(config['model']['tokenizer_path']) -------------------------------------------------------------------------------- /scripts/preprocess_wikinews.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pickle 5 | import re 6 | from collections import defaultdict 7 | 8 | import jsonlines 9 | import numpy as np 10 | import pandas 11 | from tqdm.auto import tqdm, trange 12 | 13 | if __name__ == "__main__": 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument( 18 | "output_dir", 19 | type=str, 20 | ) 21 | parser.add_argument( 22 | "--base_wikinews", 23 | type=str, 24 | help="Base folder with Wikipedia data.", 25 | ) 26 | parser.add_argument( 27 | "--langs", 28 | type=str, 29 | default="ar|bg|bs|ca|cs|de|el|en|eo|es|fa|fi|fr|he|hu|it|ja|ko|nl|no|pl|pt|ro|ru|sd|sq|sr|sv|ta|th|tr|uk|zh", 30 | help="Pipe (|) separated list of language ID to use.", 31 | ) 32 | parser.add_argument( 33 | "-d", 34 | "--debug", 35 | help="Print lots of debugging statements", 36 | action="store_const", 37 | dest="loglevel", 38 | const=logging.DEBUG, 39 | default=logging.WARNING, 40 | ) 41 | parser.add_argument( 42 | "-v", 43 | "--verbose", 44 | help="Be verbose", 45 | action="store_const", 46 | dest="loglevel", 47 | const=logging.INFO, 48 | ) 49 | 50 | args, _ = parser.parse_known_args() 51 | 52 | logging.basicConfig(level=args.loglevel) 53 | 54 | for lang in args.langs.split("|"): 55 | filename = os.path.join(args.base_wikinews, lang, "{}wiki.pkl".format(lang)) 56 | logging.info("Loading {}".format(filename)) 57 | with open(filename, "rb") as f: 58 | wiki = pickle.load(f) 59 | 60 | kilt_dataset = [] 61 | for doc in tqdm(wiki.values()): 62 | for i, anchor in enumerate(doc["anchors"]): 63 | if len(anchor["wikidata_ids"]) == 1: 64 | meta = { 65 | "left_context": ( 66 | "".join(doc["paragraphs"][: anchor["paragraph_id"]]) 67 | + doc["paragraphs"][anchor["paragraph_id"]][ 68 | : anchor["start"] 69 | ] 70 | ), 71 | "mention": ( 72 | doc["paragraphs"][anchor["paragraph_id"]][ 73 | anchor["start"] : anchor["end"] 74 | ] 75 | ), 76 | "right_context": ( 77 | doc["paragraphs"][anchor["paragraph_id"]][anchor["end"] :] 78 | + "".join(doc["paragraphs"][anchor["paragraph_id"] :]) 79 | ), 80 | } 81 | item = { 82 | "id": "wikinews-{}-{}-{}".format(lang, doc["id"], i), 83 | "input": ( 84 | meta["left_context"] 85 | + " [START] " 86 | + meta["mention"] 87 | + " [END] " 88 | + meta["right_context"] 89 | ), 90 | "output": [{"answer": anchor["wikidata_ids"]}], 91 | "meta": meta, 92 | } 93 | kilt_dataset.append(item) 94 | 95 | filename = os.path.join(args.output_dir, "{}-kilt-all.jsonl".format(lang)) 96 | logging.info("Saving {}".format(filename)) 97 | with jsonlines.open(filename, "w") as f: 98 | f.write_all(kilt_dataset) 99 | 100 | if len(kilt_dataset) >= 10000: 101 | wiki_dict = defaultdict(list) 102 | for doc in kilt_dataset: 103 | wiki_dict[doc["id"].split("-")[2]].append(doc) 104 | 105 | test_set = [] 106 | dev_set = [] 107 | train_set = [] 108 | 109 | np.random.seed(0) 110 | for docs in np.random.permutation(list(wiki_dict.values())): 111 | if len(test_set) < len(kilt_dataset) // 10: 112 | test_set += docs 113 | elif len(dev_set) < len(kilt_dataset) // 10: 114 | dev_set += docs 115 | else: 116 | train_set += docs 117 | 118 | for split_name, split in zip( 119 | ("test", "dev", "train"), (test_set, dev_set, train_set) 120 | ): 121 | filename = os.path.join( 122 | args.output_dir, "{}-kilt-{}.jsonl".format(lang, split_name) 123 | ) 124 | logging.info("Saving {}".format(filename)) 125 | with jsonlines.open(filename, "w") as f: 126 | f.write_all(split) -------------------------------------------------------------------------------- /scripts/preprocess_wiki.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import json 4 | import logging 5 | import os 6 | import pickle 7 | from collections import defaultdict 8 | from concurrent.futures import ThreadPoolExecutor, as_completed 9 | from urllib.parse import unquote 10 | 11 | import jsonlines 12 | from tqdm.auto import tqdm, trange 13 | 14 | NOPAGE = [ 15 | "Q4167836", 16 | "Q24046192", 17 | "Q20010800", 18 | "Q11266439", 19 | "Q11753321", 20 | "Q19842659", 21 | "Q21528878", 22 | "Q17362920", 23 | "Q14204246", 24 | "Q21025364", 25 | "Q17442446", 26 | "Q26267864", 27 | "Q4663903", 28 | "Q15184295", 29 | # "Q4167410", 30 | ] 31 | 32 | if __name__ == "__main__": 33 | 34 | parser = argparse.ArgumentParser() 35 | 36 | parser.add_argument( 37 | "step", 38 | type=str, 39 | choices=["compress", "normalize", "dicts", "redirects", "freebase"], 40 | ) 41 | parser.add_argument( 42 | "--base_wikidata", 43 | type=str, 44 | help="Base folder with Wikidata data.", 45 | ) 46 | parser.add_argument( 47 | "--normalized", 48 | action="store_true", 49 | ) 50 | parser.add_argument( 51 | "-d", 52 | "--debug", 53 | help="Print lots of debugging statements", 54 | action="store_const", 55 | dest="loglevel", 56 | const=logging.DEBUG, 57 | default=logging.WARNING, 58 | ) 59 | parser.add_argument( 60 | "-v", 61 | "--verbose", 62 | help="Be verbose", 63 | action="store_const", 64 | dest="loglevel", 65 | const=logging.INFO, 66 | ) 67 | 68 | args = parser.parse_args() 69 | 70 | logging.basicConfig(level=args.loglevel) 71 | 72 | if args.step == "compress": 73 | wikidata = 0 74 | with open( 75 | os.path.join(args.base_wikidata, "wikidata-all.json"), "r" 76 | ) as fi, jsonlines.open( 77 | os.path.join(args.base_wikidata, "wikidata-all-compressed.jsonl"), "w" 78 | ) as fo: 79 | 80 | iter_ = tqdm(fi) 81 | for i, line in enumerate(iter_): 82 | iter_.set_postfix(wikidata=wikidata, refresh=False) 83 | 84 | line = line.strip() 85 | if line[-1] == ",": 86 | line = line[:-1] 87 | 88 | if line == "[" or line == "]": 89 | continue 90 | 91 | line = json.loads(line) 92 | if line["type"] == "item": 93 | 94 | if any( 95 | e["mainsnak"]["datavalue"]["value"]["id"] in NOPAGE 96 | for e in line["claims"].get("P31", {}) 97 | if "datavalue" in e["mainsnak"] 98 | ): 99 | continue 100 | if any( 101 | e["mainsnak"]["datavalue"]["value"]["id"] in NOPAGE 102 | for e in line["claims"].get("P279", {}) 103 | if "datavalue" in e["mainsnak"] 104 | ): 105 | continue 106 | 107 | line["sitelinks"] = { 108 | k[:-4]: v["title"] 109 | for k, v in line["sitelinks"].items() 110 | if k.endswith("wiki") 111 | } 112 | if len(line["sitelinks"]) == 0: 113 | continue 114 | 115 | line["labels"] = {k: v["value"] for k, v in line["labels"].items()} 116 | line["descriptions"] = { 117 | k: v["value"] for k, v in line["descriptions"].items() 118 | } 119 | line["aliases"] = { 120 | k: [e["value"] for e in v] for k, v in line["aliases"].items() 121 | } 122 | 123 | for e in ("claims", "lastrevid", "type"): 124 | del line[e] 125 | 126 | fo.write(line) 127 | wikidata += 1 128 | 129 | elif args.step == "dicts": 130 | 131 | lang_title2wikidataID = defaultdict(set) 132 | wikidataID2lang_title = defaultdict(set) 133 | label_desc2wikidataID = defaultdict(set) 134 | wikidataID2label_desc_lang = defaultdict(set) 135 | label_or_alias2wikidataID = defaultdict(set) 136 | wikidataID2label_or_alias = defaultdict(set) 137 | wikidataID2lang2label_or_alias = defaultdict(lambda: defaultdict(set)) 138 | 139 | filename = os.path.join( 140 | args.base_wikidata, 141 | "wikidata-all-compressed{}.jsonl".format( 142 | "-normalized" if args.normalized else "" 143 | ), 144 | ) 145 | logging.info("Processing {}".format(filename)) 146 | with jsonlines.open(filename, "r") as f: 147 | 148 | for item in tqdm(f): 149 | for lang, title in item["sitelinks"].items(): 150 | lang_title2wikidataID[(lang, title)].add(item["id"]) 151 | wikidataID2lang_title[item["id"]].add((lang, title)) 152 | 153 | for lang, label in item["labels"].items(): 154 | if lang in item["descriptions"]: 155 | label_desc2wikidataID[(label, item["descriptions"][lang])].add( 156 | item["id"] 157 | ) 158 | wikidataID2label_desc_lang[item["id"]].add( 159 | (label, item["descriptions"][lang], lang) 160 | ) 161 | 162 | for lang, aliases in item["aliases"].items(): 163 | for alias in aliases: 164 | label_or_alias2wikidataID[alias.lower()].add(item["id"]) 165 | wikidataID2label_or_alias[item["id"]].add(alias) 166 | wikidataID2lang2label_or_alias[item["id"]][lang].add(alias) 167 | 168 | for lang, label in item["labels"].items(): 169 | label_or_alias2wikidataID[label.lower()].add(item["id"]) 170 | wikidataID2label_or_alias[item["id"]].add(label) 171 | wikidataID2lang2label_or_alias[item["id"]][lang].add(label) 172 | 173 | wikidataID2lang2label_or_alias = { 174 | wikidataID: dict(lang2label_or_alias) 175 | for wikidataID, lang2label_or_alias in wikidataID2lang2label_or_alias.items() 176 | } 177 | 178 | for data, name in zip( 179 | ( 180 | lang_title2wikidataID, 181 | wikidataID2lang_title, 182 | label_desc2wikidataID, 183 | wikidataID2label_desc_lang, 184 | label_or_alias2wikidataID, 185 | wikidataID2label_or_alias, 186 | wikidataID2lang2label_or_alias, 187 | ), 188 | ( 189 | "lang_title2wikidataID", 190 | "wikidataID2lang_title", 191 | "label_desc2wikidataID", 192 | "wikidataID2label_desc_lang", 193 | "label_or_alias2wikidataID", 194 | "wikidataID2label_or_alias", 195 | "wikidataID2lang2label_or_alias", 196 | ), 197 | ): 198 | 199 | filename = os.path.join( 200 | args.base_wikidata, 201 | "{}{}.pkl".format(name, "-normalized" if args.normalized else ""), 202 | ) 203 | logging.info("Saving {}".format(filename)) 204 | with open(filename, "wb") as f: 205 | pickle.dump(dict(data), f) 206 | 207 | elif args.step == "redirects": 208 | 209 | lang_redirect2title = {} 210 | for lang in set(wiki_langs).intersection(set(mbart100_langs)): 211 | with open( 212 | "wikipedia_redirect/target/{}wiki-redirects.txt".format(lang) 213 | ) as f: 214 | for row in tqdm(csv.reader(f, delimiter="\t"), desc=lang): 215 | title = unquote(row[1]).split("#")[0].replace("_", " ") 216 | if title: 217 | title = title[0].upper() + title[1:] 218 | assert (lang, row[0]) not in lang_redirect2title 219 | lang_redirect2title[(lang, row[0])] = title 220 | 221 | filename = os.path.join(args.base_wikidata, "lang_redirect2title.pkl") 222 | logging.info("Saving {}".format(filename)) 223 | with open(filename, "wb") as f: 224 | pickle.dump(lang_redirect2title, f) 225 | 226 | elif args.step == "freebase": 227 | 228 | wikidataID2freebaseID = defaultdict(list) 229 | freebaseID2wikidataID = defaultdict(list) 230 | 231 | with open(os.path.join(args.base_wikidata, "wikidata-all.json"), "r") as fi: 232 | 233 | iter_ = tqdm(fi) 234 | for i, line in enumerate(iter_): 235 | line = line.strip() 236 | if line[-1] == ",": 237 | line = line[:-1] 238 | 239 | if line == "[" or line == "]": 240 | continue 241 | 242 | line = json.loads(line) 243 | 244 | if line["type"] == "item": 245 | 246 | if any( 247 | e["mainsnak"]["datavalue"]["value"]["id"] in NOPAGE 248 | for e in line["claims"].get("P31", {}) 249 | if "datavalue" in e["mainsnak"] 250 | ): 251 | continue 252 | if any( 253 | e["mainsnak"]["datavalue"]["value"]["id"] in NOPAGE 254 | for e in line["claims"].get("P279", {}) 255 | if "datavalue" in e["mainsnak"] 256 | ): 257 | continue 258 | 259 | line["sitelinks"] = { 260 | k[:-4]: v["title"] 261 | for k, v in line["sitelinks"].items() 262 | if k.endswith("wiki") 263 | } 264 | if len(line["sitelinks"]) == 0: 265 | continue 266 | 267 | for freebaseID in [ 268 | e["mainsnak"]["datavalue"]["value"] 269 | for e in line["claims"].get("P646", {}) 270 | if "datavalue" in e["mainsnak"] 271 | ]: 272 | wikidataID2freebaseID[line["id"]].append(freebaseID) 273 | freebaseID2wikidataID[freebaseID].append(line["id"]) 274 | 275 | wikidataID2freebaseID = dict(wikidataID2freebaseID) 276 | filename = os.path.join(args.base_wikidata, "wikidataID2freebaseID.pkl") 277 | logging.info("Saving {}".format(filename)) 278 | with open(filename, "wb") as f: 279 | pickle.dump(wikidataID2freebaseID, f) 280 | 281 | freebaseID2wikidataID = dict(freebaseID2wikidataID) 282 | filename = os.path.join(args.base_wikidata, "freebaseID2wikidataID.pkl") 283 | logging.info("Saving {}".format(filename)) 284 | with open(filename, "wb") as f: 285 | pickle.dump(freebaseID2wikidataID, f) --------------------------------------------------------------------------------