├── zshot ├── tests │ ├── __init__.py │ ├── linker │ │ ├── __init__.py │ │ ├── test_gliner_linker.py │ │ ├── test_smxm_linker.py │ │ ├── test_relik_linker.py │ │ ├── test_ensemble_linker.py │ │ ├── test_blinker_linker.py │ │ ├── test_linker.py │ │ └── test_regen_linker.py │ ├── utils │ │ ├── __init__.py │ │ ├── test_mappings.py │ │ ├── test_ensembler.py │ │ ├── test_description_enrichment.py │ │ ├── test_displacy.py │ │ └── test_data_models.py │ ├── evaluation │ │ ├── __init__.py │ │ └── test_datasets.py │ ├── knowledge_extractor │ │ ├── __init__.py │ │ ├── test_knowledge_extractor.py │ │ └── test_knowgl_knowledge_extractor.py │ ├── mentions_extractor │ │ ├── __init__.py │ │ ├── test_smxm_mentions_extractor.py │ │ ├── test_gliner_mentions_extractor.py │ │ ├── test_mention_extractor.py │ │ ├── test_spacy_mentions_extractor.py │ │ ├── test_tars_mentions_extractor.py │ │ └── test_flair_mentions_extractor.py │ ├── relations_extractor │ │ ├── __init__.py │ │ ├── test_zsrc_relations_extractor.py │ │ └── test_relations_extractor.py │ └── config.py ├── utils │ ├── displacy │ │ ├── __init__.py │ │ ├── colors.py │ │ └── displacy.py │ ├── models │ │ ├── __init__.py │ │ ├── smxm │ │ │ ├── __init__.py │ │ │ └── model.py │ │ └── tars │ │ │ ├── __init__.py │ │ │ └── utils.py │ ├── __init__.py │ ├── data_models │ │ ├── __init__.py │ │ ├── relation.py │ │ ├── entity.py │ │ ├── relation_span.py │ │ └── span.py │ ├── enrichment │ │ └── __init__.py │ ├── file_utils.py │ ├── download_models.py │ ├── mappings.py │ └── ensembler.py ├── evaluation │ ├── metrics │ │ ├── __init__.py │ │ ├── _seqeval │ │ │ └── __init__.py │ │ └── rel_eval.py │ ├── dataset │ │ ├── fewrel │ │ │ ├── __init__.py │ │ │ └── fewrel.py │ │ ├── ontonotes │ │ │ ├── __init__.py │ │ │ └── onto_notes.py │ │ ├── med_mentions │ │ │ ├── __init__.py │ │ │ └── med_mentions.py │ │ ├── pile_ner_biomed │ │ │ ├── __init__.py │ │ │ └── pile_ner_biomed.py │ │ ├── __init__.py │ │ └── dataset.py │ ├── __init__.py │ └── pipeline.py ├── linker │ ├── linker_regen │ │ ├── __init__.py │ │ ├── trie.py │ │ └── utils.py │ ├── linker_ensemble │ │ ├── __init__.py │ │ ├── utils.py │ │ └── linker_ensemble.py │ ├── __init__.py │ ├── linker_gliner.py │ ├── linker_smxm.py │ ├── linker_relik.py │ ├── linker_tars.py │ └── linker.py ├── relation_extractor │ ├── zsrc │ │ ├── __init__.py │ │ ├── decide_entity_order.py │ │ └── zero_shot_rel_class.py │ ├── __init__.py │ ├── relations_extractor.py │ └── relation_extractor_zsrc.py ├── knowledge_extractor │ ├── knowgl │ │ ├── __init__.py │ │ └── knowledge_extractor_knowgl.py │ ├── __init__.py │ ├── knowledge_extractor_relik.py │ └── knowledge_extractor.py ├── mentions_extractor │ ├── utils │ │ ├── __init__.py │ │ └── ExtractorType.py │ ├── __init__.py │ ├── mentions_extractor_gliner.py │ ├── mentions_extractor_smxm.py │ ├── mentions_extractor_spacy.py │ ├── mentions_extractor.py │ ├── mentions_extractor_tars.py │ └── mentions_extractor_flair.py ├── config.py ├── __init__.py └── pipeline_config.py ├── requirements.txt ├── docs ├── index.md ├── img │ ├── blink.png │ ├── graph.png │ ├── annotations.png │ └── zshot-header.png ├── flair_mentions_extractor.md ├── spacy_mentions_extractor.md ├── relation_extractor.md ├── knowledge_extractor.md ├── smxm_mentions_extractor.md ├── relik_knowledge_extractor.md ├── gliner_linker.md ├── tars_mentions_extractor.md ├── gliner_mentions_extractor.md ├── tars_linker.md ├── relik_linker.md ├── smxm_linker.md ├── regen.md ├── zsbert_relations_extractor.md ├── blink.md ├── knowgl_knowledge_extractor.md ├── mentions_extractor.md ├── entity_linking.md └── evaluation.md ├── setup.cfg ├── requirements ├── devel.txt └── test.txt ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── workflows │ ├── publish-pages-doc.yml │ ├── python-publish.yml │ ├── python-tests.yml │ └── codeql.yml └── pull_request_template.md ├── LICENSE ├── setup.py ├── .gitignore └── mkdocs.yml /zshot/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | -------------------------------------------------------------------------------- /zshot/tests/linker/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/utils/displacy/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/utils/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/evaluation/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/linker/linker_regen/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/tests/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/utils/models/smxm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/utils/models/tars/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/evaluation/dataset/fewrel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/relation_extractor/zsrc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/tests/knowledge_extractor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/tests/mentions_extractor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/tests/relations_extractor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/evaluation/dataset/ontonotes/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/evaluation/metrics/_seqeval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/knowledge_extractor/knowgl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/evaluation/dataset/med_mentions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /zshot/evaluation/dataset/pile_ner_biomed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | {% 2 | include-markdown "../README.md" 3 | %} -------------------------------------------------------------------------------- /docs/img/blink.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/zshot/HEAD/docs/img/blink.png -------------------------------------------------------------------------------- /docs/img/graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/zshot/HEAD/docs/img/graph.png -------------------------------------------------------------------------------- /docs/img/annotations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/zshot/HEAD/docs/img/annotations.png -------------------------------------------------------------------------------- /zshot/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from zshot.utils.file_utils import download_file # noqa: F401 2 | -------------------------------------------------------------------------------- /docs/img/zshot-header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/zshot/HEAD/docs/img/zshot-header.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [egg_info] 2 | tag_svn_revision = true 3 | 4 | [metadata] 5 | version = attr: zshot.__version__ -------------------------------------------------------------------------------- /docs/flair_mentions_extractor.md: -------------------------------------------------------------------------------- 1 | # Flair Mentions Extractor 2 | 3 | ::: zshot.mentions_extractor.MentionsExtractorFlair -------------------------------------------------------------------------------- /docs/spacy_mentions_extractor.md: -------------------------------------------------------------------------------- 1 | # Spacy Mentions Extractor 2 | 3 | ::: zshot.mentions_extractor.MentionsExtractorSpacy -------------------------------------------------------------------------------- /zshot/linker/linker_ensemble/__init__.py: -------------------------------------------------------------------------------- 1 | from zshot.linker.linker_ensemble.linker_ensemble import LinkerEnsemble # noqa: F401 2 | -------------------------------------------------------------------------------- /zshot/mentions_extractor/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from zshot.mentions_extractor.utils.ExtractorType import ExtractorType # noqa: F401 2 | -------------------------------------------------------------------------------- /zshot/mentions_extractor/utils/ExtractorType.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | 3 | 4 | class ExtractorType(Enum): 5 | NER = auto() 6 | POS = auto() 7 | -------------------------------------------------------------------------------- /requirements/devel.txt: -------------------------------------------------------------------------------- 1 | # install all mandatory dependencies 2 | -r ../requirements.txt 3 | 4 | # extended list of dependencies for development and run lint and tests 5 | -r ./test.txt -------------------------------------------------------------------------------- /zshot/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | MODELS_CACHE_PATH = os.getenv("MODELS_CACHE_PATH") if "MODELS_CACHE_PATH" in os.environ \ 5 | else f"{pathlib.Path.home()}/.cache/zshot/" 6 | -------------------------------------------------------------------------------- /requirements/test.txt: -------------------------------------------------------------------------------- 1 | pytest>=7.0 2 | pytest-cov>=3.0.0 3 | setuptools>=65.5.1 4 | scipy<1.13.0 5 | flair>=0.13 6 | gliner>=0.2.9 7 | flake8>=4.0.1 8 | coverage>=6.4.1 9 | pydantic==1.9.2 10 | relik==1.0.5 11 | IPython -------------------------------------------------------------------------------- /zshot/relation_extractor/__init__.py: -------------------------------------------------------------------------------- 1 | from zshot.relation_extractor.relations_extractor import RelationsExtractor # noqa: F401 2 | from zshot.relation_extractor.relation_extractor_zsrc import RelationsExtractorZSRC # noqa: F401 3 | -------------------------------------------------------------------------------- /zshot/utils/data_models/__init__.py: -------------------------------------------------------------------------------- 1 | from zshot.utils.data_models.entity import Entity # noqa: F401 2 | from zshot.utils.data_models.relation import Relation # noqa: F401 3 | from zshot.utils.data_models.span import Span # noqa: F401 4 | -------------------------------------------------------------------------------- /zshot/utils/enrichment/__init__.py: -------------------------------------------------------------------------------- 1 | from zshot.utils.enrichment.description_enrichment import ParaphrasingStrategy, \ 2 | FineTunedLMExtensionStrategy, PreTrainedLMExtensionStrategy, SummarizationStrategy, \ 3 | EntropyHeuristic # noqa: F401 4 | -------------------------------------------------------------------------------- /zshot/__init__.py: -------------------------------------------------------------------------------- 1 | from zshot.zshot import MentionsExtractor, Linker, Zshot, PipelineConfig, \ 2 | RelationsExtractor, KnowledgeExtractor # noqa: F401 3 | from zshot.utils.displacy.displacy import displacy # noqa: F401 4 | 5 | __version__ = '0.0.11' 6 | -------------------------------------------------------------------------------- /docs/relation_extractor.md: -------------------------------------------------------------------------------- 1 | # Relations Extractor 2 | 3 | The **relations extractor** will extract relations among different entities *previously* extracted by a **linker**.. 4 | 5 | Currently, the is only one Relation Extractor available: ZS-Bert 6 | 7 | 8 | ::: zshot.RelationsExtractor -------------------------------------------------------------------------------- /zshot/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from zshot.evaluation.dataset import load_medmentions_zs # noqa: F401 2 | from zshot.evaluation.dataset import load_ontonotes_zs # noqa: F401 3 | from zshot.evaluation.dataset import load_few_rel_zs # noqa: F401 4 | from zshot.evaluation.dataset import load_pile_ner_biomed_zs # noqa: F401 5 | -------------------------------------------------------------------------------- /zshot/knowledge_extractor/__init__.py: -------------------------------------------------------------------------------- 1 | from zshot.knowledge_extractor.knowledge_extractor import KnowledgeExtractor # noqa: F401 2 | from zshot.knowledge_extractor.knowgl.knowledge_extractor_knowgl import KnowGL # noqa: F401 3 | from zshot.knowledge_extractor.knowledge_extractor_relik import KnowledgeExtractorRelik # noqa: F401 4 | -------------------------------------------------------------------------------- /docs/knowledge_extractor.md: -------------------------------------------------------------------------------- 1 | # Knowledge Extractor 2 | 3 | The **knowledge extractor** will perform at the same time the extraction and classification of named entities and the extraction of relations among them. 4 | 5 | Currently, the are only two Knowledge Extractor available: KnowGL and ReLiK 6 | 7 | 8 | ::: zshot.KnowledgeExtractor -------------------------------------------------------------------------------- /zshot/utils/data_models/relation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import zlib 4 | from pydantic import BaseModel 5 | 6 | 7 | class Relation(BaseModel): 8 | name: str 9 | description: Optional[str] = None 10 | 11 | def __hash__(self): 12 | self_repr = f"{self.__class__.__name__}.{str(self.__dict__)}" 13 | return zlib.crc32(self_repr.encode()) 14 | -------------------------------------------------------------------------------- /zshot/evaluation/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from zshot.evaluation.dataset.med_mentions.med_mentions import load_medmentions_zs # noqa: F401 2 | from zshot.evaluation.dataset.ontonotes.onto_notes import load_ontonotes_zs # noqa: F401 3 | from zshot.evaluation.dataset.fewrel.fewrel import load_few_rel_zs # noqa: F401 4 | from zshot.evaluation.dataset.pile_ner_biomed.pile_ner_biomed import load_pile_ner_biomed_zs # noqa: F401 5 | -------------------------------------------------------------------------------- /zshot/utils/data_models/entity.py: -------------------------------------------------------------------------------- 1 | import zlib 2 | from typing import Optional, List 3 | from pydantic import BaseModel 4 | 5 | 6 | class Entity(BaseModel): 7 | name: str 8 | description: Optional[str] = None 9 | vocabulary: Optional[List[str]] = None 10 | 11 | def __hash__(self): 12 | self_repr = f"{self.__class__.__name__}.{str(self.__dict__)}" 13 | return zlib.crc32(self_repr.encode()) 14 | -------------------------------------------------------------------------------- /zshot/utils/displacy/colors.py: -------------------------------------------------------------------------------- 1 | import zlib 2 | 3 | 4 | def light_color_from_label(label: str): 5 | channel_min = 100 6 | hash_s = zlib.crc32(label.encode()) 7 | r = ((hash_s & 0xFF0000) >> 16) % (255 - channel_min) + channel_min 8 | g = ((hash_s & 0x00FF00) >> 8) % (255 - channel_min) + channel_min 9 | b = (hash_s & 0x0000F) % (255 - channel_min) + channel_min 10 | return '#%02x%02x%02x' % (r, g, b) 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[Bug]" 5 | labels: bug 6 | assignees: "" 7 | 8 | --- 9 | 10 | # Summary 11 | 12 | **Describe the bug** 13 | A clear and concise description of what the bug is. 14 | 15 | **To Reproduce** 16 | Steps to reproduce the behavior: 17 | 18 | **Expected behavior** 19 | A clear and concise description of what you expected to happen. -------------------------------------------------------------------------------- /.github/workflows/publish-pages-doc.yml: -------------------------------------------------------------------------------- 1 | name: Publish docs via GitHub Pages 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - uses: actions/setup-python@v2 14 | with: 15 | python-version: 3.x 16 | - run: pip install mkdocs mkdocs-material mkdocstrings[python] mkdocs-markdownextradata-plugin mdx_include mkdocs-include-markdown-plugin 17 | - run: mkdocs gh-deploy --force 18 | -------------------------------------------------------------------------------- /zshot/linker/__init__.py: -------------------------------------------------------------------------------- 1 | from zshot.linker.linker_blink import LinkerBlink # noqa: F401 2 | from zshot.linker.linker_regen.linker_regen import LinkerRegen # noqa: F401 3 | from zshot.linker.linker import Linker # noqa: F401 4 | from zshot.linker.linker_smxm import LinkerSMXM # noqa: F401 5 | from zshot.linker.linker_tars import LinkerTARS # noqa: F401 6 | from zshot.linker.linker_ensemble import LinkerEnsemble # noqa: F401 7 | from zshot.linker.linker_relik import LinkerRelik # noqa: F401 8 | from zshot.linker.linker_gliner import LinkerGLINER # noqa: F401 9 | -------------------------------------------------------------------------------- /zshot/mentions_extractor/__init__.py: -------------------------------------------------------------------------------- 1 | from zshot.mentions_extractor.mentions_extractor_flair import MentionsExtractorFlair # noqa: F401 2 | from zshot.mentions_extractor.mentions_extractor_spacy import MentionsExtractorSpacy # noqa: F401 3 | from zshot.mentions_extractor.mentions_extractor_smxm import MentionsExtractorSMXM # noqa: F401 4 | from zshot.mentions_extractor.mentions_extractor_tars import MentionsExtractorTARS # noqa: F401 5 | from zshot.mentions_extractor.mentions_extractor_gliner import MentionsExtractorGLINER # noqa: F401 6 | from zshot.mentions_extractor.mentions_extractor import MentionsExtractor # noqa: F401 7 | -------------------------------------------------------------------------------- /zshot/utils/models/tars/utils.py: -------------------------------------------------------------------------------- 1 | from zshot.utils.data_models import Span 2 | 3 | 4 | def tars_predict(model, sentences, batch_size): 5 | kwargs = {'mini_batch_size': batch_size} if batch_size else {} 6 | model.predict(sentences, **kwargs) 7 | 8 | spans_annotations = [] 9 | for sent in sentences: 10 | sent_mentions = sent.get_spans('ner') 11 | spans = [ 12 | Span(mention.start_position, mention.end_position, mention.tag, mention.score) 13 | for mention in sent_mentions 14 | ] 15 | spans_annotations.append(spans) 16 | 17 | return spans_annotations 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Feature or enhancement for the project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | # Scenario summary 11 | 12 | A clear and concise description of scenario. 13 | **Is your feature request related to a problem? Please describe.** 14 | 15 | # Proposed solution 16 | 17 | **Describe the solution you'd like** 18 | A clear and concise description of what you want to happen. 19 | 20 | **Describe alternatives you've considered** 21 | A clear and concise description of any alternative solutions or features you've considered. 22 | -------------------------------------------------------------------------------- /docs/smxm_mentions_extractor.md: -------------------------------------------------------------------------------- 1 | # SMXM Mentions Extractor 2 | When the mentions to be extracted are known by the user they can be specified in order to improve the performance of the system. Based on the SMXM linker, the SMXM **mentions extractor** uses the description of the mentions to give the model information about the mentions to be extracted. By using the descriptions, the SMXM model is able to understand the mention. 3 | 4 | The SMXM **mentions extractor** will use the **mentions** specified in the `zshot.PipelineConfig`. 5 | 6 | - [Paper](https://aclanthology.org/2021.acl-long.120/) 7 | - [Original Source Code](https://github.com/Raldir/Zero-shot-NERC) 8 | 9 | ::: zshot.mentions_extractor.MentionsExtractorSMXM -------------------------------------------------------------------------------- /docs/relik_knowledge_extractor.md: -------------------------------------------------------------------------------- 1 | # ReLiK Knowledge Extractor 2 | 3 | ReLiK is a lightweight and fast model for Entity Linking and Relation Extraction. It is composed of two main components: a retriever and a reader. The retriever is responsible for retrieving relevant documents from a large collection, while the reader is responsible for extracting entities and relations from the retrieved documents. 4 | 5 | In **Zshot**, we created a Knowledge Extractor to use ReLiK and extract relations directly, without having to specify any entities or relation names. 6 | 7 | - [Paper](https://arxiv.org/abs/2408.00103) 8 | - [Original Source Code](https://github.com/SapienzaNLP/relik) 9 | 10 | ::: zshot.knowledge_extractor.KnowledgeExtractorRelik -------------------------------------------------------------------------------- /docs/gliner_linker.md: -------------------------------------------------------------------------------- 1 | # GLiNER Linker 2 | 3 | GLiNER is a Named Entity Recognition (NER) model capable of identifying any entity type using a bidirectional transformer encoder (BERT-like). It provides a practical alternative to traditional NER models, which are limited to predefined entities, and Large Language Models (LLMs) that, despite their flexibility, are costly and large for resource-constrained scenarios. 4 | 5 | The GLiNER **linker** will use the **entities** specified in the `zshot.PipelineConfig`, it just uses the names of the entities, it doesn't use the descriptions of the entities. 6 | 7 | 8 | - [Paper](https://arxiv.org/abs/2311.08526) 9 | - [Original Source Code](https://github.com/urchade/GLiNER) 10 | 11 | ::: zshot.linker.LinkerGLINER -------------------------------------------------------------------------------- /zshot/evaluation/dataset/pile_ner_biomed/pile_ner_biomed.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | 3 | from zshot.evaluation.dataset.dataset import DatasetWithEntities 4 | from zshot.utils.data_models import Entity 5 | 6 | REPO_ID = "disi-unibo-nlp/Pile-NER-biomed-IOB" 7 | ENTITIES_REPO_ID = "disi-unibo-nlp/Pile-NER-biomed-descriptions" 8 | 9 | 10 | def load_pile_ner_biomed_zs(**kwargs) -> DatasetWithEntities: 11 | dataset = load_dataset(REPO_ID, split='train', **kwargs) 12 | entities = load_dataset(ENTITIES_REPO_ID, split="train") 13 | 14 | entities_split = [Entity(name=e['entity_type'], description=e['description']) for e in entities] 15 | dataset = DatasetWithEntities(dataset.data, entities=entities_split) 16 | 17 | return dataset 18 | -------------------------------------------------------------------------------- /docs/tars_mentions_extractor.md: -------------------------------------------------------------------------------- 1 | # TARS Mentions Extractor 2 | 3 | When the mentions to be extracted are known by the user they can be specified in order to improve the performance of the system. Based on the TARS linker, the TARS **mentions extractor** uses the labels of the mentions to give the model information about the mentions to be extracted. TARS doesn't need the descriptions of the entities, so if you can't provide the descriptions of the entities maybe this is the approach you're looking for. 4 | 5 | The TARS **mentions extractor** will use the **mentions** specified in the `zshot.PipelineConfig`. 6 | 7 | - [Paper](https://kishaloyhalder.github.io/pdfs/tars_coling2020.pdf) 8 | - [Original Source Code](https://github.com/flairNLP/flair) 9 | 10 | ::: zshot.mentions_extractor.MentionsExtractorTARS -------------------------------------------------------------------------------- /docs/gliner_mentions_extractor.md: -------------------------------------------------------------------------------- 1 | # GLiNER Mentions Extractor 2 | 3 | GLiNER is a Named Entity Recognition (NER) model capable of identifying any entity type using a bidirectional transformer encoder (BERT-like). It provides a practical alternative to traditional NER models, which are limited to predefined entities, and Large Language Models (LLMs) that, despite their flexibility, are costly and large for resource-constrained scenarios. 4 | 5 | The GLiNER **mentions extractor** will use the **mentions** specified in the `zshot.PipelineConfig`, it just uses the names of the mentions, it doesn't use the descriptions of the mentions. 6 | 7 | 8 | - [Paper](https://arxiv.org/abs/2311.08526) 9 | - [Original Source Code](https://github.com/urchade/GLiNER) 10 | 11 | ::: zshot.mentions_extractor.MentionsExtractorGLINER -------------------------------------------------------------------------------- /docs/tars_linker.md: -------------------------------------------------------------------------------- 1 | # TARS Linker 2 | 3 | Task-aware representation of sentences (TARS), is a simple and effective method for few-shot and even zero-shot learning for text classification. However, it was extended to perform Zero-Shot NERC. 4 | 5 | Basically, TARS tries to convert the problem to a binary classification problem, predicting if a given text belongs to a specific class. 6 | 7 | TARS doesn't need the descriptions of the entities, so if you can't provide the descriptions of the entities maybe this is the approach you're looking for. 8 | 9 | The TARS **linker** will use the **entities** specified in the `zshot.PipelineConfig`. 10 | 11 | 12 | - [Paper](https://kishaloyhalder.github.io/pdfs/tars_coling2020.pdf) 13 | - [Original Source Code](https://github.com/flairNLP/flair) 14 | 15 | ::: zshot.linker.LinkerTARS -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 2 | > ⚠️ NOTE: Use notes like this to emphasize something important about the PR. 3 | > 4 | > This could include other PRs this PR is built on top of; API breaking changes; reasons for why the PR is on hold; or anything else you would like to draw attention to. 5 | 6 | | Status | Type | ⚠️ Core Change | Issue | 7 | | :---: | :---: | :---: | :--: | 8 | | Ready/Hold | Feature/Bug/Tooling/Refactor/Hotfix | Yes/No | [Link]() | 9 | 10 | ## Problem 11 | 12 | _What problem are you trying to solve?_ 13 | 14 | 15 | ## Solution 16 | 17 | _How did you solve the problem?_ 18 | 19 | 20 | ## Other changes (e.g. bug fixes, small refactors) 21 | 22 | 23 | **New scripts**: 24 | 25 | - `script` : script details 26 | 27 | **New dependencies**: 28 | 29 | - `dependency` : dependency details -------------------------------------------------------------------------------- /docs/relik_linker.md: -------------------------------------------------------------------------------- 1 | # ReLiK Linker 2 | ReLiK is a lightweight and fast model for Entity Linking and Relation Extraction. It is composed of two main components: a retriever and a reader. The retriever is responsible for retrieving relevant documents from a large collection, while the reader is responsible for extracting entities and relations from the retrieved documents. 3 | 4 | In **Zshot**, we created a linker to use ReLiK, and it works both providing entities or without providing entities, and with descriptions. 5 | 6 | This is an *end-to-end* model, so there is no need to use a **mentions extractor** before. 7 | 8 | The ReLiK **linker** will use the **entities** specified in the `zshot.PipelineConfig`, if any. 9 | 10 | - [Paper](https://arxiv.org/abs/2408.00103) 11 | - [Original Source Code](https://github.com/SapienzaNLP/relik) 12 | 13 | ::: zshot.linker.LinkerRelik -------------------------------------------------------------------------------- /zshot/linker/linker_regen/trie.py: -------------------------------------------------------------------------------- 1 | from typing import Collection 2 | 3 | 4 | class Trie(object): 5 | def __init__(self, sequences: Collection[Collection[int]] = []): 6 | self.trie_dict = {} 7 | for sequence in sequences: 8 | self.add(sequence) 9 | 10 | def add(self, sequence: Collection[int]): 11 | trie = self.trie_dict 12 | for idx in sequence: 13 | if idx not in trie: 14 | trie[idx] = {} 15 | trie = trie[idx] 16 | 17 | def postfix(self, prefix_sequence: Collection[int]): 18 | if len(prefix_sequence) == 1: 19 | return list(self.trie_dict.keys()) 20 | trie = self.trie_dict 21 | for pfx in prefix_sequence[1:]: 22 | if pfx not in trie: 23 | return [] 24 | trie = trie[pfx] 25 | return list(trie.keys()) 26 | -------------------------------------------------------------------------------- /zshot/tests/utils/test_mappings.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | 5 | from zshot.utils.mappings import spans_to_wikipedia, spans_to_dbpedia 6 | from zshot.utils.data_models import Span 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | @pytest.mark.skip(reason="Too expensive to run on every commit") 12 | def test_span_to_dbpedia(): # pragma: no cover 13 | s = Span(label="Surfing", start=0, end=10) 14 | db_links = spans_to_dbpedia([s]) 15 | assert len(db_links) > 0 16 | assert db_links[0].startswith("http://dbpedia.org/resource") 17 | 18 | 19 | @pytest.mark.skip(reason="Too expensive to run on every commit") 20 | def test_span_to_wiki(): # pragma: no cover 21 | s = Span(label="Surfing", start=0, end=10) 22 | wiki_links = spans_to_wikipedia([s]) 23 | assert len(wiki_links) > 0 24 | assert wiki_links[0].startswith("https://en.wikipedia.org/wiki?curid=") 25 | -------------------------------------------------------------------------------- /docs/smxm_linker.md: -------------------------------------------------------------------------------- 1 | # SMXM Linker 2 | When there is no labelled data (i.e.: Zero-Shot approaches) the performance usually decreases due to the fact that the model doesn't really know what does the entity represent. To address this problem the SMXM model uses the description of the entities to give the model information about the entities. 3 | 4 | By using the descriptions, the SMXM model is able to understand the entity. Although this approach is Zero-Shot, as it doesn't need to have seen the entities during training, the user still have to specify the descriptions of the entities. 5 | 6 | This is an *end-to-end* model, so there is no need to use a **mentions extractor** before. 7 | 8 | The SMXM **linker** will use the **entities** specified in the `zshot.PipelineConfig`. 9 | 10 | - [Paper](https://aclanthology.org/2021.acl-long.120/) 11 | - [Original Source Code](https://github.com/Raldir/Zero-shot-NERC) 12 | 13 | ::: zshot.linker.LinkerSMXM -------------------------------------------------------------------------------- /docs/regen.md: -------------------------------------------------------------------------------- 1 | # GENRE 2 | Regen is based on GENRE. GENRE is also an entity linking model released by Facebook, but in this case it uses a different approach by conseidering the NERC task as a sequence-to-sequence problem, and retrieves the entities by using a constrained beam search to force the model to generate the entities. 3 | 4 | In a nutshell, (m)GENRE uses a sequence-to-sequence approach to entity retrieval (e.g., linking), based on fine-tuned [BART](https://arxiv.org/abs/1910.13461). GENRE performs retrieval generating the unique entity name conditioned on the input text using constrained beam search to only generate valid identifiers. 5 | Although there is a version *end-to-end* of GENRE, it is not currently supported on ZShot (but it will). 6 | 7 | The REGEN **linker** will use the **entities** specified in the `zshot.PipelineConfig`. 8 | 9 | - [Paper](https://arxiv.org/pdf/2010.00904.pdf) 10 | - [Original Source Code](https://github.com/facebookresearch/GENRE) 11 | 12 | ::: zshot.linker.LinkerRegen -------------------------------------------------------------------------------- /docs/zsbert_relations_extractor.md: -------------------------------------------------------------------------------- 1 | # ZS-BERT Relations Extractor 2 | 3 | The ZS-BERT model is a novel multi-task learning model to directly predict unseen relations without hand-crafted attribute labeling and multiple pairwise classifications. Given training instances consisting of input sentences and the descriptions of their relations, ZS-BERT learns two functions that project sentences and relation descriptions into an embedding space by jointly minimizing the distances between them and classifying seen relations. By generating the embeddings of unseen relations and new-coming sentences based on such two functions, we use nearest neighbor search to obtain the prediction of unseen relations. 4 | 5 | This `RelationsExtractor` uses relations pre-defined along with their descriptions, added into the `PipelineConfig` using the `Relation` data model. 6 | 7 | - [Paper](https://arxiv.org/abs/2104.04697) 8 | - [Original Repo](https://github.com/dinobby/ZS-BERT) 9 | 10 | ::: zshot.relation_extractor.RelationsExtractorZSRC -------------------------------------------------------------------------------- /docs/blink.md: -------------------------------------------------------------------------------- 1 | ### BLINK 2 | BLINK is an Entity Linking model released by Facebook that uses Wikipedia as the target knowledge base. The process of linking entities to Wikipedia is also known as [Wikification](https://en.wikipedia.org/wiki/Wikification). 3 | 4 | In a nutshell, BLINK uses a two stages approach for entity linking, based on fine-tuned BERT architectures. In the first stage, BLINK performs retrieval in a dense space defined by a bi-encoder that independently embeds the mention context and the entity descriptions. Each candidate is then examined more carefully with a cross-encoder, that concatenates the mention and entity text. BLINK achieves state-of-the-art results on multiple datasets. 5 | 6 | ![BLINK Overview](./img/blink.png) 7 | 8 | The BLINK knowledge base (entity library) is based on the 2019/08/01 Wikipedia dump, so the target entities are Wikipedia entities or articles. 9 | 10 | - [Paper](https://arxiv.org/pdf/1911.03814.pdf) 11 | - [Original Source Code](https://github.com/facebookresearch/BLINK) 12 | 13 | ::: zshot.linker.LinkerBlink -------------------------------------------------------------------------------- /docs/knowgl_knowledge_extractor.md: -------------------------------------------------------------------------------- 1 | # KnowGL Knowledge Extractor 2 | 3 | The knowgl-large model is trained by combining Wikidata with an extended version of the training data in the REBEL dataset. Given a sentence, KnowGL generates triple(s) in the following format: 4 | ``` 5 | [(subject mention # subject label # subject type) | relation label | (object mention # object label # object type)] 6 | ``` 7 | If there are more than one triples generated, they are separated by $ in the output. The model achieves state-of-the-art results for relation extraction on the REBEL dataset. The generated labels (for the subject, relation, and object) and their types can be directly mapped to Wikidata IDs associated with them. 8 | 9 | This `KnowledgeExtractor` does not use any entity/relation pre-defined. 10 | 11 | - [Paper Rossiello et al. (AAAI 2023)](https://arxiv.org/pdf/2210.13952.pdf) 12 | - [Paper Mihindukulasooriya et al. (ISWC 2022)](https://arxiv.org/pdf/2207.05188.pdf) 13 | - [Original Model](https://huggingface.co/ibm/knowgl-large) 14 | 15 | ::: zshot.knowledge_extractor.KnowGL -------------------------------------------------------------------------------- /zshot/tests/relations_extractor/test_zsrc_relations_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | import spacy 4 | from spacy.tokens import Doc 5 | 6 | from zshot import PipelineConfig, Linker 7 | from zshot.relation_extractor import RelationsExtractorZSRC 8 | from zshot.tests.config import EX_RELATIONS, EX_DATASET_RELATIONS, EX_DOCS 9 | from zshot.utils.data_models import Span 10 | 11 | 12 | class DummyLinkerEnd2End(Linker): 13 | def predict(self, docs: Iterator[Doc], batch_size=None): 14 | return [[Span(187, 165, label='label', score=0.9), 15 | Span(111, 129, label='label', score=0.9)] for doc in docs] 16 | 17 | 18 | def test_zsrc_with_entities_config_dummy_annotator(): 19 | nlp = spacy.blank("en") 20 | config_zshot = PipelineConfig( 21 | linker=DummyLinkerEnd2End(), 22 | relations_extractor=RelationsExtractorZSRC(), 23 | relations=EX_RELATIONS, 24 | ) 25 | nlp.add_pipe("zshot", config=config_zshot, last=True) 26 | doc = nlp(EX_DOCS[1]) 27 | assert len(doc._.relations) == 0 28 | doc = nlp(EX_DATASET_RELATIONS['sentences'][0]) 29 | assert len(doc._.relations) == 1 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 International Business Machines 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /zshot/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import os 4 | import pathlib 5 | import shutil 6 | from urllib.request import urlopen 7 | 8 | import requests 9 | from tqdm.auto import tqdm 10 | 11 | 12 | def download_file(url, output_dir=".") -> pathlib.Path: 13 | """ 14 | Utility for downloading a file 15 | :param url: the file url 16 | :param output_dir: the output dir 17 | :return: 18 | """ 19 | filename = url.rsplit('/', 1)[1] 20 | path = pathlib.Path(os.path.join(output_dir, filename)).resolve() 21 | path.parent.mkdir(parents=True, exist_ok=True) 22 | with requests.get(url, stream=True) as r: 23 | logging.info(f"Downloading {url}") 24 | total_length = int(urlopen(url=url).info().get('Content-Length', 0)) 25 | if path.exists() and os.path.getsize(path) == total_length: 26 | return path 27 | r.raw.read = functools.partial(r.raw.read, decode_content=True) 28 | with tqdm.wrapattr(r.raw, "read", total=total_length, desc=f"Downloading {filename}") as raw: 29 | with path.open("wb") as output: 30 | shutil.copyfileobj(raw, output) 31 | return path 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from pathlib import Path 3 | 4 | this_directory = Path(__file__).parent 5 | long_description = (this_directory / "README.md").read_text() 6 | 7 | 8 | setup(name='zshot', 9 | description="Zero and Few shot named entity recognition", 10 | long_description_content_type='text/markdown', 11 | long_description=long_description, 12 | classifiers=[], 13 | keywords='NER Zero-Shot Few-Shot', 14 | author='IBM Research', 15 | author_email='', 16 | url='https://ibm.github.io/zshot', 17 | license='MIT', 18 | packages=find_packages(exclude=['ez_setup', 'examples', 'tests']), 19 | include_package_data=True, 20 | zip_safe=False, 21 | install_requires=[ 22 | "spacy>=3.4.1", 23 | "requests>=2.28", 24 | "tqdm>=4.62.3", 25 | "setuptools>=65.5.1", # Needed to install dynamic packages from source (e.g. Blink) 26 | "prettytable>=3.4", 27 | "torch>=1", 28 | "transformers>=4.20", 29 | "datasets>=2.9.1", 30 | "evaluate>=0.3.0", 31 | "seqeval>=1.2.2", 32 | ], 33 | entry_points=""" 34 | # -*- Entry points: -*- 35 | """, 36 | ) 37 | -------------------------------------------------------------------------------- /zshot/utils/data_models/relation_span.py: -------------------------------------------------------------------------------- 1 | import zlib 2 | 3 | from spacy.tokens import Span 4 | 5 | from zshot.utils.data_models import Relation 6 | 7 | 8 | class RelationSpan: 9 | def __init__(self, start: Span, end: Span, relation: Relation, score: float = None, kb_id: str = None): 10 | """ Create a RelationSpan that relates two entities 11 | 12 | :param start: Entity acting as subject in the relation 13 | :param end: Entity acting as object in the relation 14 | :param relation: Relation 15 | :param score: Score of the relation classification 16 | :param kb_id: ID of the Relation in a KB 17 | """ 18 | self.start = start 19 | self.end = end 20 | self.relation = relation 21 | self.score = score 22 | self.kb_id = kb_id 23 | 24 | def __repr__(self) -> str: 25 | return f"{self.relation.name}, {self.start}, {self.end}, {self.score}" 26 | 27 | def __hash__(self): 28 | return zlib.crc32(self.__repr__().encode()) 29 | 30 | def __eq__(self, other): 31 | return (type(other) is type(self) 32 | and self.start == other.start 33 | and self.end == other.end 34 | and self.relation == other.relation) 35 | -------------------------------------------------------------------------------- /zshot/utils/download_models.py: -------------------------------------------------------------------------------- 1 | from zshot.linker import LinkerRegen, LinkerSMXM, LinkerTARS, LinkerGLINER 2 | from zshot.mentions_extractor import MentionsExtractorFlair 3 | from zshot.mentions_extractor.utils import ExtractorType 4 | from zshot.relation_extractor.relation_extractor_zsrc import RelationsExtractorZSRC 5 | 6 | 7 | def load_all(): 8 | try: 9 | LinkerSMXM().load_models() 10 | except RuntimeError: 11 | pass 12 | try: 13 | LinkerRegen().load_models() 14 | except RuntimeError: 15 | pass 16 | try: 17 | LinkerTARS().load_models() 18 | except RuntimeError: 19 | pass 20 | try: 21 | LinkerGLINER().load_models() 22 | except RuntimeError: 23 | pass 24 | # try: 25 | # LinkerRelik().load_models() 26 | # except RuntimeError: 27 | # pass 28 | try: 29 | RelationsExtractorZSRC().load_models() 30 | except RuntimeError: 31 | pass 32 | try: 33 | MentionsExtractorFlair(ExtractorType.NER).load_models() 34 | except RuntimeError: 35 | pass 36 | try: 37 | MentionsExtractorFlair(ExtractorType.POS).load_models() 38 | except RuntimeError: 39 | pass 40 | 41 | 42 | if __name__ == "__main__": 43 | load_all() 44 | -------------------------------------------------------------------------------- /zshot/linker/linker_ensemble/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from zshot.utils.data_models import Span 4 | 5 | 6 | def sub_span_scoring_per_description(union_spans, spans): 7 | for k in union_spans.keys(): 8 | for span in spans: 9 | labels = {} 10 | for p in span: 11 | if k[0] <= p.start and k[1] >= p.end: 12 | if k[0] < p.start or k[1] > p.end: 13 | if p.label not in labels: 14 | labels[p.label] = p 15 | elif labels[p.label].score < p.score: 16 | labels[p.label] = p 17 | for p in labels.values(): 18 | union_spans[k].append(Span(label=p.label, score=p.score, start=k[0], end=k[1])) 19 | 20 | 21 | def normalize_group(group, require_length): 22 | group.extend(random.choices(group, k=require_length - len(group))) 23 | 24 | 25 | def get_enhance_entities(entities): 26 | entities_groups = [[ent for ent in entities if ent.name == name] for name in set([ent.name for ent in entities])] 27 | max_length = max([len(group) for group in entities_groups]) 28 | for group in entities_groups: 29 | normalize_group(group, max_length) 30 | 31 | return [list(g) for g in zip(*entities_groups)] 32 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | # Allows to run this workflow manually from the Actions tab 16 | workflow_dispatch: 17 | 18 | permissions: 19 | contents: read 20 | 21 | jobs: 22 | deploy: 23 | 24 | runs-on: ubuntu-latest 25 | 26 | steps: 27 | - uses: actions/checkout@v3 28 | - name: Set up Python 29 | uses: actions/setup-python@v3 30 | with: 31 | python-version: '3.x' 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install build 36 | - name: Build package 37 | run: python -m build 38 | - name: Publish package 39 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 40 | with: 41 | user: __token__ 42 | password: ${{ secrets.PYPI_API_TOKEN }} 43 | -------------------------------------------------------------------------------- /zshot/relation_extractor/zsrc/decide_entity_order.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from transformers import pipeline 3 | 4 | classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") 5 | 6 | 7 | def softmax(x): 8 | return np.exp(x) / sum(np.exp(x)) 9 | 10 | 11 | def score(premise, e1_text, e2_text, rel_name): 12 | sentence1 = "{} {} {}".format(e1_text, rel_name, e2_text) 13 | sentence2 = "{} {} {}".format(e2_text, rel_name, e1_text) 14 | output = classifier(premise, (sentence1, sentence2)) 15 | scores = output["scores"] 16 | if np.argmax(scores) == output["labels"].index(sentence1): 17 | return e1_text, e2_text 18 | else: 19 | return e2_text, e1_text 20 | 21 | 22 | def has_negation(premise, e1_text, e2_text, rel_name): 23 | if "is " in rel_name: 24 | negated_rel_name = rel_name.replace("is", "is not", 1) 25 | else: 26 | negated_rel_name = "does not " + rel_name 27 | negated = "{} {} {}".format(e1_text, negated_rel_name, e2_text) 28 | positive = "{} {} {}".format(e1_text, rel_name, e2_text) 29 | output = classifier(premise, (negated, positive)) 30 | return np.argmax(output["scores"]) == output["labels"].index(negated) 31 | 32 | 33 | def get_entity_order(e1_text, e2_text, rel_name, sentence): 34 | return score(sentence, e1_text, e2_text, rel_name) 35 | 36 | 37 | if __name__ == "__main__": 38 | print("models downloaded") 39 | -------------------------------------------------------------------------------- /zshot/evaluation/dataset/med_mentions/med_mentions.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from typing import Union, Optional 4 | 5 | from datasets import load_dataset, Split, Dataset, DatasetDict 6 | from huggingface_hub import hf_hub_download 7 | 8 | from zshot.evaluation.dataset.dataset import DatasetWithEntities 9 | from zshot.utils.data_models import Entity 10 | 11 | REPO_ID = "ibm/MedMentions-ZS" 12 | ENTITIES_FN = "entities.json" 13 | 14 | 15 | def load_medmentions_zs(split: Optional[Union[str, Split]] = None, **kwargs) -> Union[DatasetDict, Dataset]: 16 | dataset = load_dataset(REPO_ID, split=split, **kwargs) 17 | entities_file = hf_hub_download(repo_id=REPO_ID, 18 | repo_type='dataset', 19 | filename=ENTITIES_FN) 20 | with open(entities_file, "r") as f: 21 | entities = json.load(f) 22 | 23 | if split: 24 | entities_split = [Entity(name=k, description=v) for k, v in entities[get_simple_split(split)].items()] 25 | dataset = DatasetWithEntities(dataset.data, entities=entities_split) 26 | else: 27 | for split in dataset: 28 | entities_split = [Entity(name=k, description=v) for k, v in entities[split].items()] 29 | dataset[split] = DatasetWithEntities(dataset[split].data, entities=entities_split) 30 | 31 | return dataset 32 | 33 | 34 | def get_simple_split(split: str) -> str: 35 | first_not_alph = re.search(r'\W+', split) 36 | first_not_alph_chr = first_not_alph.start() if first_not_alph else len(split) 37 | return split[: first_not_alph_chr] 38 | -------------------------------------------------------------------------------- /docs/mentions_extractor.md: -------------------------------------------------------------------------------- 1 | # MentionsExtractor 2 | The **mentions extractor** will detect the possible entities (a.k.a. mentions), that will be then linked to a data source (e.g.: Wikidata) by the **linker**. 3 | 4 | Currently, there are 7 different **mentions extractors** supported, 2 of them are based on *SpaCy*, 2 of them are based on *Flair*, TARS, SMXM and GLiNER. The two different versions for *SpaCy* and *Flair* are similar, one is based on NERC and the other one is based on the linguistics (i.e.: using PoS and DP). The TARS and SMXM models can be used when the user wants to specify the mentions wanted to be extracted. 5 | 6 | The NERC approach will use NERC models to detect all the entities that have to be linked. This approach depends on the model that is being used, and the entities the model has been trained on, so depending on the use case and the target entities it may be not the best approach, as the entities may be not recognized by the NERC model and thus won't be linked. 7 | 8 | The linguistic approach relies on the idea that mentions will usually be a syntagma or a noun. Therefore, this approach detects nouns that are included in a syntagma and that act like objects, subjects, etc. This approach do not depend on the model (although the performance does), but a noun in a text should be always a noun, it doesn't depend on the dataset the model has been trained on. 9 | 10 | The SMXM model uses the description of the mentions to give the model information about them. 11 | 12 | TARS model will use the labels of the mentions to detect them. 13 | 14 | The GLiNER model will use the labels of the mentions to detect them. 15 | 16 | ::: zshot.MentionsExtractor -------------------------------------------------------------------------------- /zshot/tests/linker/test_gliner_linker.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | 4 | import pytest 5 | import spacy 6 | 7 | from zshot import PipelineConfig, Linker 8 | from zshot.linker import LinkerGLINER 9 | from zshot.tests.config import EX_DOCS, EX_ENTITIES 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @pytest.fixture(scope="module", autouse=True) 15 | def teardown(): 16 | logger.warning("Starting gliner tests") 17 | yield True 18 | gc.collect() 19 | 20 | 21 | def test_gliner_download(): 22 | linker = LinkerGLINER() 23 | linker.load_models() 24 | assert isinstance(linker, Linker) 25 | del linker.model, linker 26 | 27 | 28 | def test_gliner_linker(): 29 | nlp = spacy.blank("en") 30 | gliner_config = PipelineConfig( 31 | linker=LinkerGLINER(), 32 | entities=EX_ENTITIES 33 | ) 34 | nlp.add_pipe("zshot", config=gliner_config, last=True) 35 | assert "zshot" in nlp.pipe_names 36 | 37 | doc = nlp(EX_DOCS[1]) 38 | assert len(doc.ents) > 0 39 | docs = [doc for doc in nlp.pipe(EX_DOCS)] 40 | assert all(len(doc.ents) > 0 for doc in docs) 41 | del nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker 42 | nlp.remove_pipe('zshot') 43 | del doc, nlp, gliner_config 44 | 45 | 46 | def test_gliner_linker_no_entities(): 47 | nlp = spacy.blank("en") 48 | gliner_config = PipelineConfig( 49 | linker=LinkerGLINER(), 50 | entities=[] 51 | ) 52 | nlp.add_pipe("zshot", config=gliner_config, last=True) 53 | assert "zshot" in nlp.pipe_names 54 | 55 | doc = nlp(EX_DOCS[1]) 56 | assert len(doc.ents) == 0 57 | del nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker 58 | nlp.remove_pipe('zshot') 59 | del doc, nlp, gliner_config 60 | -------------------------------------------------------------------------------- /.github/workflows/python-tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python tests 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python 3.10 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: "3.10" 26 | cache: 'pip' # caching pip dependencies 27 | - name: Cache models 28 | uses: actions/cache@v3 29 | with: 30 | key: ${{ runner.os }}-build-models-cache 31 | path: | 32 | ~/.cache/huggingface 33 | ~/.cache/zshot 34 | ~/.pytest_cache 35 | - name: Install dependencies 36 | run: | 37 | python -m pip install --upgrade pip 38 | pip install -r requirements/devel.txt 39 | - name: Lint with flake8 40 | run: | 41 | flake8 --ignore E501,W503 zshot/ 42 | - name: Install Spacy pipeline and download models 43 | run: | 44 | python -m spacy download en_core_web_sm 45 | python -m zshot.utils.download_models 46 | - name: Test with pytest 47 | run: | 48 | python -m pytest --cov -v --cov-report xml:/home/runner/coverage.xml 49 | timeout-minutes: 30 50 | - name: Remove cache 51 | run: | 52 | rm -rf ~/.cache/huggingface 53 | rm -rf ~/.cache/zshot 54 | - name: Upload coverage to Codecov 55 | uses: codecov/codecov-action@v3.1.1 56 | with: 57 | files: /home/runner/coverage.xml 58 | -------------------------------------------------------------------------------- /zshot/evaluation/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from datasets import Dataset 4 | from datasets.table import Table 5 | 6 | from zshot.utils.data_models import Entity, Relation 7 | 8 | 9 | class DatasetWithRelations(Dataset): 10 | 11 | def __init__(self, arrow_table: Table, relations: List[Relation] = None, **kwargs): 12 | super().__init__(arrow_table=arrow_table, **kwargs) 13 | self.relations = relations 14 | 15 | def __repr__(self): 16 | return f"Dataset({{\n features: {list(self.features.keys())},\n num_rows: {self.num_rows}," \ 17 | f"\n relations: {[rel.name for rel in self.relations if self.relations is not None]}\n}})" 18 | 19 | 20 | class DatasetWithEntities(Dataset): 21 | 22 | def __init__(self, arrow_table: Table, entities: List[Entity] = None, **kwargs): 23 | super().__init__(arrow_table=arrow_table, **kwargs) 24 | self.entities = entities 25 | 26 | def __repr__(self): 27 | return f"Dataset({{\n features: {list(self.features.keys())},\n num_rows: {self.num_rows}," \ 28 | f"\n entities: {[ent.name for ent in self.entities if self.entities is not None]}\n}})" 29 | 30 | 31 | def create_dataset(gt: List[List[str]], sentences: List[str], entities) -> DatasetWithEntities: 32 | """ Create a simple dataset with entities from sentences and ground truth 33 | 34 | :param gt: Ground truth to use as labels. List of sentences in BIO format 35 | :param sentences: List of sentences 36 | :param entities: List of entities 37 | :return: Dataset with entities 38 | """ 39 | data_dict = { 40 | "tokens": [s.split(" ") for s in sentences], 41 | "ner_tags": gt, 42 | } 43 | dataset = DatasetWithEntities.from_dict(data_dict) 44 | dataset.entities = entities 45 | return dataset 46 | -------------------------------------------------------------------------------- /zshot/tests/linker/test_smxm_linker.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | 4 | import pytest 5 | import spacy 6 | 7 | from zshot import PipelineConfig, Linker 8 | from zshot.linker import LinkerSMXM 9 | from zshot.tests.config import EX_DOCS, EX_ENTITIES 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @pytest.fixture(scope="module", autouse=True) 15 | def teardown(): 16 | logger.warning("Starting smxm tests") 17 | yield True 18 | gc.collect() 19 | 20 | 21 | def test_smxm_download(): 22 | linker = LinkerSMXM() 23 | linker.load_models() 24 | assert isinstance(linker, Linker) 25 | del linker.tokenizer, linker.model, linker 26 | 27 | 28 | def test_smxm_linker(): 29 | nlp = spacy.blank("en") 30 | smxm_config = PipelineConfig( 31 | linker=LinkerSMXM(), 32 | entities=EX_ENTITIES 33 | ) 34 | nlp.add_pipe("zshot", config=smxm_config, last=True) 35 | assert "zshot" in nlp.pipe_names 36 | 37 | doc = nlp(EX_DOCS[1]) 38 | assert len(doc.ents) > 0 39 | docs = [doc for doc in nlp.pipe(EX_DOCS)] 40 | assert all(len(doc.ents) > 0 for doc in docs) 41 | del nlp.get_pipe('zshot').linker.tokenizer, nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker 42 | nlp.remove_pipe('zshot') 43 | del doc, nlp, smxm_config 44 | 45 | 46 | def test_smxm_linker_no_entities(): 47 | nlp = spacy.blank("en") 48 | smxm_config = PipelineConfig( 49 | linker=LinkerSMXM(), 50 | entities=[] 51 | ) 52 | nlp.add_pipe("zshot", config=smxm_config, last=True) 53 | assert "zshot" in nlp.pipe_names 54 | 55 | doc = nlp(EX_DOCS[1]) 56 | assert len(doc.ents) == 0 57 | del nlp.get_pipe('zshot').linker.tokenizer, nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker 58 | nlp.remove_pipe('zshot') 59 | del doc, nlp, smxm_config 60 | -------------------------------------------------------------------------------- /zshot/tests/linker/test_relik_linker.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | 4 | import pytest 5 | import spacy 6 | 7 | from zshot import PipelineConfig, Linker 8 | from zshot.linker import LinkerRelik 9 | from zshot.tests.config import EX_DOCS, EX_ENTITIES 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @pytest.fixture(scope="module", autouse=True) 15 | def teardown(): 16 | logger.warning("Starting relik tests") 17 | yield True 18 | gc.collect() 19 | 20 | 21 | @pytest.mark.skip(reason="Too expensive to run on every commit") 22 | def test_relik_download(): 23 | linker = LinkerRelik() 24 | linker.load_models() 25 | assert isinstance(linker, Linker) 26 | del linker.model, linker 27 | 28 | 29 | @pytest.mark.skip(reason="Too expensive to run on every commit") 30 | def test_relik_linker(): 31 | nlp = spacy.blank("en") 32 | relik_config = PipelineConfig( 33 | linker=LinkerRelik(), 34 | entities=EX_ENTITIES 35 | ) 36 | nlp.add_pipe("zshot", config=relik_config, last=True) 37 | assert "zshot" in nlp.pipe_names 38 | 39 | doc = nlp(EX_DOCS[1]) 40 | assert len(doc.ents) > 0 41 | del nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker 42 | nlp.remove_pipe('zshot') 43 | del doc, nlp, relik_config 44 | 45 | 46 | @pytest.mark.skip(reason="Too expensive to run on every commit") 47 | def test_relik_linker_no_entities(): 48 | nlp = spacy.blank("en") 49 | relik_config = PipelineConfig( 50 | linker=LinkerRelik(), 51 | entities=[] 52 | ) 53 | nlp.add_pipe("zshot", config=relik_config, last=True) 54 | assert "zshot" in nlp.pipe_names 55 | 56 | doc = nlp(EX_DOCS[1]) 57 | assert len(doc.ents) == 0 58 | del nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker 59 | nlp.remove_pipe('zshot') 60 | del doc, nlp, relik_config 61 | -------------------------------------------------------------------------------- /docs/entity_linking.md: -------------------------------------------------------------------------------- 1 | # Linker 2 | 3 | The **linker** will link the detected entities to a existing set of labels. Some of the **linkers**, however, are *end-to-end*, i.e. they don't need the **mentions extractor**, as they detect and link the entities at the same time. 4 | 5 | There are 6 **linkers** available currently, 4 of them are *end-to-end* and 2 are not. 6 | 7 | | Linker Name | end-to-end | Source Code | Paper | 8 | |:----------------------------------------------------:|:----------:|----------------------------------------------------------|--------------------------------------------------------------------| 9 | | [Blink](https://ibm.github.io/zshot/blink_linker/) | X | [Source Code](https://github.com/facebookresearch/BLINK) | [Paper](https://arxiv.org/pdf/1911.03814.pdf) | 10 | | [GENRE](https://ibm.github.io/zshot/genre_linker/) | X | [Source Code](https://github.com/facebookresearch/GENRE) | [Paper](https://arxiv.org/pdf/2010.00904.pdf) | 11 | | [SMXM](https://ibm.github.io/zshot/smxm_linker/) | ✓ | [Source Code](https://github.com/Raldir/Zero-shot-NERC) | [Paper](https://aclanthology.org/2021.acl-long.120/) | 12 | | [TARS](https://ibm.github.io/zshot/tars_linker/) | ✓ | [Source Code](https://github.com/flairNLP/flair) | [Paper](https://kishaloyhalder.github.io/pdfs/tars_coling2020.pdf) | 13 | | [GLINER](https://ibm.github.io/zshot/gliner_linker/) | ✓ | [Source Code](https://github.com/urchade/GLiNER) | [Paper](https://arxiv.org/abs/2311.08526) | 14 | | [RELIK](https://ibm.github.io/zshot/relik_linker/) | ✓ | [Source Code](https://github.com/SapienzaNLP/relik) | [Paper](https://arxiv.org/abs/2408.00103) | 15 | 16 | 17 | ::: zshot.Linker -------------------------------------------------------------------------------- /zshot/tests/linker/test_ensemble_linker.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | 3 | from zshot import PipelineConfig 4 | from zshot.linker.linker_ensemble import LinkerEnsemble 5 | from zshot.tests.linker.test_linker import DummyLinkerEnd2End 6 | from zshot.utils.data_models import Entity 7 | 8 | 9 | def test_ensemble_linker_max(): 10 | nlp = spacy.blank("en") 11 | nlp.add_pipe("zshot", config=PipelineConfig( 12 | entities=[ 13 | Entity(name="fruits", description="The sweet and fleshy product of a tree or other plant."), 14 | Entity(name="fruits", description="Names of fruits such as banana, oranges") 15 | ], 16 | linker=LinkerEnsemble( 17 | linkers=[ 18 | DummyLinkerEnd2End(), 19 | DummyLinkerEnd2End(), 20 | ] 21 | ) 22 | ), last=True) 23 | doc = nlp('Apple is a company name not a fruits like apples or orange') 24 | assert "zshot" in nlp.pipe_names 25 | assert len(doc.ents) > 0 26 | assert len(doc._.spans) > 0 27 | assert all([bool(ent.label_) for ent in doc.ents]) 28 | del doc, nlp 29 | 30 | 31 | def test_ensemble_linker_count(): 32 | nlp = spacy.blank("en") 33 | nlp.add_pipe("zshot", config=PipelineConfig( 34 | entities=[ 35 | Entity(name="fruits", description="The sweet and fleshy product of a tree or other plant."), 36 | Entity(name="fruits", description="Names of fruits such as banana, oranges") 37 | ], 38 | linker=LinkerEnsemble( 39 | linkers=[ 40 | DummyLinkerEnd2End(), 41 | DummyLinkerEnd2End(), 42 | ], 43 | strategy='count' 44 | ) 45 | ), last=True) 46 | 47 | doc = nlp('Apple is a company name not a fruits like apples or orange') 48 | assert "zshot" in nlp.pipe_names 49 | assert len(doc.ents) > 0 50 | assert len(doc._.spans) > 0 51 | assert all([bool(ent.label_) for ent in doc.ents]) 52 | del doc, nlp 53 | -------------------------------------------------------------------------------- /zshot/mentions_extractor/mentions_extractor_gliner.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List, Optional, Union 2 | 3 | import pkgutil 4 | 5 | from spacy.tokens import Doc 6 | from gliner import GLiNER 7 | 8 | from zshot.mentions_extractor.mentions_extractor import MentionsExtractor 9 | from zshot.config import MODELS_CACHE_PATH 10 | from zshot.utils.data_models import Span 11 | 12 | 13 | MODEL_NAME = "urchade/gliner_mediumv2.1" 14 | 15 | 16 | class MentionsExtractorGLINER(MentionsExtractor): 17 | """ GLiNER Mentions Extractor """ 18 | 19 | def __init__(self, model_name=MODEL_NAME): 20 | super().__init__() 21 | 22 | if not pkgutil.find_loader("gliner"): 23 | raise Exception("GLINER module not installed. You need to install gliner in order to use the GLINER Linker." 24 | "Install it with: pip install gliner") 25 | 26 | self.model_name = model_name 27 | self.model = None 28 | 29 | def load_models(self): 30 | """ Load GLINER model """ 31 | if self.model is None: 32 | self.model = GLiNER.from_pretrained(self.model_name, cache_dir=MODELS_CACHE_PATH).to(self.device) 33 | self.model.eval() 34 | 35 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: 36 | """ 37 | Perform the entity prediction 38 | :param docs: A list of spacy Document 39 | :param batch_size: The batch size 40 | :return: List Spans for each Document in docs 41 | """ 42 | if not self._mentions: 43 | return [] 44 | 45 | labels = [ent.name for ent in self._mentions] 46 | sentences = [doc.text for doc in docs] 47 | 48 | self.load_models() 49 | span_annotations = [] 50 | for sent in sentences: 51 | entities = self.model.predict_entities(sent, labels, threshold=0.5) 52 | span_annotations.append([Span.from_dict(ent) for ent in entities]) 53 | 54 | return span_annotations 55 | -------------------------------------------------------------------------------- /zshot/linker/linker_gliner.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List, Optional, Union 2 | 3 | import pkgutil 4 | 5 | from spacy.tokens import Doc 6 | from gliner import GLiNER 7 | 8 | from zshot.config import MODELS_CACHE_PATH 9 | from zshot.linker.linker import Linker 10 | from zshot.utils.data_models import Span 11 | 12 | 13 | MODEL_NAME = "urchade/gliner_mediumv2.1" 14 | 15 | 16 | class LinkerGLINER(Linker): 17 | """ GLINER linker """ 18 | 19 | def __init__(self, model_name=MODEL_NAME): 20 | super().__init__() 21 | 22 | if not pkgutil.find_loader("gliner"): 23 | raise Exception("GLINER module not installed. You need to install gliner in order to use the GLINER Linker." 24 | "Install it with: pip install gliner") 25 | 26 | self.model_name = model_name 27 | self.model = None 28 | 29 | @property 30 | def is_end2end(self) -> bool: 31 | """ GLINER is end2end model""" 32 | return True 33 | 34 | def load_models(self): 35 | """ Load GLINER model """ 36 | if self.model is None: 37 | self.model = GLiNER.from_pretrained(self.model_name, cache_dir=MODELS_CACHE_PATH).to(self.device) 38 | self.model.eval() 39 | 40 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: 41 | """ 42 | Perform the entity prediction 43 | :param docs: A list of spacy Document 44 | :param batch_size: The batch size 45 | :return: List Spans for each Document in docs 46 | """ 47 | if not self._entities: 48 | return [] 49 | 50 | labels = [ent.name for ent in self._entities] 51 | sentences = [doc.text for doc in docs] 52 | 53 | self.load_models() 54 | span_annotations = [] 55 | for sent in sentences: 56 | entities = self.model.predict_entities(sent, labels, threshold=0.5) 57 | span_annotations.append([Span.from_dict(ent) for ent in entities]) 58 | 59 | return span_annotations 60 | -------------------------------------------------------------------------------- /zshot/evaluation/dataset/fewrel/fewrel.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Dict 2 | 3 | from datasets import load_dataset, Split, Dataset 4 | from tqdm import tqdm 5 | 6 | from zshot.evaluation.dataset.dataset import DatasetWithRelations 7 | from zshot.utils.data_models import Relation 8 | 9 | 10 | def get_entity_data(e, tokenized_sentence): 11 | d = {"start": None, "end": None, "label": e["type"]} 12 | token_indices = e["indices"][0] 13 | s = "" 14 | curr_idx = 0 15 | for idx, token in enumerate(tokenized_sentence): 16 | if idx == token_indices[0]: 17 | d["start"] = curr_idx 18 | s += token + " " 19 | curr_idx = len(s.strip()) 20 | if idx == token_indices[-1]: 21 | d["end"] = curr_idx 22 | d["sentence"] = s.strip() 23 | return d 24 | 25 | 26 | def load_few_rel_zs(split: Optional[Union[str, Split]] = "val_wiki") -> Union[Dict[DatasetWithRelations, 27 | Dataset], Dataset]: 28 | dataset = load_dataset("few_rel", split=split, trust_remote_code=True) 29 | relations_descriptions = dataset["names"] 30 | tokenized_sentences = dataset["tokens"] 31 | sentences = [" ".join(tokens) for tokens in tokenized_sentences] 32 | gt = [item[0] for item in relations_descriptions] 33 | heads = dataset["head"] 34 | tails = dataset["tail"] 35 | entities_data = [] 36 | for idx in tqdm(range(len(tokenized_sentences))): 37 | e1 = heads[idx] 38 | e2 = tails[idx] 39 | entities_data.append( 40 | [ 41 | get_entity_data(e1, tokenized_sentences[idx]), 42 | get_entity_data(e2, tokenized_sentences[idx]), 43 | ] 44 | ) 45 | relations = [Relation(name=name, description=desc) for name, desc in 46 | set([(i, j) for i, j in relations_descriptions])] 47 | dataset = Dataset.from_dict({ 48 | "sentences": sentences, 49 | "sentence_entities": entities_data, 50 | "labels": gt, 51 | }) 52 | dataset.relations = relations 53 | return dataset 54 | -------------------------------------------------------------------------------- /zshot/tests/relations_extractor/test_relations_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional, Union, List 2 | 3 | import spacy 4 | from spacy.tokens.doc import Doc 5 | 6 | from zshot import PipelineConfig, RelationsExtractor 7 | from zshot.tests.config import EX_DOCS, EX_ENTITIES, EX_RELATIONS 8 | from zshot.tests.linker.test_linker import DummyLinkerEnd2End 9 | from zshot.utils.data_models import Relation 10 | from zshot.utils.data_models.relation_span import RelationSpan 11 | 12 | 13 | class DummyRelationsExtractor(RelationsExtractor): 14 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[RelationSpan]]: 15 | relations_pred = [] 16 | for doc in docs: 17 | relations = [] 18 | for span in doc._.spans: 19 | relations.append(RelationSpan(start=span, end=span, 20 | relation=Relation(name="rel", description="desc"), score=1)) 21 | relations_pred.append(relations) 22 | return relations_pred 23 | 24 | 25 | def test_dummy_relations_extractor_with_entities_config(): 26 | nlp = spacy.blank("en") 27 | config_zshot = PipelineConfig( 28 | linker=DummyLinkerEnd2End(), 29 | relations_extractor=DummyRelationsExtractor(), 30 | entities=EX_ENTITIES, 31 | relations=EX_RELATIONS, 32 | ) 33 | nlp.add_pipe("zshot", config=config_zshot, last=True) 34 | assert "zshot" in nlp.pipe_names 35 | doc = nlp(EX_DOCS[0]) 36 | assert len(doc.ents) > 0 37 | assert len(doc._.relations) > 0 38 | 39 | 40 | def test_dummy_relations_extractor_device(): 41 | nlp = spacy.blank("en") 42 | config_zshot = PipelineConfig( 43 | linker=DummyLinkerEnd2End(), 44 | relations_extractor=DummyRelationsExtractor(), 45 | entities=EX_ENTITIES, 46 | relations=EX_RELATIONS, 47 | device="cpu", 48 | ) 49 | nlp.add_pipe("zshot", config=config_zshot, last=True) 50 | assert "zshot" in nlp.pipe_names 51 | doc = nlp(EX_DOCS[0]) 52 | assert len(doc.ents) > 0 53 | assert len(doc._.relations) > 0 54 | -------------------------------------------------------------------------------- /zshot/tests/mentions_extractor/test_smxm_mentions_extractor.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | 4 | import pytest 5 | import spacy 6 | 7 | from zshot import PipelineConfig, MentionsExtractor 8 | from zshot.mentions_extractor import MentionsExtractorSMXM 9 | from zshot.tests.config import EX_DOCS, EX_ENTITIES 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @pytest.fixture(scope="module", autouse=True) 15 | def teardown(): 16 | logger.warning("Starting smxm tests") 17 | yield True 18 | gc.collect() 19 | 20 | 21 | def test_smxm_download(): 22 | mentions_extractor = MentionsExtractorSMXM() 23 | mentions_extractor.load_models() 24 | assert isinstance(mentions_extractor, MentionsExtractor) 25 | del mentions_extractor 26 | 27 | 28 | def test_smxm_mentions_extractor(): 29 | nlp = spacy.blank("en") 30 | smxm_config = PipelineConfig( 31 | mentions_extractor=MentionsExtractorSMXM(), 32 | mentions=EX_ENTITIES 33 | ) 34 | nlp.add_pipe("zshot", config=smxm_config, last=True) 35 | assert "zshot" in nlp.pipe_names 36 | 37 | doc = nlp(EX_DOCS[1]) 38 | assert len(doc._.mentions) > 0 39 | nlp.remove_pipe('zshot') 40 | del doc, nlp 41 | 42 | 43 | def test_smxm_mentions_extractor_pipeline(): 44 | nlp = spacy.blank("en") 45 | smxm_config = PipelineConfig( 46 | mentions_extractor=MentionsExtractorSMXM(), 47 | mentions=EX_ENTITIES 48 | ) 49 | nlp.add_pipe("zshot", config=smxm_config, last=True) 50 | assert "zshot" in nlp.pipe_names 51 | 52 | docs = [doc for doc in nlp.pipe(EX_DOCS)] 53 | assert all(len(doc._.mentions) > 0 for doc in docs) 54 | nlp.remove_pipe('zshot') 55 | del docs, nlp 56 | 57 | 58 | def test_smxm_mentions_extractor_no_entities(): 59 | nlp = spacy.blank("en") 60 | smxm_config = PipelineConfig( 61 | mentions_extractor=MentionsExtractorSMXM(), 62 | mentions=[] 63 | ) 64 | nlp.add_pipe("zshot", config=smxm_config, last=True) 65 | assert "zshot" in nlp.pipe_names 66 | 67 | doc = nlp(EX_DOCS[1]) 68 | assert len(doc._.mentions) == 0 69 | nlp.remove_pipe('zshot') 70 | del doc, nlp 71 | -------------------------------------------------------------------------------- /zshot/tests/mentions_extractor/test_gliner_mentions_extractor.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | 4 | import pytest 5 | import spacy 6 | 7 | from zshot import PipelineConfig, MentionsExtractor 8 | from zshot.mentions_extractor import MentionsExtractorGLINER 9 | from zshot.tests.config import EX_DOCS, EX_ENTITIES 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @pytest.fixture(scope="module", autouse=True) 15 | def teardown(): 16 | logger.warning("Starting gliner tests") 17 | yield True 18 | gc.collect() 19 | 20 | 21 | def test_gliner_download(): 22 | mentions_extractor = MentionsExtractorGLINER() 23 | mentions_extractor.load_models() 24 | assert isinstance(mentions_extractor, MentionsExtractor) 25 | del mentions_extractor 26 | 27 | 28 | def test_gliner_mentions_extractor(): 29 | nlp = spacy.blank("en") 30 | gliner_config = PipelineConfig( 31 | mentions_extractor=MentionsExtractorGLINER(), 32 | mentions=EX_ENTITIES 33 | ) 34 | nlp.add_pipe("zshot", config=gliner_config, last=True) 35 | assert "zshot" in nlp.pipe_names 36 | 37 | doc = nlp(EX_DOCS[1]) 38 | assert len(doc._.mentions) > 0 39 | nlp.remove_pipe('zshot') 40 | del doc, nlp 41 | 42 | 43 | def test_gliner_mentions_extractor_pipeline(): 44 | nlp = spacy.blank("en") 45 | gliner_config = PipelineConfig( 46 | mentions_extractor=MentionsExtractorGLINER(), 47 | mentions=EX_ENTITIES 48 | ) 49 | nlp.add_pipe("zshot", config=gliner_config, last=True) 50 | assert "zshot" in nlp.pipe_names 51 | 52 | docs = [doc for doc in nlp.pipe(EX_DOCS)] 53 | assert all(len(doc._.mentions) > 0 for doc in docs) 54 | nlp.remove_pipe('zshot') 55 | del docs, nlp 56 | 57 | 58 | def test_gliner_mentions_extractor_no_entities(): 59 | nlp = spacy.blank("en") 60 | gliner_config = PipelineConfig( 61 | mentions_extractor=MentionsExtractorGLINER(), 62 | mentions=[] 63 | ) 64 | nlp.add_pipe("zshot", config=gliner_config, last=True) 65 | assert "zshot" in nlp.pipe_names 66 | 67 | doc = nlp(EX_DOCS[1]) 68 | assert len(doc._.mentions) == 0 69 | nlp.remove_pipe('zshot') 70 | del doc, nlp 71 | -------------------------------------------------------------------------------- /zshot/mentions_extractor/mentions_extractor_smxm.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List, Optional, Union 2 | 3 | from spacy.tokens import Doc 4 | from transformers import BertTokenizerFast 5 | 6 | from zshot.config import MODELS_CACHE_PATH 7 | from zshot.mentions_extractor.mentions_extractor import MentionsExtractor 8 | from zshot.utils.data_models import Span 9 | from zshot.utils.models.smxm.model import BertTaggerMultiClass 10 | from zshot.utils.models.smxm.utils import ( 11 | get_entities_names_descriptions, 12 | smxm_predict, 13 | ) 14 | 15 | ONTONOTES_MODEL_NAME = "ibm/smxm" 16 | 17 | 18 | class MentionsExtractorSMXM(MentionsExtractor): 19 | """ SMXM Mentions Extractor """ 20 | 21 | def __init__(self, model_name=ONTONOTES_MODEL_NAME): 22 | super().__init__() 23 | 24 | self.tokenizer = BertTokenizerFast.from_pretrained( 25 | "bert-large-cased", truncation_side="left", cache_dir=MODELS_CACHE_PATH 26 | ) 27 | 28 | self.model_name = model_name 29 | self.model = None 30 | 31 | def load_models(self): 32 | """ Load SMXM model """ 33 | if self.model is None: 34 | self.model = BertTaggerMultiClass.from_pretrained( 35 | self.model_name, output_hidden_states=True, cache_dir=MODELS_CACHE_PATH 36 | ).to(self.device) 37 | self.model.eval() 38 | 39 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: 40 | """ 41 | Perform the entity prediction 42 | :param docs: A list of spacy Document 43 | :param batch_size: The batch size 44 | :return: List Spans for each Document in docs 45 | """ 46 | if not self._mentions: 47 | return [] 48 | 49 | entity_labels, entity_descriptions = get_entities_names_descriptions(self._mentions) 50 | sentences = [doc.text for doc in docs] 51 | 52 | self.load_models() 53 | self.model.eval() 54 | 55 | span_annotations = smxm_predict(self.model, self.tokenizer, 56 | sentences, entity_labels, entity_descriptions, 57 | batch_size) 58 | 59 | return span_annotations 60 | -------------------------------------------------------------------------------- /zshot/utils/mappings.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, List 3 | 4 | from huggingface_hub import hf_hub_download 5 | 6 | from zshot.config import MODELS_CACHE_PATH 7 | from zshot.utils.data_models import Span 8 | 9 | 10 | REPO_ID = "ibm/regen-disambiguation" 11 | WIKIPEDIA_MAP = "wikipedia_map_id.json" 12 | DBPEDIA_MAP = "dbpedia_map_id.json" 13 | 14 | 15 | def load_wikipedia_mapping() -> Dict[str, str]: # pragma: no cover 16 | """ 17 | Load the wikipedia trie from the HB hub 18 | :return: The Wikipedia trie 19 | """ 20 | wikipedia_map = hf_hub_download(repo_id=REPO_ID, 21 | repo_type='model', 22 | filename=WIKIPEDIA_MAP, 23 | cache_dir=MODELS_CACHE_PATH) 24 | with open(wikipedia_map, "r") as f: 25 | wikipedia_map = json.load(f) 26 | return wikipedia_map 27 | 28 | 29 | def spans_to_wikipedia(spans: List[Span]) -> List[str]: # pragma: no cover 30 | """ 31 | Generate wikipedia link for spans 32 | :return: The list of generated links 33 | """ 34 | links = [] 35 | wikipedia_map = load_wikipedia_mapping() 36 | for s in spans: 37 | if s.label in wikipedia_map: 38 | links.append(f"https://en.wikipedia.org/wiki?curid={wikipedia_map[s.label]}") 39 | else: 40 | links.append(None) 41 | return links 42 | 43 | 44 | def load_dbpedia_mapping() -> Dict[str, str]: # pragma: no cover 45 | """ 46 | Load the dbpedia trie from the HB hub 47 | :return: The DBpedia trie 48 | """ 49 | dbpedia_map = hf_hub_download(repo_id=REPO_ID, 50 | repo_type='model', 51 | filename=DBPEDIA_MAP, 52 | cache_dir=MODELS_CACHE_PATH) 53 | with open(dbpedia_map, "r") as f: 54 | dbpedia_map = json.load(f) 55 | return dbpedia_map 56 | 57 | 58 | def spans_to_dbpedia(spans: List[Span]) -> List[str]: # pragma: no cover 59 | """ 60 | Generate dbpedia link for spans 61 | :return: The list of generated links 62 | """ 63 | dbpedia_map = load_dbpedia_mapping() 64 | links = [dbpedia_map[s.label] for s in spans if s.label in dbpedia_map] 65 | return links 66 | -------------------------------------------------------------------------------- /zshot/linker/linker_smxm.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List, Optional, Union 2 | 3 | from spacy.tokens import Doc 4 | from transformers import BertTokenizerFast 5 | 6 | from zshot.config import MODELS_CACHE_PATH 7 | from zshot.linker.linker import Linker 8 | from zshot.utils.data_models import Span 9 | from zshot.utils.models.smxm.model import BertTaggerMultiClass 10 | from zshot.utils.models.smxm.utils import ( 11 | get_entities_names_descriptions, 12 | smxm_predict 13 | ) 14 | 15 | ONTONOTES_MODEL_NAME = "ibm/smxm" 16 | 17 | 18 | class LinkerSMXM(Linker): 19 | """ SMXM linker """ 20 | 21 | def __init__(self, model_name=ONTONOTES_MODEL_NAME): 22 | super().__init__() 23 | 24 | self.tokenizer = BertTokenizerFast.from_pretrained( 25 | "bert-large-cased", truncation_side="left", cache_dir=MODELS_CACHE_PATH 26 | ) 27 | 28 | self.model_name = model_name 29 | self.model = None 30 | 31 | @property 32 | def is_end2end(self) -> bool: 33 | """ SMXM is end2end model""" 34 | return True 35 | 36 | def load_models(self): 37 | """ Load SMXM model """ 38 | if self.model is None: 39 | self.model = BertTaggerMultiClass.from_pretrained( 40 | self.model_name, output_hidden_states=True, cache_dir=MODELS_CACHE_PATH 41 | ).to(self.device) 42 | self.model.eval() 43 | 44 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: 45 | """ 46 | Perform the entity prediction 47 | :param docs: A list of spacy Document 48 | :param batch_size: The batch size 49 | :return: List Spans for each Document in docs 50 | """ 51 | if not self._entities: 52 | return [] 53 | 54 | entity_labels, entity_descriptions = get_entities_names_descriptions(self._entities) 55 | sentences = [doc.text for doc in docs] 56 | 57 | self.load_models() 58 | self.model.eval() 59 | 60 | span_annotations = smxm_predict(self.model, self.tokenizer, 61 | sentences, entity_labels, entity_descriptions, 62 | batch_size) 63 | 64 | return span_annotations 65 | -------------------------------------------------------------------------------- /zshot/evaluation/metrics/rel_eval.py: -------------------------------------------------------------------------------- 1 | import evaluate 2 | from sklearn.metrics import precision_recall_fscore_support 3 | from sklearn.metrics import accuracy_score 4 | import datasets 5 | 6 | _KWARGS_DESCRIPTION = """ 7 | Produces labelling scores along with its sufficient statistics 8 | from a source against one or more references. 9 | Args: 10 | predictions: List of List of predicted labels (Estimated targets as returned by a tagger) 11 | references: List of List of reference labels (Ground truth (correct) target values) 12 | """ 13 | 14 | 15 | class RelEval(evaluate.Metric): 16 | def _info(self): 17 | return evaluate.MetricInfo( 18 | description="RelEval is a framework for relation extraction methods evaluation.", 19 | inputs_description=_KWARGS_DESCRIPTION, 20 | citation="alp@ibm.com", 21 | features=datasets.Features( 22 | { 23 | "predictions": datasets.Value("string", id="label"), 24 | "references": datasets.Value("string", id="label"), 25 | } 26 | ), 27 | ) 28 | 29 | def _compute( 30 | self, 31 | predictions, 32 | references, 33 | ): 34 | scores = {} 35 | p, r, f1, _ = precision_recall_fscore_support( 36 | references, predictions, average="micro" 37 | ) 38 | scores["overall_precision_micro"] = p 39 | scores["overall_recall_micro"] = r 40 | scores["overall_f1_micro"] = f1 41 | 42 | p, r, f1, _ = precision_recall_fscore_support( 43 | references, predictions, average="macro" 44 | ) 45 | scores["overall_precision_macro"] = p 46 | scores["overall_recall_macro"] = r 47 | scores["overall_f1_macro"] = f1 48 | 49 | acc = accuracy_score(references, predictions, normalize=False) 50 | scores["overall_accuracy"] = acc 51 | 52 | lab = sorted(list(set(references))) 53 | p, r, f1, supp = precision_recall_fscore_support( 54 | references, predictions, average=None, labels=lab 55 | ) 56 | for idx, lab in enumerate(lab): 57 | scores[lab] = {'precision': p[idx], 'recall': r[idx], 58 | 'f1': f1[idx], 'number': supp[idx]} 59 | return scores 60 | -------------------------------------------------------------------------------- /zshot/linker/linker_regen/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from huggingface_hub import hf_hub_download 4 | 5 | from zshot.config import MODELS_CACHE_PATH 6 | from zshot.linker.linker_regen.trie import Trie 7 | 8 | REPO_ID = "ibm/regen-disambiguation" 9 | WIKIPEDIA_TRIE_FILE_NAME = "wikipedia_trie.pkl" 10 | DBPEDIA_TRIE_FILE_NAME = "dbpedia_trie.pkl" 11 | 12 | 13 | def create_input(sentence, max_length, start_delimiter, end_delimiter): 14 | sent_list = sentence.split(" ") 15 | if len(sent_list) < max_length: 16 | return sentence 17 | else: 18 | end_delimiter_index = sent_list.index(end_delimiter) 19 | start_delimiter_index = sent_list.index(start_delimiter) 20 | half_context = (max_length - (end_delimiter_index - start_delimiter_index)) // 2 21 | left_index = max(0, start_delimiter_index - half_context) 22 | right_index = min(len(sent_list), 23 | end_delimiter_index + half_context + (half_context - (start_delimiter_index - left_index))) 24 | if right_index == end_delimiter_index: 25 | right_index += 1 26 | 27 | left_index = left_index - max(0, (half_context - (right_index - end_delimiter_index))) 28 | return " ".join(sent_list[left_index:right_index]) 29 | 30 | 31 | def load_wikipedia_trie() -> Trie: # pragma: no cover 32 | """ 33 | Load the wikipedia trie from the HB hub 34 | :return: The Wikipedia trie 35 | """ 36 | wikipedia_trie_file = hf_hub_download(repo_id=REPO_ID, 37 | repo_type='model', 38 | filename=WIKIPEDIA_TRIE_FILE_NAME, 39 | cache_dir=MODELS_CACHE_PATH) 40 | with open(wikipedia_trie_file, "rb") as f: 41 | wikipedia_trie = pickle.load(f) 42 | return wikipedia_trie 43 | 44 | 45 | def load_dbpedia_trie() -> Trie: # pragma: no cover 46 | """ 47 | Load the dbpedia trie from the HB hub 48 | :return: The DBpedia trie 49 | """ 50 | dbpedia_trie_file = hf_hub_download(repo_id=REPO_ID, 51 | repo_type='model', 52 | filename=DBPEDIA_TRIE_FILE_NAME, 53 | cache_dir=MODELS_CACHE_PATH) 54 | with open(dbpedia_trie_file, "rb") as f: 55 | dbpedia_trie = pickle.load(f) 56 | return dbpedia_trie 57 | -------------------------------------------------------------------------------- /zshot/evaluation/pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LinkerPipeline: 5 | def __init__(self, nlp, batch_size=100): 6 | self.nlp = nlp 7 | self.task = "token-classification" 8 | self.batch_size = batch_size 9 | 10 | def __call__(self, *args, **kwargs): 11 | res = [] 12 | docs = self.nlp.pipe(args[0], batch_size=self.batch_size) 13 | for doc in docs: 14 | res_doc = [] 15 | for span in doc._.spans: 16 | label = { 17 | "entity": span.label, 18 | "score": span.score, 19 | "word": doc.text[span.start: span.end], 20 | "start": span.start, 21 | "end": span.end, 22 | } 23 | res_doc.append(label) 24 | res.append(res_doc) 25 | 26 | return res 27 | 28 | 29 | class MentionsExtractorPipeline: 30 | def __init__(self, nlp, batch_size=100): 31 | self.nlp = nlp 32 | self.task = 'token-classification' 33 | self.batch_size = batch_size 34 | 35 | def __call__(self, *args, **kwargs): 36 | res = [] 37 | docs = self.nlp.pipe(args[0], batch_size=self.batch_size) 38 | for doc in docs: 39 | res_doc = [] 40 | for span in doc._.mentions: 41 | label = { 42 | 'entity': "MENTION", 43 | 'word': doc.text[span.start:span.end], 44 | 'start': span.start, 'end': span.end 45 | } 46 | res_doc.append(label) 47 | res.append(res_doc) 48 | 49 | return res 50 | 51 | 52 | class RelationExtractorPipeline: 53 | def __init__(self, nlp, batch_size=100): 54 | self.nlp = nlp 55 | self.task = "relation-extraction" 56 | self.batch_size = batch_size 57 | 58 | def __call__(self, *args, **kwargs): 59 | res = [] 60 | # pdb.set_trace() 61 | docs = self.nlp.pipe(args[0], batch_size=self.batch_size) 62 | for doc in docs: 63 | probs = [] 64 | rels = [] 65 | # pdb.set_trace() 66 | for r in doc._.relations: 67 | probs.append(r.score) 68 | rels.append(r.relation) 69 | best_idx = np.argmax(probs) 70 | rel = rels[best_idx] 71 | res.append(rel.name) 72 | return res 73 | -------------------------------------------------------------------------------- /zshot/tests/mentions_extractor/test_mention_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | import spacy 4 | from spacy.tokens.doc import Doc 5 | 6 | from zshot import MentionsExtractor, PipelineConfig 7 | from zshot.tests.config import EX_DOCS, EX_ENTITIES 8 | from zshot.utils.data_models import Span 9 | 10 | 11 | class DummyMentionsExtractor(MentionsExtractor): 12 | def predict(self, docs: Iterator[Doc], batch_size=None): 13 | return [ 14 | [Span(0, len(doc.text) - 1)] 15 | for doc in docs 16 | ] 17 | 18 | 19 | class DummyMentionsExtractorWithNER(MentionsExtractor): 20 | @property 21 | def require_existing_ner(self) -> bool: 22 | return True 23 | 24 | def predict(self, docs: Iterator[Doc], batch_size=None): 25 | return [ 26 | [Span(0, len(doc.text) - 1)] 27 | for doc in docs 28 | ] 29 | 30 | 31 | class DummyMentionsExtractorWithEntities(MentionsExtractor): 32 | def predict(self, docs: Iterator[Doc], batch_size=None): 33 | return [ 34 | [Span(0, len(doc.text) - 1)] 35 | for idx, doc in enumerate(docs) 36 | ] 37 | 38 | 39 | def test_dummy_mentions_extractor(): 40 | nlp = spacy.blank("en") 41 | config_zshot = PipelineConfig(mentions_extractor=DummyMentionsExtractorWithEntities()) 42 | nlp.add_pipe("zshot", config=config_zshot, last=True) 43 | assert "zshot" in nlp.pipe_names 44 | doc = nlp(EX_DOCS[1]) 45 | assert doc.ents == () 46 | assert len(doc._.mentions) > 0 47 | del doc, nlp 48 | 49 | 50 | def test_dummy_mentions_extractor_device(): 51 | nlp = spacy.blank("en") 52 | config_zshot = PipelineConfig(mentions_extractor=DummyMentionsExtractorWithEntities(), device="cpu") 53 | nlp.add_pipe("zshot", config=config_zshot, last=True) 54 | assert "zshot" in nlp.pipe_names 55 | doc = nlp(EX_DOCS[1]) 56 | assert doc.ents == () 57 | assert len(doc._.mentions) > 0 58 | del doc, nlp 59 | 60 | 61 | def test_dummy_mentions_extractor_with_entities_config(): 62 | nlp = spacy.blank("en") 63 | config_zshot = PipelineConfig(mentions_extractor=DummyMentionsExtractorWithEntities(), 64 | mentions=EX_ENTITIES) 65 | nlp.add_pipe("zshot", config=config_zshot, last=True) 66 | assert "zshot" in nlp.pipe_names 67 | doc = nlp(EX_DOCS[1]) 68 | assert doc.ents == () 69 | assert len(doc._.mentions) > 0 70 | del doc, nlp 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/ 131 | .DS_Store 132 | codecov 133 | -------------------------------------------------------------------------------- /zshot/tests/linker/test_blinker_linker.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | import pkgutil 4 | 5 | import pytest 6 | import spacy 7 | 8 | from zshot import PipelineConfig, Linker 9 | from zshot.linker import LinkerBlink 10 | from zshot.tests.config import EX_DOCS, EX_ENTITIES 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | @pytest.fixture(scope="module", autouse=True) 16 | def teardown(): 17 | logger.warning("Starting blink tests") 18 | yield True 19 | gc.collect() 20 | 21 | 22 | @pytest.mark.skipif(not pkgutil.find_loader("blink"), reason="BLINK is not installed") 23 | def test_blink(): 24 | linker = LinkerBlink() 25 | with pytest.raises(Exception): 26 | assert len(linker.entities_list) > 1 27 | with pytest.raises(Exception): 28 | assert len(linker.local_id2wikipedia_id) > 1 29 | with pytest.raises(Exception): 30 | assert linker.local_name2wikipedia_url('IBM').startswith("https://en.wikipedia.org/wiki") 31 | 32 | 33 | @pytest.mark.skip(reason="Too expensive to run on every commit") 34 | def test_blink_download(): 35 | linker = LinkerBlink() 36 | linker.load_models() 37 | assert isinstance(linker, Linker) 38 | del linker.tokenizer, linker.model, linker 39 | 40 | 41 | @pytest.mark.skip(reason="Too expensive to run on every commit") 42 | def test_blink_linker(): 43 | nlp = spacy.blank("en") 44 | blink_config = PipelineConfig( 45 | linker=LinkerBlink(), 46 | entities=EX_ENTITIES 47 | ) 48 | nlp.add_pipe("zshot", config=blink_config, last=True) 49 | assert "zshot" in nlp.pipe_names 50 | 51 | doc = nlp(EX_DOCS[1]) 52 | assert len(doc.ents) > 0 53 | docs = [doc for doc in nlp.pipe(EX_DOCS)] 54 | assert all(len(doc.ents) > 0 for doc in docs) 55 | del nlp.get_pipe('zshot').linker.tokenizer, nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker 56 | nlp.remove_pipe('zshot') 57 | del doc, nlp, blink_config 58 | 59 | 60 | @pytest.mark.skip(reason="Too expensive to run on every commit") 61 | def test_blink_linker_no_entities(): 62 | nlp = spacy.blank("en") 63 | blink_config = PipelineConfig( 64 | linker=LinkerBlink(), 65 | entities=[] 66 | ) 67 | nlp.add_pipe("zshot", config=blink_config, last=True) 68 | assert "zshot" in nlp.pipe_names 69 | 70 | doc = nlp(EX_DOCS[1]) 71 | assert len(doc.ents) == 0 72 | del nlp.get_pipe('zshot').linker.tokenizer, nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker 73 | nlp.remove_pipe('zshot') 74 | del doc, nlp, blink_config 75 | -------------------------------------------------------------------------------- /zshot/utils/data_models/span.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | 4 | import zlib 5 | 6 | import spacy 7 | from spacy.tokens import Doc 8 | 9 | 10 | class Span: 11 | def __init__(self, start: int, end: int, label: str = None, score: float = None, kb_id: str = None): 12 | """ Class for handling Spans with scores 13 | 14 | :param start: Start char idx of the span 15 | :param end: End char idx of the span 16 | :param label: Label of the span (category it belongs to, e.g.: PER) 17 | :param score: Score of the prediction 18 | :param kb_id: ID to Knowledge base (e.g.: wikipedia) 19 | """ 20 | self.start = start 21 | self.end = end 22 | self.label = label 23 | self.score = score 24 | self.kb_id = kb_id 25 | 26 | def __repr__(self) -> str: 27 | return f"{self.label}, {self.start}, {self.end}, {self.score}" 28 | 29 | def __hash__(self): 30 | return zlib.crc32(self.__repr__().encode()) 31 | 32 | def __eq__(self, other: Any): 33 | return (type(other) is type(self) 34 | and self.start == other.start 35 | and self.end == other.end 36 | and self.label == other.label 37 | and self.score == other.score) 38 | 39 | def to_spacy_span(self, doc: Doc) -> spacy.tokens.Span: 40 | kwargs = { 41 | 'alignment_mode': 'expand' 42 | } 43 | if self.kb_id: 44 | kwargs.update({'kb_id': self.kb_id}) 45 | if self.label: 46 | kwargs.update({'label': self.label}) 47 | 48 | return doc.char_span(self.start, self.end, **kwargs) 49 | 50 | @staticmethod 51 | def from_spacy_span(spacy_span: spacy.tokens.Span, score=None) -> "Span": 52 | return Span(spacy_span.start_char, spacy_span.end_char, spacy_span.label_, score=score, 53 | kb_id=str(spacy_span.kb_id)) 54 | 55 | @staticmethod 56 | def from_dict(d: Dict[str, Any]) -> "Span": 57 | start = d.get('start', None) if 'start' in d else d.get('start_char', None) 58 | end = d.get('end', None) if 'end' in d else d.get('end_char', None) 59 | label = d.get('label', None) 60 | score = d.get('score', None) 61 | kb_id = d.get('kb_id', '') 62 | if start is None: 63 | raise ValueError('One of [start, start_char] must be defined in dict.') 64 | if end is None: 65 | raise ValueError('One of [end, end_char] must be defined in dict.') 66 | if not label: 67 | raise ValueError('Label must be defined in dict.') 68 | 69 | return Span(start, end, label=label, score=score, 70 | kb_id=str(kb_id)) 71 | -------------------------------------------------------------------------------- /zshot/tests/mentions_extractor/test_spacy_mentions_extractor.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | 3 | from zshot import PipelineConfig 4 | from zshot.mentions_extractor import MentionsExtractorSpacy 5 | from zshot.mentions_extractor.mentions_extractor_spacy import ExtractorType 6 | from zshot.tests.config import EX_DOCS 7 | 8 | 9 | def test_spacy_ner_mentions_extractor(): 10 | nlp = spacy.load("en_core_web_sm") 11 | 12 | config_zshot = PipelineConfig(mentions_extractor=MentionsExtractorSpacy(ExtractorType.NER)) 13 | nlp.add_pipe("zshot", config=config_zshot, last=True) 14 | assert "zshot" in nlp.pipe_names and "ner" in nlp.pipe_names 15 | 16 | doc = nlp(EX_DOCS[1]) 17 | assert doc.ents == () 18 | assert len(doc._.mentions) > 0 19 | del doc, nlp 20 | 21 | 22 | def test_custom_spacy_mentions_extractor(): 23 | nlp = spacy.load("en_core_web_sm") 24 | 25 | custom_component = MentionsExtractorSpacy(ExtractorType.NER) 26 | config_zshot = PipelineConfig(mentions_extractor=custom_component, disable_default_ner=False) 27 | nlp.add_pipe("zshot", config=config_zshot, last=True) 28 | assert "zshot" in nlp.pipe_names and "ner" in nlp.pipe_names 29 | 30 | doc = nlp(EX_DOCS[1]) 31 | assert doc.ents == () 32 | assert len(doc._.mentions) > 0 33 | del doc, nlp 34 | 35 | 36 | def test_spacy_pos_mentions_extractor(): 37 | nlp = spacy.load("en_core_web_sm") 38 | 39 | config_zshot = PipelineConfig(mentions_extractor=MentionsExtractorSpacy(ExtractorType.POS)) 40 | nlp.add_pipe("zshot", config=config_zshot, last=True) 41 | assert "zshot" in nlp.pipe_names and "ner" not in nlp.pipe_names 42 | doc = nlp(EX_DOCS[1]) 43 | assert doc.ents == () 44 | assert len(doc._.mentions) > 0 45 | del doc, nlp 46 | 47 | 48 | def test_spacy_ner_mentions_extractor_pipeline(): 49 | nlp = spacy.load("en_core_web_sm") 50 | 51 | config_zshot = PipelineConfig(mentions_extractor=MentionsExtractorSpacy(ExtractorType.NER)) 52 | nlp.add_pipe("zshot", config=config_zshot, last=True) 53 | assert "zshot" in nlp.pipe_names and "ner" in nlp.pipe_names 54 | docs = [doc for doc in nlp.pipe(EX_DOCS)] 55 | assert all(doc.ents == () for doc in docs) 56 | assert all(len(doc._.mentions) > 0 for doc in docs) 57 | del docs, nlp 58 | 59 | 60 | def test_spacy_pos_mentions_extractor_pipeline(): 61 | nlp = spacy.load("en_core_web_sm") 62 | 63 | config_zshot = PipelineConfig(mentions_extractor=MentionsExtractorSpacy(ExtractorType.POS)) 64 | nlp.add_pipe("zshot", config=config_zshot, last=True) 65 | assert "zshot" in nlp.pipe_names and "ner" not in nlp.pipe_names 66 | docs = [doc for doc in nlp.pipe(EX_DOCS)] 67 | assert all(doc.ents == () for doc in docs) 68 | assert all(len(doc._.mentions) > 0 for doc in docs) 69 | del docs, nlp 70 | -------------------------------------------------------------------------------- /zshot/tests/knowledge_extractor/test_knowledge_extractor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import tempfile 4 | from typing import Iterator, Optional, Union, List, Tuple 5 | 6 | import spacy 7 | from spacy.tokens.doc import Doc 8 | 9 | from zshot import PipelineConfig 10 | from zshot.knowledge_extractor import KnowledgeExtractor 11 | from zshot.tests.config import EX_DOCS 12 | from zshot.utils.data_models import Relation 13 | from zshot.utils.data_models import Span 14 | from zshot.utils.data_models.relation_span import RelationSpan 15 | 16 | 17 | class DummyKnowledgeExtractor(KnowledgeExtractor): 18 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) \ 19 | -> List[List[Tuple[Span, RelationSpan, Span]]]: 20 | docs_preds = [] 21 | for doc in docs: 22 | e1 = doc[0] 23 | e2 = doc[1] 24 | s1 = Span(e1.idx, e1.idx + len(e1.text), "subject") 25 | s2 = Span(e2.idx, e2.idx + len(e2.text), "object") 26 | preds = [(s1, RelationSpan(s1, s2, Relation(name="relation")), s2)] 27 | docs_preds.append(preds) 28 | return docs_preds 29 | 30 | 31 | def test_dummy_knowledge_extractor(): 32 | nlp = spacy.blank("en") 33 | config_zshot = PipelineConfig( 34 | knowledge_extractor=DummyKnowledgeExtractor(), 35 | ) 36 | nlp.add_pipe("zshot", config=config_zshot, last=True) 37 | assert "zshot" in nlp.pipe_names 38 | doc = nlp(EX_DOCS[0]) 39 | assert len(doc.ents) > 0 40 | assert len(doc._.spans) > 0 41 | assert len(doc._.relations) > 0 42 | 43 | 44 | def test_dummy_knowledge_extractor_device(): 45 | nlp = spacy.blank("en") 46 | device = 'cpu' 47 | config_zshot = PipelineConfig( 48 | knowledge_extractor=DummyKnowledgeExtractor(), 49 | device=device, 50 | ) 51 | nlp.add_pipe("zshot", config=config_zshot, last=True) 52 | assert "zshot" in nlp.pipe_names 53 | assert nlp.get_pipe("zshot").device == device 54 | 55 | 56 | def test_serialization_knowledge_extractor(): 57 | nlp = spacy.blank("en") 58 | config_zshot = PipelineConfig(knowledge_extractor=DummyKnowledgeExtractor()) 59 | nlp.add_pipe("zshot", config=config_zshot, last=True) 60 | assert "zshot" in nlp.pipe_names 61 | assert "ner" not in nlp.pipe_names 62 | pipes = [p for p in nlp.pipe_names if p != "zshot"] 63 | 64 | d = tempfile.TemporaryDirectory() 65 | nlp.to_disk(d.name, exclude=pipes) 66 | config_fn = os.path.join(d.name, "zshot", "config.cfg") 67 | assert os.path.exists(config_fn) 68 | with open(config_fn, "r") as f: 69 | config = json.load(f) 70 | assert "disable_default_ner" in config and config["disable_default_ner"] 71 | nlp2 = spacy.load(d.name) 72 | assert "zshot" in nlp2.pipe_names 73 | assert isinstance(nlp2.get_pipe("zshot").knowledge_extractor, DummyKnowledgeExtractor) 74 | -------------------------------------------------------------------------------- /zshot/tests/utils/test_ensembler.py: -------------------------------------------------------------------------------- 1 | from zshot.utils.data_models import Span 2 | from zshot.utils.ensembler import Ensembler 3 | 4 | 5 | def test_ensemble_max(): 6 | ensembler = Ensembler(num_voters=2) 7 | assert ensembler.ensemble_max([ 8 | Span(start=0, end=6, label='fruits', score=0.9), 9 | Span(start=0, end=6, label='NEG', score=0.1) 10 | ]) == Span(start=0, end=6, label='fruits', score=0.45) 11 | assert ensembler.ensemble_max([ 12 | Span(start=7, end=9, label='fruits', score=0.1), 13 | Span(start=7, end=9, label='NEG', score=0.8), 14 | ]) == Span(start=7, end=9, label='NEG', score=0.4) 15 | assert ensembler.ensemble_max([ 16 | Span(start=10, end=17, label='fruits', score=0.8), 17 | Span(start=10, end=17, label='NEG', score=0.2), 18 | ]) == Span(start=10, end=17, label='fruits', score=0.4) 19 | assert ensembler.ensemble_max([ 20 | Span(start=40, end=41, label='fruits', score=0.4), 21 | Span(start=40, end=41, label='NEG', score=0.6) 22 | ]) == Span(start=40, end=41, label='NEG', score=0.3) 23 | 24 | 25 | def test_ensemble_count(): 26 | ensembler = Ensembler(num_voters=3) 27 | assert ensembler.ensemble_count([ 28 | Span(start=0, end=6, label='fruits', score=0.9), 29 | Span(start=0, end=6, label='fruits', score=0.9), 30 | Span(start=0, end=6, label='NEG', score=0.1) 31 | ]) == Span(start=0, end=6, label='fruits', score=2 / 3) 32 | assert ensembler.ensemble_count([ 33 | Span(start=7, end=9, label='fruits', score=0.1), 34 | Span(start=7, end=9, label='NEG', score=0.8), 35 | Span(start=7, end=9, label='NEG', score=0.8) 36 | ]) == Span(start=7, end=9, label='NEG', score=2 / 3) 37 | assert ensembler.ensemble_count([ 38 | Span(start=10, end=17, label='fruits', score=0.8), 39 | Span(start=10, end=17, label='fruits', score=0.8), 40 | Span(start=10, end=17, label='NEG', score=0.2) 41 | ]) == Span(start=10, end=17, label='fruits', score=2 / 3) 42 | assert ensembler.ensemble_count([ 43 | Span(start=40, end=41, label='fruits', score=0.4), 44 | Span(start=40, end=41, label='NEG', score=0.6), 45 | Span(start=40, end=41, label='NEG', score=0.6) 46 | ]) == Span(start=40, end=41, label='NEG', score=2 / 3) 47 | 48 | 49 | def test_select_best(): 50 | ensembler = Ensembler(num_voters=3) 51 | assert ensembler.select_best({ 52 | 'fruits': 0.3, 53 | 'NEG': 0.03 54 | }) == (0.3, 'fruits') 55 | assert ensembler.select_best({ 56 | 'fruits': 0.1, 57 | 'NEG': 0.3 58 | }) == (0.3, 'NEG') 59 | 60 | 61 | def test_inclusive(): 62 | ensembler = Ensembler(num_voters=3) 63 | spans = [ 64 | Span(start=40, end=41, label='fruits', score=0.4), 65 | Span(start=40, end=41, label='NEG', score=0.6), 66 | Span(start=40, end=42, label='NEG', score=0.6) 67 | ] 68 | assert ensembler.inclusive(spans) == [ 69 | Span(start=40, end=42, label='NEG', score=0.6) 70 | ] 71 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '20 18 * * 6' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | with: 74 | category: "/language:${{matrix.language}}" 75 | -------------------------------------------------------------------------------- /zshot/knowledge_extractor/knowledge_extractor_relik.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import logging 3 | import pkgutil 4 | from typing import List, Tuple, Iterator, Optional, Union 5 | 6 | from relik import Relik 7 | from relik.inference.data.objects import RelikOutput 8 | from spacy.tokens import Doc 9 | 10 | from zshot.config import MODELS_CACHE_PATH 11 | from zshot.knowledge_extractor.knowledge_extractor import KnowledgeExtractor 12 | from zshot.utils.data_models import Span, Relation 13 | from zshot.utils.data_models.relation_span import RelationSpan 14 | 15 | logging.getLogger("relik").setLevel(logging.ERROR) 16 | 17 | MODEL_NAME = "sapienzanlp/relik-relation-extraction-nyt-large" 18 | 19 | 20 | class KnowledgeExtractorRelik(KnowledgeExtractor): 21 | def __init__(self, model_name=MODEL_NAME): 22 | """ Instantiate the KnowGL Knowledge Extractor """ 23 | super().__init__() 24 | 25 | if not pkgutil.find_loader("relik"): 26 | raise Exception("relik module not installed. " 27 | "You need to install relik in order to use the relik Knowledge Extractor." 28 | "Install it with: pip install relik") 29 | 30 | self.model_name = model_name 31 | self.model = None 32 | 33 | def load_models(self): 34 | """ Load relik model """ 35 | # Remove RELIK print 36 | with contextlib.redirect_stdout(None): 37 | if self.model is None: 38 | self.model = Relik.from_pretrained(self.model_name, 39 | cache_dir=MODELS_CACHE_PATH, device=self.device) 40 | 41 | def parse_result(self, relik_out: RelikOutput, doc: Doc) -> List[Tuple[Span, RelationSpan, Span]]: 42 | triples = [] 43 | for triple in relik_out.triplets: 44 | subject = Span(triple.subject.start, triple.subject.end, triple.subject.label) 45 | object_ = Span(triple.object.start, triple.object.end, triple.object.label) 46 | 47 | relation = Relation(name=triple.label, description="") 48 | triples.append((subject, 49 | RelationSpan(start=subject, end=object_, relation=relation), 50 | object_)) 51 | return triples 52 | 53 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) \ 54 | -> List[List[Tuple[Span, RelationSpan, Span]]]: 55 | """ Extract triples from docs 56 | 57 | :param docs: Spacy Docs to process 58 | :param batch_size: Batch size for processing 59 | :return: Triples (subject, relation, object) extracted for each document 60 | """ 61 | if not self.model: 62 | self.load_models() 63 | 64 | texts = [d.text for d in docs] 65 | relik_out = self.model(texts) 66 | if type(relik_out) is RelikOutput: 67 | relik_out = [relik_out] 68 | 69 | triples = [] 70 | for doc, output in zip(docs, relik_out): 71 | triples.append(self.parse_result(output, doc)) 72 | 73 | return triples 74 | -------------------------------------------------------------------------------- /zshot/tests/evaluation/test_datasets.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | from zshot.evaluation import load_ontonotes_zs, load_medmentions_zs, load_few_rel_zs, load_pile_ner_biomed_zs 7 | from zshot.evaluation.dataset.dataset import create_dataset 8 | from zshot.utils.data_models import Entity 9 | 10 | ENTITIES = [ 11 | Entity(name="FAC", description="A facility"), 12 | Entity(name="LOC", description="A location"), 13 | ] 14 | 15 | 16 | @pytest.fixture(scope="module", autouse=True) 17 | def teardown(): 18 | yield True 19 | shutil.rmtree(f"{Path.home()}/.cache/huggingface", ignore_errors=True) 20 | shutil.rmtree(f"{Path.home()}/.cache/zshot", ignore_errors=True) 21 | 22 | 23 | @pytest.mark.skip(reason="Too expensive to run on every commit") 24 | def test_ontonotes_zs(): 25 | dataset = load_ontonotes_zs() 26 | assert 'train' in dataset 27 | assert 'test' in dataset 28 | assert 'validation' in dataset 29 | assert dataset['train'].num_rows == 41475 30 | assert dataset['test'].num_rows == 426 31 | assert dataset['validation'].num_rows == 1358 32 | del dataset 33 | 34 | 35 | @pytest.mark.skip(reason="Too expensive to run on every commit") 36 | def test_ontonotes_zs_split(): 37 | dataset = load_ontonotes_zs(split='test') 38 | assert dataset.num_rows == 426 39 | del dataset 40 | 41 | 42 | @pytest.mark.skip(reason="Too expensive to run on every commit") 43 | def test_ontonotes_zs_sub_split(): 44 | dataset = load_ontonotes_zs(split='test[0:10]') 45 | assert dataset.num_rows > 0 46 | del dataset 47 | 48 | 49 | @pytest.mark.skip(reason="Too expensive to run on every commit") 50 | def test_medmentions_zs(): 51 | dataset = load_medmentions_zs() 52 | assert 'train' in dataset 53 | assert 'test' in dataset 54 | assert 'validation' in dataset 55 | 56 | assert dataset['train'].num_rows == 26770 57 | assert dataset['test'].num_rows == 1048 58 | assert dataset['validation'].num_rows == 1289 59 | del dataset 60 | 61 | 62 | @pytest.mark.skip(reason="Too expensive to run on every commit") 63 | def test_medmentions_zs_split(): 64 | dataset = load_medmentions_zs(split='test') 65 | assert dataset.num_rows == 1048 66 | del dataset 67 | 68 | 69 | @pytest.mark.skip(reason="Too expensive to run on every commit") 70 | def test_pile_bioner(): 71 | dataset = load_pile_ner_biomed_zs() 72 | assert dataset.num_rows == 58861 73 | assert len(dataset.entities) == 3912 74 | del dataset 75 | 76 | 77 | @pytest.mark.skip(reason="Too expensive to run on every commit") 78 | def test_few_rel_zs(): 79 | dataset = load_few_rel_zs() 80 | assert dataset.num_rows == 11200 81 | 82 | dataset = load_few_rel_zs("val_wiki[0:5]") 83 | assert dataset.num_rows == 5 84 | del dataset 85 | 86 | 87 | def test_create_dataset(): 88 | sentences = ["New York is beautiful", "New York is beautiful"] 89 | gt = [["B-FAC", "I-FAC", "O", "O"], ["B-FAC", "I-FAC", "O", "O"]] 90 | 91 | dataset = create_dataset(gt, sentences, ENTITIES) 92 | assert dataset.num_rows == len(sentences) 93 | assert dataset.entities == ENTITIES 94 | del dataset 95 | -------------------------------------------------------------------------------- /zshot/pipeline_config.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional, Union, List 3 | 4 | import spacy 5 | 6 | from zshot.knowledge_extractor.knowledge_extractor import KnowledgeExtractor 7 | from zshot.linker import Linker 8 | from zshot.mentions_extractor import MentionsExtractor 9 | from zshot.relation_extractor import RelationsExtractor 10 | from zshot.utils.data_models import Entity, Relation 11 | 12 | 13 | class PipelineConfig(dict): 14 | 15 | def __init__(self, 16 | mentions_extractor: Optional[MentionsExtractor] = None, 17 | linker: Optional[Union[Linker, str]] = None, 18 | relations_extractor: Optional[Union[RelationsExtractor, str]] = None, 19 | knowledge_extractor: Optional[Union[KnowledgeExtractor, str]] = None, 20 | mentions: Optional[Union[List[Entity], List[str], str]] = None, 21 | entities: Optional[Union[List[Entity], List[str], str]] = None, 22 | relations: Optional[Union[List[Relation], str]] = None, 23 | disable_default_ner: Optional[bool] = True, 24 | device: Optional[str] = None) -> None: 25 | config = {} 26 | 27 | if mentions_extractor: 28 | mention_extractor_id = PipelineConfig.param(mentions_extractor) 29 | config.update({'mentions_extractor': mention_extractor_id}) 30 | 31 | if linker: 32 | linker_id = PipelineConfig.param(linker) 33 | config.update({'linker': linker_id}) 34 | 35 | if relations_extractor: 36 | relation_extractor_id = PipelineConfig.param(relations_extractor) 37 | config.update({'relations_extractor': relation_extractor_id}) 38 | 39 | if knowledge_extractor: 40 | knowledge_extractor_id = PipelineConfig.param(knowledge_extractor) 41 | config.update({'knowledge_extractor': knowledge_extractor_id}) 42 | 43 | if entities: 44 | entities_id = PipelineConfig.param(entities) 45 | config.update({'entities': entities_id}) 46 | 47 | if mentions: 48 | mentions_id = PipelineConfig.param(mentions) 49 | config.update({'mentions': mentions_id}) 50 | 51 | if relations: 52 | relations_id = PipelineConfig.param(relations) 53 | config.update({'relations': relations_id}) 54 | 55 | if disable_default_ner: 56 | config.update({'disable_default_ner': disable_default_ner}) 57 | 58 | if device: 59 | config.update({'device': device}) 60 | 61 | super().__init__(**config) 62 | 63 | @staticmethod 64 | def param(param) -> str: 65 | if isinstance(param, list): 66 | params_to_hash = random.sample(param, k=min(len(param), 10)) 67 | instance_hash = hash(sum([hash(param_to_hash) for param_to_hash in params_to_hash])) 68 | else: 69 | instance_hash = hash(param) 70 | 71 | instance_id = f"zshot.{param.__class__.__name__}.{instance_hash}" 72 | 73 | @spacy.registry.misc(instance_id) 74 | def create_custom_component(): 75 | return param 76 | 77 | return instance_id 78 | -------------------------------------------------------------------------------- /zshot/tests/mentions_extractor/test_tars_mentions_extractor.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import pkgutil 3 | 4 | import pytest 5 | import spacy 6 | 7 | from zshot import PipelineConfig 8 | from zshot.mentions_extractor import MentionsExtractorTARS 9 | from zshot.tests.config import EX_DOCS, EX_ENTITIES 10 | 11 | OVERLAP_TEXT = "Senator McConnell in addition to this the New York Times editors wrote in reaction to the " \ 12 | "Supreme Court 's decision striking down the military tribunal set up to private detainees " \ 13 | "being held in Guantanamo bay it is far more than a narrow ruling on the issue of military courts . " \ 14 | "It is an important and welcome reaffirmation That even in times of war the law is what the " \ 15 | "constitution the statuette books and the Geneva convention say it is . " \ 16 | "Not what the President wants it to be /." 17 | INCOMPLETE_SPANS_TEXT = "-LSB- -LSB- They attacked small bridges and small districts , " \ 18 | "and generally looted these stations . -RSB- -RSB-" 19 | 20 | 21 | @pytest.fixture(scope="module", autouse=True) 22 | def teardown(): 23 | yield True 24 | gc.collect() 25 | 26 | 27 | @pytest.mark.xfail(pkgutil.resolve_name("flair").__version__ == '0.12.2', reason='Bug in TARS models in Flair 0.12.2') 28 | def test_tars_mentions_extractor_with_entities(): 29 | if not pkgutil.find_loader("flair"): 30 | return 31 | nlp = spacy.blank("en") 32 | 33 | config_zshot = PipelineConfig(mentions_extractor=MentionsExtractorTARS(), mentions=EX_ENTITIES) 34 | nlp.add_pipe("zshot", config=config_zshot, last=True) 35 | assert "zshot" in nlp.pipe_names 36 | doc = nlp(EX_DOCS[1]) 37 | assert doc._.mentions != () 38 | 39 | docs = [doc for doc in nlp.pipe(EX_DOCS)] 40 | assert all(doc._.mentions != () for doc in docs) 41 | nlp.remove_pipe('zshot') 42 | del docs, doc, nlp 43 | 44 | 45 | @pytest.mark.xfail(pkgutil.resolve_name("flair").__version__ == '0.12.2', reason='Bug in TARS models in Flair 0.12.2') 46 | def test_tars_mentions_extractor_overlap(): 47 | if not pkgutil.find_loader("flair"): 48 | return 49 | nlp = spacy.blank("en") 50 | 51 | config_zshot = PipelineConfig(mentions_extractor=MentionsExtractorTARS(), 52 | mentions=["company", "location", "organic compound"]) 53 | nlp.add_pipe("zshot", config=config_zshot, last=True) 54 | assert "zshot" in nlp.pipe_names 55 | doc = nlp(OVERLAP_TEXT) 56 | assert len(doc._.mentions) > 0 57 | nlp.remove_pipe('zshot') 58 | del doc, nlp 59 | 60 | 61 | @pytest.mark.xfail(pkgutil.resolve_name("flair").__version__ == '0.12.2', reason='Bug in TARS models in Flair 0.12.2') 62 | def test_tars_end2end_incomplete_spans(): 63 | if not pkgutil.find_loader("flair"): 64 | return 65 | nlp = spacy.blank("en") 66 | 67 | config_zshot = PipelineConfig(mentions_extractor=MentionsExtractorTARS()) 68 | nlp.add_pipe("zshot", config=config_zshot, last=True) 69 | assert "zshot" in nlp.pipe_names 70 | doc = nlp(INCOMPLETE_SPANS_TEXT) 71 | assert len(doc._.mentions) >= 0 72 | nlp.remove_pipe('zshot') 73 | del doc, nlp 74 | -------------------------------------------------------------------------------- /zshot/linker/linker_relik.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import logging 3 | import pkgutil 4 | from typing import Iterator, List, Optional, Union 5 | 6 | from relik import Relik 7 | from relik.inference.data.objects import RelikOutput 8 | from relik.retriever.indexers.document import Document 9 | from spacy.tokens import Doc 10 | 11 | from zshot.config import MODELS_CACHE_PATH 12 | from zshot.linker.linker import Linker 13 | from zshot.utils.data_models import Span 14 | 15 | logging.getLogger("relik").setLevel(logging.ERROR) 16 | 17 | MODEL_NAME = "sapienzanlp/relik-entity-linking-large" 18 | 19 | 20 | class LinkerRelik(Linker): 21 | """ Relik linker """ 22 | 23 | def __init__(self, model_name=MODEL_NAME): 24 | super().__init__() 25 | 26 | if not pkgutil.find_loader("relik"): 27 | raise Exception("relik module not installed. You need to install relik in order to use the relik Linker." 28 | "Install it with: pip install relik") 29 | 30 | self.model_name = model_name 31 | self.model = None 32 | # self.device = { 33 | # "retriever_device": self.device, 34 | # "index_device": self.device, 35 | # "reader_device": self.device 36 | # } 37 | 38 | @property 39 | def is_end2end(self) -> bool: 40 | """ relik is end2end """ 41 | return True 42 | 43 | def load_models(self): 44 | """ Load relik model """ 45 | # Remove RELIK print 46 | with contextlib.redirect_stdout(None): 47 | if self.model is None: 48 | if self._entities: 49 | self.model = Relik.from_pretrained(self.model_name, 50 | cache_dir=MODELS_CACHE_PATH, 51 | retriever=None, device=self.device) 52 | else: 53 | self.model = Relik.from_pretrained(self.model_name, 54 | cache_dir=MODELS_CACHE_PATH, device=self.device, 55 | index_device='cpu') 56 | 57 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: 58 | """ 59 | Perform the entity prediction 60 | :param docs: A list of spacy Document 61 | :param batch_size: The batch size 62 | :return: List Spans for each Document in docs 63 | """ 64 | candidates = None 65 | if self._entities: 66 | candidates = [ 67 | Document(text=ent.name, id=i, metadata={'definition': ent.description}) 68 | for i, ent in enumerate(self._entities) 69 | ] 70 | 71 | sentences = [doc.text for doc in docs] 72 | 73 | self.load_models() 74 | span_annotations = [] 75 | for sent in sentences: 76 | relik_out: RelikOutput = self.model(sent, candidates=candidates) 77 | span_annotations.append([Span(start=relik_span.start, end=relik_span.end, label=relik_span.label) 78 | for relik_span in relik_out.spans]) 79 | 80 | return span_annotations 81 | -------------------------------------------------------------------------------- /zshot/relation_extractor/relations_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | import zlib 4 | from abc import ABC, abstractmethod 5 | from typing import List, Iterator, Optional, Union 6 | 7 | import torch 8 | from spacy.tokens import Doc 9 | from spacy.util import ensure_path 10 | 11 | from zshot.utils.data_models.relation import Relation 12 | from zshot.utils.data_models.relation_span import RelationSpan 13 | 14 | 15 | class RelationsExtractor(ABC): 16 | 17 | def __init__(self, device: Optional[Union[str, torch.device]] = None): 18 | self._relations = None 19 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 20 | 21 | def set_device(self, device: Union[str, torch.device]): 22 | """ 23 | Set the device to use 24 | :param device: 25 | :return: 26 | """ 27 | self.device = device 28 | 29 | def set_relations(self, relations: Iterator[Relation]): 30 | """ 31 | Set relationships that the relations extractor can use 32 | :param relations: The list of relationship 33 | """ 34 | self._relations = relations 35 | 36 | @property 37 | def relations(self) -> List[Relation]: 38 | return self._relations 39 | 40 | def load_models(self): 41 | """ 42 | Load the model 43 | :return: 44 | """ 45 | pass 46 | 47 | @abstractmethod 48 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[RelationSpan]]: 49 | """ 50 | Perform the relations extraction. 51 | :param docs: A list of spacy Document 52 | :param batch_size: The batch size 53 | :return: the predicted relations 54 | """ 55 | pass 56 | 57 | def extract_relations(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None): 58 | """ 59 | Perform the relations extraction. Call the predict function and add the mentions to the Spacy Doc 60 | :param docs: A list of spacy Document 61 | :param batch_size: The batch size 62 | :return: 63 | """ 64 | predicted_relations = self.predict(docs, batch_size) 65 | for d, preds in zip(docs, predicted_relations): 66 | d._.relations = preds 67 | 68 | @staticmethod 69 | def version() -> str: 70 | return "v1" 71 | 72 | @staticmethod 73 | def _get_serialize_file(path): 74 | return os.path.join(path, "mentions_extractor.pkl") 75 | 76 | @staticmethod 77 | def _get_config_file(path): 78 | path = os.path.join(path, "mentions_extractor.json") 79 | path = ensure_path(path) 80 | return path 81 | 82 | @classmethod 83 | def from_disk(cls, path, exclude=()): 84 | serialize_file = cls._get_serialize_file(path) 85 | with open(serialize_file, "rb") as f: 86 | return pkl.load(f) 87 | 88 | def to_disk(self, path): 89 | serialize_file = self._get_serialize_file(path) 90 | with open(serialize_file, "wb") as f: 91 | return pkl.dump(self, f) 92 | 93 | def __hash__(self): 94 | self_repr = f"{self.__class__.__name__}.{self.version()}.{str(self.__dict__)}" 95 | return zlib.crc32(self_repr.encode()) 96 | -------------------------------------------------------------------------------- /zshot/tests/mentions_extractor/test_flair_mentions_extractor.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import pkgutil 3 | 4 | import pytest 5 | import spacy 6 | 7 | from zshot import PipelineConfig 8 | from zshot.mentions_extractor import MentionsExtractorFlair 9 | from zshot.mentions_extractor.mentions_extractor_flair import ExtractorType 10 | from zshot.tests.config import EX_DOCS 11 | 12 | 13 | @pytest.fixture(scope="module", autouse=True) 14 | def teardown(): 15 | yield True 16 | gc.collect() 17 | 18 | 19 | def test_flair_ner_mentions_extractor(): 20 | if not pkgutil.find_loader("flair"): 21 | return 22 | nlp = spacy.blank("en") 23 | config_zshot = PipelineConfig(mentions_extractor=MentionsExtractorFlair(ExtractorType.NER)) 24 | nlp.add_pipe("zshot", config=config_zshot, last=True) 25 | assert "zshot" in nlp.pipe_names 26 | doc = nlp(EX_DOCS[1]) 27 | assert doc.ents == () 28 | assert len(doc._.mentions) > 0 29 | nlp.remove_pipe('zshot') 30 | del doc, nlp 31 | 32 | 33 | def test_custom_flair_mentions_extractor(): 34 | nlp = spacy.blank("en") 35 | config_zshot = PipelineConfig(mentions_extractor=MentionsExtractorFlair(ExtractorType.NER)) 36 | nlp.add_pipe("zshot", config=config_zshot, last=True) 37 | assert "zshot" in nlp.pipe_names 38 | doc = nlp(EX_DOCS[1]) 39 | assert doc.ents == () 40 | assert len(doc._.mentions) > 0 41 | nlp.remove_pipe('zshot') 42 | del doc, nlp 43 | 44 | 45 | @pytest.mark.xfail(reason='Chunk models not working in Flair. See https://github.com/flairNLP/flair/issues/3418') 46 | def test_flair_pos_mentions_extractor(): 47 | if not pkgutil.find_loader("flair"): 48 | return 49 | nlp = spacy.blank("en") 50 | 51 | config_zshot = PipelineConfig(mentions_extractor=MentionsExtractorFlair(ExtractorType.POS)) 52 | nlp.add_pipe("zshot", config=config_zshot, last=True) 53 | assert "zshot" in nlp.pipe_names 54 | doc = nlp(EX_DOCS[1]) 55 | assert doc.ents == () 56 | assert len(doc._.mentions) > 0 57 | nlp.remove_pipe('zshot') 58 | del doc, nlp 59 | 60 | 61 | def test_flair_ner_mentions_extractor_pipeline(): 62 | if not pkgutil.find_loader("flair"): 63 | return 64 | nlp = spacy.blank("en") 65 | config_zshot = PipelineConfig(mentions_extractor=MentionsExtractorFlair(ExtractorType.NER)) 66 | nlp.add_pipe("zshot", config=config_zshot, last=True) 67 | assert "zshot" in nlp.pipe_names 68 | docs = [doc for doc in nlp.pipe(EX_DOCS)] 69 | assert all(doc.ents == () for doc in docs) 70 | assert all(len(doc._.mentions) > 0 for doc in docs) 71 | nlp.remove_pipe('zshot') 72 | del docs, nlp 73 | 74 | 75 | @pytest.mark.xfail(reason='Chunk models not working in Flair. See https://github.com/flairNLP/flair/issues/3418') 76 | def test_flair_pos_mentions_extractor_pipeline(): 77 | if not pkgutil.find_loader("flair"): 78 | return 79 | nlp = spacy.blank("en") 80 | config_zshot = PipelineConfig(mentions_extractor=MentionsExtractorFlair(ExtractorType.POS)) 81 | nlp.add_pipe("zshot", config=config_zshot, last=True) 82 | assert "zshot" in nlp.pipe_names 83 | docs = [doc for doc in nlp.pipe(EX_DOCS)] 84 | assert all(doc.ents == () for doc in docs) 85 | assert all(len(doc._.mentions) > 0 for doc in docs) 86 | nlp.remove_pipe('zshot') 87 | del docs, nlp 88 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Zshot 2 | site_description: Zero and Few shot Named Entities and Relationships recognition 3 | watch: [zshot] 4 | 5 | nav: 6 | - Home: 7 | - Overview: index.md 8 | - Usage: usage.md 9 | - Evaluation: evaluation.md 10 | - Code reference: mentions_extractor/ 11 | - Code reference: 12 | - Mentions Extraction: 13 | - mentions_extractor.md 14 | - spacy_mentions_extractor.md 15 | - flair_mentions_extractor.md 16 | - smxm_mentions_extractor.md 17 | - tars_mentions_extractor.md 18 | - gliner_mentions_extractor.md 19 | - Entity Linking: 20 | - entity_linking.md 21 | - blink.md 22 | - regen.md 23 | - smxm_linker.md 24 | - tars_linker.md 25 | - relik_linker.md 26 | - gliner_linker.md 27 | - Relations Extractor: 28 | - relation_extractor.md 29 | - zsbert_relations_extractor.md 30 | - Knowledge Extractor: 31 | - knowledge_extractor.md 32 | - knowgl_knowledge_extractor.md 33 | - relik_knowledge_extractor.md 34 | 35 | markdown_extensions: 36 | - attr_list 37 | - pymdownx.emoji: 38 | emoji_index: !!python/name:materialx.emoji.twemoji 39 | emoji_generator: !!python/name:materialx.emoji.to_svg 40 | theme: 41 | name: material 42 | features: 43 | - content.code.annotate 44 | - navigation.tabs 45 | - navigation.top 46 | palette: 47 | - media: "(prefers-color-scheme: light)" 48 | scheme: default 49 | primary: black 50 | accent: purple 51 | toggle: 52 | icon: material/weather-sunny 53 | name: Switch to light mode 54 | - media: "(prefers-color-scheme: dark)" 55 | scheme: slate 56 | primary: black 57 | accent: lime 58 | toggle: 59 | icon: material/weather-night 60 | name: Switch to dark mode 61 | features: 62 | - search.suggest 63 | - search.highlight 64 | - content.tabs.link 65 | icon: 66 | repo: fontawesome/brands/github-alt 67 | language: en 68 | repo_name: IBM/zshot 69 | repo_url: https://github.com/IBM/zshot 70 | edit_uri: '' 71 | plugins: 72 | - search 73 | - include-markdown 74 | - mkdocstrings: 75 | handlers: 76 | python: 77 | import: 78 | - https://docs.python.org/3/objects.inv 79 | - https://installer.readthedocs.io/en/stable/objects.inv # demonstration purpose in the docs 80 | - https://mkdocstrings.github.io/autorefs/objects.inv 81 | options: 82 | show_source: false 83 | docstring_style: sphinx 84 | merge_init_into_class: yes 85 | show_submodules: yes 86 | - markdownextradata: 87 | data: data 88 | markdown_extensions: 89 | - toc: 90 | permalink: true 91 | - markdown.extensions.codehilite: 92 | guess_lang: false 93 | - mdx_include: 94 | base_path: docs 95 | - admonition 96 | - codehilite 97 | - extra 98 | - pymdownx.superfences: 99 | custom_fences: 100 | - name: mermaid 101 | class: mermaid 102 | format: !!python/name:pymdownx.superfences.fence_code_format '' 103 | - pymdownx.tabbed: 104 | alternate_style: true 105 | - attr_list 106 | - md_in_html 107 | extra: 108 | social: 109 | - icon: fontawesome/brands/github-alt 110 | link: https://github.com/IBM/zshot -------------------------------------------------------------------------------- /docs/evaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | Evaluation is an important process to keep improving the performance of the models, that's why ZShot allows to evaluate the component with two predefined datasets: OntoNotes and MedMentions, in a Zero-Shot version in which the entities of the test and validation splits don't appear in the train set. 3 | 4 | #### OntoNotes 5 | OntoNotes Release 5.0 is the final release of the OntoNotes project, a collaborative effort between BBN Technologies, the University of Colorado, the University of Pennsylvania and the University of Southern Californias Information Sciences Institute. The goal of the project was to annotate a large corpus comprising various genres of text (news, conversational telephone speech, weblogs, usenet newsgroups, broadcast, talk shows) in three languages (English, Chinese, and Arabic) with structural information (syntax and predicate argument structure) and shallow semantics (word sense linked to an ontology and coreference). 6 | 7 | In ZShot the version taken from the [Huggingface datasets](https://huggingface.co/datasets/conll2012_ontonotesv5) is preprocessed to adapt it to the Zero-Shot version. 8 | 9 | [Link](https://catalog.ldc.upenn.edu/LDC2013T19) 10 | 11 | #### MedMentions 12 | Corpus: The MedMentions corpus consists of 4,392 papers (Titles and Abstracts) randomly selected from among papers released on PubMed in 2016, that were in the biomedical field, published in the English language, and had both a Title and an Abstract. 13 | 14 | Annotators: We recruited a team of professional annotators with rich experience in biomedical content curation to exhaustively annotate all UMLS® (2017AA full version) entity mentions in these papers. 15 | 16 | Annotation quality: We did not collect stringent IAA (Inter-annotator agreement) data. To gain insight on the annotation quality of MedMentions, we randomly selected eight papers from the annotated corpus, containing a total of 469 concepts. Two biologists ('Reviewer') who did not participate in the annotation task then each reviewed four papers. The agreement between Reviewers and Annotators, an estimate of the Precision of the annotations, was 97.3%. 17 | 18 | In ZShot the data is downloaded from the original repository and preprocessed to convert it into the Zero-Shot version. 19 | 20 | [Link](https://github.com/chanzuckerberg/MedMentions) 21 | 22 | ### How to evaluate ZShot 23 | 24 | The package `evaluation` contains all the functionalities to evaluate the ZShot components. The main function is `zshot.evaluation.zshot_evaluate.evaluate`, that will take as input the SpaCy `nlp` model and the dataset to evaluate. It will return a `str` containing a table with the results of the evaluation. For instance the evaluation of the TARS linker in ZShot for the *Ontonotes validation* set would be: 25 | 26 | ```python 27 | import spacy 28 | 29 | from zshot import PipelineConfig 30 | from zshot.linker import LinkerTARS 31 | from zshot.evaluation.dataset import load_ontonotes_zs 32 | from zshot.evaluation.zshot_evaluate import evaluate, prettify_evaluate_report 33 | from zshot.evaluation.metrics._seqeval._seqeval import Seqeval 34 | 35 | ontonotes_zs = load_ontonotes_zs('validation') 36 | 37 | nlp = spacy.blank("en") 38 | nlp_config = PipelineConfig( 39 | linker=LinkerTARS(), 40 | entities=ontonotes_zs.entities 41 | ) 42 | 43 | nlp.add_pipe("zshot", config=nlp_config, last=True) 44 | 45 | evaluation = evaluate(nlp, ontonotes_zs, metric=Seqeval()) 46 | prettify_evaluate_report(evaluation) 47 | ``` 48 | -------------------------------------------------------------------------------- /zshot/mentions_extractor/mentions_extractor_spacy.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Iterator 2 | 3 | from spacy.tokens.doc import Doc 4 | 5 | from zshot.utils.data_models import Span 6 | from zshot.mentions_extractor.mentions_extractor import MentionsExtractor 7 | from zshot.mentions_extractor.utils import ExtractorType 8 | 9 | 10 | class MentionsExtractorSpacy(MentionsExtractor): 11 | """ SpaCy mentions extractor """ 12 | ALLOWED_POS = ("NOUN", "PROPN") 13 | ALLOWED_DEP = ("compound", "pobj", "dobj", "nsubj", "attr", "appos") 14 | COMPOUND_DEP = "compound" 15 | 16 | EXCLUDE_NER = ("CARDINAL", "DATE", "ORDINAL", "PERCENT", "QUANTITY", "TIME") 17 | 18 | def __init__(self, extractor_type: Optional[ExtractorType] = ExtractorType.NER): 19 | """ 20 | :param extractor_type: Type of extractor to get mentions. One of: 21 | - NER: to use Named Entity Recognition model to get the mentions 22 | - POS: to get the mentions based on the linguistics 23 | """ 24 | super(MentionsExtractorSpacy, self).__init__() 25 | self.extractor_type = extractor_type 26 | 27 | @property 28 | def require_existing_ner(self) -> bool: 29 | """ If the type of the extractor is NER the existing NER is required """ 30 | return self.extractor_type == ExtractorType.NER 31 | 32 | def predict_pos_mentions(self, docs: Iterator[Doc], batch_size: Optional[int] = None): 33 | """ Predict mentions of docs using POS linguistics 34 | 35 | :param docs: Documents to get mentions of 36 | :param batch_size: Batch size to use 37 | :return: Spans of the mentions 38 | """ 39 | spans = [] 40 | for doc in docs: 41 | skip = -1 42 | spans_tmp = [] 43 | for i, tok in enumerate(doc): 44 | if 0 < i < skip: 45 | continue 46 | 47 | if tok.pos_ in self.ALLOWED_POS and tok.dep_ in self.ALLOWED_DEP: 48 | if tok.dep_ == self.COMPOUND_DEP: 49 | spans_tmp.append(Span(tok.idx, tok.head.idx + len(tok.head))) 50 | skip = tok.head.i + 1 51 | else: 52 | spans_tmp.append(Span(tok.idx, tok.idx + len(tok))) 53 | spans.append(spans_tmp) 54 | 55 | return spans 56 | 57 | def predict_ner_mentions(self, docs: Iterator[Doc], batch_size: Optional[int] = None): 58 | """ Predict mentions of docs using NER model 59 | 60 | :param docs: Documents to get mentions of 61 | :param batch_size: Batch size to use 62 | :return: Spans of the mentions 63 | """ 64 | spans = [ 65 | [ 66 | Span(ent.start_char, ent.end_char) 67 | for ent in doc.ents if ent.label_ not in self.EXCLUDE_NER 68 | ] 69 | for doc in docs 70 | ] 71 | for doc in docs: 72 | doc.ents = [] 73 | 74 | return spans 75 | 76 | def predict(self, docs: Iterator[Doc], batch_size=None): 77 | """ Predict mentions of docs 78 | 79 | :param docs: Documents to get mentions of 80 | :param batch_size: Batch size to use 81 | :return: Spans of the mentions 82 | """ 83 | if self.extractor_type == ExtractorType.NER: 84 | return self.predict_ner_mentions(docs, batch_size) 85 | else: 86 | return self.predict_pos_mentions(docs, batch_size) 87 | -------------------------------------------------------------------------------- /zshot/mentions_extractor/mentions_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | import warnings 4 | 5 | import zlib 6 | from abc import ABC, abstractmethod 7 | 8 | import torch 9 | from spacy.tokens import Doc 10 | from typing import List, Iterator, Optional, Union 11 | 12 | from spacy.util import ensure_path 13 | 14 | from zshot.utils.data_models import Entity 15 | from zshot.utils.data_models import Span 16 | 17 | 18 | class MentionsExtractor(ABC): 19 | 20 | def __init__(self, device: Optional[Union[str, torch.device]] = None): 21 | self._mentions = None 22 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 23 | 24 | def set_device(self, device: Union[str, torch.device]): 25 | """ 26 | Set the device to use 27 | :param device: 28 | :return: 29 | """ 30 | self.device = device 31 | 32 | def set_kg(self, mentions: Iterator[Entity]): 33 | """ 34 | Set entities that mention extractor can use 35 | :param mentions: The list of entities 36 | """ 37 | self._mentions = mentions 38 | 39 | @property 40 | def mentions(self) -> List[Entity]: 41 | return self._mentions 42 | 43 | def load_models(self): 44 | """ 45 | Load the model 46 | :return: 47 | """ 48 | pass 49 | 50 | def extract_mentions(self, docs: Iterator[Doc], batch_size=None): 51 | """ 52 | Perform the mentions extraction. Call the predict function and add the mentions to the Spacy Doc 53 | :param docs: A list of spacy Document 54 | :param batch_size: The batch size 55 | :return: 56 | """ 57 | predictions_spans = self.predict(docs, batch_size) 58 | for doc, doc_preds in zip(docs, predictions_spans): 59 | for pred in doc_preds: 60 | try: 61 | doc._.mentions += (pred,) 62 | except TypeError or ValueError: 63 | warnings.warn("Entity couldn't be added.") 64 | 65 | @abstractmethod 66 | def predict(self, docs: Iterator[Doc], batch_size=None) -> List[List[Span]]: 67 | """ 68 | Perform the mentions prediction 69 | :param docs: A list of spacy Document 70 | :param batch_size: The batch size 71 | :return: 72 | """ 73 | pass 74 | 75 | @staticmethod 76 | def version() -> str: 77 | return "v1" 78 | 79 | @property 80 | def require_existing_ner(self) -> bool: 81 | return False 82 | 83 | @staticmethod 84 | def _get_serialize_file(path): 85 | return os.path.join(path, "mentions_extractor.pkl") 86 | 87 | @staticmethod 88 | def _get_config_file(path): 89 | path = os.path.join(path, "mentions_extractor.json") 90 | path = ensure_path(path) 91 | return path 92 | 93 | @classmethod 94 | def from_disk(cls, path, exclude=()): 95 | serialize_file = cls._get_serialize_file(path) 96 | with open(serialize_file, "rb") as f: 97 | return pkl.load(f) 98 | 99 | def to_disk(self, path): 100 | serialize_file = self._get_serialize_file(path) 101 | with open(serialize_file, "wb") as f: 102 | return pkl.dump(self, f) 103 | 104 | def __hash__(self): 105 | self_repr = f"{self.__class__.__name__}.{self.version()}.{str(self.__dict__)}" 106 | return zlib.crc32(self_repr.encode()) 107 | -------------------------------------------------------------------------------- /zshot/tests/linker/test_linker.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from typing import Iterator 3 | 4 | import pytest 5 | import spacy 6 | from spacy.tokens import Doc 7 | 8 | from zshot import Linker, PipelineConfig 9 | from zshot.tests.config import EX_DOCS, EX_ENTITIES 10 | from zshot.tests.mentions_extractor.test_mention_extractor import DummyMentionsExtractor 11 | from zshot.utils.data_models import Span 12 | 13 | 14 | class DummyLinker(Linker): 15 | 16 | def predict(self, docs: Iterator[Doc], batch_size=None): 17 | return [ 18 | [Span(mention.start, mention.end, label='label') for mention in doc._.mentions] 19 | for doc in docs 20 | ] 21 | 22 | 23 | class DummyLinkerEnd2End(Linker): 24 | 25 | @property 26 | def is_end2end(self) -> bool: 27 | return True 28 | 29 | def predict(self, docs: Iterator[Doc], batch_size=None): 30 | return [[Span(0, len(doc.text) - 1, label='label', score=0.9)] for doc in docs] 31 | 32 | 33 | class DummyLinkerWithEntities(Linker): 34 | 35 | def predict(self, docs: Iterator[Doc], batch_size=None): 36 | entities = self.entities 37 | return [ 38 | [ 39 | Span(mention.start, mention.end, label=entities[idx].name) 40 | for idx, mention in enumerate(doc._.mentions) 41 | ] 42 | for doc in docs 43 | ] 44 | 45 | 46 | @pytest.fixture(scope="module", autouse=True) 47 | def teardown(): 48 | yield True 49 | gc.collect() 50 | 51 | 52 | def test_dummy_linker(): 53 | nlp = spacy.blank("en") 54 | config = PipelineConfig( 55 | mentions_extractor=DummyMentionsExtractor(), 56 | linker=DummyLinker()) 57 | nlp.add_pipe("zshot", config=config, last=True) 58 | assert "zshot" in nlp.pipe_names 59 | doc = nlp(EX_DOCS[1]) 60 | assert len(doc._.mentions) > 0 61 | assert len(doc.ents) > 0 62 | assert len(doc._.spans) > 0 63 | del doc, nlp 64 | 65 | 66 | def test_dummy_linker_device(): 67 | nlp = spacy.blank("en") 68 | config = PipelineConfig( 69 | mentions_extractor=DummyMentionsExtractor(), 70 | linker=DummyLinker(), 71 | device="cpu") 72 | nlp.add_pipe("zshot", config=config, last=True) 73 | assert "zshot" in nlp.pipe_names 74 | doc = nlp(EX_DOCS[1]) 75 | assert len(doc._.mentions) > 0 76 | assert len(doc.ents) > 0 77 | assert len(doc._.spans) > 0 78 | del doc, nlp 79 | 80 | 81 | def test_dummy_linker_with_entities_config(): 82 | nlp = spacy.blank("en") 83 | 84 | nlp.add_pipe("zshot", config=PipelineConfig( 85 | mentions_extractor=DummyMentionsExtractor(), 86 | linker=DummyLinker(), 87 | entities=EX_ENTITIES), last=True) 88 | 89 | assert "zshot" in nlp.pipe_names 90 | doc = nlp(EX_DOCS[1]) 91 | 92 | assert len(doc._.mentions) > 0 93 | assert len(doc.ents) > 0 94 | assert len(doc._.spans) > 0 95 | assert all([bool(ent.label_) for ent in doc.ents]) 96 | del doc, nlp 97 | 98 | 99 | def test_dummy_linker_end2end(): 100 | nlp = spacy.blank("en") 101 | 102 | nlp.add_pipe("zshot", config=PipelineConfig( 103 | mentions_extractor=DummyMentionsExtractor(), 104 | linker=DummyLinkerEnd2End(), 105 | entities=EX_ENTITIES), last=True) 106 | 107 | assert "zshot" in nlp.pipe_names 108 | doc = nlp(EX_DOCS[1]) 109 | 110 | assert len(doc._.mentions) == 0 111 | assert len(doc.ents) > 0 112 | assert len(doc._.spans) > 0 113 | del doc, nlp 114 | -------------------------------------------------------------------------------- /zshot/linker/linker_tars.py: -------------------------------------------------------------------------------- 1 | import pkgutil 2 | from typing import Iterator, Optional, Union, List 3 | 4 | from spacy.tokens.doc import Doc 5 | 6 | from zshot.linker.linker import Linker 7 | from zshot.utils.models.tars.utils import tars_predict 8 | from zshot.utils.data_models import Entity, Span 9 | 10 | 11 | class LinkerTARS(Linker): 12 | """ TARS end2end Linker """ 13 | def __init__(self, default_entities: Optional[str] = "conll-short"): 14 | """ 15 | :param default_entities: Default entities to use in case no custom ones are set 16 | One of: 17 | - 'conll-short' 18 | - 'ontonotes-long' 19 | - 'ontonotes-short' 20 | - 'wnut_17-long' 21 | - 'wnut_17-short' 22 | """ 23 | super().__init__() 24 | if not pkgutil.find_loader("flair"): 25 | raise Exception("Flair module not installed. You need to install Flair for using this class." 26 | "Install it with: pip install flair>=0.13") 27 | 28 | self.is_end2end = True 29 | self.default_entities = default_entities 30 | self.model = None 31 | self.task = None 32 | 33 | def set_kg(self, entities: Iterator[Entity]): 34 | """ Set new entities in the model 35 | 36 | :param entities: New entities to use 37 | """ 38 | old_entities = self.entities 39 | super().set_kg(entities) 40 | self.flat_entities() 41 | if old_entities != entities: 42 | self.task = f'zshot.ner.{hash(tuple(self.entities))}' 43 | if not self.model: 44 | self.load_models() 45 | self.model.add_and_switch_to_new_task(self.task, 46 | self.entities, label_type='ner') 47 | 48 | def flat_entities(self): 49 | """ As TARS use only the labels, take just the name of the entities and not the description """ 50 | if isinstance(self.entities, dict): 51 | self._entities = list(self.entities.keys()) 52 | if isinstance(self.entities, list): 53 | self._entities = [e.name if type(e) is Entity else e for e in self.entities] 54 | if self.entities is None: 55 | self._entities = [] 56 | 57 | def load_models(self): 58 | """ Load TARS model if its not initialized""" 59 | if not self.model: 60 | from flair.models import TARSTagger 61 | self.model = TARSTagger.load('tars-ner') 62 | 63 | if not self.entities: 64 | self.model.switch_to_task(self.default_entities) 65 | self.task = self.default_entities 66 | else: 67 | self.flat_entities() 68 | self.task = f'zshot.ner.{hash(tuple(self.entities))}' 69 | self.model.add_and_switch_to_new_task(self.task, 70 | self.entities, label_type='ner') 71 | 72 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: 73 | """ 74 | Perform the entity prediction 75 | :param docs: A list of spacy Document 76 | :param batch_size: The batch size 77 | :return: List Spans for each Document in docs 78 | """ 79 | from flair.data import Sentence 80 | 81 | self.load_models() 82 | 83 | sentences = [ 84 | Sentence(str(doc), use_tokenizer=True) for doc in docs 85 | ] 86 | 87 | spans_annotations = tars_predict(self.model, sentences, batch_size) 88 | 89 | return spans_annotations 90 | -------------------------------------------------------------------------------- /zshot/tests/knowledge_extractor/test_knowgl_knowledge_extractor.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | import shutil 4 | from pathlib import Path 5 | 6 | import pytest 7 | import spacy 8 | from transformers import AutoTokenizer 9 | 10 | from zshot import PipelineConfig 11 | from zshot.knowledge_extractor import KnowGL 12 | from zshot.knowledge_extractor.knowgl.utils import ranges, find_sub_list, get_words_mappings, get_spans, get_triples 13 | from zshot.tests.config import TEXTS 14 | from zshot.utils.data_models import Span, Relation 15 | from zshot.utils.data_models.relation_span import RelationSpan 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | @pytest.fixture(scope="module", autouse=True) 21 | def teardown(): 22 | logger.warning("Starting regen tests") 23 | yield True 24 | logger.warning("Removing cache") 25 | shutil.rmtree(f"{Path.home()}/.cache/huggingface", ignore_errors=True) 26 | shutil.rmtree(f"{Path.home()}/.cache/zshot", ignore_errors=True) 27 | gc.collect() 28 | 29 | 30 | def test_knowgl_knowledge_extractor(): 31 | nlp = spacy.blank("en") 32 | config = PipelineConfig( 33 | knowledge_extractor=KnowGL() 34 | ) 35 | nlp.add_pipe("zshot", config=config, last=True) 36 | assert "zshot" in nlp.pipe_names 37 | 38 | doc = nlp(TEXTS[0]) 39 | assert len(doc.ents) > 0 40 | assert len(doc._.spans) > 0 41 | assert len(doc._.relations) > 0 42 | doc = nlp("") 43 | assert len(doc.ents) == 0 44 | assert len(doc._.spans) == 0 45 | assert len(doc._.relations) == 0 46 | docs = [doc for doc in nlp.pipe(TEXTS)] 47 | assert all(len(doc.ents) > 0 for doc in docs) 48 | assert all(len(doc._.spans) > 0 for doc in docs) 49 | assert all(len(doc._.relations) > 0 for doc in docs) 50 | nlp.remove_pipe('zshot') 51 | del doc, nlp, config 52 | 53 | 54 | def test_ranges(): 55 | numbers = [0, 1, 2, 3, 7] 56 | assert list(ranges(numbers)) == [[0, 1, 2, 3], [7]] 57 | 58 | 59 | def test_find_sub_list(): 60 | numbers = [0, 1, 2, 3, 7] 61 | sl = [1, 2, 3] 62 | results = find_sub_list(sl, numbers) 63 | assert type(results) is list 64 | init, end = results[0] 65 | assert init == 1 and end == 3 66 | 67 | 68 | def test_get_spans(): 69 | tokenizer = AutoTokenizer.from_pretrained("ibm/knowgl-large") 70 | input_data = tokenizer(TEXTS, 71 | truncation=True, 72 | padding=True, 73 | return_tensors="pt") 74 | words_mapping, char_mapping = get_words_mappings(input_data.encodings[0], TEXTS[0]) 75 | assert words_mapping and char_mapping 76 | spans = get_spans("LICIACube", "CubeSat", tokenizer, input_data.encodings[0], 77 | words_mapping, char_mapping) 78 | assert spans == [Span(78, 87, 'CubeSat')] 79 | 80 | words_mapping, char_mapping = get_words_mappings(input_data.encodings[1], TEXTS[1]) 81 | assert words_mapping and char_mapping 82 | spans = get_spans("CH2O2", "CH2O2", tokenizer, input_data.encodings[1], 83 | words_mapping, char_mapping) 84 | assert spans == [Span(0, 5, 'CH2O2')] 85 | 86 | 87 | def test_get_triples(): 88 | s1 = Span(78, 87, 'CubeSat') 89 | o1 = Span(41, 48, 'small satellite') 90 | o2 = Span(30, 38, 'Satellite') 91 | objects = [o1, o2] 92 | rel = "instance of" 93 | triples = get_triples([s1], rel, objects) 94 | assert len(triples) == 2 95 | for obj, triple in zip(objects, triples): 96 | assert triple[0] == s1 and triple[2] == obj 97 | assert triple[1] == RelationSpan(start=s1, end=obj, relation=Relation(name=rel, description="")) 98 | -------------------------------------------------------------------------------- /zshot/knowledge_extractor/knowgl/knowledge_extractor_knowgl.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Iterator, Optional, Union 2 | 3 | from spacy.tokens import Doc 4 | from tokenizers import Encoding 5 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 6 | 7 | from zshot.knowledge_extractor.knowgl.utils import get_words_mappings, get_spans, get_triples 8 | from zshot.knowledge_extractor.knowledge_extractor import KnowledgeExtractor 9 | from zshot.utils.data_models import Span 10 | from zshot.utils.data_models.relation_span import RelationSpan 11 | 12 | 13 | class KnowGL(KnowledgeExtractor): 14 | def __init__(self, model_name="ibm/knowgl-large"): 15 | """ Instantiate the KnowGL Knowledge Extractor """ 16 | super().__init__() 17 | 18 | self.model_name = model_name 19 | self.model = None 20 | self.tokenizer = None 21 | 22 | def load_models(self): 23 | """ Load KnowGL model """ 24 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) 25 | self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) 26 | self.model.to(self.device) 27 | 28 | def parse_result(self, result: str, doc: Doc, 29 | encodings: Encoding) -> List[Tuple[Span, RelationSpan, Span]]: 30 | """ Parse the text result into a list of triples 31 | 32 | :param result: Text generate by the KnowGL model 33 | :param doc: Spacy doc 34 | :param encodings: Encodings result of the tokenization 35 | :return: List of triples (subject, relation, object) 36 | """ 37 | words_mapping, char_mapping = get_words_mappings(encodings, doc.text) 38 | triples = [] 39 | for triple in result.split("$"): 40 | subject_, relation, object_ = triple.split("|") 41 | s_mention, s_label, s_type = subject_.strip("[()]").split("#") 42 | o_mention, o_label, o_type = object_.strip("[()]").split("#") 43 | s_type = s_label if s_label != "None" else s_type 44 | o_type = o_label if o_label != "None" else o_type 45 | subject_spans = get_spans(s_mention, s_type, self.tokenizer, encodings, 46 | words_mapping, char_mapping) 47 | object_spans = get_spans(o_mention, o_type, self.tokenizer, encodings, 48 | words_mapping, char_mapping) 49 | triples += get_triples(subject_spans, relation, object_spans) 50 | 51 | return triples 52 | 53 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) \ 54 | -> List[List[Tuple[Span, RelationSpan, Span]]]: 55 | """ Extract triples from docs 56 | 57 | :param docs: Spacy Docs to process 58 | :param batch_size: Batch size for processing 59 | :return: Triples (subject, relation, object) extracted for each document 60 | """ 61 | if not self.model: 62 | self.load_models() 63 | 64 | texts = [d.text for d in docs] 65 | input_data = self.tokenizer(texts, 66 | truncation=True, 67 | padding=True, 68 | return_tensors="pt") 69 | input_ids = input_data.input_ids.to(self.model.device) 70 | outputs = self.model.generate(inputs=input_ids) 71 | 72 | triples = [] 73 | for doc, output, encodings in zip(docs, outputs, input_data.encodings): 74 | result = self.tokenizer.decode(token_ids=output, skip_special_tokens=True) 75 | triples.append(self.parse_result(result, doc, encodings)) 76 | 77 | return triples 78 | -------------------------------------------------------------------------------- /zshot/utils/displacy/displacy.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Dict, Union, Iterable, Any, Optional 3 | 4 | from spacy import displacy as s_displacy 5 | from spacy.errors import Warnings 6 | from spacy.tokens import Doc 7 | from spacy.util import is_in_jupyter 8 | 9 | from zshot.utils.displacy.colors import light_color_from_label 10 | from zshot.utils.displacy.relations_render import RelationsRenderer, parse_rels 11 | 12 | 13 | def ents_colors(docs: Union[Iterable[Union[Doc]], Doc]): 14 | """ 15 | Can be used to derive colors for entities in a Spacy document. 16 | A color for each entity type in generated, using the entity label hash 17 | :param docs: A list of Spacy document with entities 18 | :return: A colors dictionary containing a color for each entity type 19 | """ 20 | 21 | if isinstance(docs, Doc): 22 | docs = [docs] 23 | labels = set([ent.label_ for doc in docs for ent in doc.ents]) 24 | colors = dict([(ent, light_color_from_label(ent)) for ent in labels]) 25 | return colors 26 | 27 | 28 | class displacy: 29 | 30 | @staticmethod 31 | def render(docs: Union[Iterable[Union[Doc]], Doc], style: str = "dep", options: Dict = None, **kwargs) -> str: 32 | return displacy._call_displacy(docs, style, "render", options=options, **kwargs) 33 | 34 | @staticmethod 35 | def serve(docs: Union[Iterable[Union[Doc]], Doc], style: str = "dep", options: Dict = None, **kwargs): 36 | return displacy._call_displacy(docs, style, "serve", options=options, **kwargs) 37 | 38 | @staticmethod 39 | def _call_displacy(docs: Union[Iterable[Union[Doc]], Doc], style: str, method: str, options: Dict[str, Any] = {}, 40 | port: int = 5000, host: str = "0.0.0.0", page: bool = True, minify: bool = False, 41 | jupyter: Optional[bool] = None, 42 | **kwargs) -> str: 43 | if isinstance(docs, Doc): 44 | docs = [docs] 45 | if options is None: 46 | options = {} 47 | if style == "ent": 48 | options.update({'colors': ents_colors(docs)}) 49 | if style == "rel": 50 | re_renderer = RelationsRenderer(options=options) 51 | parsed = [parse_rels(doc) for doc in docs] 52 | html = re_renderer.render(parsed, page=page, minify=minify) 53 | s_displacy._html["parsed"] = html 54 | if "serve" in method: 55 | from wsgiref import simple_server 56 | if is_in_jupyter(): 57 | warnings.warn(Warnings.W011) 58 | httpd = simple_server.make_server(host=host, port=port, app=s_displacy.app) 59 | print(f"\nUsing the '{style}' visualizer") 60 | print(f"Serving on http://{host}:{port} ...\n") 61 | try: 62 | httpd.serve_forever() 63 | except KeyboardInterrupt: 64 | print(f"Shutting down server on port {port}.") 65 | finally: 66 | httpd.server_close() 67 | else: 68 | if jupyter or (jupyter is None and is_in_jupyter()): 69 | # return HTML rendered by IPython display() 70 | # See #4840 for details on span wrapper to disable mathjax 71 | from IPython.core.display import display, HTML 72 | return display(HTML('{}'.format(html))) 73 | return html 74 | 75 | if method == "render": 76 | kwargs.update({"jupyter": jupyter}) 77 | 78 | disp = getattr(s_displacy, method) 79 | return disp(docs, style=style, options=options, **kwargs) 80 | -------------------------------------------------------------------------------- /zshot/mentions_extractor/mentions_extractor_tars.py: -------------------------------------------------------------------------------- 1 | import pkgutil 2 | from typing import Iterator, Optional, Union, List 3 | 4 | from spacy.tokens.doc import Doc 5 | 6 | from zshot.mentions_extractor.mentions_extractor import MentionsExtractor 7 | from zshot.utils.models.tars.utils import tars_predict 8 | from zshot.utils.data_models import Entity, Span 9 | 10 | 11 | class MentionsExtractorTARS(MentionsExtractor): 12 | """ TARS end2end Linker """ 13 | def __init__(self, default_entities: Optional[str] = "conll-short"): 14 | """ 15 | :param default_entities: Default entities to use in case no custom ones are set 16 | One of: 17 | - 'conll-short' 18 | - 'ontonotes-long' 19 | - 'ontonotes-short' 20 | - 'wnut_17-long' 21 | - 'wnut_17-short' 22 | """ 23 | super().__init__() 24 | if not pkgutil.find_loader("flair"): 25 | raise Exception("Flair module not installed. You need to install Flair for using this class." 26 | "Install it with: pip install flair>=0.13") 27 | 28 | self.is_end2end = True 29 | self.default_entities = default_entities 30 | self.model = None 31 | self.task = None 32 | 33 | def set_kg(self, mentions: Iterator[Entity]): 34 | """ Set new entities in the model 35 | 36 | :param mentions: New entities to use 37 | """ 38 | old_entities = self._mentions 39 | super().set_kg(mentions) 40 | if old_entities != mentions: 41 | self.flat_entities() 42 | self.task = f'zshot.ner.{hash(tuple(self._mentions))}' 43 | if not self.model: 44 | self.load_models() 45 | self.model.add_and_switch_to_new_task(self.task, 46 | self._mentions, label_type='ner') 47 | 48 | def flat_entities(self): 49 | """ As TARS use only the labels, take just the name of the entities and not the description """ 50 | if isinstance(self._mentions, dict): 51 | self._mentions = list(self._mentions.keys()) 52 | if isinstance(self._mentions, list): 53 | self._mentions = [e.name if type(e) is Entity else e for e in self._mentions] 54 | if self._mentions is None: 55 | self._mentions = [] 56 | 57 | def load_models(self): 58 | """ Load TARS model if its not initialized""" 59 | if not self.model: 60 | from flair.models import TARSTagger 61 | self.model = TARSTagger.load('tars-ner') 62 | 63 | if not self.mentions: 64 | self.model.switch_to_task(self.default_entities) 65 | self.task = self.default_entities 66 | else: 67 | self.flat_entities() 68 | self.task = f'zshot.ner.{hash(tuple(self._mentions))}' 69 | self.model.add_and_switch_to_new_task(self.task, 70 | self._mentions, label_type='ner') 71 | 72 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: 73 | """ 74 | Perform the entity prediction 75 | :param docs: A list of spacy Document 76 | :param batch_size: The batch size 77 | :return: List Spans for each Document in docs 78 | """ 79 | from flair.data import Sentence 80 | 81 | self.load_models() 82 | 83 | sentences = [ 84 | Sentence(str(doc), use_tokenizer=True) for doc in docs 85 | ] 86 | 87 | spans_annotations = tars_predict(self.model, sentences, batch_size) 88 | 89 | return spans_annotations 90 | -------------------------------------------------------------------------------- /zshot/tests/linker/test_regen_linker.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | 4 | import pytest 5 | import spacy 6 | 7 | from zshot import PipelineConfig 8 | from zshot.linker.linker_regen.linker_regen import LinkerRegen 9 | from zshot.linker.linker_regen.trie import Trie 10 | from zshot.linker.linker_regen.utils import load_wikipedia_trie, load_dbpedia_trie, create_input 11 | from zshot.tests.config import EX_DOCS, EX_ENTITIES 12 | from zshot.tests.mentions_extractor.test_mention_extractor import DummyMentionsExtractor 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | @pytest.fixture(scope="module", autouse=True) 18 | def teardown(): 19 | logger.warning("Starting regen tests") 20 | yield True 21 | gc.collect() 22 | 23 | 24 | def test_regen_linker(): 25 | nlp = spacy.blank("en") 26 | config = PipelineConfig( 27 | mentions_extractor=DummyMentionsExtractor(), 28 | linker=LinkerRegen(), 29 | entities=EX_ENTITIES 30 | ) 31 | nlp.add_pipe("zshot", config=config, last=True) 32 | assert "zshot" in nlp.pipe_names 33 | 34 | doc = nlp(EX_DOCS[1]) 35 | assert len(doc.ents) > 0 36 | doc = nlp("") 37 | assert len(doc.ents) == 0 38 | docs = [doc for doc in nlp.pipe(EX_DOCS)] 39 | assert all(len(doc.ents) > 0 for doc in docs) 40 | del nlp.get_pipe('zshot').mentions_extractor, nlp.get_pipe('zshot').entities, nlp.get_pipe('zshot').nlp 41 | del nlp.get_pipe('zshot').linker.tokenizer, nlp.get_pipe('zshot').linker.trie, \ 42 | nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker 43 | nlp.remove_pipe('zshot') 44 | del doc, nlp, config 45 | 46 | 47 | def test_regen_linker_wikification(): 48 | nlp = spacy.blank("en") 49 | trie = Trie() 50 | trie.add([794, 536, 1]) 51 | trie.add([794, 357, 1]) 52 | config = PipelineConfig( 53 | mentions_extractor=DummyMentionsExtractor(), 54 | linker=LinkerRegen(trie=trie), 55 | ) 56 | nlp.add_pipe("zshot", config=config, last=True) 57 | assert "zshot" in nlp.pipe_names 58 | 59 | doc = nlp(EX_DOCS[1]) 60 | assert len(doc.ents) > 0 61 | del nlp.get_pipe('zshot').mentions_extractor, nlp.get_pipe('zshot').entities, nlp.get_pipe('zshot').nlp 62 | del nlp.get_pipe('zshot').linker.tokenizer, nlp.get_pipe('zshot').linker.trie, \ 63 | nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker 64 | nlp.remove_pipe('zshot') 65 | del doc, nlp, config 66 | 67 | 68 | @pytest.mark.skip(reason="Too expensive to run on every commit") 69 | def test_load_wikipedia_trie(): # pragma: no cover 70 | trie = load_wikipedia_trie() 71 | assert len(list(trie.trie_dict.keys())) == 6952 72 | 73 | 74 | @pytest.mark.skip(reason="Too expensive to run on every commit") 75 | def test_load_dbpedia_trie(): # pragma: no cover 76 | trie = load_dbpedia_trie() 77 | assert len(list(trie.trie_dict.keys())) == 7156 78 | 79 | 80 | def test_create_input(): 81 | start_delimiter = "[START]" 82 | end_delimiter = "[END]" 83 | max_length = 10 84 | 85 | times_rep = 6 86 | sentence = "[START]" + " test" * times_rep + " [END]" 87 | input_sentence = create_input(sentence, max_length, start_delimiter, end_delimiter) 88 | assert input_sentence == sentence 89 | times_rep = 12 90 | sentence = "[START]" + " test" * times_rep + " [END]" 91 | input_sentence = create_input(sentence, max_length, start_delimiter, end_delimiter) 92 | assert input_sentence == " ".join(["test" for i in range(9)]) 93 | 94 | text = f"IBM headquarters are located in {start_delimiter} New York {end_delimiter} ." 95 | input_ = create_input(text, max_length=4, start_delimiter=start_delimiter, end_delimiter=end_delimiter) 96 | assert start_delimiter in input_ and end_delimiter in input_ 97 | -------------------------------------------------------------------------------- /zshot/tests/utils/test_description_enrichment.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import spacy 3 | 4 | from zshot import PipelineConfig 5 | from zshot.linker import LinkerSMXM 6 | from zshot.utils.data_models import Entity 7 | from zshot.utils.enrichment.description_enrichment import PreTrainedLMExtensionStrategy, \ 8 | FineTunedLMExtensionStrategy, SummarizationStrategy, ParaphrasingStrategy, EntropyHeuristic 9 | 10 | 11 | @pytest.mark.skip(reason="Too expensive to run on every commit") 12 | def test_pretrained_lm_extension_strategy(): 13 | description = "The name of a company" 14 | strategy = PreTrainedLMExtensionStrategy() 15 | num_variations = 3 16 | 17 | desc_variations = strategy.alter_description( 18 | description, num_variations=num_variations 19 | ) 20 | 21 | assert len(desc_variations) == 3 and len(set(desc_variations + [description])) == 4 22 | 23 | 24 | @pytest.mark.skip(reason="Too expensive to run on every commit") 25 | def test_finetuned_lm_extension_strategy(): 26 | description = "The name of a company" 27 | strategy = FineTunedLMExtensionStrategy() 28 | num_variations = 3 29 | 30 | desc_variations = strategy.alter_description( 31 | description, num_variations=num_variations 32 | ) 33 | 34 | assert len(desc_variations) == 3 and len(set(desc_variations + [description])) == 4 35 | 36 | 37 | @pytest.mark.skip(reason="Too expensive to run on every commit") 38 | def test_summarization_strategy(): 39 | description = "The name of a company" 40 | strategy = SummarizationStrategy() 41 | num_variations = 3 42 | 43 | desc_variations = strategy.alter_description( 44 | description, num_variations=num_variations 45 | ) 46 | 47 | assert len(desc_variations) == 3 and len(set(desc_variations + [description])) == 4 48 | 49 | 50 | @pytest.mark.skip(reason="Too expensive to run on every commit") 51 | def test_paraphrasing_strategy(): 52 | description = "The name of a company" 53 | strategy = ParaphrasingStrategy() 54 | num_variations = 3 55 | 56 | desc_variations = strategy.alter_description( 57 | description, num_variations=num_variations 58 | ) 59 | 60 | assert len(desc_variations) == 3 and len(set(desc_variations + [description])) == 4 61 | 62 | 63 | @pytest.mark.skip(reason="Too expensive to run on every commit") 64 | def test_entropy_heuristic(): 65 | def check_is_tuple(x): 66 | return isinstance(x, tuple) and len(x) == 2 and isinstance(x[0], str) and isinstance(x[1], float) 67 | 68 | entropy_heuristic = EntropyHeuristic() 69 | dataset = [ 70 | {'tokens': ['IBM', 'headquarters', 'are', 'located', 'in', 'Armonk', '.'], 71 | 'ner_tags': ['B-company', 'O', 'O', 'O', 'O', 'B-location', 'O']} 72 | ] 73 | entities = [ 74 | Entity(name="company", description="The name of a company"), 75 | Entity(name="location", description="A physical location"), 76 | ] 77 | 78 | nlp = spacy.blank("en") 79 | nlp_config = PipelineConfig( 80 | linker=LinkerSMXM(), 81 | entities=entities 82 | ) 83 | nlp.add_pipe("zshot", config=nlp_config, last=True) 84 | strategy = ParaphrasingStrategy() 85 | num_variations = 3 86 | 87 | variations = entropy_heuristic.evaluate_variations_strategy(dataset, 88 | entities=entities, 89 | alter_strategy=strategy, 90 | num_variations=num_variations, 91 | nlp_pipeline=nlp) 92 | 93 | assert len(variations) == 2 94 | assert len(variations[0]) == 3 and len(variations[1]) == 3 95 | assert all([check_is_tuple(x) for x in variations[0]]) 96 | assert all([check_is_tuple(x) for x in variations[1]]) 97 | -------------------------------------------------------------------------------- /zshot/relation_extractor/zsrc/zero_shot_rel_class.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from transformers import BertConfig, BertModel, BertPreTrainedModel 7 | 8 | from zshot.config import MODELS_CACHE_PATH 9 | from zshot.utils.file_utils import download_file 10 | 11 | MODEL_REMOTE_URL = 'https://huggingface.co/albep/zsrc/resolve/main/zsrc' 12 | MODEL_PATH = os.path.join(MODELS_CACHE_PATH, 'zsrc') 13 | 14 | 15 | def load_model(device: Optional[Union[str, torch.device]] = None): 16 | model = ZSBert(device) 17 | if not os.path.isfile(MODEL_PATH): 18 | download_file(MODEL_REMOTE_URL, MODELS_CACHE_PATH) 19 | 20 | model.load_state_dict(torch.load(MODEL_PATH), strict=False) 21 | model.to(device) 22 | model.eval() 23 | return model 24 | 25 | 26 | class ZSBert(BertPreTrainedModel): 27 | def __init__(self, device: Optional[Union[str, torch.device]] = 'cpu'): 28 | bertconfig = BertConfig.from_pretrained('bert-large-cased', num_labels=2, finetuning_task='fewrel-zero-shot', 29 | device=device) 30 | bertconfig.relation_emb_dim = 1024 31 | super().__init__(bertconfig) 32 | self.bert = BertModel(bertconfig) 33 | self.num_labels = 2 34 | self.relation_emb_dim = 1024 35 | self.dropout = nn.Dropout(bertconfig.hidden_dropout_prob) 36 | self.fclayer = nn.Linear(bertconfig.hidden_size * 3, self.relation_emb_dim) 37 | self.classifier = nn.Linear( 38 | self.relation_emb_dim, bertconfig.num_labels) 39 | self.batch_size = 4 40 | self.init_weights() 41 | self.bert.to(device) 42 | 43 | def forward( 44 | self, 45 | input_ids=None, 46 | attention_mask=None, 47 | token_type_ids=None, 48 | position_ids=None, 49 | e1_mask=None, 50 | e2_mask=None, 51 | head_mask=None, 52 | inputs_embeds=None, 53 | labels=None, 54 | ): 55 | 56 | outputs = self.bert( 57 | input_ids, 58 | attention_mask=attention_mask, 59 | token_type_ids=token_type_ids, 60 | position_ids=position_ids, 61 | head_mask=head_mask, 62 | inputs_embeds=inputs_embeds, 63 | ) 64 | 65 | # Sequence of hidden-states of the last layer. 66 | sequence_output = outputs[0] 67 | # Last layer hidden-state of the [CLS] token further processed 68 | pooled_output = outputs[1] 69 | # by a Linear layer and a Tanh activation function. 70 | 71 | def extract_entity(sequence_output, e_mask): 72 | extended_e_mask = e_mask.unsqueeze(1) 73 | extended_e_mask = torch.bmm( 74 | extended_e_mask.float(), sequence_output).squeeze(1) 75 | return extended_e_mask.float() 76 | 77 | e1_h = extract_entity(sequence_output, e1_mask) 78 | e2_h = extract_entity(sequence_output, e2_mask) 79 | context = self.dropout(pooled_output) 80 | pooled_output = torch.cat([context, e1_h, e2_h], dim=-1) 81 | pooled_output = torch.tanh(pooled_output) 82 | pooled_output = self.fclayer(pooled_output) 83 | sent_embedding = torch.tanh(pooled_output) 84 | sent_embedding = self.dropout(sent_embedding) 85 | 86 | # [batch_size x hidden_size] 87 | logits = self.classifier(sent_embedding).to(self.bert.device) 88 | # add hidden states and attention if they are here 89 | 90 | outputs = (torch.softmax(logits, -1),) + outputs[2:] 91 | if labels is not None: 92 | ce_loss = nn.CrossEntropyLoss() 93 | labels = labels.to(self.bert.device) 94 | loss = (ce_loss(logits.view(-1, self.num_labels), labels.view(-1))) 95 | outputs = (loss,) + outputs 96 | 97 | return outputs 98 | -------------------------------------------------------------------------------- /zshot/linker/linker_ensemble/linker_ensemble.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional, List 2 | 3 | from spacy.tokens import Doc 4 | 5 | from zshot.linker import Linker 6 | from zshot.linker import LinkerSMXM 7 | from zshot.utils.ensembler import Ensembler 8 | from zshot.utils.data_models import Entity 9 | from zshot.linker.linker_ensemble.utils import sub_span_scoring_per_description, get_enhance_entities 10 | 11 | 12 | class LinkerEnsemble(Linker): 13 | def __init__(self, 14 | linkers: Optional[List[Linker]] = None, 15 | strategy: Optional[str] = 'max', 16 | threshold: Optional[float] = 0.5): 17 | """ Ensemble of linkers and entities to improve performance. 18 | Each combination of linker with entity will be a voter. 19 | 20 | :param linkers: Linkers to use in the ensemble 21 | :param strategy: Strategy to use. Options: max; count 22 | When `max` choose the label with max total vote score 23 | When `count` choose the label with max total vote count 24 | :param threshold: Threshold to use. Proportion of voters voting the entity 25 | """ 26 | super(LinkerEnsemble, self).__init__() 27 | if linkers is not None: 28 | self.linkers = linkers 29 | else: 30 | # default options 31 | self.linkers = [ 32 | LinkerSMXM() 33 | ] 34 | self.enhance_entities = [] 35 | self.strategy = strategy 36 | self.threshold = threshold 37 | self.ensembler = None 38 | 39 | def set_smxm_model(self, smxm_model): 40 | for linker in self.linkers: 41 | if isinstance(linker, LinkerSMXM): 42 | linker.model_name = smxm_model 43 | 44 | def set_kg(self, entities: Iterator[Entity]): 45 | """ 46 | Set entities that linker can use 47 | :param entities: The list of entities 48 | """ 49 | super().set_kg(entities) 50 | self.enhance_entities = get_enhance_entities(self.entities) 51 | self.ensembler = Ensembler(len(self.linkers), 52 | len(self.enhance_entities) if self.enhance_entities is not None else -1, 53 | threshold=self.threshold) 54 | for linker in self.linkers: 55 | linker.set_kg(entities) 56 | 57 | def predict(self, docs: Iterator[Doc], batch_size=None): 58 | """ 59 | Perform the entity prediction 60 | :param docs: A list of spacy Document 61 | :param batch_size: The batch size 62 | :return: List Spans for each Document in docs 63 | """ 64 | spans = [] 65 | for entities in self.enhance_entities: 66 | self.set_kg(entities) 67 | for linker in self.linkers: 68 | span_prediction = linker.predict(docs, batch_size) 69 | spans.append(span_prediction) 70 | 71 | return self.prediction_ensemble(spans) 72 | 73 | def prediction_ensemble(self, spans): 74 | doc_ensemble_spans = [] 75 | num_doc = len(spans[0]) 76 | for doc_idx in range(num_doc): 77 | union_spans = {} 78 | span_per_descriptions = [] 79 | for span in spans: 80 | span_per_descriptions.append(span[doc_idx]) 81 | for s in span[doc_idx]: 82 | span_pos = (s.start, s.end) 83 | if span_pos not in union_spans: 84 | union_spans[span_pos] = [s] 85 | else: 86 | union_spans[span_pos].append(s) 87 | 88 | sub_span_scoring_per_description(union_spans, span_per_descriptions) 89 | all_union_spans = self.ensembler.ensemble(union_spans) 90 | doc_ensemble_spans.append(all_union_spans) 91 | 92 | return doc_ensemble_spans 93 | -------------------------------------------------------------------------------- /zshot/utils/models/smxm/model.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from transformers import BertModel, BertPreTrainedModel, logging 5 | 6 | logging.set_verbosity_error() 7 | 8 | 9 | class BertTaggerMultiClass(BertPreTrainedModel): 10 | def __init__(self, config): 11 | 12 | super().__init__(config) 13 | 14 | self.bert = BertModel(config) 15 | self.drop = torch.nn.Dropout(config.finetuning_task["dropout_prob"]) 16 | self.bert_output_size = config.hidden_size 17 | self.linear = torch.nn.Linear(self.bert_output_size, 1) 18 | self.linear_zero2 = torch.nn.Linear(self.bert_output_size, 1) 19 | self.linear_zero3 = torch.nn.Linear(self.bert_output_size, 1) 20 | 21 | self.init_weights() 22 | 23 | def forward( 24 | self, 25 | *args, 26 | input_ids, 27 | attention_mask, 28 | token_type_ids, 29 | sep_index, 30 | seq_mask, 31 | split, 32 | **kwargs, 33 | ): 34 | sep_index_max = torch.max(sep_index) 35 | predictions = [] 36 | predictions_zero = [] 37 | predictions_zero_base = [] 38 | for j in range(input_ids.size(0)): 39 | if j == 0: 40 | inp_zero = torch.stack( 41 | [ 42 | input_ids[j][i, : sep_index_max.item()] 43 | for i in range(sep_index.size(0)) 44 | ] 45 | ) 46 | att_zero = torch.stack( 47 | [ 48 | attention_mask[j][i, : sep_index_max.item()] 49 | for i in range(sep_index.size(0)) 50 | ] 51 | ) 52 | tok_type_zero = torch.stack( 53 | [ 54 | token_type_ids[j][i, : sep_index_max.item()] 55 | for i in range(sep_index.size(0)) 56 | ] 57 | ) 58 | with torch.no_grad(): 59 | words_out = self.bert( 60 | input_ids=inp_zero, 61 | attention_mask=att_zero, 62 | token_type_ids=tok_type_zero, 63 | )[0] 64 | 65 | pooled_out = self.drop(words_out) 66 | logits = self.linear_zero3(pooled_out) 67 | predictions_zero_base.append(logits) 68 | else: 69 | with torch.no_grad(): 70 | words_out = self.bert( 71 | input_ids=input_ids[j], 72 | attention_mask=attention_mask[j], 73 | token_type_ids=token_type_ids[j], 74 | )[0] 75 | 76 | words_out = torch.stack( 77 | [ 78 | words_out[i, : sep_index_max.item(), :] 79 | for i in range(words_out.size(0)) 80 | ] 81 | ) 82 | pooled_out = self.drop(words_out) 83 | predictions_zero.append(self.linear_zero2(pooled_out)) 84 | logits = self.linear(pooled_out) 85 | predictions.append(logits) 86 | 87 | random.shuffle(predictions_zero) 88 | predictions_zero = torch.stack(predictions_zero_base + predictions_zero) 89 | predictions_zero = predictions_zero.transpose(0, 1).transpose(1, 2) 90 | predictions_zero = predictions_zero.contiguous().view( 91 | predictions_zero.size(0), predictions_zero.size(1), -1 92 | ) 93 | 94 | predictions_zero = torch.max(predictions_zero, dim=2)[0].unsqueeze(2) 95 | 96 | predictions = torch.stack(predictions) 97 | predictions = predictions.transpose(0, 1).transpose(1, 2).squeeze(3) 98 | 99 | logits = torch.cat((predictions_zero, predictions), dim=2) 100 | 101 | return logits 102 | -------------------------------------------------------------------------------- /zshot/linker/linker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | import zlib 4 | from abc import ABC, abstractmethod 5 | from typing import Iterator, List, Optional, Union 6 | 7 | import torch 8 | from spacy.tokens import Doc 9 | from spacy.util import ensure_path 10 | 11 | from zshot.utils.data_models import Entity, Span 12 | from zshot.utils.alignment_utils import filter_overlapping_spans, spacy_token_offsets 13 | 14 | 15 | class Linker(ABC): 16 | """ 17 | Linker define a standard interface for entity linking. A Linker may relay on existing 18 | extracted mentions or perform end-2-end extraction 19 | """ 20 | 21 | def __init__(self, device: Optional[Union[str, torch.device]] = None): 22 | self._entities = None 23 | self._is_end2end = False 24 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 25 | 26 | def set_device(self, device: Union[str, torch.device]): 27 | """ 28 | Set the device to use 29 | :param device: 30 | :return: 31 | """ 32 | self.device = device 33 | 34 | def set_kg(self, entities: Iterator[Entity]): 35 | """ 36 | Set entities that linker can use 37 | :param entities: The list of entities 38 | """ 39 | self._entities = entities 40 | 41 | @property 42 | def entities(self) -> List[Entity]: 43 | """ Entities to link to """ 44 | return self._entities 45 | 46 | def load_models(self): 47 | """ 48 | Load the model 49 | :return: 50 | """ 51 | pass 52 | 53 | @abstractmethod 54 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) -> List[List[Span]]: 55 | """ 56 | Perform the entity prediction 57 | :param docs: A list of spacy Document 58 | :param batch_size: The batch size 59 | :return: List Spans for each Document in docs 60 | """ 61 | pass 62 | 63 | @property 64 | def is_end2end(self) -> bool: 65 | return self._is_end2end 66 | 67 | @is_end2end.setter 68 | def is_end2end(self, value): 69 | self._is_end2end = value 70 | 71 | @staticmethod 72 | def version() -> str: 73 | return "v1" 74 | 75 | def link(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None): 76 | """ 77 | Perform the entity linking. Call the predict function and add entities to the Spacy Docs 78 | :param docs: A list of spacy Document 79 | :param batch_size: The batch size 80 | :return: 81 | """ 82 | predictions_spans = self.predict(docs, batch_size) 83 | 84 | for d, preds in zip(docs, predictions_spans): 85 | d._.spans = preds 86 | d.ents = map(lambda p: p.to_spacy_span(d), filter_overlapping_spans(preds, list(d), 87 | tokens_offsets=spacy_token_offsets(d))) 88 | # d.spans = map(lambda p: p.to_spacy_span(d), preds) 89 | 90 | @staticmethod 91 | def _get_serialize_file(path): 92 | return os.path.join(path, "linker.pkl") 93 | 94 | @staticmethod 95 | def _get_config_file(path): 96 | path = os.path.join(path, "linker.json") 97 | path = ensure_path(path) 98 | return path 99 | 100 | @classmethod 101 | def from_disk(cls, path, exclude=()): 102 | serialize_file = cls._get_serialize_file(path) 103 | with open(serialize_file, "rb") as f: 104 | return pkl.load(f) 105 | 106 | def to_disk(self, path): 107 | serialize_file = self._get_serialize_file(path) 108 | with open(serialize_file, "wb") as f: 109 | return pkl.dump(self, f) 110 | 111 | def __hash__(self): 112 | self_repr = f"{self.__class__.__name__}.{self.version()}.{str(self.__dict__)}" 113 | return zlib.crc32(self_repr.encode()) 114 | -------------------------------------------------------------------------------- /zshot/tests/config.py: -------------------------------------------------------------------------------- 1 | from zshot.utils.data_models import Entity, Relation 2 | 3 | EX_DOCS = ["The Domain Name System (DNS) is the hierarchical and decentralized naming system used to identify" 4 | " computers, services, and other resources reachable through the Internet or other Internet Protocol" 5 | " (IP) networks.", 6 | "International Business Machines Corporation (IBM) is an American multinational technology corporation" 7 | " headquartered in Armonk, New York, with operations in over 171 countries."] 8 | 9 | EX_ENTITIES = \ 10 | [ 11 | Entity(name="apple", description="the apple fruit"), 12 | Entity(name="DNS", description="domain name system", vocabulary=["DNS", "Domain Name System"]), 13 | Entity(name="IBM", description="technology corporation", vocabulary=["IBM", "International Business machine"]), 14 | Entity(name="NYC", description="New York city"), 15 | Entity(name="Florida", description="southeasternmost U.S. state"), 16 | Entity(name="Paris", description="Paris is located in northern central France, " 17 | "in a north-bending arc of the river Seine"), 18 | ] 19 | 20 | EX_RELATIONS = \ 21 | [ 22 | Relation(name="parent", description="Is the parent of someone"), 23 | Relation(name="child", description="Is the child of someone"), 24 | Relation(name="sibling", description="Is the sibling of someone"), 25 | Relation(name='crosses', 26 | description='obstacle (body of water, road, ...) ' 27 | 'which this bridge crosses over or this tunnel goes under'), 28 | ] 29 | 30 | EX_DATASET_RELATIONS = { 31 | 'sentences': [ 32 | 'In June 1987 , the Missouri Highway and Transportation Department approved design location of a ' 33 | 'new four - lane Mississippi River bridge to replace the deteriorating Cape Girardeau Bridge .', 34 | 'Wilton Bridge was a major crossing of the River Wye and was protected by Wilton Castle .'], 35 | 'sentence_entities': [[{'end': 187, 36 | 'label': 'Q5034838', 37 | 'sentence': 'In June 1987 , the Missouri Highway and Transportation Department ' 38 | 'approved design location of a new four - lane ' 39 | 'Mississippi River bridge to replace the deteriorating Cape Girardeau Bridge .', 40 | 'start': 165}, 41 | {'end': 129, 42 | 'label': 'Q1497', 43 | 'sentence': 'In June 1987 , the Missouri Highway and Transportation Department ' 44 | 'approved design location of a new four - lane Mississippi River ' 45 | 'bridge to replace the deteriorating Cape Girardeau Bridge .', 46 | 'start': 111}], 47 | [{'end': 13, 48 | 'label': 'Q8023362', 49 | 'sentence': 'Wilton Bridge was a major crossing of the River Wye ' 50 | 'and was protected by Wilton Castle .', 51 | 'start': 0}, 52 | {'end': 55, 53 | 'label': 'Q19695', 54 | 'sentence': 'Wilton Bridge was a major crossing of the River Wye ' 55 | 'and was protected by Wilton Castle .', 56 | 'start': 45}]], 57 | 'labels': ['crosses', 'crosses']} 58 | 59 | TEXT_CUBESAT = "The Italian Space Agency’s Light Italian CubeSat for Imaging of Asteroids, or LICIACube, will fly by Dimorphos to capture images and video of the impact plume as it sprays up off the asteroid and maybe even spy the crater it could leave behind." 60 | TEXT_ACETAMIDE = "CH2O2 is a chemical compound similar to Acetamide used in International Business Machines Corporation (IBM)." 61 | TEXTS = [TEXT_CUBESAT, TEXT_ACETAMIDE] 62 | -------------------------------------------------------------------------------- /zshot/relation_extractor/relation_extractor_zsrc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | from zshot.relation_extractor.relations_extractor import RelationsExtractor 5 | from zshot.relation_extractor.zsrc import data_helper 6 | from zshot.relation_extractor.zsrc.zero_shot_rel_class import load_model 7 | import numpy as np 8 | from tqdm import tqdm 9 | from typing import Iterator, List 10 | from spacy.tokens import Doc 11 | 12 | from zshot.utils.data_models.relation_span import RelationSpan 13 | 14 | 15 | class RelationsExtractorZSRC(RelationsExtractor): 16 | def __init__(self, thr=0.5): 17 | super().__init__() 18 | self.model = None 19 | self.load_models() 20 | self.thr = thr 21 | 22 | def load_models( 23 | self, 24 | ): 25 | if self.model is None: 26 | self.model = load_model(self.device) 27 | 28 | def predict(self, docs: Iterator[Doc], batch_size=None) -> List[List[RelationSpan]]: 29 | relations_pred = [] 30 | for doc in tqdm(docs, desc='classifying documents'): 31 | relations_doc = [] 32 | items_to_process = [] 33 | for i, e1 in enumerate(doc._.spans): 34 | for j, e2 in enumerate(doc._.spans): 35 | if ( 36 | i == j or (e1, e2) in items_to_process or ( 37 | e2, e1) in items_to_process 38 | ): 39 | continue 40 | else: 41 | items_to_process.append((e1, e2)) 42 | 43 | relations_probs = [] 44 | if self.relations is not None: 45 | for rel in self.relations: 46 | _, probs = self._predict_internal( 47 | [(e1, e2, doc.text)], 48 | rel.description, 49 | batch_size, 50 | ) 51 | relations_probs.append(probs[0]) 52 | pred_class_idx = np.argmax(np.array(relations_probs)) 53 | p = relations_probs[pred_class_idx] 54 | if p >= self.thr: 55 | relations_doc.append( 56 | RelationSpan( 57 | start=e1, end=e2, score=p, relation=self.relations[pred_class_idx]) 58 | ) 59 | relations_pred.append(relations_doc) 60 | return relations_pred 61 | 62 | def _predict_internal(self, items_to_process, relation_description, batch_size=4): 63 | trainset = data_helper.ZSDataset( 64 | 'test', items_to_process, relation_description) 65 | trainloader = DataLoader(trainset, batch_size=batch_size, 66 | collate_fn=data_helper.create_mini_batch_fewrel_aio, shuffle=False) 67 | all_preds = [] 68 | all_probs = [] 69 | for data in trainloader: 70 | tokens_tensors, segments_tensors, marked_e1, marked_e2, masks_tensors, labels = [ 71 | t.to(self.device) for t in data] 72 | if tokens_tensors.shape[1] <= 512: 73 | with torch.no_grad(): 74 | outputs = self.model(input_ids=tokens_tensors, 75 | token_type_ids=segments_tensors, 76 | e1_mask=marked_e1, 77 | e2_mask=marked_e2, 78 | attention_mask=masks_tensors, 79 | labels=labels) 80 | preds = outputs[1] 81 | probs = preds.detach().cpu().numpy()[:, 1] 82 | all_probs.extend(probs) 83 | all_preds.extend([item >= 0.5 for item in probs]) 84 | else: 85 | all_probs.extend([-1] * tokens_tensors.shape[0]) 86 | all_preds.extend([False] * tokens_tensors.shape[0]) 87 | 88 | return all_preds, all_probs 89 | -------------------------------------------------------------------------------- /zshot/knowledge_extractor/knowledge_extractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | from abc import ABC, abstractmethod 4 | from typing import List, Iterator, Optional, Union, Tuple 5 | 6 | import torch 7 | import zlib 8 | from spacy.tokens import Doc 9 | 10 | from zshot.utils.alignment_utils import filter_overlapping_spans, spacy_token_offsets 11 | from zshot.utils.data_models import Span 12 | from zshot.utils.data_models.relation_span import RelationSpan 13 | 14 | 15 | class KnowledgeExtractor(ABC): 16 | 17 | def __init__(self, device: Optional[Union[str, torch.device]] = None): 18 | """ Instantiate the Knowledge Extractor 19 | 20 | :param device: Device to be used for computation 21 | """ 22 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device 23 | 24 | def set_device(self, device: Union[str, torch.device]): 25 | """ 26 | Set the device to use 27 | :param device: 28 | :return: 29 | """ 30 | self.device = device 31 | 32 | def load_models(self): 33 | """ 34 | Load the model 35 | :return: 36 | """ 37 | pass 38 | 39 | @abstractmethod 40 | def predict(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None) \ 41 | -> List[List[Tuple[Span, RelationSpan, Span]]]: 42 | """ 43 | Perform the knowledge extraction. 44 | :param docs: A list of spacy Document 45 | :param batch_size: The batch size 46 | :return: the predicted triples 47 | """ 48 | pass 49 | 50 | def parse_triples(self, preds: List[Tuple[Span, RelationSpan, Span]]) -> Tuple[List[Span], List[RelationSpan]]: 51 | """ Parse the triples into lists of entities and relations 52 | 53 | :param preds: Predicted triples 54 | :return: Tuple with list of entities and list of relations 55 | """ 56 | entities = [] 57 | relations = [] 58 | for triple in preds: 59 | entities.append(triple[0]) 60 | entities.append(triple[2]) 61 | relations.append(triple[1]) 62 | 63 | return list(set(entities)), list(set(relations)) 64 | 65 | def extract_knowledge(self, docs: Iterator[Doc], batch_size: Optional[Union[int, None]] = None): 66 | """ 67 | Perform the relations extraction. Call the predict function and add the mentions to the Spacy Doc 68 | :param docs: A list of spacy Document 69 | :param batch_size: The batch size 70 | :return: 71 | """ 72 | predicted_triples = self.predict(docs, batch_size) 73 | for d, preds in zip(docs, predicted_triples): 74 | entities, relations = self.parse_triples(preds) 75 | d._.relations = relations 76 | d._.spans = entities 77 | d.ents = map(lambda p: p.to_spacy_span(d), filter_overlapping_spans(entities, list(d), 78 | tokens_offsets=spacy_token_offsets(d))) 79 | 80 | @staticmethod 81 | def version() -> str: 82 | return "v1" 83 | 84 | @staticmethod 85 | def _get_serialize_file(path): 86 | """ Get full filepath of the serialization file """ 87 | return os.path.join(path, "knowledge_extractor.pkl") 88 | 89 | @classmethod 90 | def from_disk(cls, path, exclude=()): 91 | """ Load component from disk """ 92 | serialize_file = cls._get_serialize_file(path) 93 | with open(serialize_file, "rb") as f: 94 | return pkl.load(f) 95 | 96 | def to_disk(self, path): 97 | """ Save component into disk """ 98 | serialize_file = self._get_serialize_file(path) 99 | with open(serialize_file, "wb") as f: 100 | return pkl.dump(self, f) 101 | 102 | def __hash__(self): 103 | """ Get hash representation of the component """ 104 | self_repr = f"{self.__class__.__name__}.{self.version()}.{str(self.__dict__)}" 105 | return zlib.crc32(self_repr.encode()) 106 | -------------------------------------------------------------------------------- /zshot/mentions_extractor/mentions_extractor_flair.py: -------------------------------------------------------------------------------- 1 | import pkgutil 2 | from typing import Optional, Iterator 3 | 4 | from spacy.tokens.doc import Doc 5 | 6 | from zshot.utils.data_models import Span 7 | from zshot.mentions_extractor.mentions_extractor import MentionsExtractor 8 | from zshot.mentions_extractor.utils import ExtractorType 9 | 10 | 11 | class MentionsExtractorFlair(MentionsExtractor): 12 | """ Flair Mentions extractor """ 13 | ALLOWED_CHUNKS = ("NP",) 14 | 15 | def __init__(self, extractor_type: Optional[ExtractorType] = ExtractorType.NER): 16 | """ 17 | * Requires flair package to be installed * 18 | 19 | :param extractor_type: Type of extractor to get mentions. One of: 20 | - NER: to use Named Entity Recognition model to get the mentions 21 | - POS: to get the mentions based on the linguistics 22 | """ 23 | if not pkgutil.find_loader("flair"): 24 | raise Exception("Flair module not installed. You need to install Flair for using this class." 25 | "Install it with: pip install flair>=0.13") 26 | 27 | super(MentionsExtractorFlair, self).__init__() 28 | 29 | self.extractor_type = extractor_type 30 | self.model = None 31 | 32 | def load_models(self): 33 | """ Load Flair model to perform the mentions extraction """ 34 | if self.model is None: 35 | from flair.models import SequenceTagger 36 | if self.extractor_type == ExtractorType.NER: 37 | self.model = SequenceTagger.load("ner") 38 | else: 39 | self.model = SequenceTagger.load("chunk") 40 | 41 | def predict_pos_mentions(self, docs: Iterator[Doc], batch_size: Optional[int] = None): 42 | """ Predict mentions of docs using POS linguistics 43 | 44 | :param docs: Documents to get mentions of 45 | :param batch_size: Batch size to use 46 | :return: Spans of the mentions 47 | """ 48 | from flair.data import Sentence 49 | sentences = [ 50 | Sentence(str(doc), use_tokenizer=True) for doc in docs 51 | ] 52 | kwargs = {'mini_batch_size': batch_size} if batch_size else {} 53 | self.model.predict(sentences, **kwargs) 54 | 55 | spans = [] 56 | for sent, doc in zip(sentences, docs): 57 | spans_tmp = [] 58 | for i in range(len(sent.labels)): 59 | if sent.labels[i].value in self.ALLOWED_CHUNKS: 60 | spans_tmp.append(Span(sent.labels[i].data_point.start_position, 61 | sent.labels[i].data_point.end_position)) 62 | 63 | spans.append(spans_tmp) 64 | 65 | return spans 66 | 67 | def predict_ner_mentions(self, docs: Iterator[Doc], batch_size: Optional[int] = None): 68 | """ Predict mentions of docs using NER model 69 | 70 | :param docs: Documents to get mentions of 71 | :param batch_size: Batch size to use 72 | :return: Spans of the mentions 73 | """ 74 | from flair.data import Sentence 75 | sentences = [ 76 | Sentence(str(doc), use_tokenizer=True) for doc in docs 77 | ] 78 | kwargs = {'mini_batch_size': batch_size} if batch_size else {} 79 | self.model.predict(sentences, **kwargs) 80 | 81 | spans = [] 82 | for sent, doc in zip(sentences, docs): 83 | sent_mentions = sent.get_spans('ner') 84 | spans_tmp = [ 85 | Span(mention.start_position, mention.end_position, score=mention.score) 86 | for mention in sent_mentions 87 | ] 88 | spans.append(spans_tmp) 89 | 90 | return spans 91 | 92 | def predict(self, docs: Iterator[Doc], batch_size=None): 93 | """ Predict mentions in each document 94 | 95 | :param docs: Documents to get mentions of 96 | :param batch_size: Batch size to use 97 | :return: Spans of the mentions 98 | """ 99 | self.load_models() 100 | if self.extractor_type == ExtractorType.NER: 101 | return self.predict_ner_mentions(docs, batch_size) 102 | else: 103 | return self.predict_pos_mentions(docs, batch_size) 104 | -------------------------------------------------------------------------------- /zshot/utils/ensembler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, Tuple, List 2 | 3 | from zshot.utils.data_models import Span 4 | 5 | 6 | class Ensembler: 7 | 8 | def __init__(self, 9 | num_voters: int, 10 | num_enhance_entities: Optional[int] = -1, 11 | strategy: Optional[str] = 'max', 12 | threshold: Optional[float] = 0.5): 13 | """ Ensembler to improve performance. 14 | 15 | :param num_voters: Number of voters (combination of linker/mention extractor and entity) 16 | :param num_enhance_entities: Number of entities 17 | :param strategy: Strategy to use. Options: max; count. 18 | When `max` choose the label with max total vote score. 19 | When `count` choose the label with max total vote count. 20 | :param threshold: Threshold to use. Proportion of voters voting the entity. 21 | """ 22 | self.number_pipelines = num_voters 23 | if num_enhance_entities > 0: 24 | self.number_pipelines *= self.number_pipelines 25 | self.strategy = strategy 26 | self.threshold = threshold 27 | 28 | def ensemble(self, spans: List[Span]) -> List[Span]: 29 | """ Ensemble the spans 30 | 31 | :param spans: Spans to ensemble 32 | """ 33 | if self.strategy == 'max': 34 | all_union_spans = [self.ensemble_max(s) for k, s in spans.items()] 35 | else: 36 | all_union_spans = [self.ensemble_count(s) for k, s in spans.items()] 37 | all_union_spans = [s for s in all_union_spans if s.score > self.threshold] 38 | all_union_spans = self.inclusive(all_union_spans) 39 | return all_union_spans 40 | 41 | def ensemble_max(self, spans: List[Span]) -> Span: 42 | """ Ensemble the spans with the max strategy, choosing the label with max total vote score 43 | 44 | :param spans: Spans to ensemble 45 | """ 46 | votes = {} 47 | for s in spans: 48 | if s.label not in votes: 49 | votes[s.label] = s.score / self.number_pipelines 50 | else: 51 | votes[s.label] += s.score / self.number_pipelines 52 | 53 | max_score, best_label = self.select_best(votes) 54 | s = spans[0] 55 | 56 | return Span(label=best_label, score=max_score, start=s.start, end=s.end) 57 | 58 | def ensemble_count(self, spans: List[Span]) -> Span: 59 | """ Ensemble the spans with the max strategy, choosing the label with max total vote count 60 | 61 | :param spans: Spans to ensemble 62 | """ 63 | votes = {} 64 | for s in spans: 65 | if s.label not in votes: 66 | votes[s.label] = 1.0 / self.number_pipelines 67 | else: 68 | votes[s.label] += 1.0 / self.number_pipelines 69 | 70 | max_score, best_label = self.select_best(votes) 71 | s = spans[0] 72 | 73 | return Span(label=best_label, score=max_score, start=s.start, end=s.end) 74 | 75 | @staticmethod 76 | def select_best(votes: Dict[str, float]) -> Tuple[float, str]: 77 | """ Select the best entity based on the votes. 78 | 79 | :param votes: Votes to select the best one of 80 | """ 81 | max_score = -1.0 82 | best_label = None 83 | for label, score in votes.items(): 84 | if best_label is None: 85 | best_label = label 86 | max_score = score 87 | elif max_score < score: 88 | best_label = label 89 | max_score = score 90 | 91 | return max_score, best_label 92 | 93 | @staticmethod 94 | def inclusive(spans: List[Span]) -> List[Span]: 95 | n = len(spans) 96 | non_overlapping_spans = [] 97 | for i in range(n): 98 | is_overlapping = False 99 | for j in range(n): 100 | if spans[i].start >= spans[j].start and spans[i].end <= spans[j].end: 101 | if spans[i].start > spans[j].start or spans[i].end < spans[j].end: 102 | is_overlapping = True 103 | break 104 | if not is_overlapping: 105 | non_overlapping_spans.append(spans[i]) 106 | return non_overlapping_spans 107 | -------------------------------------------------------------------------------- /zshot/evaluation/dataset/ontonotes/onto_notes.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict, Union, Optional 3 | 4 | from datasets import ClassLabel, load_dataset, DatasetDict, Split, Dataset 5 | 6 | from zshot.evaluation.dataset.dataset import DatasetWithEntities 7 | from zshot.evaluation.dataset.ontonotes.entities import ONTONOTES_ENTITIES 8 | 9 | LABELS = ONTONOTES_ENTITIES 10 | labels = ClassLabel(num_classes=37, 11 | names=["O", "B-PERSON", "I-PERSON", "B-NORP", "I-NORP", "B-FAC", "I-FAC", 12 | "B-ORG", "I-ORG", "B-GPE", "I-GPE", "B-LOC", "I-LOC", 13 | "B-PRODUCT", "I-PRODUCT", "B-DATE", "I-DATE", "B-TIME", "I-TIME", 14 | "B-PERCENT", "I-PERCENT", "B-MONEY", "I-MONEY", "B-QUANTITY", "I-QUANTITY", 15 | "B-ORDINAL", "I-ORDINAL", "B-CARDINAL", "I-CARDINAL", "B-EVENT", "I-EVENT", 16 | "B-WORK_OF_ART", "I-WORK_OF_ART", "B-LAW", "I-LAW", "B-LANGUAGE", "I-LANGUAGE"]) 17 | 18 | CLASSES_PER_SPLIT = { 19 | "train": ["PERSON", "GPE", "ORG", "DATE"], 20 | "validation": ["NORP", "MONEY", 'ORDINAL', "PERCENT", "EVENT", "PRODUCT", "LAW"], 21 | "test": ["CARDINAL", "TIME", "LOC", "WORK_OF_ART", "FAC", "QUANTITY", "LANGUAGE"] 22 | } 23 | TRIVIAL_CLASSES = ["ORDINAL", "QUANTITY", "MONEY", "PERCENT", "CARDINAL", "LANGUAGE", "TIME"] 24 | 25 | 26 | def remove_other_tasks(sentence): 27 | if 'pos_tags' in sentence: 28 | del sentence['pos_tags'] 29 | if 'parse_tree' in sentence: 30 | del sentence['parse_tree'] 31 | if 'predicate_framenet_ids' in sentence: 32 | del sentence['predicate_framenet_ids'] 33 | if 'word_senses' in sentence: 34 | del sentence['word_senses'] 35 | if 'speaker' in sentence: 36 | del sentence['speaker'] 37 | if 'predicate_lemmas' in sentence: 38 | del sentence['predicate_lemmas'] 39 | if 'coref_spans' in sentence: 40 | del sentence['coref_spans'] 41 | if 'srl_frames' in sentence: 42 | del sentence['srl_frames'] 43 | return sentence 44 | 45 | 46 | def is_not_empty(sentence): 47 | return not all([s == 0 for s in sentence['named_entities']]) 48 | 49 | 50 | def remove_out_of_split(sentence, split): 51 | for i, ent in enumerate(sentence['named_entities']): 52 | label = labels.int2str(ent) 53 | if label == 'O' or label[2:] in TRIVIAL_CLASSES or label[2:] not in CLASSES_PER_SPLIT[split]: 54 | sentence['named_entities'][i] = 0 55 | return sentence 56 | 57 | 58 | def load_ontonotes_zs(split: Optional[Union[str, Split]] = None, **kwargs) -> Union[Dict[DatasetWithEntities, 59 | Dataset], Dataset]: 60 | dataset_zs = load_dataset("conll2012_ontonotesv5", "english_v12", 61 | split=split, verification_mode='no_checks', trust_remote_code=True, **kwargs) 62 | if split: 63 | ontonotes_zs = preprocess_spit(dataset_zs, get_simple_split(split)) 64 | else: 65 | ontonotes_zs = DatasetDict() 66 | for split in dataset_zs: 67 | ontonotes_zs[split] = preprocess_spit(dataset_zs[split], split) 68 | return ontonotes_zs 69 | 70 | 71 | def preprocess_spit(dataset, split) -> DatasetWithEntities: 72 | dataset = dataset.map(lambda example, idx: { 73 | "sentences": [remove_out_of_split(s, split) for s in example['sentences']] 74 | }, with_indices=True) 75 | dataset = dataset.map(lambda example, idx: { 76 | "sentences": list(filter(is_not_empty, example['sentences'])) 77 | }, with_indices=True) 78 | tokens = [] 79 | ner_tags = [] 80 | for example in dataset: 81 | tokens += [s['words'] for s in example['sentences']] 82 | ner_tags += [[labels.int2str(ent) for ent in s['named_entities']] for s in example['sentences']] 83 | split_entities = [ent for ent in ONTONOTES_ENTITIES 84 | if ent.name in ['NEG'] + CLASSES_PER_SPLIT[split] and ent.name not in TRIVIAL_CLASSES] 85 | dataset = Dataset.from_dict({ 86 | 'tokens': tokens, 87 | 'ner_tags': ner_tags 88 | }, split=split) 89 | dataset.entities = split_entities 90 | return dataset 91 | 92 | 93 | def get_simple_split(split: str) -> str: 94 | first_not_alph = re.search(r'\W+', split) 95 | first_not_alph_chr = first_not_alph.start() if first_not_alph else len(split) 96 | return split[: first_not_alph_chr] 97 | -------------------------------------------------------------------------------- /zshot/tests/utils/test_displacy.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | from spacy.tokens import Doc 3 | 4 | from zshot import displacy, PipelineConfig 5 | from zshot.tests.config import EX_DOCS, EX_ENTITIES 6 | from zshot.tests.linker.test_linker import DummyLinkerEnd2End 7 | from zshot.tests.mentions_extractor.test_mention_extractor import DummyMentionsExtractor 8 | from zshot.utils.data_models import Span, Relation 9 | from zshot.utils.data_models.relation_span import RelationSpan 10 | 11 | 12 | def test_displacy_render(): 13 | nlp = spacy.blank("en") 14 | nlp.add_pipe("zshot", config=PipelineConfig( 15 | mentions_extractor=DummyMentionsExtractor(), 16 | linker=DummyLinkerEnd2End(), 17 | entities=EX_ENTITIES), last=True) 18 | doc = nlp(EX_DOCS[1]) 19 | assert len(doc.ents) > 0 20 | assert len(doc._.spans) > 0 21 | res = displacy.render(doc, style="ent", jupyter=False) 22 | assert res is not None 23 | 24 | 25 | def test_displacy_render_notebook(): 26 | nlp = spacy.blank("en") 27 | nlp.add_pipe("zshot", config=PipelineConfig( 28 | mentions_extractor=DummyMentionsExtractor(), 29 | linker=DummyLinkerEnd2End(), 30 | entities=EX_ENTITIES), last=True) 31 | doc = nlp(EX_DOCS[1]) 32 | assert len(doc.ents) > 0 33 | assert len(doc._.spans) > 0 34 | res = displacy.render(doc, style="ent", jupyter=True) 35 | assert res is None 36 | 37 | 38 | def test_displacy_rel_style(): 39 | nlp = spacy.load("en_core_web_sm") 40 | doc = nlp(EX_DOCS[1]) 41 | relations = [ 42 | RelationSpan(start=Span(0, 43, "IBM", -0.007964816875755787), end=Span(45, 48, "IBM", -0.00017413603200111538), 43 | relation=Relation(name="is_in", description="is inside"), score=0.7), 44 | RelationSpan(start=Span(0, 43, "IBM", -0.007964816875755787), 45 | end=Span(127, 135, "New York", -2.3538105487823486), 46 | relation=Relation(name="has_headquarters", description="has headquarters"), score=0.3) 47 | ] 48 | spans = [Span(0, 43, "IBM", -0.007964816875755787), Span(45, 48, "IBM", -0.00017413603200111538), 49 | Span(56, 64, "American", -5.8533525466918945), Span(119, 125, "Armonk", -2.1522278785705566), 50 | Span(127, 135, "New York", -2.3538105487823486)] 51 | if not Doc.has_extension("spans"): 52 | Doc.set_extension("spans", default=[]) 53 | if not Doc.has_extension("relations"): 54 | Doc.set_extension("relations", default=[]) 55 | doc._.relations = relations 56 | doc._.spans = spans 57 | html = displacy.render(doc, style="rel") 58 | assert html is not None 59 | assert "IBM" in html 60 | assert "American" in html 61 | assert "New York" in html 62 | assert "is_in" in html 63 | assert "has_headquarters" in html 64 | assert "displacy-token" in html 65 | assert "displacy-tag" in html 66 | assert "displacy-arrow" in html 67 | 68 | 69 | def test_displacy_rel_compact_style(): 70 | nlp = spacy.load("en_core_web_sm") 71 | doc = nlp(EX_DOCS[1]) 72 | relations = [ 73 | RelationSpan(start=Span(45, 48, "IBM", -0.00017413603200111538), end=Span(0, 43, "IBM", -0.007964816875755787), 74 | relation=Relation(name="is_in", description="is inside"), score=0.7), 75 | RelationSpan(start=Span(0, 43, "IBM", -0.007964816875755787), 76 | end=Span(127, 135, "New York", -2.3538105487823486), 77 | relation=Relation(name="has_headquarters", description="has headquarters"), score=0.3) 78 | ] 79 | spans = [Span(0, 43, "IBM", -0.007964816875755787), Span(45, 48, "IBM", -0.00017413603200111538), 80 | Span(56, 64, "American", -5.8533525466918945), Span(119, 125, "Armonk", -2.1522278785705566), 81 | Span(127, 135, "New York", -2.3538105487823486)] 82 | if not Doc.has_extension("spans"): 83 | Doc.set_extension("spans", default=[]) 84 | if not Doc.has_extension("relations"): 85 | Doc.set_extension("relations", default=[]) 86 | doc._.relations = relations 87 | doc._.spans = spans 88 | html = displacy.render(doc, style="rel", options={"compact": True}) 89 | assert html is not None 90 | assert "IBM" in html 91 | assert "American" in html 92 | assert "New York" in html 93 | assert "is_in" in html 94 | assert "has_headquarters" in html 95 | assert "displacy-token" in html 96 | assert "displacy-tag" in html 97 | assert "displacy-arrow" in html 98 | -------------------------------------------------------------------------------- /zshot/tests/utils/test_data_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import spacy 3 | 4 | from zshot.tests.config import EX_DATASET_RELATIONS 5 | from zshot.utils.data_models import Entity, Relation, Span 6 | from zshot.utils.data_models.relation_span import RelationSpan 7 | 8 | 9 | def test_span(): 10 | # Full 11 | s = Span(start=0, end=10, label='E', score=1, kb_id='e1') 12 | assert type(s) is Span 13 | assert s.start == 0 14 | assert s.end == 10 15 | assert s.label == 'E' 16 | assert s.score == 1 17 | assert s.kb_id == 'e1' 18 | 19 | # No score/KB Id 20 | s = Span(start=165, end=187, label='Q5034838') 21 | assert type(s) is Span 22 | assert s.start == 165 23 | assert s.end == 187 24 | assert s.label == 'Q5034838' 25 | 26 | # Check hash 27 | assert hash(s) == 10737688 28 | 29 | # Check repr 30 | assert repr(s) == f"{s.label}, {s.start}, {s.end}, {s.score}" 31 | 32 | # From Dict 33 | with pytest.raises(ValueError): 34 | s1 = Span.from_dict({}) 35 | with pytest.raises(ValueError): 36 | s1 = Span.from_dict({'start': 0}) 37 | with pytest.raises(ValueError): 38 | s1 = Span.from_dict({'start': 0, 'end': 0}) 39 | 40 | s1 = Span.from_dict(EX_DATASET_RELATIONS['sentence_entities'][0][0]) 41 | assert type(s1) is Span 42 | assert s1.start == 165 43 | assert s1.end == 187 44 | assert s1.label == 'Q5034838' 45 | 46 | s1 = Span.from_dict(EX_DATASET_RELATIONS['sentence_entities'][0][0]) 47 | assert type(s1) is Span 48 | assert s1.start == 165 49 | assert s1.end == 187 50 | assert s1.label == 'Q5034838' 51 | 52 | # Check eq 53 | assert s == s1 54 | 55 | # From/To SpaCy Span 56 | nlp = spacy.blank('en') 57 | doc = nlp(EX_DATASET_RELATIONS['sentence_entities'][0][0]['sentence']) 58 | spacy_span = s.to_spacy_span(doc) 59 | assert type(spacy_span) is spacy.tokens.Span 60 | assert spacy_span.start == 26 61 | assert spacy_span.end == 29 62 | assert spacy_span.label_ == s.label 63 | 64 | s1 = Span.from_spacy_span(spacy_span) 65 | assert type(s1) is Span 66 | assert s1.start == 166 67 | assert s1.end == 187 68 | assert s1.label == 'Q5034838' 69 | 70 | 71 | def test_entity(): 72 | # Full 73 | e = Entity(name='E', description='Entity', vocabulary=['Vocab']) 74 | assert type(e) is Entity 75 | assert e.name == 'E' 76 | assert e.description == 'Entity' 77 | assert len(e.vocabulary) == 1 and e.vocabulary[0] == 'Vocab' 78 | 79 | # No description 80 | e = Entity(name='E', vocabulary=['Vocab']) 81 | assert type(e) is Entity 82 | assert e.name == 'E' 83 | assert len(e.vocabulary) == 1 and e.vocabulary[0] == 'Vocab' 84 | 85 | # No vocabulary 86 | e = Entity(name='E', description='Entity') 87 | assert type(e) is Entity 88 | assert e.name == 'E' 89 | assert e.description == 'Entity' 90 | 91 | # Check hash 92 | e = Entity(name='E') 93 | assert hash(e) == 3095248369 94 | 95 | 96 | def test_relation_span(): 97 | # Full 98 | s1 = Span.from_dict(EX_DATASET_RELATIONS['sentence_entities'][0][0]) 99 | s2 = Span.from_dict(EX_DATASET_RELATIONS['sentence_entities'][0][1]) 100 | rs = RelationSpan(start=s1, end=s2, relation=Relation(name=EX_DATASET_RELATIONS['labels'][0]), score=1, kb_id='P1') 101 | assert type(rs) is RelationSpan 102 | assert type(rs.start) is Span 103 | assert type(rs.end) is Span 104 | assert type(rs.relation) is Relation 105 | assert rs.score == 1 106 | assert rs.kb_id == 'P1' 107 | 108 | # No Score/KB Id 109 | rs = RelationSpan(start=s1, end=s2, relation=Relation(name=EX_DATASET_RELATIONS['labels'][0])) 110 | assert type(rs) is RelationSpan 111 | assert type(rs.start) is Span 112 | assert type(rs.end) is Span 113 | assert type(rs.relation) is Relation 114 | 115 | # Check hash 116 | assert hash(rs) == 1864423560 117 | 118 | # Check repr 119 | assert repr(rs) == f"{rs.relation.name}, {rs.start}, {rs.end}, {rs.score}" 120 | 121 | # Check eq 122 | rs2 = RelationSpan(end=s2, start=s1, relation=Relation(name=EX_DATASET_RELATIONS['labels'][0])) 123 | assert rs == rs2 124 | 125 | 126 | def test_relation(): 127 | # Full 128 | r = Relation(name='R', description='Relation') 129 | assert type(r) is Relation 130 | assert r.name == 'R' 131 | assert r.description == 'Relation' 132 | 133 | # No description 134 | r = Relation(name='R') 135 | assert type(r) is Relation 136 | assert r.name == 'R' 137 | 138 | # Check hash 139 | assert hash(r) == 3502422000 140 | --------------------------------------------------------------------------------