├── .devcontainer └── devcontainer.json ├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── llmner ├── __init__.py ├── data.py ├── models.py ├── templates.py └── utils.py ├── notebooks ├── 1-example.ipynb └── 2-conll.ipynb ├── setup.py └── tests.py /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | // README at: https://github.com/devcontainers/templates/tree/main/src/python 3 | { 4 | "name": "llmner", 5 | // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile 6 | "image": "mcr.microsoft.com/devcontainers/python:1-3.10-bullseye", 7 | "customizations": { 8 | "vscode": { 9 | "extensions": [ 10 | "ms-python.black-formatter", 11 | "ms-python.python", 12 | "GitHub.copilot", 13 | "ms-toolsai.jupyter" 14 | ] 15 | } 16 | }, 17 | 18 | // Features to add to the dev container. More info: https://containers.dev/features. 19 | // "features": {}, 20 | 21 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 22 | // "forwardPorts": [], 23 | 24 | // Use 'postCreateCommand' to run commands after the container is created. 25 | "postCreateCommand": "pip install -e .[dev]" 26 | 27 | // Configure tool-specific properties. 28 | // "customizations": {}, 29 | 30 | // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. 31 | // "remoteUser": "root" 32 | } 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | data 3 | llmner.egg-info 4 | gpt_test 5 | json_parsing_test 6 | venv -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.typeCheckingMode": "basic", 3 | "python.testing.unittestArgs": [ 4 | "-v", 5 | "-s", 6 | ".", 7 | "-p", 8 | "test*.py" 9 | ], 10 | "python.testing.pytestEnabled": false, 11 | "python.testing.unittestEnabled": true 12 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llmNER: (Few|Zero)-Shot Named Entity Recognition without training data 2 | 3 | Exploit the power of Large Language Models (LLM) to perform Zero-Shot or Few-Shot Named Entity Recognition (NER) without the need of annotated data. 4 | 5 | ## Installation 6 | 7 | ```bash 8 | pip install git+https://github.com/plncmm/llmner.git 9 | ``` 10 | 11 | ## Usage 12 | 13 | ### Zero-Shot NER 14 | 15 | ```python 16 | import os 17 | os.environ["OPENAI_API_KEY"] = "" 18 | from llmner import ZeroShotNer 19 | 20 | entities = { 21 | "person": "A person name, it can include first and last names, for example: John Kennedy and Bill Gates", 22 | "organization": "An organization name, it can be a company, a government agency, etc.", 23 | "location": "A location name, it can be a city, a country, etc.", 24 | } 25 | 26 | model = ZeroShotNer() 27 | model.contextualize(entities=entities) 28 | 29 | model.predict(["Pedro Pereira is the president of Perú and the owner of Walmart."]) 30 | ``` 31 | 32 | ### Few-Shot NER 33 | 34 | ```python 35 | import os 36 | os.environ["OPENAI_API_KEY"] = "" 37 | from llmner import FewShotNer 38 | from llmner.data import AnnotatedDocument, Annotation 39 | 40 | entities = { 41 | "person": "A person name, it can include first and last names, for example: John Kennedy and Bill Gates", 42 | "organization": "An organization name, it can be a company, a government agency, etc.", 43 | "location": "A location name, it can be a city, a country, etc.", 44 | } 45 | 46 | examples = [ 47 | AnnotatedDocument( 48 | text="Gabriel Boric is the president of Chile", 49 | annotations={ 50 | Annotation(start=34, end=39, label="location"), 51 | Annotation(start=0, end=13, label="person"), 52 | }, 53 | ), 54 | AnnotatedDocument( 55 | text="Elon Musk is the owner of the US company Tesla", 56 | annotations={ 57 | Annotation(start=30, end=32, label="location"), 58 | Annotation(start=0, end=9, label="person"), 59 | Annotation(start=41, end=46, label="organization"), 60 | }, 61 | ), 62 | AnnotatedDocument( 63 | text="Bill Gates is the owner of Microsoft", 64 | annotations={ 65 | Annotation(start=0, end=10, label="person"), 66 | Annotation(start=27, end=36, label="organization"), 67 | }, 68 | ), 69 | AnnotatedDocument( 70 | text="John is the president of Argentina", 71 | annotations={ 72 | Annotation(start=0, end=4, label="person"), 73 | Annotation(start=25, end=34, label="location"), 74 | }, 75 | ), 76 | ] 77 | 78 | model = FewShotNer() 79 | model.contextualize(entities=entities, examples=examples) 80 | 81 | model.predict(["Pedro Pereira is the president of Perú and the owner of Walmart."]) 82 | ``` 83 | 84 | ### Use your own LLM 85 | 86 | You need to set the next environment variables to use your own LLM: 87 | 88 | - `OPENAI_API_KEY`: Your API key if you need one, otherwise use a random one. 89 | - `OPENAI_API_BASE`: The API base URL 90 | 91 | ### If you belong to an OpenAI organization 92 | 93 | You need to set the next environment variables to use your organization: 94 | 95 | - `OPENAI_ORG_ID`: Your organization ID 96 | - `OPENAI_API_KEY`: Your OpenAi API key 97 | 98 | ### If you are using the model through Azure 99 | 100 | You need to set the next environment variables to use your LLM throgh Azure: 101 | 102 | - `OPENAI_API_BASE`: The Azure API base URL 103 | - `OPENAI_API_KEY`: Your Azure API key 104 | - `OPENAI_API_TYPE`: You need to set it to `azure` 105 | - `OPENAI_API_VERSION`: You need to set it to your Azure API version 106 | 107 | Also, when instantiating the model object you need to pass `model_kwargs={"engine":""}` 108 | 109 | For example: 110 | 111 | ```python 112 | import os 113 | os.environ["OPENAI_API_KEY"] = "" 114 | os.environ["OPENAI_API_BASE"] = "" 115 | os.environ["OPENAI_API_TYPE"] = "azure" 116 | os.environ["OPENAI_API_VERSION"] = "" 117 | from llmner import ZeroShotNer 118 | 119 | entities = { 120 | "person": "A person name, it can include first and last names, for example: John Kennedy and Bill Gates", 121 | "organization": "An organization name, it can be a company, a government agency, etc.", 122 | "location": "A location name, it can be a city, a country, etc.", 123 | } 124 | 125 | model = ZeroShotNer(model_kwargs={"engine":""}) 126 | model.contextualize(entities=entities) 127 | 128 | model.predict(["Pedro Pereira is the president of Perú and the owner of Walmart."]) 129 | ``` 130 | 131 | ## If you are using Deep Infra 132 | 133 | You have to set `OPENAI_API_BASE` to `https://api.deepinfra.com/v1/openai`, and `OPENAI_API_KEY` to your API key and instantiate the models setting the `model` argument to the name of the deployed model. For example if you want perform Few-Shot NER using Llama 2 70B you need to instantiate the model as follows: `ZeroShotNer(model="meta-llama/Llama-2-70b-chat-hf")`. 134 | -------------------------------------------------------------------------------- /llmner/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | sys.path.append(os.path.dirname(os.path.realpath(__file__))) 4 | from models import ZeroShotNer, FewShotNer 5 | -------------------------------------------------------------------------------- /llmner/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Set, Optional, List, Tuple, Literal, Union 3 | 4 | 5 | @dataclass 6 | class Document: 7 | """Document class. Used to represent a document. 8 | Args: 9 | text (str): Text of the document. 10 | """ 11 | 12 | text: str 13 | 14 | 15 | @dataclass() 16 | class Annotation: 17 | """Annotation class. Used to represent an annotation. An annotation is a labelled span of text. 18 | Args: 19 | start (int): Start index of the annotation. 20 | end (int): End index of the annotation. 21 | label (str): Label of the annotation. 22 | text (Optional[str], optional): Text content of the annotation. Is is optional and defaults to None. 23 | """ 24 | 25 | start: int 26 | end: int 27 | label: str 28 | text: Optional[str] = None 29 | 30 | def __hash__(self): 31 | return hash((self.start, self.end, self.label)) 32 | 33 | 34 | @dataclass 35 | class AnnotatedDocument(Document): 36 | """AnnotatedDocument class. Used to represent an annotated document. 37 | Args: 38 | text (str): Text of the document. 39 | annotations (Set[Annotation]): Set of annotations of the document. 40 | """ 41 | 42 | annotations: Set[Annotation] 43 | 44 | 45 | @dataclass 46 | class AnnotatedDocumentWithException(AnnotatedDocument): 47 | """AnnotatedDocumentWithException class. Used to represent an annotated document with an exception. 48 | Args: 49 | text (str): Text of the document. 50 | annotations (Set[Annotation]): Set of annotations of the document. 51 | exception (Exception): Exception of the document. 52 | """ 53 | 54 | exception: Exception 55 | 56 | 57 | class NotContextualizedError(Exception): 58 | pass 59 | 60 | 61 | class NotPerfectlyAlignedError(Exception): 62 | """Exception raised when the text cannot be perfectly aligned. 63 | Args: 64 | message (str): Message of the exception. 65 | removed_annotations (List[Annotation]): List of annotations that were removed. 66 | """ 67 | 68 | def __init__( 69 | self, 70 | message: str, 71 | removed_annotations: List[Annotation] = [], 72 | completion_text: str = "", 73 | ): 74 | self.removed_annotations = removed_annotations 75 | self.message = message 76 | self.completion_text = completion_text 77 | super().__init__(self.message) 78 | 79 | 80 | Token = str 81 | Label = str 82 | Conll = List[Tuple[Token, Label]] 83 | 84 | 85 | @dataclass 86 | class PromptTemplate: 87 | """PromptTemplate class. Used to represent a prompt template. 88 | Args: 89 | inline_single_turn (str): Template for inline single turn. 90 | inline_multi_turn_default_delimiters (str): Template for inline multi turn with default delimiters. 91 | inline_multi_turn_custom_delimiters (str): Template for inline multi turn with custom delimiters. 92 | json_single_turn (str): Template for json single turn. 93 | json_multi_turn (str): Template for json multi turn. 94 | multi_turn_prefix (str): Prefix for multi turn. 95 | pos (str): Template for part of speech tagging. 96 | pos_answer_prefix (str): Prefix for part of speech tagging answer. 97 | final_message_prefix (str): Prefix for final message. 98 | """ 99 | 100 | inline_single_turn: str 101 | inline_multi_turn_default_delimiters: str 102 | inline_multi_turn_custom_delimiters: str 103 | json_single_turn: str 104 | json_multi_turn: str 105 | multi_turn_prefix: str 106 | pos: str 107 | pos_answer_prefix: str 108 | final_message_prefix: str 109 | -------------------------------------------------------------------------------- /llmner/models.py: -------------------------------------------------------------------------------- 1 | from langchain.prompts import ( 2 | HumanMessagePromptTemplate, 3 | SystemMessagePromptTemplate, 4 | ChatPromptTemplate, 5 | FewShotChatMessagePromptTemplate, 6 | AIMessagePromptTemplate, 7 | ) 8 | 9 | from langchain.schema.messages import AIMessage, HumanMessage 10 | 11 | from langchain.chat_models import ChatOpenAI 12 | 13 | from llmner.utils import ( 14 | dict_to_enumeration, 15 | inline_annotation_to_annotated_document, 16 | inline_special_tokens_annotation_to_annotated_document, 17 | json_annotation_to_annotated_document, 18 | align_annotation, 19 | annotated_document_to_single_turn_few_shot_example, 20 | annotated_document_to_multi_turn_few_shot_example, 21 | detokenizer, 22 | annotated_document_to_conll, 23 | annotated_document_to_inline_annotated_string, 24 | annotated_document_to_json_annotated_string, 25 | ) 26 | 27 | from llmner.templates import TEMPLATE_EN 28 | 29 | from typing import List, Dict, Union, Tuple, Callable, Literal 30 | from llmner.data import ( 31 | AnnotatedDocument, 32 | AnnotatedDocumentWithException, 33 | NotContextualizedError, 34 | Conll, 35 | Label, 36 | PromptTemplate, 37 | ) 38 | 39 | from tqdm import tqdm 40 | 41 | from concurrent.futures import ThreadPoolExecutor, as_completed 42 | import multiprocessing 43 | 44 | CPU_COUNT = multiprocessing.cpu_count() 45 | 46 | import logging 47 | 48 | logger = logging.getLogger(__name__) 49 | 50 | # pyright: reportUnboundVariable=false 51 | 52 | 53 | class BaseNer: 54 | """Base NER model class. All NER models should inherit from this class.""" 55 | 56 | def __init__( 57 | self, 58 | model: str = "gpt-3.5-turbo", 59 | max_tokens: int = 256, 60 | stop: List[str] = ["###"], 61 | temperature: float = 1.0, 62 | model_kwargs: Dict = {}, 63 | answer_shape: Literal["inline", "json"] = "inline", 64 | prompting_method: Literal["single_turn", "multi_turn"] = "single_turn", 65 | multi_turn_delimiters: Union[None, Tuple[str, str]] = None, 66 | final_message_with_all_entities: bool = False, 67 | augment_with_pos: Union[bool, Callable[[str], str]] = False, 68 | prompt_template: PromptTemplate = TEMPLATE_EN, 69 | system_message_as_user_message: bool = False, 70 | ): 71 | """NER model. Make sure you have at least the OPENAI_API_KEY environment variable set with your API key. Refer to the python openai library documentation for more information. 72 | 73 | Args: 74 | model (str, optional): Model name. Defaults to "gpt-3.5-turbo". 75 | max_tokens (int, optional): Max number of new tokens. Defaults to 256. 76 | stop (List[str], optional): List of strings that should stop generation. Defaults to ["###"]. 77 | temperature (float, optional): Temperature for the generation. Defaults to 1.0. 78 | model_kwargs (Dict, optional): Arguments to pass to the llm. Defaults to {}. Refer to the OpenAI python library documentation and OpenAI API documentation for more information. 79 | answer_shape (Literal["inline", "json"], optional): Shape of the answer. The inline answer shape encloses entities between inline tags, as in 'Washington' and the json answer shapes expects a valid json response from the model. Defaults to "inline". 80 | prompting_method (Literal["single_turn", "multi_turn"], optional): Prompting method. In multi_turn, we query the model for each entity and at the end we compile the anotated document. Defaults to "single_turn". 81 | multi_turn_delimiters (Union[None, Tuple[str, str]], optional): Delimiter symbols for multi-turn prompting, the first element of the tuple is the start delimiter and the second element of the tuple is the end delimiter, for example, if you want to enclose the mention between @, you need to set this argument to ('@', '@') or if you need to enclose the mention as in @mention# you need to set the argument to ('@', '#'). Defaults to None, which uses the entity class name as delimiters, as in mention. 82 | final_message_with_all_entities (bool, optional): If True, the final message will ask the AI to annotate the document with all entities, only valid when prompting_method=multi_turn. Defaults to False. 83 | augment_with_pos (Union[bool, Callable[[str], str]], optional): If True, the model will be augmented with the part-of-speech tagging of the document. If a function is passed, the function will be called with the docuemnt as the argument and the returned value will be used as the augmentation. Defaults to False. 84 | prompt_template (str, optional): Prompt template to send the llm as the system message. Defaults to a prompt template for NER in English. 85 | system_message_as_user_message (bool, optional): If True, the system message will be sent as a user message. Defaults to False. 86 | """ 87 | self.max_tokens = max_tokens 88 | self.stop = stop 89 | self.model = model 90 | self.chat_template = None 91 | self.model_kwargs = model_kwargs 92 | self.temperature = temperature 93 | self.answer_shape = answer_shape 94 | self.prompting_method = prompting_method 95 | self.multi_turn_delimiters = multi_turn_delimiters 96 | self.final_message_with_all_entities = final_message_with_all_entities 97 | self.augment_with_pos = augment_with_pos 98 | self.prompt_template = prompt_template 99 | self.system_message_as_user_message = system_message_as_user_message 100 | 101 | self.multi_turn_prefix = self.prompt_template.multi_turn_prefix 102 | if self.multi_turn_delimiters: 103 | self.start_token = self.multi_turn_delimiters[0] 104 | self.end_token = self.multi_turn_delimiters[1] 105 | else: 106 | self.start_token = "###" 107 | self.end_token = "###" 108 | if (self.answer_shape == "inline") & (self.prompting_method == "single_turn"): 109 | current_prompt_template = self.prompt_template.inline_single_turn 110 | elif (self.answer_shape == "inline") & (self.prompting_method == "multi_turn"): 111 | if self.multi_turn_delimiters: 112 | current_prompt_template = ( 113 | self.prompt_template.inline_multi_turn_custom_delimiters 114 | ) 115 | else: 116 | current_prompt_template = ( 117 | self.prompt_template.inline_multi_turn_default_delimiters 118 | ) 119 | elif (self.answer_shape == "json") & (self.prompting_method == "single_turn"): 120 | current_prompt_template = self.prompt_template.json_single_turn 121 | elif (self.answer_shape == "json") & (self.prompting_method == "multi_turn"): 122 | current_prompt_template = self.prompt_template.json_multi_turn 123 | else: 124 | raise ValueError( 125 | "The answer shape and prompting method combination is not valid" 126 | ) 127 | if not self.system_message_as_user_message: 128 | self.system_template = SystemMessagePromptTemplate.from_template( 129 | current_prompt_template 130 | ) 131 | else: 132 | self.system_template = HumanMessagePromptTemplate.from_template( 133 | current_prompt_template 134 | ) 135 | 136 | def _query_model( 137 | self, 138 | messages: list, 139 | request_timeout: int = 600, 140 | remove_model_kwargs: bool = False, 141 | ): 142 | if remove_model_kwargs: 143 | model_kwargs = {} 144 | else: 145 | model_kwargs = self.model_kwargs 146 | chat = ChatOpenAI( 147 | model_name=self.model, # type: ignore 148 | max_tokens=self.max_tokens, 149 | temperature=self.temperature, 150 | model_kwargs=model_kwargs, 151 | request_timeout=request_timeout, 152 | ) 153 | completion = chat.invoke(messages, stop=self.stop) 154 | return completion 155 | 156 | 157 | class ZeroShotNer(BaseNer): 158 | """Zero-shot NER model class.""" 159 | 160 | def contextualize( 161 | self, 162 | entities: Dict[str, str], 163 | ): 164 | """Method to ontextualize the zero-shot NER model. You don't need examples to contextualize this model. 165 | 166 | Args: 167 | entities (Dict[str, str]): Dict containing the entities to be recognized. The keys are the entity names and the values are the entity descriptions. 168 | """ 169 | self.entities = entities 170 | if self.multi_turn_delimiters: 171 | self.system_message = self.system_template.format( 172 | entities=dict_to_enumeration(entities), 173 | entity_list=list(entities.keys()), 174 | start_token=self.start_token, 175 | end_token=self.end_token, 176 | ) 177 | else: 178 | self.system_message = self.system_template.format( 179 | entities=dict_to_enumeration(entities), 180 | entity_list=list(entities.keys()), 181 | ) 182 | if self.augment_with_pos: 183 | self.chat_template = ChatPromptTemplate.from_messages( 184 | [ 185 | self.system_message, 186 | HumanMessagePromptTemplate.from_template( 187 | f"{self.prompt_template.pos_answer_prefix} {{pos}}" 188 | ), 189 | HumanMessagePromptTemplate.from_template("{x}"), 190 | ] 191 | ) 192 | else: 193 | self.chat_template = ChatPromptTemplate.from_messages( 194 | [ 195 | self.system_message, 196 | HumanMessagePromptTemplate.from_template("{x}"), 197 | ] 198 | ) 199 | 200 | def fit(self, *args, **kwargs): 201 | """Just a wrapper for the contextualize method. This method is here to be compatible with the sklearn API.""" 202 | return self.contextualize(*args, **kwargs) 203 | 204 | def _predict_pos(self, x: str, request_timeout: int) -> str: 205 | pos_chat_template = ChatPromptTemplate.from_messages( 206 | [ 207 | self.prompt_template.pos, 208 | HumanMessagePromptTemplate.from_template("{x}"), 209 | ] 210 | ) 211 | messages = pos_chat_template.format_messages(x=x) 212 | completion = self._query_model( 213 | messages, request_timeout, remove_model_kwargs=True 214 | ) 215 | return completion.content 216 | 217 | def _predict( 218 | self, x: str, request_timeout: int 219 | ) -> AnnotatedDocument | AnnotatedDocumentWithException: 220 | chat_template = self.chat_template 221 | if self.augment_with_pos: 222 | try: 223 | if callable(self.augment_with_pos): 224 | pos = self.augment_with_pos(x) 225 | else: 226 | pos = self._predict_pos(x, request_timeout) 227 | except Exception as e: 228 | logger.warning( 229 | f"The pos completion for the text '{x}' raised an exception: {e}" 230 | ) 231 | return AnnotatedDocumentWithException( 232 | text=x, annotations=set(), exception=e 233 | ) 234 | logger.debug(f"POS: {pos}") 235 | messages = chat_template.format_messages(x=x, pos=pos) 236 | else: 237 | messages = chat_template.format_messages(x=x) 238 | try: 239 | completion = self._query_model(messages, request_timeout) 240 | except Exception as e: 241 | logger.warning( 242 | f"The completion for the text '{x}' raised an exception: {e}" 243 | ) 244 | return AnnotatedDocumentWithException( 245 | text=x, annotations=set(), exception=e 246 | ) 247 | logger.debug(f"Completion: {completion}") 248 | annotated_document = AnnotatedDocument(text=x, annotations=set()) 249 | if self.answer_shape == "json": 250 | annotated_document = json_annotation_to_annotated_document( 251 | completion.content, list(self.entities.keys()), x 252 | ) 253 | elif self.answer_shape == "inline": 254 | annotated_document = inline_annotation_to_annotated_document( 255 | completion.content, list(self.entities.keys()) 256 | ) 257 | 258 | aligned_annotated_document = align_annotation(x, annotated_document) 259 | y = aligned_annotated_document 260 | return y 261 | 262 | def _predict_multi_turn( 263 | self, 264 | x: str, 265 | request_timeout: int, 266 | ) -> AnnotatedDocument | AnnotatedDocumentWithException: 267 | chat_template = self.chat_template 268 | annotated_documents = [] 269 | pos_added = False 270 | for entity in self.entities: 271 | human_msg_string = self.multi_turn_prefix + entity + ": " + x 272 | if bool(self.augment_with_pos) & (not pos_added): 273 | try: 274 | if callable(self.augment_with_pos): 275 | pos = self.augment_with_pos(x) 276 | else: 277 | pos = self._predict_pos(x, request_timeout) 278 | except Exception as e: 279 | logger.warning( 280 | f"The pos completion for the text '{x}' raised an exception: {e}" 281 | ) 282 | return AnnotatedDocumentWithException( 283 | text=x, annotations=set(), exception=e 284 | ) 285 | logger.debug(f"POS: {pos}") 286 | messages = chat_template.format_messages(x=human_msg_string, pos=pos) 287 | pos_added = True 288 | else: 289 | messages = chat_template.format_messages(x=human_msg_string) 290 | 291 | try: 292 | completion = self._query_model(messages, request_timeout) 293 | except Exception as e: 294 | logger.warning( 295 | f"The completion for the text '{x}' raised an exception: {e}" 296 | ) 297 | return AnnotatedDocumentWithException( 298 | text=x, annotations=set(), exception=e 299 | ) 300 | logger.debug( 301 | f"Human message: {human_msg_string} \n Completion: {completion}" 302 | ) 303 | 304 | annotated_document = AnnotatedDocument(text=x, annotations=set()) 305 | if self.answer_shape == "json": 306 | annotated_document = json_annotation_to_annotated_document( 307 | completion.content, list(self.entities.keys()), x 308 | ) 309 | elif self.answer_shape == "inline": 310 | if self.multi_turn_delimiters: 311 | annotated_document = ( 312 | inline_special_tokens_annotation_to_annotated_document( 313 | completion.content, entity, self.start_token, self.end_token 314 | ) 315 | ) 316 | else: 317 | annotated_document = inline_annotation_to_annotated_document( 318 | completion.content, list(self.entities.keys()) 319 | ) 320 | aligned_annotated_document = align_annotation(x, annotated_document) 321 | annotated_documents.append(aligned_annotated_document) 322 | if self.answer_shape == "inline": 323 | chat_template = ChatPromptTemplate.from_messages( 324 | messages=messages 325 | + [ 326 | AIMessage( 327 | content=annotated_document_to_inline_annotated_string( 328 | aligned_annotated_document, 329 | custom_delimiters=self.multi_turn_delimiters, 330 | ) 331 | ), 332 | HumanMessagePromptTemplate.from_template("{x}"), 333 | ] 334 | ) 335 | elif self.answer_shape == "json": 336 | chat_template = ChatPromptTemplate.from_messages( 337 | messages=messages 338 | + [ 339 | AIMessage( 340 | content=annotated_document_to_json_annotated_string( 341 | aligned_annotated_document 342 | ) 343 | ), 344 | HumanMessagePromptTemplate.from_template("{x}"), 345 | ] 346 | ) 347 | else: 348 | raise ValueError("The answer shape is not valid") 349 | 350 | if self.final_message_with_all_entities == False: 351 | final_annotated_document = annotated_documents[0] 352 | for annotated_document in annotated_documents[1:]: 353 | final_annotated_document.annotations.update(annotated_document.annotations) 354 | else: 355 | messages.append( 356 | HumanMessage( 357 | content = f"{self.prompt_template.final_message_prefix.format(entity_list=list(self.entities.keys()))}: {x}" 358 | ) 359 | ) 360 | try: 361 | completion = self._query_model(messages, request_timeout) 362 | except Exception as e: 363 | logger.warning( 364 | f"The completion for the text '{x}' raised an exception: {e}" 365 | ) 366 | return AnnotatedDocumentWithException( 367 | text=x, annotations=set(), exception=e 368 | ) 369 | if self.answer_shape == "json": 370 | last_annotated_document = json_annotation_to_annotated_document( 371 | completion.content, list(self.entities.keys()), x 372 | ) 373 | elif self.answer_shape == "inline": 374 | if self.multi_turn_delimiters: 375 | raise ValueError("The final message with all entities is not supported with custom delimiters") 376 | else: 377 | last_annotated_document = inline_annotation_to_annotated_document( 378 | completion.content, list(self.entities.keys()) 379 | ) 380 | final_annotated_document = align_annotation(x, last_annotated_document) 381 | 382 | return final_annotated_document 383 | 384 | def _predict_tokenized( 385 | self, x: List[str], request_timeout: int, only_return_labels: bool = False 386 | ) -> Conll | List[Label]: 387 | detokenized_text = detokenizer(x) 388 | annotated_document = AnnotatedDocument(text=detokenized_text, annotations=set()) 389 | if self.prompting_method == "single_turn": 390 | annotated_document = self._predict(detokenized_text, request_timeout) 391 | elif self.prompting_method == "multi_turn": 392 | annotated_document = self._predict_multi_turn( 393 | detokenized_text, request_timeout 394 | ) 395 | if isinstance(annotated_document, AnnotatedDocumentWithException): 396 | logger.warning( 397 | f"The completion for the text '{detokenized_text}' raised an exception: {annotated_document.exception}" 398 | ) 399 | conll = annotated_document_to_conll( 400 | annotated_document, only_return_labels=only_return_labels 401 | ) 402 | if not len(x) == len(conll): 403 | logger.warning( 404 | "The number of tokens and the number of conll tokens are different" 405 | ) 406 | return conll 407 | 408 | def _predict_parallel( 409 | self, x: List[str], max_workers: int, progress_bar: bool, request_timeout: int 410 | ) -> List[AnnotatedDocument | AnnotatedDocumentWithException]: 411 | y = [] 412 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 413 | if self.prompting_method == "single_turn": 414 | for annotated_document in tqdm( 415 | executor.map(lambda x: self._predict(x, request_timeout), x), 416 | disable=not progress_bar, 417 | unit=" example", 418 | total=len(x), 419 | ): 420 | y.append(annotated_document) 421 | 422 | elif self.prompting_method == "multi_turn": 423 | for annotated_document in tqdm( 424 | executor.map( 425 | lambda x: self._predict_multi_turn(x, request_timeout), x 426 | ), 427 | disable=not progress_bar, 428 | unit=" example", 429 | total=len(x), 430 | ): 431 | y.append(annotated_document) 432 | return y 433 | 434 | def _predict_tokenized_parallel( 435 | self, 436 | x: List[List[str]], 437 | max_workers: int, 438 | progress_bar: bool, 439 | request_timeout: int, 440 | only_return_labels: bool = False, 441 | ) -> List[Conll] | List[List[Label]]: 442 | y = [] 443 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 444 | for conll in tqdm( 445 | executor.map( 446 | lambda x: self._predict_tokenized( 447 | x, request_timeout, only_return_labels 448 | ), 449 | x, 450 | ), 451 | disable=not progress_bar, 452 | unit=" example", 453 | total=len(x), 454 | ): 455 | y.append(conll) 456 | return y 457 | 458 | def _predict_serial( 459 | self, x: List[str], progress_bar: bool, request_timeout: int 460 | ) -> List[AnnotatedDocument | AnnotatedDocumentWithException]: 461 | y = [] 462 | for text in tqdm(x, disable=not progress_bar, unit=" example"): 463 | annotated_document = AnnotatedDocument(text=text, annotations=set()) 464 | if self.prompting_method == "single_turn": 465 | annotated_document = self._predict(text, request_timeout) 466 | elif self.prompting_method == "multi_turn": 467 | annotated_document = self._predict_multi_turn(text, request_timeout) 468 | y.append(annotated_document) 469 | return y 470 | 471 | def _predict_tokenized_serial( 472 | self, 473 | x: List[List[str]], 474 | progress_bar: bool, 475 | request_timeout: int, 476 | only_return_labels: bool = False, 477 | ) -> List[Conll] | List[List[Label]]: 478 | y = [] 479 | for tokenized_text in tqdm(x, disable=not progress_bar, unit=" example"): 480 | conll = self._predict_tokenized( 481 | tokenized_text, request_timeout, only_return_labels 482 | ) 483 | y.append(conll) 484 | return y 485 | 486 | def predict( 487 | self, 488 | x: List[str], 489 | progress_bar: bool = True, 490 | max_workers: int = 1, 491 | request_timeout: int = 600, 492 | ) -> List[AnnotatedDocument | AnnotatedDocumentWithException]: 493 | """Method to perform NER on a list of strings. 494 | 495 | Args: 496 | x (List[str]): List of strings. 497 | progress_bar (bool, optional): If True, a progress bar will be displayed. Defaults to True. 498 | max_workers (int, optional): Number of workers to use for parallel processing. If -1, the number of workers will be equal to the number of CPU cores. Defaults to 1. 499 | request_timeout (int, optional): Timeout in seconds for the requests. Defaults to 600 seconds. 500 | 501 | Raises: 502 | NotContextualizedError: Error if the model is not contextualized before calling the predict method. 503 | ValueError: The input must be a list of strings. 504 | 505 | Returns: 506 | List[AnnotatedDocument | AnnotatedDocumentWithException]: List of AnnotatedDocument objects if there were no exceptions, a list of AnnotatedDocumentWithException objects if there were exceptions. 507 | """ 508 | if self.chat_template is None: 509 | raise NotContextualizedError( 510 | "You must call the contextualize method before calling the predict method" 511 | ) 512 | if not isinstance(x, list): 513 | raise ValueError("x must be a list") 514 | if isinstance(x[0], str): 515 | if max_workers == -1: 516 | y = self._predict_parallel(x, CPU_COUNT, progress_bar, request_timeout) 517 | elif max_workers == 1: 518 | y = self._predict_serial(x, progress_bar, request_timeout) 519 | elif max_workers > 1: 520 | y = self._predict_parallel( 521 | x, max_workers, progress_bar, request_timeout 522 | ) 523 | else: 524 | raise ValueError("max_workers must be greater than 0") 525 | else: 526 | raise ValueError( 527 | "x must be a list of strings, maybe you want to use predict_tokenized instead?" 528 | ) 529 | return y 530 | 531 | def predict_tokenized( 532 | self, 533 | x: List[List[str]], 534 | progress_bar: bool = True, 535 | max_workers: int = 1, 536 | request_timeout: int = 600, 537 | only_return_labels: bool = False, 538 | ) -> List[Conll] | List[List[str]]: 539 | """Method to perform NER on a list of tokenized documents. 540 | 541 | Args: 542 | x (List[List[str]]): List of lists of tokens. 543 | progress_bar (bool, optional): If True, a progress bar will be displayed. Defaults to True. 544 | max_workers (int, optional): Number of workers to use for parallel processing. If -1, the number of workers will be equal to the number of CPU cores. Defaults to 1. 545 | request_timeout (int, optional): Timeout in seconds for the requests. Defaults to 600 seconds. 546 | only_return_labels (bool, optional): If True, only the labels will be returned. Defaults to False. 547 | 548 | Returns: 549 | List[Conll] | List[List[str]]: List of Conll objects if only_return_labels is False, a list of lists of labels if only_return_labels is True. 550 | """ 551 | if not isinstance(x, list): 552 | raise ValueError("x must be a list") 553 | if isinstance(x[0], list): 554 | if max_workers == -1: 555 | y = self._predict_tokenized_parallel( 556 | x, CPU_COUNT, progress_bar, request_timeout, only_return_labels 557 | ) 558 | elif max_workers == 1: 559 | y = self._predict_tokenized_serial( 560 | x, progress_bar, request_timeout, only_return_labels 561 | ) 562 | elif max_workers > 1: 563 | y = self._predict_tokenized_parallel( 564 | x, max_workers, progress_bar, request_timeout, only_return_labels 565 | ) 566 | else: 567 | raise ValueError("max_workers must be greater than 0") 568 | else: 569 | raise ValueError( 570 | "x must be a list of lists of tokens, maybe you want to use predict instead?" 571 | ) 572 | return y 573 | 574 | 575 | class FewShotNer(ZeroShotNer): 576 | def contextualize( 577 | self, entities: Dict[str, str], examples: List[AnnotatedDocument] 578 | ): 579 | """Method to ontextualize the few-shot NER model. You need examples to contextualize this model. 580 | 581 | Args: 582 | entities (Dict[str, str]): Dict containing the entities to be recognized. The keys are the entity names and the values are the entity descriptions. 583 | examples (List[AnnotatedDocument]): List of AnnotatedDocument objects containing the annotated examples. 584 | """ 585 | self.entities = entities 586 | if self.multi_turn_delimiters: 587 | self.system_message = self.system_template.format( 588 | entities=dict_to_enumeration(entities), 589 | entity_list=list(entities.keys()), 590 | start_token=self.start_token, 591 | end_token=self.end_token, 592 | ) 593 | else: 594 | self.system_message = self.system_template.format( 595 | entities=dict_to_enumeration(entities), 596 | entity_list=list(entities.keys()), 597 | ) 598 | example_template = ChatPromptTemplate.from_messages( 599 | [("human", "{input}"), ("ai", "{output}")] 600 | ) 601 | 602 | if self.prompting_method == "multi_turn": 603 | few_shot_templates = [] 604 | for example in examples: 605 | few_shot_example = annotated_document_to_multi_turn_few_shot_example( 606 | annotated_document=example, 607 | multi_turn_prefix=self.multi_turn_prefix, 608 | answer_shape=self.answer_shape, # type: ignore 609 | entity_set=list(self.entities.keys()), 610 | custom_delimiters=self.multi_turn_delimiters, 611 | final_message_with_all_entities=self.final_message_with_all_entities, 612 | final_message_prefix=self.prompt_template.final_message_prefix, 613 | ) 614 | template = FewShotChatMessagePromptTemplate( 615 | examples=few_shot_example, 616 | example_prompt=example_template, 617 | ) 618 | few_shot_templates.append(template) 619 | if self.augment_with_pos: 620 | few_shot_template = [] 621 | for template, example in zip(few_shot_templates, examples): 622 | few_shot_template.append( 623 | HumanMessage( 624 | content=f"{self.prompt_template.pos_answer_prefix} {self._predict_pos(example.text, 600)}" 625 | ) 626 | ) 627 | few_shot_template.append(template) 628 | else: 629 | few_shot_template = few_shot_templates 630 | else: 631 | if self.augment_with_pos: 632 | example_template = ChatPromptTemplate.from_messages( 633 | [ 634 | ("human", f"{self.prompt_template.pos_answer_prefix} {{pos}}"), 635 | ("human", "{input}"), 636 | ("ai", "{output}"), 637 | ] 638 | ) 639 | few_shot_examples = [] 640 | for example in examples: 641 | few_shot_example = annotated_document_to_single_turn_few_shot_example(example, answer_shape=self.answer_shape) # type: ignore 642 | few_shot_example["pos"] = self._predict_pos(example.text, 600) 643 | few_shot_examples.append(few_shot_example) 644 | few_shot_template = [ 645 | FewShotChatMessagePromptTemplate( 646 | examples=few_shot_examples, 647 | example_prompt=example_template, 648 | ) 649 | ] 650 | 651 | else: 652 | few_shot_template = [ 653 | FewShotChatMessagePromptTemplate( 654 | examples=list( 655 | map( 656 | lambda x: annotated_document_to_single_turn_few_shot_example( 657 | x, answer_shape=self.answer_shape # type: ignore 658 | ), 659 | examples, 660 | ) 661 | ), 662 | example_prompt=example_template, 663 | ) 664 | ] 665 | 666 | if self.augment_with_pos: 667 | messages = ( 668 | [self.system_message] 669 | + few_shot_template 670 | + [ 671 | HumanMessagePromptTemplate.from_template( 672 | f"{self.prompt_template.pos_answer_prefix} {{pos}}" 673 | ), 674 | HumanMessagePromptTemplate.from_template("{x}"), 675 | ] 676 | ) 677 | self.chat_template = ChatPromptTemplate.from_messages(messages) 678 | else: 679 | messages = ( 680 | [self.system_message] 681 | + few_shot_template 682 | + [HumanMessagePromptTemplate.from_template("{x}")] 683 | ) 684 | 685 | self.chat_template = ChatPromptTemplate.from_messages(messages) 686 | -------------------------------------------------------------------------------- /llmner/templates.py: -------------------------------------------------------------------------------- 1 | from llmner.data import PromptTemplate 2 | 3 | TEMPLATE_EN = PromptTemplate( 4 | inline_single_turn="""You are a named entity recognizer that must detect the next entities: 5 | {entities} 6 | You must answer with the same input text, but with the named entities annotated with in-line tag annotations (text), where each tag corresponds to an entity name, for example: John Doe is the owner of ACME. 7 | The only available tags are: {entity_list}, you cannot add more tags than the included in that list. 8 | IMPORTANT: YOU SHOULD NOT CHANGE THE INPUT TEXT, ONLY ADD THE TAGS.""", 9 | inline_multi_turn_default_delimiters="""You are a named entity recognizer that must detect the next entities: 10 | {entities} 11 | You must answer with the same input text, but with a single entity annotated with in-line tag annotations (text), where the tag corresponds to an entity name, for example, first I ask you to annotate names: John Doe is the owner of ACME and then I ask you to annotate organizations: John Doe is the owner of ACME. 12 | The only available tags are: {entity_list}, you cannot add more tags than the included in that list. 13 | IMPORTANT: YOU SHOULD NOT CHANGE THE INPUT TEXT, ONLY ADD THE TAGS""", 14 | inline_multi_turn_custom_delimiters="""You are a named entity recognizer that must detect the next entities: 15 | {entities} 16 | You must answer with the same input text, but with a single entity annotated with in-line tag annotations ({start_token}text{end_token}), where the tag corresponds to an entity name, for example, first I ask you to annotate names: {start_token}Jhon Doe{end_token} is the owner of ACME and then I ask you to annotate organizations: John Doe is the owner of {start_token}ACME{end_token}. 17 | The only available tags are: {entity_list}, you cannot add more tags than the included in that list. 18 | IMPORTANT: YOU SHOULD NOT CHANGE THE INPUT TEXT, ONLY ADD THE TAGS""", 19 | json_single_turn="""You are a named entity recognizer that must detect the next entities: 20 | {entities} 21 | You must answer with JSON format, where each key corresponds to an entity class, and the value is a list of the entity mentions, for example: {{"name": ["John Doe"], "organization": ["ACME"]}}. 22 | The only available tags are: {entity_list}, you cannot add more tags than the included in that list. 23 | IMPORTANT: YOUR OUTPUT SHOULD ONLY BE A JSON IN THE FORMAT {{"entity_class": ["entity_mention_1", "entity_mention_2"]}}. NO OTHER FORMAT IS ALLOWED.""", 24 | json_multi_turn="""You are a named entity recognizer that must detect the next entities: 25 | {entities} 26 | You must answer with the same input text, but with a single entity annotated with JSON format, where the key corresponds to an entity class for example, first I ask you to annotate names: {{"name": ["John Doe"]}} and then I ask you to annotate organizations: {{"organization": ["ACME"]}} 27 | The only available tags are: {entity_list}, you cannot add more tags than the included in that list. 28 | IMPORTANT: YOUR OUTPUT SHOULD ONLY BE A JSON IN THE FORMAT {{"entity_class": ["entity_mention_1", "entity_mention_2"]}}. NO OTHER FORMAT IS ALLOWED.""", 29 | multi_turn_prefix="""In the next text, annotate the entity """, 30 | pos="""You are a part-of-speech tagger that must detect part-of-speech tags. Respond with the same input text, but with the part-of-speech tags after each word, for example: John/NNP Doe/NNP is/VBZ the/DT owner/NN of/IN ACME/NNP.""", 31 | pos_answer_prefix="""This is the text with the part-of-speech tags:""", 32 | final_message_prefix= """Now, annotate the next document with all entities ({entity_list}):""" 33 | ) 34 | -------------------------------------------------------------------------------- /llmner/utils.py: -------------------------------------------------------------------------------- 1 | from llmner.data import ( 2 | Annotation, 3 | AnnotatedDocument, 4 | Conll, 5 | Label, 6 | NotPerfectlyAlignedError, 7 | AnnotatedDocumentWithException, 8 | ) 9 | from difflib import SequenceMatcher 10 | from copy import deepcopy 11 | import re 12 | from nltk.tokenize import TreebankWordTokenizer as twt 13 | from nltk.tokenize.treebank import TreebankWordDetokenizer as twd 14 | from typing import List, Tuple, Literal, Union 15 | import logging 16 | import json 17 | from collections import defaultdict 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def dict_to_enumeration(d: dict) -> str: 23 | enumeration = "" 24 | for key, value in d.items(): 25 | enumeration += f"- {key}: {value}\n" 26 | return enumeration.strip() 27 | 28 | 29 | def inline_annotation_to_annotated_document( 30 | inline_annotation: str, entity_set: List[str] 31 | ) -> AnnotatedDocument: 32 | annotations = set() 33 | offset = 0 34 | entities_pattern = [rf"<(.*?)>(.*?)"] 35 | all_matches = [ 36 | re.finditer(entity_pattern, inline_annotation) 37 | for entity_pattern in entities_pattern 38 | ] 39 | all_matches = [match for matches in all_matches for match in matches] 40 | # sort all matches by start position in order to change the offset correctly 41 | all_matches = sorted(all_matches, key=lambda x: x.start()) 42 | for match in all_matches: 43 | match_offset = len(match.group(0)) - len(match.group(2)) 44 | start = match.start() - offset 45 | end = match.end() - offset - match_offset 46 | offset += match_offset 47 | # getting the entity name 48 | entity_name = match.group(1) 49 | # add the entity to the dictionary like this: {entity_name: [ [named_entity,start, end], [named_entity, start, end] , ...]} 50 | if (entity_name in entity_set) | (len(entity_set) == 0): 51 | annotations.add(Annotation(start, end, entity_name, text=match.group(2))) 52 | for match in all_matches: 53 | inline_annotation = inline_annotation.replace(match.group(0), match.group(2)) 54 | annotated_document = AnnotatedDocument( 55 | text=inline_annotation, annotations=annotations 56 | ) 57 | return annotated_document 58 | 59 | 60 | def inline_special_tokens_annotation_to_annotated_document( 61 | inline_annotation: str, 62 | entity: str, 63 | start_pattern: str, 64 | end_pattern: str, 65 | ) -> AnnotatedDocument: 66 | annotations = set() 67 | offset = 0 68 | entities_pattern = [rf"{start_pattern}(.*?){end_pattern}"] 69 | all_matches = [ 70 | re.finditer(entity_pattern, inline_annotation) 71 | for entity_pattern in entities_pattern 72 | ] 73 | all_matches = [match for matches in all_matches for match in matches] 74 | # sort all matches by start position in order to change the offset correctly 75 | all_matches = sorted(all_matches, key=lambda x: x.start()) 76 | for match in all_matches: 77 | match_offset = len(match.group(0)) - len(match.group(1)) 78 | start = match.start() - offset 79 | end = match.end() - offset - match_offset 80 | offset += match_offset 81 | entity_name = entity 82 | annotations.add(Annotation(start, end, entity_name, text=match.group(1))) 83 | for match in all_matches: 84 | inline_annotation = inline_annotation.replace(match.group(0), match.group(1)) 85 | annotated_document = AnnotatedDocument( 86 | text=inline_annotation, annotations=annotations 87 | ) 88 | return annotated_document 89 | 90 | 91 | def parse_json(text) -> dict: 92 | text = "".join(c for c in text if c.isprintable()) 93 | text = text.replace("{\n", "{") 94 | text = text.replace("}\n", "}") 95 | text = re.sub(r"'([^\"']+)'", r'"\1"', text) # all pairs as doublequote 96 | # text = re.sub(r"'([^\"']+)':", r'"\1":', text) # keys as doublequote 97 | # text = re.sub(r'"([^\'"]+)":', r"'\1':", text) # keys as singlequote 98 | # text = text.replace("'", '"') 99 | # text = text.replace("\'", '"') 100 | text = text.replace("\\", "") 101 | start_brace = text.find("{") 102 | if start_brace >= 0: 103 | obj_text = text[start_brace:] 104 | nesting = ["}"] 105 | cleaned = "{" 106 | in_string = False 107 | i = 1 108 | while i < len(obj_text) and len(nesting) > 0: 109 | ch = obj_text[i] 110 | if in_string: 111 | cleaned += ch 112 | if ch == "\\": 113 | i += 1 114 | if i < len(obj_text): 115 | cleaned += obj_text[i] 116 | else: 117 | return {} 118 | elif ch == '"': 119 | in_string = False 120 | else: 121 | if ch == '"': 122 | in_string = True 123 | elif ch == "{": 124 | nesting.append("}") 125 | elif ch == "[": 126 | nesting.append("]") 127 | elif ch == "}": 128 | close_object = nesting.pop() 129 | if close_object != "}": 130 | return {} 131 | elif ch == "]": 132 | close_array = nesting.pop() 133 | if close_array != "]": 134 | return {} 135 | elif ch == "<": 136 | ch = '"<' 137 | elif ch == ">": 138 | ch = '>"' 139 | cleaned += ch 140 | i += 1 141 | 142 | if len(nesting) > 0: 143 | cleaned += "".join(reversed(nesting)) 144 | 145 | obj = json.loads(cleaned) 146 | return obj 147 | else: 148 | return {} 149 | 150 | 151 | def json_annotation_to_annotated_document( 152 | json_annotation_str: str, entity_set: List[str], original_text: str 153 | ) -> AnnotatedDocument: 154 | text = original_text 155 | annotations = set() 156 | try: 157 | json_annotation = parse_json(json_annotation_str) 158 | except Exception as e: 159 | logger.warning( 160 | f"Failed to parse json annotation: {json_annotation_str} because {e}" 161 | ) 162 | json_annotation = {} 163 | # Check if json annotations is a in the form {"entity_name": ["entity_mention", "entity_mention", ...]} 164 | fixed_json_annotation = {} 165 | if isinstance(json_annotation, dict): 166 | for key, value in json_annotation.items(): 167 | if isinstance(value, list): 168 | for entity_mention in value: 169 | if isinstance(entity_mention, str): 170 | fixed_json_annotation[key] = value 171 | 172 | for entity_name, entity_mentions in fixed_json_annotation.items(): 173 | for entity_mention in entity_mentions: 174 | matches = list(re.finditer(entity_mention, text)) 175 | if len(matches) == 0: 176 | logger.warning(f"Found 0 matches for {entity_mention} in {text}.") 177 | if len(matches) == 1: 178 | start = matches[0].start() 179 | if start != -1 and entity_name in entity_set: 180 | end = matches[0].end() 181 | annotations.add( 182 | Annotation(start, end, entity_name, text=entity_mention) 183 | ) 184 | elif len(matches) > 1: 185 | logger.warning( 186 | f"Found {len(matches)} matches for {entity_mention} in {text}. The first match will be used." 187 | ) 188 | start = matches[0].start() 189 | if start != -1 and entity_name in entity_set: 190 | end = matches[0].end() 191 | annotations.add( 192 | Annotation(start, end, entity_name, text=entity_mention) 193 | ) 194 | return AnnotatedDocument(text=text, annotations=annotations) 195 | 196 | 197 | def align_annotation( 198 | original_text: str, chatgpt_annotated_document: AnnotatedDocument 199 | ) -> AnnotatedDocument | AnnotatedDocumentWithException: 200 | fixed_annotation = deepcopy(chatgpt_annotated_document) 201 | a = chatgpt_annotated_document.text 202 | b = original_text 203 | 204 | total_difs = [ 205 | (tag, i1, i2, j1, j2, a[i1:i2], b[j1:j2]) 206 | for tag, i1, i2, j1, j2 in SequenceMatcher(None, a, b).get_opcodes() 207 | ] 208 | 209 | replace_difs = [dif for dif in total_difs if dif[0] == "replace"] 210 | 211 | # fix the replace difs 212 | for dif in replace_difs: 213 | a = dif[5] 214 | b = dif[6] 215 | new_entity_difs = [ 216 | ( 217 | tag, 218 | i1 + dif[1], 219 | i2 + dif[1], 220 | j1 + dif[3], 221 | j2 + dif[3], 222 | a[i1:i2], 223 | b[j1:j2], 224 | ) 225 | for tag, i1, i2, j1, j2 in SequenceMatcher(None, a, b).get_opcodes() 226 | ] 227 | total_difs.remove(dif) 228 | total_difs += new_entity_difs 229 | 230 | for entity in fixed_annotation.annotations: 231 | difs = [dif for dif in total_difs if dif[1] <= entity.start] 232 | offset = sum([(dif[4] - dif[3]) - (dif[2] - dif[1]) for dif in difs]) 233 | entity.start += offset 234 | entity.end += offset 235 | 236 | fixed_annotation.text = original_text 237 | 238 | # remove annotations that not exist in original text 239 | # because gpt adds or modifies text 240 | 241 | fixed_annotations_2 = list(fixed_annotation.annotations) 242 | 243 | perfect_align = True 244 | removed_annotations = [] 245 | for annotation in fixed_annotations_2.copy(): 246 | if annotation.text not in original_text: # type: ignore 247 | logger.warning( 248 | f"The text cannot be perfectly aligned: {annotation} was removed because the string is not in the text." 249 | ) 250 | perfect_align = False 251 | removed_annotations.append(annotation) 252 | fixed_annotations_2.remove(annotation) 253 | elif (annotation.start < 0) | (annotation.end < 0): 254 | # check if annotation.text is only one time in original_text 255 | if original_text.count(annotation.text) == 1: # type: ignore 256 | # find annotation.text indices in original_text 257 | start = original_text.find(annotation.text) # type: ignore 258 | end = start + len(annotation.text) # type: ignore 259 | # update annotation indices in fixed_annotation_2 260 | for annotation_2 in fixed_annotations_2: 261 | if annotation_2 == annotation: 262 | annotation_2.start = start 263 | annotation_2.end = end 264 | else: 265 | logger.warning( 266 | f"The text cannot be perfectly aligned: {annotation} was removed because the string was found multiple times." 267 | ) 268 | perfect_align = False 269 | removed_annotations.append(annotation) 270 | fixed_annotations_2.remove(annotation) 271 | fixed_annotation.annotations = set(fixed_annotations_2) 272 | 273 | if perfect_align: 274 | if chatgpt_annotated_document.text != original_text: 275 | logger.info( 276 | f"The text was aligned: {chatgpt_annotated_document.text} -> {original_text}" 277 | ) 278 | return fixed_annotation 279 | else: 280 | return AnnotatedDocumentWithException( 281 | text=fixed_annotation.text, 282 | annotations=fixed_annotation.annotations, 283 | exception=NotPerfectlyAlignedError( 284 | "The text cannot be perfectly aligned", 285 | removed_annotations=removed_annotations, 286 | completion_text=chatgpt_annotated_document.text, 287 | ), 288 | ) 289 | 290 | 291 | def conll_to_inline_annotated_string(conll: Conll) -> str: 292 | annotated_string = "" 293 | current_entity = None 294 | 295 | for token, label in conll: 296 | if label.startswith("B-"): 297 | if current_entity: 298 | annotated_string = annotated_string[:-1] 299 | annotated_string += f" " 300 | entity_class = label[2:] 301 | annotated_string += f"<{entity_class}>{token}" 302 | current_entity = entity_class 303 | elif label.startswith("I-"): 304 | if current_entity: 305 | annotated_string += f"{token}" 306 | else: 307 | if current_entity: 308 | annotated_string = annotated_string[:-1] 309 | annotated_string += f" {token}" 310 | current_entity = None 311 | else: 312 | annotated_string += f"{token}" 313 | annotated_string += " " 314 | 315 | if current_entity: 316 | annotated_string = annotated_string[:-1] 317 | annotated_string += f"" 318 | 319 | return annotated_string.strip() 320 | 321 | 322 | def conll_to_annotated_document(conll: Conll) -> AnnotatedDocument: 323 | return inline_annotation_to_annotated_document( 324 | conll_to_inline_annotated_string(conll), entity_set=[] 325 | ) 326 | 327 | 328 | def annotated_document_to_conll( 329 | annotated_document: AnnotatedDocument, 330 | only_return_labels: bool = False, 331 | ) -> Conll | List[Label]: 332 | spans = list(twt().span_tokenize(annotated_document.text)) 333 | tokens = [annotated_document.text[span[0] : span[1]] for span in spans] 334 | boundaries = [span[0] for span in spans] 335 | conll = ["O"] * len(spans) 336 | for annotation in annotated_document.annotations: 337 | start = annotation.start 338 | end = annotation.end 339 | label = annotation.label 340 | start_idx = 0 341 | for i, boundary in enumerate(boundaries): 342 | if start == boundary: 343 | start_idx = i 344 | break 345 | elif start < boundary: 346 | start_idx = i 347 | break 348 | end_idx = start_idx 349 | for i, boundary in list(enumerate(boundaries))[start_idx + 1 :]: 350 | if end == boundary: 351 | end_idx = i - 1 352 | break 353 | elif end < boundary: 354 | end_idx = i - 1 355 | break 356 | for i in range(start_idx, end_idx + 1): 357 | if i == start_idx: 358 | conll[i] = f"B-{label}" 359 | else: 360 | conll[i] = f"I-{label}" 361 | if only_return_labels: 362 | result = conll 363 | else: 364 | result = list(zip(tokens, conll)) 365 | return result 366 | 367 | 368 | def annotated_document_to_inline_annotated_string( 369 | annotated_document: AnnotatedDocument, 370 | custom_delimiters: Union[Tuple[str, str], None] = None, 371 | ) -> str: 372 | annotated_document = deepcopy(annotated_document) 373 | inline_annotated_string = annotated_document.text 374 | annotations = sorted(annotated_document.annotations, key=lambda x: x.start) 375 | for i in range(len(annotations)): 376 | annotation = annotations[i] 377 | start = annotation.start 378 | end = annotation.end 379 | label = annotation.label 380 | text = inline_annotated_string[start:end] 381 | if custom_delimiters: 382 | inline_annotation = f"{custom_delimiters[0]}{text}{custom_delimiters[1]}" 383 | else: 384 | inline_annotation = f"<{label}>{text}" 385 | inline_annotated_string = ( 386 | inline_annotated_string[:start] 387 | + inline_annotation 388 | + inline_annotated_string[end:] 389 | ) 390 | for j in range(i, len(annotations)): 391 | annotations[j].start += len(inline_annotation) - len(text) 392 | annotations[j].end += len(inline_annotation) - len(text) 393 | 394 | return inline_annotated_string 395 | 396 | 397 | def annotated_document_to_json_annotated_string( 398 | annotated_document: AnnotatedDocument, 399 | ): 400 | annotations = defaultdict(list) 401 | for annotation in annotated_document.annotations: 402 | annotations[annotation.label].append( 403 | annotated_document.text[annotation.start : annotation.end] 404 | ) 405 | return json.dumps(annotations, ensure_ascii=False) 406 | 407 | 408 | def annotated_document_to_single_turn_few_shot_example( 409 | annotated_document: AnnotatedDocument, 410 | answer_shape: Literal["inline", "json"] = "inline", 411 | custom_delimiters: Union[Tuple[str, str], None] = None, 412 | ) -> dict: 413 | if answer_shape == "inline": 414 | annotated_string = annotated_document_to_inline_annotated_string( 415 | annotated_document, 416 | custom_delimiters=custom_delimiters, 417 | ) 418 | elif answer_shape == "json": 419 | annotated_string = annotated_document_to_json_annotated_string( 420 | annotated_document 421 | ) 422 | else: 423 | raise ValueError( 424 | f"answer_shape should be 'inline' or 'json', but {answer_shape} was given." 425 | ) 426 | return {"input": annotated_document.text, "output": annotated_string} 427 | 428 | 429 | def annotated_document_to_multi_turn_few_shot_example( 430 | annotated_document: AnnotatedDocument, 431 | multi_turn_prefix: str, 432 | final_message_prefix: str, 433 | answer_shape: Literal["inline", "json"] = "inline", 434 | entity_set: List[str] = [], 435 | custom_delimiters: Union[Tuple[str, str], None] = None, 436 | final_message_with_all_entities: bool = False, 437 | ) -> List[dict]: 438 | examples = [] 439 | if answer_shape == "inline": 440 | for entity in entity_set: 441 | annotations = { 442 | annotation 443 | for annotation in annotated_document.annotations 444 | if annotation.label == entity 445 | } 446 | examples.append( 447 | { 448 | "input": f"{multi_turn_prefix} {entity}: {annotated_document.text}", 449 | "output": annotated_document_to_inline_annotated_string( 450 | AnnotatedDocument( 451 | text=annotated_document.text, 452 | annotations=annotations, 453 | ), 454 | custom_delimiters=custom_delimiters, 455 | ), 456 | } 457 | ) 458 | if final_message_with_all_entities: 459 | examples.append( 460 | { 461 | "input": f"{final_message_prefix.format(entity_list=entity_set)}: {annotated_document.text}", 462 | "output": annotated_document_to_inline_annotated_string( 463 | annotated_document, custom_delimiters=custom_delimiters 464 | ), 465 | } 466 | ) 467 | elif answer_shape == "json": 468 | for entity in entity_set: 469 | annotations = { 470 | annotation 471 | for annotation in annotated_document.annotations 472 | if annotation.label == entity 473 | } 474 | examples.append( 475 | { 476 | "input": f"{multi_turn_prefix} {entity}: {annotated_document.text}", 477 | "output": annotated_document_to_json_annotated_string( 478 | AnnotatedDocument( 479 | text=annotated_document.text, 480 | annotations=annotations, 481 | ) 482 | ), 483 | } 484 | ) 485 | if final_message_with_all_entities: 486 | examples.append( 487 | { 488 | "input": f"{final_message_prefix.format(entity_list=entity_set)}: {annotated_document.text}", 489 | "output": annotated_document_to_json_annotated_string( 490 | annotated_document 491 | ), 492 | } 493 | ) 494 | else: 495 | raise ValueError( 496 | f"answer_shape should be 'inline' or 'json', but {answer_shape} was given." 497 | ) 498 | return examples 499 | 500 | 501 | def annotated_document_to_multi_turn_chat( 502 | annotated_document: AnnotatedDocument, 503 | entity: str, 504 | parsing_method: str, 505 | human_msg: str, 506 | ): 507 | if parsing_method == "inline": 508 | inline_annotated_string = annotated_document_to_inline_annotated_string( 509 | annotated_document 510 | ) 511 | return {"input": human_msg, "output": inline_annotated_string} 512 | elif parsing_method == "json": 513 | json_annotation = {} 514 | json_annotation[entity] = [ 515 | annotation.text for annotation in annotated_document.annotations 516 | ] 517 | return {"input": human_msg, "output": json.dumps(json_annotation)} 518 | 519 | return {"input": "", "output": ""} 520 | 521 | 522 | def detokenizer(tokens: List[str]) -> str: 523 | return twd().detokenize(tokens) 524 | 525 | 526 | def tokenizer(text: str) -> List[str]: 527 | return twt().tokenize(text) 528 | -------------------------------------------------------------------------------- /notebooks/1-example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "\n", 11 | "os.environ[\"OPENAI_API_KEY\"] = \"\"\n", 12 | "from llmner import ZeroShotNer" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 3, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "entities = {\n", 22 | " \"person\": \"A person name, it can include first and last names, for example: John Kennedy and Bill Gates\",\n", 23 | " \"organization\": \"An organization name, it can be a company, a government agency, etc.\",\n", 24 | " \"location\": \"A location name, it can be a city, a country, etc.\",\n", 25 | "}" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 4, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "model = ZeroShotNer(\n", 35 | " prompting_method=\"multi_turn\",\n", 36 | " answer_shape=\"inline\",\n", 37 | " final_message_with_all_entities=True,\n", 38 | ")\n", 39 | "model.contextualize(entities=entities)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 5, 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "name": "stderr", 49 | "output_type": "stream", 50 | "text": [ 51 | "100%|██████████| 1/1 [00:04<00:00, 4.39s/ example]\n" 52 | ] 53 | }, 54 | { 55 | "data": { 56 | "text/plain": [ 57 | "[AnnotatedDocument(text='Pedro Pereira is the president of Perú and the owner of Walmart.', annotations={Annotation(start=34, end=38, label='location', text='Perú'), Annotation(start=56, end=63, label='organization', text='Walmart'), Annotation(start=0, end=13, label='person', text='Pedro Pereira')})]" 58 | ] 59 | }, 60 | "execution_count": 5, 61 | "metadata": {}, 62 | "output_type": "execute_result" 63 | } 64 | ], 65 | "source": [ 66 | "model.predict([\"Pedro Pereira is the president of Perú and the owner of Walmart.\"])" 67 | ] 68 | } 69 | ], 70 | "metadata": { 71 | "kernelspec": { 72 | "display_name": "Python 3", 73 | "language": "python", 74 | "name": "python3" 75 | }, 76 | "language_info": { 77 | "codemirror_mode": { 78 | "name": "ipython", 79 | "version": 3 80 | }, 81 | "file_extension": ".py", 82 | "mimetype": "text/x-python", 83 | "name": "python", 84 | "nbconvert_exporter": "python", 85 | "pygments_lexer": "ipython3", 86 | "version": "3.12.0" 87 | } 88 | }, 89 | "nbformat": 4, 90 | "nbformat_minor": 2 91 | } 92 | -------------------------------------------------------------------------------- /notebooks/2-conll.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "from llmner import ZeroShotNer, FewShotNer\n", 11 | "from datasets import load_dataset\n", 12 | "from seqeval.metrics import classification_report\n", 13 | "from llmner.data import PromptTemplate\n", 14 | "import json\n", 15 | "import numpy as np\n", 16 | "\n", 17 | "# We change api base for deepinfra \n", 18 | "os.environ[\"OPENAI_API_BASE\"] = \"https://api.deepinfra.com/v1/openai\"\n", 19 | "os.environ[\"OPENAI_API_KEY\"] = \"\"" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 4, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# Test with small data\n", 29 | "\n", 30 | "conll2003 = load_dataset(\"conll2003\", split=\"test[:5%]\")\n", 31 | "conll2002 = load_dataset(\"conll2002\", \"es\", split=\"test[:5%]\")" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 5, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# Mapping from Number to CoNLL-2003 Tag\n", 41 | "n_to_conll = {0:'O', 1:'B-PER', 2:'I-PER', 3:'B-ORG', 4:'I-ORG', 5:'B-LOC', 6:'I-LOC',7: 'B-MISC',8: 'I-MISC' }\n", 42 | "entity_set = [\"PER\", \"ORG\", \"LOC\", \"MISC\"]" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 6, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# Formatting annotations\n", 52 | "\n", 53 | "conll2003_annotations_conll = []\n", 54 | "conll2002_annotations_conll = []\n", 55 | "\n", 56 | "for i in range(len(conll2003)):\n", 57 | " tokens = conll2003[i][\"tokens\"]\n", 58 | " conll2003_annotations_conll.append([(tokens[j] , n_to_conll[conll2003[i][\"ner_tags\"][j]]) for j in range(len(tokens))])\n", 59 | "\n", 60 | "\n", 61 | "for i in range(len(conll2002)):\n", 62 | " tokens = conll2002[i][\"tokens\"]\n", 63 | " conll2002_annotations_conll.append([(tokens[j] , n_to_conll[conll2002[i][\"ner_tags\"][j]]) for j in range(len(tokens))])\n", 64 | "\n", 65 | "conll2003_annotations_seqeval = [ [annotation[j][1] for j in range(len(annotation))] for annotation in conll2003_annotations_conll]\n", 66 | "conll2002_annotations_seqeval = [ [annotation[j][1] for j in range(len(annotation))] for annotation in conll2002_annotations_conll]" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 7, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "from llmner.data import PromptTemplate\n", 76 | "\n", 77 | "template_es = PromptTemplate(\n", 78 | " inline_single_turn=\"\"\"Eres un reconocedor de entidades nombradas que debe detectar las siguientes entidades:\n", 79 | " {entities}\n", 80 | " Debes responder con el mismo texto de entrada, pero con las entidades nombradas anotadas con anotaciones de etiquetas en línea (texto), donde cada etiqueta corresponde a un nombre de entidad, por ejemplo: John Doe es el propietario de ACME.\n", 81 | " Las únicas etiquetas disponibles son: {entity_list}, no puedes agregar más etiquetas que las incluidas en esa lista.\n", 82 | " IMPORTANTE: NO DEBE CAMBIAR EL TEXTO DE ENTRADA, SOLO AGREGAR LAS ETIQUETAS.\"\"\",\n", 83 | " inline_multi_turn_default_delimiters=\"\"\"Eres un reconocedor de entidades nombradas que debe detectar las siguientes entidades:\n", 84 | " {entities}\n", 85 | " Debes responder con el mismo texto de entrada, pero con una sola entidad anotada con anotaciones de etiquetas en línea (texto), donde la etiqueta corresponde a un nombre de entidad, por ejemplo, primero te pido que anotes los nombres: John Doe es el propietario de ACME y luego te pido que anotes las organizaciones: John Doe es el propietario de ACME.\n", 86 | " Las únicas etiquetas disponibles son: {entity_list}, no puedes agregar más etiquetas que las incluidas en esa lista.\n", 87 | " IMPORTANTE: NO DEBE CAMBIAR EL TEXTO DE ENTRADA, SOLO AGREGAR LAS ETIQUETAS\"\"\",\n", 88 | " inline_multi_turn_custom_delimiters=\"\"\"Eres un reconocedor de entidades nombradas que debe detectar las siguientes entidades:\n", 89 | " {entities}\n", 90 | " Debes responder con el mismo texto de entrada, pero con una sola entidad anotada con anotaciones de etiquetas en línea ({start_token}texto{end_token}), donde la etiqueta corresponde a un nombre de entidad, por ejemplo, primero te pido que anotes los nombres: {start_token}Jhon Doe{end_token} es el propietario de ACME y luego te pido que anotes las organizaciones: John Doe es el propietario de {start_token}ACME{end_token}.\n", 91 | " Las únicas etiquetas disponibles son: {entity_list}, no puedes agregar más etiquetas que las incluidas en esa lista.\n", 92 | " IMPORTANTE: NO DEBE CAMBIAR EL TEXTO DE ENTRADA, SOLO AGREGAR LAS ETIQUETAS\"\"\",\n", 93 | " json_single_turn=\"\"\"Eres un reconocedor de entidades nombradas que debe detectar las siguientes entidades:\n", 94 | " {entities}\n", 95 | " Debes responder con formato JSON, donde cada clave corresponde a una clase de entidad, y el valor es una lista de las menciones de la entidad, por ejemplo: {{\"name\": [\"John Doe\"], \"organization\": [\"ACME\"]}}.\n", 96 | " Las únicas etiquetas disponibles son: {entity_list}, no puedes agregar más etiquetas que las incluidas en esa lista.\n", 97 | " IMPORTANTE: SU SALIDA DEBE SER SOLO UN JSON EN EL FORMATO {{\"entity_class\": [\"entity_mention_1\", \"entity_mention_2\"]}}. NO SE PERMITE OTRO FORMATO.\"\"\",\n", 98 | " json_multi_turn=\"\"\"Eres un reconocedor de entidades nombradas que debe detectar las siguientes entidades:\n", 99 | " {entities}\n", 100 | " Debes responder con el mismo texto de entrada, pero con una sola entidad anotada con formato JSON, donde la clave corresponde a una clase de entidad, por ejemplo, primero te pido que anotes los nombres: {{\"name\": [\"John Doe\"]}} y luego te pido que anotes las organizaciones: {{\"organization\": [\"ACME\"]}}\n", 101 | " Las únicas etiquetas disponibles son: {entity_list}, no puedes agregar más etiquetas que las incluidas en esa lista.\n", 102 | " IMPORTANTE: SU SALIDA DEBE SER SOLO UN JSON EN EL FORMATO {{\"entity_class\": [\"entity_mention_1\", \"entity_mention_2\"]}}. NO SE PERMITE OTRO FORMATO.\"\"\",\n", 103 | " multi_turn_prefix=\"\"\"En el siguiente texto, anota la entidad \"\"\",\n", 104 | " pos=\"\"\"Eres un etiquetador de partes del discurso que debe detectar las etiquetas de partes del discurso. Responda con el mismo texto de entrada, pero con las etiquetas de partes del discurso después de cada palabra, por ejemplo: John/NNP Doe/NNP es/VBZ el/DT propietario/NN de/IN ACME/NNP.\"\"\",\n", 105 | " pos_answer_prefix=\"\"\"Este es el texto con las etiquetas de partes del discurso:\"\"\",\n", 106 | " final_message_prefix = \"\"\"Ahora, anota el siguiente documento con todas las entidades ({entity_list}):\"\"\"\n", 107 | ")\n", 108 | "\n", 109 | "entities_en = {\n", 110 | " \"LOC\": \"roads, trajectories, regions, structures, natural locations, public places, commercial places, assorted buildings, abstract places (e.g. the free world)\",\n", 111 | " \"PER\": \"first, middle and last names of people, animals and fictional characters aliases\",\n", 112 | " \"ORG\": \"companies, subdivisions of companies, brands, political movements, government bodies, publications, musical companies, public organisations, other collections of people\",\n", 113 | " \"MISC\": \"words of which one part is a location, organisation, miscellaneous, or person, adjectives and other words derived from a word which is location, organisation, miscellaneous, or person, religions, political ideologies, nationalities, languages, programs, events, wars, sports related names, titles, slogans, eras in time types of objects\",\n", 114 | "}\n", 115 | "entities_es = {\n", 116 | " \"LOC\": \"carreteras, trayectorias, regiones, estructuras, lugares naturales, lugares públicos, lugares comerciales, edificios varios, lugares abstractos (por ejemplo, el mundo libre)\",\n", 117 | " \"PER\": \"nombres de personas, animales y personajes de ficción, alias\",\n", 118 | " \"ORG\": \"empresas, subdivisiones de empresas, marcas, movimientos políticos, organismos gubernamentales, publicaciones, empresas musicales, organizaciones públicas, otras colecciones de personas\",\n", 119 | " \"MISC\": \"palabras de las cuales una parte es una ubicación, organización, miscelánea o persona, adjetivos y otras palabras derivadas de una palabra que es una ubicación, organización, miscelánea o persona, religiones, ideologías políticas, nacionalidades, idiomas, programas, eventos, guerras, nombres relacionados con los deportes, títulos, eslóganes, épocas en tipos de objetos de tiempo\",\n", 120 | "}" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 8, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "# Quit annotation with different length\n", 130 | "def get_different_length_annotations(annotations, predictions):\n", 131 | " annotation_filtered = []\n", 132 | " prediction_filtered = []\n", 133 | " for i in range(len(annotations)):\n", 134 | " if len(annotations[i]) == len(predictions[i]):\n", 135 | " annotation_filtered.append(annotations[i])\n", 136 | " prediction_filtered.append(predictions[i])\n", 137 | " return annotation_filtered, prediction_filtered, abs(len(annotations) - len(annotation_filtered))" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 9, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "name": "stderr", 147 | "output_type": "stream", 148 | "text": [ 149 | " 0%| | 0/173 [00:00=24.2.0", 18 | ], 19 | python_requires=">=3.10", 20 | extras_require={"dev": ["ipykernel==6.25.2"]}, 21 | classifiers=[ 22 | "Development Status :: 1 - Planning", 23 | "Intended Audience :: Science/Research", 24 | "License :: OSI Approved :: BSD License", 25 | "Operating System :: POSIX :: Linux", 26 | "Programming Language :: Python :: 3.10", 27 | "Programming Language :: Python :: 3.11", 28 | "Programming Language :: Python :: 3.12", 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from llmner import ZeroShotNer, FewShotNer 3 | from llmner.data import AnnotatedDocument, Annotation 4 | from llmner.utils import conll_to_annotated_document 5 | import logging 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | entities = { 12 | "person": "A person name, it can include first and last names, for example: Fabián Villena, Claudio or Luis Miranda", 13 | "organization": "An organization name, it can be a company, a government agency, etc.", 14 | "location": "A location name, it can be a city, a country, etc.", 15 | } 16 | 17 | examples = [ 18 | AnnotatedDocument( 19 | text="Gabriel Boric is the president of Chile", 20 | annotations={ 21 | Annotation(start=34, end=39, label="location"), 22 | Annotation(start=0, end=13, label="person"), 23 | }, 24 | ), 25 | AnnotatedDocument( 26 | text="Elon Musk is the owner of the US company Tesla", 27 | annotations={ 28 | Annotation(start=30, end=32, label="location"), 29 | Annotation(start=0, end=9, label="person"), 30 | Annotation(start=41, end=46, label="organization"), 31 | }, 32 | ), 33 | AnnotatedDocument( 34 | text="Bill Gates is the owner of Microsoft", 35 | annotations={ 36 | Annotation(start=0, end=10, label="person"), 37 | Annotation(start=27, end=36, label="organization"), 38 | }, 39 | ), 40 | AnnotatedDocument( 41 | text="John is the president of Argentina and he visited Chile last week", 42 | annotations={ 43 | Annotation(start=0, end=4, label="person"), 44 | Annotation(start=25, end=34, label="location"), 45 | Annotation(start=50, end=55, label="location"), 46 | }, 47 | ), 48 | ] 49 | 50 | x = [ 51 | "Pedro Pereira is the president of Perú and the owner of Walmart.", 52 | "John Kennedy was the president of the United States of America.", 53 | "Jeff Bezos is the owner of Amazon.", 54 | "Jocelyn Dunstan is a female scientist from Chile", 55 | ] 56 | 57 | x_tokenized = [ 58 | [ 59 | "Pedro", 60 | "Pereira", 61 | "is", 62 | "the", 63 | "president", 64 | "of", 65 | "Perú", 66 | "and", 67 | "the", 68 | "owner", 69 | "of", 70 | "Walmart", 71 | ".", 72 | ], 73 | [ 74 | "John", 75 | "Kennedy", 76 | "was", 77 | "the", 78 | "president", 79 | "of", 80 | "the", 81 | "United", 82 | "States", 83 | "of", 84 | "America", 85 | ".", 86 | ], 87 | ["Jeff", "Bezos", "is", "the", "owner", "of", "Amazon", "."], 88 | ["Jocelyn", "Dunstan", "is", "a", "female", "scientist", "from", "Chile"], 89 | ] 90 | 91 | y = [ 92 | AnnotatedDocument( 93 | text="Pedro Pereira is the president of Perú and the owner of Walmart.", 94 | annotations={ 95 | Annotation(start=34, end=38, label="location", text="Perú"), 96 | Annotation(start=0, end=13, label="person", text="Pedro Pereira"), 97 | Annotation(start=56, end=63, label="organization", text="Walmart"), 98 | }, 99 | ), 100 | AnnotatedDocument( 101 | text="John Kennedy was the president of the United States of America.", 102 | annotations={ 103 | Annotation( 104 | start=38, end=62, label="location", text="United States of America" 105 | ), 106 | Annotation(start=0, end=12, label="person", text="John Kennedy"), 107 | }, 108 | ), 109 | AnnotatedDocument( 110 | text="Jeff Bezos is the owner of Amazon.", 111 | annotations={ 112 | Annotation(start=0, end=10, label="person", text="Jeff Bezos"), 113 | Annotation(start=27, end=33, label="organization", text="Amazon"), 114 | }, 115 | ), 116 | AnnotatedDocument( 117 | text="Jocelyn Dunstan is a female scientist from Chile", 118 | annotations={ 119 | Annotation(start=43, end=48, label="location", text="Chile"), 120 | Annotation(start=0, end=15, label="person", text="Jocelyn Dunstan"), 121 | }, 122 | ), 123 | ] 124 | 125 | y_conll = [ 126 | [ 127 | ("Pedro", "B-person"), 128 | ("Pereira", "I-person"), 129 | ("is", "O"), 130 | ("the", "O"), 131 | ("president", "O"), 132 | ("of", "O"), 133 | ("Perú", "B-location"), 134 | ("and", "O"), 135 | ("the", "O"), 136 | ("owner", "O"), 137 | ("of", "O"), 138 | ("Walmart", "B-organization"), 139 | (".", "O"), 140 | ], 141 | [ 142 | ("John", "B-person"), 143 | ("Kennedy", "I-person"), 144 | ("was", "O"), 145 | ("the", "O"), 146 | ("president", "O"), 147 | ("of", "O"), 148 | ("the", "O"), 149 | ("United", "B-location"), 150 | ("States", "I-location"), 151 | ("of", "I-location"), 152 | ("America", "I-location"), 153 | (".", "O"), 154 | ], 155 | [ 156 | ("Jeff", "B-person"), 157 | ("Bezos", "I-person"), 158 | ("is", "O"), 159 | ("the", "O"), 160 | ("owner", "O"), 161 | ("of", "O"), 162 | ("Amazon", "B-organization"), 163 | (".", "O"), 164 | ], 165 | [ 166 | ("Jocelyn", "B-person"), 167 | ("Dunstan", "I-person"), 168 | ("is", "O"), 169 | ("a", "O"), 170 | ("female", "O"), 171 | ("scientist", "O"), 172 | ("from", "O"), 173 | ("Chile", "B-location"), 174 | ], 175 | ] 176 | 177 | 178 | def iou(annotations_true, annotations_predicted) -> float: 179 | # intersection over union 180 | intersection = annotations_true.intersection(annotations_predicted) 181 | union = annotations_true.union(annotations_predicted) 182 | return len(intersection) / len(union) 183 | 184 | 185 | def assert_equal_annotated_documents( 186 | annotated_documents_true, 187 | annotated_documents, 188 | iou_threshold: float = 1.0, 189 | tokenized: bool = False, 190 | ): 191 | if tokenized: 192 | annotated_documents_true = [ 193 | conll_to_annotated_document(doc) for doc in annotated_documents_true 194 | ] 195 | annotated_documents = [ 196 | conll_to_annotated_document(doc) for doc in annotated_documents 197 | ] 198 | close_enough = False 199 | annotations_not_equal = [] 200 | for annotated_document_true, annotated_document in zip( 201 | annotated_documents_true, annotated_documents 202 | ): 203 | iou_value = iou( 204 | annotated_document_true.annotations, annotated_document.annotations 205 | ) 206 | if iou_value == 0: 207 | raise AssertionError( 208 | f"Annotations are not equal. Expected: {annotated_document_true.annotations}, got: {annotated_document.annotations}" 209 | ) 210 | elif iou_value >= iou_threshold: 211 | close_enough = True 212 | else: 213 | annotations_not_equal.append( 214 | ( 215 | annotated_document_true.annotations, 216 | annotated_document.annotations, 217 | iou_value, 218 | ) 219 | ) 220 | 221 | error = "" 222 | for annotations_true, annotations_predicted, iou_value in annotations_not_equal: 223 | error += f"\nExpected: {annotations_true}, got: {annotations_predicted}, iou: {iou_value}" 224 | if not close_enough: 225 | raise AssertionError(f"Annotations are not equal. {error}") 226 | elif len(annotations_not_equal) > 0: 227 | logger.warning(f"Annotations are not perfectly equal. {error}") 228 | 229 | return True 230 | 231 | 232 | def test_model( 233 | few_shot: bool, 234 | model_kwargs: dict, 235 | contextualize_kwargs: dict, 236 | iou_threshold: float = 1.0, 237 | tokenized: bool = False, 238 | ) -> bool: 239 | if not few_shot: 240 | model = ZeroShotNer(**model_kwargs) 241 | model.contextualize(**contextualize_kwargs) 242 | else: 243 | model = FewShotNer(**model_kwargs) 244 | model.contextualize(**contextualize_kwargs) 245 | if not tokenized: 246 | annotated_documents = model.predict(x, max_workers=1) 247 | assert_equal_annotated_documents( 248 | y, annotated_documents, iou_threshold=iou_threshold 249 | ) 250 | if tokenized: 251 | annotated_documents_conll = model.predict_tokenized(x_tokenized, max_workers=-1) 252 | assert_equal_annotated_documents( 253 | y_conll, 254 | annotated_documents_conll, 255 | iou_threshold=iou_threshold, 256 | tokenized=True, 257 | ) 258 | return True 259 | 260 | 261 | class TestZeroShotNer(unittest.TestCase): 262 | # Single-turn test cases 263 | 264 | def test_zero_shot_inline_single_turn_posfalse(self): 265 | test_model( 266 | few_shot=False, 267 | model_kwargs=dict( 268 | answer_shape="inline", 269 | prompting_method="single_turn", 270 | multi_turn_delimiters=None, 271 | augment_with_pos=False, 272 | ), 273 | contextualize_kwargs=dict(entities=entities), 274 | ) 275 | 276 | def test_zero_shot_json_single_turn_posfalse(self): 277 | test_model( 278 | few_shot=False, 279 | model_kwargs=dict( 280 | answer_shape="json", 281 | prompting_method="single_turn", 282 | multi_turn_delimiters=None, 283 | augment_with_pos=False, 284 | ), 285 | contextualize_kwargs=dict(entities=entities), 286 | ) 287 | 288 | def test_zero_shot_inline_single_turn_postrue(self): 289 | test_model( 290 | few_shot=False, 291 | model_kwargs=dict( 292 | answer_shape="inline", 293 | prompting_method="single_turn", 294 | multi_turn_delimiters=None, 295 | augment_with_pos=True, 296 | ), 297 | contextualize_kwargs=dict(entities=entities), 298 | ) 299 | 300 | def test_zero_shot_json_single_turn_postrue(self): 301 | test_model( 302 | few_shot=False, 303 | model_kwargs=dict( 304 | answer_shape="json", 305 | prompting_method="single_turn", 306 | multi_turn_delimiters=None, 307 | augment_with_pos=True, 308 | ), 309 | contextualize_kwargs=dict(entities=entities), 310 | ) 311 | 312 | # Multi-turn test cases 313 | 314 | def test_zero_shot_inline_multi_turn_default_delimiters_posfalse(self): 315 | test_model( 316 | few_shot=False, 317 | model_kwargs=dict( 318 | answer_shape="inline", 319 | prompting_method="multi_turn", 320 | multi_turn_delimiters=None, 321 | augment_with_pos=False, 322 | ), 323 | contextualize_kwargs=dict(entities=entities), 324 | iou_threshold=1, 325 | ) 326 | 327 | def test_zero_shot_inline_multi_turn_custom_delimiters_posfalse(self): 328 | test_model( 329 | few_shot=False, 330 | model_kwargs=dict( 331 | answer_shape="inline", 332 | prompting_method="multi_turn", 333 | multi_turn_delimiters=("@", "@"), 334 | augment_with_pos=False, 335 | ), 336 | contextualize_kwargs=dict(entities=entities), 337 | iou_threshold=1.0, 338 | ) 339 | 340 | def test_zero_shot_inline_multi_turn_default_delimiters_postrue(self): 341 | test_model( 342 | few_shot=False, 343 | model_kwargs=dict( 344 | answer_shape="inline", 345 | prompting_method="multi_turn", 346 | multi_turn_delimiters=None, 347 | augment_with_pos=True, 348 | ), 349 | contextualize_kwargs=dict(entities=entities), 350 | iou_threshold=1.0, 351 | ) 352 | 353 | def test_zero_shot_inline_multi_turn_custom_delimiters_postrue(self): 354 | test_model( 355 | few_shot=False, 356 | model_kwargs=dict( 357 | answer_shape="inline", 358 | prompting_method="multi_turn", 359 | multi_turn_delimiters=("@", "@"), 360 | augment_with_pos=True, 361 | ), 362 | contextualize_kwargs=dict(entities=entities), 363 | iou_threshold=1.0, 364 | ) 365 | 366 | def test_zero_shot_json_multi_turn_posfalse(self): 367 | test_model( 368 | few_shot=False, 369 | model_kwargs=dict( 370 | answer_shape="json", 371 | prompting_method="multi_turn", 372 | multi_turn_delimiters=None, 373 | augment_with_pos=False, 374 | ), 375 | contextualize_kwargs=dict(entities=entities), 376 | iou_threshold=1.0, 377 | ) 378 | 379 | def test_zero_shot_json_multi_turn_postrue(self): 380 | test_model( 381 | few_shot=False, 382 | model_kwargs=dict( 383 | answer_shape="json", 384 | prompting_method="multi_turn", 385 | multi_turn_delimiters=None, 386 | augment_with_pos=True, 387 | ), 388 | contextualize_kwargs=dict(entities=entities), 389 | iou_threshold=1.0, 390 | ) 391 | 392 | def test_zero_shot_json_multi_turn_postrue_final_message(self): 393 | test_model( 394 | few_shot=False, 395 | model_kwargs=dict( 396 | answer_shape="json", 397 | prompting_method="multi_turn", 398 | multi_turn_delimiters=None, 399 | augment_with_pos=True, 400 | final_message_with_all_entities=True, 401 | ), 402 | contextualize_kwargs=dict(entities=entities), 403 | iou_threshold=1.0, 404 | ) 405 | def test_zero_shot_inline_multi_turn_postrue_final_message(self): 406 | test_model( 407 | few_shot=False, 408 | model_kwargs=dict( 409 | answer_shape="inline", 410 | prompting_method="multi_turn", 411 | multi_turn_delimiters=None, 412 | augment_with_pos=True, 413 | final_message_with_all_entities=True, 414 | ), 415 | contextualize_kwargs=dict(entities=entities), 416 | iou_threshold=1.0, 417 | ) 418 | 419 | 420 | class TestFewShotNer(unittest.TestCase): 421 | # Single-turn test cases 422 | 423 | def test_few_shot_inline_single_turn_posfalse(self): 424 | test_model( 425 | few_shot=True, 426 | model_kwargs=dict( 427 | answer_shape="inline", 428 | prompting_method="single_turn", 429 | multi_turn_delimiters=None, 430 | augment_with_pos=False, 431 | ), 432 | contextualize_kwargs=dict(entities=entities, examples=examples), 433 | ) 434 | 435 | def test_few_shot_json_single_turn_posfalse(self): 436 | test_model( 437 | few_shot=True, 438 | model_kwargs=dict( 439 | answer_shape="json", 440 | prompting_method="single_turn", 441 | multi_turn_delimiters=None, 442 | augment_with_pos=False, 443 | ), 444 | contextualize_kwargs=dict(entities=entities, examples=examples), 445 | ) 446 | 447 | def test_few_shot_inline_single_turn_postrue(self): 448 | test_model( 449 | few_shot=True, 450 | model_kwargs=dict( 451 | answer_shape="inline", 452 | prompting_method="single_turn", 453 | multi_turn_delimiters=None, 454 | augment_with_pos=True, 455 | ), 456 | contextualize_kwargs=dict(entities=entities, examples=examples), 457 | ) 458 | 459 | def test_few_shot_json_single_turn_postrue(self): 460 | test_model( 461 | few_shot=True, 462 | model_kwargs=dict( 463 | answer_shape="json", 464 | prompting_method="single_turn", 465 | multi_turn_delimiters=None, 466 | augment_with_pos=True, 467 | ), 468 | contextualize_kwargs=dict(entities=entities, examples=examples), 469 | ) 470 | 471 | # multi-turn test cases 472 | 473 | def test_few_shot_inline_multi_turn_default_delimiters_posfalse(self): 474 | test_model( 475 | few_shot=True, 476 | model_kwargs=dict( 477 | answer_shape="inline", 478 | prompting_method="multi_turn", 479 | multi_turn_delimiters=None, 480 | augment_with_pos=False, 481 | ), 482 | contextualize_kwargs=dict(entities=entities, examples=examples), 483 | iou_threshold=1.0, 484 | ) 485 | 486 | def test_few_shot_inline_multi_turn_custom_delimiters_posfalse(self): 487 | test_model( 488 | few_shot=True, 489 | model_kwargs=dict( 490 | answer_shape="inline", 491 | prompting_method="multi_turn", 492 | multi_turn_delimiters=("@", "@"), 493 | augment_with_pos=False, 494 | ), 495 | contextualize_kwargs=dict(entities=entities, examples=examples), 496 | iou_threshold=1.0, 497 | ) 498 | 499 | def test_few_shot_inline_multi_turn_default_delimiters_postrue(self): 500 | test_model( 501 | few_shot=True, 502 | model_kwargs=dict( 503 | answer_shape="inline", 504 | prompting_method="multi_turn", 505 | multi_turn_delimiters=None, 506 | augment_with_pos=True, 507 | ), 508 | contextualize_kwargs=dict(entities=entities, examples=examples), 509 | iou_threshold=1.0, 510 | ) 511 | 512 | def test_few_shot_inline_multi_turn_custom_delimiters_postrue(self): 513 | test_model( 514 | few_shot=True, 515 | model_kwargs=dict( 516 | answer_shape="inline", 517 | prompting_method="multi_turn", 518 | multi_turn_delimiters=("@", "@"), 519 | augment_with_pos=True, 520 | ), 521 | contextualize_kwargs=dict(entities=entities, examples=examples), 522 | iou_threshold=1.0, 523 | ) 524 | 525 | def test_few_shot_json_multi_turn_posfalse(self): 526 | test_model( 527 | few_shot=True, 528 | model_kwargs=dict( 529 | answer_shape="json", 530 | prompting_method="multi_turn", 531 | multi_turn_delimiters=None, 532 | augment_with_pos=False, 533 | ), 534 | contextualize_kwargs=dict(entities=entities, examples=examples), 535 | iou_threshold=1.0, 536 | ) 537 | 538 | def test_few_shot_json_multi_turn_postrue(self): 539 | test_model( 540 | few_shot=True, 541 | model_kwargs=dict( 542 | answer_shape="json", 543 | prompting_method="multi_turn", 544 | multi_turn_delimiters=None, 545 | augment_with_pos=True, 546 | ), 547 | contextualize_kwargs=dict(entities=entities, examples=examples), 548 | iou_threshold=1.0, 549 | ) 550 | 551 | def test_few_shot_json_multi_turn_postrue_final_message(self): 552 | test_model( 553 | few_shot=True, 554 | model_kwargs=dict( 555 | answer_shape="json", 556 | prompting_method="multi_turn", 557 | multi_turn_delimiters=None, 558 | augment_with_pos=True, 559 | final_message_with_all_entities=True, 560 | ), 561 | contextualize_kwargs=dict(entities=entities, examples=examples), 562 | iou_threshold=1.0, 563 | ) 564 | 565 | def test_few_shot_inline_multi_turn_postrue_final_message(self): 566 | test_model( 567 | few_shot=True, 568 | model_kwargs=dict( 569 | answer_shape="inline", 570 | prompting_method="multi_turn", 571 | multi_turn_delimiters=None, 572 | augment_with_pos=True, 573 | final_message_with_all_entities=True, 574 | ), 575 | contextualize_kwargs=dict(entities=entities, examples=examples), 576 | iou_threshold=1.0, 577 | ) 578 | 579 | 580 | class TestPredictTokenized(unittest.TestCase): 581 | def test_zero_shot_inline_single_turn_posfalse_tokenized(self): 582 | test_model( 583 | few_shot=False, 584 | model_kwargs=dict( 585 | answer_shape="inline", 586 | prompting_method="single_turn", 587 | multi_turn_delimiters=None, 588 | augment_with_pos=False, 589 | ), 590 | contextualize_kwargs=dict(entities=entities), 591 | tokenized=True, 592 | ) 593 | 594 | def test_few_shot_json_multi_turn_postrue_tokenized(self): 595 | test_model( 596 | few_shot=True, 597 | model_kwargs=dict( 598 | answer_shape="json", 599 | prompting_method="multi_turn", 600 | multi_turn_delimiters=None, 601 | augment_with_pos=True, 602 | ), 603 | contextualize_kwargs=dict(entities=entities, examples=examples), 604 | iou_threshold=1.0, 605 | tokenized=True, 606 | ) 607 | 608 | 609 | class TestCustomDelimiters(unittest.TestCase): 610 | def test_zero_shot_inline_multi_turn_custom_delimiters_postrue_doubleat(self): 611 | test_model( 612 | few_shot=False, 613 | model_kwargs=dict( 614 | answer_shape="inline", 615 | prompting_method="multi_turn", 616 | multi_turn_delimiters=("@@", "@@"), 617 | augment_with_pos=True, 618 | ), 619 | contextualize_kwargs=dict(entities=entities), 620 | iou_threshold=1.0, 621 | ) 622 | 623 | def test_few_shot_inline_multi_turn_custom_delimiters_postrue_doubleat(self): 624 | test_model( 625 | few_shot=True, 626 | model_kwargs=dict( 627 | answer_shape="inline", 628 | prompting_method="multi_turn", 629 | multi_turn_delimiters=("@@", "@@"), 630 | augment_with_pos=True, 631 | ), 632 | contextualize_kwargs=dict(entities=entities, examples=examples), 633 | iou_threshold=1.0, 634 | ) 635 | 636 | def test_zero_shot_inline_multi_turn_custom_delimiters_postrue_assymetric(self): 637 | test_model( 638 | few_shot=False, 639 | model_kwargs=dict( 640 | answer_shape="inline", 641 | prompting_method="multi_turn", 642 | multi_turn_delimiters=("@", "#"), 643 | augment_with_pos=True, 644 | ), 645 | contextualize_kwargs=dict(entities=entities), 646 | iou_threshold=1.0, 647 | ) 648 | 649 | def test_few_shot_inline_multi_turn_custom_delimiters_postrue_assymetric(self): 650 | test_model( 651 | few_shot=True, 652 | model_kwargs=dict( 653 | answer_shape="inline", 654 | prompting_method="multi_turn", 655 | multi_turn_delimiters=("@", "#"), 656 | augment_with_pos=True, 657 | ), 658 | contextualize_kwargs=dict(entities=entities, examples=examples), 659 | iou_threshold=1.0, 660 | ) --------------------------------------------------------------------------------