├── .gitignore ├── LICENSE ├── README.md ├── README_Extended.md ├── RELEASE.md ├── configs ├── config.yaml ├── config_biencoder.yaml ├── config_span.yaml └── config_token.yaml ├── convert_to_onnx.py ├── custom_train.py ├── data ├── process_nuner.py └── process_pilener.py ├── demo.jpg ├── demo.py ├── eval.py ├── examples ├── convert_to_onnx.ipynb ├── exal_example_conll.ipynb ├── finetune.ipynb ├── gliner_spacy_demo.ipynb ├── load_local_model.ipynb ├── quickstart.ipynb ├── sample_data.json └── synthetic_data_generation.ipynb ├── gliner ├── __init__.py ├── config.py ├── data_processing │ ├── __init__.py │ ├── collator.py │ ├── dataset.py │ ├── processor.py │ ├── tokenizer.py │ └── utils.py ├── decoding │ ├── __init__.py │ ├── decoder.py │ └── utils.py ├── evaluation │ ├── __init__.py │ ├── evaluate.py │ └── evaluator.py ├── model.py ├── modeling │ ├── __init__.py │ ├── base.py │ ├── encoder.py │ ├── layers.py │ ├── loss_functions.py │ ├── scorers.py │ └── span_rep.py ├── multitask │ ├── __init__.py │ ├── base.py │ ├── classification.py │ ├── open_extraction.py │ ├── question_answering.py │ ├── relation_extraction.py │ └── summarization.py ├── onnx │ ├── __init__.py │ └── model.py ├── training │ ├── __init__.py │ └── trainer.py └── utils.py ├── image.png ├── logo ├── FI Group.png └── FI_COMPLET_CW.png ├── pyproject.toml ├── requirements.txt ├── tests ├── test_features_selection.py └── test_models.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | ### Python ### 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | #data 12 | data.json 13 | 14 | #logs 15 | logs/ 16 | models/ 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/#use-with-ide 119 | .pdm.toml 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | #.idea/ 170 | 171 | ### Python Patch ### 172 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 173 | poetry.toml 174 | 175 | # ruff 176 | .ruff_cache/ 177 | 178 | # LSP config files 179 | pyrightconfig.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 👑 GLiNER: Generalist and Lightweight Model for Named Entity Recognition 2 | 3 | GLiNER is a Named Entity Recognition (NER) model capable of identifying any entity type using a bidirectional transformer encoder (BERT-like). It provides a practical alternative to traditional NER models, which are limited to predefined entities, and Large Language Models (LLMs) that, despite their flexibility, are costly and large for resource-constrained scenarios. 4 | 5 |

6 | 7 | Python 8 | Version 9 | 10 |

11 | 12 |

13 | 📄 Paper 14 |   •   15 | 📢 Discord 16 |   •   17 | 🤗 Demo 18 |   •   19 | 🤗 Available models 20 |   •   21 | 22 | 23 | 24 |

25 | 26 | ## Example Notebooks 27 | 28 | Explore various examples including finetuning, ONNX conversion, and synthetic data generation. 29 | 30 | - [Example Notebooks](https://github.com/urchade/GLiNER/tree/main/examples) 31 | - Finetune on Colab  [](https://colab.research.google.com/drive/1HNKd74cmfS9tGvWrKeIjSxBt01QQS7bq?usp=sharing) 32 | ## 🛠 Installation & Usage 33 | 34 | ### Installation 35 | ```bash 36 | !pip install gliner 37 | ``` 38 | 39 | ### Usage 40 | After the installation of the GLiNER library, import the `GLiNER` class. Following this, you can load your chosen model with `GLiNER.from_pretrained` and utilize `predict_entities` to discern entities within your text. 41 | 42 | ```python 43 | from gliner import GLiNER 44 | 45 | # Initialize GLiNER with the base model 46 | model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1") 47 | 48 | # Sample text for entity prediction 49 | text = """ 50 | Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time. 51 | """ 52 | 53 | # Labels for entity prediction 54 | # Most GLiNER models should work best when entity types are in lower case or title case 55 | labels = ["Person", "Award", "Date", "Competitions", "Teams"] 56 | 57 | # Perform entity prediction 58 | entities = model.predict_entities(text, labels, threshold=0.5) 59 | 60 | # Display predicted entities and their labels 61 | for entity in entities: 62 | print(entity["text"], "=>", entity["label"]) 63 | ``` 64 | 65 | #### Expected Output 66 | 67 | ``` 68 | Cristiano Ronaldo dos Santos Aveiro => person 69 | 5 February 1985 => date 70 | Al Nassr => teams 71 | Portugal national team => teams 72 | Ballon d'Or => award 73 | UEFA Men's Player of the Year Awards => award 74 | European Golden Shoes => award 75 | UEFA Champions Leagues => competitions 76 | UEFA European Championship => competitions 77 | UEFA Nations League => competitions 78 | European Championship => competitions 79 | ``` 80 | ## 🌟 Maintainers 81 | 82 |
83 | 84 | 85 | 90 | 95 | 96 |
86 | Urchade Zaratiana
87 | PhD Student at LIPN
88 | LinkedIn 89 |
91 | Ihor Stepanov
92 | Co-Founder at Knowledgator
93 | LinkedIn 94 |
97 |
98 | 99 | ## 👨‍💻 Model Authors 100 | The model authors are: 101 | * [Urchade Zaratiana](https://huggingface.co/urchade) 102 | * Nadi Tomeh 103 | * Pierre Holat 104 | * Thierry Charnois 105 | 106 | ## 📚 Citation 107 | 108 | If you find GLiNER useful in your research, please consider citing our paper: 109 | 110 | ```bibtex 111 | @inproceedings{zaratiana-etal-2024-gliner, 112 | title = "{GL}i{NER}: Generalist Model for Named Entity Recognition using Bidirectional Transformer", 113 | author = "Zaratiana, Urchade and 114 | Tomeh, Nadi and 115 | Holat, Pierre and 116 | Charnois, Thierry", 117 | editor = "Duh, Kevin and 118 | Gomez, Helena and 119 | Bethard, Steven", 120 | booktitle = "Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers)", 121 | month = jun, 122 | year = "2024", 123 | address = "Mexico City, Mexico", 124 | publisher = "Association for Computational Linguistics", 125 | url = "https://aclanthology.org/2024.naacl-long.300", 126 | doi = "10.18653/v1/2024.naacl-long.300", 127 | pages = "5364--5376", 128 | abstract = "Named Entity Recognition (NER) is essential in various Natural Language Processing (NLP) applications. Traditional NER models are effective but limited to a set of predefined entity types. In contrast, Large Language Models (LLMs) can extract arbitrary entities through natural language instructions, offering greater flexibility. However, their size and cost, particularly for those accessed via APIs like ChatGPT, make them impractical in resource-limited scenarios. In this paper, we introduce a compact NER model trained to identify any type of entity. Leveraging a bidirectional transformer encoder, our model, GLiNER, facilitates parallel entity extraction, an advantage over the slow sequential token generation of LLMs. Through comprehensive testing, GLiNER demonstrate strong performance, outperforming both ChatGPT and fine-tuned LLMs in zero-shot evaluations on various NER benchmarks.", 129 | } 130 | ``` 131 | ## Support and funding 132 | 133 | This project has been supported and funded by **F.initiatives** and **Laboratoire Informatique de Paris Nord**. 134 | 135 | F.initiatives has been an expert in public funding strategies for R&D, Innovation, and Investments (R&D&I) for over 20 years. With a team of more than 200 qualified consultants, F.initiatives guides its clients at every stage of developing their public funding strategy: from structuring their projects to submitting their aid application, while ensuring the translation of their industrial and technological challenges to public funders. Through its continuous commitment to excellence and integrity, F.initiatives relies on the synergy between methods and tools to offer tailored, high-quality, and secure support. 136 | 137 |

138 | FI Group 139 |

