├── .github
├── scripts
│ ├── run-tests.sh
│ └── tests-setup.sh
└── workflows
│ └── tests.yml
├── .gitignore
├── LICENSE
├── README.md
├── downloadKnowledgeBase.sh
├── requirements.txt
├── setup.py
├── spacy_entity_linker
├── DatabaseConnection.py
├── EntityCandidates.py
├── EntityClassifier.py
├── EntityCollection.py
├── EntityElement.py
├── EntityLinker.py
├── SpanInfo.py
├── TermCandidate.py
├── TermCandidateExtractor.py
├── __init__.py
└── __main__.py
└── tests
├── test_EntityCollection.py
├── test_EntityElement.py
├── test_EntityLinker.py
├── test_TermCandidateExtractor.py
├── test_multiprocessing.py
├── test_multithreading.py
├── test_pipe.py
└── test_serialize.py
/.github/scripts/run-tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | set -e
3 |
4 | echo "Running tests..."
5 | python -m unittest discover tests
6 | echo "Tests passed!"
--------------------------------------------------------------------------------
/.github/scripts/tests-setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | # This scripts sets up the environment for the tests to run.
3 |
4 | set -e
5 | # Install the spacy models that are used in the tests.
6 | python -m spacy download en_core_web_sm
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 | on:
3 | push:
4 | branches: [ master ]
5 | pull_request:
6 | branches: [ master ]
7 | jobs:
8 | test:
9 | name: Run tests
10 | runs-on: ubuntu-latest
11 | strategy:
12 | matrix:
13 | python-version: [ "3.7", "3.8", "3.9", "3.10", "3.11" ]
14 | steps:
15 | - uses: actions/checkout@v3
16 | - name: python ${{ matrix.python-version }}
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: ${{ matrix.python-version }}
20 | cache: 'pip' # caching pip dependencies
21 | - name: Install Python dependencies
22 | uses: py-actions/py-dependency-install@v4
23 | - name: Install additional for testing
24 | run: ./.github/scripts/tests-setup.sh
25 | shell: bash
26 | - name: Run the tests
27 | run: ./.github/scripts/run-tests.sh
28 | shell: bash
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | data/*
2 | .idea
3 | *.log
4 | .ipynb_checkpoints
5 | data_spacy_entity_linker
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 | cover/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | .pybuilder/
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | # For a library or package, you might want to ignore these files since the code is
93 | # intended to run in multiple environments; otherwise, check them in:
94 | # .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104 | __pypackages__/
105 |
106 | # Celery stuff
107 | celerybeat-schedule
108 | celerybeat.pid
109 |
110 | # SageMath parsed files
111 | *.sage.py
112 |
113 | # Environments
114 | .env
115 | .venv
116 | env/
117 | venv/
118 | ENV/
119 | env.bak/
120 | venv.bak/
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # mkdocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 | .dmypy.json
135 | dmypy.json
136 |
137 | # Pyre type checker
138 | .pyre/
139 |
140 | # pytype static type analyzer
141 | .pytype/
142 |
143 | # Cython debug symbols
144 | cython_debug/
145 |
146 | experimental_notebooks
147 | settings.json
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Emanuel Gerber
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://github.com/egerber/spaCy-entity-linker/actions/workflows/tests.yml)
2 | [](https://pepy.tech/project/spacy-entity-linker)
3 | [](https://github.com/egerber/spaCy-entity-linker/releases)
4 | [](https://pypi.org/project/spacy-entity-linker/)
5 | # Spacy Entity Linker
6 |
7 | ## Introduction
8 |
9 | Spacy Entity Linker is a pipeline for spaCy that performs Linked Entity Extraction with Wikidata on a given Document.
10 | The Entity Linking System operates by matching potential candidates from each sentence
11 | (subject, object, prepositional phrase, compounds, etc.) to aliases from Wikidata. The package allows to easily find the
12 | category behind each entity (e.g. "banana" is type "food" OR "Microsoft" is type "company"). It can is therefore useful
13 | for information extraction tasks and labeling tasks.
14 |
15 | The package was written before a working Linked Entity Solution existed inside spaCy. In comparison to spaCy's linked
16 | entity system, it has the following advantages:
17 |
18 | - no extensive training required (entity-matching via database)
19 | - knowledge base can be dynamically updated without retraining
20 | - entity categories can be easily resolved
21 | - grouping entities by category
22 |
23 | It also comes along with a number of disadvantages:
24 |
25 | - it is slower than the spaCy implementation due to the use of a database for finding entities
26 | - no context sensitivity due to the implementation of the "max-prior method" for entitiy disambiguation (an improved
27 | method for this is in progress)
28 |
29 |
30 | ## Installation
31 |
32 | To install the package, run:
33 | ```bash
34 | pip install spacy-entity-linker
35 | ```
36 |
37 | Afterwards, the knowledge base (Wikidata) must be downloaded. This can be either be done by manually calling
38 |
39 | ```bash
40 | python -m spacy_entity_linker "download_knowledge_base"
41 | ```
42 |
43 | or when you first access the entity linker through spacy.
44 | This will download and extract a ~1.3GB file that contains a preprocessed version of Wikidata.
45 |
46 | ## Use
47 |
48 | ```python
49 | import spacy # version 3.5
50 |
51 | # initialize language model
52 | nlp = spacy.load("en_core_web_md")
53 |
54 | # add pipeline (declared through entry_points in setup.py)
55 | nlp.add_pipe("entityLinker", last=True)
56 |
57 | doc = nlp("I watched the Pirates of the Caribbean last silvester")
58 |
59 | # returns all entities in the whole document
60 | all_linked_entities = doc._.linkedEntities
61 | # iterates over sentences and prints linked entities
62 | for sent in doc.sents:
63 | sent._.linkedEntities.pretty_print()
64 |
65 | # OUTPUT:
66 | # https://www.wikidata.org/wiki/Q194318 Pirates of the Caribbean Series of fantasy adventure films
67 | # https://www.wikidata.org/wiki/Q12525597 Silvester the day celebrated on 31 December (Roman Catholic Church) or 2 January (Eastern Orthodox Churches)
68 |
69 | # entities are also directly accessible through spans
70 | doc[3:7]._.linkedEntities.pretty_print()
71 | # OUTPUT:
72 | # https://www.wikidata.org/wiki/Q194318 Pirates of the Caribbean Series of fantasy adventure films
73 | ```
74 |
75 | ### EntityCollection
76 |
77 | contains an array of entity elements. It can be accessed like an array but also implements the following helper
78 | functions:
79 |
80 | - pretty_print()
prints out information about all contained entities
81 | - print_super_classes()
groups and prints all entites by their super class
82 |
83 | ```python
84 | doc = nlp("Elon Musk was born in South Africa. Bill Gates and Steve Jobs come from the United States")
85 | doc._.linkedEntities.print_super_entities()
86 | # OUTPUT:
87 | # human (3) : Elon Musk,Bill Gates,Steve Jobs
88 | # country (2) : South Africa,United States of America
89 | # sovereign state (2) : South Africa,United States of America
90 | # federal state (1) : United States of America
91 | # constitutional republic (1) : United States of America
92 | # democratic republic (1) : United States of America
93 | ```
94 |
95 | ### EntityElement
96 |
97 | each linked Entity is an object of type EntityElement
. Each entity contains the methods
98 |
99 | - get_description()
returns description from Wikidata
100 | - get_id()
returns Wikidata ID
101 | - get_label()
returns Wikidata label
102 | - get_span(doc)
returns the span from the spacy document that contains the linked entity. You need to provide the current `doc` as argument, in order to receive an actual `spacy.tokens.Span` object, otherwise you will receive a `SpanInfo` emulating the behaviour of a Span
103 | - get_url()
returns the url to the corresponding Wikidata item
104 | - pretty_print()
prints out information about the entity element
105 | - get_sub_entities(limit=10)
returns EntityCollection of all entities that derive from the current
106 | entityElement (e.g. fruit -> apple, banana, etc.)
107 | - get_super_entities(limit=10)
returns EntityCollection of all entities that the current entityElement
108 | derives from (e.g. New England Patriots -> Football Team))
109 |
110 |
111 | Usage of the `get_span` method with `SpanInfo`:
112 |
113 | ```python
114 | import spacy
115 | nlp = spacy.load('en_core_web_md')
116 | nlp.add_pipe("entityLinker", last=True)
117 | text = 'Apple is competing with Microsoft.'
118 | doc = nlp(text)
119 | sents = list(doc.sents)
120 | ent = doc._.linkedEntities[0]
121 |
122 | # using the SpanInfo class
123 | span = ent.get_span()
124 | print(span.start, span.end, span.text) # behaves like a Span
125 |
126 | # check equivalence
127 | print(span == doc[0:1]) # True
128 | print(doc[0:1] == span) # TypeError: Argument 'other' has incorrect type (expected spacy.tokens.span.Span, got SpanInfo)
129 |
130 | # now get the real span
131 | span = ent.get_span(doc) # passing the doc instance here
132 | print(span.start, span.end, span.text)
133 |
134 | print(span == doc[0:1]) # True
135 | print(doc[0:1] == span) # True
136 | ```
137 |
138 | ## Example
139 |
140 | In the following example we will use SpacyEntityLinker to find find the mentioned Football Team in our text and explore
141 | other football teams of the same type
142 |
143 | ```python
144 |
145 | doc = nlp("I follow the New England Patriots")
146 |
147 | patriots_entity = doc._.linkedEntities[0]
148 | patriots_entity.pretty_print()
149 | # OUTPUT:
150 | # https://www.wikidata.org/wiki/Q193390
151 | # New England Patriots
152 | # National Football League franchise in Foxborough, Massachusetts
153 |
154 | football_team_entity = patriots_entity.get_super_entities()[0]
155 | football_team_entity.pretty_print()
156 | # OUTPUT:
157 | # https://www.wikidata.org/wiki/Q17156793
158 | # American football team
159 | # organization, in which a group of players are organized to compete as a team in American football
160 |
161 |
162 | for child in football_team_entity.get_sub_entities(limit=32):
163 | print(child)
164 | # OUTPUT:
165 | # New Orleans Saints
166 | # New York Giants
167 | # Pittsburgh Steelers
168 | # New England Patriots
169 | # Indianapolis Colts
170 | # Miami Seahawks
171 | # Dallas Cowboys
172 | # Chicago Bears
173 | # Washington Redskins
174 | # Green Bay Packers
175 | # ...
176 | ```
177 |
178 | ### Entity Linking Policy
179 |
180 | Currently the only method for choosing an entity given different possible matches (e.g. Paris - city vs Paris -
181 | firstname) is max-prior. This method achieves around 70% accuracy on predicting the correct entities behind link
182 | descriptions on wikipedia.
183 |
184 | ## Note
185 |
186 | The Entity Linker at the current state is still experimental and should not be used in production mode.
187 |
188 | ## Performance
189 |
190 | The current implementation supports only Sqlite. This is advantageous for development because it does not requirement
191 | any special setup and configuration. However, for more performance critical usecases, a different database with
192 | in-memory access (e.g. Redis) should be used. This may be implemented in the future.
193 |
194 | ## Data
195 | the knowledge base was derived from this dataset: https://www.kaggle.com/kenshoresearch/kensho-derived-wikimedia-data
196 |
197 | It was cleaned and post-procesed, including filtering out entities of "overrepresented" categories such as
198 | * village in China
199 | * train stations
200 | * stars in the Galaxy
201 | * etc.
202 |
203 | The purpose behind the knowledge base cleaning was to reduce the knowledge base size, while keeping the most useful entities for general purpose applications.
204 |
205 | Currently, the only way to change the knowledge base is a bit hacky and requires to replace or modify the underlying sqlite database. You will find it under site_packages/data_spacy_entity_linker/wikidb_filtered.db
. The database contains 3 tables:
206 | * aliases
207 | * en_alias (english alias)
208 | * en_alias_lowercase (english alias lowercased)
209 | * joined
210 | * en_label (label of the wikidata item)
211 | * views (number of views of the corresponding wikipedia page (in a given period of time))
212 | * inlinks (number of inlinks to the corresponding wikipedia page)
213 | * item_id (wikidata id)
214 | * description (description of the wikidata item)
215 | * statements
216 | * source_item_id (references item_id)
217 | * target_item_id (references item_id)
218 | * edge_property_id
219 | * 279=subclass of (https://www.wikidata.org/wiki/Property:P279)
220 | * 31=instance of (https://www.wikidata.org/wiki/Property:P31)
221 | * 361=part of (https://www.wikidata.org/wiki/Property:P361)
222 |
223 |
224 | ## Versions:
225 |
226 | - spacy_entity_linker>=0.0
(requires spacy>=2.2,<3.0
)
227 | - spacy_entity_linker>=1.0
(requires spacy>=3.0
)
228 |
229 | ## TODO
230 |
231 | - [ ] implement Entity Classifier based on sentence embeddings for improved accuracy
232 | - [ ] implement get_picture_urls() on EntityElement
233 | - [ ] retrieve statements for each EntityElement (inlinks + outlinks)
234 |
--------------------------------------------------------------------------------
/downloadKnowledgeBase.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget "https://huggingface.co/MartinoMensio/spaCy-entity-linker/resolve/main/knowledge_base.tar.gz" -O /tmp/knowledge_base.tar.gz
4 | tar -xzf /tmp/knowledge_base.tar.gz --directory ./data_spacy_entity_linker
5 | rm /tmp/knowledge_base.tar.gz
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -e .
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | import os
6 |
7 | try:
8 | from setuptools import setup
9 | except ImportError:
10 | from distutils.core import setup
11 |
12 |
13 | def open_file(fname):
14 | return open(os.path.join(os.path.dirname(__file__), fname))
15 |
16 |
17 | with open("README.md", "r") as fh:
18 | long_description = fh.read()
19 |
20 | setup(
21 | name='spacy-entity-linker',
22 | version='1.0.3',
23 | author='Emanuel Gerber',
24 | author_email='emanuel.j.gerber@gmail.com',
25 | packages=['spacy_entity_linker'],
26 | url='https://github.com/egerber/spacy-entity-linker',
27 | license="MIT",
28 | classifiers=["Environment :: Console",
29 | "Intended Audience :: Developers",
30 | "Intended Audience :: Science/Research",
31 | "License :: OSI Approved :: MIT License",
32 | "Programming Language :: Cython",
33 | "Programming Language :: Python",
34 | "Programming Language :: Python :: 3.6"
35 | ],
36 | description='Linked Entity Pipeline for spaCy',
37 | long_description=long_description,
38 | long_description_content_type="text/markdown",
39 | zip_safe=True,
40 | install_requires=[
41 | 'spacy>=3.0.0',
42 | 'numpy>=1.0.0',
43 | 'tqdm'
44 | ],
45 | entry_points={
46 | 'spacy_factories': 'entityLinker = spacy_entity_linker.EntityLinker:EntityLinker'
47 | }
48 | )
49 |
--------------------------------------------------------------------------------
/spacy_entity_linker/DatabaseConnection.py:
--------------------------------------------------------------------------------
1 | import sqlite3
2 | import os
3 |
4 | from .__main__ import download_knowledge_base
5 |
6 | MAX_DEPTH_CHAIN = 10
7 | P_INSTANCE_OF = 31
8 | P_SUBCLASS = 279
9 |
10 | MAX_ITEMS_CACHE = 100000
11 |
12 | conn = None
13 | entity_cache = {}
14 | chain_cache = {}
15 |
16 | DB_DEFAULT_PATH = os.path.abspath(os.path.join(__file__, "../../data_spacy_entity_linker/wikidb_filtered.db"))
17 |
18 | wikidata_instance = None
19 |
20 |
21 | def get_wikidata_instance():
22 | global wikidata_instance
23 |
24 | if wikidata_instance is None:
25 | wikidata_instance = WikidataQueryController()
26 |
27 | return wikidata_instance
28 |
29 |
30 | class WikidataQueryController:
31 |
32 | def __init__(self):
33 | self.conn = None
34 |
35 | self.cache = {
36 | "entity": {},
37 | "chain": {},
38 | "name": {}
39 | }
40 |
41 | self.init_database_connection()
42 |
43 | def _get_cached_value(self, cache_type, key):
44 | return self.cache[cache_type][key]
45 |
46 | def _is_cached(self, cache_type, key):
47 | return key in self.cache[cache_type]
48 |
49 | def _add_to_cache(self, cache_type, key, value):
50 | if len(self.cache[cache_type]) < MAX_ITEMS_CACHE:
51 | self.cache[cache_type][key] = value
52 |
53 | def init_database_connection(self, path=DB_DEFAULT_PATH):
54 | # check if the database exists
55 | if not os.path.exists(DB_DEFAULT_PATH):
56 | # Automatically download the knowledge base if it isn't already
57 | download_knowledge_base()
58 | self.conn = sqlite3.connect(path, check_same_thread=False)
59 |
60 | def clear_cache(self):
61 | self.cache["entity"].clear()
62 | self.cache["chain"].clear()
63 | self.cache["name"].clear()
64 |
65 | def get_entities_from_alias(self, alias):
66 | c = self.conn.cursor()
67 | if self._is_cached("entity", alias):
68 | return self._get_cached_value("entity", alias).copy()
69 |
70 | query_alias = """SELECT j.item_id,j.en_label, j.en_description,j.views,j.inlinks,a.en_alias
71 | FROM aliases as a LEFT JOIN joined as j ON a.item_id = j.item_id
72 | WHERE a.en_alias_lowercase = ? AND j.item_id NOT NULL"""
73 |
74 | c.execute(query_alias, [alias.lower()])
75 | fetched_rows = c.fetchall()
76 |
77 | self._add_to_cache("entity", alias, fetched_rows)
78 | return fetched_rows
79 |
80 | def get_instances_of(self, item_id, properties=[P_INSTANCE_OF, P_SUBCLASS], count=1000):
81 | query = "SELECT source_item_id from statements where target_item_id={} and edge_property_id IN ({}) LIMIT {}".format(
82 | item_id, ",".join([str(prop) for prop in properties]), count)
83 |
84 | c = self.conn.cursor()
85 | c.execute(query)
86 |
87 | res = c.fetchall()
88 |
89 | return [e[0] for e in res]
90 |
91 | def get_entity_name(self, item_id):
92 | if self._is_cached("name", item_id):
93 | return self._get_cached_value("name", item_id)
94 |
95 | c = self.conn.cursor()
96 | query = "SELECT en_label from joined WHERE item_id=?"
97 | c.execute(query, [item_id])
98 | res = c.fetchone()
99 |
100 | if res and len(res):
101 | if res[0] is None:
102 | self._add_to_cache("name", item_id, 'no label')
103 | else:
104 | self._add_to_cache("name", item_id, res[0])
105 | else:
106 | self._add_to_cache("name", item_id, '')
107 |
108 | return self._get_cached_value("name", item_id)
109 |
110 | def get_entity(self, item_id):
111 | c = self.conn.cursor()
112 | query = "SELECT j.item_id,j.en_label,j.en_description,j.views,j.inlinks from joined as j " \
113 | "WHERE j.item_id=={}".format(item_id)
114 |
115 | res = c.execute(query)
116 |
117 | return res.fetchone()
118 |
119 | def get_children(self, item_id, limit=100):
120 | c = self.conn.cursor()
121 | query = "SELECT j.item_id,j.en_label,j.en_description,j.views,j.inlinks from joined as j " \
122 | "JOIN statements as s on j.item_id=s.source_item_id " \
123 | "WHERE s.target_item_id={} and s.edge_property_id IN (279,31) LIMIT {}".format(item_id, limit)
124 |
125 | res = c.execute(query)
126 |
127 | return res.fetchall()
128 |
129 | def get_parents(self, item_id, limit=100):
130 | c = self.conn.cursor()
131 | query = "SELECT j.item_id,j.en_label,j.en_description,j.views,j.inlinks from joined as j " \
132 | "JOIN statements as s on j.item_id=s.target_item_id " \
133 | "WHERE s.source_item_id={} and s.edge_property_id IN (279,31) LIMIT {}".format(item_id, limit)
134 |
135 | res = c.execute(query)
136 |
137 | return res.fetchall()
138 |
139 | def get_categories(self, item_id, max_depth=10):
140 | chain = []
141 | edges = []
142 | self._append_chain_elements(item_id, 0, chain, edges, max_depth, [P_INSTANCE_OF, P_SUBCLASS])
143 | return [el[0] for el in chain]
144 |
145 | def get_chain(self, item_id, max_depth=10, property=P_INSTANCE_OF):
146 | chain = []
147 | edges = []
148 | self._append_chain_elements(item_id, 0, chain, edges, max_depth, property)
149 | return chain
150 |
151 | def get_recursive_edges(self, item_id):
152 | chain = []
153 | edges = []
154 | self._append_chain_elements(self, item_id, 0, chain, edges)
155 | return edges
156 |
157 | def _append_chain_elements(self, item_id, level=0, chain=None, edges=None, max_depth=10, prop=P_INSTANCE_OF):
158 | if chain is None:
159 | chain = []
160 | if edges is None:
161 | edges = []
162 | properties = prop
163 | if type(prop) != list:
164 | properties = [prop]
165 |
166 | if self._is_cached("chain", (item_id, max_depth)):
167 | chain += self._get_cached_value("chain", (item_id, max_depth)).copy()
168 | return
169 |
170 | # prevent infinite recursion
171 | if level >= max_depth:
172 | return
173 |
174 | c = self.conn.cursor()
175 |
176 | query = "SELECT target_item_id,edge_property_id from statements where source_item_id={} and edge_property_id IN ({})".format(
177 | item_id, ",".join([str(prop) for prop in properties]))
178 |
179 | # set value for current item in order to prevent infinite recursion
180 | self._add_to_cache("chain", (item_id, max_depth), [])
181 |
182 | for target_item in c.execute(query):
183 |
184 | chain_ids = [el[0] for el in chain]
185 |
186 | if not (target_item[0] in chain_ids):
187 | chain += [(target_item[0], level + 1)]
188 | edges.append((item_id, target_item[0], target_item[1]))
189 | self._append_chain_elements(target_item[0],
190 | level=level + 1,
191 | chain=chain,
192 | edges=edges,
193 | max_depth=max_depth,
194 | prop=prop)
195 |
196 | self._add_to_cache("chain", (item_id, max_depth), chain)
197 |
198 |
199 | if __name__ == '__main__':
200 | queryInstance = WikidataQueryController()
201 |
202 | queryInstance.init_database_connection()
203 | print(queryInstance.get_categories(13191, max_depth=1))
204 | print(queryInstance.get_categories(13191, max_depth=1))
205 |
--------------------------------------------------------------------------------
/spacy_entity_linker/EntityCandidates.py:
--------------------------------------------------------------------------------
1 | MAX_ITEMS_PREVIEW=20
2 |
3 | class EntityCandidates:
4 |
5 | def __init__(self, entity_elements):
6 | self.entity_elements = entity_elements
7 |
8 | def __iter__(self):
9 | for entity in self.entity_elements:
10 | yield entity
11 |
12 | def __len__(self):
13 | return len(self.entity_elements)
14 |
15 | def __getitem__(self, item):
16 | return self.entity_elements[item]
17 |
18 | def pretty_print(self):
19 | for entity in self.entity_elements:
20 | entity.pretty_print()
21 |
22 | def __repr__(self) -> str:
23 | preview_str=""
24 | for index,entity_element in enumerate(self):
25 | if index>MAX_ITEMS_PREVIEW:
26 | break
27 | preview_str+="{}\n".format(entity_element.get_preview_string())
28 |
29 | return preview_str
30 |
31 | def __str__(self):
32 | return str(["entity {}: {} (<{}>)".format(i, entity.get_label(), entity.get_description()) for i, entity in
33 | enumerate(self.entity_elements)])
34 |
--------------------------------------------------------------------------------
/spacy_entity_linker/EntityClassifier.py:
--------------------------------------------------------------------------------
1 | from itertools import groupby
2 | import numpy as np
3 |
4 |
5 | class EntityClassifier:
6 | def __init__(self):
7 | pass
8 |
9 | def _get_grouped_by_length(self, entities):
10 | sorted_by_len = sorted(entities, key=lambda entity: len(entity.get_span()), reverse=True)
11 |
12 | entities_by_length = {}
13 | for length, group in groupby(sorted_by_len, lambda entity: len(entity.get_span())):
14 | entities = list(group)
15 | entities_by_length[length] = entities
16 |
17 | return entities_by_length
18 |
19 | def _filter_max_length(self, entities):
20 | entities_by_length = self._get_grouped_by_length(entities)
21 | max_length = max(list(entities_by_length.keys()))
22 |
23 | return entities_by_length[max_length]
24 |
25 | def _select_max_prior(self, entities):
26 | priors = [entity.get_prior() for entity in entities]
27 | return entities[np.argmax(priors)]
28 |
29 | def _get_casing_difference(self, word1, original):
30 | difference = 0
31 | for w1, w2 in zip(word1, original):
32 | if w1 != w2:
33 | difference += 1
34 |
35 | return difference
36 |
37 | def _filter_most_similar(self, entities):
38 | similarities = np.array(
39 | [self._get_casing_difference(entity.get_span().text, entity.get_original_alias()) for entity in entities])
40 |
41 | min_indices = np.where(similarities == similarities.min())[0].tolist()
42 |
43 | return [entities[i] for i in min_indices]
44 |
45 | def __call__(self, entities):
46 | filtered_by_length = self._filter_max_length(entities)
47 | filtered_by_casing = self._filter_most_similar(filtered_by_length)
48 |
49 | return self._select_max_prior(filtered_by_casing)
50 |
--------------------------------------------------------------------------------
/spacy_entity_linker/EntityCollection.py:
--------------------------------------------------------------------------------
1 | import srsly
2 | from collections import Counter, defaultdict
3 |
4 | from .DatabaseConnection import get_wikidata_instance
5 |
6 | MAX_ITEMS_PREVIEW=20
7 |
8 |
9 | class EntityCollection:
10 |
11 | def __init__(self, entities=[]):
12 | self.entities = entities
13 |
14 | def __iter__(self):
15 | for entity in self.entities:
16 | yield entity
17 |
18 | def __getitem__(self, item):
19 | return self.entities[item]
20 |
21 | def __len__(self):
22 | return len(self.entities)
23 |
24 | def append(self, entity):
25 | self.entities.append(entity)
26 |
27 | def get_categories(self, max_depth=1):
28 | categories = []
29 | for entity in self.entities:
30 | categories += entity.get_categories(max_depth)
31 |
32 | return categories
33 |
34 | def print_super_entities(self, max_depth=1, limit=10):
35 | wikidataInstance = get_wikidata_instance()
36 |
37 | all_categories = []
38 | category_to_entites = defaultdict(list)
39 |
40 | for e in self.entities:
41 | for category in e.get_categories(max_depth):
42 | category_to_entites[category].append(e)
43 | all_categories.append(category)
44 |
45 | counter = Counter()
46 | counter.update(all_categories)
47 |
48 | for category, frequency in counter.most_common(limit):
49 | print("{} ({}) : {}".format(wikidataInstance.get_entity_name(category), frequency,
50 | ','.join([str(e) for e in category_to_entites[category]])))
51 |
52 | def __repr__(self) -> str:
53 | preview_str="MAX_ITEMS_PREVIEW:
56 | preview_str+="\n...{} more".format(len(self)-MAX_ITEMS_PREVIEW)
57 | break
58 | preview_str+="\n-{}".format(entity_element.get_preview_string())
59 |
60 | preview_str+=">"
61 | return preview_str
62 |
63 | def pretty_print(self):
64 | for entity in self.entities:
65 | entity.pretty_print()
66 |
67 | def grouped_by_super_entities(self, max_depth=1):
68 | counter = Counter()
69 | counter.update(self.get_categories(max_depth))
70 |
71 | return counter
72 |
73 | def get_distinct_categories(self, max_depth=1):
74 | return list(set(self.get_categories(max_depth)))
75 |
76 |
77 | @srsly.msgpack_encoders("EntityCollection")
78 | def serialize_obj(obj, chain=None):
79 | if isinstance(obj, EntityCollection):
80 | return {
81 | "entities": obj.entities,
82 | }
83 | # otherwise return the original object so another serializer can handle it
84 | return obj if chain is None else chain(obj)
85 |
86 |
87 | @srsly.msgpack_decoders("EntityCollection")
88 | def deserialize_obj(obj, chain=None):
89 | if "entities" in obj:
90 | return EntityCollection(entities=obj["entities"])
91 | # otherwise return the original object so another serializer can handle it
92 | return obj if chain is None else chain(obj)
--------------------------------------------------------------------------------
/spacy_entity_linker/EntityElement.py:
--------------------------------------------------------------------------------
1 | import spacy
2 | import srsly
3 |
4 | from .DatabaseConnection import get_wikidata_instance
5 | from .EntityCollection import EntityCollection
6 | from .SpanInfo import SpanInfo
7 |
8 | class EntityElement:
9 | def __init__(self, row, span):
10 | self.identifier = row[0]
11 | self.prior = 0
12 | self.original_alias = None
13 | self.in_degree = None
14 | self.label = None
15 | self.description = None
16 |
17 | if len(row) > 1:
18 | self.label = row[1]
19 | if len(row) > 2:
20 | self.description = row[2]
21 | if len(row) > 3 and row[3]:
22 | self.prior = row[3]
23 | if len(row) > 4 and row[4]:
24 | self.in_degree = row[4]
25 | if len(row) > 5 and row[5]:
26 | self.original_alias = row[5]
27 |
28 | self.url="https://www.wikidata.org/wiki/Q{}".format(self.get_id())
29 | if span:
30 | self.span_info = SpanInfo.from_span(span)
31 | else:
32 | # sometimes the constructor is called with None as second parameter (e.g. in get_sub_entities/get_super_entities)
33 | self.span_info = None
34 |
35 | self.chain = None
36 | self.chain_ids = None
37 |
38 | self.wikidata_instance = get_wikidata_instance()
39 |
40 | def get_in_degree(self):
41 | return self.in_degree
42 |
43 | def get_original_alias(self):
44 | return self.original_alias
45 |
46 | def is_singleton(self):
47 | return len(self.get_chain()) == 0
48 |
49 | def get_span(self, doc: spacy.tokens.Doc=None):
50 | """
51 | Returns the span of the entity in the document.
52 | :param doc: the document in which the entity is contained
53 | :return: the span of the entity in the document
54 |
55 | If the doc is not None, it returns a real spacy.tokens.Span.
56 | Otherwise it returns the instance of SpanInfo that emulates the behaviour of a spacy.tokens.Span
57 | """
58 | if doc is not None:
59 | # return a real spacy.tokens.Span
60 | return self.span_info.get_span(doc)
61 | # otherwise return the instance of SpanInfo that emulates the behaviour of a spacy.tokens.Span
62 | return self.span_info
63 |
64 | def get_label(self):
65 | return self.label
66 |
67 | def get_id(self):
68 | return self.identifier
69 |
70 | def get_prior(self):
71 | return self.prior
72 |
73 | def get_chain(self, max_depth=10):
74 | if self.chain is None:
75 | self.chain = self.wikidata_instance.get_chain(self.identifier, max_depth=max_depth, property=31)
76 | return self.chain
77 |
78 | def is_category(self):
79 | pass
80 |
81 | def is_leaf(self):
82 | pass
83 |
84 | def get_categories(self, max_depth=10):
85 | return self.wikidata_instance.get_categories(self.identifier, max_depth=max_depth)
86 |
87 | def get_sub_entities(self, limit=10):
88 | return EntityCollection(
89 | [EntityElement(row, None) for row in self.wikidata_instance.get_children(self.get_id(), limit)])
90 |
91 | def get_super_entities(self, limit=10):
92 | return EntityCollection(
93 | [EntityElement(row, None) for row in self.wikidata_instance.get_parents(self.get_id(), limit)])
94 |
95 | def get_subclass_hierarchy(self):
96 | chain = self.wikidata_instance.get_chain(self.identifier, max_depth=5, property=279)
97 | return [self.wikidata_instance.get_entity_name(el[0]) for el in chain]
98 |
99 | def get_instance_of_hierarchy(self):
100 | chain = self.wikidata_instance.get_chain(self.identifier, max_depth=5, property=31)
101 | return [self.wikidata_instance.get_entity_name(el[0]) for el in chain]
102 |
103 | def get_chain_ids(self, max_depth=10):
104 | if self.chain_ids is None:
105 | self.chain_ids = set([el[0] for el in self.get_chain(max_depth=max_depth)])
106 |
107 | return self.chain_ids
108 |
109 | def get_description(self):
110 | if self.description:
111 | return self.description
112 | else:
113 | return ""
114 |
115 | def is_intersecting(self, other_element):
116 | return len(self.get_chain_ids().intersection(other_element.get_chain_ids())) > 0
117 |
118 | def serialize(self):
119 | return {
120 | "id": self.get_id(),
121 | "label": self.get_label(),
122 | "span": self.get_span()
123 | }
124 |
125 | def pretty_print(self):
126 | print(self.__repr__())
127 |
128 | def get_url(self):
129 | return self.url
130 |
131 | def __repr__(self):
132 | return "".format(self.get_preview_string())
133 |
134 | def get_preview_string(self):
135 | return "{0:<10} {1:<25} {2:<50}".format(self.get_url(),self.get_label(),self.get_description()[:100])
136 |
137 | def pretty_string(self, description=False):
138 | if description:
139 | return "{} => {} <{}>".format(self.span_info, self.get_label(), self.get_description())
140 | else:
141 | return "{} => {}".format(self.span_info, self.get_label())
142 |
143 | # TODO: this method has never worked because the custom attribute is not registered properly
144 | # def save(self, category):
145 | # for span in self.span:
146 | # span.sent._.linked_entities.append(
147 | # {"id": self.identifier, "range": [span.start, span.end + 1], "category": category})
148 |
149 | def __str__(self):
150 | label = self.get_label()
151 | if label:
152 | return label
153 | else:
154 | return ""
155 |
156 | def __eq__(self, other):
157 | return isinstance(other, EntityElement) and other.get_id() == self.get_id()
158 |
159 |
160 | @srsly.msgpack_encoders("EntityElement")
161 | def serialize_obj(obj, chain=None):
162 | if isinstance(obj, EntityElement):
163 | result = {
164 | "identifier": obj.identifier,
165 | "label": obj.label,
166 | "description": obj.description,
167 | "prior": obj.prior,
168 | "in_degree": obj.in_degree,
169 | "original_alias": obj.original_alias,
170 | "span_info": obj.span_info,
171 | }
172 | return result
173 | # otherwise return the original object so another serializer can handle it
174 | return obj if chain is None else chain(obj)
175 |
176 |
177 | @srsly.msgpack_decoders("EntityElement")
178 | def deserialize_obj(obj, chain=None):
179 | if "identifier" in obj:
180 | row = [obj['identifier'], obj['label'], obj['description'], obj['prior'], obj['in_degree'], obj['original_alias']]
181 | span_info = obj['span_info']
182 | return EntityElement(row, span_info)
183 | # otherwise return the original object so another serializer can handle it
184 | return obj if chain is None else chain(obj)
--------------------------------------------------------------------------------
/spacy_entity_linker/EntityLinker.py:
--------------------------------------------------------------------------------
1 | from spacy.tokens import Doc, Span
2 | from spacy.language import Language
3 |
4 | from .EntityClassifier import EntityClassifier
5 | from .EntityCollection import EntityCollection
6 | from .TermCandidateExtractor import TermCandidateExtractor
7 |
8 | @Language.factory('entityLinker')
9 | class EntityLinker:
10 |
11 | def __init__(self, nlp, name):
12 | Doc.set_extension("linkedEntities", default=EntityCollection(), force=True)
13 | Span.set_extension("linkedEntities", default=None, force=True)
14 |
15 | def __call__(self, doc):
16 | tce = TermCandidateExtractor(doc)
17 | classifier = EntityClassifier()
18 |
19 | for sent in doc.sents:
20 | sent._.linkedEntities = EntityCollection([])
21 |
22 | entities = []
23 | for termCandidates in tce:
24 | entityCandidates = termCandidates.get_entity_candidates()
25 | if len(entityCandidates) > 0:
26 | entity = classifier(entityCandidates)
27 | span = doc[entity.span_info.start:entity.span_info.end]
28 | # Add the entity to the sentence-level EntityCollection
29 | span.sent._.linkedEntities.append(entity)
30 | # Also associate the token span with the entity
31 | span._.linkedEntities = entity
32 | # And finally append to the document-level collection
33 | entities.append(entity)
34 |
35 | doc._.linkedEntities = EntityCollection(entities)
36 |
37 | return doc
38 |
--------------------------------------------------------------------------------
/spacy_entity_linker/SpanInfo.py:
--------------------------------------------------------------------------------
1 | """
2 | SpanInfo class
3 | Stores the info of spacy.tokens.Span (start, end and text of a span) by making it serializable
4 | """
5 |
6 | import spacy
7 | import srsly
8 |
9 | class SpanInfo:
10 |
11 | @staticmethod
12 | def from_span(span: spacy.tokens.Span):
13 | return SpanInfo(span.start, span.end, span.text)
14 |
15 | def __init__(self, start: int, end: int, text: str):
16 | self.start = start
17 | self.end = end
18 | self.text = text
19 |
20 |
21 | def __repr__(self) -> str:
22 | return self.text
23 |
24 | def __len__(self):
25 | return self.end - self.start
26 |
27 | def __eq__(self, __o: object) -> bool:
28 | if isinstance(__o, SpanInfo) or isinstance(__o, spacy.tokens.Span):
29 | return self.start == __o.start and self.end == __o.end and self.text == __o.text
30 | return False
31 |
32 | def get_span(self, doc: spacy.tokens.Doc):
33 | """
34 | Returns the real spacy.tokens.Span of the doc from the stored info"""
35 | return doc[self.start:self.end]
36 |
37 |
38 | @srsly.msgpack_encoders("SpanInfo")
39 | def serialize_spaninfo(obj, chain=None):
40 | if isinstance(obj, SpanInfo):
41 | result = {
42 | "start": obj.start,
43 | "end": obj.end,
44 | "text": obj.text,
45 | }
46 | return result
47 | # otherwise return the original object so another serializer can handle it
48 | return obj if chain is None else chain(obj)
49 |
50 | @srsly.msgpack_decoders("SpanInfo")
51 | def deserialize_spaninfo(obj, chain=None):
52 | if "start" in obj:
53 | return SpanInfo(obj['start'], obj['end'], obj['text'])
54 | # otherwise return the original object so another serializer can handle it
55 | return obj if chain is None else chain(obj)
--------------------------------------------------------------------------------
/spacy_entity_linker/TermCandidate.py:
--------------------------------------------------------------------------------
1 | from .EntityCandidates import EntityCandidates
2 | from .EntityElement import EntityElement
3 | from .DatabaseConnection import get_wikidata_instance
4 |
5 |
6 | class TermCandidate:
7 | def __init__(self, span):
8 | self.variations = [span]
9 |
10 | def pretty_print(self):
11 | print("Term Candidates are [{}]".format(self))
12 |
13 | def append(self, span):
14 | self.variations.append(span)
15 |
16 | def has_plural(self, variation):
17 | return any([t.tag_ == "NNS" for t in variation])
18 |
19 | def get_singular(self, variation):
20 | return ' '.join([t.text if t.tag_ != "NNS" else t.lemma_ for t in variation])
21 |
22 | def __str__(self):
23 | return ', '.join([variation.text for variation in self.variations])
24 |
25 | def get_entity_candidates(self):
26 | wikidata_instance = get_wikidata_instance()
27 | entities_by_variation = {}
28 | for variation in self.variations:
29 | entities_by_variation[variation] = wikidata_instance.get_entities_from_alias(variation.text)
30 | if self.has_plural(variation):
31 | entities_by_variation[variation] += wikidata_instance.get_entities_from_alias(
32 | self.get_singular(variation))
33 |
34 | entity_elements = []
35 | for variation, entities in entities_by_variation.items():
36 | entity_elements += [EntityElement(entity, variation) for entity in entities]
37 |
38 | return EntityCandidates(entity_elements)
39 |
--------------------------------------------------------------------------------
/spacy_entity_linker/TermCandidateExtractor.py:
--------------------------------------------------------------------------------
1 | from .TermCandidate import TermCandidate
2 |
3 |
4 | class TermCandidateExtractor:
5 | def __init__(self, doc):
6 | self.doc = doc
7 |
8 | def __iter__(self):
9 | for sent in self.doc.sents:
10 | for candidate in self._get_candidates_in_sent(sent, self.doc):
11 | yield candidate
12 |
13 | def _get_candidates_in_sent(self, sent, doc):
14 | roots = list(filter(lambda token: token.dep_ == "ROOT", sent))
15 | if len(roots) < 1:
16 | return []
17 | root = roots[0]
18 |
19 | excluded_children = []
20 | candidates = []
21 |
22 | def get_candidates(node, doc):
23 |
24 | if (node.pos_ in ["PROPN", "NOUN"]) and node.pos_ not in ["PRON"]:
25 | term_candidates = TermCandidate(doc[node.i:node.i + 1])
26 |
27 | for child in node.children:
28 |
29 | start_index = min(node.i, child.i)
30 | end_index = max(node.i, child.i)
31 |
32 | if child.dep_ == "compound" or child.dep_ == "amod":
33 | subtree_tokens = list(child.subtree)
34 | if all([c.dep_ == "compound" for c in subtree_tokens]):
35 | start_index = min([c.i for c in subtree_tokens])
36 | term_candidates.append(doc[start_index:end_index + 1])
37 |
38 | if not child.dep_ == "amod":
39 | term_candidates.append(doc[start_index:start_index + 1])
40 | excluded_children.append(child)
41 |
42 | if child.dep_ == "prep" and child.text == "of":
43 | end_index = max([c.i for c in child.subtree])
44 | term_candidates.append(doc[start_index:end_index + 1])
45 |
46 | candidates.append(term_candidates)
47 |
48 | for child in node.children:
49 | if child in excluded_children:
50 | continue
51 | get_candidates(child, doc)
52 |
53 | get_candidates(root, doc)
54 |
55 | return candidates
56 |
--------------------------------------------------------------------------------
/spacy_entity_linker/__init__.py:
--------------------------------------------------------------------------------
1 | try: # Python 3.8
2 | import importlib.metadata as importlib_metadata
3 | except ImportError:
4 | import importlib_metadata # noqa: F401
5 |
6 | from .EntityLinker import EntityLinker
7 |
8 | pkg_meta = importlib_metadata.metadata(__name__.split(".")[0])
9 | __version__ = pkg_meta["version"]
10 | __all__ = [EntityLinker]
11 |
--------------------------------------------------------------------------------
/spacy_entity_linker/__main__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import tarfile
3 | import urllib.request
4 | import tqdm
5 | import os
6 |
7 |
8 | class DownloadProgressBar(tqdm.tqdm):
9 | """
10 | Code taken from https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
11 | """
12 | def update_to(self, chunk_id=1, max_chunk_size=1, total_size=None):
13 | if total_size is not None:
14 | self.total = total_size
15 | self.update(chunk_id * max_chunk_size - self.n)
16 |
17 |
18 | def download_knowledge_base(
19 | file_url="https://huggingface.co/MartinoMensio/spaCy-entity-linker/resolve/main/knowledge_base.tar.gz"
20 | ):
21 | OUTPUT_TAR_FILE = os.path.abspath(
22 | os.path.dirname(__file__)) + '/../data_spacy_entity_linker/wikidb_filtered.tar.gz'
23 | OUTPUT_DB_PATH = os.path.abspath(os.path.dirname(__file__)) + '/../data_spacy_entity_linker'
24 | if not os.path.exists(OUTPUT_DB_PATH):
25 | os.makedirs(OUTPUT_DB_PATH)
26 | with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc='Downloading knowledge base') as dpb:
27 | urllib.request.urlretrieve(file_url, filename=OUTPUT_TAR_FILE, reporthook=dpb.update_to)
28 |
29 | tar = tarfile.open(OUTPUT_TAR_FILE)
30 | tar.extractall(OUTPUT_DB_PATH)
31 | tar.close()
32 |
33 | os.remove(OUTPUT_TAR_FILE)
34 |
35 |
36 | if __name__ == "__main__":
37 |
38 | if len(sys.argv) < 2:
39 | print("No arguments given.")
40 | pass
41 |
42 | command = sys.argv.pop(1)
43 |
44 | if command == "download_knowledge_base":
45 | download_knowledge_base()
46 | else:
47 | raise ValueError("Unrecognized command given. If you are trying to install the knowledge base, run "
48 | "'python -m spacy_entity_linker \"download_knowledge_base\"'.")
49 |
--------------------------------------------------------------------------------
/tests/test_EntityCollection.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spacy
3 | from spacy_entity_linker.EntityCollection import EntityCollection
4 |
5 |
6 | class TestEntityCollection(unittest.TestCase):
7 |
8 | def __init__(self, arg, *args, **kwargs):
9 | super(TestEntityCollection, self).__init__(arg, *args, **kwargs)
10 | self.nlp = spacy.load('en_core_web_sm')
11 |
12 | def setUp(self):
13 | self.nlp.add_pipe("entityLinker", last=True)
14 | self.doc = self.nlp(
15 | "Elon Musk was born in South Africa. Bill Gates and Steve Jobs come from the United States")
16 |
17 | def tearDown(self):
18 | self.nlp.remove_pipe("entityLinker")
19 |
20 | def test_categories(self):
21 | doc = self.doc
22 |
23 | res = doc._.linkedEntities.get_distinct_categories()
24 | print(res)
25 | assert res != None
26 | assert len(res) > 0
27 |
28 | res = doc._.linkedEntities.grouped_by_super_entities()
29 | print(res)
30 | assert res != None
31 | assert len(res) > 0
32 |
33 | def test_printing(self):
34 | doc = self.doc
35 |
36 | # pretty print
37 | doc._.linkedEntities.pretty_print()
38 |
39 | # repr
40 | print(doc._.linkedEntities)
41 |
42 | def test_super_entities(self):
43 | doc = self.doc
44 |
45 | doc._.linkedEntities.print_super_entities()
46 |
47 | def test_iterable_indexable(self):
48 | doc = self.doc
49 |
50 | ents = list(doc._.linkedEntities)
51 | assert len(ents) > 0
52 |
53 | ent = doc._.linkedEntities[0]
54 | assert ent != None
55 |
56 | length = len(doc._.linkedEntities)
57 | assert length > 0
--------------------------------------------------------------------------------
/tests/test_EntityElement.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spacy
3 |
4 |
5 | class TestEntityElement(unittest.TestCase):
6 |
7 | def __init__(self, arg, *args, **kwargs):
8 | super(TestEntityElement, self).__init__(arg, *args, **kwargs)
9 | self.nlp = spacy.load('en_core_web_sm')
10 |
11 | def setUp(self):
12 | self.nlp.add_pipe("entityLinker", last=True)
13 | self.doc = self.nlp(
14 | "Elon Musk was born in South Africa. Bill Gates and Steve Jobs come from the United States. The US are located in North America. A ship is made of wood.")
15 |
16 | def tearDown(self):
17 | self.nlp.remove_pipe("entityLinker")
18 |
19 | def test_get_in_degree(self):
20 | doc = self.doc
21 |
22 | all_linked_entities = doc._.linkedEntities
23 | in_degree = all_linked_entities[0].get_in_degree()
24 | assert in_degree > 0
25 |
26 | def test_get_original_alias(self):
27 | doc = self.doc
28 |
29 | all_linked_entities = doc._.linkedEntities
30 | original_alias = all_linked_entities[0].get_original_alias()
31 | assert original_alias == "Elon Musk"
32 |
33 | def test_is_singleton(self):
34 | doc = self.doc
35 |
36 | all_linked_entities = doc._.linkedEntities
37 | is_singleton = all_linked_entities[0].is_singleton()
38 | assert is_singleton == False
39 | is_singleton = all_linked_entities[-1].is_singleton()
40 | assert is_singleton == True
41 |
42 | def test_get_span(self):
43 | doc = self.doc
44 |
45 | all_linked_entities = doc._.linkedEntities
46 | span = all_linked_entities[0].get_span()
47 | real_span = doc[0:2]
48 | assert span.text == real_span.text
49 | assert span.start == real_span.start
50 | assert span.end == real_span.end
51 |
52 | def test_get_label(self):
53 | doc = self.doc
54 |
55 | all_linked_entities = doc._.linkedEntities
56 | label = all_linked_entities[0].get_label()
57 | assert label == "Elon Musk"
58 |
59 | def test_get_id(self):
60 | doc = self.doc
61 |
62 | all_linked_entities = doc._.linkedEntities
63 | id = all_linked_entities[0].get_id()
64 | assert id > 0
65 |
66 | def test_get_prior(self):
67 | doc = self.doc
68 |
69 | all_linked_entities = doc._.linkedEntities
70 | prior = all_linked_entities[0].get_prior()
71 | assert prior > 0
72 |
73 | def test_get_chain(self):
74 | doc = self.doc
75 |
76 | all_linked_entities = doc._.linkedEntities
77 | chain = all_linked_entities[0].get_chain()
78 | assert chain != None
79 | assert len(chain) > 0
80 |
81 | def test_get_categories(self):
82 | doc = self.doc
83 |
84 | all_linked_entities = doc._.linkedEntities
85 | categories = all_linked_entities[0].get_categories()
86 | assert categories != None
87 | assert len(categories) > 0
88 |
89 | def test_get_sub_entities(self):
90 | doc = self.doc
91 |
92 | all_linked_entities = doc._.linkedEntities
93 | # [-1] --> wood
94 | sub_entities = all_linked_entities[-1].get_sub_entities()
95 | assert sub_entities != None
96 | assert len(sub_entities) > 0
97 |
98 | def test_get_super_entities(self):
99 | doc = self.doc
100 |
101 | all_linked_entities = doc._.linkedEntities
102 | super_entities = all_linked_entities[0].get_super_entities()
103 | assert super_entities != None
104 | assert len(super_entities) > 0
105 |
106 | def test_get_subclass_hierarchy(self):
107 | doc = self.doc
108 |
109 | all_linked_entities = doc._.linkedEntities
110 | # [5] --> US
111 | hierarchy = all_linked_entities[5].get_subclass_hierarchy()
112 | assert hierarchy != None
113 | assert len(hierarchy) > 0
114 | assert 'country' in hierarchy
115 |
116 | def test_get_instance_of_hierarchy(self):
117 | doc = self.doc
118 |
119 | all_linked_entities = doc._.linkedEntities
120 | # [5] --> US
121 | hierarchy = all_linked_entities[5].get_instance_of_hierarchy()
122 | assert hierarchy != None
123 | assert len(hierarchy) > 0
124 | assert 'country' in hierarchy
125 |
126 | def test_get_chain_ids(self):
127 | doc = self.doc
128 |
129 | all_linked_entities = doc._.linkedEntities
130 | chain_ids = all_linked_entities[0].get_chain_ids()
131 | assert chain_ids != None
132 | assert len(chain_ids) > 0
133 |
134 | def test_get_description(self):
135 | doc = self.doc
136 |
137 | all_linked_entities = doc._.linkedEntities
138 | description = all_linked_entities[0].get_description()
139 | assert description != None
140 | assert len(description) > 0
141 |
142 | def test_is_intersecting(self):
143 | doc = self.doc
144 |
145 | all_linked_entities = doc._.linkedEntities
146 | assert not all_linked_entities[0].is_intersecting(all_linked_entities[1])
147 | # United States and US
148 | assert all_linked_entities[4].is_intersecting(all_linked_entities[5])
149 |
150 | def test_serialize(self):
151 | doc = self.doc
152 |
153 | all_linked_entities = doc._.linkedEntities
154 | serialized = all_linked_entities[0].serialize()
155 | assert serialized != None
156 | assert len(serialized) > 0
157 | assert 'id' in serialized
158 | assert 'label' in serialized
159 | assert 'span' in serialized
160 |
161 | def test_pretty_print(self):
162 | doc = self.doc
163 |
164 | all_linked_entities = doc._.linkedEntities
165 | all_linked_entities[0].pretty_print()
166 |
167 | def test_get_url(self):
168 | doc = self.doc
169 |
170 | all_linked_entities = doc._.linkedEntities
171 | url = all_linked_entities[0].get_url()
172 | assert url != None
173 | assert len(url) > 0
174 | assert 'wikidata.org/wiki/Q' in url
175 |
176 | def test___repr__(self):
177 | doc = self.doc
178 |
179 | all_linked_entities = doc._.linkedEntities
180 | repr = all_linked_entities[0].__repr__()
181 | assert repr != None
182 | assert len(repr) > 0
183 |
184 | def test___eq__(self):
185 | doc = self.doc
186 |
187 | all_linked_entities = doc._.linkedEntities
188 | assert not all_linked_entities[0] == all_linked_entities[1]
189 | assert all_linked_entities[4] == all_linked_entities[5]
190 |
--------------------------------------------------------------------------------
/tests/test_EntityLinker.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spacy
3 | from spacy_entity_linker.EntityLinker import EntityLinker
4 |
5 |
6 | class TestEntityLinker(unittest.TestCase):
7 |
8 | def __init__(self, arg, *args, **kwargs):
9 | super(TestEntityLinker, self).__init__(arg, *args, **kwargs)
10 | self.nlp = spacy.load('en_core_web_sm')
11 |
12 | def test_initialization(self):
13 |
14 | self.nlp.add_pipe("entityLinker", last=True)
15 |
16 | doc = self.nlp(
17 | "Elon Musk was born in South Africa. Bill Gates and Steve Jobs come from in the United States")
18 |
19 | doc._.linkedEntities.pretty_print()
20 | doc._.linkedEntities.print_super_entities()
21 | for sent in doc.sents:
22 | sent._.linkedEntities.pretty_print()
23 |
24 | self.nlp.remove_pipe("entityLinker")
25 |
26 | def test_empty_root(self):
27 | # test empty lists of roots (#9)
28 | self.nlp.add_pipe("entityLinker", last=True)
29 |
30 | doc = self.nlp(
31 | 'I was right."\n\n "To that extent."\n\n "But that was all."\n\n "No, no, m')
32 | for sent in doc.sents:
33 | sent._.linkedEntities.pretty_print()
34 | # empty document
35 | doc = self.nlp('\n\n')
36 | for sent in doc.sents:
37 | sent._.linkedEntities.pretty_print()
38 |
39 | self.nlp.remove_pipe("entityLinker")
40 |
--------------------------------------------------------------------------------
/tests/test_TermCandidateExtractor.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spacy
3 | import spacy_entity_linker.TermCandidateExtractor
4 |
5 |
6 | class TestCandidateExtractor(unittest.TestCase):
7 |
8 | def __init__(self, arg, *args, **kwargs):
9 | super(TestCandidateExtractor, self).__init__(arg, *args, **kwargs)
10 |
--------------------------------------------------------------------------------
/tests/test_multiprocessing.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spacy
3 | from multiprocessing.pool import ThreadPool
4 |
5 |
6 | class TestMultiprocessing(unittest.TestCase):
7 |
8 | def __init__(self, arg, *args, **kwargs):
9 | super(TestMultiprocessing, self).__init__(arg, *args, **kwargs)
10 | self.nlp = spacy.load('en_core_web_sm')
11 |
12 | def test_is_pipe_multiprocessing_safe(self):
13 | self.nlp.add_pipe("entityLinker", last=True)
14 |
15 | ents = [
16 | 'Apple',
17 | 'Microsoft',
18 | 'Google',
19 | 'Amazon',
20 | 'Facebook',
21 | 'IBM',
22 | 'Twitter',
23 | 'Tesla',
24 | 'SpaceX',
25 | 'Alphabet',
26 | ]
27 | text = "{} is looking at buying U.K. startup for $1 billion"
28 |
29 | texts = [text.format(ent) for ent in ents]
30 | docs = self.nlp.pipe(texts, n_process=2)
31 | for doc in docs:
32 | print(doc)
33 | for ent in doc.ents:
34 | print(ent.text, ent.label_, ent._.linkedEntities)
35 |
36 | self.nlp.remove_pipe("entityLinker")
37 |
--------------------------------------------------------------------------------
/tests/test_multithreading.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spacy
3 | from multiprocessing.pool import ThreadPool
4 |
5 |
6 | class TestMultiThreading(unittest.TestCase):
7 |
8 | def __init__(self, arg, *args, **kwargs):
9 | super(TestMultiThreading, self).__init__(arg, *args, **kwargs)
10 | self.nlp = spacy.load('en_core_web_sm')
11 |
12 | def test_is_multithread_safe(self):
13 | self.nlp.add_pipe("entityLinker", last=True)
14 |
15 | ents = [
16 | 'Apple',
17 | 'Microsoft',
18 | 'Google',
19 | 'Amazon',
20 | 'Facebook',
21 | 'IBM',
22 | 'Twitter',
23 | 'Tesla',
24 | 'SpaceX',
25 | 'Alphabet',
26 | ]
27 | text = "{} is looking at buying U.K. startup for $1 billion"
28 |
29 | def thread_func(i):
30 | doc = self.nlp(text.format(ents[i]))
31 | print(doc)
32 |
33 | for ent in doc.ents:
34 | print(ent.text, ent.label_, ent._.linkedEntities)
35 | return i
36 |
37 | with ThreadPool(10) as pool:
38 | for res in pool.imap_unordered(thread_func, range(10)):
39 | pass
40 |
41 | self.nlp.remove_pipe("entityLinker")
42 |
--------------------------------------------------------------------------------
/tests/test_pipe.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spacy
3 | from multiprocessing.pool import ThreadPool
4 |
5 |
6 | class TestPipe(unittest.TestCase):
7 |
8 | def __init__(self, arg, *args, **kwargs):
9 | super(TestPipe, self).__init__(arg, *args, **kwargs)
10 | self.nlp = spacy.load('en_core_web_sm')
11 |
12 | def test_serialize(self):
13 | self.nlp.add_pipe("entityLinker", last=True)
14 |
15 | ents = [
16 | 'Apple',
17 | 'Microsoft',
18 | 'Google',
19 | 'Amazon',
20 | 'Facebook',
21 | 'IBM',
22 | 'Twitter',
23 | 'Tesla',
24 | 'SpaceX',
25 | 'Alphabet',
26 | ]
27 | text = "{} is looking at buying U.K. startup for $1 billion"
28 |
29 | texts = [text.format(ent) for ent in ents]
30 | docs = self.nlp.pipe(texts, n_process=2)
31 | for doc in docs:
32 | print(doc)
33 | for ent in doc.ents:
34 | print(ent.text, ent.label_, ent._.linkedEntities)
35 |
36 |
37 | self.nlp.remove_pipe("entityLinker")
38 |
--------------------------------------------------------------------------------
/tests/test_serialize.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import spacy
3 | from multiprocessing.pool import ThreadPool
4 |
5 |
6 | class TestSerialize(unittest.TestCase):
7 |
8 | def __init__(self, arg, *args, **kwargs):
9 | super(TestSerialize, self).__init__(arg, *args, **kwargs)
10 | self.nlp = spacy.load('en_core_web_sm')
11 |
12 | def test_serialize(self):
13 | self.nlp.add_pipe("entityLinker", last=True)
14 |
15 | text = "Apple is looking at buying U.K. startup for $1 billion"
16 | doc = self.nlp(text)
17 | serialised = doc.to_bytes()
18 |
19 | doc2 = spacy.tokens.Doc(doc.vocab).from_bytes(serialised)
20 | for ent, ent2 in zip(doc.ents, doc2.ents):
21 | assert ent.text == ent2.text
22 | assert ent.label_ == ent2.label_
23 | linked = ent._.linkedEntities
24 | linked2 = ent2._.linkedEntities
25 | if linked:
26 | assert linked.get_description() == linked2.get_description()
27 | assert linked.get_id() == linked2.get_id()
28 | assert linked.get_label() == linked2.get_label()
29 | assert linked.get_span() == linked2.get_span()
30 | assert linked.get_url() == linked2.get_url()
31 |
32 |
33 | self.nlp.remove_pipe("entityLinker")
34 |
--------------------------------------------------------------------------------