├── .github ├── scripts │ ├── run-tests.sh │ └── tests-setup.sh └── workflows │ └── tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── downloadKnowledgeBase.sh ├── requirements.txt ├── setup.py ├── spacy_entity_linker ├── DatabaseConnection.py ├── EntityCandidates.py ├── EntityClassifier.py ├── EntityCollection.py ├── EntityElement.py ├── EntityLinker.py ├── SpanInfo.py ├── TermCandidate.py ├── TermCandidateExtractor.py ├── __init__.py └── __main__.py └── tests ├── test_EntityCollection.py ├── test_EntityElement.py ├── test_EntityLinker.py ├── test_TermCandidateExtractor.py ├── test_multiprocessing.py ├── test_multithreading.py ├── test_pipe.py └── test_serialize.py /.github/scripts/run-tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e 3 | 4 | echo "Running tests..." 5 | python -m unittest discover tests 6 | echo "Tests passed!" -------------------------------------------------------------------------------- /.github/scripts/tests-setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # This scripts sets up the environment for the tests to run. 3 | 4 | set -e 5 | # Install the spacy models that are used in the tests. 6 | python -m spacy download en_core_web_sm -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | on: 3 | push: 4 | branches: [ master ] 5 | pull_request: 6 | branches: [ master ] 7 | jobs: 8 | test: 9 | name: Run tests 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: [ "3.7", "3.8", "3.9", "3.10", "3.11" ] 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | cache: 'pip' # caching pip dependencies 21 | - name: Install Python dependencies 22 | uses: py-actions/py-dependency-install@v4 23 | - name: Install additional for testing 24 | run: ./.github/scripts/tests-setup.sh 25 | shell: bash 26 | - name: Run the tests 27 | run: ./.github/scripts/run-tests.sh 28 | shell: bash -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | .idea 3 | *.log 4 | .ipynb_checkpoints 5 | data_spacy_entity_linker 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | 140 | # pytype static type analyzer 141 | .pytype/ 142 | 143 | # Cython debug symbols 144 | cython_debug/ 145 | 146 | experimental_notebooks 147 | settings.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Emanuel Gerber 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Tests](https://github.com/egerber/spaCy-entity-linker/actions/workflows/tests.yml/badge.svg)](https://github.com/egerber/spaCy-entity-linker/actions/workflows/tests.yml) 2 | [![Downloads](https://static.pepy.tech/badge/spacy-entity-linker)](https://pepy.tech/project/spacy-entity-linker) 3 | [![Current Release Version](https://img.shields.io/github/release/egerber/spaCy-entity-linker.svg?style=flat-square&logo=github)](https://github.com/egerber/spaCy-entity-linker/releases) 4 | [![pypi Version](https://img.shields.io/pypi/v/spacy-entity-linker.svg?style=flat-square&logo=pypi&logoColor=white)](https://pypi.org/project/spacy-entity-linker/) 5 | # Spacy Entity Linker 6 | 7 | ## Introduction 8 | 9 | Spacy Entity Linker is a pipeline for spaCy that performs Linked Entity Extraction with Wikidata on a given Document. 10 | The Entity Linking System operates by matching potential candidates from each sentence 11 | (subject, object, prepositional phrase, compounds, etc.) to aliases from Wikidata. The package allows to easily find the 12 | category behind each entity (e.g. "banana" is type "food" OR "Microsoft" is type "company"). It can is therefore useful 13 | for information extraction tasks and labeling tasks. 14 | 15 | The package was written before a working Linked Entity Solution existed inside spaCy. In comparison to spaCy's linked 16 | entity system, it has the following advantages: 17 | 18 | - no extensive training required (entity-matching via database) 19 | - knowledge base can be dynamically updated without retraining 20 | - entity categories can be easily resolved 21 | - grouping entities by category 22 | 23 | It also comes along with a number of disadvantages: 24 | 25 | - it is slower than the spaCy implementation due to the use of a database for finding entities 26 | - no context sensitivity due to the implementation of the "max-prior method" for entitiy disambiguation (an improved 27 | method for this is in progress) 28 | 29 | 30 | ## Installation 31 | 32 | To install the package, run: 33 | ```bash 34 | pip install spacy-entity-linker 35 | ``` 36 | 37 | Afterwards, the knowledge base (Wikidata) must be downloaded. This can be either be done by manually calling 38 | 39 | ```bash 40 | python -m spacy_entity_linker "download_knowledge_base" 41 | ``` 42 | 43 | or when you first access the entity linker through spacy. 44 | This will download and extract a ~1.3GB file that contains a preprocessed version of Wikidata. 45 | 46 | ## Use 47 | 48 | ```python 49 | import spacy # version 3.5 50 | 51 | # initialize language model 52 | nlp = spacy.load("en_core_web_md") 53 | 54 | # add pipeline (declared through entry_points in setup.py) 55 | nlp.add_pipe("entityLinker", last=True) 56 | 57 | doc = nlp("I watched the Pirates of the Caribbean last silvester") 58 | 59 | # returns all entities in the whole document 60 | all_linked_entities = doc._.linkedEntities 61 | # iterates over sentences and prints linked entities 62 | for sent in doc.sents: 63 | sent._.linkedEntities.pretty_print() 64 | 65 | # OUTPUT: 66 | # https://www.wikidata.org/wiki/Q194318 Pirates of the Caribbean Series of fantasy adventure films 67 | # https://www.wikidata.org/wiki/Q12525597 Silvester the day celebrated on 31 December (Roman Catholic Church) or 2 January (Eastern Orthodox Churches) 68 | 69 | # entities are also directly accessible through spans 70 | doc[3:7]._.linkedEntities.pretty_print() 71 | # OUTPUT: 72 | # https://www.wikidata.org/wiki/Q194318 Pirates of the Caribbean Series of fantasy adventure films 73 | ``` 74 | 75 | ### EntityCollection 76 | 77 | contains an array of entity elements. It can be accessed like an array but also implements the following helper 78 | functions: 79 | 80 | - pretty_print() prints out information about all contained entities 81 | - print_super_classes() groups and prints all entites by their super class 82 | 83 | ```python 84 | doc = nlp("Elon Musk was born in South Africa. Bill Gates and Steve Jobs come from the United States") 85 | doc._.linkedEntities.print_super_entities() 86 | # OUTPUT: 87 | # human (3) : Elon Musk,Bill Gates,Steve Jobs 88 | # country (2) : South Africa,United States of America 89 | # sovereign state (2) : South Africa,United States of America 90 | # federal state (1) : United States of America 91 | # constitutional republic (1) : United States of America 92 | # democratic republic (1) : United States of America 93 | ``` 94 | 95 | ### EntityElement 96 | 97 | each linked Entity is an object of type EntityElement. Each entity contains the methods 98 | 99 | - get_description() returns description from Wikidata 100 | - get_id() returns Wikidata ID 101 | - get_label() returns Wikidata label 102 | - get_span(doc) returns the span from the spacy document that contains the linked entity. You need to provide the current `doc` as argument, in order to receive an actual `spacy.tokens.Span` object, otherwise you will receive a `SpanInfo` emulating the behaviour of a Span 103 | - get_url() returns the url to the corresponding Wikidata item 104 | - pretty_print() prints out information about the entity element 105 | - get_sub_entities(limit=10) returns EntityCollection of all entities that derive from the current 106 | entityElement (e.g. fruit -> apple, banana, etc.) 107 | - get_super_entities(limit=10) returns EntityCollection of all entities that the current entityElement 108 | derives from (e.g. New England Patriots -> Football Team)) 109 | 110 | 111 | Usage of the `get_span` method with `SpanInfo`: 112 | 113 | ```python 114 | import spacy 115 | nlp = spacy.load('en_core_web_md') 116 | nlp.add_pipe("entityLinker", last=True) 117 | text = 'Apple is competing with Microsoft.' 118 | doc = nlp(text) 119 | sents = list(doc.sents) 120 | ent = doc._.linkedEntities[0] 121 | 122 | # using the SpanInfo class 123 | span = ent.get_span() 124 | print(span.start, span.end, span.text) # behaves like a Span 125 | 126 | # check equivalence 127 | print(span == doc[0:1]) # True 128 | print(doc[0:1] == span) # TypeError: Argument 'other' has incorrect type (expected spacy.tokens.span.Span, got SpanInfo) 129 | 130 | # now get the real span 131 | span = ent.get_span(doc) # passing the doc instance here 132 | print(span.start, span.end, span.text) 133 | 134 | print(span == doc[0:1]) # True 135 | print(doc[0:1] == span) # True 136 | ``` 137 | 138 | ## Example 139 | 140 | In the following example we will use SpacyEntityLinker to find find the mentioned Football Team in our text and explore 141 | other football teams of the same type 142 | 143 | ```python 144 | 145 | doc = nlp("I follow the New England Patriots") 146 | 147 | patriots_entity = doc._.linkedEntities[0] 148 | patriots_entity.pretty_print() 149 | # OUTPUT: 150 | # https://www.wikidata.org/wiki/Q193390 151 | # New England Patriots 152 | # National Football League franchise in Foxborough, Massachusetts 153 | 154 | football_team_entity = patriots_entity.get_super_entities()[0] 155 | football_team_entity.pretty_print() 156 | # OUTPUT: 157 | # https://www.wikidata.org/wiki/Q17156793 158 | # American football team 159 | # organization, in which a group of players are organized to compete as a team in American football 160 | 161 | 162 | for child in football_team_entity.get_sub_entities(limit=32): 163 | print(child) 164 | # OUTPUT: 165 | # New Orleans Saints 166 | # New York Giants 167 | # Pittsburgh Steelers 168 | # New England Patriots 169 | # Indianapolis Colts 170 | # Miami Seahawks 171 | # Dallas Cowboys 172 | # Chicago Bears 173 | # Washington Redskins 174 | # Green Bay Packers 175 | # ... 176 | ``` 177 | 178 | ### Entity Linking Policy 179 | 180 | Currently the only method for choosing an entity given different possible matches (e.g. Paris - city vs Paris - 181 | firstname) is max-prior. This method achieves around 70% accuracy on predicting the correct entities behind link 182 | descriptions on wikipedia. 183 | 184 | ## Note 185 | 186 | The Entity Linker at the current state is still experimental and should not be used in production mode. 187 | 188 | ## Performance 189 | 190 | The current implementation supports only Sqlite. This is advantageous for development because it does not requirement 191 | any special setup and configuration. However, for more performance critical usecases, a different database with 192 | in-memory access (e.g. Redis) should be used. This may be implemented in the future. 193 | 194 | ## Data 195 | the knowledge base was derived from this dataset: https://www.kaggle.com/kenshoresearch/kensho-derived-wikimedia-data 196 | 197 | It was cleaned and post-procesed, including filtering out entities of "overrepresented" categories such as 198 | * village in China 199 | * train stations 200 | * stars in the Galaxy 201 | * etc. 202 | 203 | The purpose behind the knowledge base cleaning was to reduce the knowledge base size, while keeping the most useful entities for general purpose applications. 204 | 205 | Currently, the only way to change the knowledge base is a bit hacky and requires to replace or modify the underlying sqlite database. You will find it under site_packages/data_spacy_entity_linker/wikidb_filtered.db. The database contains 3 tables: 206 | * aliases 207 | * en_alias (english alias) 208 | * en_alias_lowercase (english alias lowercased) 209 | * joined 210 | * en_label (label of the wikidata item) 211 | * views (number of views of the corresponding wikipedia page (in a given period of time)) 212 | * inlinks (number of inlinks to the corresponding wikipedia page) 213 | * item_id (wikidata id) 214 | * description (description of the wikidata item) 215 | * statements 216 | * source_item_id (references item_id) 217 | * target_item_id (references item_id) 218 | * edge_property_id 219 | * 279=subclass of (https://www.wikidata.org/wiki/Property:P279) 220 | * 31=instance of (https://www.wikidata.org/wiki/Property:P31) 221 | * 361=part of (https://www.wikidata.org/wiki/Property:P361) 222 | 223 | 224 | ## Versions: 225 | 226 | - spacy_entity_linker>=0.0 (requires spacy>=2.2,<3.0) 227 | - spacy_entity_linker>=1.0 (requires spacy>=3.0) 228 | 229 | ## TODO 230 | 231 | - [ ] implement Entity Classifier based on sentence embeddings for improved accuracy 232 | - [ ] implement get_picture_urls() on EntityElement 233 | - [ ] retrieve statements for each EntityElement (inlinks + outlinks) 234 | -------------------------------------------------------------------------------- /downloadKnowledgeBase.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget "https://huggingface.co/MartinoMensio/spaCy-entity-linker/resolve/main/knowledge_base.tar.gz" -O /tmp/knowledge_base.tar.gz 4 | tar -xzf /tmp/knowledge_base.tar.gz --directory ./data_spacy_entity_linker 5 | rm /tmp/knowledge_base.tar.gz 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e . -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import os 6 | 7 | try: 8 | from setuptools import setup 9 | except ImportError: 10 | from distutils.core import setup 11 | 12 | 13 | def open_file(fname): 14 | return open(os.path.join(os.path.dirname(__file__), fname)) 15 | 16 | 17 | with open("README.md", "r") as fh: 18 | long_description = fh.read() 19 | 20 | setup( 21 | name='spacy-entity-linker', 22 | version='1.0.3', 23 | author='Emanuel Gerber', 24 | author_email='emanuel.j.gerber@gmail.com', 25 | packages=['spacy_entity_linker'], 26 | url='https://github.com/egerber/spacy-entity-linker', 27 | license="MIT", 28 | classifiers=["Environment :: Console", 29 | "Intended Audience :: Developers", 30 | "Intended Audience :: Science/Research", 31 | "License :: OSI Approved :: MIT License", 32 | "Programming Language :: Cython", 33 | "Programming Language :: Python", 34 | "Programming Language :: Python :: 3.6" 35 | ], 36 | description='Linked Entity Pipeline for spaCy', 37 | long_description=long_description, 38 | long_description_content_type="text/markdown", 39 | zip_safe=True, 40 | install_requires=[ 41 | 'spacy>=3.0.0', 42 | 'numpy>=1.0.0', 43 | 'tqdm' 44 | ], 45 | entry_points={ 46 | 'spacy_factories': 'entityLinker = spacy_entity_linker.EntityLinker:EntityLinker' 47 | } 48 | ) 49 | -------------------------------------------------------------------------------- /spacy_entity_linker/DatabaseConnection.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import os 3 | 4 | from .__main__ import download_knowledge_base 5 | 6 | MAX_DEPTH_CHAIN = 10 7 | P_INSTANCE_OF = 31 8 | P_SUBCLASS = 279 9 | 10 | MAX_ITEMS_CACHE = 100000 11 | 12 | conn = None 13 | entity_cache = {} 14 | chain_cache = {} 15 | 16 | DB_DEFAULT_PATH = os.path.abspath(os.path.join(__file__, "../../data_spacy_entity_linker/wikidb_filtered.db")) 17 | 18 | wikidata_instance = None 19 | 20 | 21 | def get_wikidata_instance(): 22 | global wikidata_instance 23 | 24 | if wikidata_instance is None: 25 | wikidata_instance = WikidataQueryController() 26 | 27 | return wikidata_instance 28 | 29 | 30 | class WikidataQueryController: 31 | 32 | def __init__(self): 33 | self.conn = None 34 | 35 | self.cache = { 36 | "entity": {}, 37 | "chain": {}, 38 | "name": {} 39 | } 40 | 41 | self.init_database_connection() 42 | 43 | def _get_cached_value(self, cache_type, key): 44 | return self.cache[cache_type][key] 45 | 46 | def _is_cached(self, cache_type, key): 47 | return key in self.cache[cache_type] 48 | 49 | def _add_to_cache(self, cache_type, key, value): 50 | if len(self.cache[cache_type]) < MAX_ITEMS_CACHE: 51 | self.cache[cache_type][key] = value 52 | 53 | def init_database_connection(self, path=DB_DEFAULT_PATH): 54 | # check if the database exists 55 | if not os.path.exists(DB_DEFAULT_PATH): 56 | # Automatically download the knowledge base if it isn't already 57 | download_knowledge_base() 58 | self.conn = sqlite3.connect(path, check_same_thread=False) 59 | 60 | def clear_cache(self): 61 | self.cache["entity"].clear() 62 | self.cache["chain"].clear() 63 | self.cache["name"].clear() 64 | 65 | def get_entities_from_alias(self, alias): 66 | c = self.conn.cursor() 67 | if self._is_cached("entity", alias): 68 | return self._get_cached_value("entity", alias).copy() 69 | 70 | query_alias = """SELECT j.item_id,j.en_label, j.en_description,j.views,j.inlinks,a.en_alias 71 | FROM aliases as a LEFT JOIN joined as j ON a.item_id = j.item_id 72 | WHERE a.en_alias_lowercase = ? AND j.item_id NOT NULL""" 73 | 74 | c.execute(query_alias, [alias.lower()]) 75 | fetched_rows = c.fetchall() 76 | 77 | self._add_to_cache("entity", alias, fetched_rows) 78 | return fetched_rows 79 | 80 | def get_instances_of(self, item_id, properties=[P_INSTANCE_OF, P_SUBCLASS], count=1000): 81 | query = "SELECT source_item_id from statements where target_item_id={} and edge_property_id IN ({}) LIMIT {}".format( 82 | item_id, ",".join([str(prop) for prop in properties]), count) 83 | 84 | c = self.conn.cursor() 85 | c.execute(query) 86 | 87 | res = c.fetchall() 88 | 89 | return [e[0] for e in res] 90 | 91 | def get_entity_name(self, item_id): 92 | if self._is_cached("name", item_id): 93 | return self._get_cached_value("name", item_id) 94 | 95 | c = self.conn.cursor() 96 | query = "SELECT en_label from joined WHERE item_id=?" 97 | c.execute(query, [item_id]) 98 | res = c.fetchone() 99 | 100 | if res and len(res): 101 | if res[0] is None: 102 | self._add_to_cache("name", item_id, 'no label') 103 | else: 104 | self._add_to_cache("name", item_id, res[0]) 105 | else: 106 | self._add_to_cache("name", item_id, '') 107 | 108 | return self._get_cached_value("name", item_id) 109 | 110 | def get_entity(self, item_id): 111 | c = self.conn.cursor() 112 | query = "SELECT j.item_id,j.en_label,j.en_description,j.views,j.inlinks from joined as j " \ 113 | "WHERE j.item_id=={}".format(item_id) 114 | 115 | res = c.execute(query) 116 | 117 | return res.fetchone() 118 | 119 | def get_children(self, item_id, limit=100): 120 | c = self.conn.cursor() 121 | query = "SELECT j.item_id,j.en_label,j.en_description,j.views,j.inlinks from joined as j " \ 122 | "JOIN statements as s on j.item_id=s.source_item_id " \ 123 | "WHERE s.target_item_id={} and s.edge_property_id IN (279,31) LIMIT {}".format(item_id, limit) 124 | 125 | res = c.execute(query) 126 | 127 | return res.fetchall() 128 | 129 | def get_parents(self, item_id, limit=100): 130 | c = self.conn.cursor() 131 | query = "SELECT j.item_id,j.en_label,j.en_description,j.views,j.inlinks from joined as j " \ 132 | "JOIN statements as s on j.item_id=s.target_item_id " \ 133 | "WHERE s.source_item_id={} and s.edge_property_id IN (279,31) LIMIT {}".format(item_id, limit) 134 | 135 | res = c.execute(query) 136 | 137 | return res.fetchall() 138 | 139 | def get_categories(self, item_id, max_depth=10): 140 | chain = [] 141 | edges = [] 142 | self._append_chain_elements(item_id, 0, chain, edges, max_depth, [P_INSTANCE_OF, P_SUBCLASS]) 143 | return [el[0] for el in chain] 144 | 145 | def get_chain(self, item_id, max_depth=10, property=P_INSTANCE_OF): 146 | chain = [] 147 | edges = [] 148 | self._append_chain_elements(item_id, 0, chain, edges, max_depth, property) 149 | return chain 150 | 151 | def get_recursive_edges(self, item_id): 152 | chain = [] 153 | edges = [] 154 | self._append_chain_elements(self, item_id, 0, chain, edges) 155 | return edges 156 | 157 | def _append_chain_elements(self, item_id, level=0, chain=None, edges=None, max_depth=10, prop=P_INSTANCE_OF): 158 | if chain is None: 159 | chain = [] 160 | if edges is None: 161 | edges = [] 162 | properties = prop 163 | if type(prop) != list: 164 | properties = [prop] 165 | 166 | if self._is_cached("chain", (item_id, max_depth)): 167 | chain += self._get_cached_value("chain", (item_id, max_depth)).copy() 168 | return 169 | 170 | # prevent infinite recursion 171 | if level >= max_depth: 172 | return 173 | 174 | c = self.conn.cursor() 175 | 176 | query = "SELECT target_item_id,edge_property_id from statements where source_item_id={} and edge_property_id IN ({})".format( 177 | item_id, ",".join([str(prop) for prop in properties])) 178 | 179 | # set value for current item in order to prevent infinite recursion 180 | self._add_to_cache("chain", (item_id, max_depth), []) 181 | 182 | for target_item in c.execute(query): 183 | 184 | chain_ids = [el[0] for el in chain] 185 | 186 | if not (target_item[0] in chain_ids): 187 | chain += [(target_item[0], level + 1)] 188 | edges.append((item_id, target_item[0], target_item[1])) 189 | self._append_chain_elements(target_item[0], 190 | level=level + 1, 191 | chain=chain, 192 | edges=edges, 193 | max_depth=max_depth, 194 | prop=prop) 195 | 196 | self._add_to_cache("chain", (item_id, max_depth), chain) 197 | 198 | 199 | if __name__ == '__main__': 200 | queryInstance = WikidataQueryController() 201 | 202 | queryInstance.init_database_connection() 203 | print(queryInstance.get_categories(13191, max_depth=1)) 204 | print(queryInstance.get_categories(13191, max_depth=1)) 205 | -------------------------------------------------------------------------------- /spacy_entity_linker/EntityCandidates.py: -------------------------------------------------------------------------------- 1 | MAX_ITEMS_PREVIEW=20 2 | 3 | class EntityCandidates: 4 | 5 | def __init__(self, entity_elements): 6 | self.entity_elements = entity_elements 7 | 8 | def __iter__(self): 9 | for entity in self.entity_elements: 10 | yield entity 11 | 12 | def __len__(self): 13 | return len(self.entity_elements) 14 | 15 | def __getitem__(self, item): 16 | return self.entity_elements[item] 17 | 18 | def pretty_print(self): 19 | for entity in self.entity_elements: 20 | entity.pretty_print() 21 | 22 | def __repr__(self) -> str: 23 | preview_str="" 24 | for index,entity_element in enumerate(self): 25 | if index>MAX_ITEMS_PREVIEW: 26 | break 27 | preview_str+="{}\n".format(entity_element.get_preview_string()) 28 | 29 | return preview_str 30 | 31 | def __str__(self): 32 | return str(["entity {}: {} (<{}>)".format(i, entity.get_label(), entity.get_description()) for i, entity in 33 | enumerate(self.entity_elements)]) 34 | -------------------------------------------------------------------------------- /spacy_entity_linker/EntityClassifier.py: -------------------------------------------------------------------------------- 1 | from itertools import groupby 2 | import numpy as np 3 | 4 | 5 | class EntityClassifier: 6 | def __init__(self): 7 | pass 8 | 9 | def _get_grouped_by_length(self, entities): 10 | sorted_by_len = sorted(entities, key=lambda entity: len(entity.get_span()), reverse=True) 11 | 12 | entities_by_length = {} 13 | for length, group in groupby(sorted_by_len, lambda entity: len(entity.get_span())): 14 | entities = list(group) 15 | entities_by_length[length] = entities 16 | 17 | return entities_by_length 18 | 19 | def _filter_max_length(self, entities): 20 | entities_by_length = self._get_grouped_by_length(entities) 21 | max_length = max(list(entities_by_length.keys())) 22 | 23 | return entities_by_length[max_length] 24 | 25 | def _select_max_prior(self, entities): 26 | priors = [entity.get_prior() for entity in entities] 27 | return entities[np.argmax(priors)] 28 | 29 | def _get_casing_difference(self, word1, original): 30 | difference = 0 31 | for w1, w2 in zip(word1, original): 32 | if w1 != w2: 33 | difference += 1 34 | 35 | return difference 36 | 37 | def _filter_most_similar(self, entities): 38 | similarities = np.array( 39 | [self._get_casing_difference(entity.get_span().text, entity.get_original_alias()) for entity in entities]) 40 | 41 | min_indices = np.where(similarities == similarities.min())[0].tolist() 42 | 43 | return [entities[i] for i in min_indices] 44 | 45 | def __call__(self, entities): 46 | filtered_by_length = self._filter_max_length(entities) 47 | filtered_by_casing = self._filter_most_similar(filtered_by_length) 48 | 49 | return self._select_max_prior(filtered_by_casing) 50 | -------------------------------------------------------------------------------- /spacy_entity_linker/EntityCollection.py: -------------------------------------------------------------------------------- 1 | import srsly 2 | from collections import Counter, defaultdict 3 | 4 | from .DatabaseConnection import get_wikidata_instance 5 | 6 | MAX_ITEMS_PREVIEW=20 7 | 8 | 9 | class EntityCollection: 10 | 11 | def __init__(self, entities=[]): 12 | self.entities = entities 13 | 14 | def __iter__(self): 15 | for entity in self.entities: 16 | yield entity 17 | 18 | def __getitem__(self, item): 19 | return self.entities[item] 20 | 21 | def __len__(self): 22 | return len(self.entities) 23 | 24 | def append(self, entity): 25 | self.entities.append(entity) 26 | 27 | def get_categories(self, max_depth=1): 28 | categories = [] 29 | for entity in self.entities: 30 | categories += entity.get_categories(max_depth) 31 | 32 | return categories 33 | 34 | def print_super_entities(self, max_depth=1, limit=10): 35 | wikidataInstance = get_wikidata_instance() 36 | 37 | all_categories = [] 38 | category_to_entites = defaultdict(list) 39 | 40 | for e in self.entities: 41 | for category in e.get_categories(max_depth): 42 | category_to_entites[category].append(e) 43 | all_categories.append(category) 44 | 45 | counter = Counter() 46 | counter.update(all_categories) 47 | 48 | for category, frequency in counter.most_common(limit): 49 | print("{} ({}) : {}".format(wikidataInstance.get_entity_name(category), frequency, 50 | ','.join([str(e) for e in category_to_entites[category]]))) 51 | 52 | def __repr__(self) -> str: 53 | preview_str="MAX_ITEMS_PREVIEW: 56 | preview_str+="\n...{} more".format(len(self)-MAX_ITEMS_PREVIEW) 57 | break 58 | preview_str+="\n-{}".format(entity_element.get_preview_string()) 59 | 60 | preview_str+=">" 61 | return preview_str 62 | 63 | def pretty_print(self): 64 | for entity in self.entities: 65 | entity.pretty_print() 66 | 67 | def grouped_by_super_entities(self, max_depth=1): 68 | counter = Counter() 69 | counter.update(self.get_categories(max_depth)) 70 | 71 | return counter 72 | 73 | def get_distinct_categories(self, max_depth=1): 74 | return list(set(self.get_categories(max_depth))) 75 | 76 | 77 | @srsly.msgpack_encoders("EntityCollection") 78 | def serialize_obj(obj, chain=None): 79 | if isinstance(obj, EntityCollection): 80 | return { 81 | "entities": obj.entities, 82 | } 83 | # otherwise return the original object so another serializer can handle it 84 | return obj if chain is None else chain(obj) 85 | 86 | 87 | @srsly.msgpack_decoders("EntityCollection") 88 | def deserialize_obj(obj, chain=None): 89 | if "entities" in obj: 90 | return EntityCollection(entities=obj["entities"]) 91 | # otherwise return the original object so another serializer can handle it 92 | return obj if chain is None else chain(obj) -------------------------------------------------------------------------------- /spacy_entity_linker/EntityElement.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | import srsly 3 | 4 | from .DatabaseConnection import get_wikidata_instance 5 | from .EntityCollection import EntityCollection 6 | from .SpanInfo import SpanInfo 7 | 8 | class EntityElement: 9 | def __init__(self, row, span): 10 | self.identifier = row[0] 11 | self.prior = 0 12 | self.original_alias = None 13 | self.in_degree = None 14 | self.label = None 15 | self.description = None 16 | 17 | if len(row) > 1: 18 | self.label = row[1] 19 | if len(row) > 2: 20 | self.description = row[2] 21 | if len(row) > 3 and row[3]: 22 | self.prior = row[3] 23 | if len(row) > 4 and row[4]: 24 | self.in_degree = row[4] 25 | if len(row) > 5 and row[5]: 26 | self.original_alias = row[5] 27 | 28 | self.url="https://www.wikidata.org/wiki/Q{}".format(self.get_id()) 29 | if span: 30 | self.span_info = SpanInfo.from_span(span) 31 | else: 32 | # sometimes the constructor is called with None as second parameter (e.g. in get_sub_entities/get_super_entities) 33 | self.span_info = None 34 | 35 | self.chain = None 36 | self.chain_ids = None 37 | 38 | self.wikidata_instance = get_wikidata_instance() 39 | 40 | def get_in_degree(self): 41 | return self.in_degree 42 | 43 | def get_original_alias(self): 44 | return self.original_alias 45 | 46 | def is_singleton(self): 47 | return len(self.get_chain()) == 0 48 | 49 | def get_span(self, doc: spacy.tokens.Doc=None): 50 | """ 51 | Returns the span of the entity in the document. 52 | :param doc: the document in which the entity is contained 53 | :return: the span of the entity in the document 54 | 55 | If the doc is not None, it returns a real spacy.tokens.Span. 56 | Otherwise it returns the instance of SpanInfo that emulates the behaviour of a spacy.tokens.Span 57 | """ 58 | if doc is not None: 59 | # return a real spacy.tokens.Span 60 | return self.span_info.get_span(doc) 61 | # otherwise return the instance of SpanInfo that emulates the behaviour of a spacy.tokens.Span 62 | return self.span_info 63 | 64 | def get_label(self): 65 | return self.label 66 | 67 | def get_id(self): 68 | return self.identifier 69 | 70 | def get_prior(self): 71 | return self.prior 72 | 73 | def get_chain(self, max_depth=10): 74 | if self.chain is None: 75 | self.chain = self.wikidata_instance.get_chain(self.identifier, max_depth=max_depth, property=31) 76 | return self.chain 77 | 78 | def is_category(self): 79 | pass 80 | 81 | def is_leaf(self): 82 | pass 83 | 84 | def get_categories(self, max_depth=10): 85 | return self.wikidata_instance.get_categories(self.identifier, max_depth=max_depth) 86 | 87 | def get_sub_entities(self, limit=10): 88 | return EntityCollection( 89 | [EntityElement(row, None) for row in self.wikidata_instance.get_children(self.get_id(), limit)]) 90 | 91 | def get_super_entities(self, limit=10): 92 | return EntityCollection( 93 | [EntityElement(row, None) for row in self.wikidata_instance.get_parents(self.get_id(), limit)]) 94 | 95 | def get_subclass_hierarchy(self): 96 | chain = self.wikidata_instance.get_chain(self.identifier, max_depth=5, property=279) 97 | return [self.wikidata_instance.get_entity_name(el[0]) for el in chain] 98 | 99 | def get_instance_of_hierarchy(self): 100 | chain = self.wikidata_instance.get_chain(self.identifier, max_depth=5, property=31) 101 | return [self.wikidata_instance.get_entity_name(el[0]) for el in chain] 102 | 103 | def get_chain_ids(self, max_depth=10): 104 | if self.chain_ids is None: 105 | self.chain_ids = set([el[0] for el in self.get_chain(max_depth=max_depth)]) 106 | 107 | return self.chain_ids 108 | 109 | def get_description(self): 110 | if self.description: 111 | return self.description 112 | else: 113 | return "" 114 | 115 | def is_intersecting(self, other_element): 116 | return len(self.get_chain_ids().intersection(other_element.get_chain_ids())) > 0 117 | 118 | def serialize(self): 119 | return { 120 | "id": self.get_id(), 121 | "label": self.get_label(), 122 | "span": self.get_span() 123 | } 124 | 125 | def pretty_print(self): 126 | print(self.__repr__()) 127 | 128 | def get_url(self): 129 | return self.url 130 | 131 | def __repr__(self): 132 | return "".format(self.get_preview_string()) 133 | 134 | def get_preview_string(self): 135 | return "{0:<10} {1:<25} {2:<50}".format(self.get_url(),self.get_label(),self.get_description()[:100]) 136 | 137 | def pretty_string(self, description=False): 138 | if description: 139 | return "{} => {} <{}>".format(self.span_info, self.get_label(), self.get_description()) 140 | else: 141 | return "{} => {}".format(self.span_info, self.get_label()) 142 | 143 | # TODO: this method has never worked because the custom attribute is not registered properly 144 | # def save(self, category): 145 | # for span in self.span: 146 | # span.sent._.linked_entities.append( 147 | # {"id": self.identifier, "range": [span.start, span.end + 1], "category": category}) 148 | 149 | def __str__(self): 150 | label = self.get_label() 151 | if label: 152 | return label 153 | else: 154 | return "" 155 | 156 | def __eq__(self, other): 157 | return isinstance(other, EntityElement) and other.get_id() == self.get_id() 158 | 159 | 160 | @srsly.msgpack_encoders("EntityElement") 161 | def serialize_obj(obj, chain=None): 162 | if isinstance(obj, EntityElement): 163 | result = { 164 | "identifier": obj.identifier, 165 | "label": obj.label, 166 | "description": obj.description, 167 | "prior": obj.prior, 168 | "in_degree": obj.in_degree, 169 | "original_alias": obj.original_alias, 170 | "span_info": obj.span_info, 171 | } 172 | return result 173 | # otherwise return the original object so another serializer can handle it 174 | return obj if chain is None else chain(obj) 175 | 176 | 177 | @srsly.msgpack_decoders("EntityElement") 178 | def deserialize_obj(obj, chain=None): 179 | if "identifier" in obj: 180 | row = [obj['identifier'], obj['label'], obj['description'], obj['prior'], obj['in_degree'], obj['original_alias']] 181 | span_info = obj['span_info'] 182 | return EntityElement(row, span_info) 183 | # otherwise return the original object so another serializer can handle it 184 | return obj if chain is None else chain(obj) -------------------------------------------------------------------------------- /spacy_entity_linker/EntityLinker.py: -------------------------------------------------------------------------------- 1 | from spacy.tokens import Doc, Span 2 | from spacy.language import Language 3 | 4 | from .EntityClassifier import EntityClassifier 5 | from .EntityCollection import EntityCollection 6 | from .TermCandidateExtractor import TermCandidateExtractor 7 | 8 | @Language.factory('entityLinker') 9 | class EntityLinker: 10 | 11 | def __init__(self, nlp, name): 12 | Doc.set_extension("linkedEntities", default=EntityCollection(), force=True) 13 | Span.set_extension("linkedEntities", default=None, force=True) 14 | 15 | def __call__(self, doc): 16 | tce = TermCandidateExtractor(doc) 17 | classifier = EntityClassifier() 18 | 19 | for sent in doc.sents: 20 | sent._.linkedEntities = EntityCollection([]) 21 | 22 | entities = [] 23 | for termCandidates in tce: 24 | entityCandidates = termCandidates.get_entity_candidates() 25 | if len(entityCandidates) > 0: 26 | entity = classifier(entityCandidates) 27 | span = doc[entity.span_info.start:entity.span_info.end] 28 | # Add the entity to the sentence-level EntityCollection 29 | span.sent._.linkedEntities.append(entity) 30 | # Also associate the token span with the entity 31 | span._.linkedEntities = entity 32 | # And finally append to the document-level collection 33 | entities.append(entity) 34 | 35 | doc._.linkedEntities = EntityCollection(entities) 36 | 37 | return doc 38 | -------------------------------------------------------------------------------- /spacy_entity_linker/SpanInfo.py: -------------------------------------------------------------------------------- 1 | """ 2 | SpanInfo class 3 | Stores the info of spacy.tokens.Span (start, end and text of a span) by making it serializable 4 | """ 5 | 6 | import spacy 7 | import srsly 8 | 9 | class SpanInfo: 10 | 11 | @staticmethod 12 | def from_span(span: spacy.tokens.Span): 13 | return SpanInfo(span.start, span.end, span.text) 14 | 15 | def __init__(self, start: int, end: int, text: str): 16 | self.start = start 17 | self.end = end 18 | self.text = text 19 | 20 | 21 | def __repr__(self) -> str: 22 | return self.text 23 | 24 | def __len__(self): 25 | return self.end - self.start 26 | 27 | def __eq__(self, __o: object) -> bool: 28 | if isinstance(__o, SpanInfo) or isinstance(__o, spacy.tokens.Span): 29 | return self.start == __o.start and self.end == __o.end and self.text == __o.text 30 | return False 31 | 32 | def get_span(self, doc: spacy.tokens.Doc): 33 | """ 34 | Returns the real spacy.tokens.Span of the doc from the stored info""" 35 | return doc[self.start:self.end] 36 | 37 | 38 | @srsly.msgpack_encoders("SpanInfo") 39 | def serialize_spaninfo(obj, chain=None): 40 | if isinstance(obj, SpanInfo): 41 | result = { 42 | "start": obj.start, 43 | "end": obj.end, 44 | "text": obj.text, 45 | } 46 | return result 47 | # otherwise return the original object so another serializer can handle it 48 | return obj if chain is None else chain(obj) 49 | 50 | @srsly.msgpack_decoders("SpanInfo") 51 | def deserialize_spaninfo(obj, chain=None): 52 | if "start" in obj: 53 | return SpanInfo(obj['start'], obj['end'], obj['text']) 54 | # otherwise return the original object so another serializer can handle it 55 | return obj if chain is None else chain(obj) -------------------------------------------------------------------------------- /spacy_entity_linker/TermCandidate.py: -------------------------------------------------------------------------------- 1 | from .EntityCandidates import EntityCandidates 2 | from .EntityElement import EntityElement 3 | from .DatabaseConnection import get_wikidata_instance 4 | 5 | 6 | class TermCandidate: 7 | def __init__(self, span): 8 | self.variations = [span] 9 | 10 | def pretty_print(self): 11 | print("Term Candidates are [{}]".format(self)) 12 | 13 | def append(self, span): 14 | self.variations.append(span) 15 | 16 | def has_plural(self, variation): 17 | return any([t.tag_ == "NNS" for t in variation]) 18 | 19 | def get_singular(self, variation): 20 | return ' '.join([t.text if t.tag_ != "NNS" else t.lemma_ for t in variation]) 21 | 22 | def __str__(self): 23 | return ', '.join([variation.text for variation in self.variations]) 24 | 25 | def get_entity_candidates(self): 26 | wikidata_instance = get_wikidata_instance() 27 | entities_by_variation = {} 28 | for variation in self.variations: 29 | entities_by_variation[variation] = wikidata_instance.get_entities_from_alias(variation.text) 30 | if self.has_plural(variation): 31 | entities_by_variation[variation] += wikidata_instance.get_entities_from_alias( 32 | self.get_singular(variation)) 33 | 34 | entity_elements = [] 35 | for variation, entities in entities_by_variation.items(): 36 | entity_elements += [EntityElement(entity, variation) for entity in entities] 37 | 38 | return EntityCandidates(entity_elements) 39 | -------------------------------------------------------------------------------- /spacy_entity_linker/TermCandidateExtractor.py: -------------------------------------------------------------------------------- 1 | from .TermCandidate import TermCandidate 2 | 3 | 4 | class TermCandidateExtractor: 5 | def __init__(self, doc): 6 | self.doc = doc 7 | 8 | def __iter__(self): 9 | for sent in self.doc.sents: 10 | for candidate in self._get_candidates_in_sent(sent, self.doc): 11 | yield candidate 12 | 13 | def _get_candidates_in_sent(self, sent, doc): 14 | roots = list(filter(lambda token: token.dep_ == "ROOT", sent)) 15 | if len(roots) < 1: 16 | return [] 17 | root = roots[0] 18 | 19 | excluded_children = [] 20 | candidates = [] 21 | 22 | def get_candidates(node, doc): 23 | 24 | if (node.pos_ in ["PROPN", "NOUN"]) and node.pos_ not in ["PRON"]: 25 | term_candidates = TermCandidate(doc[node.i:node.i + 1]) 26 | 27 | for child in node.children: 28 | 29 | start_index = min(node.i, child.i) 30 | end_index = max(node.i, child.i) 31 | 32 | if child.dep_ == "compound" or child.dep_ == "amod": 33 | subtree_tokens = list(child.subtree) 34 | if all([c.dep_ == "compound" for c in subtree_tokens]): 35 | start_index = min([c.i for c in subtree_tokens]) 36 | term_candidates.append(doc[start_index:end_index + 1]) 37 | 38 | if not child.dep_ == "amod": 39 | term_candidates.append(doc[start_index:start_index + 1]) 40 | excluded_children.append(child) 41 | 42 | if child.dep_ == "prep" and child.text == "of": 43 | end_index = max([c.i for c in child.subtree]) 44 | term_candidates.append(doc[start_index:end_index + 1]) 45 | 46 | candidates.append(term_candidates) 47 | 48 | for child in node.children: 49 | if child in excluded_children: 50 | continue 51 | get_candidates(child, doc) 52 | 53 | get_candidates(root, doc) 54 | 55 | return candidates 56 | -------------------------------------------------------------------------------- /spacy_entity_linker/__init__.py: -------------------------------------------------------------------------------- 1 | try: # Python 3.8 2 | import importlib.metadata as importlib_metadata 3 | except ImportError: 4 | import importlib_metadata # noqa: F401 5 | 6 | from .EntityLinker import EntityLinker 7 | 8 | pkg_meta = importlib_metadata.metadata(__name__.split(".")[0]) 9 | __version__ = pkg_meta["version"] 10 | __all__ = [EntityLinker] 11 | -------------------------------------------------------------------------------- /spacy_entity_linker/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tarfile 3 | import urllib.request 4 | import tqdm 5 | import os 6 | 7 | 8 | class DownloadProgressBar(tqdm.tqdm): 9 | """ 10 | Code taken from https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads 11 | """ 12 | def update_to(self, chunk_id=1, max_chunk_size=1, total_size=None): 13 | if total_size is not None: 14 | self.total = total_size 15 | self.update(chunk_id * max_chunk_size - self.n) 16 | 17 | 18 | def download_knowledge_base( 19 | file_url="https://huggingface.co/MartinoMensio/spaCy-entity-linker/resolve/main/knowledge_base.tar.gz" 20 | ): 21 | OUTPUT_TAR_FILE = os.path.abspath( 22 | os.path.dirname(__file__)) + '/../data_spacy_entity_linker/wikidb_filtered.tar.gz' 23 | OUTPUT_DB_PATH = os.path.abspath(os.path.dirname(__file__)) + '/../data_spacy_entity_linker' 24 | if not os.path.exists(OUTPUT_DB_PATH): 25 | os.makedirs(OUTPUT_DB_PATH) 26 | with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc='Downloading knowledge base') as dpb: 27 | urllib.request.urlretrieve(file_url, filename=OUTPUT_TAR_FILE, reporthook=dpb.update_to) 28 | 29 | tar = tarfile.open(OUTPUT_TAR_FILE) 30 | tar.extractall(OUTPUT_DB_PATH) 31 | tar.close() 32 | 33 | os.remove(OUTPUT_TAR_FILE) 34 | 35 | 36 | if __name__ == "__main__": 37 | 38 | if len(sys.argv) < 2: 39 | print("No arguments given.") 40 | pass 41 | 42 | command = sys.argv.pop(1) 43 | 44 | if command == "download_knowledge_base": 45 | download_knowledge_base() 46 | else: 47 | raise ValueError("Unrecognized command given. If you are trying to install the knowledge base, run " 48 | "'python -m spacy_entity_linker \"download_knowledge_base\"'.") 49 | -------------------------------------------------------------------------------- /tests/test_EntityCollection.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spacy 3 | from spacy_entity_linker.EntityCollection import EntityCollection 4 | 5 | 6 | class TestEntityCollection(unittest.TestCase): 7 | 8 | def __init__(self, arg, *args, **kwargs): 9 | super(TestEntityCollection, self).__init__(arg, *args, **kwargs) 10 | self.nlp = spacy.load('en_core_web_sm') 11 | 12 | def setUp(self): 13 | self.nlp.add_pipe("entityLinker", last=True) 14 | self.doc = self.nlp( 15 | "Elon Musk was born in South Africa. Bill Gates and Steve Jobs come from the United States") 16 | 17 | def tearDown(self): 18 | self.nlp.remove_pipe("entityLinker") 19 | 20 | def test_categories(self): 21 | doc = self.doc 22 | 23 | res = doc._.linkedEntities.get_distinct_categories() 24 | print(res) 25 | assert res != None 26 | assert len(res) > 0 27 | 28 | res = doc._.linkedEntities.grouped_by_super_entities() 29 | print(res) 30 | assert res != None 31 | assert len(res) > 0 32 | 33 | def test_printing(self): 34 | doc = self.doc 35 | 36 | # pretty print 37 | doc._.linkedEntities.pretty_print() 38 | 39 | # repr 40 | print(doc._.linkedEntities) 41 | 42 | def test_super_entities(self): 43 | doc = self.doc 44 | 45 | doc._.linkedEntities.print_super_entities() 46 | 47 | def test_iterable_indexable(self): 48 | doc = self.doc 49 | 50 | ents = list(doc._.linkedEntities) 51 | assert len(ents) > 0 52 | 53 | ent = doc._.linkedEntities[0] 54 | assert ent != None 55 | 56 | length = len(doc._.linkedEntities) 57 | assert length > 0 -------------------------------------------------------------------------------- /tests/test_EntityElement.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spacy 3 | 4 | 5 | class TestEntityElement(unittest.TestCase): 6 | 7 | def __init__(self, arg, *args, **kwargs): 8 | super(TestEntityElement, self).__init__(arg, *args, **kwargs) 9 | self.nlp = spacy.load('en_core_web_sm') 10 | 11 | def setUp(self): 12 | self.nlp.add_pipe("entityLinker", last=True) 13 | self.doc = self.nlp( 14 | "Elon Musk was born in South Africa. Bill Gates and Steve Jobs come from the United States. The US are located in North America. A ship is made of wood.") 15 | 16 | def tearDown(self): 17 | self.nlp.remove_pipe("entityLinker") 18 | 19 | def test_get_in_degree(self): 20 | doc = self.doc 21 | 22 | all_linked_entities = doc._.linkedEntities 23 | in_degree = all_linked_entities[0].get_in_degree() 24 | assert in_degree > 0 25 | 26 | def test_get_original_alias(self): 27 | doc = self.doc 28 | 29 | all_linked_entities = doc._.linkedEntities 30 | original_alias = all_linked_entities[0].get_original_alias() 31 | assert original_alias == "Elon Musk" 32 | 33 | def test_is_singleton(self): 34 | doc = self.doc 35 | 36 | all_linked_entities = doc._.linkedEntities 37 | is_singleton = all_linked_entities[0].is_singleton() 38 | assert is_singleton == False 39 | is_singleton = all_linked_entities[-1].is_singleton() 40 | assert is_singleton == True 41 | 42 | def test_get_span(self): 43 | doc = self.doc 44 | 45 | all_linked_entities = doc._.linkedEntities 46 | span = all_linked_entities[0].get_span() 47 | real_span = doc[0:2] 48 | assert span.text == real_span.text 49 | assert span.start == real_span.start 50 | assert span.end == real_span.end 51 | 52 | def test_get_label(self): 53 | doc = self.doc 54 | 55 | all_linked_entities = doc._.linkedEntities 56 | label = all_linked_entities[0].get_label() 57 | assert label == "Elon Musk" 58 | 59 | def test_get_id(self): 60 | doc = self.doc 61 | 62 | all_linked_entities = doc._.linkedEntities 63 | id = all_linked_entities[0].get_id() 64 | assert id > 0 65 | 66 | def test_get_prior(self): 67 | doc = self.doc 68 | 69 | all_linked_entities = doc._.linkedEntities 70 | prior = all_linked_entities[0].get_prior() 71 | assert prior > 0 72 | 73 | def test_get_chain(self): 74 | doc = self.doc 75 | 76 | all_linked_entities = doc._.linkedEntities 77 | chain = all_linked_entities[0].get_chain() 78 | assert chain != None 79 | assert len(chain) > 0 80 | 81 | def test_get_categories(self): 82 | doc = self.doc 83 | 84 | all_linked_entities = doc._.linkedEntities 85 | categories = all_linked_entities[0].get_categories() 86 | assert categories != None 87 | assert len(categories) > 0 88 | 89 | def test_get_sub_entities(self): 90 | doc = self.doc 91 | 92 | all_linked_entities = doc._.linkedEntities 93 | # [-1] --> wood 94 | sub_entities = all_linked_entities[-1].get_sub_entities() 95 | assert sub_entities != None 96 | assert len(sub_entities) > 0 97 | 98 | def test_get_super_entities(self): 99 | doc = self.doc 100 | 101 | all_linked_entities = doc._.linkedEntities 102 | super_entities = all_linked_entities[0].get_super_entities() 103 | assert super_entities != None 104 | assert len(super_entities) > 0 105 | 106 | def test_get_subclass_hierarchy(self): 107 | doc = self.doc 108 | 109 | all_linked_entities = doc._.linkedEntities 110 | # [5] --> US 111 | hierarchy = all_linked_entities[5].get_subclass_hierarchy() 112 | assert hierarchy != None 113 | assert len(hierarchy) > 0 114 | assert 'country' in hierarchy 115 | 116 | def test_get_instance_of_hierarchy(self): 117 | doc = self.doc 118 | 119 | all_linked_entities = doc._.linkedEntities 120 | # [5] --> US 121 | hierarchy = all_linked_entities[5].get_instance_of_hierarchy() 122 | assert hierarchy != None 123 | assert len(hierarchy) > 0 124 | assert 'country' in hierarchy 125 | 126 | def test_get_chain_ids(self): 127 | doc = self.doc 128 | 129 | all_linked_entities = doc._.linkedEntities 130 | chain_ids = all_linked_entities[0].get_chain_ids() 131 | assert chain_ids != None 132 | assert len(chain_ids) > 0 133 | 134 | def test_get_description(self): 135 | doc = self.doc 136 | 137 | all_linked_entities = doc._.linkedEntities 138 | description = all_linked_entities[0].get_description() 139 | assert description != None 140 | assert len(description) > 0 141 | 142 | def test_is_intersecting(self): 143 | doc = self.doc 144 | 145 | all_linked_entities = doc._.linkedEntities 146 | assert not all_linked_entities[0].is_intersecting(all_linked_entities[1]) 147 | # United States and US 148 | assert all_linked_entities[4].is_intersecting(all_linked_entities[5]) 149 | 150 | def test_serialize(self): 151 | doc = self.doc 152 | 153 | all_linked_entities = doc._.linkedEntities 154 | serialized = all_linked_entities[0].serialize() 155 | assert serialized != None 156 | assert len(serialized) > 0 157 | assert 'id' in serialized 158 | assert 'label' in serialized 159 | assert 'span' in serialized 160 | 161 | def test_pretty_print(self): 162 | doc = self.doc 163 | 164 | all_linked_entities = doc._.linkedEntities 165 | all_linked_entities[0].pretty_print() 166 | 167 | def test_get_url(self): 168 | doc = self.doc 169 | 170 | all_linked_entities = doc._.linkedEntities 171 | url = all_linked_entities[0].get_url() 172 | assert url != None 173 | assert len(url) > 0 174 | assert 'wikidata.org/wiki/Q' in url 175 | 176 | def test___repr__(self): 177 | doc = self.doc 178 | 179 | all_linked_entities = doc._.linkedEntities 180 | repr = all_linked_entities[0].__repr__() 181 | assert repr != None 182 | assert len(repr) > 0 183 | 184 | def test___eq__(self): 185 | doc = self.doc 186 | 187 | all_linked_entities = doc._.linkedEntities 188 | assert not all_linked_entities[0] == all_linked_entities[1] 189 | assert all_linked_entities[4] == all_linked_entities[5] 190 | -------------------------------------------------------------------------------- /tests/test_EntityLinker.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spacy 3 | from spacy_entity_linker.EntityLinker import EntityLinker 4 | 5 | 6 | class TestEntityLinker(unittest.TestCase): 7 | 8 | def __init__(self, arg, *args, **kwargs): 9 | super(TestEntityLinker, self).__init__(arg, *args, **kwargs) 10 | self.nlp = spacy.load('en_core_web_sm') 11 | 12 | def test_initialization(self): 13 | 14 | self.nlp.add_pipe("entityLinker", last=True) 15 | 16 | doc = self.nlp( 17 | "Elon Musk was born in South Africa. Bill Gates and Steve Jobs come from in the United States") 18 | 19 | doc._.linkedEntities.pretty_print() 20 | doc._.linkedEntities.print_super_entities() 21 | for sent in doc.sents: 22 | sent._.linkedEntities.pretty_print() 23 | 24 | self.nlp.remove_pipe("entityLinker") 25 | 26 | def test_empty_root(self): 27 | # test empty lists of roots (#9) 28 | self.nlp.add_pipe("entityLinker", last=True) 29 | 30 | doc = self.nlp( 31 | 'I was right."\n\n "To that extent."\n\n "But that was all."\n\n "No, no, m') 32 | for sent in doc.sents: 33 | sent._.linkedEntities.pretty_print() 34 | # empty document 35 | doc = self.nlp('\n\n') 36 | for sent in doc.sents: 37 | sent._.linkedEntities.pretty_print() 38 | 39 | self.nlp.remove_pipe("entityLinker") 40 | -------------------------------------------------------------------------------- /tests/test_TermCandidateExtractor.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spacy 3 | import spacy_entity_linker.TermCandidateExtractor 4 | 5 | 6 | class TestCandidateExtractor(unittest.TestCase): 7 | 8 | def __init__(self, arg, *args, **kwargs): 9 | super(TestCandidateExtractor, self).__init__(arg, *args, **kwargs) 10 | -------------------------------------------------------------------------------- /tests/test_multiprocessing.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spacy 3 | from multiprocessing.pool import ThreadPool 4 | 5 | 6 | class TestMultiprocessing(unittest.TestCase): 7 | 8 | def __init__(self, arg, *args, **kwargs): 9 | super(TestMultiprocessing, self).__init__(arg, *args, **kwargs) 10 | self.nlp = spacy.load('en_core_web_sm') 11 | 12 | def test_is_pipe_multiprocessing_safe(self): 13 | self.nlp.add_pipe("entityLinker", last=True) 14 | 15 | ents = [ 16 | 'Apple', 17 | 'Microsoft', 18 | 'Google', 19 | 'Amazon', 20 | 'Facebook', 21 | 'IBM', 22 | 'Twitter', 23 | 'Tesla', 24 | 'SpaceX', 25 | 'Alphabet', 26 | ] 27 | text = "{} is looking at buying U.K. startup for $1 billion" 28 | 29 | texts = [text.format(ent) for ent in ents] 30 | docs = self.nlp.pipe(texts, n_process=2) 31 | for doc in docs: 32 | print(doc) 33 | for ent in doc.ents: 34 | print(ent.text, ent.label_, ent._.linkedEntities) 35 | 36 | self.nlp.remove_pipe("entityLinker") 37 | -------------------------------------------------------------------------------- /tests/test_multithreading.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spacy 3 | from multiprocessing.pool import ThreadPool 4 | 5 | 6 | class TestMultiThreading(unittest.TestCase): 7 | 8 | def __init__(self, arg, *args, **kwargs): 9 | super(TestMultiThreading, self).__init__(arg, *args, **kwargs) 10 | self.nlp = spacy.load('en_core_web_sm') 11 | 12 | def test_is_multithread_safe(self): 13 | self.nlp.add_pipe("entityLinker", last=True) 14 | 15 | ents = [ 16 | 'Apple', 17 | 'Microsoft', 18 | 'Google', 19 | 'Amazon', 20 | 'Facebook', 21 | 'IBM', 22 | 'Twitter', 23 | 'Tesla', 24 | 'SpaceX', 25 | 'Alphabet', 26 | ] 27 | text = "{} is looking at buying U.K. startup for $1 billion" 28 | 29 | def thread_func(i): 30 | doc = self.nlp(text.format(ents[i])) 31 | print(doc) 32 | 33 | for ent in doc.ents: 34 | print(ent.text, ent.label_, ent._.linkedEntities) 35 | return i 36 | 37 | with ThreadPool(10) as pool: 38 | for res in pool.imap_unordered(thread_func, range(10)): 39 | pass 40 | 41 | self.nlp.remove_pipe("entityLinker") 42 | -------------------------------------------------------------------------------- /tests/test_pipe.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spacy 3 | from multiprocessing.pool import ThreadPool 4 | 5 | 6 | class TestPipe(unittest.TestCase): 7 | 8 | def __init__(self, arg, *args, **kwargs): 9 | super(TestPipe, self).__init__(arg, *args, **kwargs) 10 | self.nlp = spacy.load('en_core_web_sm') 11 | 12 | def test_serialize(self): 13 | self.nlp.add_pipe("entityLinker", last=True) 14 | 15 | ents = [ 16 | 'Apple', 17 | 'Microsoft', 18 | 'Google', 19 | 'Amazon', 20 | 'Facebook', 21 | 'IBM', 22 | 'Twitter', 23 | 'Tesla', 24 | 'SpaceX', 25 | 'Alphabet', 26 | ] 27 | text = "{} is looking at buying U.K. startup for $1 billion" 28 | 29 | texts = [text.format(ent) for ent in ents] 30 | docs = self.nlp.pipe(texts, n_process=2) 31 | for doc in docs: 32 | print(doc) 33 | for ent in doc.ents: 34 | print(ent.text, ent.label_, ent._.linkedEntities) 35 | 36 | 37 | self.nlp.remove_pipe("entityLinker") 38 | -------------------------------------------------------------------------------- /tests/test_serialize.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spacy 3 | from multiprocessing.pool import ThreadPool 4 | 5 | 6 | class TestSerialize(unittest.TestCase): 7 | 8 | def __init__(self, arg, *args, **kwargs): 9 | super(TestSerialize, self).__init__(arg, *args, **kwargs) 10 | self.nlp = spacy.load('en_core_web_sm') 11 | 12 | def test_serialize(self): 13 | self.nlp.add_pipe("entityLinker", last=True) 14 | 15 | text = "Apple is looking at buying U.K. startup for $1 billion" 16 | doc = self.nlp(text) 17 | serialised = doc.to_bytes() 18 | 19 | doc2 = spacy.tokens.Doc(doc.vocab).from_bytes(serialised) 20 | for ent, ent2 in zip(doc.ents, doc2.ents): 21 | assert ent.text == ent2.text 22 | assert ent.label_ == ent2.label_ 23 | linked = ent._.linkedEntities 24 | linked2 = ent2._.linkedEntities 25 | if linked: 26 | assert linked.get_description() == linked2.get_description() 27 | assert linked.get_id() == linked2.get_id() 28 | assert linked.get_label() == linked2.get_label() 29 | assert linked.get_span() == linked2.get_span() 30 | assert linked.get_url() == linked2.get_url() 31 | 32 | 33 | self.nlp.remove_pipe("entityLinker") 34 | --------------------------------------------------------------------------------