140 | 141 | We also extend our heartfelt gratitude to the open-source community for their invaluable contributions, which have been instrumental in the success of this project. 142 | 143 | 144 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # A guide to making a release 2 | 3 | This guide collects the steps we do in GLiNER to make a release on PyPI. They result from (variations of) hard-learned lessons and while following this guide is completely optional, it’s strongly recommended to do so. 🙂 This is a truncated version of the [SetFit](https://github.com/huggingface/setfit/blob/main/RELEASE.md) release guide, which is more exhaustive and does some additional steps. 4 | 5 | ### Preparation 6 | 7 | To be able to make a release for a given project, you’ll need an account on [PyPI](https://pypi.org/) and on [Test PyPI](https://test.pypi.org/). If you are making a release for an existing project, your username will need to be added to that project by one of the current maintainers on PyPI. Note that we strongly recommend enabling two-factor authentication on PyPI. 8 | 9 | You will also need to install twine in your Python environment with `pip install twine`. 10 | 11 | Additionally, it can be nice to familiarize yourself with [Semantic Versioning](https://semver.org/). This is a fairly strict document, but it provides a useful summary that library maintainers should follow: 12 | 13 | > Given a version number MAJOR.MINOR.PATCH, increment the: 14 | > 15 | > 1. MAJOR version when you make incompatible API changes 16 | > 2. MINOR version when you add functionality in a backward compatible manner 17 | > 3. PATCH version when you make backward compatible bug fixes 18 | > 19 | > Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format. 20 | 21 | The very first release should be "0.1.0". 22 | 23 | ## Releases 24 | 25 | ### Step 1: Adjust the version of your package 26 | 27 | You should have the current version specified in [`gliner/__init__.py`](gliner/__init__.py). This version should be a dev version (e.g. `0.1.0.dev`) before you release, change it to the name of the version you are releasing: 28 | 29 | ```diff 30 | - __version__ = "0.4.0.dev" 31 | + __version__ = "0.4.0" 32 | ``` 33 | 34 | Commit the changes on your release branch and push them: 35 | 36 | ```bash 37 | git add gliner 38 | git commit -m "Release: v{VERSION}" 39 | git push -u origin main 40 | ``` 41 | 42 | ### Step 2: (Optional) Make sure all tests pass 43 | 44 | If you add tests, then you should also add CI, e.g. like this [`tests.yaml`](https://github.com/tomaarsen/SpanMarkerNER/blob/main/.github/workflows/tests.yaml) file. This will automatically run tests whenever you make changes, it can be very useful. Make sure all tests that you may have pass before proceeding to the next step. 45 | 46 | ### Step 3: Add a tag for your release 47 | 48 | A tag will flag the exact commit associated to your release (and be easier to remember than the commit hash!). The tag should be `v` so for instance `v4.12.0`. 49 | 50 | Here is how you can create and push your tag: 51 | 52 | ```bash 53 | git tag v 54 | git push --tags origin main 55 | ``` 56 | 57 | ### Step 4: (Optional) Prepare the release notes 58 | 59 | You can then put your release notes in a Draft Release on GitHub, in [https://github.com/urchade/GLiNER/releases](https://github.com/urchade/GLiNER/releases) and write a small paragraph highlighting each of the new features this release is adding. 60 | 61 | You can use the previously created tag to let GitHub auto-generate some release notes based on recent pull requests. 62 | 63 | ### Step 5: Create the wheels for your release 64 | 65 | This is what you'll upload on PyPI and what everyone will download each time they `pip install` your package. 66 | 67 | Clean previous builds by deleting the `build` and `dist` directories or by running: 68 | 69 | ``` 70 | rm -rf build && rm -rf dist 71 | ``` 72 | 73 | Then run: 74 | 75 | ```bash 76 | python -m build 77 | ``` 78 | 79 | This will create two folders, `build` and a `dist` with the new versions of your package. These contain a 1) source distribution and a 2) wheel. 80 | 81 | ### Step 6: Upload your package on PyPI test 82 | 83 | **DO NOT SKIP THIS STEP!** 84 | 85 | This is the most important check before actually releasing your package in the wild. Upload the package on PyPI test and check you can properly install it. 86 | 87 | To upload it: 88 | 89 | ```bash 90 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ 91 | ``` 92 | 93 | You will be prompted for your username and password. If that doesn't work, you can create an API Token for your Test PyPI account and create a `~/.pypirc` account if it doesn't already exist, with: 94 | 95 | ``` 96 | [distutils] 97 | index-servers = 98 | gliner_test 99 | 100 | [gliner_test] 101 | repository = https://test.pypi.org/legacy/ 102 | username = __token__ 103 | password = pypi-... 104 | ``` 105 | (some more details on this [here](https://pypi.org/help/#apitoken)) 106 | 107 | And then run: 108 | ```bash 109 | twine upload dist/* -r gliner_test 110 | ``` 111 | 112 | Once that has uploaded the package, in a fresh environment containing all dependencies you need (tip: you can use Google Colab for this!), try to install your new package from the PyPI test server. First install all dependencies, and then your package. 113 | 114 | ```bash 115 | python -m pip install torch transformers huggingface_hub flair tqdm 116 | python -m pip install -i https://testpypi.python.org/pypi gliner 117 | ``` 118 | 119 | If everything works, you should be able to run this code: 120 | 121 | ```python 122 | from gliner import GLiNER 123 | 124 | model = GLiNER.from_pretrained("urchade/gliner_base") 125 | 126 | text = """ 127 | Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time. 128 | """ 129 | 130 | labels = ["person", "award", "date", "competitions", "teams"] 131 | 132 | entities = model.predict_entities(text, labels, threshold=0.5) 133 | 134 | for entity in entities: 135 | print(entity["text"], "=>", entity["label"]) 136 | ``` 137 | 138 | ### Step 7: Publish on PyPI 139 | 140 | This cannot be undone if you messed up, so make sure you have run Step 6! 141 | 142 | Once you’re fully ready, upload your package on PyPI: 143 | 144 | ```bash 145 | twine upload dist/* -r pypi 146 | ``` 147 | 148 | You will be prompted for your username and password, unless you're using the recommended [PyPI API token](https://pypi.org/help/#apitoken). 149 | 150 | ### Step 8: (Optional) Publish your release notes 151 | 152 | Go back to the draft you did at step 4 ([https://github.com/urchade/GLiNER/releases](https://github.com/urchade/GLiNER/releases)) and publish them. 153 | 154 | ### Step 9: Bump the dev version on the main branch 155 | 156 | You’re almost done! Just go back to the `main` branch and change the dev version in [`gliner/__init__.py`](gliner/__init__.py) to the new version you’re developing, for instance `4.13.0.dev` if just released `4.12.0`. 157 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # Model Configuration 2 | model_name: microsoft/deberta-v3-small # Hugging Face model 3 | labels_encoder: "BAAI/bge-small-en-v1.5" 4 | name: "span level gliner" 5 | max_width: 12 6 | hidden_size: 768 7 | dropout: 0.3 8 | fine_tune: true 9 | subtoken_pooling: first 10 | fuse_layers: false 11 | post_fusion_schema: "l2l-l2t-t2t" 12 | span_mode: markerV0 13 | 14 | # Training Parameters 15 | num_steps: 100000 16 | train_batch_size: 8 17 | eval_every: 5000 18 | warmup_ratio: 0.05 19 | scheduler_type: "cosine" 20 | 21 | # loss function 22 | loss_alpha: 0.75 23 | loss_gamma: 0 24 | label_smoothing: 0 25 | loss_reduction: "sum" 26 | 27 | # Learning Rate and weight decay Configuration 28 | lr_encoder: 1e-5 29 | lr_others: 3e-5 30 | weight_decay_encoder: 0.1 31 | weight_decay_other: 0.01 32 | 33 | max_grad_norm: 10.0 34 | 35 | # Directory Paths 36 | root_dir: gliner_logs 37 | train_data: "data.json" #"data/nuner_train.json" # see https://github.com/urchade/GLiNER/tree/main/data 38 | val_data_dir: "none" 39 | # "NER_datasets": val data from the paper can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view" 40 | 41 | # Pretrained Model Path 42 | # Use "none" if no pretrained model is being used 43 | prev_path: null 44 | 45 | save_total_limit: 3 #maximum amount of checkpoints to save 46 | 47 | # Advanced Training Settings 48 | size_sup: -1 49 | max_types: 100 50 | shuffle_types: true 51 | random_drop: true 52 | max_neg_type_ratio: 1 53 | max_len: 512 54 | freeze_token_rep: false 55 | -------------------------------------------------------------------------------- /configs/config_biencoder.yaml: -------------------------------------------------------------------------------- 1 | # Model Configuration 2 | model_name: microsoft/deberta-v3-small # Hugging Face model 3 | labels_encoder: "microsoft/deberta-v3-small" 4 | name: "span level gliner" 5 | max_width: 12 6 | hidden_size: 768 7 | dropout: 0.4 8 | fine_tune: true 9 | subtoken_pooling: first 10 | fuse_layers: false 11 | post_fusion_schema: "" 12 | span_mode: markerV0 13 | 14 | # Training Parameters 15 | num_steps: 30000 16 | train_batch_size: 8 17 | eval_every: 1000 18 | warmup_ratio: 0.1 19 | scheduler_type: "cosine" 20 | 21 | # loss function 22 | loss_alpha: -1 23 | loss_gamma: 0 24 | label_smoothing: 0 25 | loss_reduction: "sum" 26 | 27 | # Learning Rate and weight decay Configuration 28 | lr_encoder: 1e-5 29 | lr_others: 5e-5 30 | weight_decay_encoder: 0.01 31 | weight_decay_other: 0.01 32 | 33 | max_grad_norm: 10.0 34 | 35 | # Directory Paths 36 | root_dir: gliner_logs 37 | train_data: "data.json" #"data/nuner_train.json" # see https://github.com/urchade/GLiNER/tree/main/data 38 | val_data_dir: "none" 39 | # "NER_datasets": val data from the paper can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view" 40 | 41 | # Pretrained Model Path 42 | # Use "none" if no pretrained model is being used 43 | prev_path: null 44 | 45 | save_total_limit: 3 #maximum amount of checkpoints to save 46 | 47 | # Advanced Training Settings 48 | size_sup: -1 49 | max_types: 25 50 | shuffle_types: true 51 | random_drop: true 52 | max_neg_type_ratio: 1 53 | max_len: 386 54 | freeze_token_rep: false 55 | -------------------------------------------------------------------------------- /configs/config_span.yaml: -------------------------------------------------------------------------------- 1 | # Model Configuration 2 | model_name: microsoft/deberta-v3-small # Hugging Face model 3 | name: "span level gliner" 4 | max_width: 12 5 | hidden_size: 768 6 | dropout: 0.4 7 | fine_tune: true 8 | subtoken_pooling: first 9 | span_mode: markerV0 10 | 11 | # Training Parameters 12 | num_steps: 30000 13 | train_batch_size: 8 14 | eval_every: 5000 15 | warmup_ratio: 0.1 16 | scheduler_type: "cosine" 17 | 18 | # loss function 19 | loss_alpha: -1 # focal loss alpha, if -1, no focal loss 20 | loss_gamma: 0 # focal loss gamma, if 0, no focal loss 21 | label_smoothing: 0 22 | loss_reduction: "sum" 23 | 24 | # Learning Rate and weight decay Configuration 25 | lr_encoder: 1e-5 26 | lr_others: 5e-5 27 | weight_decay_encoder: 0.01 28 | weight_decay_other: 0.01 29 | 30 | max_grad_norm: 1.0 31 | 32 | # Directory Paths 33 | root_dir: span_gliner_logs 34 | train_data: "data.json" # see https://github.com/urchade/GLiNER/tree/main/data 35 | val_data_dir: "none" 36 | # "NER_datasets": val data from the paper can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view" 37 | 38 | # Pretrained Model Path 39 | # Use "none" if no pretrained model is being used 40 | prev_path: "none" 41 | 42 | save_total_limit: 10 #maximum amount of checkpoints to save 43 | 44 | # Advanced Training Settings 45 | size_sup: -1 46 | max_types: 25 47 | shuffle_types: true 48 | random_drop: true 49 | max_neg_type_ratio: 1 50 | max_len: 384 51 | freeze_token_rep: false 52 | -------------------------------------------------------------------------------- /configs/config_token.yaml: -------------------------------------------------------------------------------- 1 | # Model Configuration 2 | model_name: microsoft/deberta-v3-small # Hugging Face model 3 | name: "token level gliner" 4 | max_width: 100 5 | hidden_size: 768 6 | dropout: 0.1 7 | fine_tune: true 8 | subtoken_pooling: first 9 | span_mode: token_level 10 | 11 | # Training Parameters 12 | num_steps: 30000 13 | train_batch_size: 8 14 | eval_every: 5000 15 | warmup_ratio: 0.1 16 | scheduler_type: "cosine" 17 | 18 | # loss function 19 | loss_alpha: -1 # focal loss alpha, if -1, no focal loss 20 | loss_gamma: 0 # focal loss gamma, if 0, no focal loss 21 | label_smoothing: 0 22 | loss_reduction: "sum" 23 | 24 | # Learning Rate and weight decay Configuration 25 | lr_encoder: 1e-5 26 | lr_others: 5e-5 27 | weight_decay_encoder: 0.01 28 | weight_decay_other: 0.01 29 | 30 | max_grad_norm: 1.0 31 | 32 | # Directory Paths 33 | root_dir: gliner_logs 34 | train_data: "train.json" # see https://github.com/urchade/GLiNER/tree/main/data 35 | val_data_dir: "NER_datasets" 36 | # "NER_datasets": val data from the paper can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view" 37 | 38 | # Pretrained Model Path 39 | # Use "none" if no pretrained model is being used 40 | prev_path: "none" 41 | 42 | save_total_limit: 10 #maximum amount of checkpoints to save 43 | 44 | # Advanced Training Settings 45 | size_sup: -1 46 | max_types: 25 47 | shuffle_types: true 48 | random_drop: true 49 | max_neg_type_ratio: 1 50 | max_len: 384 51 | freeze_token_rep: false 52 | 53 | -------------------------------------------------------------------------------- /convert_to_onnx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | from gliner import GLiNER 6 | 7 | import torch 8 | from onnxruntime.quantization import quantize_dynamic, QuantType 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--model_path', type=str, default= "logs/model_12000") 13 | parser.add_argument('--save_path', type=str, default = 'model/') 14 | parser.add_argument('--quantize', type=bool, default = True) 15 | args = parser.parse_args() 16 | 17 | if not os.path.exists(args.save_path): 18 | os.makedirs(args.save_path) 19 | 20 | onnx_save_path = os.path.join(args.save_path, "model.onnx") 21 | 22 | print("Loading a model...") 23 | gliner_model = GLiNER.from_pretrained(args.model_path, load_tokenizer=True) 24 | 25 | text = "ONNX is an open-source format designed to enable the interoperability of AI models across various frameworks and tools." 26 | labels = ['format', 'model', 'tool', 'cat'] 27 | 28 | inputs, _ = gliner_model.prepare_model_inputs([text], labels) 29 | 30 | if gliner_model.config.span_mode == 'token_level': 31 | all_inputs = (inputs['input_ids'], inputs['attention_mask'], 32 | inputs['words_mask'], inputs['text_lengths']) 33 | input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths'] 34 | dynamic_axes={ 35 | "input_ids": {0: "batch_size", 1: "sequence_length"}, 36 | "attention_mask": {0: "batch_size", 1: "sequence_length"}, 37 | "words_mask": {0: "batch_size", 1: "sequence_length"}, 38 | "text_lengths": {0: "batch_size", 1: "value"}, 39 | "logits": {0: "position", 1: "batch_size", 2: "sequence_length", 3: "num_classes"}, 40 | } 41 | else: 42 | all_inputs = (inputs['input_ids'], inputs['attention_mask'], 43 | inputs['words_mask'], inputs['text_lengths'], 44 | inputs['span_idx'], inputs['span_mask']) 45 | input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths', 'span_idx', 'span_mask'] 46 | dynamic_axes={ 47 | "input_ids": {0: "batch_size", 1: "sequence_length"}, 48 | "attention_mask": {0: "batch_size", 1: "sequence_length"}, 49 | "words_mask": {0: "batch_size", 1: "sequence_length"}, 50 | "text_lengths": {0: "batch_size", 1: "value"}, 51 | "span_idx": {0: "batch_size", 1: "num_spans", 2: "idx"}, 52 | "span_mask": {0: "batch_size", 1: "num_spans"}, 53 | "logits": {0: "batch_size", 1: "sequence_length", 2: "num_spans", 3: "num_classes"}, 54 | } 55 | print('Converting the model...') 56 | torch.onnx.export( 57 | gliner_model.model, 58 | all_inputs, 59 | f=onnx_save_path, 60 | input_names=input_names, 61 | output_names=["logits"], 62 | dynamic_axes=dynamic_axes, 63 | opset_version=14, 64 | ) 65 | 66 | if args.quantize: 67 | quantized_save_path = os.path.join(args.save_path, "model_quantized.onnx") 68 | # Quantize the ONNX model 69 | print("Quantizing the model...") 70 | quantize_dynamic( 71 | onnx_save_path, # Input model 72 | quantized_save_path, # Output model 73 | weight_type=QuantType.QUInt8 # Quantize weights to 8-bit integers 74 | ) 75 | print("Done!") -------------------------------------------------------------------------------- /data/process_nuner.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import re 3 | import ast 4 | import json 5 | from tqdm import tqdm 6 | 7 | 8 | def tokenize_text(text): 9 | """Tokenizes the input text into a list of tokens.""" 10 | return re.findall(r'\w+(?:[-_]\w+)*|\S', text) 11 | 12 | 13 | def process_entities(dataset): 14 | """Processes entities in the dataset to extract tokenized text and named entity spans.""" 15 | all_data = [] 16 | for el in tqdm(dataset["entity"]): 17 | try: 18 | tokenized_text = tokenize_text(el["input"]) 19 | parsed_output = ast.literal_eval(el["output"]) 20 | entity_texts, entity_types = zip(*[i.split(" <> ") for i in parsed_output]) 21 | 22 | entity_spans = [] 23 | for j, entity_text in enumerate(entity_texts): 24 | entity_tokens = tokenize_text(entity_text) 25 | matches = [] 26 | for i in range(len(tokenized_text) - len(entity_tokens) + 1): 27 | if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower(): 28 | matches.append((i, i + len(entity_tokens) - 1, entity_types[j])) 29 | if matches: 30 | entity_spans.extend(matches) 31 | 32 | except Exception as e: 33 | continue 34 | 35 | all_data.append({"tokenized_text": tokenized_text, "ner": entity_spans}) 36 | return all_data 37 | 38 | 39 | def save_data_to_file(data, filepath): 40 | """Saves the processed data to a JSON file.""" 41 | with open(filepath, 'w') as f: 42 | json.dump(data, f) 43 | 44 | 45 | if __name__ == "__main__": 46 | dataset = load_dataset("numind/NuNER") 47 | processed_data = process_entities(dataset) 48 | 49 | save_data_to_file(processed_data, 'nuner_train.json') 50 | 51 | print("dataset size:", len(processed_data)) -------------------------------------------------------------------------------- /data/process_pilener.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import ast 4 | from tqdm import tqdm 5 | 6 | def load_data(filepath): 7 | """Loads data from a JSON file.""" 8 | with open(filepath, 'r') as f: 9 | data = json.load(f) 10 | return data 11 | 12 | def tokenize_text(text): 13 | """Tokenizes the input text into a list of tokens.""" 14 | return re.findall(r'\w+(?:[-_]\w+)*|\S', text) 15 | 16 | def extract_entity_spans(entry): 17 | """Extracts entity spans from an entry.""" 18 | len_start = len("What describes ") 19 | len_end = len(" in the text?") 20 | entity_types, entity_texts, negative = [], [], [] 21 | 22 | for c in entry['conversations']: 23 | if c['from'] == 'human' and c['value'].startswith('Text: '): 24 | text = c['value'][len('Text: '):] 25 | tokenized_text = tokenize_text(text) 26 | elif c['from'] == 'human' and c['value'].startswith('What describes '): 27 | entity_type = c['value'][len_start:-len_end] 28 | entity_types.append(entity_type) 29 | elif c['from'] == 'gpt' and c['value'].startswith('['): 30 | if c['value'] == '[]': 31 | negative.append(entity_types.pop()) 32 | continue 33 | texts_ents = ast.literal_eval(c['value']) 34 | entity_texts.extend(texts_ents) 35 | num_repeat = len(texts_ents) - 1 36 | entity_types.extend([entity_types[-1]] * num_repeat) 37 | 38 | entity_spans = [] 39 | for j, entity_text in enumerate(entity_texts): 40 | entity_tokens = tokenize_text(entity_text) 41 | matches = [] 42 | for i in range(len(tokenized_text) - len(entity_tokens) + 1): 43 | if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower(): 44 | matches.append((i, i + len(entity_tokens) - 1, entity_types[j])) 45 | if matches: 46 | entity_spans.extend(matches) 47 | 48 | return {"tokenized_text": tokenized_text, "ner": entity_spans, "negative": negative} 49 | 50 | def process_data(data): 51 | """Processes a list of data entries to extract entity spans.""" 52 | all_data = [extract_entity_spans(entry) for entry in tqdm(data)] 53 | return all_data 54 | 55 | def save_data_to_file(data, filepath): 56 | """Saves the processed data to a JSON file.""" 57 | with open(filepath, 'w') as f: 58 | json.dump(data, f) 59 | 60 | if __name__ == "__main__": 61 | # download the pile-ner data: "wget https://huggingface.co/datasets/Universal-NER/Pile-NER-type/blob/main/train.json" 62 | path_pile_ner = 'train.json' 63 | data = load_data(path_pile_ner) 64 | processed_data = process_data(data) 65 | save_data_to_file(processed_data, 'pilener_train.json') 66 | 67 | print("dataset size:", len(processed_data)) -------------------------------------------------------------------------------- /demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/urchade/GLiNER/efbfa38211136657895372d33d4ee2fe11b6f11b/demo.jpg -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | from gliner import GLiNER 3 | import gradio as gr 4 | 5 | model = GLiNER.from_pretrained("model/", load_tokenizer=True) 6 | 7 | examples = [ 8 | [ 9 | "Libretto by Marius Petipa, based on the 1822 novella ``Trilby, ou Le Lutin d'Argail`` by Charles Nodier, first presented by the Ballet of the Moscow Imperial Bolshoi Theatre on January 25/February 6 (Julian/Gregorian calendar dates), 1870, in Moscow with Polina Karpakova as Trilby and Ludiia Geiten as Miranda and restaged by Petipa for the Imperial Ballet at the Imperial Bolshoi Kamenny Theatre on January 17–29, 1871 in St. Petersburg with Adèle Grantzow as Trilby and Lev Ivanov as Count Leopold.", 10 | "person, book, location, date, actor, character", 11 | 0.3, 12 | True, 13 | ], 14 | [ 15 | """ 16 | * Data Scientist, Data Analyst, or Data Engineer with 1+ years of experience. 17 | * Experience with technologies such as Docker, Kubernetes, or Kubeflow 18 | * Machine Learning experience preferred 19 | * Experience with programming languages such as Python, C++, or SQL preferred 20 | * Experience with technologies such as Databricks, Qlik, TensorFlow, PyTorch, Python, Dash, Pandas, or NumPy preferred 21 | * BA or BS degree 22 | * Active Secret OR Active Top Secret or Active TS/SCI clearance 23 | """, 24 | "software package, programing language, software tool, degree, job title", 25 | 0.3, 26 | False, 27 | ], 28 | [ 29 | "However, both models lack other frequent DM symptoms including the fibre-type dependent atrophy, myotonia, cataract and male-infertility.", 30 | "disease, symptom", 31 | 0.3, 32 | False, 33 | ], 34 | [ 35 | "Synergy between signal transduction pathways is obligatory for expression of c-fos in B and T cell lines: implication for c-fos control via surface immunoglobulin and T cell antigen receptors.", 36 | "DNA, RNA, cell line, cell type, protein", 37 | 0.3, 38 | False, 39 | ], 40 | [ 41 | "The choice of the encoder and decoder modules of dnpg can be quite flexible, for instance long short term memory networks (lstm) or convolutional neural network (cnn).", 42 | "short acronym, long acronym", 43 | 0.3, 44 | False, 45 | ], 46 | [ 47 | "Amelia Earhart flew her single engine Lockheed Vega 5B across the Atlantic to Paris.", 48 | "person, company, location, airplane", 49 | 0.3, 50 | True, 51 | ], 52 | [ 53 | "Feldman is a contributor to NBC Sports Boston's ``State of the Revs`` and ``Revolution Postgame Live`` programs as well as to 98.5 the SportsHub, SiriusXM FC's MLS coverage and to other New England and national radio outlets and podcasts.", 54 | "person, company, location", 55 | 0.3, 56 | False, 57 | ], 58 | [ 59 | "On 25 July 1948, on the 39th anniversary of Bleriot's crossing of the English Channel, the Type 618 Nene-Viking flew Heathrow to Paris (Villacoublay) in the morning carrying letters to Bleriot's widow and son (secretary of the FAI), who met it at the airport.", 60 | "date, location, person, organization", 61 | 0.3, 62 | False, 63 | ], 64 | [ 65 | "Leo & Ian won the 1962 Bathurst Six Hour Classic at Mount Panorama driving a Daimler SP250 sports car, (that year the 500 mile race for touring cars were held at Phillip Island)", 66 | "person, date, location, organization, competition", 67 | 0.3, 68 | False, 69 | ], 70 | [ 71 | "The Shore Line route of the CNS & M until 1955 served, from south to north, the Illinois communities of Chicago, Evanston, Wilmette, Kenilworth, Winnetka, Glencoe, Highland Park, Highwood, Fort Sheridan, Lake Forest, Lake Bluff, North Chicago, Waukegan, Zion, and Winthrop Harbor as well as Kenosha, Racine, and Milwaukee (the ``KRM'') in Wisconsin.", 72 | "location, organization, date", 73 | 0.3, 74 | False, 75 | ], 76 | [ 77 | "Comet C/2006 M4 (SWAN) is a non-periodic comet discovered in late June 2006 by Robert D. Matson of Irvine, California and Michael Mattiazzo of Adelaide, South Australia in publicly available images of the Solar and Heliospheric Observatory (SOHO).", 78 | "person, organization, date, location", 79 | 0.3, 80 | False, 81 | ], 82 | [ 83 | "From November 29, 2011 to March 31, 2012, Karimloo returned to ``Les Misérables`` to play the lead role of Jean Valjean at The Queen's Theatre, London, for which he won the 2013 Theatregoers' Choice Award for Best Takeover in a Role.", 84 | "person, actor, award, date, location", 85 | 0.3, 86 | False, 87 | ], 88 | [ 89 | "A Mexicali health clinic supported by former Baja California gubernatorial candidate Enrique Acosta Fregoso (PRI) was closed on June 15 after selling a supposed COVID-19 ``cure'' for between MXN $10,000 and $50,000.", 90 | "location, organization, person, date, currency", 91 | 0.3, 92 | False, 93 | ], 94 | [ 95 | "Built in 1793, it was the home of Mary Young Pickersgill when she moved to Baltimore in 1806 and the location where she later sewed the ``Star Spangled Banner'', in 1813, the huge out-sized garrison flag that flew over Fort McHenry at Whetstone Point in Baltimore Harbor in the summer of 1814 during the British Royal Navy attack in the Battle of Baltimore during the War of 1812.", 96 | "date, person, location, organization, event, flag", 97 | 0.3, 98 | False, 99 | ], 100 | ] 101 | 102 | 103 | def ner( 104 | text, labels: str, threshold: float, nested_ner: bool 105 | ) -> Dict[str, Union[str, int, float]]: 106 | labels = labels.split(",") 107 | return { 108 | "text": text, 109 | "entities": [ 110 | { 111 | "entity": entity["label"], 112 | "word": entity["text"], 113 | "start": entity["start"], 114 | "end": entity["end"], 115 | "score": 0, 116 | } 117 | for entity in model.predict_entities( 118 | text, labels, flat_ner=not nested_ner, threshold=threshold 119 | ) 120 | ], 121 | } 122 | 123 | 124 | with gr.Blocks(title="GLiNER-M-v2.1") as demo: 125 | gr.Markdown( 126 | """ 127 | # GLiNER-base 128 | GLiNER is a Named Entity Recognition (NER) model capable of identifying any entity type using a bidirectional transformer encoder (BERT-like). It provides a practical alternative to traditional NER models, which are limited to predefined entities, and Large Language Models (LLMs) that, despite their flexibility, are costly and large for resource-constrained scenarios. 129 | ## Links 130 | * Model: https://huggingface.co/urchade/gliner_multi-v2.1 131 | * All GLiNER models: https://huggingface.co/models?library=gliner 132 | * Paper: https://arxiv.org/abs/2311.08526 133 | * Repository: https://github.com/urchade/GLiNER 134 | """ 135 | ) 136 | with gr.Accordion("How to run this model locally", open=False): 137 | gr.Markdown( 138 | """ 139 | ## Installation 140 | To use this model, you must install the GLiNER Python library: 141 | ``` 142 | !pip install gliner 143 | ``` 144 | 145 | ## Usage 146 | Once you've downloaded the GLiNER library, you can import the GLiNER class. You can then load this model using `GLiNER.from_pretrained` and predict entities with `predict_entities`. 147 | """ 148 | ) 149 | gr.Code( 150 | ''' 151 | from gliner import GLiNER 152 | model = GLiNER.from_pretrained("urchade/gliner_mediumv2.1") 153 | text = """ 154 | Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time. 155 | """ 156 | labels = ["person", "award", "date", "competitions", "teams"] 157 | entities = model.predict_entities(text, labels) 158 | for entity in entities: 159 | print(entity["text"], "=>", entity["label"]) 160 | ''', 161 | language="python", 162 | ) 163 | gr.Code( 164 | """ 165 | Cristiano Ronaldo dos Santos Aveiro => person 166 | 5 February 1985 => date 167 | Al Nassr => teams 168 | Portugal national team => teams 169 | Ballon d'Or => award 170 | UEFA Men's Player of the Year Awards => award 171 | European Golden Shoes => award 172 | UEFA Champions Leagues => competitions 173 | UEFA European Championship => competitions 174 | UEFA Nations League => competitions 175 | Champions League => competitions 176 | European Championship => competitions 177 | """ 178 | ) 179 | 180 | input_text = gr.Textbox( 181 | value=examples[0][0], label="Text input", placeholder="Enter your text here" 182 | ) 183 | with gr.Row() as row: 184 | labels = gr.Textbox( 185 | value=examples[0][1], 186 | label="Labels", 187 | placeholder="Enter your labels here (comma separated)", 188 | scale=2, 189 | ) 190 | threshold = gr.Slider( 191 | 0, 192 | 1, 193 | value=0.3, 194 | step=0.01, 195 | label="Threshold", 196 | info="Lower the threshold to increase how many entities get predicted.", 197 | scale=1, 198 | ) 199 | nested_ner = gr.Checkbox( 200 | value=examples[0][2], 201 | label="Nested NER", 202 | info="Allow for nested NER?", 203 | scale=0, 204 | ) 205 | output = gr.HighlightedText(label="Predicted Entities") 206 | submit_btn = gr.Button("Submit") 207 | examples = gr.Examples( 208 | examples, 209 | fn=ner, 210 | inputs=[input_text, labels, threshold, nested_ner], 211 | outputs=output, 212 | cache_examples=True, 213 | ) 214 | 215 | # Submitting 216 | input_text.submit( 217 | fn=ner, inputs=[input_text, labels, threshold, nested_ner], outputs=output 218 | ) 219 | labels.submit( 220 | fn=ner, inputs=[input_text, labels, threshold, nested_ner], outputs=output 221 | ) 222 | threshold.release( 223 | fn=ner, inputs=[input_text, labels, threshold, nested_ner], outputs=output 224 | ) 225 | submit_btn.click( 226 | fn=ner, inputs=[input_text, labels, threshold, nested_ner], outputs=output 227 | ) 228 | nested_ner.change( 229 | fn=ner, inputs=[input_text, labels, threshold, nested_ner], outputs=output 230 | ) 231 | 232 | demo.queue() 233 | demo.launch(debug=True) 234 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from gliner import GLiNER 4 | from gliner.evaluation import get_for_all_path 5 | 6 | 7 | def create_parser(): 8 | parser = argparse.ArgumentParser(description="Span-based NER") 9 | parser.add_argument("--model", type=str, default="logs/model_12000", help="Path to model folder") 10 | parser.add_argument("--log_dir", type=str, default="logs", help="Path to model folder") 11 | parser.add_argument('--data', type=str, default='data/ie_data/NER/', help='Path to the eval datasets directory') 12 | return parser 13 | 14 | 15 | if __name__ == "__main__": 16 | parser = create_parser() 17 | args = parser.parse_args() 18 | 19 | model = GLiNER.from_pretrained(args.model, load_tokenizer=True).to("cuda:0") 20 | get_for_all_path(model, -1, args.log_dir, args.data) -------------------------------------------------------------------------------- /examples/convert_to_onnx.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# !pip install onnx" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "from gliner import GLiNER" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "model = GLiNER.from_pretrained(\"urchade/gliner_medium\")" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# save\n", 38 | "\n", 39 | "model.save_pretrained(\"gliner_medium\")" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "gliner_model = GLiNER.from_pretrained(\"gliner_medium\", load_tokenizer=True)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "import os\n", 58 | "\n", 59 | "onnx_save_path = os.path.join(\"gliner_medium\", \"model.onnx\")" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "text = \"ONNX is an open-source format designed to enable the interoperability of AI models across various frameworks and tools.\"\n", 69 | "labels = ['format', 'model', 'tool', 'cat']\n", 70 | "\n", 71 | "inputs, _ = gliner_model.prepare_model_inputs([text], labels)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "if gliner_model.config.span_mode == 'token_level':\n", 81 | " all_inputs = (inputs['input_ids'], inputs['attention_mask'], \n", 82 | " inputs['words_mask'], inputs['text_lengths'])\n", 83 | " input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths']\n", 84 | " dynamic_axes={\n", 85 | " \"input_ids\": {0: \"batch_size\", 1: \"sequence_length\"},\n", 86 | " \"attention_mask\": {0: \"batch_size\", 1: \"sequence_length\"},\n", 87 | " \"words_mask\": {0: \"batch_size\", 1: \"sequence_length\"},\n", 88 | " \"text_lengths\": {0: \"batch_size\", 1: \"value\"},\n", 89 | " \"logits\": {0: \"position\", 1: \"batch_size\", 2: \"sequence_length\", 3: \"num_classes\"},\n", 90 | " }\n", 91 | "else:\n", 92 | " all_inputs = (inputs['input_ids'], inputs['attention_mask'], \n", 93 | " inputs['words_mask'], inputs['text_lengths'],\n", 94 | " inputs['span_idx'], inputs['span_mask'])\n", 95 | " input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths', 'span_idx', 'span_mask']\n", 96 | " dynamic_axes={\n", 97 | " \"input_ids\": {0: \"batch_size\", 1: \"sequence_length\"},\n", 98 | " \"attention_mask\": {0: \"batch_size\", 1: \"sequence_length\"},\n", 99 | " \"words_mask\": {0: \"batch_size\", 1: \"sequence_length\"},\n", 100 | " \"text_lengths\": {0: \"batch_size\", 1: \"value\"},\n", 101 | " \"span_idx\": {0: \"batch_size\", 1: \"num_spans\", 2: \"idx\"},\n", 102 | " \"span_mask\": {0: \"batch_size\", 1: \"num_spans\"},\n", 103 | " \"logits\": {0: \"batch_size\", 1: \"sequence_length\", 2: \"num_spans\", 3: \"num_classes\"},\n", 104 | " }\n", 105 | "print('Converting the model...')\n", 106 | "all_inputs = dict(zip(input_names,all_inputs))\n", 107 | "\n", 108 | "torch.onnx.export(\n", 109 | " gliner_model.model,\n", 110 | " all_inputs,\n", 111 | " f=onnx_save_path,\n", 112 | " input_names=input_names,\n", 113 | " output_names=[\"logits\"],\n", 114 | " dynamic_axes=dynamic_axes,\n", 115 | " opset_version=14,\n", 116 | ")\n" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "#quantize model\n", 126 | "from onnxruntime.quantization import quantize_dynamic, QuantType\n", 127 | "\n", 128 | "quantized_save_path = os.path.join(\"gliner_medium\", \"model_quantized.onnx\")\n", 129 | "# Quantize the ONNX model\n", 130 | "print(\"Quantizing the model...\")\n", 131 | "quantize_dynamic(\n", 132 | " onnx_save_path, # Input model\n", 133 | " quantized_save_path, # Output model\n", 134 | " weight_type=QuantType.QUInt8 # Quantize weights to 8-bit integers\n", 135 | ")" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "# load onnx model\n", 145 | "model = GLiNER.from_pretrained(\"gliner_medium\", load_onnx_model=True, load_tokenizer=True)\n" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "text = \"\"\"\n", 155 | "Libretto by Marius Petipa, based on the 1822 novella ``Trilby, ou Le Lutin d'Argail`` by Charles Nodier, first presented by the Ballet of the Moscow Imperial Bolshoi Theatre on January 25/February 6 (Julian/Gregorian calendar dates), 1870, in Moscow with Polina Karpakova as Trilby and Ludiia Geiten as Miranda and restaged by Petipa for the Imperial Ballet at the Imperial Bolshoi Kamenny Theatre on January 17–29, 1871 in St. Petersburg with Adèle Grantzow as Trilby and Lev Ivanov as Count Leopold.\n", 156 | "\"\"\"\n", 157 | "\n", 158 | "labels = [\"person\", \"book\", \"location\", \"date\", \"actor\", \"character\"]\n", 159 | "\n", 160 | "entities = model.predict_entities(text, labels, threshold=0.4)\n", 161 | "\n", 162 | "for entity in entities:\n", 163 | " print(entity[\"text\"], \"=>\", entity[\"label\"])" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "# load quantized model\n", 173 | "model = GLiNER.from_pretrained(\"gliner_medium\", load_onnx_model=True, load_tokenizer=True, onnx_model_file=\"model_quantized.onnx\")\n" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "text = \"\"\"\n", 183 | "Libretto by Marius Petipa, based on the 1822 novella ``Trilby, ou Le Lutin d'Argail`` by Charles Nodier, first presented by the Ballet of the Moscow Imperial Bolshoi Theatre on January 25/February 6 (Julian/Gregorian calendar dates), 1870, in Moscow with Polina Karpakova as Trilby and Ludiia Geiten as Miranda and restaged by Petipa for the Imperial Ballet at the Imperial Bolshoi Kamenny Theatre on January 17–29, 1871 in St. Petersburg with Adèle Grantzow as Trilby and Lev Ivanov as Count Leopold.\n", 184 | "\"\"\"\n", 185 | "\n", 186 | "labels = [\"person\", \"book\", \"location\", \"date\", \"actor\", \"character\"]\n", 187 | "\n", 188 | "entities = model.predict_entities(text, labels, threshold=0.4)\n", 189 | "\n", 190 | "for entity in entities:\n", 191 | " print(entity[\"text\"], \"=>\", entity[\"label\"])" 192 | ] 193 | } 194 | ], 195 | "metadata": { 196 | "kernelspec": { 197 | "display_name": "base", 198 | "language": "python", 199 | "name": "python3" 200 | }, 201 | "language_info": { 202 | "codemirror_mode": { 203 | "name": "ipython", 204 | "version": 3 205 | }, 206 | "file_extension": ".py", 207 | "mimetype": "text/x-python", 208 | "name": "python", 209 | "nbconvert_exporter": "python", 210 | "pygments_lexer": "ipython3", 211 | "version": "3.10.10" 212 | }, 213 | "orig_nbformat": 4 214 | }, 215 | "nbformat": 4, 216 | "nbformat_minor": 2 217 | } 218 | -------------------------------------------------------------------------------- /examples/exal_example_conll.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "f2087a56", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install datasets" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 10, 16 | "id": "24a58336-3b16-491b-8646-2f54e93a8964", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from datasets import load_dataset" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 11, 26 | "id": "df670cb5-24cb-4683-849e-2e27769dd762", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "def ner_tags_to_spans(samples, tag_to_id):\n", 31 | " \"\"\"\n", 32 | " Converts NER tags in the dataset samples to spans (start, end, entity type).\n", 33 | " \n", 34 | " Args:\n", 35 | " samples (dict): A dictionary containing the tokens and NER tags.\n", 36 | " tag_to_id (dict): A dictionary mapping NER tags to IDs.\n", 37 | " \n", 38 | " Returns:\n", 39 | " dict: A dictionary containing tokenized text and corresponding NER spans.\n", 40 | " \"\"\"\n", 41 | " ner_tags = samples[\"ner_tags\"]\n", 42 | " id_to_tag = {v: k for k, v in tag_to_id.items()}\n", 43 | " spans = []\n", 44 | " start_pos = None\n", 45 | " entity_name = None\n", 46 | "\n", 47 | " for i, tag in enumerate(ner_tags):\n", 48 | " if tag == 0: # 'O' tag\n", 49 | " if entity_name is not None:\n", 50 | " spans.append((start_pos, i - 1, entity_name))\n", 51 | " entity_name = None\n", 52 | " start_pos = None\n", 53 | " else:\n", 54 | " tag_name = id_to_tag[tag]\n", 55 | " if tag_name.startswith('B-'):\n", 56 | " if entity_name is not None:\n", 57 | " spans.append((start_pos, i - 1, entity_name))\n", 58 | " entity_name = tag_name[2:]\n", 59 | " start_pos = i\n", 60 | " elif tag_name.startswith('I-'):\n", 61 | " continue\n", 62 | "\n", 63 | " # Handle the last entity if the sentence ends with an entity\n", 64 | " if entity_name is not None:\n", 65 | " spans.append((start_pos, len(samples[\"tokens\"]) - 1, entity_name))\n", 66 | " \n", 67 | " return {\"tokenized_text\": samples[\"tokens\"], \"ner\": spans}" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "id": "971f92b9-ece2-460d-99f8-73277b5d3081", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# step 1: load data\n", 78 | "dataset = load_dataset(\"eriktks/conll2003\")" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 13, 84 | "id": "67a18f87-1571-4e8c-8253-6e0305bfa0cb", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "# Step 2: Define NER tag-to-ID mapping\n", 89 | "tag_to_id = {\n", 90 | " 'O': 0, 'B-person': 1, 'I-person': 2, 'B-organization': 3, 'I-organization': 4,\n", 91 | " 'B-location': 5, 'I-location': 6, 'B-others': 7, 'I-others': 8\n", 92 | "}" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 14, 98 | "id": "354aae86-2e5f-4a82-821b-6baba9438532", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "# Convert NER tags to spans for the training data\n", 103 | "gliner_data_conll = [ner_tags_to_spans(i, tag_to_id) for i in dataset['train']]" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 15, 109 | "id": "7c717148-7c98-4998-90fc-c244e11d7b67", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# Load the pre-trained GLiNER model\n", 114 | "from gliner import GLiNER\n", 115 | "import torch\n", 116 | "\n", 117 | "model = GLiNER.from_pretrained(\"urchade/gliner_small\", load_tokenizer=True) #true if a model was trained from scratch with new code base\n", 118 | "\n", 119 | "if torch.cuda.is_available():\n", 120 | " device = \"cuda\"\n", 121 | "else:\n", 122 | " device = \"cpu\"\n", 123 | "\n", 124 | "model = model.to(device)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 17, 130 | "id": "601c2e03-2fe7-481f-b769-c8c874bee9c6", 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "# Evaluate the model on the first 100 samples\n", 135 | "evaluation_results = model.evaluate(\n", 136 | " gliner_data_conll[:100], flat_ner=True, entity_types=[\"person\", \"organization\", \"location\", \"others\"]\n", 137 | ")" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 18, 143 | "id": "273d79be-2f0f-4191-bc8e-641854ffa540", 144 | "metadata": {}, 145 | "outputs": [ 146 | { 147 | "name": "stdout", 148 | "output_type": "stream", 149 | "text": [ 150 | "('P: 63.13%\\tR: 71.43%\\tF1: 67.02%\\n', 0.6702412868632708)\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "print(evaluation_results)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "e8c6190d-6a63-43c0-9010-27d141970877", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [] 165 | } 166 | ], 167 | "metadata": { 168 | "kernelspec": { 169 | "display_name": "Python 3 (ipykernel)", 170 | "language": "python", 171 | "name": "python3" 172 | }, 173 | "language_info": { 174 | "codemirror_mode": { 175 | "name": "ipython", 176 | "version": 3 177 | }, 178 | "file_extension": ".py", 179 | "mimetype": "text/x-python", 180 | "name": "python", 181 | "nbconvert_exporter": "python", 182 | "pygments_lexer": "ipython3", 183 | "version": "3.8.18" 184 | } 185 | }, 186 | "nbformat": 4, 187 | "nbformat_minor": 5 188 | } 189 | -------------------------------------------------------------------------------- /examples/gliner_spacy_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/Applications/anaconda3/envs/gliner-spacy/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import spacy\n", 19 | "from gliner_spacy.pipeline import GlinerSpacy" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stderr", 29 | "output_type": "stream", 30 | "text": [ 31 | "/Applications/anaconda3/envs/gliner-spacy/lib/python3.10/site-packages/transformers/convert_slow_tokenizer.py:550: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n", 32 | " warnings.warn(\n" 33 | ] 34 | }, 35 | { 36 | "data": { 37 | "text/plain": [ 38 | "" 39 | ] 40 | }, 41 | "execution_count": 2, 42 | "metadata": {}, 43 | "output_type": "execute_result" 44 | } 45 | ], 46 | "source": [ 47 | "nlp = spacy.load(\"en_core_web_sm\")\n", 48 | "nlp.add_pipe(\"gliner_spacy\")" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "ent 250\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "text = \"This is a text about Bill Gates and Microsoft.\"\n", 66 | "doc = nlp(text)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "from spacy import displacy" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 5, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "data": { 85 | "text/html": [ 86 | "
This is a text about \n", 87 | "\n", 88 | " Bill Gates\n", 89 | " person\n", 90 | "\n", 91 | " and \n", 92 | "\n", 93 | " Microsoft\n", 94 | " organization\n", 95 | "\n", 96 | ".
" 97 | ], 98 | "text/plain": [ 99 | "" 100 | ] 101 | }, 102 | "metadata": {}, 103 | "output_type": "display_data" 104 | } 105 | ], 106 | "source": [ 107 | "displacy.render(doc, style=\"ent\")" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 6, 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "name": "stdout", 117 | "output_type": "stream", 118 | "text": [ 119 | "Bill Gates person\n", 120 | "Microsoft organization\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "for ent in doc.ents:\n", 126 | " print(ent.text, ent.label_)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [] 135 | } 136 | ], 137 | "metadata": { 138 | "kernelspec": { 139 | "display_name": "gliner-spacy", 140 | "language": "python", 141 | "name": "python3" 142 | }, 143 | "language_info": { 144 | "codemirror_mode": { 145 | "name": "ipython", 146 | "version": 3 147 | }, 148 | "file_extension": ".py", 149 | "mimetype": "text/x-python", 150 | "name": "python", 151 | "nbconvert_exporter": "python", 152 | "pygments_lexer": "ipython3", 153 | "version": "3.10.13" 154 | } 155 | }, 156 | "nbformat": 4, 157 | "nbformat_minor": 2 158 | } 159 | -------------------------------------------------------------------------------- /examples/load_local_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "f2d5e279-1cbc-4291-985f-9e23af3a6ecc", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "from gliner import GLiNER" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "3baf9b40-daba-4638-b4a2-cc82c8a9ed99", 18 | "metadata": { 19 | "scrolled": true 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "# first load your model\n", 24 | "\n", 25 | "model = GLiNER.from_pretrained(\"gliner-community/gliner_medium-v2.5\")" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "id": "5351bc8d-1182-4398-8be7-de61e6b24936", 31 | "metadata": {}, 32 | "source": [ 33 | "## Option 1" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "c25fbf7b-b10c-4808-995e-2431f6c0356f", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "# save\n", 44 | "\n", 45 | "model.save_pretrained(\"gliner_Med\")" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "f46d16c1-ab18-4300-b21f-e78c1da81df3", 52 | "metadata": { 53 | "scrolled": true 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "# load\n", 58 | "\n", 59 | "loaded_model = GLiNER.from_pretrained(\"gliner_Med\", load_tokenizer = True, local_files_only=True)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "4041910a-ee0e-470d-b718-bc151a2666eb", 65 | "metadata": {}, 66 | "source": [ 67 | "## Option 2" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "id": "fe7e3b71-1d15-4739-9b41-fcc279046950", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "def save_model(current_model, path):\n", 78 | " config = current_model.config\n", 79 | " dict_save = {\"model_weights\": current_model.state_dict(), \"config\": config}\n", 80 | " torch.save(dict_save, path)\n", 81 | "\n", 82 | "\n", 83 | "def load_model(path, model_name=None):\n", 84 | " \n", 85 | " dict_load = torch.load(path, map_location=torch.device('cpu'))\n", 86 | " config = dict_load[\"config\"]\n", 87 | "\n", 88 | " print(f\"'{config.model_name}' should be available for local processing\")\n", 89 | "\n", 90 | " if model_name is not None:\n", 91 | " config.model_name = model_name\n", 92 | "\n", 93 | " loaded_model = GLiNER(config)\n", 94 | " loaded_model.load_state_dict(dict_load[\"model_weights\"])\n", 95 | " return loaded_model" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "e513be85-3178-449c-adec-1a609e38b580", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "# save the model weight\n", 106 | "\n", 107 | "save_model(model, \"model_weight.pt\")" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "628eb872-ff3d-4c59-ac20-9b229797090f", 114 | "metadata": { 115 | "scrolled": true 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "# load model weight\n", 120 | "\n", 121 | "loaded_model = load_model(\"model_weight.pt\")\n", 122 | "print(\"success !!\")" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "id": "e057d7ec-1756-4c97-a1d9-e5fdcb60e20a", 128 | "metadata": {}, 129 | "source": [ 130 | "## Testing" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "2827009e-bdb8-44b2-92b5-e6bdcc17f08e", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "text = \"\"\"\n", 141 | "Libretto by Marius Petipa, based on the 1822 novella ``Trilby, ou Le Lutin d'Argail`` by Charles Nodier, first presented by the Ballet of the Moscow Imperial Bolshoi Theatre on January 25/February 6 (Julian/Gregorian calendar dates), 1870, in Moscow with Polina Karpakova as Trilby and Ludiia Geiten as Miranda and restaged by Petipa for the Imperial Ballet at the Imperial Bolshoi Kamenny Theatre on January 17–29, 1871 in St. Petersburg with Adèle Grantzow as Trilby and Lev Ivanov as Count Leopold.\n", 142 | "\"\"\"\n", 143 | "\n", 144 | "labels = [\"person\", \"book\", \"location\", \"date\", \"actor\", \"character\"]\n", 145 | "\n", 146 | "entities = loaded_model.predict_entities(text, labels, threshold=0.4)\n", 147 | "\n", 148 | "for entity in entities:\n", 149 | " print(entity[\"text\"], \"=>\", entity[\"label\"])" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "id": "839336f8-e5a0-471d-ace9-ef7f7e1c5c97", 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [] 159 | } 160 | ], 161 | "metadata": { 162 | "kernelspec": { 163 | "display_name": "Python 3 (ipykernel)", 164 | "language": "python", 165 | "name": "python3" 166 | }, 167 | "language_info": { 168 | "codemirror_mode": { 169 | "name": "ipython", 170 | "version": 3 171 | }, 172 | "file_extension": ".py", 173 | "mimetype": "text/x-python", 174 | "name": "python", 175 | "nbconvert_exporter": "python", 176 | "pygments_lexer": "ipython3", 177 | "version": "3.8.10" 178 | } 179 | }, 180 | "nbformat": 4, 181 | "nbformat_minor": 5 182 | } 183 | -------------------------------------------------------------------------------- /examples/quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "7037f111-e8eb-4270-8e69-b013d075b751", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "from gliner import GLiNER" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "2c86b85f-ab71-4918-95c5-909a79ba7158", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "# available models: https://huggingface.co/urchade\n", 23 | "\n", 24 | "model = GLiNER.from_pretrained(\"urchade/gliner_medium\")\n", 25 | "model.eval()\n", 26 | "print(\"ok\")" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "f823dbf3-2462-4a67-8c4b-9a45ec580c1d", 33 | "metadata": { 34 | "tags": [] 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "text = \"\"\"\n", 39 | "Libretto by Marius Petipa, based on the 1822 novella ``Trilby, ou Le Lutin d'Argail`` by Charles Nodier, first presented by the Ballet of the Moscow Imperial Bolshoi Theatre on January 25/February 6 (Julian/Gregorian calendar dates), 1870, in Moscow with Polina Karpakova as Trilby and Ludiia Geiten as Miranda and restaged by Petipa for the Imperial Ballet at the Imperial Bolshoi Kamenny Theatre on January 17–29, 1871 in St. Petersburg with Adèle Grantzow as Trilby and Lev Ivanov as Count Leopold.\n", 40 | "\"\"\"\n", 41 | "\n", 42 | "labels = [\"person\", \"book\", \"location\", \"date\", \"actor\", \"character\"]\n", 43 | "\n", 44 | "entities = model.predict_entities(text, labels, threshold=0.4)\n", 45 | "\n", 46 | "for entity in entities:\n", 47 | " print(entity[\"text\"], \"=>\", entity[\"label\"])" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "5f5d1377-f073-485f-bd4f-35b5750ba020", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [] 57 | } 58 | ], 59 | "metadata": { 60 | "kernelspec": { 61 | "display_name": "Python 3 (ipykernel)", 62 | "language": "python", 63 | "name": "python3" 64 | }, 65 | "language_info": { 66 | "codemirror_mode": { 67 | "name": "ipython", 68 | "version": 3 69 | }, 70 | "file_extension": ".py", 71 | "mimetype": "text/x-python", 72 | "name": "python", 73 | "nbconvert_exporter": "python", 74 | "pygments_lexer": "ipython3", 75 | "version": "3.10.10" 76 | } 77 | }, 78 | "nbformat": 4, 79 | "nbformat_minor": 5 80 | } 81 | -------------------------------------------------------------------------------- /gliner/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.20" 2 | 3 | from .model import GLiNER 4 | from .config import GLiNERConfig 5 | # from .multitask import (GLiNERClassifier, GLiNERQuestionAnswerer, GLiNEROpenExtractor, 6 | # GLiNERRelationExtractor, GLiNERSummarizer, GLiNERSquadEvaluator, 7 | # GLiNERDocREDEvaluator) 8 | 9 | __all__ = ["GLiNER"] 10 | -------------------------------------------------------------------------------- /gliner/config.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from transformers import PretrainedConfig 3 | from transformers.models.auto import CONFIG_MAPPING 4 | 5 | class GLiNERConfig(PretrainedConfig): 6 | model_type = "gliner" 7 | is_composition = True 8 | def __init__(self, 9 | model_name: str = "microsoft/deberta-v3-small", 10 | labels_encoder: str = None, 11 | name: str = "span level gliner", 12 | max_width: int = 12, 13 | hidden_size: int = 512, 14 | dropout: float = 0.4, 15 | fine_tune: bool = True, 16 | subtoken_pooling: str = "first", 17 | span_mode: str = "markerV0", 18 | post_fusion_schema: str = '', #l2l-l2t-t2t 19 | num_post_fusion_layers: int = 1, 20 | vocab_size: int = -1, 21 | max_neg_type_ratio: int = 1, 22 | max_types: int = 25, 23 | max_len: int = 384, 24 | words_splitter_type: str = "whitespace", 25 | has_rnn: bool = True, 26 | fuse_layers: bool = False, 27 | embed_ent_token: bool = True, 28 | class_token_index: int = -1, 29 | encoder_config: Optional[dict] = None, 30 | labels_encoder_config: Optional[dict] = None, 31 | ent_token = "<>", 32 | sep_token = "<>", 33 | _attn_implementation = None, 34 | **kwargs): 35 | super().__init__(**kwargs) 36 | if isinstance(encoder_config, dict): 37 | encoder_config["model_type"] = (encoder_config["model_type"] 38 | if "model_type" in encoder_config 39 | else "deberta-v2") 40 | encoder_config = CONFIG_MAPPING[encoder_config["model_type"]](**encoder_config) 41 | self.encoder_config = encoder_config 42 | 43 | if isinstance(labels_encoder_config, dict): 44 | labels_encoder_config["model_type"] = (labels_encoder_config["model_type"] 45 | if "model_type" in labels_encoder_config 46 | else "deberta-v2") 47 | labels_encoder_config = CONFIG_MAPPING[labels_encoder_config["model_type"]](**labels_encoder_config) 48 | self.labels_encoder_config = labels_encoder_config 49 | 50 | self.model_name = model_name 51 | self.labels_encoder = labels_encoder 52 | self.name = name 53 | self.max_width = max_width 54 | self.hidden_size = hidden_size 55 | self.dropout = dropout 56 | self.fine_tune = fine_tune 57 | self.subtoken_pooling = subtoken_pooling 58 | self.span_mode = span_mode 59 | self.post_fusion_schema = post_fusion_schema 60 | self.num_post_fusion_layers = num_post_fusion_layers 61 | self.vocab_size = vocab_size 62 | self.max_neg_type_ratio = max_neg_type_ratio 63 | self.max_types = max_types 64 | self.max_len = max_len 65 | self.words_splitter_type = words_splitter_type 66 | self.has_rnn = has_rnn 67 | self.fuse_layers = fuse_layers 68 | self.class_token_index = class_token_index 69 | self.embed_ent_token = embed_ent_token 70 | self.ent_token = ent_token 71 | self.sep_token = sep_token 72 | self._attn_implementation = _attn_implementation 73 | 74 | # Register the configuration 75 | from transformers import CONFIG_MAPPING 76 | CONFIG_MAPPING.update({"gliner": GLiNERConfig}) -------------------------------------------------------------------------------- /gliner/data_processing/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import SpanProcessor, SpanBiEncoderProcessor, TokenProcessor, TokenBiEncoderProcessor 2 | from .collator import DataCollator 3 | from .tokenizer import WordsSplitter 4 | from .dataset import GLiNERDataset -------------------------------------------------------------------------------- /gliner/data_processing/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | import torch.nn.functional as F 4 | from .processor import SpanProcessor, TokenProcessor 5 | from .utils import pad_2d_tensor 6 | 7 | class DataCollator: 8 | def __init__(self, config, tokenizer=None, words_splitter=None, data_processor=None, 9 | return_tokens: bool = False, 10 | return_id_to_classes: bool = False, 11 | return_entities: bool = False, 12 | prepare_labels: bool = False, 13 | entity_types = None): 14 | self.config=config 15 | if data_processor is None: 16 | if config.span_mode == "token_level": 17 | self.data_processor = TokenProcessor(config, tokenizer, words_splitter) 18 | else: 19 | self.data_processor = SpanProcessor(config, tokenizer, words_splitter) 20 | else: 21 | self.data_processor = data_processor 22 | self.prepare_labels = prepare_labels 23 | self.return_tokens = return_tokens 24 | self.return_id_to_classes = return_id_to_classes 25 | self.return_entities = return_entities 26 | self.entity_types = entity_types 27 | 28 | def __call__(self, input_x): 29 | raw_batch = self.data_processor.collate_raw_batch(input_x, entity_types = self.entity_types) 30 | 31 | model_input = self.data_processor.collate_fn(raw_batch, prepare_labels=self.prepare_labels) 32 | model_input.update({"span_idx": raw_batch['span_idx'] if 'span_idx' in raw_batch else None, 33 | "span_mask": raw_batch["span_mask"] if 'span_mask' in raw_batch else None, 34 | "text_lengths": raw_batch['seq_length']}) 35 | if self.return_tokens: 36 | model_input['tokens'] = raw_batch['tokens'] 37 | if self.return_id_to_classes: 38 | model_input['id_to_classes'] = raw_batch['id_to_classes'] 39 | if self.return_entities: 40 | model_input['entities'] = raw_batch['entities'] 41 | model_input = {k:v for k, v in model_input.items() if v is not None} 42 | return model_input 43 | 44 | class DataCollatorWithPadding: 45 | def __init__(self, config=None): 46 | """ 47 | Initialize the DataCollator with configs. 48 | """ 49 | self.config = config 50 | 51 | def __call__(self, batch): 52 | if not batch: 53 | raise ValueError("Batch cannot be empty") 54 | batch = [item for item in batch if item is not None] 55 | # Extract all keys from the first item 56 | keys = batch[0].keys() 57 | 58 | # Create a dictionary to hold padded data 59 | padded_batch = {key: [] for key in keys} 60 | 61 | for key in keys: 62 | if key in {'tokens', 'id_to_classes', 'entities'}: 63 | padded_batch[key] = [item[key] for item in batch] 64 | continue 65 | # Collect data for the current key 66 | key_data = [item[key].squeeze(0) for item in batch] 67 | 68 | if isinstance(key_data[0], torch.Tensor): 69 | if key_data[0].dim() == 1: 70 | # For 1D tensors, use pad_sequence 71 | if key == 'span_label': 72 | span_label = pad_sequence(key_data, batch_first=True, padding_value=-1) 73 | span_mask = span_label != -1 74 | padded_batch[key] = span_mask 75 | else: 76 | padded_batch[key] = pad_sequence(key_data, batch_first=True) 77 | elif key_data[0].dim() == 2: # span_idx case 78 | padded_batch[key] = self._pad_2d_tensor(key_data) 79 | elif key == 'labels' and self.config.span_mode == 'token_level': 80 | padded_batch[key] = self.pad_token_labels(key_data) 81 | else: 82 | raise TypeError(f"Unsuported amount of dimension for key '{key}'") 83 | elif isinstance(key_data[0], list): 84 | # Pad list-like data 85 | max_length = max(len(seq) for seq in key_data) 86 | padded_batch[key] = torch.tensor( 87 | [seq + [0] * (max_length - len(seq)) for seq in key_data], 88 | dtype=torch.float32 89 | ).to(self.device) 90 | elif isinstance(key_data[0], (int, float)): 91 | # Directly convert numeric data to tensors 92 | padded_batch[key] = torch.tensor(key_data, dtype=torch.float32).to(self.device) 93 | else: 94 | raise TypeError(f"Unsupported data type for key '{key}': {type(key_data[0])}") 95 | padded_batch = {k:v for k,v in padded_batch.items() if v is not None} 96 | return padded_batch 97 | 98 | def _pad_2d_tensor(self, key_data): 99 | padded_tensors = pad_2d_tensor(key_data) 100 | return padded_tensors 101 | 102 | def pad_token_labels(self, key_data): 103 | if not key_data: 104 | raise ValueError("The input list 'key_data' should not be empty.") 105 | 106 | # Determine the maximum sequence length and number of classes 107 | max_seq_len = max(tensor.shape[2] for tensor in key_data) 108 | max_num_classes = max(tensor.shape[3] for tensor in key_data) 109 | 110 | padded_tensors = [] 111 | 112 | for tensor in key_data: 113 | current_seq_len = tensor.shape[2] 114 | current_num_classes = tensor.shape[3] 115 | 116 | seq_padding = max_seq_len - current_seq_len 117 | class_padding = max_num_classes - current_num_classes 118 | 119 | # Pad tensor to the maximum sequence length and number of classes 120 | padded_tensor = F.pad(tensor, (0, class_padding, 0, seq_padding), mode='constant', value=0) 121 | padded_tensors.append(padded_tensor) 122 | 123 | # Concatenate the tensors along the batch dimension 124 | concatenated_labels = torch.cat(padded_tensors, dim=1) 125 | 126 | return concatenated_labels -------------------------------------------------------------------------------- /gliner/data_processing/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from tqdm import tqdm 3 | from typing import Optional, List 4 | from torch.utils.data import Dataset 5 | from transformers import AutoTokenizer 6 | 7 | from . import TokenProcessor, SpanProcessor, WordsSplitter 8 | from ..config import GLiNERConfig 9 | 10 | class GLiNERDataset(Dataset): 11 | def __init__(self, examples, 12 | config: Optional[GLiNERConfig], 13 | tokenizer: Optional[AutoTokenizer] = None, 14 | words_splitter: Optional[WordsSplitter] = None, 15 | data_processor = None, 16 | entities = None, 17 | get_negatives:bool=True): 18 | self._data = examples 19 | self.config=config 20 | if data_processor is not None: 21 | self.data_processor = data_processor 22 | else: 23 | if config.span_mode == "token_level": 24 | self.data_processor = TokenProcessor(config, tokenizer, words_splitter, preprocess_text=True) 25 | else: 26 | self.data_processor = SpanProcessor(config, tokenizer, words_splitter, preprocess_text=True) 27 | 28 | self.max_neg_type_ratio = int(self.config.max_neg_type_ratio) 29 | self.get_negatives = get_negatives 30 | if not entities: 31 | self.all_entities = self._collect_all_entities() 32 | else: 33 | self.all_entities = entities 34 | self.max_negatives = min(50, len(self.all_entities)) 35 | 36 | def _get_entities_from_example(self, example): 37 | entities = {ner[-1] for ner in example['ner']} 38 | return entities 39 | 40 | def _collect_all_entities(self): 41 | print("Collecting all entities...") 42 | all_entities = set() 43 | for example in tqdm(self._data): 44 | curr_entities = self._get_entities_from_example(example) 45 | all_entities.update(curr_entities) 46 | print('Total number of entity classes: ', len(all_entities)) 47 | return list(all_entities) 48 | 49 | def _get_negatives(self): 50 | negatives = random.sample(self.all_entities, k=self.max_negatives) 51 | random.shuffle(negatives) 52 | return negatives 53 | 54 | def __len__(self): 55 | return len(self._data) 56 | 57 | def __getitem__(self, idx): 58 | try: 59 | example = self._data[idx] 60 | if self.get_negatives: 61 | curr_negatives = self._get_negatives() 62 | else: 63 | curr_negatives = None 64 | 65 | raw_batch = self.data_processor.collate_raw_batch([example], negatives = curr_negatives) 66 | 67 | model_input = self.data_processor.collate_fn(raw_batch, prepare_labels=True) 68 | if 'span_idx' in raw_batch: 69 | model_input['span_idx'] = raw_batch['span_idx'] 70 | if 'span_mask' in raw_batch: 71 | model_input['span_mask'] = raw_batch['span_mask'] 72 | if 'seq_length' in raw_batch: 73 | model_input['text_lengths'] = raw_batch['seq_length'] 74 | return model_input 75 | except Exception as e: 76 | print(f"Skipping getting item due to error: {e}") 77 | return None -------------------------------------------------------------------------------- /gliner/data_processing/tokenizer.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | class TokenSplitterBase(): 5 | def __init__(self): 6 | pass 7 | 8 | def __call__(self, text) -> (str, int, int): 9 | pass 10 | 11 | 12 | class WhitespaceTokenSplitter(TokenSplitterBase): 13 | def __init__(self): 14 | self.whitespace_pattern = re.compile(r'\w+(?:[-_]\w+)*|\S') 15 | 16 | def __call__(self, text): 17 | for match in self.whitespace_pattern.finditer(text): 18 | yield match.group(), match.start(), match.end() 19 | 20 | 21 | class SpaCyTokenSplitter(TokenSplitterBase): 22 | def __init__(self, lang=None): 23 | try: 24 | import spacy # noqa 25 | except ModuleNotFoundError as error: 26 | raise error.__class__( 27 | "Please install spacy with: `pip install spacy`" 28 | ) 29 | if lang is None: 30 | lang = 'en' # Default to English if no language is specified 31 | self.nlp = spacy.blank(lang) 32 | 33 | def __call__(self, text): 34 | doc = self.nlp(text) 35 | for token in doc: 36 | yield token.text, token.idx, token.idx + len(token.text) 37 | 38 | 39 | class MecabKoTokenSplitter(TokenSplitterBase): 40 | def __init__(self): 41 | try: 42 | import mecab # noqa 43 | except ModuleNotFoundError as error: 44 | raise error.__class__( 45 | "Please install python-mecab-ko with: `pip install python-mecab-ko`" 46 | ) 47 | self.tagger = mecab.MeCab() 48 | 49 | def __call__(self, text): 50 | tokens = self.tagger.morphs(text) 51 | 52 | last_idx = 0 53 | for morph in tokens: 54 | start_idx = text.find(morph, last_idx) 55 | end_idx = start_idx + len(morph) 56 | last_idx = end_idx 57 | yield morph, start_idx, end_idx 58 | 59 | class JiebaTokenSplitter(TokenSplitterBase): 60 | def __init__(self): 61 | try: 62 | import jieba # noqa 63 | except ModuleNotFoundError as error: 64 | raise error.__class__( 65 | "Please install jieba with: `pip install jieba`" 66 | ) 67 | self.tagger = jieba 68 | 69 | def __call__(self, text): 70 | tokens = self.tagger.cut(text) 71 | last_idx = 0 72 | for token in tokens: 73 | start_idx = text.find(token, last_idx) 74 | end_idx = start_idx + len(token) 75 | last_idx = end_idx 76 | yield token, start_idx, end_idx 77 | 78 | class HanLPTokenSplitter(TokenSplitterBase): 79 | def __init__(self, model_name="FINE_ELECTRA_SMALL_ZH"): 80 | try: 81 | import hanlp # noqa 82 | import hanlp.pretrained 83 | except ModuleNotFoundError as error: 84 | raise error.__class__( 85 | "Please install hanlp with: `pip install hanlp`" 86 | ) 87 | 88 | models = hanlp.pretrained.tok.ALL 89 | if model_name not in models: 90 | raise ValueError(f"HanLP: {model_name} is not available, choose between {models.keys()}") 91 | url = models[model_name] 92 | self.tagger = hanlp.load(url) 93 | 94 | def __call__(self, text): 95 | tokens = self.tagger(text) 96 | last_idx = 0 97 | for token in tokens: 98 | start_idx = text.find(token, last_idx) 99 | end_idx = start_idx + len(token) 100 | last_idx = end_idx 101 | yield token, start_idx, end_idx 102 | 103 | class WordsSplitter(TokenSplitterBase): 104 | def __init__(self, splitter_type='whitespace'): 105 | if splitter_type=='whitespace': 106 | self.splitter = WhitespaceTokenSplitter() 107 | elif splitter_type == 'spacy': 108 | self.splitter = SpaCyTokenSplitter() 109 | elif splitter_type == 'mecab': 110 | self.splitter = MecabKoTokenSplitter() 111 | elif splitter_type == 'jieba': 112 | self.splitter = JiebaTokenSplitter() 113 | elif splitter_type == 'hanlp': 114 | self.splitter = HanLPTokenSplitter() 115 | else: 116 | raise ValueError(f"{splitter_type} is not implemented, choose between 'whitespace', 'spacy', 'jieba', 'hanlp' and 'mecab'") 117 | 118 | def __call__(self, text): 119 | for token in self.splitter(text): 120 | yield token -------------------------------------------------------------------------------- /gliner/data_processing/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def pad_2d_tensor(key_data): 4 | """ 5 | Pad a list of 2D tensors to have the same size along both dimensions. 6 | 7 | :param key_data: List of 2D tensors to pad. 8 | :return: Tensor of padded tensors stacked along a new batch dimension. 9 | """ 10 | if not key_data: 11 | raise ValueError("The input list 'key_data' should not be empty.") 12 | 13 | # Determine the maximum size along both dimensions 14 | max_rows = max(tensor.shape[0] for tensor in key_data) 15 | max_cols = max(tensor.shape[1] for tensor in key_data) 16 | 17 | tensors = [] 18 | 19 | for tensor in key_data: 20 | rows, cols = tensor.shape 21 | row_padding = max_rows - rows 22 | col_padding = max_cols - cols 23 | 24 | # Pad the tensor along both dimensions 25 | padded_tensor = torch.nn.functional.pad(tensor, (0, col_padding, 0, row_padding), 26 | mode='constant', value=0) 27 | tensors.append(padded_tensor) 28 | 29 | # Stack the tensors into a single tensor along a new batch dimension 30 | padded_tensors = torch.stack(tensors) 31 | 32 | return padded_tensors -------------------------------------------------------------------------------- /gliner/decoding/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import SpanDecoder, TokenDecoder -------------------------------------------------------------------------------- /gliner/decoding/decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from abc import ABC, abstractmethod 3 | from functools import partial 4 | import torch 5 | 6 | from .utils import has_overlapping, has_overlapping_nested 7 | 8 | 9 | class BaseDecoder(ABC): 10 | def __init__(self, config): 11 | self.config = config 12 | 13 | @abstractmethod 14 | def decode(self, *args, **kwargs): 15 | pass 16 | 17 | def greedy_search(self, spans, flat_ner=True, multi_label=False): 18 | if flat_ner: 19 | has_ov = partial(has_overlapping, multi_label=multi_label) 20 | else: 21 | has_ov = partial(has_overlapping_nested, multi_label=multi_label) 22 | 23 | new_list = [] 24 | span_prob = sorted(spans, key=lambda x: -x[-1]) 25 | 26 | for i in range(len(spans)): 27 | b = span_prob[i] 28 | flag = False 29 | for new in new_list: 30 | if has_ov(b[:-1], new): 31 | flag = True 32 | break 33 | if not flag: 34 | new_list.append(b) 35 | 36 | new_list = sorted(new_list, key=lambda x: x[0]) 37 | return new_list 38 | 39 | 40 | class SpanDecoder(BaseDecoder): 41 | def decode(self, tokens, id_to_classes, model_output, flat_ner=False, threshold=0.5, multi_label=False): 42 | probs = torch.sigmoid(model_output) 43 | spans = [] 44 | for i, _ in enumerate(tokens): 45 | probs_i = probs[i] 46 | 47 | # Support for id_to_classes being a list of dictionaries 48 | id_to_class_i = id_to_classes[i] if isinstance(id_to_classes, list) else id_to_classes 49 | 50 | wh_i = [i.tolist() for i in torch.where(probs_i > threshold)] 51 | span_i = [] 52 | for s, k, c in zip(*wh_i): 53 | if s + k < len(tokens[i]): 54 | span_i.append((s, s + k, id_to_class_i[c + 1], probs_i[s, k, c].item())) 55 | 56 | span_i = self.greedy_search(span_i, flat_ner, multi_label=multi_label) 57 | spans.append(span_i) 58 | return spans 59 | 60 | 61 | class TokenDecoder(BaseDecoder): 62 | def get_indices_above_threshold(self, scores, threshold): 63 | scores = torch.sigmoid(scores) 64 | return [k.tolist() for k in torch.where(scores > threshold)] 65 | 66 | def calculate_span_score(self, start_idx, end_idx, scores_inside_i, start_i, end_i, id_to_classes, threshold): 67 | span_i = [] 68 | for st, cls_st in zip(*start_idx): 69 | for ed, cls_ed in zip(*end_idx): 70 | if ed >= st and cls_st == cls_ed: 71 | ins = scores_inside_i[st:ed + 1, cls_st] 72 | if (ins < threshold).any(): 73 | continue 74 | # Get the start and end scores for this span 75 | start_score = start_i[st, cls_st] 76 | end_score = end_i[ed, cls_st] 77 | # Concatenate the inside scores with start and end scores 78 | combined = torch.cat([ins, start_score.unsqueeze(0), end_score.unsqueeze(0)]) 79 | # The span score is the minimum value among these scores 80 | spn_score = combined.min().item() 81 | span_i.append((st, ed, id_to_classes[cls_st + 1], spn_score)) 82 | return span_i 83 | 84 | def decode(self, tokens, id_to_classes, model_output, flat_ner=False, threshold=0.5, multi_label=False): 85 | scores_start, scores_end, scores_inside = model_output 86 | spans = [] 87 | for i, _ in enumerate(tokens): 88 | id_to_class_i = id_to_classes[i] if isinstance(id_to_classes, list) else id_to_classes 89 | span_scores = self.calculate_span_score( 90 | self.get_indices_above_threshold(scores_start[i], threshold), 91 | self.get_indices_above_threshold(scores_end[i], threshold), 92 | torch.sigmoid(scores_inside[i]), 93 | torch.sigmoid(scores_start[i]), 94 | torch.sigmoid(scores_end[i]), 95 | id_to_class_i, 96 | threshold 97 | ) 98 | span_i = self.greedy_search(span_scores, flat_ner, multi_label) 99 | spans.append(span_i) 100 | return spans -------------------------------------------------------------------------------- /gliner/decoding/utils.py: -------------------------------------------------------------------------------- 1 | def is_nested(idx1, idx2): 2 | # Return True if idx2 is nested inside idx1 or vice versa 3 | return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1]) 4 | 5 | def has_overlapping(idx1, idx2, multi_label=False): 6 | # Check for any overlap between two spans 7 | if idx1[:2] == idx2[:2]: # Exact same boundaries can be considered as overlapping 8 | return not multi_label 9 | if idx1[0] > idx2[1] or idx2[0] > idx1[1]: 10 | return False 11 | return True 12 | 13 | 14 | def has_overlapping_nested(idx1, idx2, multi_label=False): 15 | # Return True if idx1 and idx2 overlap, but neither is nested inside the other 16 | if idx1[:2] == idx2[:2]: # Exact same boundaries, not considering labels here 17 | return not multi_label 18 | if (idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2): 19 | return False 20 | return True 21 | -------------------------------------------------------------------------------- /gliner/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import Evaluator 2 | from .evaluate import get_for_all_path, get_for_one_path -------------------------------------------------------------------------------- /gliner/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import os 5 | import numpy as np 6 | import argparse 7 | import torch 8 | from tqdm import tqdm 9 | import random 10 | 11 | def open_content(path): 12 | paths = glob.glob(os.path.join(path, "*.json")) 13 | train, dev, test, labels = None, None, None, None 14 | for p in paths: 15 | if "train" in p: 16 | with open(p, "r") as f: 17 | train = json.load(f) 18 | elif "dev" in p: 19 | with open(p, "r") as f: 20 | dev = json.load(f) 21 | elif "test" in p: 22 | with open(p, "r") as f: 23 | test = json.load(f) 24 | elif "labels" in p: 25 | with open(p, "r") as f: 26 | labels = json.load(f) 27 | return train, dev, test, labels 28 | 29 | 30 | def process(data): 31 | words = data['sentence'].split() 32 | entities = [] # List of entities (start, end, type) 33 | 34 | for entity in data['entities']: 35 | start_char, end_char = entity['pos'] 36 | 37 | # Initialize variables to keep track of word positions 38 | start_word = None 39 | end_word = None 40 | 41 | # Iterate through words and find the word positions 42 | char_count = 0 43 | for i, word in enumerate(words): 44 | word_length = len(word) 45 | if char_count == start_char: 46 | start_word = i 47 | if char_count + word_length == end_char: 48 | end_word = i 49 | break 50 | char_count += word_length + 1 # Add 1 for the space 51 | 52 | # Append the word positions to the list 53 | entities.append((start_word, end_word, entity['type'].lower())) 54 | 55 | # Create a list of word positions for each entity 56 | sample = { 57 | "tokenized_text": words, 58 | "ner": entities 59 | } 60 | 61 | return sample 62 | 63 | 64 | # create dataset 65 | def create_dataset(path): 66 | train, dev, test, labels = open_content(path) 67 | train_dataset = [] 68 | dev_dataset = [] 69 | test_dataset = [] 70 | for data in train: 71 | train_dataset.append(process(data)) 72 | for data in dev: 73 | dev_dataset.append(process(data)) 74 | for data in test: 75 | test_dataset.append(process(data)) 76 | labels = [label.lower() for label in labels] 77 | return train_dataset, dev_dataset, test_dataset, labels 78 | 79 | 80 | @torch.no_grad() 81 | def get_for_one_path(path, model): 82 | # load the dataset 83 | _, _, test_dataset, entity_types = create_dataset(path) 84 | 85 | data_name = path.split("/")[-1] # get the name of the dataset 86 | 87 | # check if the dataset is flat_ner 88 | flat_ner = True 89 | if any([i in data_name for i in ["ACE", "GENIA", "Corpus"]]): 90 | flat_ner = False 91 | 92 | # evaluate the model 93 | results, f1 = model.evaluate(test_dataset, flat_ner=flat_ner, threshold=0.5, batch_size=12, 94 | entity_types=entity_types) 95 | return data_name, results, f1 96 | 97 | 98 | def get_for_all_path(model, steps, log_dir, data_paths): 99 | all_paths = glob.glob(f"{data_paths}/*") 100 | 101 | all_paths = sorted(all_paths) 102 | 103 | # move the model to the device 104 | device = next(model.parameters()).device 105 | model.to(device) 106 | # set the model to eval mode 107 | model.eval() 108 | 109 | # log the results 110 | save_path = os.path.join(log_dir, "results.txt") 111 | 112 | with open(save_path, "a") as f: 113 | f.write("##############################################\n") 114 | # write step 115 | f.write("step: " + str(steps) + "\n") 116 | 117 | zero_shot_benc = ["mit-movie", "mit-restaurant", "CrossNER_AI", "CrossNER_literature", "CrossNER_music", 118 | "CrossNER_politics", "CrossNER_science"] 119 | 120 | zero_shot_benc_results = {} 121 | all_results = {} # without crossNER 122 | 123 | for p in tqdm(all_paths): 124 | if "sample_" not in p: 125 | data_name, results, f1 = get_for_one_path(p, model) 126 | # write to file 127 | with open(save_path, "a") as f: 128 | f.write(data_name + "\n") 129 | f.write(str(results) + "\n") 130 | 131 | if data_name in zero_shot_benc: 132 | zero_shot_benc_results[data_name] = f1 133 | else: 134 | all_results[data_name] = f1 135 | 136 | avg_all = sum(all_results.values()) / len(all_results) 137 | avg_zs = sum(zero_shot_benc_results.values()) / len(zero_shot_benc_results) 138 | 139 | save_path_table = os.path.join(log_dir, "tables.txt") 140 | 141 | # results for all datasets except crossNER 142 | table_bench_all = "" 143 | for k, v in all_results.items(): 144 | table_bench_all += f"{k:20}: {v:.1%}\n" 145 | # (20 size aswell for average i.e. :20) 146 | table_bench_all += f"{'Average':20}: {avg_all:.1%}" 147 | 148 | # results for zero-shot benchmark 149 | table_bench_zeroshot = "" 150 | for k, v in zero_shot_benc_results.items(): 151 | table_bench_zeroshot += f"{k:20}: {v:.1%}\n" 152 | table_bench_zeroshot += f"{'Average':20}: {avg_zs:.1%}" 153 | 154 | # write to file 155 | with open(save_path_table, "a") as f: 156 | f.write("##############################################\n") 157 | f.write("step: " + str(steps) + "\n") 158 | f.write("Table for all datasets except crossNER\n") 159 | f.write(table_bench_all + "\n\n") 160 | f.write("Table for zero-shot benchmark\n") 161 | f.write(table_bench_zeroshot + "\n") 162 | f.write("##############################################\n\n") 163 | 164 | 165 | def sample_train_data(data_paths, sample_size=10000): 166 | all_paths = glob.glob(f"{data_paths}/*") 167 | 168 | all_paths = sorted(all_paths) 169 | 170 | # to exclude the zero-shot benchmark datasets 171 | zero_shot_benc = ["CrossNER_AI", "CrossNER_literature", "CrossNER_music", 172 | "CrossNER_politics", "CrossNER_science", "ACE 2004"] 173 | 174 | new_train = [] 175 | # take 10k samples from each dataset 176 | for p in tqdm(all_paths): 177 | if any([i in p for i in zero_shot_benc]): 178 | continue 179 | train, dev, test, labels = create_dataset(p) 180 | 181 | # add label key to the train data 182 | for i in range(len(train)): 183 | train[i]["label"] = labels 184 | 185 | random.shuffle(train) 186 | train = train[:sample_size] 187 | new_train.extend(train) 188 | 189 | return new_train 190 | -------------------------------------------------------------------------------- /gliner/evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import defaultdict 3 | from typing import Union, List, Literal 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class UndefinedMetricWarning(UserWarning): 10 | pass 11 | 12 | 13 | def _prf_divide( 14 | numerator: np.ndarray, 15 | denominator: np.ndarray, 16 | metric: Literal["precision", "recall", "f-score"], 17 | modifier: str, 18 | average: str, 19 | warn_for: List[str], 20 | zero_division: Union[str, int] = "warn", 21 | ) -> np.ndarray: 22 | """Performs division and handles divide-by-zero with warnings.""" 23 | with np.errstate(divide="ignore", invalid="ignore"): 24 | result = np.true_divide(numerator, denominator) 25 | result[denominator == 0] = 0.0 if zero_division in ["warn", 0] else 1.0 26 | 27 | if denominator == 0 and zero_division == "warn" and metric in warn_for: 28 | msg_start = f"{metric.title()}" 29 | if "f-score" in warn_for: 30 | msg_start += " and F-score" if metric in warn_for else "F-score" 31 | msg_start += " are" if "f-score" in warn_for else " is" 32 | _warn_prf( 33 | average=average, 34 | modifier=modifier, 35 | msg_start=msg_start, 36 | result_size=len(result), 37 | ) 38 | 39 | return result 40 | 41 | 42 | def _warn_prf(average: str, modifier: str, msg_start: str, result_size: int): 43 | axis0, axis1 = ("label", "sample") if average == "samples" else ("sample", "label") 44 | if result_size == 1: 45 | msg = f"{msg_start} ill-defined and being set to 0.0 due to no {modifier} {axis0}." # noqa: E501 46 | else: 47 | msg = f"{msg_start} ill-defined and being set to 0.0 in {axis1}s with no {modifier} {axis0}s." # noqa: E501 48 | msg += " Use `zero_division` parameter to control this behavior." 49 | warnings.warn(msg, UndefinedMetricWarning, stacklevel=3) 50 | 51 | 52 | def extract_tp_actual_correct(y_true, y_pred): 53 | entities_true = defaultdict(set) 54 | entities_pred = defaultdict(set) 55 | 56 | for type_name, (start, end), idx in y_true: 57 | entities_true[type_name].add((start, end, idx)) 58 | for type_name, (start, end), idx in y_pred: 59 | entities_pred[type_name].add((start, end, idx)) 60 | 61 | target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys())) 62 | 63 | tp_sum = np.array([], dtype=np.int32) 64 | pred_sum = np.array([], dtype=np.int32) 65 | true_sum = np.array([], dtype=np.int32) 66 | for type_name in target_names: 67 | entities_true_type = entities_true.get(type_name, set()) 68 | entities_pred_type = entities_pred.get(type_name, set()) 69 | tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type)) 70 | pred_sum = np.append(pred_sum, len(entities_pred_type)) 71 | true_sum = np.append(true_sum, len(entities_true_type)) 72 | 73 | return pred_sum, tp_sum, true_sum, target_names 74 | 75 | 76 | def flatten_for_eval(y_true, y_pred): 77 | all_true = [] 78 | all_pred = [] 79 | 80 | for i, (true, pred) in enumerate(zip(y_true, y_pred)): 81 | all_true.extend([t + [i] for t in true]) 82 | all_pred.extend([p + [i] for p in pred]) 83 | 84 | return all_true, all_pred 85 | 86 | 87 | def compute_prf(y_true, y_pred, average="micro"): 88 | y_true, y_pred = flatten_for_eval(y_true, y_pred) 89 | 90 | pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred) 91 | 92 | if average == "micro": 93 | tp_sum = np.array([tp_sum.sum()]) 94 | pred_sum = np.array([pred_sum.sum()]) 95 | true_sum = np.array([true_sum.sum()]) 96 | 97 | precision = _prf_divide( 98 | numerator=tp_sum, 99 | denominator=pred_sum, 100 | metric="precision", 101 | modifier="predicted", 102 | average=average, 103 | warn_for=["precision", "recall", "f-score"], 104 | zero_division="warn", 105 | ) 106 | 107 | recall = _prf_divide( 108 | numerator=tp_sum, 109 | denominator=true_sum, 110 | metric="recall", 111 | modifier="true", 112 | average=average, 113 | warn_for=["precision", "recall", "f-score"], 114 | zero_division="warn", 115 | ) 116 | 117 | denominator = precision + recall 118 | denominator[denominator == 0.0] = 1 119 | f_score = 2 * (precision * recall) / denominator 120 | 121 | return {"precision": precision[0], "recall": recall[0], "f_score": f_score[0]} 122 | 123 | 124 | class Evaluator: 125 | def __init__(self, all_true, all_outs): 126 | self.all_true = all_true 127 | self.all_outs = all_outs 128 | 129 | def get_entities_fr(self, ents): 130 | all_ents = [] 131 | for s, e, lab in ents: 132 | all_ents.append([lab, (s, e)]) 133 | return all_ents 134 | 135 | def get_entities_pr(self, ents): 136 | all_ents = [] 137 | for s, e, lab, _ in ents: 138 | all_ents.append([lab, (s, e)]) 139 | return all_ents 140 | 141 | def transform_data(self): 142 | all_true_ent = [] 143 | all_outs_ent = [] 144 | for i, j in zip(self.all_true, self.all_outs): 145 | e = self.get_entities_fr(i) 146 | all_true_ent.append(e) 147 | e = self.get_entities_pr(j) 148 | all_outs_ent.append(e) 149 | return all_true_ent, all_outs_ent 150 | 151 | @torch.no_grad() 152 | def evaluate(self): 153 | all_true_typed, all_outs_typed = self.transform_data() 154 | precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values() 155 | output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n" 156 | return output_str, f1 157 | 158 | 159 | def is_nested(idx1, idx2): 160 | # Return True if idx2 is nested inside idx1 or vice versa 161 | return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or ( 162 | idx2[0] <= idx1[0] and idx2[1] >= idx1[1] 163 | ) 164 | 165 | 166 | def has_overlapping(idx1, idx2, multi_label=False): 167 | # Check for any overlap between two spans 168 | if idx1[:2] == idx2[:2]: # Exact same boundaries can be considered as overlapping 169 | return not multi_label 170 | if idx1[0] > idx2[1] or idx2[0] > idx1[1]: 171 | return False 172 | return True 173 | 174 | 175 | def has_overlapping_nested(idx1, idx2, multi_label=False): 176 | # Return True if idx1 and idx2 overlap, but neither is nested inside the other 177 | if idx1[:2] == idx2[:2]: # Exact same boundaries, not considering labels here 178 | return not multi_label 179 | if (idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2): 180 | return False 181 | return True 182 | 183 | 184 | from functools import partial 185 | 186 | 187 | def greedy_search(spans, flat_ner=True, multi_label=False): # start, end, class, score 188 | if flat_ner: 189 | has_ov = partial(has_overlapping, multi_label=multi_label) 190 | else: 191 | has_ov = partial(has_overlapping_nested, multi_label=multi_label) 192 | 193 | new_list = [] 194 | span_prob = sorted(spans, key=lambda x: -x[-1]) 195 | 196 | for i in range(len(spans)): 197 | b = span_prob[i] 198 | flag = False 199 | for new in new_list: 200 | if has_ov(b[:-1], new): 201 | flag = True 202 | break 203 | if not flag: 204 | new_list.append(b) 205 | 206 | new_list = sorted(new_list, key=lambda x: x[0]) 207 | return new_list 208 | -------------------------------------------------------------------------------- /gliner/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/urchade/GLiNER/efbfa38211136657895372d33d4ee2fe11b6f11b/gliner/modeling/__init__.py -------------------------------------------------------------------------------- /gliner/modeling/encoder.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import nn 6 | from transformers import AutoModel, AutoConfig 7 | 8 | from .layers import LayersFuser 9 | from ..utils import is_module_available, MissedPackageException 10 | from typing import Optional, Union 11 | 12 | IS_LLM2VEC = is_module_available('llm2vec') 13 | IS_PEFT = is_module_available('peft') 14 | IS_TURBOT5 = is_module_available('turbot5') 15 | IS_FLASHDEBERTA = is_module_available('flashdeberta') 16 | 17 | if IS_LLM2VEC: 18 | from llm2vec.models import MistralBiModel, LlamaBiModel, GemmaBiModel, Qwen2BiModel 19 | DECODER_MODEL_MAPPING = { 20 | "MistralConfig": MistralBiModel, 21 | "LlamaConfig": LlamaBiModel, 22 | "GemmaConfig": GemmaBiModel, 23 | "Qwen2Config": Qwen2BiModel 24 | } 25 | else: 26 | DECODER_MODEL_MAPPING = {} 27 | 28 | if IS_TURBOT5: 29 | from turbot5.model.modeling import T5EncoderModel 30 | else: 31 | from transformers import T5EncoderModel 32 | 33 | if IS_FLASHDEBERTA: 34 | from flashdeberta import FlashDebertaV2Model as DebertaV2Model 35 | else: 36 | from transformers import DebertaV2Model 37 | 38 | if IS_PEFT: 39 | from peft import LoraConfig, get_peft_model 40 | 41 | class Transformer(nn.Module): 42 | def __init__( 43 | self, 44 | model_name, 45 | config, 46 | from_pretrained=False, 47 | labels_encoder = False, 48 | cache_dir:Optional[Union[str, Path]] = None 49 | ): 50 | super().__init__() 51 | if labels_encoder: 52 | encoder_config = config.labels_encoder_config 53 | else: 54 | encoder_config = config.encoder_config 55 | if encoder_config is None: 56 | encoder_config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) 57 | if config.vocab_size!=-1: 58 | encoder_config.vocab_size = config.vocab_size 59 | 60 | if config._attn_implementation is not None and not labels_encoder: 61 | encoder_config._attn_implementation = config._attn_implementation 62 | 63 | config_name = encoder_config.__class__.__name__ 64 | 65 | kwargs = {} 66 | if config_name in DECODER_MODEL_MAPPING: 67 | if not IS_LLM2VEC: 68 | raise MissedPackageException(f"The llm2vec package must be installed to use this decoder model: {config_name}") 69 | else: 70 | print('Loading decoder model using LLM2Vec...') 71 | ModelClass = DECODER_MODEL_MAPPING[config_name] 72 | custom = True 73 | elif config_name in {'T5Config', 'MT5Config'}: 74 | custom = True 75 | ModelClass = T5EncoderModel 76 | if IS_TURBOT5: 77 | kwargs = {"attention_type": 'flash'} 78 | elif config_name in {'DebertaV2Config'}: 79 | custom = True 80 | ModelClass = DebertaV2Model 81 | else: 82 | custom = False 83 | ModelClass = AutoModel 84 | 85 | if from_pretrained: 86 | self.model = ModelClass.from_pretrained(model_name, trust_remote_code=True) 87 | else: 88 | if not custom: 89 | self.model = ModelClass.from_config(encoder_config, trust_remote_code=True) 90 | else: 91 | self.model = ModelClass(encoder_config, **kwargs) 92 | 93 | adapter_config_file = Path(model_name) / "adapter_config.json" 94 | 95 | if adapter_config_file.exists(): 96 | if not IS_PEFT: 97 | warnings.warn(f"Adapter configs were detected, if you want to apply them you need to install peft package.") 98 | else: 99 | adapter_config = LoraConfig.from_pretrained(model_name) 100 | self.model = get_peft_model(self.model, adapter_config) 101 | 102 | if config.fuse_layers: 103 | self.layers_fuser = LayersFuser(encoder_config.num_hidden_layers, 104 | encoder_config.hidden_size) 105 | 106 | if labels_encoder: 107 | config.labels_encoder_config = encoder_config 108 | else: 109 | config.encoder_config = encoder_config 110 | 111 | self.config = config 112 | 113 | def forward(self, *args, **kwargs): 114 | if self.config.fuse_layers: 115 | output_hidden_states = True 116 | else: 117 | output_hidden_states = False 118 | output = self.model(*args, output_hidden_states = output_hidden_states, 119 | return_dict = True, **kwargs) 120 | if self.config.fuse_layers: 121 | encoder_layer = self.layers_fuser(output.hidden_states) 122 | else: 123 | encoder_layer = output[0] 124 | 125 | return encoder_layer 126 | 127 | class Encoder(nn.Module): 128 | def __init__(self, config, from_pretrained: bool = False, cache_dir: Optional[Union[str, Path]]= None): 129 | super().__init__() 130 | 131 | self.bert_layer = Transformer( #transformer_model 132 | config.model_name, config, from_pretrained, cache_dir = cache_dir 133 | ) 134 | 135 | bert_hidden_size = self.bert_layer.model.config.hidden_size 136 | 137 | if config.hidden_size != bert_hidden_size: 138 | self.projection = nn.Linear(bert_hidden_size, config.hidden_size) 139 | 140 | def resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): 141 | return self.bert_layer.model.resize_token_embeddings(new_num_tokens, 142 | pad_to_multiple_of) 143 | 144 | def get_input_embeddings(self): 145 | return self.bert_layer.model.get_input_embeddings() 146 | 147 | def encode_text(self, input_ids, attention_mask, *args, **kwargs): 148 | token_embeddings = self.bert_layer(input_ids, attention_mask, *args, **kwargs) 149 | if hasattr(self, "projection"): 150 | token_embeddings = self.projection(token_embeddings) 151 | return token_embeddings 152 | 153 | def forward(self, *args, **kwargs) -> torch.Tensor: 154 | token_embeddings = self.encode_text(*args, **kwargs) 155 | return token_embeddings 156 | 157 | class BiEncoder(Encoder): 158 | def __init__(self, config, from_pretrained: bool = False, cache_dir:Optional[Union[str, Path]] = None): 159 | super().__init__(config, from_pretrained) 160 | if config.labels_encoder is not None: 161 | self.labels_encoder = Transformer( #transformer_model 162 | config.labels_encoder, config, from_pretrained, True, cache_dir=cache_dir 163 | ) 164 | le_hidden_size = self.labels_encoder.model.config.hidden_size 165 | 166 | if config.hidden_size != le_hidden_size: 167 | self.labels_projection = nn.Linear(le_hidden_size, config.hidden_size) 168 | 169 | def mean_pooling(self, token_embeddings, attention_mask): 170 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 171 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 172 | 173 | def encode_labels(self, input_ids, attention_mask, *args, **kwargs): 174 | labels_embeddings = self.labels_encoder(input_ids, attention_mask, *args, **kwargs) 175 | if hasattr(self, "labels_projection"): 176 | labels_embeddings = self.labels_projection(labels_embeddings) 177 | labels_embeddings = self.mean_pooling(labels_embeddings, attention_mask) 178 | return labels_embeddings 179 | 180 | def forward(self, input_ids, attention_mask, 181 | labels_input_ids = None, labels_attention_mask=None, 182 | *args, **kwargs) -> torch.Tensor: 183 | token_embeddings = self.encode_text(input_ids, attention_mask, *args, **kwargs) 184 | 185 | labels_embeddings = self.encode_labels(labels_input_ids, labels_attention_mask, *args, **kwargs) 186 | return token_embeddings, labels_embeddings -------------------------------------------------------------------------------- /gliner/modeling/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence 5 | 6 | class LstmSeq2SeqEncoder(nn.Module): 7 | def __init__(self, config, num_layers=1, dropout=0., bidirectional=True): 8 | super(LstmSeq2SeqEncoder, self).__init__() 9 | self.lstm = nn.LSTM(input_size=config.hidden_size, 10 | hidden_size=config.hidden_size//2, 11 | num_layers=num_layers, 12 | dropout=dropout, 13 | bidirectional=bidirectional, 14 | batch_first=True) 15 | 16 | def forward(self, x, mask, hidden=None): 17 | # Packing the input sequence 18 | lengths = mask.sum(dim=1).cpu() 19 | packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) 20 | 21 | # Passing packed sequence through LSTM 22 | packed_output, hidden = self.lstm(packed_x, hidden) 23 | 24 | # Unpacking the output sequence 25 | output, _ = pad_packed_sequence(packed_output, batch_first=True) 26 | 27 | return output 28 | 29 | 30 | def create_projection_layer(hidden_size: int, dropout: float, out_dim: int = None) -> nn.Sequential: 31 | """ 32 | Creates a projection layer with specified configurations. 33 | """ 34 | if out_dim is None: 35 | out_dim = hidden_size 36 | 37 | return nn.Sequential( 38 | nn.Linear(hidden_size, out_dim * 4), 39 | nn.ReLU(), 40 | nn.Dropout(dropout), 41 | nn.Linear(out_dim * 4, out_dim) 42 | ) 43 | 44 | class MultiheadAttention(nn.Module): 45 | def __init__(self, hidden_size, num_heads, dropout) -> None: 46 | super().__init__() 47 | self.hidden_size=hidden_size 48 | self.num_heads=num_heads 49 | self.attention_head_size=hidden_size//num_heads 50 | self.attention_probs_dropout_prob=dropout 51 | self.query_layer = nn.Linear(hidden_size, hidden_size) 52 | self.key_layer = nn.Linear(hidden_size, hidden_size) 53 | self.value_layer = nn.Linear(hidden_size, hidden_size) 54 | 55 | def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: 56 | new_x_shape = x.size()[:-1] + (self.num_heads, self.attention_head_size) 57 | x = x.view(new_x_shape) 58 | return x.permute(0, 2, 1, 3) 59 | 60 | def forward(self, query, key=None, value=None, head_mask=None, attn_mask=None): 61 | query = self.transpose_for_scores(self.query_layer(query)) 62 | if key is None: 63 | key = self.transpose_for_scores(self.key_layer(query)) 64 | else: 65 | key = self.transpose_for_scores(self.key_layer(key)) 66 | if value is None and key is None: 67 | value = self.transpose_for_scores(self.value_layer(query)) 68 | elif value is None and key is not None: 69 | value = self.transpose_for_scores(self.value_layer(key)) 70 | else: 71 | value = self.transpose_for_scores(self.value_layer(value)) 72 | 73 | context_layer = torch.nn.functional.scaled_dot_product_attention( 74 | query, 75 | key, 76 | value, 77 | head_mask, 78 | self.attention_probs_dropout_prob if self.training else 0.0, 79 | is_causal=False, 80 | scale=None, 81 | ) 82 | 83 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 84 | new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) 85 | context_layer = context_layer.view(new_context_layer_shape) 86 | 87 | return context_layer, None 88 | 89 | class SelfAttentionBlock(nn.Module): 90 | def __init__(self, d_model, num_heads, dropout=0.1): 91 | super().__init__() 92 | self.self_attn = MultiheadAttention(d_model, num_heads, dropout=dropout) 93 | self.pre_norm = nn.LayerNorm(d_model) 94 | self.post_norm = nn.LayerNorm(d_model) 95 | self.dropout = nn.Dropout(dropout) 96 | self.q_proj = nn.Linear(d_model, d_model) 97 | self.k_proj = nn.Linear(d_model, d_model) 98 | self.v_proj = nn.Linear(d_model, d_model) 99 | 100 | def forward(self, x, mask=None): 101 | x = self.pre_norm(x) 102 | q = self.q_proj(x) 103 | k = self.k_proj(x) 104 | v = self.v_proj(x) 105 | attn_output, _ = self.self_attn(q, k, v, attn_mask=mask) 106 | output = x + self.dropout(attn_output) 107 | return self.post_norm(output) 108 | 109 | class CrossAttentionBlock(nn.Module): 110 | def __init__(self, d_model, num_heads, dropout=0.1): 111 | super().__init__() 112 | self.cross_attn = MultiheadAttention(d_model, num_heads, dropout=dropout) 113 | self.pre_norm = nn.LayerNorm(d_model) 114 | self.post_norm = nn.LayerNorm(d_model) 115 | self.dropout = nn.Dropout(dropout) 116 | self.v_proj = nn.Linear(d_model, d_model) 117 | 118 | def forward(self, query, key, value=None, mask=None): 119 | query = self.pre_norm(query) 120 | if value is None: 121 | value = self.v_proj(key) 122 | attn_output, _ = self.cross_attn(query, key, value, attn_mask=mask) 123 | output = query + self.dropout(attn_output) 124 | return self.post_norm(output) 125 | 126 | class CrossFuser(nn.Module): 127 | def __init__(self, d_model, query_dim, num_heads=8, num_layers=1, dropout=0.1, schema='l2l-l2t'): 128 | super().__init__() 129 | self.d_model = d_model 130 | self.schema = schema.split('-') 131 | layers = [] 132 | for _ in range(num_layers): 133 | layer = [] 134 | for attn_type in self.schema: 135 | if attn_type in {'l2l', 't2t'}: 136 | layer.append(SelfAttentionBlock(d_model, num_heads, dropout)) 137 | else: 138 | layer.append(CrossAttentionBlock(d_model, num_heads, dropout)) 139 | layer = nn.ModuleList(layer) 140 | layers.append(layer) 141 | 142 | self.layers = nn.ModuleList(layers) 143 | # self.dense_i = nn.Linear(query_dim, d_model) 144 | # self.dense_o = nn.Linear(d_model, query_dim) 145 | 146 | def forward(self, query, key, query_mask=None, key_mask=None): 147 | # query = self.dense_i(query) 148 | for sublayers in self.layers: 149 | for id, layer in enumerate(sublayers): 150 | if self.schema[id] == 'l2l': 151 | if query_mask is not None: 152 | self_attn_mask = query_mask.unsqueeze(1) * query_mask.unsqueeze(2) 153 | else: 154 | self_attn_mask = None 155 | query = layer(query, mask=self_attn_mask) 156 | elif self.schema[id] == 't2t': 157 | if key_mask is not None: 158 | self_attn_mask = key_mask.unsqueeze(1) * key_mask.unsqueeze(2) 159 | else: 160 | self_attn_mask = None 161 | key = layer(key, mask=self_attn_mask) 162 | elif self.schema[id] == 'l2t': 163 | if query_mask is not None and key_mask is not None: 164 | cross_attn_mask = query_mask.unsqueeze(-1) * key_mask.unsqueeze(1) 165 | else: 166 | cross_attn_mask = None 167 | query = layer(query, key, mask=cross_attn_mask) 168 | elif self.schema[id] == 't2l': 169 | if query_mask is not None and key_mask is not None: 170 | cross_attn_mask = key_mask.unsqueeze(-1) * query_mask.unsqueeze(1) 171 | else: 172 | cross_attn_mask = None 173 | key = layer(key, query, mask=cross_attn_mask) 174 | # query=self.dense_o(query) 175 | return query, key 176 | 177 | class LayersFuser(nn.Module): 178 | def __init__(self, num_layers, hidden_size, output_size=None): 179 | super().__init__() 180 | self.num_layers = num_layers 181 | self.hidden_size = hidden_size 182 | self.output_size = output_size if output_size is not None else hidden_size 183 | 184 | # Squeeze operation 185 | self.squeeze = nn.Linear(hidden_size, 1) 186 | 187 | # Excitation operation 188 | self.W1 = nn.Linear(num_layers, num_layers // 2) 189 | self.W2 = nn.Linear(num_layers // 2, num_layers) 190 | 191 | # Final projection 192 | self.output_projection = nn.Linear(self.hidden_size, self.output_size) 193 | 194 | def forward(self, encoder_outputs): 195 | # encoder_outputs is a list of tensors, each of shape [B, L, D] 196 | B, L, D = encoder_outputs[0].shape 197 | 198 | # Concatenate all layers 199 | U = torch.stack(encoder_outputs[1:], dim=1) # [B, K, L, D] 200 | 201 | # Squeeze operation 202 | Z = self.squeeze(U).squeeze(-1) # [B, K, L] 203 | Z = Z.mean(dim=2) # [B, K] 204 | 205 | # Excitation operation 206 | s = self.W2(F.relu(self.W1(Z))) # [B, K] 207 | s = torch.sigmoid(s) # [B, K] 208 | 209 | # Apply attention weights 210 | U_weighted = U * s.unsqueeze(-1).unsqueeze(-1) # [B, K, L, D] 211 | 212 | # Sum across layers 213 | U_sum = U_weighted.sum(dim=1) # [B, L, D] 214 | 215 | # final projection 216 | output = self.output_projection(U_sum) # [B, L, output_size] 217 | 218 | return output -------------------------------------------------------------------------------- /gliner/modeling/loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def focal_loss_with_logits( 6 | inputs: torch.Tensor, 7 | targets: torch.Tensor, 8 | alpha: float = 0.25, 9 | gamma: float = 2, 10 | reduction: str = "none", 11 | label_smoothing: float = 0.0, 12 | ignore_index: int = -100 # default value for ignored index 13 | ) -> torch.Tensor: 14 | """ 15 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 16 | 17 | Args: 18 | inputs (Tensor): A float tensor of arbitrary shape. 19 | The predictions for each example. 20 | targets (Tensor): A float tensor with the same shape as inputs. Stores the binary 21 | classification label for each element in inputs 22 | (0 for the negative class and 1 for the positive class). 23 | alpha (float): Weighting factor in range (0,1) to balance 24 | positive vs negative examples or -1 for ignore. Default: ``0.25``. 25 | gamma (float): Exponent of the modulating factor (1 - p_t) to 26 | balance easy vs hard examples. Default: ``2``. 27 | reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` 28 | ``'none'``: No reduction will be applied to the output. 29 | ``'mean'``: The output will be averaged. 30 | ``'sum'``: The output will be summed. Default: ``'none'``. 31 | label_smoothing (float): Specifies the amount of smoothing when computing the loss, 32 | where 0.0 means no smoothing. 33 | ignore_index (int): Specifies a target value that is ignored and does not contribute 34 | to the input gradient. Default: ``-100``. 35 | Returns: 36 | Loss tensor with the reduction option applied. 37 | """ 38 | # Create a mask to ignore specified index 39 | valid_mask = targets != ignore_index 40 | 41 | # Apply label smoothing if needed 42 | if label_smoothing != 0: 43 | with torch.no_grad(): 44 | targets = targets * (1 - label_smoothing) + 0.5 * label_smoothing 45 | 46 | # Apply sigmoid activation to inputs 47 | p = torch.sigmoid(inputs) 48 | 49 | # Compute the binary cross-entropy loss without reduction 50 | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 51 | 52 | # Apply the valid mask to the loss 53 | loss = loss * valid_mask 54 | 55 | # Apply focal loss modulation if gamma is greater than 0 56 | if gamma > 0: 57 | p_t = p * targets + (1 - p) * (1 - targets) 58 | loss = loss * ((1 - p_t) ** gamma) 59 | 60 | # Apply alpha weighting if alpha is specified 61 | if alpha >= 0: 62 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 63 | loss = alpha_t * loss 64 | 65 | # Apply reduction method 66 | if reduction == "none": 67 | return loss 68 | elif reduction == "mean": 69 | return loss.sum() / valid_mask.sum() # Normalize by the number of valid (non-ignored) elements 70 | elif reduction == "sum": 71 | return loss.sum() 72 | else: 73 | raise ValueError( 74 | f"Invalid value for argument 'reduction': '{reduction}'. " 75 | f"Supported reduction modes: 'none', 'mean', 'sum'" 76 | ) -------------------------------------------------------------------------------- /gliner/modeling/scorers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class Scorer(nn.Module): 5 | def __init__(self, hidden_size, dropout=0.1): 6 | super().__init__() 7 | 8 | self.proj_token = nn.Linear(hidden_size, hidden_size * 2) 9 | self.proj_label = nn.Linear(hidden_size, hidden_size * 2) 10 | 11 | self.out_mlp = nn.Sequential( 12 | nn.Linear(hidden_size * 3, hidden_size * 4), 13 | nn.Dropout(dropout), 14 | nn.ReLU(), 15 | nn.Linear(hidden_size * 4, 3) # start, end, score 16 | ) 17 | 18 | def forward(self, token_rep, label_rep): 19 | batch_size, seq_len, hidden_size = token_rep.shape 20 | num_classes = label_rep.shape[1] 21 | 22 | # (batch_size, seq_len, 3, hidden_size) 23 | token_rep = self.proj_token(token_rep).view(batch_size, seq_len, 1, 2, hidden_size) 24 | label_rep = self.proj_label(label_rep).view(batch_size, 1, num_classes, 2, hidden_size) 25 | 26 | # (2, batch_size, seq_len, num_classes, hidden_size) 27 | token_rep = token_rep.expand(-1, -1, num_classes, -1, -1).permute(3, 0, 1, 2, 4) 28 | label_rep = label_rep.expand(-1, seq_len, -1, -1, -1).permute(3, 0, 1, 2, 4) 29 | 30 | # (batch_size, seq_len, num_classes, hidden_size * 3) 31 | cat = torch.cat([token_rep[0], label_rep[0], token_rep[1] * label_rep[1]], dim=-1) 32 | 33 | # (batch_size, seq_len, num_classes, 3) 34 | scores = self.out_mlp(cat).permute(3, 0, 1, 2) 35 | 36 | return scores 37 | -------------------------------------------------------------------------------- /gliner/modeling/span_rep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from .layers import create_projection_layer 6 | 7 | class SpanQuery(nn.Module): 8 | 9 | def __init__(self, hidden_size, max_width, trainable=True): 10 | super().__init__() 11 | 12 | self.query_seg = nn.Parameter(torch.randn(hidden_size, max_width)) 13 | 14 | nn.init.uniform_(self.query_seg, a=-1, b=1) 15 | 16 | if not trainable: 17 | self.query_seg.requires_grad = False 18 | 19 | self.project = nn.Sequential( 20 | nn.Linear(hidden_size, hidden_size), 21 | nn.ReLU() 22 | ) 23 | 24 | def forward(self, h, *args): 25 | # h of shape [B, L, D] 26 | # query_seg of shape [D, max_width] 27 | 28 | span_rep = torch.einsum('bld, ds->blsd', h, self.query_seg) 29 | 30 | return self.project(span_rep) 31 | 32 | 33 | class SpanMLP(nn.Module): 34 | 35 | def __init__(self, hidden_size, max_width): 36 | super().__init__() 37 | 38 | self.mlp = nn.Linear(hidden_size, hidden_size * max_width) 39 | 40 | def forward(self, h, *args): 41 | # h of shape [B, L, D] 42 | # query_seg of shape [D, max_width] 43 | 44 | B, L, D = h.size() 45 | 46 | span_rep = self.mlp(h) 47 | 48 | span_rep = span_rep.view(B, L, -1, D) 49 | 50 | return span_rep.relu() 51 | 52 | 53 | class SpanCAT(nn.Module): 54 | 55 | def __init__(self, hidden_size, max_width): 56 | super().__init__() 57 | 58 | self.max_width = max_width 59 | 60 | self.query_seg = nn.Parameter(torch.randn(128, max_width)) 61 | 62 | self.project = nn.Sequential( 63 | nn.Linear(hidden_size + 128, hidden_size), 64 | nn.ReLU() 65 | ) 66 | 67 | def forward(self, h, *args): 68 | # h of shape [B, L, D] 69 | # query_seg of shape [D, max_width] 70 | 71 | B, L, D = h.size() 72 | 73 | h = h.view(B, L, 1, D).repeat(1, 1, self.max_width, 1) 74 | 75 | q = self.query_seg.view(1, 1, self.max_width, -1).repeat(B, L, 1, 1) 76 | 77 | span_rep = torch.cat([h, q], dim=-1) 78 | 79 | span_rep = self.project(span_rep) 80 | 81 | return span_rep 82 | 83 | 84 | class SpanConvBlock(nn.Module): 85 | def __init__(self, hidden_size, kernel_size, span_mode='conv_normal'): 86 | super().__init__() 87 | 88 | if span_mode == 'conv_conv': 89 | self.conv = nn.Conv1d(hidden_size, hidden_size, 90 | kernel_size=kernel_size) 91 | 92 | # initialize the weights 93 | nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='relu') 94 | 95 | elif span_mode == 'conv_max': 96 | self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1) 97 | elif span_mode == 'conv_mean' or span_mode == 'conv_sum': 98 | self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1) 99 | 100 | self.span_mode = span_mode 101 | 102 | self.pad = kernel_size - 1 103 | 104 | def forward(self, x): 105 | 106 | x = torch.einsum('bld->bdl', x) 107 | 108 | if self.pad > 0: 109 | x = F.pad(x, (0, self.pad), "constant", 0) 110 | 111 | x = self.conv(x) 112 | 113 | if self.span_mode == "conv_sum": 114 | x = x * (self.pad + 1) 115 | 116 | return torch.einsum('bdl->bld', x) 117 | 118 | 119 | class SpanConv(nn.Module): 120 | def __init__(self, hidden_size, max_width, span_mode): 121 | super().__init__() 122 | 123 | kernels = [i + 2 for i in range(max_width - 1)] 124 | 125 | self.convs = nn.ModuleList() 126 | 127 | for kernel in kernels: 128 | self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode)) 129 | 130 | self.project = nn.Sequential( 131 | nn.ReLU(), 132 | nn.Linear(hidden_size, hidden_size) 133 | ) 134 | 135 | def forward(self, x, *args): 136 | 137 | span_reps = [x] 138 | 139 | for conv in self.convs: 140 | h = conv(x) 141 | span_reps.append(h) 142 | 143 | span_reps = torch.stack(span_reps, dim=-2) 144 | 145 | return self.project(span_reps) 146 | 147 | 148 | class SpanEndpointsBlock(nn.Module): 149 | def __init__(self, kernel_size): 150 | super().__init__() 151 | 152 | self.kernel_size = kernel_size 153 | 154 | def forward(self, x): 155 | B, L, D = x.size() 156 | 157 | span_idx = torch.LongTensor( 158 | [[i, i + self.kernel_size - 1] for i in range(L)]).to(x.device) 159 | 160 | x = F.pad(x, (0, 0, 0, self.kernel_size - 1), "constant", 0) 161 | 162 | # endrep 163 | start_end_rep = torch.index_select(x, dim=1, index=span_idx.view(-1)) 164 | 165 | start_end_rep = start_end_rep.view(B, L, 2, D) 166 | 167 | return start_end_rep 168 | 169 | 170 | class ConvShare(nn.Module): 171 | def __init__(self, hidden_size, max_width): 172 | super().__init__() 173 | 174 | self.max_width = max_width 175 | 176 | self.conv_weigth = nn.Parameter( 177 | torch.randn(hidden_size, hidden_size, max_width)) 178 | 179 | nn.init.kaiming_uniform_(self.conv_weigth, nonlinearity='relu') 180 | 181 | self.project = nn.Sequential( 182 | nn.ReLU(), 183 | nn.Linear(hidden_size, hidden_size) 184 | ) 185 | 186 | def forward(self, x, *args): 187 | span_reps = [] 188 | 189 | x = torch.einsum('bld->bdl', x) 190 | 191 | for i in range(self.max_width): 192 | pad = i 193 | x_i = F.pad(x, (0, pad), "constant", 0) 194 | conv_w = self.conv_weigth[:, :, :i + 1] 195 | out_i = F.conv1d(x_i, conv_w) 196 | span_reps.append(out_i.transpose(-1, -2)) 197 | 198 | out = torch.stack(span_reps, dim=-2) 199 | 200 | return self.project(out) 201 | 202 | 203 | def extract_elements(sequence, indices): 204 | B, L, D = sequence.shape 205 | K = indices.shape[1] 206 | 207 | # Expand indices to [B, K, D] 208 | expanded_indices = indices.unsqueeze(2).expand(-1, -1, D) 209 | 210 | # Gather the elements 211 | extracted_elements = torch.gather(sequence, 1, expanded_indices) 212 | 213 | return extracted_elements 214 | 215 | 216 | class SpanMarker(nn.Module): 217 | 218 | def __init__(self, hidden_size, max_width, dropout=0.4): 219 | super().__init__() 220 | 221 | self.max_width = max_width 222 | 223 | self.project_start = nn.Sequential( 224 | nn.Linear(hidden_size, hidden_size * 2, bias=True), 225 | nn.ReLU(), 226 | nn.Dropout(dropout), 227 | nn.Linear(hidden_size * 2, hidden_size, bias=True), 228 | ) 229 | 230 | self.project_end = nn.Sequential( 231 | nn.Linear(hidden_size, hidden_size * 2, bias=True), 232 | nn.ReLU(), 233 | nn.Dropout(dropout), 234 | nn.Linear(hidden_size * 2, hidden_size, bias=True), 235 | ) 236 | 237 | self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True) 238 | 239 | def forward(self, h, span_idx): 240 | # h of shape [B, L, D] 241 | # query_seg of shape [D, max_width] 242 | 243 | B, L, D = h.size() 244 | 245 | # project start and end 246 | start_rep = self.project_start(h) 247 | end_rep = self.project_end(h) 248 | 249 | start_span_rep = extract_elements(start_rep, span_idx[:, :, 0]) 250 | end_span_rep = extract_elements(end_rep, span_idx[:, :, 1]) 251 | 252 | # concat start and end 253 | cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu() 254 | 255 | # project 256 | cat = self.out_project(cat) 257 | 258 | # reshape 259 | return cat.view(B, L, self.max_width, D) 260 | 261 | 262 | class SpanMarkerV0(nn.Module): 263 | """ 264 | Marks and projects span endpoints using an MLP. 265 | """ 266 | 267 | def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4): 268 | super().__init__() 269 | self.max_width = max_width 270 | self.project_start = create_projection_layer(hidden_size, dropout) 271 | self.project_end = create_projection_layer(hidden_size, dropout) 272 | 273 | self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size) 274 | 275 | def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor: 276 | B, L, D = h.size() 277 | 278 | start_rep = self.project_start(h) 279 | end_rep = self.project_end(h) 280 | 281 | start_span_rep = extract_elements(start_rep, span_idx[:, :, 0]) 282 | end_span_rep = extract_elements(end_rep, span_idx[:, :, 1]) 283 | 284 | cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu() 285 | 286 | return self.out_project(cat).view(B, L, self.max_width, D) 287 | 288 | 289 | class ConvShareV2(nn.Module): 290 | def __init__(self, hidden_size, max_width): 291 | super().__init__() 292 | 293 | self.max_width = max_width 294 | 295 | self.conv_weigth = nn.Parameter( 296 | torch.randn(hidden_size, hidden_size, max_width) 297 | ) 298 | 299 | nn.init.xavier_normal_(self.conv_weigth) 300 | 301 | def forward(self, x, *args): 302 | span_reps = [] 303 | 304 | x = torch.einsum('bld->bdl', x) 305 | 306 | for i in range(self.max_width): 307 | pad = i 308 | x_i = F.pad(x, (0, pad), "constant", 0) 309 | conv_w = self.conv_weigth[:, :, :i + 1] 310 | out_i = F.conv1d(x_i, conv_w) 311 | span_reps.append(out_i.transpose(-1, -2)) 312 | 313 | out = torch.stack(span_reps, dim=-2) 314 | 315 | return out 316 | 317 | 318 | class SpanRepLayer(nn.Module): 319 | """ 320 | Various span representation approaches 321 | """ 322 | 323 | def __init__(self, hidden_size, max_width, span_mode, **kwargs): 324 | super().__init__() 325 | 326 | if span_mode == 'marker': 327 | self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs) 328 | elif span_mode == 'markerV0': 329 | self.span_rep_layer = SpanMarkerV0(hidden_size, max_width, **kwargs) 330 | elif span_mode == 'query': 331 | self.span_rep_layer = SpanQuery( 332 | hidden_size, max_width, trainable=True) 333 | elif span_mode == 'mlp': 334 | self.span_rep_layer = SpanMLP(hidden_size, max_width) 335 | elif span_mode == 'cat': 336 | self.span_rep_layer = SpanCAT(hidden_size, max_width) 337 | elif span_mode == 'conv_conv': 338 | self.span_rep_layer = SpanConv( 339 | hidden_size, max_width, span_mode='conv_conv') 340 | elif span_mode == 'conv_max': 341 | self.span_rep_layer = SpanConv( 342 | hidden_size, max_width, span_mode='conv_max') 343 | elif span_mode == 'conv_mean': 344 | self.span_rep_layer = SpanConv( 345 | hidden_size, max_width, span_mode='conv_mean') 346 | elif span_mode == 'conv_sum': 347 | self.span_rep_layer = SpanConv( 348 | hidden_size, max_width, span_mode='conv_sum') 349 | elif span_mode == 'conv_share': 350 | self.span_rep_layer = ConvShare(hidden_size, max_width) 351 | else: 352 | raise ValueError(f'Unknown span mode {span_mode}') 353 | 354 | def forward(self, x, *args): 355 | 356 | return self.span_rep_layer(x, *args) 357 | -------------------------------------------------------------------------------- /gliner/multitask/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification import GLiNERClassifier 2 | from .question_answering import GLiNERQuestionAnswerer, GLiNERSquadEvaluator 3 | from .open_extraction import GLiNEROpenExtractor 4 | from .relation_extraction import GLiNERRelationExtractor, GLiNERDocREDEvaluator 5 | from .summarization import GLiNERSummarizer -------------------------------------------------------------------------------- /gliner/multitask/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Union, Optional 3 | import torch 4 | import warnings 5 | 6 | from ..model import GLiNER 7 | 8 | class GLiNERBasePipeline(ABC): 9 | """ 10 | Base class for GLiNER pipelines. Provides an interface for preparing texts, 11 | processing predictions, and evaluating the model. 12 | 13 | Args: 14 | model_id (str): Identifier for the model to be loaded. 15 | prompt (str, optional): Prompt template for text preparation. Defaults to None. 16 | device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'. 17 | 18 | Attributes: 19 | model (GLiNER): The loaded GLiNER model. 20 | device (str): The device being used for computation. 21 | prompt (str): The prompt template for text preparation. 22 | """ 23 | 24 | def __init__(self, model_id: str = None, model: GLiNER = None, prompt=None, device='cuda:0'): 25 | """ 26 | Initializes the GLiNERBasePipeline. 27 | 28 | Args: 29 | model_id (str): Identifier for the model to be loaded. 30 | prompt (str, optional): Prompt template for text preparation. Defaults to None. 31 | device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'. 32 | """ 33 | if 'cuda' in device and not torch.cuda.is_available(): 34 | warnings.warn(f"{device} is not available, setting device as 'cpu'.") 35 | device = 'cpu' 36 | self.device = device 37 | 38 | if model is not None: 39 | self.model = model.to(self.device) 40 | elif model_id is not None: 41 | self.model = GLiNER.from_pretrained(model_id).to(self.device) 42 | else: 43 | raise ValueError("Either 'model_id' or 'model' must be provided to initialize the pipeline.") 44 | 45 | self.prompt = prompt 46 | 47 | @abstractmethod 48 | def prepare_texts(self, texts: List[str], *args, **kwargs): 49 | """ 50 | Prepares texts for input to the model. 51 | 52 | Args: 53 | texts (List[str]): List of input texts. 54 | *args: Additional positional arguments. 55 | **kwargs: Additional keyword arguments. 56 | 57 | Returns: 58 | Any: The processed texts ready for model input. 59 | """ 60 | pass 61 | 62 | @abstractmethod 63 | def process_predictions(self, predictions: List[dict]): 64 | """ 65 | Processes model predictions into the desired format. 66 | 67 | Args: 68 | predictions (List[dict]): Raw predictions from the model. 69 | 70 | Returns: 71 | Any: Processed predictions in the desired format. 72 | """ 73 | pass 74 | 75 | @abstractmethod 76 | def evaluate(self, dataset_id: str, labels: Optional[List[str]] = None, threshold: float = 0.5): 77 | """ 78 | Evaluates the model on a given dataset. 79 | 80 | Args: 81 | dataset_id (str): Identifier for the evaluation dataset. 82 | labels (Optional[List[str]]): List of labels to evaluate. Defaults to None. 83 | threshold (float): Threshold for prediction confidence. Defaults to 0.5. 84 | 85 | Returns: 86 | Any: Evaluation results. 87 | """ 88 | pass 89 | 90 | def __call__(self, texts: Union[str, List[str]], labels: List[str] = ['match'], 91 | threshold: float = 0.5, batch_size: int = 8, **kwargs): 92 | """ 93 | Runs the model on the provided texts and returns processed results. 94 | 95 | Args: 96 | texts (Union[str, List[str]]): Single or list of input texts. 97 | labels (Optional[List[str]]): List of class labels for text preparation. Defaults to None. 98 | threshold (float): Threshold for prediction confidence. Defaults to 0.5. 99 | batch_size (int): Batch size for processing. Defaults to 8. 100 | 101 | Returns: 102 | Any: Processed results from the model. 103 | """ 104 | if isinstance(texts, str): 105 | texts = [texts] 106 | 107 | prompts = self.prepare_texts(texts, **kwargs) 108 | 109 | predictions = self.model.run(prompts, labels, threshold=threshold, batch_size=batch_size) 110 | 111 | results = self.process_predictions(predictions, **kwargs) 112 | 113 | return results -------------------------------------------------------------------------------- /gliner/multitask/classification.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | import os 3 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 4 | import torch 5 | from datasets import load_dataset, Dataset 6 | from sklearn.metrics import f1_score 7 | from gliner import GLiNER 8 | 9 | from .base import GLiNERBasePipeline 10 | 11 | class GLiNERClassifier(GLiNERBasePipeline): 12 | """ 13 | A class to evaluate the GLiNER model for classification tasks using F1 scores. 14 | 15 | Attributes: 16 | device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'. 17 | model (GLiNER): Loaded GLiNER model instance. 18 | prompt (str): Template prompt for text classification. 19 | 20 | Methods: 21 | compute_f_score(predicts, true_labels): 22 | Computes micro, macro, and weighted F1 scores. 23 | prepare_dataset(dataset, classes=None, text_column='text', label_column='label', split=None, max_examples=-1): 24 | Prepares texts and true labels from the given dataset. 25 | process_predictions(predictions): 26 | Processes model predictions to extract the most likely labels. 27 | prepare_texts(texts, labels): 28 | Creates classification prompts for each input text. 29 | __call__(texts, labels, threshold=0.5): 30 | Runs the model on the given texts and returns predicted labels. 31 | evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1): 32 | Evaluates the model on a dataset and computes F1 scores. 33 | """ 34 | 35 | prompt = "Classify text into the following classes: {}" 36 | 37 | def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', prompt: Optional[str] = None): 38 | """ 39 | Initializes the GLiNERClassifier. 40 | 41 | Args: 42 | model_id (str, optional): Identifier for the model to be loaded. Defaults to None. 43 | model (GLiNER, optional): Preloaded GLiNER model. Defaults to None. 44 | device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'. 45 | prompt (str, optional): Template prompt for text classification. Defaults to the class-level prompt. 46 | """ 47 | # Use the provided prompt or default to the class-level prompt 48 | prompt = prompt if prompt is not None else self.prompt 49 | super().__init__(model_id=model_id, model=model, prompt=prompt, device=device) 50 | 51 | 52 | def compute_f_score(self, predicts, true_labels): 53 | """ 54 | Computes the micro, macro, and weighted F1 scores. 55 | 56 | Args: 57 | predicts (list): List of predicted labels. 58 | true_labels (list): List of true labels. 59 | 60 | Returns: 61 | dict: Dictionary with micro, macro, and weighted F1 scores. 62 | """ 63 | micro = f1_score(true_labels, predicts, average="micro") 64 | macro = f1_score(true_labels, predicts, average="macro") 65 | weighted = f1_score(true_labels, predicts, average="weighted") 66 | return {"micro": micro, "macro": macro, "weighted": weighted} 67 | 68 | def prepare_dataset(self, dataset: Dataset, classes=None, text_column='text', label_column="label", split=None, max_examples=-1): 69 | """ 70 | Prepares the dataset by extracting texts and true labels. 71 | 72 | Args: 73 | dataset (Dataset or dict): The dataset to prepare. 74 | classes (list, optional): List of class labels. Defaults to None. 75 | text_column (str): Name of the text column. Defaults to 'text'. 76 | label_column (str): Name of the label column. Defaults to 'label'. 77 | split (str, optional): Delimiter for splitting class names. Defaults to None. 78 | max_examples (int): Maximum number of examples to use. Defaults to -1 (use all). 79 | 80 | Returns: 81 | tuple: Texts, classes, and true labels. 82 | """ 83 | if 'test' in dataset: 84 | test_dataset = dataset['test'] 85 | elif isinstance(dataset, Dataset): 86 | test_dataset = dataset 87 | else: 88 | test_dataset = dataset['train'] 89 | 90 | if classes is None: 91 | classes = test_dataset.features[label_column].names 92 | if split is not None: 93 | classes = [' '.join(class_.split(split)) for class_ in classes] 94 | 95 | texts = test_dataset[text_column] 96 | true_labels = test_dataset[label_column] 97 | 98 | if isinstance(test_dataset[label_column][0], int): 99 | true_labels = [classes[label] for label in true_labels] 100 | 101 | if max_examples > 0: 102 | texts = texts[:max_examples] 103 | true_labels = true_labels[:max_examples] 104 | 105 | return texts, classes, true_labels 106 | 107 | def process_predictions(self, predictions, multi_label=False, **kwargs): 108 | """ 109 | Processes predictions to extract the highest-scoring label(s). 110 | 111 | Args: 112 | predictions (list): List of predictions with scores. 113 | multi_label (bool): Whether to allow multiple labels per input. Defaults to False. 114 | 115 | Returns: 116 | list: List of predicted labels for each input. 117 | """ 118 | batch_predicted_labels = [] 119 | 120 | for prediction in predictions: 121 | # Sort predictions by score in descending order 122 | sorted_predictions = sorted(prediction, key=lambda entity: entity["score"], reverse=True) 123 | 124 | if not sorted_predictions: 125 | # Default prediction if no valid predictions are found 126 | batch_predicted_labels.append([{'label': 'other', 'score': 1.0}]) 127 | continue 128 | 129 | if not multi_label: 130 | # Single-label mode: select the top prediction and compute softmax score 131 | scores = [item['score'] for item in sorted_predictions] 132 | softmax_scores = torch.softmax(torch.tensor(scores), dim=0).tolist() 133 | top_prediction = {'label': sorted_predictions[0]['text'], 'score': softmax_scores[0]} 134 | batch_predicted_labels.append([top_prediction]) 135 | else: 136 | # Multi-label mode: retain all predictions with original scores 137 | predicted_labels = [{'label': pred['text'], 'score': pred['score']} for pred in sorted_predictions] 138 | batch_predicted_labels.append(predicted_labels) 139 | 140 | return batch_predicted_labels 141 | 142 | def prepare_texts(self, texts, classes, **kwargs): 143 | """ 144 | Prepares prompts for classification by appending labels to texts. 145 | 146 | Args: 147 | texts (list): List of input texts. 148 | classes (list): List of classification labels. 149 | 150 | Returns: 151 | list: List of formatted prompts. 152 | """ 153 | prompts = [] 154 | labels_ = ', '.join(classes) 155 | for text in texts: 156 | prompt = f"{self.prompt.format(labels_)} \n {text}" 157 | prompts.append(prompt) 158 | return prompts 159 | 160 | def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None, 161 | labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1): 162 | """ 163 | Evaluates the model on a specified dataset and computes evaluation metrics. 164 | 165 | Args: 166 | dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets). 167 | dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored. 168 | labels (list, optional): List of target labels to consider for classification. Defaults to None (use all). 169 | threshold (float): Confidence threshold for predictions. Defaults to 0.5. 170 | max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples). 171 | 172 | Returns: 173 | dict: A dictionary containing evaluation metrics such as F1 scores (micro, macro, and weighted). 174 | 175 | Raises: 176 | ValueError: If neither `dataset_id` nor `dataset` is provided. 177 | """ 178 | if dataset is None and dataset_id is not None: 179 | dataset = load_dataset(dataset_id) 180 | elif dataset is not None and dataset_id is None: 181 | dataset = dataset 182 | else: 183 | raise ValueError("Either 'dataset_id' or 'dataset' must be provided to start evaluation.") 184 | 185 | test_texts, classes, true_labels = self.prepare_dataset(dataset, labels, max_examples=max_examples) 186 | 187 | predictions = self.__call__(test_texts, classes=classes, threshold=threshold) 188 | predicted_labels = [pred[0]['label'] for pred in predictions] 189 | 190 | return self.compute_f_score(predicted_labels, true_labels) 191 | -------------------------------------------------------------------------------- /gliner/multitask/open_extraction.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Union 2 | import os 3 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 4 | import torch 5 | from datasets import load_dataset, Dataset 6 | from gliner import GLiNER 7 | 8 | from .base import GLiNERBasePipeline 9 | 10 | class GLiNEROpenExtractor(GLiNERBasePipeline): 11 | """ 12 | A class to use GLiNER for open information extraction inference and evaluation. 13 | 14 | Attributes: 15 | device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'. 16 | model (GLiNER): Loaded GLiNER model instance. 17 | prompt (str): Template prompt for open information extraction. 18 | 19 | Methods: 20 | process_predictions(predictions): 21 | Processes model predictions to extract the most likely labels. 22 | prepare_texts(texts, labels): 23 | Creates open information extraction prompts for each input text. 24 | __call__(texts, labels, threshold=0.5): 25 | Runs the model on the given texts and returns predicted labels. 26 | evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1): 27 | Evaluates the model on a dataset and computes F1 scores. 28 | """ 29 | 30 | prompt = "" 31 | 32 | def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', prompt: Optional[str] = None): 33 | """ 34 | Initializes the GLiNEROpenExtractor. 35 | 36 | Args: 37 | model_id (str, optional): Identifier for the model to be loaded. Defaults to None. 38 | model (GLiNER, optional): Preloaded GLiNER model. Defaults to None. 39 | device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'. 40 | prompt (str, optional): Template prompt for open information extraction. 41 | """ 42 | # Use the provided prompt or default to the class-level prompt 43 | prompt = prompt if prompt is not None else self.prompt 44 | super().__init__(model_id=model_id, model=model, prompt=prompt, device=device) 45 | 46 | 47 | def process_predictions(self, predictions, **kwargs): 48 | """ 49 | Processes predictions to extract the highest-scoring label(s). 50 | 51 | Args: 52 | predictions (list): List of predictions with scores. 53 | 54 | Returns: 55 | list: List of predicted labels for each input. 56 | """ 57 | return predictions 58 | 59 | def prepare_texts(self, texts: List[str], **kwargs): 60 | """ 61 | Prepares prompts for open-information extraction. 62 | 63 | Args: 64 | texts (list): List of input texts. 65 | 66 | Returns: 67 | list: List of formatted prompts. 68 | """ 69 | prompts = [] 70 | 71 | for id, text in enumerate(texts): 72 | prompt = f"{self.prompt} \n {text}" 73 | prompts.append(prompt) 74 | return prompts 75 | 76 | 77 | def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None, 78 | labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1): 79 | """ 80 | Evaluates the model on a specified dataset and computes evaluation metrics. 81 | 82 | Args: 83 | dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets). 84 | dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored. 85 | labels (list, optional): List of target labels to consider for extraction. Defaults to None (use all). 86 | threshold (float): Confidence threshold for predictions. Defaults to 0.5. 87 | max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples). 88 | 89 | Returns: 90 | dict: A dictionary containing evaluation metrics. 91 | 92 | Raises: 93 | ValueError: If neither `dataset_id` nor `dataset` is provided. 94 | """ 95 | raise NotImplementedError("Currently `evaluate` method is not implemented.") -------------------------------------------------------------------------------- /gliner/multitask/question_answering.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Union 2 | import os 3 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 4 | import torch 5 | from datasets import load_dataset, Dataset 6 | from gliner import GLiNER 7 | 8 | from .base import GLiNERBasePipeline 9 | 10 | class GLiNERQuestionAnswerer(GLiNERBasePipeline): 11 | """ 12 | A class to use GLiNER for question-answering inference and evaluation. 13 | 14 | Attributes: 15 | device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'. 16 | model (GLiNER): Loaded GLiNER model instance. 17 | prompt (str): Template prompt for text question-asnwering. 18 | 19 | Methods: 20 | process_predictions(predictions): 21 | Processes model predictions to extract the most likely labels. 22 | prepare_texts(texts, labels): 23 | Creates Q&A prompts for each input text. 24 | __call__(texts, labels, threshold=0.5): 25 | Runs the model on the given texts and returns predicted labels. 26 | evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1): 27 | Evaluates the model on a dataset and computes F1 scores. 28 | """ 29 | 30 | prompt = "Answer the following question: {}" 31 | 32 | def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', prompt: Optional[str] = None): 33 | """ 34 | Initializes the GLiNERQuestionAnswerer. 35 | 36 | Args: 37 | model_id (str, optional): Identifier for the model to be loaded. Defaults to None. 38 | model (GLiNER, optional): Preloaded GLiNER model. Defaults to None. 39 | device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'. 40 | prompt (str, optional): Template prompt for question-answering. 41 | """ 42 | # Use the provided prompt or default to the class-level prompt 43 | prompt = prompt if prompt is not None else self.prompt 44 | super().__init__(model_id=model_id, model=model, prompt=prompt, device=device) 45 | 46 | 47 | def process_predictions(self, predictions, **kwargs): 48 | """ 49 | Processes predictions to extract the highest-scoring answer(s). 50 | 51 | Args: 52 | predictions (list): List of predictions with scores. 53 | 54 | Returns: 55 | list: List of predicted labels for each input. 56 | """ 57 | batch_predicted_labels = [] 58 | 59 | for prediction in predictions: 60 | # Sort predictions by score in descending order 61 | sorted_predictions = sorted(prediction, key=lambda entity: entity["score"], reverse=True) 62 | 63 | predicted_labels = [{'answer': pred['text'], 'score': pred['score']} for pred in sorted_predictions] 64 | batch_predicted_labels.append(predicted_labels) 65 | 66 | return batch_predicted_labels 67 | 68 | def prepare_texts(self, texts: List[str], questions: Union[List[str], str], **kwargs): 69 | """ 70 | Prepares prompts for question-answering by appending questions to texts. 71 | 72 | Args: 73 | texts (list): List of input texts. 74 | questions (list|str): Question or list of questions. 75 | 76 | Returns: 77 | list: List of formatted prompts. 78 | """ 79 | prompts = [] 80 | 81 | for id, text in enumerate(texts): 82 | if isinstance(questions, str): 83 | question = questions 84 | else: 85 | question = questions[0] 86 | prompt = f"{self.prompt.format(question)} \n {text}" 87 | prompts.append(prompt) 88 | return prompts 89 | 90 | def __call__(self, texts: Union[str, List[str]], questions: Union[str, List[str]], 91 | labels: List[str] = ['answer'], threshold: float = 0.5, 92 | batch_size: int = 8, **kwargs): 93 | return super().__call__(texts, labels, threshold, batch_size, questions=questions) 94 | 95 | def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None, 96 | labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1): 97 | """ 98 | Evaluates the model on a specified dataset and computes evaluation metrics. 99 | 100 | Args: 101 | dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets). 102 | dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored. 103 | labels (list, optional): List of target labels to consider for classification. Defaults to None (use all). 104 | threshold (float): Confidence threshold for predictions. Defaults to 0.5. 105 | max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples). 106 | 107 | Returns: 108 | dict: A dictionary containing evaluation metrics such as F1 scores. 109 | 110 | Raises: 111 | ValueError: If neither `dataset_id` nor `dataset` is provided. 112 | """ 113 | raise NotImplementedError("Currently `evaluate` method is not implemented.") 114 | 115 | class GLiNERSquadEvaluator(GLiNERQuestionAnswerer): 116 | def evaluate(self, dataset_id: str = 'rajpurkar/squad_v2', dataset: Optional[Dataset] = None, 117 | labels: Optional[List[str]] = ['answer'], threshold: float = 0.5, max_examples: int = -1): 118 | """ 119 | Evaluates the model on a specified dataset and computes evaluation metrics. 120 | 121 | Args: 122 | dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets). 123 | dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored. 124 | labels (list, optional): List of target labels to consider for classification. Defaults to ['answer']. 125 | threshold (float): Confidence threshold for predictions. Defaults to 0.5. 126 | max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples). 127 | 128 | Returns: 129 | dict: A dictionary containing evaluation metrics such as F1 Scores. 130 | 131 | Raises: 132 | ValueError: If neither `dataset_id` nor `dataset` is provided. 133 | """ 134 | from evaluate import load 135 | 136 | # Validate input 137 | if not dataset and not dataset_id: 138 | raise ValueError("Either `dataset` or `dataset_id` must be provided.") 139 | 140 | # Load the dataset if not provided 141 | if not dataset: 142 | dataset = load_dataset(dataset_id, split="validation") 143 | 144 | if not isinstance(dataset, Dataset): 145 | dataset = dataset['validation'] 146 | 147 | # Truncate dataset if max_examples is specified 148 | if max_examples > 0: 149 | dataset = dataset.shuffle().select(range(min(len(dataset), max_examples))) 150 | 151 | # Load evaluation metric for SQuAD 152 | squad_metric = load("squad_v2" if "squad_v2" in dataset_id else "squad") 153 | 154 | # Prepare predictions and references 155 | contexts = dataset['context'] 156 | questions = dataset['question'] 157 | 158 | raw_predictions = self(contexts, questions, labels=labels, threshold=threshold) 159 | 160 | predictions = [] 161 | references = [] 162 | for id, prediction in enumerate(raw_predictions): 163 | example = dataset[id] 164 | 165 | if len(prediction): 166 | predicted_answer = prediction[0]["answer"] 167 | no_answer_probability=0.0 168 | else: 169 | predicted_answer = "" 170 | no_answer_probability=1.0 171 | 172 | # Append to predictions and references 173 | predictions.append({ 174 | "id": example["id"], 175 | "prediction_text": predicted_answer, 176 | "no_answer_probability": no_answer_probability 177 | }) 178 | 179 | references.append({ 180 | "id": example["id"], 181 | "answers": {"text": example["answers"]["text"], "answer_start": example["answers"]["answer_start"]} 182 | }) 183 | 184 | # Compute metrics 185 | results = squad_metric.compute(predictions=predictions, references=references) 186 | return results -------------------------------------------------------------------------------- /gliner/multitask/summarization.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Union 2 | import os 3 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 4 | import torch 5 | from datasets import load_dataset, Dataset 6 | from gliner import GLiNER 7 | 8 | from .base import GLiNERBasePipeline 9 | 10 | class GLiNERSummarizer(GLiNERBasePipeline): 11 | """ 12 | A class to use GLiNER for summarization inference and evaluation. 13 | 14 | Attributes: 15 | device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'. 16 | model (GLiNER): Loaded GLiNER model instance. 17 | prompt (str): Template prompt for text summarization. 18 | 19 | Methods: 20 | process_predictions(predictions): 21 | Processes model predictions to extract the most likely labels. 22 | prepare_texts(texts, labels): 23 | Creates summarization prompts for each input text. 24 | __call__(texts, labels, threshold=0.5): 25 | Runs the model on the given texts and returns predicted labels. 26 | evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1): 27 | Evaluates the model on a dataset and computes F1 scores. 28 | """ 29 | 30 | prompt = "Summarize the following text highlighting the most important information:" 31 | 32 | def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', prompt: Optional[str] = None): 33 | """ 34 | Initializes the GLiNERSummarizer. 35 | 36 | Args: 37 | model_id (str, optional): Identifier for the model to be loaded. Defaults to None. 38 | model (GLiNER, optional): Preloaded GLiNER model. Defaults to None. 39 | device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'. 40 | prompt (str, optional): Template prompt for summarization. 41 | """ 42 | # Use the provided prompt or default to the class-level prompt 43 | prompt = prompt if prompt is not None else self.prompt 44 | super().__init__(model_id=model_id, model=model, prompt=prompt, device=device) 45 | 46 | 47 | def process_predictions(self, predictions, **kwargs): 48 | """ 49 | Processes predictions to extract the highest-scoring text chunk(s). 50 | 51 | Args: 52 | predictions (list): List of predictions with scores. 53 | 54 | Returns: 55 | list: List of predicted labels for each input. 56 | """ 57 | batch_predicted_labels = [] 58 | 59 | for prediction in predictions: 60 | # Sort predictions by score in descending order 61 | sorted_predictions = sorted(prediction, key=lambda entity: entity["start"], reverse=False) 62 | 63 | extracted_text = [pred['text'] for pred in sorted_predictions] 64 | batch_predicted_labels.append(' '.join(extracted_text)) 65 | 66 | return batch_predicted_labels 67 | 68 | def prepare_texts(self, texts: List[str], **kwargs): 69 | """ 70 | Prepares prompts for summarization by appending prompt to texts. 71 | 72 | Args: 73 | texts (list): List of input texts. 74 | 75 | Returns: 76 | list: List of formatted prompts. 77 | """ 78 | prompts = [] 79 | 80 | for id, text in enumerate(texts): 81 | prompt = f"{self.prompt} \n {text}" 82 | prompts.append(prompt) 83 | return prompts 84 | 85 | def __call__(self, texts: Union[str, List[str]], labels: List[str] = ['summary'], 86 | threshold: float = 0.25, batch_size: int = 8, **kwargs): 87 | return super().__call__(texts, labels, threshold, batch_size) 88 | 89 | def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None, 90 | labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1): 91 | """ 92 | Evaluates the model on a specified dataset and computes evaluation metrics. 93 | 94 | Args: 95 | dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets). 96 | dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored. 97 | labels (list, optional): List of target labels to consider for summarization. Defaults to None (use all). 98 | threshold (float): Confidence threshold for predictions. Defaults to 0.5. 99 | max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples). 100 | 101 | Returns: 102 | dict: A dictionary containing evaluation metrics. 103 | 104 | Raises: 105 | ValueError: If neither `dataset_id` nor `dataset` is provided. 106 | """ 107 | raise NotImplementedError("Currently `evaluate` method is not implemented.") -------------------------------------------------------------------------------- /gliner/onnx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/urchade/GLiNER/efbfa38211136657895372d33d4ee2fe11b6f11b/gliner/onnx/__init__.py -------------------------------------------------------------------------------- /gliner/onnx/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, Any 2 | from abc import ABC, abstractmethod 3 | import warnings 4 | import onnxruntime as ort 5 | import numpy as np 6 | import torch 7 | 8 | from ..modeling.base import GLiNERModelOutput 9 | 10 | class BaseORTModel(ABC): 11 | def __init__(self, session: ort.InferenceSession): 12 | self.session = session 13 | self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} 14 | self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} 15 | 16 | def prepare_inputs(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]: 17 | """ 18 | Prepare inputs for ONNX model inference. 19 | 20 | Args: 21 | inputs (Dict[str, torch.Tensor]): Dictionary of input names and tensors. 22 | 23 | Returns: 24 | Dict[str, np.ndarray]: Dictionary of input names and numpy arrays. 25 | """ 26 | if not isinstance(inputs, dict): 27 | raise ValueError("Inputs must be a dictionary of input names and tensors.") 28 | 29 | prepared_inputs = {} 30 | for key, tensor in inputs.items(): 31 | if key not in self.input_names: 32 | warnings.warn(f"Input key '{key}' not found in ONNX model's input names. Ignored.") 33 | continue 34 | prepared_inputs[key] = tensor.cpu().detach().numpy() 35 | return prepared_inputs 36 | 37 | def run_inference(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: 38 | """ 39 | Run the ONNX model inference. 40 | 41 | Args: 42 | inputs (Dict[str, np.ndarray]): Prepared inputs for the model. 43 | 44 | Returns: 45 | Dict[str, np.ndarray]: Model's outputs as numpy arrays. 46 | """ 47 | onnx_outputs = self.session.run(None, inputs) 48 | outputs = {name: onnx_outputs[idx] for name, idx in self.output_names.items()} 49 | return outputs 50 | 51 | @abstractmethod 52 | def forward(self, input_ids, attention_mask, **kwargs) -> Dict[str, Any]: 53 | """ 54 | Abstract method to perform forward pass. Must be implemented by subclasses. 55 | """ 56 | pass 57 | 58 | def __call__(self, *args, **kwargs): 59 | return self.forward(*args, **kwargs) 60 | 61 | class SpanORTModel(BaseORTModel): 62 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, 63 | words_mask: torch.Tensor, text_lengths: torch.Tensor, 64 | span_idx: torch.Tensor, span_mask: torch.Tensor, **kwargs) -> Dict[str, Any]: 65 | """ 66 | Forward pass for span model using ONNX inference. 67 | 68 | Args: 69 | input_ids (torch.Tensor): Input IDs tensor. 70 | attention_mask (torch.Tensor): Attention mask tensor. 71 | span_idx (torch.Tensor): Span indices tensor. 72 | span_mask (torch.Tensor): Span mask tensor. 73 | **kwargs: Additional arguments. 74 | 75 | Returns: 76 | Dict[str, Any]: Model outputs. 77 | """ 78 | inputs = { 79 | 'input_ids': input_ids, 80 | 'attention_mask': attention_mask, 81 | 'words_mask': words_mask, 82 | 'text_lengths': text_lengths, 83 | 'span_idx': span_idx, 84 | 'span_mask': span_mask 85 | } 86 | prepared_inputs = self.prepare_inputs(inputs) 87 | inference_output = self.run_inference(prepared_inputs) 88 | outputs = GLiNERModelOutput( 89 | logits=inference_output['logits'] 90 | ) 91 | return outputs 92 | 93 | class TokenORTModel(BaseORTModel): 94 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, 95 | words_mask: torch.Tensor, text_lengths: torch.Tensor, 96 | **kwargs) -> Dict[str, Any]: 97 | """ 98 | Forward pass for token model using ONNX inference. 99 | 100 | Args: 101 | input_ids (torch.Tensor): Input IDs tensor. 102 | attention_mask (torch.Tensor): Attention mask tensor. 103 | **kwargs: Additional arguments. 104 | 105 | Returns: 106 | Dict[str, Any]: Model outputs. 107 | """ 108 | inputs = { 109 | 'input_ids': input_ids, 110 | 'attention_mask': attention_mask, 111 | 'words_mask': words_mask, 112 | 'text_lengths': text_lengths, 113 | } 114 | prepared_inputs = self.prepare_inputs(inputs) 115 | inference_output = self.run_inference(prepared_inputs) 116 | outputs = GLiNERModelOutput( 117 | logits=inference_output['logits'] 118 | ) 119 | return outputs -------------------------------------------------------------------------------- /gliner/training/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer, TrainingArguments -------------------------------------------------------------------------------- /gliner/training/trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Any, Dict, Tuple, List 2 | from dataclasses import dataclass, field 3 | 4 | import torch 5 | import transformers 6 | from numpy.ma.core import negative 7 | from transformers.training_args import OptimizerNames 8 | from transformers.trainer import ( 9 | is_sagemaker_mp_enabled, 10 | get_parameter_names, 11 | ALL_LAYERNORM_LAYERS, 12 | ) 13 | from transformers.trainer_utils import seed_worker 14 | 15 | if transformers.utils.is_apex_available(): 16 | from apex import amp 17 | 18 | if is_sagemaker_mp_enabled(): 19 | from transformers.trainer_pt_utils import smp_forward_backward 20 | from torch.utils.data import DataLoader, Dataset 21 | 22 | @dataclass 23 | class TrainingArguments(transformers.TrainingArguments): 24 | cache_dir: Optional[str] = field(default=None) 25 | optim: str = field(default="adamw_torch") 26 | others_lr: Optional[float] = None 27 | others_weight_decay: Optional[float] = 0.0 28 | focal_loss_alpha: Optional[float] = -1 29 | focal_loss_gamma: Optional[float] = 0 30 | label_smoothing: Optional[float] = 0 31 | loss_reduction: Optional[str] = 'sum' 32 | negatives: Optional[float] = 1.0 33 | masking: Optional[str] = 'global' 34 | 35 | class Trainer(transformers.Trainer): 36 | def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor: 37 | """ 38 | Perform a training step on a batch of inputs. 39 | 40 | Subclass and override to inject custom behavior. 41 | 42 | Args: 43 | model (`nn.Module`): 44 | The model to train. 45 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 46 | The inputs and targets of the model. 47 | 48 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 49 | argument `labels`. Check your model's documentation for all accepted arguments. 50 | 51 | Return: 52 | `torch.Tensor`: The tensor with training loss on this batch. 53 | """ 54 | model.train() 55 | try: 56 | inputs = self._prepare_inputs(inputs) 57 | if is_sagemaker_mp_enabled(): 58 | loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) 59 | return loss_mb.reduce_mean().detach().to(self.args.device) 60 | 61 | with self.compute_loss_context_manager(): 62 | loss = self.compute_loss(model, inputs) 63 | 64 | del inputs 65 | torch.cuda.empty_cache() 66 | 67 | kwargs = {} 68 | 69 | # For LOMO optimizers you need to explicitly use the learnign rate 70 | # if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: 71 | # kwargs["learning_rate"] = self._get_learning_rate() 72 | 73 | if self.args.n_gpu > 1: 74 | loss = loss.mean() # mean() to average on multi-gpu parallel training 75 | 76 | if self.use_apex: 77 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 78 | scaled_loss.backward() 79 | else: 80 | self.accelerator.backward(loss, **kwargs) 81 | 82 | return loss.detach() / self.args.gradient_accumulation_steps 83 | except Exception as e: 84 | print(f"Skipping iteration due to error: {e}") 85 | model.zero_grad(set_to_none=True) 86 | torch.cuda.empty_cache() 87 | return torch.tensor(0.0, requires_grad=True).to(model.device) 88 | 89 | def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): 90 | self.model.save_pretrained(output_dir) 91 | 92 | def compute_loss(self, model, inputs): 93 | """ 94 | Override compute_loss to use a custom loss function. 95 | """ 96 | # Forward pass 97 | outputs = model(alpha = self.args.focal_loss_alpha, 98 | gamma = self.args.focal_loss_gamma, 99 | label_smoothing = self.args.label_smoothing, 100 | reduction = self.args.loss_reduction, 101 | negatives = self.args.negatives, 102 | masking = self.args.masking, 103 | **inputs) 104 | loss = outputs.loss 105 | return loss 106 | 107 | def create_optimizer(self): 108 | """ 109 | Setup the optimizer. 110 | 111 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 112 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 113 | """ 114 | if is_sagemaker_mp_enabled(): 115 | return super().create_optimizer() 116 | 117 | opt_model = self.model 118 | 119 | if self.optimizer is None: 120 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 121 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 122 | if self.args.others_lr is not None: 123 | encoder_parameters = [name for name, _ in opt_model.named_parameters() if "token_rep_layer" in name] 124 | optimizer_grouped_parameters = [ 125 | { 126 | "params": [ 127 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in encoder_parameters and p.requires_grad) 128 | ], 129 | "weight_decay": self.args.others_weight_decay, 130 | "lr": self.args.others_lr, 131 | }, 132 | { 133 | "params": [ 134 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in encoder_parameters and p.requires_grad) 135 | ], 136 | "weight_decay": 0.0, 137 | "lr": self.args.others_lr, 138 | }, 139 | { 140 | "params": [ 141 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in encoder_parameters and p.requires_grad) 142 | ], 143 | "weight_decay": self.args.weight_decay, 144 | }, 145 | { 146 | "params": [ 147 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in encoder_parameters and p.requires_grad) 148 | ], 149 | "weight_decay": 0.0, 150 | }, 151 | ] 152 | else: 153 | optimizer_grouped_parameters = [ 154 | { 155 | "params": [ 156 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) 157 | ], 158 | "weight_decay": self.args.weight_decay, 159 | }, 160 | { 161 | "params": [ 162 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) 163 | ], 164 | "weight_decay": 0.0, 165 | }, 166 | ] 167 | 168 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 169 | 170 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 171 | 172 | return self.optimizer 173 | 174 | def prediction_step( 175 | self, 176 | model: torch.nn.Module, 177 | inputs: Dict[str, Union[torch.Tensor, Any]], 178 | prediction_loss_only: bool, 179 | ignore_keys: Optional[List[str]] = None, 180 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 181 | """ 182 | Perform an evaluation step on model using inputs. 183 | 184 | Subclass and override to inject custom behavior. 185 | 186 | Args: 187 | model (nn.Module): 188 | The model to evaluate. 189 | inputs (Dict[str, Union[torch.Tensor, Any]]): 190 | The inputs and targets of the model. 191 | 192 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 193 | argument labels. Check your model's documentation for all accepted arguments. 194 | prediction_loss_only (bool): 195 | Whether or not to return the loss only. 196 | ignore_keys (List[str], *optional*): 197 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 198 | gathering predictions. 199 | 200 | Return: 201 | Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, 202 | logits and labels (each being optional). 203 | """ 204 | with torch.no_grad(): 205 | loss = None 206 | with self.compute_loss_context_manager(): 207 | outputs = model(**inputs) 208 | loss = outputs.loss 209 | logits = outputs.logits 210 | labels = inputs['labels'] 211 | if prediction_loss_only: 212 | return (loss, None, None) 213 | return (loss, logits, labels) 214 | 215 | 216 | def get_train_dataloader(self) -> DataLoader: 217 | """ 218 | Returns the training [`~torch.utils.data.DataLoader`]. 219 | 220 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed 221 | training if necessary) otherwise. 222 | 223 | Subclass and override this method if you want to inject some custom behavior. 224 | """ 225 | if self.train_dataset is None: 226 | raise ValueError("Trainer: training requires a train_dataset.") 227 | 228 | train_dataset = self.train_dataset 229 | data_collator = self.data_collator 230 | 231 | dataloader_params = { 232 | "batch_size": self._train_batch_size, 233 | "collate_fn": data_collator, 234 | "num_workers": self.args.dataloader_num_workers, 235 | "pin_memory": self.args.dataloader_pin_memory, 236 | "persistent_workers": self.args.dataloader_persistent_workers, 237 | } 238 | 239 | if not isinstance(train_dataset, torch.utils.data.IterableDataset): 240 | dataloader_params["sampler"] = self._get_train_sampler() 241 | dataloader_params["drop_last"] = self.args.dataloader_drop_last 242 | dataloader_params["worker_init_fn"] = seed_worker 243 | dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor 244 | 245 | return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) 246 | 247 | def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader: 248 | """ 249 | Returns the evaluation [`~torch.utils.data.DataLoader`]. 250 | 251 | Subclass and override this method if you want to inject some custom behavior. 252 | 253 | Args: 254 | eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*): 255 | If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. 256 | """ 257 | if eval_dataset is None and self.eval_dataset is None: 258 | raise ValueError("Trainer: evaluation requires an eval_dataset.") 259 | 260 | # If we have persistent workers, don't do a fork bomb especially as eval datasets 261 | # don't change during training 262 | dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval" 263 | if ( 264 | hasattr(self, "_eval_dataloaders") 265 | and dataloader_key in self._eval_dataloaders 266 | and self.args.dataloader_persistent_workers 267 | ): 268 | return self.accelerator.prepare(self._eval_dataloaders[dataloader_key]) 269 | 270 | eval_dataset = ( 271 | self.eval_dataset[eval_dataset] 272 | if isinstance(eval_dataset, str) 273 | else eval_dataset 274 | if eval_dataset is not None 275 | else self.eval_dataset 276 | ) 277 | data_collator = self.data_collator 278 | 279 | dataloader_params = { 280 | "batch_size": self.args.eval_batch_size, 281 | "collate_fn": data_collator, 282 | "num_workers": self.args.dataloader_num_workers, 283 | "pin_memory": self.args.dataloader_pin_memory, 284 | "persistent_workers": self.args.dataloader_persistent_workers, 285 | } 286 | 287 | if not isinstance(eval_dataset, torch.utils.data.IterableDataset): 288 | dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) 289 | dataloader_params["drop_last"] = self.args.dataloader_drop_last 290 | dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor 291 | 292 | # accelerator.free_memory() will destroy the references, so 293 | # we need to store the non-prepared version 294 | eval_dataloader = DataLoader(eval_dataset, **dataloader_params) 295 | if self.args.dataloader_persistent_workers: 296 | if hasattr(self, "_eval_dataloaders"): 297 | self._eval_dataloaders[dataloader_key] = eval_dataloader 298 | else: 299 | self._eval_dataloaders = {dataloader_key: eval_dataloader} 300 | 301 | return self.accelerator.prepare(eval_dataloader) 302 | -------------------------------------------------------------------------------- /gliner/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | import yaml 4 | 5 | def load_config_as_namespace(config_file): 6 | with open(config_file, "r") as f: 7 | config_dict = yaml.safe_load(f) 8 | return argparse.Namespace(**config_dict) 9 | 10 | def is_module_available(module_name): 11 | """ 12 | Checks whether the specified Python module is available. 13 | 14 | Args: 15 | module_name (str): The name of the module to check. 16 | 17 | Returns: 18 | bool: True if the module is available, False otherwise. 19 | """ 20 | try: 21 | __import__(module_name) 22 | return True 23 | except ImportError: 24 | return False 25 | 26 | class MissedPackageException(Exception): 27 | """Raised when the requested decoder model is not supported.""" 28 | pass -------------------------------------------------------------------------------- /image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/urchade/GLiNER/efbfa38211136657895372d33d4ee2fe11b6f11b/image.png -------------------------------------------------------------------------------- /logo/FI Group.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/urchade/GLiNER/efbfa38211136657895372d33d4ee2fe11b6f11b/logo/FI Group.png -------------------------------------------------------------------------------- /logo/FI_COMPLET_CW.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/urchade/GLiNER/efbfa38211136657895372d33d4ee2fe11b6f11b/logo/FI_COMPLET_CW.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.setuptools.packages.find] 6 | include = ["gliner", "gliner.*"] 7 | 8 | [tool.setuptools.dynamic] 9 | version = {attr = "gliner.__version__"} 10 | 11 | [project] 12 | name = "gliner" 13 | description = "Generalist model for NER (Extract any entity types from texts)" 14 | readme = "README.md" 15 | requires-python = ">=3.8" 16 | license = {text = "Apache-2.0"} 17 | keywords = [ 18 | "named-entity-recognition", 19 | "ner", 20 | "data-science", 21 | "natural-language-processing", 22 | "artificial-intelligence", 23 | "nlp", 24 | "machine-learning", 25 | "transformers" 26 | ] 27 | authors = [ 28 | {name = "Urchade Zaratiana"}, 29 | {name = "Nadi Tomeh"}, 30 | {name = "Pierre Holat"}, 31 | {name = "Thierry Charnois"}, 32 | ] 33 | maintainers = [ 34 | {name = "Urchade Zaratiana"}, 35 | ] 36 | 37 | dependencies = [ 38 | "torch>=2.0.0", 39 | "transformers>=4.38.2", 40 | "huggingface_hub>=0.21.4", 41 | "tqdm", 42 | "onnxruntime", 43 | "sentencepiece", 44 | ] 45 | 46 | dynamic = ["version"] 47 | 48 | [project.optional-dependencies] 49 | gpu = ["onnxruntime-gpu"] 50 | 51 | 52 | [project.urls] 53 | Homepage = "https://github.com/urchade/GLiNER" 54 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.0 2 | transformers>=4.38.2,<=4.45.2 3 | huggingface_hub>=0.21.4 4 | onnxruntime-gpu 5 | sentencepiece 6 | tqdm 7 | -------------------------------------------------------------------------------- /tests/test_features_selection.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from transformers import AutoTokenizer 4 | from gliner import GLiNERConfig 5 | from gliner.modeling.base import extract_prompt_features_and_word_embeddings 6 | from gliner.data_processing import SpanProcessor, WordsSplitter 7 | 8 | class TestFeaturesExtractor: 9 | @pytest.fixture(autouse=True) 10 | def setup(self): 11 | self.config = GLiNERConfig() 12 | self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) 13 | self.config.class_token_index=len(self.tokenizer) 14 | self.tokenizer.add_tokens([self.config.ent_token, self.config.sep_token]) 15 | self.splitter = WordsSplitter() 16 | self.base_tokens = [['Hello', 'world', '!']] 17 | self.tokens_with_missed = [['Hello', '', 'world', '']] 18 | self.labels = ['world'] 19 | self.processor = SpanProcessor(self.config, self.tokenizer, self.splitter) 20 | 21 | def test_base_extraction(self): 22 | input_x = [{"tokenized_text": tk, "ner": None} for tk in self.base_tokens] 23 | raw_batch = self.processor.collate_raw_batch(input_x, self.labels) 24 | model_input = self.processor.collate_fn(raw_batch, prepare_labels=False) 25 | model_input['text_lengths'] = raw_batch['seq_length'] 26 | token_embeds = torch.rand(model_input['words_mask'].shape + (self.config.hidden_size,)) 27 | 28 | (prompts_embedding, 29 | prompts_embedding_mask, 30 | words_embedding, 31 | mask) = extract_prompt_features_and_word_embeddings(self.config, token_embeds, **model_input) 32 | 33 | assert prompts_embedding_mask.shape == (1, 1) 34 | assert prompts_embedding.shape == (1, 1, self.config.hidden_size) 35 | assert words_embedding.shape == (1, len(self.base_tokens[0]), self.config.hidden_size) 36 | 37 | def test_extraction_with_missed_tokens(self): 38 | input_x = [{"tokenized_text": tk, "ner": None} for tk in self.tokens_with_missed] 39 | raw_batch = self.processor.collate_raw_batch(input_x, self.labels) 40 | model_input = self.processor.collate_fn(raw_batch, prepare_labels=False) 41 | model_input['text_lengths'] = raw_batch['seq_length'] 42 | token_embeds = torch.rand(model_input['words_mask'].shape + (self.config.hidden_size,)) 43 | 44 | (prompts_embedding, 45 | prompts_embedding_mask, 46 | words_embedding, 47 | mask) = extract_prompt_features_and_word_embeddings(self.config, token_embeds, **model_input) 48 | 49 | assert prompts_embedding_mask.shape == (1, 1) 50 | assert prompts_embedding.shape == (1, 1, self.config.hidden_size) 51 | assert words_embedding.shape == (1, len(self.tokens_with_missed[0]), self.config.hidden_size) 52 | 53 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | from gliner import GLiNER 2 | 3 | 4 | def test_span_model(): 5 | model = GLiNER.from_pretrained("gliner-community/gliner_small-v2.5") 6 | 7 | text = """ 8 | Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time. 9 | """ 10 | 11 | labels = ["person", "award", "date", "competitions", "teams", "person"] 12 | 13 | entities = model.predict_entities(text, labels) 14 | 15 | assert len(entities) > 0 16 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 3 | import argparse 4 | import random 5 | import json 6 | 7 | from transformers import AutoTokenizer 8 | import torch 9 | 10 | from gliner import GLiNERConfig, GLiNER 11 | from gliner.training import Trainer, TrainingArguments 12 | from gliner.data_processing.collator import DataCollatorWithPadding, DataCollator 13 | from gliner.utils import load_config_as_namespace 14 | from gliner.data_processing import WordsSplitter, GLiNERDataset 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--config', type=str, default= "configs/config.yaml") 20 | parser.add_argument('--log_dir', type=str, default = 'models/') 21 | parser.add_argument('--compile_model', type=bool, default = False) 22 | parser.add_argument('--freeze_language_model', type=bool, default = False) 23 | parser.add_argument('--new_data_schema', type=bool, default = False) 24 | args = parser.parse_args() 25 | 26 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 27 | 28 | config = load_config_as_namespace(args.config) 29 | config.log_dir = args.log_dir 30 | 31 | with open(config.train_data, 'r') as f: 32 | data = json.load(f) 33 | 34 | print('Dataset size:', len(data)) 35 | #shuffle 36 | random.shuffle(data) 37 | print('Dataset is shuffled...') 38 | 39 | train_data = data[:int(len(data)*0.9)] 40 | test_data = data[int(len(data)*0.9):] 41 | 42 | print('Dataset is splitted...') 43 | 44 | 45 | if config.prev_path is not None: 46 | tokenizer = AutoTokenizer.from_pretrained(config.prev_path) 47 | model = GLiNER.from_pretrained(config.prev_path) 48 | model_config = model.config 49 | else: 50 | model_config = GLiNERConfig(**vars(config)) 51 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name) 52 | 53 | words_splitter = WordsSplitter(model_config.words_splitter_type) 54 | 55 | model = GLiNER(model_config, tokenizer=tokenizer, words_splitter=words_splitter) 56 | 57 | if not config.labels_encoder: 58 | model_config.class_token_index=len(tokenizer) 59 | tokenizer.add_tokens([model_config.ent_token, model_config.sep_token], special_tokens=True) 60 | model_config.vocab_size = len(tokenizer) 61 | model.resize_token_embeddings([model_config.ent_token, model_config.sep_token], 62 | set_class_token_index = False, 63 | add_tokens_to_tokenizer=False) 64 | 65 | if args.compile_model: 66 | torch.set_float32_matmul_precision('high') 67 | model.to(device) 68 | model.compile_for_training() 69 | 70 | if args.freeze_language_model: 71 | model.model.token_rep_layer.bert_layer.model.requires_grad_(False) 72 | else: 73 | model.model.token_rep_layer.bert_layer.model.requires_grad_(True) 74 | 75 | if args.new_data_schema: 76 | train_dataset = GLiNERDataset(train_data, model_config, tokenizer, words_splitter) 77 | test_dataset = GLiNERDataset(test_data, model_config, tokenizer, words_splitter) 78 | data_collator = DataCollatorWithPadding(model_config) 79 | else: 80 | train_dataset = train_data 81 | test_dataset = test_data 82 | data_collator = DataCollator(model.config, data_processor=model.data_processor, prepare_labels=True) 83 | 84 | training_args = TrainingArguments( 85 | output_dir=config.log_dir, 86 | learning_rate=float(config.lr_encoder), 87 | weight_decay=float(config.weight_decay_encoder), 88 | others_lr=float(config.lr_others), 89 | others_weight_decay=float(config.weight_decay_other), 90 | focal_loss_gamma=config.loss_gamma, 91 | focal_loss_alpha=config.loss_alpha, 92 | lr_scheduler_type=config.scheduler_type, 93 | warmup_ratio=config.warmup_ratio, 94 | per_device_train_batch_size=config.train_batch_size, 95 | per_device_eval_batch_size=config.train_batch_size, 96 | max_grad_norm=config.max_grad_norm, 97 | max_steps=config.num_steps, 98 | evaluation_strategy="epoch", 99 | save_steps = config.eval_every, 100 | save_total_limit=config.save_total_limit, 101 | dataloader_num_workers = 8, 102 | use_cpu = False, 103 | report_to="none", 104 | bf16=True, 105 | ) 106 | 107 | trainer = Trainer( 108 | model=model, 109 | args=training_args, 110 | train_dataset=train_dataset, 111 | eval_dataset=test_dataset, 112 | tokenizer=tokenizer, 113 | data_collator=data_collator, 114 | ) 115 | trainer.train() 116 | --------------------------------------------------------------------------------