├── .gitignore ├── Chaitanya Joshi - Knowledge Graphs from Unstructured Text.pdf ├── README.md ├── img ├── kg.png ├── kg_aspirin.png ├── kg_basf.png ├── kg_bayer.png ├── kg_ig farben.png ├── kg_monsanto.png └── kg_pharmaceutical industry.png ├── kg_utils.py ├── main.ipynb ├── scraper_utils.py └── viz_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | neuralcoref-4.0.0/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /Chaitanya Joshi - Knowledge Graphs from Unstructured Text.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/knowledge-graphs/fd1130d506e033cdadabbdb6c8672e538570eae3/Chaitanya Joshi - Knowledge Graphs from Unstructured Text.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Building Knowledge Graphs from Unstructured Text 2 | 3 | Challenge: Build a Knowledge Graph for the company **Bayer**, focused on their **Pharmacology business**. 4 | 5 | Refer to the notebook [`main.ipynb`](main.ipynb) for usage and visualizations of our results. Check out the [presentation slides here](Chaitanya%20Joshi%20-%20Knowledge%20Graphs%20from%20Unstructured%20Text.pdf). 6 | 7 | ![Bayer-Pharma Knowledge Graph](/img/kg.png) 8 | 9 | ## Installation 10 | 11 | ```sh 12 | # Create conda environment 13 | conda create -n nlp python=3.7 14 | conda activate nlp 15 | 16 | # Install and setup Spacy 17 | conda install -c conda-forge spacy==2.1.6 18 | python -m spacy download en 19 | python -m spacy download en_core_web_lg 20 | 21 | # Install neuralcoref (specific version, for spacy compatibility) 22 | conda install cython 23 | curl https://github.com/huggingface/neuralcoref/archive/4.0.0.zip -o neuralcoref-4.0.0.zip -J -L -k 24 | cd neuralcoref 25 | python setup.py build_ext --inplace 26 | python setup.py install 27 | 28 | # Install additional packages 29 | pip install wikipedia-api 30 | conda install pandas networkx matplotlib seaborn 31 | conda install pytorch=1.2.0 cudatoolkit=10.0 -c pytorch 32 | pip install transformers 33 | 34 | # Install extras 35 | conda install ipywidgets nodejs -c conda-forge 36 | jupyter labextension install @jupyter-widgets/jupyterlab-manager 37 | ``` 38 | -------------------------------------------------------------------------------- /img/kg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/knowledge-graphs/fd1130d506e033cdadabbdb6c8672e538570eae3/img/kg.png -------------------------------------------------------------------------------- /img/kg_aspirin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/knowledge-graphs/fd1130d506e033cdadabbdb6c8672e538570eae3/img/kg_aspirin.png -------------------------------------------------------------------------------- /img/kg_basf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/knowledge-graphs/fd1130d506e033cdadabbdb6c8672e538570eae3/img/kg_basf.png -------------------------------------------------------------------------------- /img/kg_bayer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/knowledge-graphs/fd1130d506e033cdadabbdb6c8672e538570eae3/img/kg_bayer.png -------------------------------------------------------------------------------- /img/kg_ig farben.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/knowledge-graphs/fd1130d506e033cdadabbdb6c8672e538570eae3/img/kg_ig farben.png -------------------------------------------------------------------------------- /img/kg_monsanto.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/knowledge-graphs/fd1130d506e033cdadabbdb6c8672e538570eae3/img/kg_monsanto.png -------------------------------------------------------------------------------- /img/kg_pharmaceutical industry.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/knowledge-graphs/fd1130d506e033cdadabbdb6c8672e538570eae3/img/kg_pharmaceutical industry.png -------------------------------------------------------------------------------- /kg_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import re 3 | import spacy 4 | from transformers import AutoModelForTokenClassification, AutoTokenizer 5 | import torch 6 | import neuralcoref 7 | from tqdm import tqdm 8 | 9 | nlp = spacy.load('en_core_web_lg') 10 | neuralcoref.add_to_pipe(nlp) 11 | 12 | 13 | def extract_ner_bert(text, model=None, tokenizer=None): 14 | """Method to extract Named Entities from text using pre-trained BERT 15 | 16 | Parameters 17 | ---------- 18 | text : str 19 | Text document/paragraph/sentence from which NEs are extracted 20 | model : transformers.Model, optional 21 | Pre-trained BERT model 22 | tokenizer : tokenizers.Tokenizer, optional 23 | Tokenizer associated with BERT model 24 | 25 | Returns 26 | ------- 27 | ents : list 28 | List of NEs in text 29 | 30 | Reference 31 | --------- 32 | HuggingFace Transformers library tutorials: 33 | https://huggingface.co/transformers/usage.html#named-entity-recognition 34 | """ 35 | if model is None: 36 | model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english") 37 | if tokenizer is None: 38 | tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 39 | 40 | label_list = [ 41 | "O", # Outside of a named entity 42 | "B-MISC", # Beginning of a miscellaneous entity right after another miscellaneous entity 43 | "I-MISC", # Miscellaneous entity 44 | "B-PER", # Beginning of a person's name right after another person's name 45 | "I-PER", # Person's name 46 | "B-ORG", # Beginning of an organisation right after another organisation 47 | "I-ORG", # Organisation 48 | "B-LOC", # Beginning of a location right after another location 49 | "I-LOC" # Location 50 | ] 51 | tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(text))) 52 | inputs = tokenizer.encode(text, return_tensors="pt") 53 | 54 | outputs = model(inputs)[0] 55 | predictions = torch.argmax(outputs, dim=2) 56 | 57 | # Build list of named entities 58 | ents = [] 59 | cur_ent = "" 60 | cur_label = "" 61 | for token, prediction in zip(tokens, predictions[0].tolist()): 62 | if label_list[prediction] != "O": 63 | if token[:2] == "##": 64 | # Append token to current NE 65 | cur_ent += token[2:] 66 | else: 67 | # Start of a new NE 68 | if cur_ent != "": 69 | ents.append(cur_ent) 70 | cur_ent = token 71 | cur_label = label_list[prediction] 72 | if cur_ent != "": 73 | ents.append(cur_ent) 74 | 75 | return ents 76 | 77 | 78 | def extract_triplets(text, title, global_ents_list, verbose=False, use_bert=False): 79 | """Method to extract Subject-Relation-Object triplets for KG construction 80 | 81 | Parameters 82 | ---------- 83 | text : str 84 | Raw text from Wikipedia article/document 85 | title : str 86 | Title of document, used as a default/fallback subject 87 | global_ents_list : list 88 | List of domain-specific entities which Spacy tools do not 89 | recognize as Named entities or Nouns chunks 90 | verbose : bool, optional 91 | Flag for displaying progress bar and verbose output 92 | use_bert : bool, optional 93 | Flag for using pre-trained BERT for NER instead of Spacy (default) 94 | 95 | Returns 96 | ------- 97 | sro_triplets_df : pd.DataFrame 98 | Pandas dataframe with S-R-O triplets extracted from document 99 | """ 100 | # print("\nRaw text:\n", text) 101 | text = re.sub(r'\n+', '. ', text) 102 | # print("\nRemove new-line:\n", text) 103 | text = re.sub(r'\[\d+\]', ' ', text) 104 | # print("\nRemove reference numbers:\n", text) 105 | text = re.sub(r'\([^()]*\)', ' ', text) 106 | # print("\nRemove parenthesis:\n", text) 107 | text = re.sub(r'(?<=[.,])(?=[^\s0-9])', r' ', text) 108 | # print("\nFix formatting:\n", text) 109 | 110 | # Resolve coreferences with Spacy+neuralcoref 111 | text = nlp(text) 112 | text = nlp(text._.coref_resolved) 113 | # print("\nResolve coreferences:\n", text) 114 | 115 | # Track (Subject, Relation, Object) triplets 116 | sro_triplets = [] 117 | 118 | # Temp variables to track previous sentence subject and object 119 | prev_subj = nlp("")[:] 120 | prev_obj = nlp("")[:] 121 | default_subj = nlp(title)[:] # Set default subject: title of document 122 | 123 | sentences = [sent.string.strip() for sent in text.sents] 124 | for sent in tqdm(sentences): 125 | prev_obj_end = 0 # Temp pointer to previous object end 126 | 127 | sent = nlp(sent) # Pass through Spacy pipeline 128 | 129 | # Retokenize to combine Named Entities into single tokens 130 | ents = list(sent.ents) 131 | spans = spacy.util.filter_spans(ents) 132 | with sent.retokenize() as retokenizer: 133 | [retokenizer.merge(span) for span in spans] 134 | 135 | # Re-build Named Entities by categories 136 | if use_bert: 137 | # TODO: 138 | # Implement using BERT NER model in combination with spacy features. 139 | # There are several issues with BERT: 140 | # 1. At the moment, using BERT means we won't be able to use spans, lemmas, PoS tags, etc. 141 | # from spacy. If needed in the future, we can write a functions to bridge the gap. 142 | # 2. Unfortunately, pre-trained BERT does not support detailed NE types beyond 143 | # (person, location, organization). On the other hand, spacy supports up to 18 NE types. 144 | raise NotImplementedError 145 | # ents = extract_ner_bert(sent) 146 | else: 147 | # Use spacy's NER model 148 | ents = list(sent.ents) 149 | main_ents = [] # Named Entities recognised by Spacy (Main) 150 | addn_ents = [] # Additional named entities (Date/Time/etc.) 151 | for ent in ents: 152 | if ent.label_ in ("DATE", "TIME", "MONEY", "QUANTITY"): 153 | addn_ents.append(ent) 154 | elif ent.label_ in ("CARDINAL", "ORDINAL", "PERCENT"): 155 | # Ignore cardinal/ordinal numbers and percentages 156 | continue 157 | elif ent.label_ in ("PERSON", "NORP", "FAC", "ORG", 158 | "GPE", "LOC", "PRODUCT", "EVENT", 159 | "WORK_OF_ART", "LAW", "LANGUAGE"): 160 | main_ents.append(ent) 161 | # Identidy Domain-specific/global named entities 162 | global_ents = [] 163 | for tok in sent: 164 | if tok.text.lower() in global_ents_list: 165 | global_ents.append(sent[tok.i:tok.i+1]) 166 | 167 | # Identify noun chunks besides Named Entities 168 | noun_chunks = list(sent.noun_chunks) 169 | 170 | # Identify verbs for forming relations 171 | verbs = [tok for tok in sent if tok.pos_ == "VERB"] 172 | 173 | if verbose: 174 | print("\n----------\n") 175 | print("\nSentence:\n", sent) 176 | print("\nNamed Entities:\n", main_ents) 177 | print("\nDomain-specific Entities:\n", global_ents) 178 | print("\nAdditional Entities:\n", addn_ents) 179 | print("\nNoun spans:\n", noun_chunks) 180 | print("\nVerbs:\n", verbs) 181 | 182 | for verb in verbs: 183 | 184 | # Identify Subject 185 | subj = None 186 | # Find leftmost Main Ent to verb 187 | for ent in main_ents: 188 | if ent.end > verb.i: 189 | break 190 | elif ent.end > prev_obj_end: 191 | subj = ent 192 | rel_start = subj.end 193 | if subj is None: 194 | # Find leftmost Global Ent to verb 195 | for ent in global_ents: 196 | if ent.end > verb.i: 197 | break 198 | elif ent.end > prev_obj_end: 199 | subj = ent 200 | rel_start = subj.end 201 | if subj is None: 202 | # Find leftmost noun chunk to verb 203 | for noun_chunk in noun_chunks: 204 | if noun_chunk.end > verb.i: 205 | break 206 | elif noun_chunk.end > prev_obj_end: 207 | subj = noun_chunk 208 | rel_start = subj.end 209 | if subj is None: 210 | # Find leftmost Additional Ent to verb 211 | for ent in addn_ents: 212 | if ent.end > verb.i: 213 | break 214 | elif ent.end > prev_obj_end: 215 | subj = ent 216 | rel_start = subj.end 217 | if subj is None: 218 | # If no subject found, assign default subject 219 | subj = default_subj 220 | rel_start = verb.i 221 | 222 | ########## 223 | 224 | # Identify Object 225 | obj = None 226 | # Find rightmost Main Ent to verb 227 | for ent in main_ents[::-1]: 228 | if ent.end <= verb.i: 229 | break 230 | else: 231 | obj = ent 232 | rel_end = obj.start 233 | if obj is None: 234 | # Find rightmost Global Ent to verb 235 | for ent in global_ents[::-1]: 236 | if ent.end <= verb.i: 237 | break 238 | elif ent.text.lower() != verb.text.lower(): 239 | # Additional check for global entity not being verb itself! 240 | obj = ent 241 | rel_end = obj.start 242 | if obj is None: 243 | # Find rightmost noun chunk to verb 244 | for noun_chunk in noun_chunks[::-1]: 245 | if noun_chunk.end <= verb.i: 246 | break 247 | else: 248 | obj = noun_chunk 249 | rel_end = obj.start 250 | if obj is None: 251 | # Find rightmost Additional Ent to verb 252 | for ent in addn_ents[::-1]: 253 | if ent.end <= verb.i: 254 | break 255 | else: 256 | obj = ent 257 | rel_end = obj.start 258 | if obj is None: 259 | # If no object found, assign previous subject 260 | obj = prev_obj 261 | rel_end = verb.i + 1 262 | 263 | ########## 264 | 265 | # Identify and lemmatized relationship spans around verb token 266 | triplet = ( 267 | # Subject 268 | " ".join(tok.text.lower() for tok in subj if 269 | (tok.is_stop == False and tok.is_punct == False)).strip(), 270 | # Relationship 271 | " ".join(tok.lemma_.lower() for tok in sent[rel_start:rel_end] if 272 | (tok == verb or (tok.is_stop == False and tok.is_punct == False))).strip(), 273 | # Object 274 | " ".join(tok.text.lower() for tok in obj if 275 | (tok.is_stop == False and tok.is_punct == False)).strip(), 276 | ) 277 | 278 | # Append valid SRO triplets to list 279 | if triplet[0] != "" and triplet[1] != "" and triplet[2] != "" and triplet[0] != triplet[2]: 280 | # Check for duplicate triplets within same sentence 281 | if subj == prev_subj and obj == prev_obj: 282 | prev_triplet = sro_triplets.pop() 283 | # Define relation as the longest relation span among duplicates 284 | if len(prev_triplet[1]) > len(triplet[1]): 285 | triplet = prev_triplet 286 | 287 | sro_triplets.append(triplet) 288 | if verbose: 289 | print("\nS-R-O:\n", subj, "-", relation, "-", obj) 290 | 291 | # Update previous subject and object variables 292 | prev_subj = subj 293 | prev_obj = obj 294 | prev_obj_end = obj.end 295 | 296 | # Convert to df 297 | sro_triplets_df = pd.DataFrame(sro_triplets, columns=['subject', 'relation', 'object']) 298 | return sro_triplets_df 299 | 300 | 301 | def merge_duplicate_subjs(triplets, title=None): 302 | """Helper function to merge duplicate subjects 303 | 304 | Duplicate subjects can be extensions/additional words joined to typical subjects, 305 | e.g. 'bayer ag', 'bayer healthcare', 'bayer pharmaceuticals' --> 'bayer' 306 | Note that when merging an extended subject (e.g. 'bayer healthcare') 307 | back to a subject ('bayer'), we append the extension ('healthcare') 308 | to the relation for the triplet and then replace the extended subject with the subject. 309 | 310 | Parameters 311 | ---------- 312 | triplets : pd.DataFrame 313 | S-R-O triplets dataframe 314 | 315 | Returns 316 | ------- 317 | triplets : pd.DataFrame 318 | Updated dataframe with merged duplicate subjects 319 | """ 320 | subjects = sorted(list(triplets.subject.unique())) 321 | prev_subj = subjects[0] 322 | for subj in subjects[1:]: 323 | # TODO Use string edit distance between prev_subj and subj 324 | if prev_subj in subj: 325 | # Detect extension in subj compared to prev_subj and append it to relations of rows with subj 326 | triplets.loc[triplets.subject==subj, 'relation'] = \ 327 | subj.replace(prev_subj, '').strip() + ' ' + triplets[triplets.subject==subj].relation 328 | # Update subject from subj to prev_subj 329 | triplets.loc[triplets.subject==subj, 'subject'] = prev_subj 330 | 331 | else: 332 | # Update prev_subj 333 | prev_subj = subj 334 | 335 | return triplets 336 | 337 | 338 | def prune_infreq_subjects(triplets, threshold=2): 339 | """Helper function to prune triplets with infrequent subject 340 | 341 | Parameters 342 | ---------- 343 | triplets : pd.DataFrame 344 | S-R-O triplets dataframe 345 | threshold : int 346 | Frequency threshold for pruning 347 | 348 | Returns 349 | ------- 350 | triplets : pd.DataFrame 351 | Updated dataframw with pruned rows 352 | """ 353 | # Count unique subjects 354 | subj_counts = triplets.subject.value_counts() 355 | # TODO: add more/smarter heuristics for pruning? 356 | # Drop subjects with counts below threshold 357 | triplets['subj_count'] = list(subj_counts[triplets.subject]) 358 | triplets.drop(triplets[triplets['subj_count'] < threshold].index, inplace=True) 359 | triplets = triplets.drop('subj_count', 1) 360 | return triplets 361 | 362 | 363 | def prune_infreq_objects(triplets, threshold=2): 364 | """Helper function to prune triplets with infrequent objects 365 | 366 | Parameters 367 | ---------- 368 | triplets : pd.DataFrame 369 | S-R-O triplets dataframe 370 | threshold : int 371 | Frequency threshold for pruning 372 | 373 | Returns 374 | ------- 375 | triplets : pd.DataFrame 376 | Updated dataframw with pruned rows 377 | """ 378 | # Count unique objects 379 | obj_counts = triplets.object.value_counts() 380 | # TODO: add more/smarter heuristics for pruning? 381 | # Drop objects with counts below threshold 382 | triplets['obj_count'] = list(obj_counts[triplets.object]) 383 | triplets.drop(triplets[triplets['obj_count'] < threshold].index, inplace=True) 384 | triplets = triplets.drop('obj_count', 1) 385 | return triplets 386 | 387 | 388 | def prune_self_loops(triplets): 389 | """Helper function to prune triplets where subject is the same as object 390 | """ 391 | triplets.drop(triplets[triplets.subject==triplets.object].index, inplace=True) 392 | return triplets 393 | -------------------------------------------------------------------------------- /scraper_utils.py: -------------------------------------------------------------------------------- 1 | import wikipediaapi 2 | import pandas as pd 3 | import concurrent.futures 4 | from tqdm import tqdm 5 | 6 | 7 | def wiki_scrape(start_page_name, verbose=True): 8 | """Method to scrape Wikipedia pages associated with/linked to a starting page 9 | 10 | Parameters 11 | ---------- 12 | start_page_name : str 13 | Name of page to start scraping from 14 | verbose : bool, optional 15 | Flag for displaying progress bar and verbose output 16 | 17 | Returns 18 | ------- 19 | sources : pd.DataFrame 20 | DataFrame containing all scraped Wikipedia articles linked to start_page_name, 21 | with entries ('page', 'text', 'link', 'categories') 22 | 23 | References 24 | ---------- 25 | Modified from https://towardsdatascience.com/auto-generated-knowledge-graphs-92ca99a81121 26 | """ 27 | def follow_link(link): 28 | """Helper function to follow links using Wikipedia API 29 | """ 30 | try: 31 | page = wiki_api.page(link) 32 | if page.exists(): 33 | d = {'page': link, 'text': page.text, 'link': page.fullurl, 34 | 'categories': list(page.categories.keys())} 35 | return d 36 | else: 37 | return None 38 | except: 39 | return None 40 | 41 | # Instantiate Wikipedia API 42 | wiki_api = wikipediaapi.Wikipedia(language='en', extract_format=wikipediaapi.ExtractFormat.WIKI) 43 | 44 | # Scrape starting page 45 | page_name = wiki_api.page(start_page_name) 46 | if not page_name.exists(): 47 | print('page does not exist') 48 | return 49 | 50 | # Initialize dict (to be converted to df) 51 | sources = [{ 52 | 'page': start_page_name, 53 | 'text': page_name.text, 54 | 'link': page_name.fullurl, 55 | 'categories': list(page_name.categories.keys()) 56 | }] 57 | 58 | page_links = set(page_name.links.keys()) 59 | # Multiprocessing to parallely scrape from multiple pages 60 | progress = tqdm(desc='Links Scraped', unit='', total=len(page_links)) if verbose else None 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: 62 | # Follow links from list of links 63 | future_link = {executor.submit(follow_link, link): link for link in page_links} 64 | for future in concurrent.futures.as_completed(future_link): 65 | data = future.result() 66 | progress.update(1) if verbose else None 67 | if data: 68 | # Update sources dict 69 | sources.append(data) 70 | progress.close() if verbose else None 71 | 72 | # Convert dict to df 73 | sources = pd.DataFrame(sources) 74 | 75 | # Filter out generic Wikipedia pages 76 | blacklist = ('Template', 'Help:', 'Category:', 'Portal:', 'Wikipedia:', 'Talk:') 77 | sources = sources[(len(sources['text']) > 20) 78 | & ~(sources['page'].str.startswith(blacklist))] 79 | sources['categories'] = sources.categories.apply(lambda x: [y[9:] for y in x]) 80 | 81 | return sources 82 | 83 | 84 | def build_category_whitelist(wiki_data, page_whitelist, cat_blacklist): 85 | """Helper function to build whitelist of page categories 86 | 87 | This method finds a set of page categories which we can use to 88 | reduce the amount of pages we use for building knowledge graphs. 89 | Typically, we want to build KGs about particular domains and 90 | specific pages. 91 | 92 | Parameters 93 | ---------- 94 | wiki_data : pd.DataFrame 95 | 96 | page_whitelist : list 97 | List of pages from whose categories we select a domain-specific subset 98 | cat_blacklist : list 99 | List of categories which we don't want to include in the whitelist 100 | 101 | Returns 102 | ------- 103 | cat_whitelist : set/list 104 | List of categories which we want to build KGs about 105 | """ 106 | cat_whitelist = [] 107 | for page_name in page_whitelist: 108 | # Iterate over categies list for each page in the page whitelist 109 | categories = list(wiki_data[wiki_data.page==page_name].categories)[0] 110 | for cat in categories: 111 | relevant_cat = True 112 | for unwanted in cat_blacklist: 113 | # If given category is part of blacklisted categories, 114 | # do not add it to whitelist 115 | if unwanted in cat: 116 | relevant_cat = False 117 | break 118 | 119 | # All non-blacklisted categories from the page whitelist 120 | # are added to categories whitelist 121 | if relevant_cat: 122 | cat_whitelist.append(cat) 123 | 124 | return set(cat_whitelist) 125 | -------------------------------------------------------------------------------- /viz_utils.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import matplotlib.pyplot as plt 3 | import itertools 4 | 5 | 6 | def draw_kg(triplets, save_fig=False): 7 | """Method to plot and save full KG using networkx 8 | 9 | Parameters 10 | ---------- 11 | triplets : pd.DataFrame 12 | S-R-O triplets dataframe 13 | save_fig : bool, optional 14 | Flag for saving figure to /img directory 15 | 16 | References 17 | ---------- 18 | Adapted from: https://towardsdatascience.com/auto-generated-knowledge-graphs-92ca99a81121 19 | """ 20 | # Build networkx graph 21 | k_graph = nx.from_pandas_edgelist(triplets, source='subject', target='object', 22 | create_using=nx.MultiDiGraph()) 23 | # Compute node degrees, for resizing highly connected nodes in plot 24 | node_deg = nx.degree(k_graph) 25 | # Plot graph 26 | layout = nx.spring_layout(k_graph, k=0.15, iterations=20) 27 | plt.figure(num=None, figsize=(120, 90), dpi=80) 28 | nx.draw_networkx( 29 | k_graph, 30 | node_size=[int(deg[1]) * 500 for deg in node_deg], 31 | arrowsize=20, 32 | linewidths=1.5, 33 | pos=layout, 34 | edge_color='red', 35 | edgecolors='black', 36 | node_color='white', 37 | ) 38 | # Build edge/relationship labels 39 | labels = dict(zip( 40 | list(zip(triplets.subject, triplets.object)), 41 | triplets['relation'].tolist() 42 | )) 43 | # Add edge labels to plot 44 | nx.draw_networkx_edge_labels( 45 | k_graph, 46 | pos=layout, 47 | edge_labels=labels, 48 | font_color='red' 49 | ) 50 | plt.axis('off') 51 | if save_fig: 52 | plt.savefig("img/kg_full.png", format='png', bbox_inches='tight') 53 | plt.show() 54 | 55 | 56 | def draw_kg_subgraph(triplets, node, n_hops=2, verbose=True, save_fig=False): 57 | """Method to plot and save KG subgraph centered around a given node 58 | 59 | The subgraph around the node is build using all relationships 60 | that are at `n_hop` hops around the node in a DFS tree, 61 | i.e., those nodes that are reachable from given node in `n_hop` hops. 62 | 63 | Parameters 64 | ---------- 65 | triplets : pd.DataFrame 66 | S-R-O triplets dataframe 67 | node : str 68 | Node for which subgraph is computed 69 | n_hops : int, optional 70 | Number of hops for DFS neighborhood construction 71 | verbose : bool, optional 72 | Flag to print S-R-O triplets associated with node 73 | save_fig : bool, optional 74 | Flag for saving figure to /img directory 75 | 76 | References 77 | ---------- 78 | Adapted from: https://towardsdatascience.com/auto-generated-knowledge-graphs-92ca99a81121 79 | """ 80 | # Build networkx graph 81 | k_graph = nx.from_pandas_edgelist(triplets, source='subject', target='object', 82 | create_using=nx.MultiDiGraph()) 83 | # Build subgraph nodes list 84 | nodes = [node] 85 | # Add n-hop DFS successors 86 | dfs_suc = list(nx.dfs_successors(k_graph, node).values()) 87 | if len(dfs_suc) > 0: 88 | for hop in range(n_hops): 89 | nodes += dfs_suc[hop] 90 | # Build subgraph 91 | subgraph = k_graph.subgraph(nodes) 92 | # Plot subgraph 93 | layout = nx.circular_layout(subgraph) 94 | plt.figure(num=None, figsize=(10, 10), dpi=80) 95 | nx.draw_networkx( 96 | subgraph, 97 | node_size=1000, 98 | arrowsize=20, 99 | linewidths=1.5, 100 | pos=layout, 101 | edge_color='red', 102 | edgecolors='black', 103 | node_color='white' 104 | ) 105 | # Build edge/relationship labels 106 | labels = dict(zip( 107 | (list(zip(triplets.subject, triplets.object))), 108 | triplets['relation'].tolist() 109 | )) 110 | edges = tuple(subgraph.out_edges(data=False)) 111 | sublabels = {k: labels[k] for k in edges} 112 | if verbose: 113 | for pair in sublabels.keys(): 114 | print("\nS-R-O:\n", pair[0], "-", sublabels[pair], "-", pair[1]) 115 | # Add edge labels to plot 116 | nx.draw_networkx_edge_labels( 117 | subgraph, 118 | pos=layout, 119 | edge_labels=sublabels, 120 | font_color='red' 121 | ) 122 | plt.axis('off') 123 | if save_fig: 124 | plt.savefig(f"img/kg_{node.lower()}.png", format='png', bbox_inches='tight') 125 | plt.show() 126 | --------------------------------------------------------------------------------