├── .gitignore
├── LICENSE
├── README.md
├── README_Extended.md
├── RELEASE.md
├── configs
├── config.yaml
├── config_biencoder.yaml
├── config_span.yaml
└── config_token.yaml
├── convert_to_onnx.py
├── custom_train.py
├── data
├── process_nuner.py
└── process_pilener.py
├── demo.jpg
├── demo.py
├── eval.py
├── examples
├── convert_to_onnx.ipynb
├── exal_example_conll.ipynb
├── finetune.ipynb
├── gliner_spacy_demo.ipynb
├── load_local_model.ipynb
├── quickstart.ipynb
├── sample_data.json
└── synthetic_data_generation.ipynb
├── gliner
├── __init__.py
├── config.py
├── data_processing
│ ├── __init__.py
│ ├── collator.py
│ ├── dataset.py
│ ├── processor.py
│ ├── tokenizer.py
│ └── utils.py
├── decoding
│ ├── __init__.py
│ ├── decoder.py
│ └── utils.py
├── evaluation
│ ├── __init__.py
│ ├── evaluate.py
│ └── evaluator.py
├── model.py
├── modeling
│ ├── __init__.py
│ ├── base.py
│ ├── encoder.py
│ ├── layers.py
│ ├── loss_functions.py
│ ├── scorers.py
│ └── span_rep.py
├── multitask
│ ├── __init__.py
│ ├── base.py
│ ├── classification.py
│ ├── open_extraction.py
│ ├── question_answering.py
│ ├── relation_extraction.py
│ └── summarization.py
├── onnx
│ ├── __init__.py
│ └── model.py
├── training
│ ├── __init__.py
│ └── trainer.py
└── utils.py
├── image.png
├── logo
├── FI Group.png
└── FI_COMPLET_CW.png
├── pyproject.toml
├── requirements.txt
├── tests
├── test_features_selection.py
└── test_models.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | ### Python ###
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | #data
12 | data.json
13 |
14 | #logs
15 | logs/
16 | models/
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | share/python-wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .nox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | *.py,cover
59 | .hypothesis/
60 | .pytest_cache/
61 | cover/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | .pybuilder/
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | # For a library or package, you might want to ignore these files since the code is
96 | # intended to run in multiple environments; otherwise, check them in:
97 | # .python-version
98 |
99 | # pipenv
100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
103 | # install all needed dependencies.
104 | #Pipfile.lock
105 |
106 | # poetry
107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108 | # This is especially recommended for binary packages to ensure reproducibility, and is more
109 | # commonly ignored for libraries.
110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111 | #poetry.lock
112 |
113 | # pdm
114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115 | #pdm.lock
116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117 | # in version control.
118 | # https://pdm.fming.dev/#use-with-ide
119 | .pdm.toml
120 |
121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122 | __pypackages__/
123 |
124 | # Celery stuff
125 | celerybeat-schedule
126 | celerybeat.pid
127 |
128 | # SageMath parsed files
129 | *.sage.py
130 |
131 | # Environments
132 | .env
133 | .venv
134 | env/
135 | venv/
136 | ENV/
137 | env.bak/
138 | venv.bak/
139 |
140 | # Spyder project settings
141 | .spyderproject
142 | .spyproject
143 |
144 | # Rope project settings
145 | .ropeproject
146 |
147 | # mkdocs documentation
148 | /site
149 |
150 | # mypy
151 | .mypy_cache/
152 | .dmypy.json
153 | dmypy.json
154 |
155 | # Pyre type checker
156 | .pyre/
157 |
158 | # pytype static type analyzer
159 | .pytype/
160 |
161 | # Cython debug symbols
162 | cython_debug/
163 |
164 | # PyCharm
165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167 | # and can be added to the global gitignore or merged into this file. For a more nuclear
168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169 | #.idea/
170 |
171 | ### Python Patch ###
172 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
173 | poetry.toml
174 |
175 | # ruff
176 | .ruff_cache/
177 |
178 | # LSP config files
179 | pyrightconfig.json
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 👑 GLiNER: Generalist and Lightweight Model for Named Entity Recognition
2 |
3 | GLiNER is a Named Entity Recognition (NER) model capable of identifying any entity type using a bidirectional transformer encoder (BERT-like). It provides a practical alternative to traditional NER models, which are limited to predefined entities, and Large Language Models (LLMs) that, despite their flexibility, are costly and large for resource-constrained scenarios.
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | 📄 Paper
14 | •
15 | 📢 Discord
16 | •
17 | 🤗 Demo
18 | •
19 | 🤗 Available models
20 | •
21 |
22 |
23 |
24 |
25 |
26 | ## Example Notebooks
27 |
28 | Explore various examples including finetuning, ONNX conversion, and synthetic data generation.
29 |
30 | - [Example Notebooks](https://github.com/urchade/GLiNER/tree/main/examples)
31 | - Finetune on Colab [
](https://colab.research.google.com/drive/1HNKd74cmfS9tGvWrKeIjSxBt01QQS7bq?usp=sharing)
32 | ## 🛠 Installation & Usage
33 |
34 | ### Installation
35 | ```bash
36 | !pip install gliner
37 | ```
38 |
39 | ### Usage
40 | After the installation of the GLiNER library, import the `GLiNER` class. Following this, you can load your chosen model with `GLiNER.from_pretrained` and utilize `predict_entities` to discern entities within your text.
41 |
42 | ```python
43 | from gliner import GLiNER
44 |
45 | # Initialize GLiNER with the base model
46 | model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1")
47 |
48 | # Sample text for entity prediction
49 | text = """
50 | Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time.
51 | """
52 |
53 | # Labels for entity prediction
54 | # Most GLiNER models should work best when entity types are in lower case or title case
55 | labels = ["Person", "Award", "Date", "Competitions", "Teams"]
56 |
57 | # Perform entity prediction
58 | entities = model.predict_entities(text, labels, threshold=0.5)
59 |
60 | # Display predicted entities and their labels
61 | for entity in entities:
62 | print(entity["text"], "=>", entity["label"])
63 | ```
64 |
65 | #### Expected Output
66 |
67 | ```
68 | Cristiano Ronaldo dos Santos Aveiro => person
69 | 5 February 1985 => date
70 | Al Nassr => teams
71 | Portugal national team => teams
72 | Ballon d'Or => award
73 | UEFA Men's Player of the Year Awards => award
74 | European Golden Shoes => award
75 | UEFA Champions Leagues => competitions
76 | UEFA European Championship => competitions
77 | UEFA Nations League => competitions
78 | European Championship => competitions
79 | ```
80 | ## 🌟 Maintainers
81 |
82 |
83 |
84 |
85 |
86 | Urchade Zaratiana
87 | PhD Student at LIPN
88 |
89 | |
90 |
91 | Ihor Stepanov
92 | Co-Founder at Knowledgator
93 |
94 | |
95 |
96 |
97 |
98 |
99 | ## 👨💻 Model Authors
100 | The model authors are:
101 | * [Urchade Zaratiana](https://huggingface.co/urchade)
102 | * Nadi Tomeh
103 | * Pierre Holat
104 | * Thierry Charnois
105 |
106 | ## 📚 Citation
107 |
108 | If you find GLiNER useful in your research, please consider citing our paper:
109 |
110 | ```bibtex
111 | @inproceedings{zaratiana-etal-2024-gliner,
112 | title = "{GL}i{NER}: Generalist Model for Named Entity Recognition using Bidirectional Transformer",
113 | author = "Zaratiana, Urchade and
114 | Tomeh, Nadi and
115 | Holat, Pierre and
116 | Charnois, Thierry",
117 | editor = "Duh, Kevin and
118 | Gomez, Helena and
119 | Bethard, Steven",
120 | booktitle = "Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers)",
121 | month = jun,
122 | year = "2024",
123 | address = "Mexico City, Mexico",
124 | publisher = "Association for Computational Linguistics",
125 | url = "https://aclanthology.org/2024.naacl-long.300",
126 | doi = "10.18653/v1/2024.naacl-long.300",
127 | pages = "5364--5376",
128 | abstract = "Named Entity Recognition (NER) is essential in various Natural Language Processing (NLP) applications. Traditional NER models are effective but limited to a set of predefined entity types. In contrast, Large Language Models (LLMs) can extract arbitrary entities through natural language instructions, offering greater flexibility. However, their size and cost, particularly for those accessed via APIs like ChatGPT, make them impractical in resource-limited scenarios. In this paper, we introduce a compact NER model trained to identify any type of entity. Leveraging a bidirectional transformer encoder, our model, GLiNER, facilitates parallel entity extraction, an advantage over the slow sequential token generation of LLMs. Through comprehensive testing, GLiNER demonstrate strong performance, outperforming both ChatGPT and fine-tuned LLMs in zero-shot evaluations on various NER benchmarks.",
129 | }
130 | ```
131 | ## Support and funding
132 |
133 | This project has been supported and funded by **F.initiatives** and **Laboratoire Informatique de Paris Nord**.
134 |
135 | F.initiatives has been an expert in public funding strategies for R&D, Innovation, and Investments (R&D&I) for over 20 years. With a team of more than 200 qualified consultants, F.initiatives guides its clients at every stage of developing their public funding strategy: from structuring their projects to submitting their aid application, while ensuring the translation of their industrial and technological challenges to public funders. Through its continuous commitment to excellence and integrity, F.initiatives relies on the synergy between methods and tools to offer tailored, high-quality, and secure support.
136 |
137 |
138 |
139 |
140 |
141 | We also extend our heartfelt gratitude to the open-source community for their invaluable contributions, which have been instrumental in the success of this project.
142 |
143 |
144 |
--------------------------------------------------------------------------------
/RELEASE.md:
--------------------------------------------------------------------------------
1 | # A guide to making a release
2 |
3 | This guide collects the steps we do in GLiNER to make a release on PyPI. They result from (variations of) hard-learned lessons and while following this guide is completely optional, it’s strongly recommended to do so. 🙂 This is a truncated version of the [SetFit](https://github.com/huggingface/setfit/blob/main/RELEASE.md) release guide, which is more exhaustive and does some additional steps.
4 |
5 | ### Preparation
6 |
7 | To be able to make a release for a given project, you’ll need an account on [PyPI](https://pypi.org/) and on [Test PyPI](https://test.pypi.org/). If you are making a release for an existing project, your username will need to be added to that project by one of the current maintainers on PyPI. Note that we strongly recommend enabling two-factor authentication on PyPI.
8 |
9 | You will also need to install twine in your Python environment with `pip install twine`.
10 |
11 | Additionally, it can be nice to familiarize yourself with [Semantic Versioning](https://semver.org/). This is a fairly strict document, but it provides a useful summary that library maintainers should follow:
12 |
13 | > Given a version number MAJOR.MINOR.PATCH, increment the:
14 | >
15 | > 1. MAJOR version when you make incompatible API changes
16 | > 2. MINOR version when you add functionality in a backward compatible manner
17 | > 3. PATCH version when you make backward compatible bug fixes
18 | >
19 | > Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format.
20 |
21 | The very first release should be "0.1.0".
22 |
23 | ## Releases
24 |
25 | ### Step 1: Adjust the version of your package
26 |
27 | You should have the current version specified in [`gliner/__init__.py`](gliner/__init__.py). This version should be a dev version (e.g. `0.1.0.dev`) before you release, change it to the name of the version you are releasing:
28 |
29 | ```diff
30 | - __version__ = "0.4.0.dev"
31 | + __version__ = "0.4.0"
32 | ```
33 |
34 | Commit the changes on your release branch and push them:
35 |
36 | ```bash
37 | git add gliner
38 | git commit -m "Release: v{VERSION}"
39 | git push -u origin main
40 | ```
41 |
42 | ### Step 2: (Optional) Make sure all tests pass
43 |
44 | If you add tests, then you should also add CI, e.g. like this [`tests.yaml`](https://github.com/tomaarsen/SpanMarkerNER/blob/main/.github/workflows/tests.yaml) file. This will automatically run tests whenever you make changes, it can be very useful. Make sure all tests that you may have pass before proceeding to the next step.
45 |
46 | ### Step 3: Add a tag for your release
47 |
48 | A tag will flag the exact commit associated to your release (and be easier to remember than the commit hash!). The tag should be `v` so for instance `v4.12.0`.
49 |
50 | Here is how you can create and push your tag:
51 |
52 | ```bash
53 | git tag v
54 | git push --tags origin main
55 | ```
56 |
57 | ### Step 4: (Optional) Prepare the release notes
58 |
59 | You can then put your release notes in a Draft Release on GitHub, in [https://github.com/urchade/GLiNER/releases](https://github.com/urchade/GLiNER/releases) and write a small paragraph highlighting each of the new features this release is adding.
60 |
61 | You can use the previously created tag to let GitHub auto-generate some release notes based on recent pull requests.
62 |
63 | ### Step 5: Create the wheels for your release
64 |
65 | This is what you'll upload on PyPI and what everyone will download each time they `pip install` your package.
66 |
67 | Clean previous builds by deleting the `build` and `dist` directories or by running:
68 |
69 | ```
70 | rm -rf build && rm -rf dist
71 | ```
72 |
73 | Then run:
74 |
75 | ```bash
76 | python -m build
77 | ```
78 |
79 | This will create two folders, `build` and a `dist` with the new versions of your package. These contain a 1) source distribution and a 2) wheel.
80 |
81 | ### Step 6: Upload your package on PyPI test
82 |
83 | **DO NOT SKIP THIS STEP!**
84 |
85 | This is the most important check before actually releasing your package in the wild. Upload the package on PyPI test and check you can properly install it.
86 |
87 | To upload it:
88 |
89 | ```bash
90 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
91 | ```
92 |
93 | You will be prompted for your username and password. If that doesn't work, you can create an API Token for your Test PyPI account and create a `~/.pypirc` account if it doesn't already exist, with:
94 |
95 | ```
96 | [distutils]
97 | index-servers =
98 | gliner_test
99 |
100 | [gliner_test]
101 | repository = https://test.pypi.org/legacy/
102 | username = __token__
103 | password = pypi-...
104 | ```
105 | (some more details on this [here](https://pypi.org/help/#apitoken))
106 |
107 | And then run:
108 | ```bash
109 | twine upload dist/* -r gliner_test
110 | ```
111 |
112 | Once that has uploaded the package, in a fresh environment containing all dependencies you need (tip: you can use Google Colab for this!), try to install your new package from the PyPI test server. First install all dependencies, and then your package.
113 |
114 | ```bash
115 | python -m pip install torch transformers huggingface_hub flair tqdm
116 | python -m pip install -i https://testpypi.python.org/pypi gliner
117 | ```
118 |
119 | If everything works, you should be able to run this code:
120 |
121 | ```python
122 | from gliner import GLiNER
123 |
124 | model = GLiNER.from_pretrained("urchade/gliner_base")
125 |
126 | text = """
127 | Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time.
128 | """
129 |
130 | labels = ["person", "award", "date", "competitions", "teams"]
131 |
132 | entities = model.predict_entities(text, labels, threshold=0.5)
133 |
134 | for entity in entities:
135 | print(entity["text"], "=>", entity["label"])
136 | ```
137 |
138 | ### Step 7: Publish on PyPI
139 |
140 | This cannot be undone if you messed up, so make sure you have run Step 6!
141 |
142 | Once you’re fully ready, upload your package on PyPI:
143 |
144 | ```bash
145 | twine upload dist/* -r pypi
146 | ```
147 |
148 | You will be prompted for your username and password, unless you're using the recommended [PyPI API token](https://pypi.org/help/#apitoken).
149 |
150 | ### Step 8: (Optional) Publish your release notes
151 |
152 | Go back to the draft you did at step 4 ([https://github.com/urchade/GLiNER/releases](https://github.com/urchade/GLiNER/releases)) and publish them.
153 |
154 | ### Step 9: Bump the dev version on the main branch
155 |
156 | You’re almost done! Just go back to the `main` branch and change the dev version in [`gliner/__init__.py`](gliner/__init__.py) to the new version you’re developing, for instance `4.13.0.dev` if just released `4.12.0`.
157 |
--------------------------------------------------------------------------------
/configs/config.yaml:
--------------------------------------------------------------------------------
1 | # Model Configuration
2 | model_name: microsoft/deberta-v3-small # Hugging Face model
3 | labels_encoder: "BAAI/bge-small-en-v1.5"
4 | name: "span level gliner"
5 | max_width: 12
6 | hidden_size: 768
7 | dropout: 0.3
8 | fine_tune: true
9 | subtoken_pooling: first
10 | fuse_layers: false
11 | post_fusion_schema: "l2l-l2t-t2t"
12 | span_mode: markerV0
13 |
14 | # Training Parameters
15 | num_steps: 100000
16 | train_batch_size: 8
17 | eval_every: 5000
18 | warmup_ratio: 0.05
19 | scheduler_type: "cosine"
20 |
21 | # loss function
22 | loss_alpha: 0.75
23 | loss_gamma: 0
24 | label_smoothing: 0
25 | loss_reduction: "sum"
26 |
27 | # Learning Rate and weight decay Configuration
28 | lr_encoder: 1e-5
29 | lr_others: 3e-5
30 | weight_decay_encoder: 0.1
31 | weight_decay_other: 0.01
32 |
33 | max_grad_norm: 10.0
34 |
35 | # Directory Paths
36 | root_dir: gliner_logs
37 | train_data: "data.json" #"data/nuner_train.json" # see https://github.com/urchade/GLiNER/tree/main/data
38 | val_data_dir: "none"
39 | # "NER_datasets": val data from the paper can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"
40 |
41 | # Pretrained Model Path
42 | # Use "none" if no pretrained model is being used
43 | prev_path: null
44 |
45 | save_total_limit: 3 #maximum amount of checkpoints to save
46 |
47 | # Advanced Training Settings
48 | size_sup: -1
49 | max_types: 100
50 | shuffle_types: true
51 | random_drop: true
52 | max_neg_type_ratio: 1
53 | max_len: 512
54 | freeze_token_rep: false
55 |
--------------------------------------------------------------------------------
/configs/config_biencoder.yaml:
--------------------------------------------------------------------------------
1 | # Model Configuration
2 | model_name: microsoft/deberta-v3-small # Hugging Face model
3 | labels_encoder: "microsoft/deberta-v3-small"
4 | name: "span level gliner"
5 | max_width: 12
6 | hidden_size: 768
7 | dropout: 0.4
8 | fine_tune: true
9 | subtoken_pooling: first
10 | fuse_layers: false
11 | post_fusion_schema: ""
12 | span_mode: markerV0
13 |
14 | # Training Parameters
15 | num_steps: 30000
16 | train_batch_size: 8
17 | eval_every: 1000
18 | warmup_ratio: 0.1
19 | scheduler_type: "cosine"
20 |
21 | # loss function
22 | loss_alpha: -1
23 | loss_gamma: 0
24 | label_smoothing: 0
25 | loss_reduction: "sum"
26 |
27 | # Learning Rate and weight decay Configuration
28 | lr_encoder: 1e-5
29 | lr_others: 5e-5
30 | weight_decay_encoder: 0.01
31 | weight_decay_other: 0.01
32 |
33 | max_grad_norm: 10.0
34 |
35 | # Directory Paths
36 | root_dir: gliner_logs
37 | train_data: "data.json" #"data/nuner_train.json" # see https://github.com/urchade/GLiNER/tree/main/data
38 | val_data_dir: "none"
39 | # "NER_datasets": val data from the paper can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"
40 |
41 | # Pretrained Model Path
42 | # Use "none" if no pretrained model is being used
43 | prev_path: null
44 |
45 | save_total_limit: 3 #maximum amount of checkpoints to save
46 |
47 | # Advanced Training Settings
48 | size_sup: -1
49 | max_types: 25
50 | shuffle_types: true
51 | random_drop: true
52 | max_neg_type_ratio: 1
53 | max_len: 386
54 | freeze_token_rep: false
55 |
--------------------------------------------------------------------------------
/configs/config_span.yaml:
--------------------------------------------------------------------------------
1 | # Model Configuration
2 | model_name: microsoft/deberta-v3-small # Hugging Face model
3 | name: "span level gliner"
4 | max_width: 12
5 | hidden_size: 768
6 | dropout: 0.4
7 | fine_tune: true
8 | subtoken_pooling: first
9 | span_mode: markerV0
10 |
11 | # Training Parameters
12 | num_steps: 30000
13 | train_batch_size: 8
14 | eval_every: 5000
15 | warmup_ratio: 0.1
16 | scheduler_type: "cosine"
17 |
18 | # loss function
19 | loss_alpha: -1 # focal loss alpha, if -1, no focal loss
20 | loss_gamma: 0 # focal loss gamma, if 0, no focal loss
21 | label_smoothing: 0
22 | loss_reduction: "sum"
23 |
24 | # Learning Rate and weight decay Configuration
25 | lr_encoder: 1e-5
26 | lr_others: 5e-5
27 | weight_decay_encoder: 0.01
28 | weight_decay_other: 0.01
29 |
30 | max_grad_norm: 1.0
31 |
32 | # Directory Paths
33 | root_dir: span_gliner_logs
34 | train_data: "data.json" # see https://github.com/urchade/GLiNER/tree/main/data
35 | val_data_dir: "none"
36 | # "NER_datasets": val data from the paper can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"
37 |
38 | # Pretrained Model Path
39 | # Use "none" if no pretrained model is being used
40 | prev_path: "none"
41 |
42 | save_total_limit: 10 #maximum amount of checkpoints to save
43 |
44 | # Advanced Training Settings
45 | size_sup: -1
46 | max_types: 25
47 | shuffle_types: true
48 | random_drop: true
49 | max_neg_type_ratio: 1
50 | max_len: 384
51 | freeze_token_rep: false
52 |
--------------------------------------------------------------------------------
/configs/config_token.yaml:
--------------------------------------------------------------------------------
1 | # Model Configuration
2 | model_name: microsoft/deberta-v3-small # Hugging Face model
3 | name: "token level gliner"
4 | max_width: 100
5 | hidden_size: 768
6 | dropout: 0.1
7 | fine_tune: true
8 | subtoken_pooling: first
9 | span_mode: token_level
10 |
11 | # Training Parameters
12 | num_steps: 30000
13 | train_batch_size: 8
14 | eval_every: 5000
15 | warmup_ratio: 0.1
16 | scheduler_type: "cosine"
17 |
18 | # loss function
19 | loss_alpha: -1 # focal loss alpha, if -1, no focal loss
20 | loss_gamma: 0 # focal loss gamma, if 0, no focal loss
21 | label_smoothing: 0
22 | loss_reduction: "sum"
23 |
24 | # Learning Rate and weight decay Configuration
25 | lr_encoder: 1e-5
26 | lr_others: 5e-5
27 | weight_decay_encoder: 0.01
28 | weight_decay_other: 0.01
29 |
30 | max_grad_norm: 1.0
31 |
32 | # Directory Paths
33 | root_dir: gliner_logs
34 | train_data: "train.json" # see https://github.com/urchade/GLiNER/tree/main/data
35 | val_data_dir: "NER_datasets"
36 | # "NER_datasets": val data from the paper can be obtained from "https://drive.google.com/file/d/1T-5IbocGka35I7X3CE6yKe5N_Xg2lVKT/view"
37 |
38 | # Pretrained Model Path
39 | # Use "none" if no pretrained model is being used
40 | prev_path: "none"
41 |
42 | save_total_limit: 10 #maximum amount of checkpoints to save
43 |
44 | # Advanced Training Settings
45 | size_sup: -1
46 | max_types: 25
47 | shuffle_types: true
48 | random_drop: true
49 | max_neg_type_ratio: 1
50 | max_len: 384
51 | freeze_token_rep: false
52 |
53 |
--------------------------------------------------------------------------------
/convert_to_onnx.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 |
5 | from gliner import GLiNER
6 |
7 | import torch
8 | from onnxruntime.quantization import quantize_dynamic, QuantType
9 |
10 | if __name__ == "__main__":
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--model_path', type=str, default= "logs/model_12000")
13 | parser.add_argument('--save_path', type=str, default = 'model/')
14 | parser.add_argument('--quantize', type=bool, default = True)
15 | args = parser.parse_args()
16 |
17 | if not os.path.exists(args.save_path):
18 | os.makedirs(args.save_path)
19 |
20 | onnx_save_path = os.path.join(args.save_path, "model.onnx")
21 |
22 | print("Loading a model...")
23 | gliner_model = GLiNER.from_pretrained(args.model_path, load_tokenizer=True)
24 |
25 | text = "ONNX is an open-source format designed to enable the interoperability of AI models across various frameworks and tools."
26 | labels = ['format', 'model', 'tool', 'cat']
27 |
28 | inputs, _ = gliner_model.prepare_model_inputs([text], labels)
29 |
30 | if gliner_model.config.span_mode == 'token_level':
31 | all_inputs = (inputs['input_ids'], inputs['attention_mask'],
32 | inputs['words_mask'], inputs['text_lengths'])
33 | input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths']
34 | dynamic_axes={
35 | "input_ids": {0: "batch_size", 1: "sequence_length"},
36 | "attention_mask": {0: "batch_size", 1: "sequence_length"},
37 | "words_mask": {0: "batch_size", 1: "sequence_length"},
38 | "text_lengths": {0: "batch_size", 1: "value"},
39 | "logits": {0: "position", 1: "batch_size", 2: "sequence_length", 3: "num_classes"},
40 | }
41 | else:
42 | all_inputs = (inputs['input_ids'], inputs['attention_mask'],
43 | inputs['words_mask'], inputs['text_lengths'],
44 | inputs['span_idx'], inputs['span_mask'])
45 | input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths', 'span_idx', 'span_mask']
46 | dynamic_axes={
47 | "input_ids": {0: "batch_size", 1: "sequence_length"},
48 | "attention_mask": {0: "batch_size", 1: "sequence_length"},
49 | "words_mask": {0: "batch_size", 1: "sequence_length"},
50 | "text_lengths": {0: "batch_size", 1: "value"},
51 | "span_idx": {0: "batch_size", 1: "num_spans", 2: "idx"},
52 | "span_mask": {0: "batch_size", 1: "num_spans"},
53 | "logits": {0: "batch_size", 1: "sequence_length", 2: "num_spans", 3: "num_classes"},
54 | }
55 | print('Converting the model...')
56 | torch.onnx.export(
57 | gliner_model.model,
58 | all_inputs,
59 | f=onnx_save_path,
60 | input_names=input_names,
61 | output_names=["logits"],
62 | dynamic_axes=dynamic_axes,
63 | opset_version=14,
64 | )
65 |
66 | if args.quantize:
67 | quantized_save_path = os.path.join(args.save_path, "model_quantized.onnx")
68 | # Quantize the ONNX model
69 | print("Quantizing the model...")
70 | quantize_dynamic(
71 | onnx_save_path, # Input model
72 | quantized_save_path, # Output model
73 | weight_type=QuantType.QUInt8 # Quantize weights to 8-bit integers
74 | )
75 | print("Done!")
--------------------------------------------------------------------------------
/data/process_nuner.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import re
3 | import ast
4 | import json
5 | from tqdm import tqdm
6 |
7 |
8 | def tokenize_text(text):
9 | """Tokenizes the input text into a list of tokens."""
10 | return re.findall(r'\w+(?:[-_]\w+)*|\S', text)
11 |
12 |
13 | def process_entities(dataset):
14 | """Processes entities in the dataset to extract tokenized text and named entity spans."""
15 | all_data = []
16 | for el in tqdm(dataset["entity"]):
17 | try:
18 | tokenized_text = tokenize_text(el["input"])
19 | parsed_output = ast.literal_eval(el["output"])
20 | entity_texts, entity_types = zip(*[i.split(" <> ") for i in parsed_output])
21 |
22 | entity_spans = []
23 | for j, entity_text in enumerate(entity_texts):
24 | entity_tokens = tokenize_text(entity_text)
25 | matches = []
26 | for i in range(len(tokenized_text) - len(entity_tokens) + 1):
27 | if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
28 | matches.append((i, i + len(entity_tokens) - 1, entity_types[j]))
29 | if matches:
30 | entity_spans.extend(matches)
31 |
32 | except Exception as e:
33 | continue
34 |
35 | all_data.append({"tokenized_text": tokenized_text, "ner": entity_spans})
36 | return all_data
37 |
38 |
39 | def save_data_to_file(data, filepath):
40 | """Saves the processed data to a JSON file."""
41 | with open(filepath, 'w') as f:
42 | json.dump(data, f)
43 |
44 |
45 | if __name__ == "__main__":
46 | dataset = load_dataset("numind/NuNER")
47 | processed_data = process_entities(dataset)
48 |
49 | save_data_to_file(processed_data, 'nuner_train.json')
50 |
51 | print("dataset size:", len(processed_data))
--------------------------------------------------------------------------------
/data/process_pilener.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import ast
4 | from tqdm import tqdm
5 |
6 | def load_data(filepath):
7 | """Loads data from a JSON file."""
8 | with open(filepath, 'r') as f:
9 | data = json.load(f)
10 | return data
11 |
12 | def tokenize_text(text):
13 | """Tokenizes the input text into a list of tokens."""
14 | return re.findall(r'\w+(?:[-_]\w+)*|\S', text)
15 |
16 | def extract_entity_spans(entry):
17 | """Extracts entity spans from an entry."""
18 | len_start = len("What describes ")
19 | len_end = len(" in the text?")
20 | entity_types, entity_texts, negative = [], [], []
21 |
22 | for c in entry['conversations']:
23 | if c['from'] == 'human' and c['value'].startswith('Text: '):
24 | text = c['value'][len('Text: '):]
25 | tokenized_text = tokenize_text(text)
26 | elif c['from'] == 'human' and c['value'].startswith('What describes '):
27 | entity_type = c['value'][len_start:-len_end]
28 | entity_types.append(entity_type)
29 | elif c['from'] == 'gpt' and c['value'].startswith('['):
30 | if c['value'] == '[]':
31 | negative.append(entity_types.pop())
32 | continue
33 | texts_ents = ast.literal_eval(c['value'])
34 | entity_texts.extend(texts_ents)
35 | num_repeat = len(texts_ents) - 1
36 | entity_types.extend([entity_types[-1]] * num_repeat)
37 |
38 | entity_spans = []
39 | for j, entity_text in enumerate(entity_texts):
40 | entity_tokens = tokenize_text(entity_text)
41 | matches = []
42 | for i in range(len(tokenized_text) - len(entity_tokens) + 1):
43 | if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
44 | matches.append((i, i + len(entity_tokens) - 1, entity_types[j]))
45 | if matches:
46 | entity_spans.extend(matches)
47 |
48 | return {"tokenized_text": tokenized_text, "ner": entity_spans, "negative": negative}
49 |
50 | def process_data(data):
51 | """Processes a list of data entries to extract entity spans."""
52 | all_data = [extract_entity_spans(entry) for entry in tqdm(data)]
53 | return all_data
54 |
55 | def save_data_to_file(data, filepath):
56 | """Saves the processed data to a JSON file."""
57 | with open(filepath, 'w') as f:
58 | json.dump(data, f)
59 |
60 | if __name__ == "__main__":
61 | # download the pile-ner data: "wget https://huggingface.co/datasets/Universal-NER/Pile-NER-type/blob/main/train.json"
62 | path_pile_ner = 'train.json'
63 | data = load_data(path_pile_ner)
64 | processed_data = process_data(data)
65 | save_data_to_file(processed_data, 'pilener_train.json')
66 |
67 | print("dataset size:", len(processed_data))
--------------------------------------------------------------------------------
/demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/urchade/GLiNER/efbfa38211136657895372d33d4ee2fe11b6f11b/demo.jpg
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union
2 | from gliner import GLiNER
3 | import gradio as gr
4 |
5 | model = GLiNER.from_pretrained("model/", load_tokenizer=True)
6 |
7 | examples = [
8 | [
9 | "Libretto by Marius Petipa, based on the 1822 novella ``Trilby, ou Le Lutin d'Argail`` by Charles Nodier, first presented by the Ballet of the Moscow Imperial Bolshoi Theatre on January 25/February 6 (Julian/Gregorian calendar dates), 1870, in Moscow with Polina Karpakova as Trilby and Ludiia Geiten as Miranda and restaged by Petipa for the Imperial Ballet at the Imperial Bolshoi Kamenny Theatre on January 17–29, 1871 in St. Petersburg with Adèle Grantzow as Trilby and Lev Ivanov as Count Leopold.",
10 | "person, book, location, date, actor, character",
11 | 0.3,
12 | True,
13 | ],
14 | [
15 | """
16 | * Data Scientist, Data Analyst, or Data Engineer with 1+ years of experience.
17 | * Experience with technologies such as Docker, Kubernetes, or Kubeflow
18 | * Machine Learning experience preferred
19 | * Experience with programming languages such as Python, C++, or SQL preferred
20 | * Experience with technologies such as Databricks, Qlik, TensorFlow, PyTorch, Python, Dash, Pandas, or NumPy preferred
21 | * BA or BS degree
22 | * Active Secret OR Active Top Secret or Active TS/SCI clearance
23 | """,
24 | "software package, programing language, software tool, degree, job title",
25 | 0.3,
26 | False,
27 | ],
28 | [
29 | "However, both models lack other frequent DM symptoms including the fibre-type dependent atrophy, myotonia, cataract and male-infertility.",
30 | "disease, symptom",
31 | 0.3,
32 | False,
33 | ],
34 | [
35 | "Synergy between signal transduction pathways is obligatory for expression of c-fos in B and T cell lines: implication for c-fos control via surface immunoglobulin and T cell antigen receptors.",
36 | "DNA, RNA, cell line, cell type, protein",
37 | 0.3,
38 | False,
39 | ],
40 | [
41 | "The choice of the encoder and decoder modules of dnpg can be quite flexible, for instance long short term memory networks (lstm) or convolutional neural network (cnn).",
42 | "short acronym, long acronym",
43 | 0.3,
44 | False,
45 | ],
46 | [
47 | "Amelia Earhart flew her single engine Lockheed Vega 5B across the Atlantic to Paris.",
48 | "person, company, location, airplane",
49 | 0.3,
50 | True,
51 | ],
52 | [
53 | "Feldman is a contributor to NBC Sports Boston's ``State of the Revs`` and ``Revolution Postgame Live`` programs as well as to 98.5 the SportsHub, SiriusXM FC's MLS coverage and to other New England and national radio outlets and podcasts.",
54 | "person, company, location",
55 | 0.3,
56 | False,
57 | ],
58 | [
59 | "On 25 July 1948, on the 39th anniversary of Bleriot's crossing of the English Channel, the Type 618 Nene-Viking flew Heathrow to Paris (Villacoublay) in the morning carrying letters to Bleriot's widow and son (secretary of the FAI), who met it at the airport.",
60 | "date, location, person, organization",
61 | 0.3,
62 | False,
63 | ],
64 | [
65 | "Leo & Ian won the 1962 Bathurst Six Hour Classic at Mount Panorama driving a Daimler SP250 sports car, (that year the 500 mile race for touring cars were held at Phillip Island)",
66 | "person, date, location, organization, competition",
67 | 0.3,
68 | False,
69 | ],
70 | [
71 | "The Shore Line route of the CNS & M until 1955 served, from south to north, the Illinois communities of Chicago, Evanston, Wilmette, Kenilworth, Winnetka, Glencoe, Highland Park, Highwood, Fort Sheridan, Lake Forest, Lake Bluff, North Chicago, Waukegan, Zion, and Winthrop Harbor as well as Kenosha, Racine, and Milwaukee (the ``KRM'') in Wisconsin.",
72 | "location, organization, date",
73 | 0.3,
74 | False,
75 | ],
76 | [
77 | "Comet C/2006 M4 (SWAN) is a non-periodic comet discovered in late June 2006 by Robert D. Matson of Irvine, California and Michael Mattiazzo of Adelaide, South Australia in publicly available images of the Solar and Heliospheric Observatory (SOHO).",
78 | "person, organization, date, location",
79 | 0.3,
80 | False,
81 | ],
82 | [
83 | "From November 29, 2011 to March 31, 2012, Karimloo returned to ``Les Misérables`` to play the lead role of Jean Valjean at The Queen's Theatre, London, for which he won the 2013 Theatregoers' Choice Award for Best Takeover in a Role.",
84 | "person, actor, award, date, location",
85 | 0.3,
86 | False,
87 | ],
88 | [
89 | "A Mexicali health clinic supported by former Baja California gubernatorial candidate Enrique Acosta Fregoso (PRI) was closed on June 15 after selling a supposed COVID-19 ``cure'' for between MXN $10,000 and $50,000.",
90 | "location, organization, person, date, currency",
91 | 0.3,
92 | False,
93 | ],
94 | [
95 | "Built in 1793, it was the home of Mary Young Pickersgill when she moved to Baltimore in 1806 and the location where she later sewed the ``Star Spangled Banner'', in 1813, the huge out-sized garrison flag that flew over Fort McHenry at Whetstone Point in Baltimore Harbor in the summer of 1814 during the British Royal Navy attack in the Battle of Baltimore during the War of 1812.",
96 | "date, person, location, organization, event, flag",
97 | 0.3,
98 | False,
99 | ],
100 | ]
101 |
102 |
103 | def ner(
104 | text, labels: str, threshold: float, nested_ner: bool
105 | ) -> Dict[str, Union[str, int, float]]:
106 | labels = labels.split(",")
107 | return {
108 | "text": text,
109 | "entities": [
110 | {
111 | "entity": entity["label"],
112 | "word": entity["text"],
113 | "start": entity["start"],
114 | "end": entity["end"],
115 | "score": 0,
116 | }
117 | for entity in model.predict_entities(
118 | text, labels, flat_ner=not nested_ner, threshold=threshold
119 | )
120 | ],
121 | }
122 |
123 |
124 | with gr.Blocks(title="GLiNER-M-v2.1") as demo:
125 | gr.Markdown(
126 | """
127 | # GLiNER-base
128 | GLiNER is a Named Entity Recognition (NER) model capable of identifying any entity type using a bidirectional transformer encoder (BERT-like). It provides a practical alternative to traditional NER models, which are limited to predefined entities, and Large Language Models (LLMs) that, despite their flexibility, are costly and large for resource-constrained scenarios.
129 | ## Links
130 | * Model: https://huggingface.co/urchade/gliner_multi-v2.1
131 | * All GLiNER models: https://huggingface.co/models?library=gliner
132 | * Paper: https://arxiv.org/abs/2311.08526
133 | * Repository: https://github.com/urchade/GLiNER
134 | """
135 | )
136 | with gr.Accordion("How to run this model locally", open=False):
137 | gr.Markdown(
138 | """
139 | ## Installation
140 | To use this model, you must install the GLiNER Python library:
141 | ```
142 | !pip install gliner
143 | ```
144 |
145 | ## Usage
146 | Once you've downloaded the GLiNER library, you can import the GLiNER class. You can then load this model using `GLiNER.from_pretrained` and predict entities with `predict_entities`.
147 | """
148 | )
149 | gr.Code(
150 | '''
151 | from gliner import GLiNER
152 | model = GLiNER.from_pretrained("urchade/gliner_mediumv2.1")
153 | text = """
154 | Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time.
155 | """
156 | labels = ["person", "award", "date", "competitions", "teams"]
157 | entities = model.predict_entities(text, labels)
158 | for entity in entities:
159 | print(entity["text"], "=>", entity["label"])
160 | ''',
161 | language="python",
162 | )
163 | gr.Code(
164 | """
165 | Cristiano Ronaldo dos Santos Aveiro => person
166 | 5 February 1985 => date
167 | Al Nassr => teams
168 | Portugal national team => teams
169 | Ballon d'Or => award
170 | UEFA Men's Player of the Year Awards => award
171 | European Golden Shoes => award
172 | UEFA Champions Leagues => competitions
173 | UEFA European Championship => competitions
174 | UEFA Nations League => competitions
175 | Champions League => competitions
176 | European Championship => competitions
177 | """
178 | )
179 |
180 | input_text = gr.Textbox(
181 | value=examples[0][0], label="Text input", placeholder="Enter your text here"
182 | )
183 | with gr.Row() as row:
184 | labels = gr.Textbox(
185 | value=examples[0][1],
186 | label="Labels",
187 | placeholder="Enter your labels here (comma separated)",
188 | scale=2,
189 | )
190 | threshold = gr.Slider(
191 | 0,
192 | 1,
193 | value=0.3,
194 | step=0.01,
195 | label="Threshold",
196 | info="Lower the threshold to increase how many entities get predicted.",
197 | scale=1,
198 | )
199 | nested_ner = gr.Checkbox(
200 | value=examples[0][2],
201 | label="Nested NER",
202 | info="Allow for nested NER?",
203 | scale=0,
204 | )
205 | output = gr.HighlightedText(label="Predicted Entities")
206 | submit_btn = gr.Button("Submit")
207 | examples = gr.Examples(
208 | examples,
209 | fn=ner,
210 | inputs=[input_text, labels, threshold, nested_ner],
211 | outputs=output,
212 | cache_examples=True,
213 | )
214 |
215 | # Submitting
216 | input_text.submit(
217 | fn=ner, inputs=[input_text, labels, threshold, nested_ner], outputs=output
218 | )
219 | labels.submit(
220 | fn=ner, inputs=[input_text, labels, threshold, nested_ner], outputs=output
221 | )
222 | threshold.release(
223 | fn=ner, inputs=[input_text, labels, threshold, nested_ner], outputs=output
224 | )
225 | submit_btn.click(
226 | fn=ner, inputs=[input_text, labels, threshold, nested_ner], outputs=output
227 | )
228 | nested_ner.change(
229 | fn=ner, inputs=[input_text, labels, threshold, nested_ner], outputs=output
230 | )
231 |
232 | demo.queue()
233 | demo.launch(debug=True)
234 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from gliner import GLiNER
4 | from gliner.evaluation import get_for_all_path
5 |
6 |
7 | def create_parser():
8 | parser = argparse.ArgumentParser(description="Span-based NER")
9 | parser.add_argument("--model", type=str, default="logs/model_12000", help="Path to model folder")
10 | parser.add_argument("--log_dir", type=str, default="logs", help="Path to model folder")
11 | parser.add_argument('--data', type=str, default='data/ie_data/NER/', help='Path to the eval datasets directory')
12 | return parser
13 |
14 |
15 | if __name__ == "__main__":
16 | parser = create_parser()
17 | args = parser.parse_args()
18 |
19 | model = GLiNER.from_pretrained(args.model, load_tokenizer=True).to("cuda:0")
20 | get_for_all_path(model, -1, args.log_dir, args.data)
--------------------------------------------------------------------------------
/examples/convert_to_onnx.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# !pip install onnx"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import torch\n",
19 | "from gliner import GLiNER"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "model = GLiNER.from_pretrained(\"urchade/gliner_medium\")"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "# save\n",
38 | "\n",
39 | "model.save_pretrained(\"gliner_medium\")"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "gliner_model = GLiNER.from_pretrained(\"gliner_medium\", load_tokenizer=True)"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "import os\n",
58 | "\n",
59 | "onnx_save_path = os.path.join(\"gliner_medium\", \"model.onnx\")"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "text = \"ONNX is an open-source format designed to enable the interoperability of AI models across various frameworks and tools.\"\n",
69 | "labels = ['format', 'model', 'tool', 'cat']\n",
70 | "\n",
71 | "inputs, _ = gliner_model.prepare_model_inputs([text], labels)"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "if gliner_model.config.span_mode == 'token_level':\n",
81 | " all_inputs = (inputs['input_ids'], inputs['attention_mask'], \n",
82 | " inputs['words_mask'], inputs['text_lengths'])\n",
83 | " input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths']\n",
84 | " dynamic_axes={\n",
85 | " \"input_ids\": {0: \"batch_size\", 1: \"sequence_length\"},\n",
86 | " \"attention_mask\": {0: \"batch_size\", 1: \"sequence_length\"},\n",
87 | " \"words_mask\": {0: \"batch_size\", 1: \"sequence_length\"},\n",
88 | " \"text_lengths\": {0: \"batch_size\", 1: \"value\"},\n",
89 | " \"logits\": {0: \"position\", 1: \"batch_size\", 2: \"sequence_length\", 3: \"num_classes\"},\n",
90 | " }\n",
91 | "else:\n",
92 | " all_inputs = (inputs['input_ids'], inputs['attention_mask'], \n",
93 | " inputs['words_mask'], inputs['text_lengths'],\n",
94 | " inputs['span_idx'], inputs['span_mask'])\n",
95 | " input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths', 'span_idx', 'span_mask']\n",
96 | " dynamic_axes={\n",
97 | " \"input_ids\": {0: \"batch_size\", 1: \"sequence_length\"},\n",
98 | " \"attention_mask\": {0: \"batch_size\", 1: \"sequence_length\"},\n",
99 | " \"words_mask\": {0: \"batch_size\", 1: \"sequence_length\"},\n",
100 | " \"text_lengths\": {0: \"batch_size\", 1: \"value\"},\n",
101 | " \"span_idx\": {0: \"batch_size\", 1: \"num_spans\", 2: \"idx\"},\n",
102 | " \"span_mask\": {0: \"batch_size\", 1: \"num_spans\"},\n",
103 | " \"logits\": {0: \"batch_size\", 1: \"sequence_length\", 2: \"num_spans\", 3: \"num_classes\"},\n",
104 | " }\n",
105 | "print('Converting the model...')\n",
106 | "all_inputs = dict(zip(input_names,all_inputs))\n",
107 | "\n",
108 | "torch.onnx.export(\n",
109 | " gliner_model.model,\n",
110 | " all_inputs,\n",
111 | " f=onnx_save_path,\n",
112 | " input_names=input_names,\n",
113 | " output_names=[\"logits\"],\n",
114 | " dynamic_axes=dynamic_axes,\n",
115 | " opset_version=14,\n",
116 | ")\n"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "#quantize model\n",
126 | "from onnxruntime.quantization import quantize_dynamic, QuantType\n",
127 | "\n",
128 | "quantized_save_path = os.path.join(\"gliner_medium\", \"model_quantized.onnx\")\n",
129 | "# Quantize the ONNX model\n",
130 | "print(\"Quantizing the model...\")\n",
131 | "quantize_dynamic(\n",
132 | " onnx_save_path, # Input model\n",
133 | " quantized_save_path, # Output model\n",
134 | " weight_type=QuantType.QUInt8 # Quantize weights to 8-bit integers\n",
135 | ")"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "# load onnx model\n",
145 | "model = GLiNER.from_pretrained(\"gliner_medium\", load_onnx_model=True, load_tokenizer=True)\n"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": null,
151 | "metadata": {},
152 | "outputs": [],
153 | "source": [
154 | "text = \"\"\"\n",
155 | "Libretto by Marius Petipa, based on the 1822 novella ``Trilby, ou Le Lutin d'Argail`` by Charles Nodier, first presented by the Ballet of the Moscow Imperial Bolshoi Theatre on January 25/February 6 (Julian/Gregorian calendar dates), 1870, in Moscow with Polina Karpakova as Trilby and Ludiia Geiten as Miranda and restaged by Petipa for the Imperial Ballet at the Imperial Bolshoi Kamenny Theatre on January 17–29, 1871 in St. Petersburg with Adèle Grantzow as Trilby and Lev Ivanov as Count Leopold.\n",
156 | "\"\"\"\n",
157 | "\n",
158 | "labels = [\"person\", \"book\", \"location\", \"date\", \"actor\", \"character\"]\n",
159 | "\n",
160 | "entities = model.predict_entities(text, labels, threshold=0.4)\n",
161 | "\n",
162 | "for entity in entities:\n",
163 | " print(entity[\"text\"], \"=>\", entity[\"label\"])"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": null,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "# load quantized model\n",
173 | "model = GLiNER.from_pretrained(\"gliner_medium\", load_onnx_model=True, load_tokenizer=True, onnx_model_file=\"model_quantized.onnx\")\n"
174 | ]
175 | },
176 | {
177 | "cell_type": "code",
178 | "execution_count": null,
179 | "metadata": {},
180 | "outputs": [],
181 | "source": [
182 | "text = \"\"\"\n",
183 | "Libretto by Marius Petipa, based on the 1822 novella ``Trilby, ou Le Lutin d'Argail`` by Charles Nodier, first presented by the Ballet of the Moscow Imperial Bolshoi Theatre on January 25/February 6 (Julian/Gregorian calendar dates), 1870, in Moscow with Polina Karpakova as Trilby and Ludiia Geiten as Miranda and restaged by Petipa for the Imperial Ballet at the Imperial Bolshoi Kamenny Theatre on January 17–29, 1871 in St. Petersburg with Adèle Grantzow as Trilby and Lev Ivanov as Count Leopold.\n",
184 | "\"\"\"\n",
185 | "\n",
186 | "labels = [\"person\", \"book\", \"location\", \"date\", \"actor\", \"character\"]\n",
187 | "\n",
188 | "entities = model.predict_entities(text, labels, threshold=0.4)\n",
189 | "\n",
190 | "for entity in entities:\n",
191 | " print(entity[\"text\"], \"=>\", entity[\"label\"])"
192 | ]
193 | }
194 | ],
195 | "metadata": {
196 | "kernelspec": {
197 | "display_name": "base",
198 | "language": "python",
199 | "name": "python3"
200 | },
201 | "language_info": {
202 | "codemirror_mode": {
203 | "name": "ipython",
204 | "version": 3
205 | },
206 | "file_extension": ".py",
207 | "mimetype": "text/x-python",
208 | "name": "python",
209 | "nbconvert_exporter": "python",
210 | "pygments_lexer": "ipython3",
211 | "version": "3.10.10"
212 | },
213 | "orig_nbformat": 4
214 | },
215 | "nbformat": 4,
216 | "nbformat_minor": 2
217 | }
218 |
--------------------------------------------------------------------------------
/examples/exal_example_conll.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "f2087a56",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "!pip install datasets"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 10,
16 | "id": "24a58336-3b16-491b-8646-2f54e93a8964",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from datasets import load_dataset"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 11,
26 | "id": "df670cb5-24cb-4683-849e-2e27769dd762",
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "def ner_tags_to_spans(samples, tag_to_id):\n",
31 | " \"\"\"\n",
32 | " Converts NER tags in the dataset samples to spans (start, end, entity type).\n",
33 | " \n",
34 | " Args:\n",
35 | " samples (dict): A dictionary containing the tokens and NER tags.\n",
36 | " tag_to_id (dict): A dictionary mapping NER tags to IDs.\n",
37 | " \n",
38 | " Returns:\n",
39 | " dict: A dictionary containing tokenized text and corresponding NER spans.\n",
40 | " \"\"\"\n",
41 | " ner_tags = samples[\"ner_tags\"]\n",
42 | " id_to_tag = {v: k for k, v in tag_to_id.items()}\n",
43 | " spans = []\n",
44 | " start_pos = None\n",
45 | " entity_name = None\n",
46 | "\n",
47 | " for i, tag in enumerate(ner_tags):\n",
48 | " if tag == 0: # 'O' tag\n",
49 | " if entity_name is not None:\n",
50 | " spans.append((start_pos, i - 1, entity_name))\n",
51 | " entity_name = None\n",
52 | " start_pos = None\n",
53 | " else:\n",
54 | " tag_name = id_to_tag[tag]\n",
55 | " if tag_name.startswith('B-'):\n",
56 | " if entity_name is not None:\n",
57 | " spans.append((start_pos, i - 1, entity_name))\n",
58 | " entity_name = tag_name[2:]\n",
59 | " start_pos = i\n",
60 | " elif tag_name.startswith('I-'):\n",
61 | " continue\n",
62 | "\n",
63 | " # Handle the last entity if the sentence ends with an entity\n",
64 | " if entity_name is not None:\n",
65 | " spans.append((start_pos, len(samples[\"tokens\"]) - 1, entity_name))\n",
66 | " \n",
67 | " return {\"tokenized_text\": samples[\"tokens\"], \"ner\": spans}"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": null,
73 | "id": "971f92b9-ece2-460d-99f8-73277b5d3081",
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "# step 1: load data\n",
78 | "dataset = load_dataset(\"eriktks/conll2003\")"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 13,
84 | "id": "67a18f87-1571-4e8c-8253-6e0305bfa0cb",
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "# Step 2: Define NER tag-to-ID mapping\n",
89 | "tag_to_id = {\n",
90 | " 'O': 0, 'B-person': 1, 'I-person': 2, 'B-organization': 3, 'I-organization': 4,\n",
91 | " 'B-location': 5, 'I-location': 6, 'B-others': 7, 'I-others': 8\n",
92 | "}"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 14,
98 | "id": "354aae86-2e5f-4a82-821b-6baba9438532",
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "# Convert NER tags to spans for the training data\n",
103 | "gliner_data_conll = [ner_tags_to_spans(i, tag_to_id) for i in dataset['train']]"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 15,
109 | "id": "7c717148-7c98-4998-90fc-c244e11d7b67",
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "# Load the pre-trained GLiNER model\n",
114 | "from gliner import GLiNER\n",
115 | "import torch\n",
116 | "\n",
117 | "model = GLiNER.from_pretrained(\"urchade/gliner_small\", load_tokenizer=True) #true if a model was trained from scratch with new code base\n",
118 | "\n",
119 | "if torch.cuda.is_available():\n",
120 | " device = \"cuda\"\n",
121 | "else:\n",
122 | " device = \"cpu\"\n",
123 | "\n",
124 | "model = model.to(device)"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 17,
130 | "id": "601c2e03-2fe7-481f-b769-c8c874bee9c6",
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "# Evaluate the model on the first 100 samples\n",
135 | "evaluation_results = model.evaluate(\n",
136 | " gliner_data_conll[:100], flat_ner=True, entity_types=[\"person\", \"organization\", \"location\", \"others\"]\n",
137 | ")"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 18,
143 | "id": "273d79be-2f0f-4191-bc8e-641854ffa540",
144 | "metadata": {},
145 | "outputs": [
146 | {
147 | "name": "stdout",
148 | "output_type": "stream",
149 | "text": [
150 | "('P: 63.13%\\tR: 71.43%\\tF1: 67.02%\\n', 0.6702412868632708)\n"
151 | ]
152 | }
153 | ],
154 | "source": [
155 | "print(evaluation_results)"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "id": "e8c6190d-6a63-43c0-9010-27d141970877",
162 | "metadata": {},
163 | "outputs": [],
164 | "source": []
165 | }
166 | ],
167 | "metadata": {
168 | "kernelspec": {
169 | "display_name": "Python 3 (ipykernel)",
170 | "language": "python",
171 | "name": "python3"
172 | },
173 | "language_info": {
174 | "codemirror_mode": {
175 | "name": "ipython",
176 | "version": 3
177 | },
178 | "file_extension": ".py",
179 | "mimetype": "text/x-python",
180 | "name": "python",
181 | "nbconvert_exporter": "python",
182 | "pygments_lexer": "ipython3",
183 | "version": "3.8.18"
184 | }
185 | },
186 | "nbformat": 4,
187 | "nbformat_minor": 5
188 | }
189 |
--------------------------------------------------------------------------------
/examples/gliner_spacy_demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stderr",
10 | "output_type": "stream",
11 | "text": [
12 | "/Applications/anaconda3/envs/gliner-spacy/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13 | " from .autonotebook import tqdm as notebook_tqdm\n"
14 | ]
15 | }
16 | ],
17 | "source": [
18 | "import spacy\n",
19 | "from gliner_spacy.pipeline import GlinerSpacy"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {},
26 | "outputs": [
27 | {
28 | "name": "stderr",
29 | "output_type": "stream",
30 | "text": [
31 | "/Applications/anaconda3/envs/gliner-spacy/lib/python3.10/site-packages/transformers/convert_slow_tokenizer.py:550: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
32 | " warnings.warn(\n"
33 | ]
34 | },
35 | {
36 | "data": {
37 | "text/plain": [
38 | ""
39 | ]
40 | },
41 | "execution_count": 2,
42 | "metadata": {},
43 | "output_type": "execute_result"
44 | }
45 | ],
46 | "source": [
47 | "nlp = spacy.load(\"en_core_web_sm\")\n",
48 | "nlp.add_pipe(\"gliner_spacy\")"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 3,
54 | "metadata": {},
55 | "outputs": [
56 | {
57 | "name": "stdout",
58 | "output_type": "stream",
59 | "text": [
60 | "ent 250\n"
61 | ]
62 | }
63 | ],
64 | "source": [
65 | "text = \"This is a text about Bill Gates and Microsoft.\"\n",
66 | "doc = nlp(text)"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 4,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "from spacy import displacy"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 5,
81 | "metadata": {},
82 | "outputs": [
83 | {
84 | "data": {
85 | "text/html": [
86 | "This is a text about \n",
87 | "\n",
88 | " Bill Gates\n",
89 | " person\n",
90 | "\n",
91 | " and \n",
92 | "\n",
93 | " Microsoft\n",
94 | " organization\n",
95 | "\n",
96 | ".
"
97 | ],
98 | "text/plain": [
99 | ""
100 | ]
101 | },
102 | "metadata": {},
103 | "output_type": "display_data"
104 | }
105 | ],
106 | "source": [
107 | "displacy.render(doc, style=\"ent\")"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 6,
113 | "metadata": {},
114 | "outputs": [
115 | {
116 | "name": "stdout",
117 | "output_type": "stream",
118 | "text": [
119 | "Bill Gates person\n",
120 | "Microsoft organization\n"
121 | ]
122 | }
123 | ],
124 | "source": [
125 | "for ent in doc.ents:\n",
126 | " print(ent.text, ent.label_)"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "metadata": {},
133 | "outputs": [],
134 | "source": []
135 | }
136 | ],
137 | "metadata": {
138 | "kernelspec": {
139 | "display_name": "gliner-spacy",
140 | "language": "python",
141 | "name": "python3"
142 | },
143 | "language_info": {
144 | "codemirror_mode": {
145 | "name": "ipython",
146 | "version": 3
147 | },
148 | "file_extension": ".py",
149 | "mimetype": "text/x-python",
150 | "name": "python",
151 | "nbconvert_exporter": "python",
152 | "pygments_lexer": "ipython3",
153 | "version": "3.10.13"
154 | }
155 | },
156 | "nbformat": 4,
157 | "nbformat_minor": 2
158 | }
159 |
--------------------------------------------------------------------------------
/examples/load_local_model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "f2d5e279-1cbc-4291-985f-9e23af3a6ecc",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import torch\n",
11 | "from gliner import GLiNER"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "id": "3baf9b40-daba-4638-b4a2-cc82c8a9ed99",
18 | "metadata": {
19 | "scrolled": true
20 | },
21 | "outputs": [],
22 | "source": [
23 | "# first load your model\n",
24 | "\n",
25 | "model = GLiNER.from_pretrained(\"gliner-community/gliner_medium-v2.5\")"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "id": "5351bc8d-1182-4398-8be7-de61e6b24936",
31 | "metadata": {},
32 | "source": [
33 | "## Option 1"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": null,
39 | "id": "c25fbf7b-b10c-4808-995e-2431f6c0356f",
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "# save\n",
44 | "\n",
45 | "model.save_pretrained(\"gliner_Med\")"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "id": "f46d16c1-ab18-4300-b21f-e78c1da81df3",
52 | "metadata": {
53 | "scrolled": true
54 | },
55 | "outputs": [],
56 | "source": [
57 | "# load\n",
58 | "\n",
59 | "loaded_model = GLiNER.from_pretrained(\"gliner_Med\", load_tokenizer = True, local_files_only=True)"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "id": "4041910a-ee0e-470d-b718-bc151a2666eb",
65 | "metadata": {},
66 | "source": [
67 | "## Option 2"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": null,
73 | "id": "fe7e3b71-1d15-4739-9b41-fcc279046950",
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "def save_model(current_model, path):\n",
78 | " config = current_model.config\n",
79 | " dict_save = {\"model_weights\": current_model.state_dict(), \"config\": config}\n",
80 | " torch.save(dict_save, path)\n",
81 | "\n",
82 | "\n",
83 | "def load_model(path, model_name=None):\n",
84 | " \n",
85 | " dict_load = torch.load(path, map_location=torch.device('cpu'))\n",
86 | " config = dict_load[\"config\"]\n",
87 | "\n",
88 | " print(f\"'{config.model_name}' should be available for local processing\")\n",
89 | "\n",
90 | " if model_name is not None:\n",
91 | " config.model_name = model_name\n",
92 | "\n",
93 | " loaded_model = GLiNER(config)\n",
94 | " loaded_model.load_state_dict(dict_load[\"model_weights\"])\n",
95 | " return loaded_model"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "id": "e513be85-3178-449c-adec-1a609e38b580",
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "# save the model weight\n",
106 | "\n",
107 | "save_model(model, \"model_weight.pt\")"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "id": "628eb872-ff3d-4c59-ac20-9b229797090f",
114 | "metadata": {
115 | "scrolled": true
116 | },
117 | "outputs": [],
118 | "source": [
119 | "# load model weight\n",
120 | "\n",
121 | "loaded_model = load_model(\"model_weight.pt\")\n",
122 | "print(\"success !!\")"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "id": "e057d7ec-1756-4c97-a1d9-e5fdcb60e20a",
128 | "metadata": {},
129 | "source": [
130 | "## Testing"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "id": "2827009e-bdb8-44b2-92b5-e6bdcc17f08e",
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "text = \"\"\"\n",
141 | "Libretto by Marius Petipa, based on the 1822 novella ``Trilby, ou Le Lutin d'Argail`` by Charles Nodier, first presented by the Ballet of the Moscow Imperial Bolshoi Theatre on January 25/February 6 (Julian/Gregorian calendar dates), 1870, in Moscow with Polina Karpakova as Trilby and Ludiia Geiten as Miranda and restaged by Petipa for the Imperial Ballet at the Imperial Bolshoi Kamenny Theatre on January 17–29, 1871 in St. Petersburg with Adèle Grantzow as Trilby and Lev Ivanov as Count Leopold.\n",
142 | "\"\"\"\n",
143 | "\n",
144 | "labels = [\"person\", \"book\", \"location\", \"date\", \"actor\", \"character\"]\n",
145 | "\n",
146 | "entities = loaded_model.predict_entities(text, labels, threshold=0.4)\n",
147 | "\n",
148 | "for entity in entities:\n",
149 | " print(entity[\"text\"], \"=>\", entity[\"label\"])"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": null,
155 | "id": "839336f8-e5a0-471d-ace9-ef7f7e1c5c97",
156 | "metadata": {},
157 | "outputs": [],
158 | "source": []
159 | }
160 | ],
161 | "metadata": {
162 | "kernelspec": {
163 | "display_name": "Python 3 (ipykernel)",
164 | "language": "python",
165 | "name": "python3"
166 | },
167 | "language_info": {
168 | "codemirror_mode": {
169 | "name": "ipython",
170 | "version": 3
171 | },
172 | "file_extension": ".py",
173 | "mimetype": "text/x-python",
174 | "name": "python",
175 | "nbconvert_exporter": "python",
176 | "pygments_lexer": "ipython3",
177 | "version": "3.8.10"
178 | }
179 | },
180 | "nbformat": 4,
181 | "nbformat_minor": 5
182 | }
183 |
--------------------------------------------------------------------------------
/examples/quickstart.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "7037f111-e8eb-4270-8e69-b013d075b751",
7 | "metadata": {
8 | "tags": []
9 | },
10 | "outputs": [],
11 | "source": [
12 | "from gliner import GLiNER"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "id": "2c86b85f-ab71-4918-95c5-909a79ba7158",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "# available models: https://huggingface.co/urchade\n",
23 | "\n",
24 | "model = GLiNER.from_pretrained(\"urchade/gliner_medium\")\n",
25 | "model.eval()\n",
26 | "print(\"ok\")"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "id": "f823dbf3-2462-4a67-8c4b-9a45ec580c1d",
33 | "metadata": {
34 | "tags": []
35 | },
36 | "outputs": [],
37 | "source": [
38 | "text = \"\"\"\n",
39 | "Libretto by Marius Petipa, based on the 1822 novella ``Trilby, ou Le Lutin d'Argail`` by Charles Nodier, first presented by the Ballet of the Moscow Imperial Bolshoi Theatre on January 25/February 6 (Julian/Gregorian calendar dates), 1870, in Moscow with Polina Karpakova as Trilby and Ludiia Geiten as Miranda and restaged by Petipa for the Imperial Ballet at the Imperial Bolshoi Kamenny Theatre on January 17–29, 1871 in St. Petersburg with Adèle Grantzow as Trilby and Lev Ivanov as Count Leopold.\n",
40 | "\"\"\"\n",
41 | "\n",
42 | "labels = [\"person\", \"book\", \"location\", \"date\", \"actor\", \"character\"]\n",
43 | "\n",
44 | "entities = model.predict_entities(text, labels, threshold=0.4)\n",
45 | "\n",
46 | "for entity in entities:\n",
47 | " print(entity[\"text\"], \"=>\", entity[\"label\"])"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "id": "5f5d1377-f073-485f-bd4f-35b5750ba020",
54 | "metadata": {},
55 | "outputs": [],
56 | "source": []
57 | }
58 | ],
59 | "metadata": {
60 | "kernelspec": {
61 | "display_name": "Python 3 (ipykernel)",
62 | "language": "python",
63 | "name": "python3"
64 | },
65 | "language_info": {
66 | "codemirror_mode": {
67 | "name": "ipython",
68 | "version": 3
69 | },
70 | "file_extension": ".py",
71 | "mimetype": "text/x-python",
72 | "name": "python",
73 | "nbconvert_exporter": "python",
74 | "pygments_lexer": "ipython3",
75 | "version": "3.10.10"
76 | }
77 | },
78 | "nbformat": 4,
79 | "nbformat_minor": 5
80 | }
81 |
--------------------------------------------------------------------------------
/gliner/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.20"
2 |
3 | from .model import GLiNER
4 | from .config import GLiNERConfig
5 | # from .multitask import (GLiNERClassifier, GLiNERQuestionAnswerer, GLiNEROpenExtractor,
6 | # GLiNERRelationExtractor, GLiNERSummarizer, GLiNERSquadEvaluator,
7 | # GLiNERDocREDEvaluator)
8 |
9 | __all__ = ["GLiNER"]
10 |
--------------------------------------------------------------------------------
/gliner/config.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from transformers import PretrainedConfig
3 | from transformers.models.auto import CONFIG_MAPPING
4 |
5 | class GLiNERConfig(PretrainedConfig):
6 | model_type = "gliner"
7 | is_composition = True
8 | def __init__(self,
9 | model_name: str = "microsoft/deberta-v3-small",
10 | labels_encoder: str = None,
11 | name: str = "span level gliner",
12 | max_width: int = 12,
13 | hidden_size: int = 512,
14 | dropout: float = 0.4,
15 | fine_tune: bool = True,
16 | subtoken_pooling: str = "first",
17 | span_mode: str = "markerV0",
18 | post_fusion_schema: str = '', #l2l-l2t-t2t
19 | num_post_fusion_layers: int = 1,
20 | vocab_size: int = -1,
21 | max_neg_type_ratio: int = 1,
22 | max_types: int = 25,
23 | max_len: int = 384,
24 | words_splitter_type: str = "whitespace",
25 | has_rnn: bool = True,
26 | fuse_layers: bool = False,
27 | embed_ent_token: bool = True,
28 | class_token_index: int = -1,
29 | encoder_config: Optional[dict] = None,
30 | labels_encoder_config: Optional[dict] = None,
31 | ent_token = "<>",
32 | sep_token = "<>",
33 | _attn_implementation = None,
34 | **kwargs):
35 | super().__init__(**kwargs)
36 | if isinstance(encoder_config, dict):
37 | encoder_config["model_type"] = (encoder_config["model_type"]
38 | if "model_type" in encoder_config
39 | else "deberta-v2")
40 | encoder_config = CONFIG_MAPPING[encoder_config["model_type"]](**encoder_config)
41 | self.encoder_config = encoder_config
42 |
43 | if isinstance(labels_encoder_config, dict):
44 | labels_encoder_config["model_type"] = (labels_encoder_config["model_type"]
45 | if "model_type" in labels_encoder_config
46 | else "deberta-v2")
47 | labels_encoder_config = CONFIG_MAPPING[labels_encoder_config["model_type"]](**labels_encoder_config)
48 | self.labels_encoder_config = labels_encoder_config
49 |
50 | self.model_name = model_name
51 | self.labels_encoder = labels_encoder
52 | self.name = name
53 | self.max_width = max_width
54 | self.hidden_size = hidden_size
55 | self.dropout = dropout
56 | self.fine_tune = fine_tune
57 | self.subtoken_pooling = subtoken_pooling
58 | self.span_mode = span_mode
59 | self.post_fusion_schema = post_fusion_schema
60 | self.num_post_fusion_layers = num_post_fusion_layers
61 | self.vocab_size = vocab_size
62 | self.max_neg_type_ratio = max_neg_type_ratio
63 | self.max_types = max_types
64 | self.max_len = max_len
65 | self.words_splitter_type = words_splitter_type
66 | self.has_rnn = has_rnn
67 | self.fuse_layers = fuse_layers
68 | self.class_token_index = class_token_index
69 | self.embed_ent_token = embed_ent_token
70 | self.ent_token = ent_token
71 | self.sep_token = sep_token
72 | self._attn_implementation = _attn_implementation
73 |
74 | # Register the configuration
75 | from transformers import CONFIG_MAPPING
76 | CONFIG_MAPPING.update({"gliner": GLiNERConfig})
--------------------------------------------------------------------------------
/gliner/data_processing/__init__.py:
--------------------------------------------------------------------------------
1 | from .processor import SpanProcessor, SpanBiEncoderProcessor, TokenProcessor, TokenBiEncoderProcessor
2 | from .collator import DataCollator
3 | from .tokenizer import WordsSplitter
4 | from .dataset import GLiNERDataset
--------------------------------------------------------------------------------
/gliner/data_processing/collator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.utils.rnn import pad_sequence
3 | import torch.nn.functional as F
4 | from .processor import SpanProcessor, TokenProcessor
5 | from .utils import pad_2d_tensor
6 |
7 | class DataCollator:
8 | def __init__(self, config, tokenizer=None, words_splitter=None, data_processor=None,
9 | return_tokens: bool = False,
10 | return_id_to_classes: bool = False,
11 | return_entities: bool = False,
12 | prepare_labels: bool = False,
13 | entity_types = None):
14 | self.config=config
15 | if data_processor is None:
16 | if config.span_mode == "token_level":
17 | self.data_processor = TokenProcessor(config, tokenizer, words_splitter)
18 | else:
19 | self.data_processor = SpanProcessor(config, tokenizer, words_splitter)
20 | else:
21 | self.data_processor = data_processor
22 | self.prepare_labels = prepare_labels
23 | self.return_tokens = return_tokens
24 | self.return_id_to_classes = return_id_to_classes
25 | self.return_entities = return_entities
26 | self.entity_types = entity_types
27 |
28 | def __call__(self, input_x):
29 | raw_batch = self.data_processor.collate_raw_batch(input_x, entity_types = self.entity_types)
30 |
31 | model_input = self.data_processor.collate_fn(raw_batch, prepare_labels=self.prepare_labels)
32 | model_input.update({"span_idx": raw_batch['span_idx'] if 'span_idx' in raw_batch else None,
33 | "span_mask": raw_batch["span_mask"] if 'span_mask' in raw_batch else None,
34 | "text_lengths": raw_batch['seq_length']})
35 | if self.return_tokens:
36 | model_input['tokens'] = raw_batch['tokens']
37 | if self.return_id_to_classes:
38 | model_input['id_to_classes'] = raw_batch['id_to_classes']
39 | if self.return_entities:
40 | model_input['entities'] = raw_batch['entities']
41 | model_input = {k:v for k, v in model_input.items() if v is not None}
42 | return model_input
43 |
44 | class DataCollatorWithPadding:
45 | def __init__(self, config=None):
46 | """
47 | Initialize the DataCollator with configs.
48 | """
49 | self.config = config
50 |
51 | def __call__(self, batch):
52 | if not batch:
53 | raise ValueError("Batch cannot be empty")
54 | batch = [item for item in batch if item is not None]
55 | # Extract all keys from the first item
56 | keys = batch[0].keys()
57 |
58 | # Create a dictionary to hold padded data
59 | padded_batch = {key: [] for key in keys}
60 |
61 | for key in keys:
62 | if key in {'tokens', 'id_to_classes', 'entities'}:
63 | padded_batch[key] = [item[key] for item in batch]
64 | continue
65 | # Collect data for the current key
66 | key_data = [item[key].squeeze(0) for item in batch]
67 |
68 | if isinstance(key_data[0], torch.Tensor):
69 | if key_data[0].dim() == 1:
70 | # For 1D tensors, use pad_sequence
71 | if key == 'span_label':
72 | span_label = pad_sequence(key_data, batch_first=True, padding_value=-1)
73 | span_mask = span_label != -1
74 | padded_batch[key] = span_mask
75 | else:
76 | padded_batch[key] = pad_sequence(key_data, batch_first=True)
77 | elif key_data[0].dim() == 2: # span_idx case
78 | padded_batch[key] = self._pad_2d_tensor(key_data)
79 | elif key == 'labels' and self.config.span_mode == 'token_level':
80 | padded_batch[key] = self.pad_token_labels(key_data)
81 | else:
82 | raise TypeError(f"Unsuported amount of dimension for key '{key}'")
83 | elif isinstance(key_data[0], list):
84 | # Pad list-like data
85 | max_length = max(len(seq) for seq in key_data)
86 | padded_batch[key] = torch.tensor(
87 | [seq + [0] * (max_length - len(seq)) for seq in key_data],
88 | dtype=torch.float32
89 | ).to(self.device)
90 | elif isinstance(key_data[0], (int, float)):
91 | # Directly convert numeric data to tensors
92 | padded_batch[key] = torch.tensor(key_data, dtype=torch.float32).to(self.device)
93 | else:
94 | raise TypeError(f"Unsupported data type for key '{key}': {type(key_data[0])}")
95 | padded_batch = {k:v for k,v in padded_batch.items() if v is not None}
96 | return padded_batch
97 |
98 | def _pad_2d_tensor(self, key_data):
99 | padded_tensors = pad_2d_tensor(key_data)
100 | return padded_tensors
101 |
102 | def pad_token_labels(self, key_data):
103 | if not key_data:
104 | raise ValueError("The input list 'key_data' should not be empty.")
105 |
106 | # Determine the maximum sequence length and number of classes
107 | max_seq_len = max(tensor.shape[2] for tensor in key_data)
108 | max_num_classes = max(tensor.shape[3] for tensor in key_data)
109 |
110 | padded_tensors = []
111 |
112 | for tensor in key_data:
113 | current_seq_len = tensor.shape[2]
114 | current_num_classes = tensor.shape[3]
115 |
116 | seq_padding = max_seq_len - current_seq_len
117 | class_padding = max_num_classes - current_num_classes
118 |
119 | # Pad tensor to the maximum sequence length and number of classes
120 | padded_tensor = F.pad(tensor, (0, class_padding, 0, seq_padding), mode='constant', value=0)
121 | padded_tensors.append(padded_tensor)
122 |
123 | # Concatenate the tensors along the batch dimension
124 | concatenated_labels = torch.cat(padded_tensors, dim=1)
125 |
126 | return concatenated_labels
--------------------------------------------------------------------------------
/gliner/data_processing/dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | from tqdm import tqdm
3 | from typing import Optional, List
4 | from torch.utils.data import Dataset
5 | from transformers import AutoTokenizer
6 |
7 | from . import TokenProcessor, SpanProcessor, WordsSplitter
8 | from ..config import GLiNERConfig
9 |
10 | class GLiNERDataset(Dataset):
11 | def __init__(self, examples,
12 | config: Optional[GLiNERConfig],
13 | tokenizer: Optional[AutoTokenizer] = None,
14 | words_splitter: Optional[WordsSplitter] = None,
15 | data_processor = None,
16 | entities = None,
17 | get_negatives:bool=True):
18 | self._data = examples
19 | self.config=config
20 | if data_processor is not None:
21 | self.data_processor = data_processor
22 | else:
23 | if config.span_mode == "token_level":
24 | self.data_processor = TokenProcessor(config, tokenizer, words_splitter, preprocess_text=True)
25 | else:
26 | self.data_processor = SpanProcessor(config, tokenizer, words_splitter, preprocess_text=True)
27 |
28 | self.max_neg_type_ratio = int(self.config.max_neg_type_ratio)
29 | self.get_negatives = get_negatives
30 | if not entities:
31 | self.all_entities = self._collect_all_entities()
32 | else:
33 | self.all_entities = entities
34 | self.max_negatives = min(50, len(self.all_entities))
35 |
36 | def _get_entities_from_example(self, example):
37 | entities = {ner[-1] for ner in example['ner']}
38 | return entities
39 |
40 | def _collect_all_entities(self):
41 | print("Collecting all entities...")
42 | all_entities = set()
43 | for example in tqdm(self._data):
44 | curr_entities = self._get_entities_from_example(example)
45 | all_entities.update(curr_entities)
46 | print('Total number of entity classes: ', len(all_entities))
47 | return list(all_entities)
48 |
49 | def _get_negatives(self):
50 | negatives = random.sample(self.all_entities, k=self.max_negatives)
51 | random.shuffle(negatives)
52 | return negatives
53 |
54 | def __len__(self):
55 | return len(self._data)
56 |
57 | def __getitem__(self, idx):
58 | try:
59 | example = self._data[idx]
60 | if self.get_negatives:
61 | curr_negatives = self._get_negatives()
62 | else:
63 | curr_negatives = None
64 |
65 | raw_batch = self.data_processor.collate_raw_batch([example], negatives = curr_negatives)
66 |
67 | model_input = self.data_processor.collate_fn(raw_batch, prepare_labels=True)
68 | if 'span_idx' in raw_batch:
69 | model_input['span_idx'] = raw_batch['span_idx']
70 | if 'span_mask' in raw_batch:
71 | model_input['span_mask'] = raw_batch['span_mask']
72 | if 'seq_length' in raw_batch:
73 | model_input['text_lengths'] = raw_batch['seq_length']
74 | return model_input
75 | except Exception as e:
76 | print(f"Skipping getting item due to error: {e}")
77 | return None
--------------------------------------------------------------------------------
/gliner/data_processing/tokenizer.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 |
4 | class TokenSplitterBase():
5 | def __init__(self):
6 | pass
7 |
8 | def __call__(self, text) -> (str, int, int):
9 | pass
10 |
11 |
12 | class WhitespaceTokenSplitter(TokenSplitterBase):
13 | def __init__(self):
14 | self.whitespace_pattern = re.compile(r'\w+(?:[-_]\w+)*|\S')
15 |
16 | def __call__(self, text):
17 | for match in self.whitespace_pattern.finditer(text):
18 | yield match.group(), match.start(), match.end()
19 |
20 |
21 | class SpaCyTokenSplitter(TokenSplitterBase):
22 | def __init__(self, lang=None):
23 | try:
24 | import spacy # noqa
25 | except ModuleNotFoundError as error:
26 | raise error.__class__(
27 | "Please install spacy with: `pip install spacy`"
28 | )
29 | if lang is None:
30 | lang = 'en' # Default to English if no language is specified
31 | self.nlp = spacy.blank(lang)
32 |
33 | def __call__(self, text):
34 | doc = self.nlp(text)
35 | for token in doc:
36 | yield token.text, token.idx, token.idx + len(token.text)
37 |
38 |
39 | class MecabKoTokenSplitter(TokenSplitterBase):
40 | def __init__(self):
41 | try:
42 | import mecab # noqa
43 | except ModuleNotFoundError as error:
44 | raise error.__class__(
45 | "Please install python-mecab-ko with: `pip install python-mecab-ko`"
46 | )
47 | self.tagger = mecab.MeCab()
48 |
49 | def __call__(self, text):
50 | tokens = self.tagger.morphs(text)
51 |
52 | last_idx = 0
53 | for morph in tokens:
54 | start_idx = text.find(morph, last_idx)
55 | end_idx = start_idx + len(morph)
56 | last_idx = end_idx
57 | yield morph, start_idx, end_idx
58 |
59 | class JiebaTokenSplitter(TokenSplitterBase):
60 | def __init__(self):
61 | try:
62 | import jieba # noqa
63 | except ModuleNotFoundError as error:
64 | raise error.__class__(
65 | "Please install jieba with: `pip install jieba`"
66 | )
67 | self.tagger = jieba
68 |
69 | def __call__(self, text):
70 | tokens = self.tagger.cut(text)
71 | last_idx = 0
72 | for token in tokens:
73 | start_idx = text.find(token, last_idx)
74 | end_idx = start_idx + len(token)
75 | last_idx = end_idx
76 | yield token, start_idx, end_idx
77 |
78 | class HanLPTokenSplitter(TokenSplitterBase):
79 | def __init__(self, model_name="FINE_ELECTRA_SMALL_ZH"):
80 | try:
81 | import hanlp # noqa
82 | import hanlp.pretrained
83 | except ModuleNotFoundError as error:
84 | raise error.__class__(
85 | "Please install hanlp with: `pip install hanlp`"
86 | )
87 |
88 | models = hanlp.pretrained.tok.ALL
89 | if model_name not in models:
90 | raise ValueError(f"HanLP: {model_name} is not available, choose between {models.keys()}")
91 | url = models[model_name]
92 | self.tagger = hanlp.load(url)
93 |
94 | def __call__(self, text):
95 | tokens = self.tagger(text)
96 | last_idx = 0
97 | for token in tokens:
98 | start_idx = text.find(token, last_idx)
99 | end_idx = start_idx + len(token)
100 | last_idx = end_idx
101 | yield token, start_idx, end_idx
102 |
103 | class WordsSplitter(TokenSplitterBase):
104 | def __init__(self, splitter_type='whitespace'):
105 | if splitter_type=='whitespace':
106 | self.splitter = WhitespaceTokenSplitter()
107 | elif splitter_type == 'spacy':
108 | self.splitter = SpaCyTokenSplitter()
109 | elif splitter_type == 'mecab':
110 | self.splitter = MecabKoTokenSplitter()
111 | elif splitter_type == 'jieba':
112 | self.splitter = JiebaTokenSplitter()
113 | elif splitter_type == 'hanlp':
114 | self.splitter = HanLPTokenSplitter()
115 | else:
116 | raise ValueError(f"{splitter_type} is not implemented, choose between 'whitespace', 'spacy', 'jieba', 'hanlp' and 'mecab'")
117 |
118 | def __call__(self, text):
119 | for token in self.splitter(text):
120 | yield token
--------------------------------------------------------------------------------
/gliner/data_processing/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def pad_2d_tensor(key_data):
4 | """
5 | Pad a list of 2D tensors to have the same size along both dimensions.
6 |
7 | :param key_data: List of 2D tensors to pad.
8 | :return: Tensor of padded tensors stacked along a new batch dimension.
9 | """
10 | if not key_data:
11 | raise ValueError("The input list 'key_data' should not be empty.")
12 |
13 | # Determine the maximum size along both dimensions
14 | max_rows = max(tensor.shape[0] for tensor in key_data)
15 | max_cols = max(tensor.shape[1] for tensor in key_data)
16 |
17 | tensors = []
18 |
19 | for tensor in key_data:
20 | rows, cols = tensor.shape
21 | row_padding = max_rows - rows
22 | col_padding = max_cols - cols
23 |
24 | # Pad the tensor along both dimensions
25 | padded_tensor = torch.nn.functional.pad(tensor, (0, col_padding, 0, row_padding),
26 | mode='constant', value=0)
27 | tensors.append(padded_tensor)
28 |
29 | # Stack the tensors into a single tensor along a new batch dimension
30 | padded_tensors = torch.stack(tensors)
31 |
32 | return padded_tensors
--------------------------------------------------------------------------------
/gliner/decoding/__init__.py:
--------------------------------------------------------------------------------
1 | from .decoder import SpanDecoder, TokenDecoder
--------------------------------------------------------------------------------
/gliner/decoding/decoder.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from abc import ABC, abstractmethod
3 | from functools import partial
4 | import torch
5 |
6 | from .utils import has_overlapping, has_overlapping_nested
7 |
8 |
9 | class BaseDecoder(ABC):
10 | def __init__(self, config):
11 | self.config = config
12 |
13 | @abstractmethod
14 | def decode(self, *args, **kwargs):
15 | pass
16 |
17 | def greedy_search(self, spans, flat_ner=True, multi_label=False):
18 | if flat_ner:
19 | has_ov = partial(has_overlapping, multi_label=multi_label)
20 | else:
21 | has_ov = partial(has_overlapping_nested, multi_label=multi_label)
22 |
23 | new_list = []
24 | span_prob = sorted(spans, key=lambda x: -x[-1])
25 |
26 | for i in range(len(spans)):
27 | b = span_prob[i]
28 | flag = False
29 | for new in new_list:
30 | if has_ov(b[:-1], new):
31 | flag = True
32 | break
33 | if not flag:
34 | new_list.append(b)
35 |
36 | new_list = sorted(new_list, key=lambda x: x[0])
37 | return new_list
38 |
39 |
40 | class SpanDecoder(BaseDecoder):
41 | def decode(self, tokens, id_to_classes, model_output, flat_ner=False, threshold=0.5, multi_label=False):
42 | probs = torch.sigmoid(model_output)
43 | spans = []
44 | for i, _ in enumerate(tokens):
45 | probs_i = probs[i]
46 |
47 | # Support for id_to_classes being a list of dictionaries
48 | id_to_class_i = id_to_classes[i] if isinstance(id_to_classes, list) else id_to_classes
49 |
50 | wh_i = [i.tolist() for i in torch.where(probs_i > threshold)]
51 | span_i = []
52 | for s, k, c in zip(*wh_i):
53 | if s + k < len(tokens[i]):
54 | span_i.append((s, s + k, id_to_class_i[c + 1], probs_i[s, k, c].item()))
55 |
56 | span_i = self.greedy_search(span_i, flat_ner, multi_label=multi_label)
57 | spans.append(span_i)
58 | return spans
59 |
60 |
61 | class TokenDecoder(BaseDecoder):
62 | def get_indices_above_threshold(self, scores, threshold):
63 | scores = torch.sigmoid(scores)
64 | return [k.tolist() for k in torch.where(scores > threshold)]
65 |
66 | def calculate_span_score(self, start_idx, end_idx, scores_inside_i, start_i, end_i, id_to_classes, threshold):
67 | span_i = []
68 | for st, cls_st in zip(*start_idx):
69 | for ed, cls_ed in zip(*end_idx):
70 | if ed >= st and cls_st == cls_ed:
71 | ins = scores_inside_i[st:ed + 1, cls_st]
72 | if (ins < threshold).any():
73 | continue
74 | # Get the start and end scores for this span
75 | start_score = start_i[st, cls_st]
76 | end_score = end_i[ed, cls_st]
77 | # Concatenate the inside scores with start and end scores
78 | combined = torch.cat([ins, start_score.unsqueeze(0), end_score.unsqueeze(0)])
79 | # The span score is the minimum value among these scores
80 | spn_score = combined.min().item()
81 | span_i.append((st, ed, id_to_classes[cls_st + 1], spn_score))
82 | return span_i
83 |
84 | def decode(self, tokens, id_to_classes, model_output, flat_ner=False, threshold=0.5, multi_label=False):
85 | scores_start, scores_end, scores_inside = model_output
86 | spans = []
87 | for i, _ in enumerate(tokens):
88 | id_to_class_i = id_to_classes[i] if isinstance(id_to_classes, list) else id_to_classes
89 | span_scores = self.calculate_span_score(
90 | self.get_indices_above_threshold(scores_start[i], threshold),
91 | self.get_indices_above_threshold(scores_end[i], threshold),
92 | torch.sigmoid(scores_inside[i]),
93 | torch.sigmoid(scores_start[i]),
94 | torch.sigmoid(scores_end[i]),
95 | id_to_class_i,
96 | threshold
97 | )
98 | span_i = self.greedy_search(span_scores, flat_ner, multi_label)
99 | spans.append(span_i)
100 | return spans
--------------------------------------------------------------------------------
/gliner/decoding/utils.py:
--------------------------------------------------------------------------------
1 | def is_nested(idx1, idx2):
2 | # Return True if idx2 is nested inside idx1 or vice versa
3 | return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1])
4 |
5 | def has_overlapping(idx1, idx2, multi_label=False):
6 | # Check for any overlap between two spans
7 | if idx1[:2] == idx2[:2]: # Exact same boundaries can be considered as overlapping
8 | return not multi_label
9 | if idx1[0] > idx2[1] or idx2[0] > idx1[1]:
10 | return False
11 | return True
12 |
13 |
14 | def has_overlapping_nested(idx1, idx2, multi_label=False):
15 | # Return True if idx1 and idx2 overlap, but neither is nested inside the other
16 | if idx1[:2] == idx2[:2]: # Exact same boundaries, not considering labels here
17 | return not multi_label
18 | if (idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2):
19 | return False
20 | return True
21 |
--------------------------------------------------------------------------------
/gliner/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .evaluator import Evaluator
2 | from .evaluate import get_for_all_path, get_for_one_path
--------------------------------------------------------------------------------
/gliner/evaluation/evaluate.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import os
4 | import os
5 | import numpy as np
6 | import argparse
7 | import torch
8 | from tqdm import tqdm
9 | import random
10 |
11 | def open_content(path):
12 | paths = glob.glob(os.path.join(path, "*.json"))
13 | train, dev, test, labels = None, None, None, None
14 | for p in paths:
15 | if "train" in p:
16 | with open(p, "r") as f:
17 | train = json.load(f)
18 | elif "dev" in p:
19 | with open(p, "r") as f:
20 | dev = json.load(f)
21 | elif "test" in p:
22 | with open(p, "r") as f:
23 | test = json.load(f)
24 | elif "labels" in p:
25 | with open(p, "r") as f:
26 | labels = json.load(f)
27 | return train, dev, test, labels
28 |
29 |
30 | def process(data):
31 | words = data['sentence'].split()
32 | entities = [] # List of entities (start, end, type)
33 |
34 | for entity in data['entities']:
35 | start_char, end_char = entity['pos']
36 |
37 | # Initialize variables to keep track of word positions
38 | start_word = None
39 | end_word = None
40 |
41 | # Iterate through words and find the word positions
42 | char_count = 0
43 | for i, word in enumerate(words):
44 | word_length = len(word)
45 | if char_count == start_char:
46 | start_word = i
47 | if char_count + word_length == end_char:
48 | end_word = i
49 | break
50 | char_count += word_length + 1 # Add 1 for the space
51 |
52 | # Append the word positions to the list
53 | entities.append((start_word, end_word, entity['type'].lower()))
54 |
55 | # Create a list of word positions for each entity
56 | sample = {
57 | "tokenized_text": words,
58 | "ner": entities
59 | }
60 |
61 | return sample
62 |
63 |
64 | # create dataset
65 | def create_dataset(path):
66 | train, dev, test, labels = open_content(path)
67 | train_dataset = []
68 | dev_dataset = []
69 | test_dataset = []
70 | for data in train:
71 | train_dataset.append(process(data))
72 | for data in dev:
73 | dev_dataset.append(process(data))
74 | for data in test:
75 | test_dataset.append(process(data))
76 | labels = [label.lower() for label in labels]
77 | return train_dataset, dev_dataset, test_dataset, labels
78 |
79 |
80 | @torch.no_grad()
81 | def get_for_one_path(path, model):
82 | # load the dataset
83 | _, _, test_dataset, entity_types = create_dataset(path)
84 |
85 | data_name = path.split("/")[-1] # get the name of the dataset
86 |
87 | # check if the dataset is flat_ner
88 | flat_ner = True
89 | if any([i in data_name for i in ["ACE", "GENIA", "Corpus"]]):
90 | flat_ner = False
91 |
92 | # evaluate the model
93 | results, f1 = model.evaluate(test_dataset, flat_ner=flat_ner, threshold=0.5, batch_size=12,
94 | entity_types=entity_types)
95 | return data_name, results, f1
96 |
97 |
98 | def get_for_all_path(model, steps, log_dir, data_paths):
99 | all_paths = glob.glob(f"{data_paths}/*")
100 |
101 | all_paths = sorted(all_paths)
102 |
103 | # move the model to the device
104 | device = next(model.parameters()).device
105 | model.to(device)
106 | # set the model to eval mode
107 | model.eval()
108 |
109 | # log the results
110 | save_path = os.path.join(log_dir, "results.txt")
111 |
112 | with open(save_path, "a") as f:
113 | f.write("##############################################\n")
114 | # write step
115 | f.write("step: " + str(steps) + "\n")
116 |
117 | zero_shot_benc = ["mit-movie", "mit-restaurant", "CrossNER_AI", "CrossNER_literature", "CrossNER_music",
118 | "CrossNER_politics", "CrossNER_science"]
119 |
120 | zero_shot_benc_results = {}
121 | all_results = {} # without crossNER
122 |
123 | for p in tqdm(all_paths):
124 | if "sample_" not in p:
125 | data_name, results, f1 = get_for_one_path(p, model)
126 | # write to file
127 | with open(save_path, "a") as f:
128 | f.write(data_name + "\n")
129 | f.write(str(results) + "\n")
130 |
131 | if data_name in zero_shot_benc:
132 | zero_shot_benc_results[data_name] = f1
133 | else:
134 | all_results[data_name] = f1
135 |
136 | avg_all = sum(all_results.values()) / len(all_results)
137 | avg_zs = sum(zero_shot_benc_results.values()) / len(zero_shot_benc_results)
138 |
139 | save_path_table = os.path.join(log_dir, "tables.txt")
140 |
141 | # results for all datasets except crossNER
142 | table_bench_all = ""
143 | for k, v in all_results.items():
144 | table_bench_all += f"{k:20}: {v:.1%}\n"
145 | # (20 size aswell for average i.e. :20)
146 | table_bench_all += f"{'Average':20}: {avg_all:.1%}"
147 |
148 | # results for zero-shot benchmark
149 | table_bench_zeroshot = ""
150 | for k, v in zero_shot_benc_results.items():
151 | table_bench_zeroshot += f"{k:20}: {v:.1%}\n"
152 | table_bench_zeroshot += f"{'Average':20}: {avg_zs:.1%}"
153 |
154 | # write to file
155 | with open(save_path_table, "a") as f:
156 | f.write("##############################################\n")
157 | f.write("step: " + str(steps) + "\n")
158 | f.write("Table for all datasets except crossNER\n")
159 | f.write(table_bench_all + "\n\n")
160 | f.write("Table for zero-shot benchmark\n")
161 | f.write(table_bench_zeroshot + "\n")
162 | f.write("##############################################\n\n")
163 |
164 |
165 | def sample_train_data(data_paths, sample_size=10000):
166 | all_paths = glob.glob(f"{data_paths}/*")
167 |
168 | all_paths = sorted(all_paths)
169 |
170 | # to exclude the zero-shot benchmark datasets
171 | zero_shot_benc = ["CrossNER_AI", "CrossNER_literature", "CrossNER_music",
172 | "CrossNER_politics", "CrossNER_science", "ACE 2004"]
173 |
174 | new_train = []
175 | # take 10k samples from each dataset
176 | for p in tqdm(all_paths):
177 | if any([i in p for i in zero_shot_benc]):
178 | continue
179 | train, dev, test, labels = create_dataset(p)
180 |
181 | # add label key to the train data
182 | for i in range(len(train)):
183 | train[i]["label"] = labels
184 |
185 | random.shuffle(train)
186 | train = train[:sample_size]
187 | new_train.extend(train)
188 |
189 | return new_train
190 |
--------------------------------------------------------------------------------
/gliner/evaluation/evaluator.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from collections import defaultdict
3 | from typing import Union, List, Literal
4 |
5 | import numpy as np
6 | import torch
7 |
8 |
9 | class UndefinedMetricWarning(UserWarning):
10 | pass
11 |
12 |
13 | def _prf_divide(
14 | numerator: np.ndarray,
15 | denominator: np.ndarray,
16 | metric: Literal["precision", "recall", "f-score"],
17 | modifier: str,
18 | average: str,
19 | warn_for: List[str],
20 | zero_division: Union[str, int] = "warn",
21 | ) -> np.ndarray:
22 | """Performs division and handles divide-by-zero with warnings."""
23 | with np.errstate(divide="ignore", invalid="ignore"):
24 | result = np.true_divide(numerator, denominator)
25 | result[denominator == 0] = 0.0 if zero_division in ["warn", 0] else 1.0
26 |
27 | if denominator == 0 and zero_division == "warn" and metric in warn_for:
28 | msg_start = f"{metric.title()}"
29 | if "f-score" in warn_for:
30 | msg_start += " and F-score" if metric in warn_for else "F-score"
31 | msg_start += " are" if "f-score" in warn_for else " is"
32 | _warn_prf(
33 | average=average,
34 | modifier=modifier,
35 | msg_start=msg_start,
36 | result_size=len(result),
37 | )
38 |
39 | return result
40 |
41 |
42 | def _warn_prf(average: str, modifier: str, msg_start: str, result_size: int):
43 | axis0, axis1 = ("label", "sample") if average == "samples" else ("sample", "label")
44 | if result_size == 1:
45 | msg = f"{msg_start} ill-defined and being set to 0.0 due to no {modifier} {axis0}." # noqa: E501
46 | else:
47 | msg = f"{msg_start} ill-defined and being set to 0.0 in {axis1}s with no {modifier} {axis0}s." # noqa: E501
48 | msg += " Use `zero_division` parameter to control this behavior."
49 | warnings.warn(msg, UndefinedMetricWarning, stacklevel=3)
50 |
51 |
52 | def extract_tp_actual_correct(y_true, y_pred):
53 | entities_true = defaultdict(set)
54 | entities_pred = defaultdict(set)
55 |
56 | for type_name, (start, end), idx in y_true:
57 | entities_true[type_name].add((start, end, idx))
58 | for type_name, (start, end), idx in y_pred:
59 | entities_pred[type_name].add((start, end, idx))
60 |
61 | target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
62 |
63 | tp_sum = np.array([], dtype=np.int32)
64 | pred_sum = np.array([], dtype=np.int32)
65 | true_sum = np.array([], dtype=np.int32)
66 | for type_name in target_names:
67 | entities_true_type = entities_true.get(type_name, set())
68 | entities_pred_type = entities_pred.get(type_name, set())
69 | tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
70 | pred_sum = np.append(pred_sum, len(entities_pred_type))
71 | true_sum = np.append(true_sum, len(entities_true_type))
72 |
73 | return pred_sum, tp_sum, true_sum, target_names
74 |
75 |
76 | def flatten_for_eval(y_true, y_pred):
77 | all_true = []
78 | all_pred = []
79 |
80 | for i, (true, pred) in enumerate(zip(y_true, y_pred)):
81 | all_true.extend([t + [i] for t in true])
82 | all_pred.extend([p + [i] for p in pred])
83 |
84 | return all_true, all_pred
85 |
86 |
87 | def compute_prf(y_true, y_pred, average="micro"):
88 | y_true, y_pred = flatten_for_eval(y_true, y_pred)
89 |
90 | pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred)
91 |
92 | if average == "micro":
93 | tp_sum = np.array([tp_sum.sum()])
94 | pred_sum = np.array([pred_sum.sum()])
95 | true_sum = np.array([true_sum.sum()])
96 |
97 | precision = _prf_divide(
98 | numerator=tp_sum,
99 | denominator=pred_sum,
100 | metric="precision",
101 | modifier="predicted",
102 | average=average,
103 | warn_for=["precision", "recall", "f-score"],
104 | zero_division="warn",
105 | )
106 |
107 | recall = _prf_divide(
108 | numerator=tp_sum,
109 | denominator=true_sum,
110 | metric="recall",
111 | modifier="true",
112 | average=average,
113 | warn_for=["precision", "recall", "f-score"],
114 | zero_division="warn",
115 | )
116 |
117 | denominator = precision + recall
118 | denominator[denominator == 0.0] = 1
119 | f_score = 2 * (precision * recall) / denominator
120 |
121 | return {"precision": precision[0], "recall": recall[0], "f_score": f_score[0]}
122 |
123 |
124 | class Evaluator:
125 | def __init__(self, all_true, all_outs):
126 | self.all_true = all_true
127 | self.all_outs = all_outs
128 |
129 | def get_entities_fr(self, ents):
130 | all_ents = []
131 | for s, e, lab in ents:
132 | all_ents.append([lab, (s, e)])
133 | return all_ents
134 |
135 | def get_entities_pr(self, ents):
136 | all_ents = []
137 | for s, e, lab, _ in ents:
138 | all_ents.append([lab, (s, e)])
139 | return all_ents
140 |
141 | def transform_data(self):
142 | all_true_ent = []
143 | all_outs_ent = []
144 | for i, j in zip(self.all_true, self.all_outs):
145 | e = self.get_entities_fr(i)
146 | all_true_ent.append(e)
147 | e = self.get_entities_pr(j)
148 | all_outs_ent.append(e)
149 | return all_true_ent, all_outs_ent
150 |
151 | @torch.no_grad()
152 | def evaluate(self):
153 | all_true_typed, all_outs_typed = self.transform_data()
154 | precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values()
155 | output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n"
156 | return output_str, f1
157 |
158 |
159 | def is_nested(idx1, idx2):
160 | # Return True if idx2 is nested inside idx1 or vice versa
161 | return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (
162 | idx2[0] <= idx1[0] and idx2[1] >= idx1[1]
163 | )
164 |
165 |
166 | def has_overlapping(idx1, idx2, multi_label=False):
167 | # Check for any overlap between two spans
168 | if idx1[:2] == idx2[:2]: # Exact same boundaries can be considered as overlapping
169 | return not multi_label
170 | if idx1[0] > idx2[1] or idx2[0] > idx1[1]:
171 | return False
172 | return True
173 |
174 |
175 | def has_overlapping_nested(idx1, idx2, multi_label=False):
176 | # Return True if idx1 and idx2 overlap, but neither is nested inside the other
177 | if idx1[:2] == idx2[:2]: # Exact same boundaries, not considering labels here
178 | return not multi_label
179 | if (idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2):
180 | return False
181 | return True
182 |
183 |
184 | from functools import partial
185 |
186 |
187 | def greedy_search(spans, flat_ner=True, multi_label=False): # start, end, class, score
188 | if flat_ner:
189 | has_ov = partial(has_overlapping, multi_label=multi_label)
190 | else:
191 | has_ov = partial(has_overlapping_nested, multi_label=multi_label)
192 |
193 | new_list = []
194 | span_prob = sorted(spans, key=lambda x: -x[-1])
195 |
196 | for i in range(len(spans)):
197 | b = span_prob[i]
198 | flag = False
199 | for new in new_list:
200 | if has_ov(b[:-1], new):
201 | flag = True
202 | break
203 | if not flag:
204 | new_list.append(b)
205 |
206 | new_list = sorted(new_list, key=lambda x: x[0])
207 | return new_list
208 |
--------------------------------------------------------------------------------
/gliner/modeling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/urchade/GLiNER/efbfa38211136657895372d33d4ee2fe11b6f11b/gliner/modeling/__init__.py
--------------------------------------------------------------------------------
/gliner/modeling/encoder.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from pathlib import Path
3 |
4 | import torch
5 | from torch import nn
6 | from transformers import AutoModel, AutoConfig
7 |
8 | from .layers import LayersFuser
9 | from ..utils import is_module_available, MissedPackageException
10 | from typing import Optional, Union
11 |
12 | IS_LLM2VEC = is_module_available('llm2vec')
13 | IS_PEFT = is_module_available('peft')
14 | IS_TURBOT5 = is_module_available('turbot5')
15 | IS_FLASHDEBERTA = is_module_available('flashdeberta')
16 |
17 | if IS_LLM2VEC:
18 | from llm2vec.models import MistralBiModel, LlamaBiModel, GemmaBiModel, Qwen2BiModel
19 | DECODER_MODEL_MAPPING = {
20 | "MistralConfig": MistralBiModel,
21 | "LlamaConfig": LlamaBiModel,
22 | "GemmaConfig": GemmaBiModel,
23 | "Qwen2Config": Qwen2BiModel
24 | }
25 | else:
26 | DECODER_MODEL_MAPPING = {}
27 |
28 | if IS_TURBOT5:
29 | from turbot5.model.modeling import T5EncoderModel
30 | else:
31 | from transformers import T5EncoderModel
32 |
33 | if IS_FLASHDEBERTA:
34 | from flashdeberta import FlashDebertaV2Model as DebertaV2Model
35 | else:
36 | from transformers import DebertaV2Model
37 |
38 | if IS_PEFT:
39 | from peft import LoraConfig, get_peft_model
40 |
41 | class Transformer(nn.Module):
42 | def __init__(
43 | self,
44 | model_name,
45 | config,
46 | from_pretrained=False,
47 | labels_encoder = False,
48 | cache_dir:Optional[Union[str, Path]] = None
49 | ):
50 | super().__init__()
51 | if labels_encoder:
52 | encoder_config = config.labels_encoder_config
53 | else:
54 | encoder_config = config.encoder_config
55 | if encoder_config is None:
56 | encoder_config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
57 | if config.vocab_size!=-1:
58 | encoder_config.vocab_size = config.vocab_size
59 |
60 | if config._attn_implementation is not None and not labels_encoder:
61 | encoder_config._attn_implementation = config._attn_implementation
62 |
63 | config_name = encoder_config.__class__.__name__
64 |
65 | kwargs = {}
66 | if config_name in DECODER_MODEL_MAPPING:
67 | if not IS_LLM2VEC:
68 | raise MissedPackageException(f"The llm2vec package must be installed to use this decoder model: {config_name}")
69 | else:
70 | print('Loading decoder model using LLM2Vec...')
71 | ModelClass = DECODER_MODEL_MAPPING[config_name]
72 | custom = True
73 | elif config_name in {'T5Config', 'MT5Config'}:
74 | custom = True
75 | ModelClass = T5EncoderModel
76 | if IS_TURBOT5:
77 | kwargs = {"attention_type": 'flash'}
78 | elif config_name in {'DebertaV2Config'}:
79 | custom = True
80 | ModelClass = DebertaV2Model
81 | else:
82 | custom = False
83 | ModelClass = AutoModel
84 |
85 | if from_pretrained:
86 | self.model = ModelClass.from_pretrained(model_name, trust_remote_code=True)
87 | else:
88 | if not custom:
89 | self.model = ModelClass.from_config(encoder_config, trust_remote_code=True)
90 | else:
91 | self.model = ModelClass(encoder_config, **kwargs)
92 |
93 | adapter_config_file = Path(model_name) / "adapter_config.json"
94 |
95 | if adapter_config_file.exists():
96 | if not IS_PEFT:
97 | warnings.warn(f"Adapter configs were detected, if you want to apply them you need to install peft package.")
98 | else:
99 | adapter_config = LoraConfig.from_pretrained(model_name)
100 | self.model = get_peft_model(self.model, adapter_config)
101 |
102 | if config.fuse_layers:
103 | self.layers_fuser = LayersFuser(encoder_config.num_hidden_layers,
104 | encoder_config.hidden_size)
105 |
106 | if labels_encoder:
107 | config.labels_encoder_config = encoder_config
108 | else:
109 | config.encoder_config = encoder_config
110 |
111 | self.config = config
112 |
113 | def forward(self, *args, **kwargs):
114 | if self.config.fuse_layers:
115 | output_hidden_states = True
116 | else:
117 | output_hidden_states = False
118 | output = self.model(*args, output_hidden_states = output_hidden_states,
119 | return_dict = True, **kwargs)
120 | if self.config.fuse_layers:
121 | encoder_layer = self.layers_fuser(output.hidden_states)
122 | else:
123 | encoder_layer = output[0]
124 |
125 | return encoder_layer
126 |
127 | class Encoder(nn.Module):
128 | def __init__(self, config, from_pretrained: bool = False, cache_dir: Optional[Union[str, Path]]= None):
129 | super().__init__()
130 |
131 | self.bert_layer = Transformer( #transformer_model
132 | config.model_name, config, from_pretrained, cache_dir = cache_dir
133 | )
134 |
135 | bert_hidden_size = self.bert_layer.model.config.hidden_size
136 |
137 | if config.hidden_size != bert_hidden_size:
138 | self.projection = nn.Linear(bert_hidden_size, config.hidden_size)
139 |
140 | def resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
141 | return self.bert_layer.model.resize_token_embeddings(new_num_tokens,
142 | pad_to_multiple_of)
143 |
144 | def get_input_embeddings(self):
145 | return self.bert_layer.model.get_input_embeddings()
146 |
147 | def encode_text(self, input_ids, attention_mask, *args, **kwargs):
148 | token_embeddings = self.bert_layer(input_ids, attention_mask, *args, **kwargs)
149 | if hasattr(self, "projection"):
150 | token_embeddings = self.projection(token_embeddings)
151 | return token_embeddings
152 |
153 | def forward(self, *args, **kwargs) -> torch.Tensor:
154 | token_embeddings = self.encode_text(*args, **kwargs)
155 | return token_embeddings
156 |
157 | class BiEncoder(Encoder):
158 | def __init__(self, config, from_pretrained: bool = False, cache_dir:Optional[Union[str, Path]] = None):
159 | super().__init__(config, from_pretrained)
160 | if config.labels_encoder is not None:
161 | self.labels_encoder = Transformer( #transformer_model
162 | config.labels_encoder, config, from_pretrained, True, cache_dir=cache_dir
163 | )
164 | le_hidden_size = self.labels_encoder.model.config.hidden_size
165 |
166 | if config.hidden_size != le_hidden_size:
167 | self.labels_projection = nn.Linear(le_hidden_size, config.hidden_size)
168 |
169 | def mean_pooling(self, token_embeddings, attention_mask):
170 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
171 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
172 |
173 | def encode_labels(self, input_ids, attention_mask, *args, **kwargs):
174 | labels_embeddings = self.labels_encoder(input_ids, attention_mask, *args, **kwargs)
175 | if hasattr(self, "labels_projection"):
176 | labels_embeddings = self.labels_projection(labels_embeddings)
177 | labels_embeddings = self.mean_pooling(labels_embeddings, attention_mask)
178 | return labels_embeddings
179 |
180 | def forward(self, input_ids, attention_mask,
181 | labels_input_ids = None, labels_attention_mask=None,
182 | *args, **kwargs) -> torch.Tensor:
183 | token_embeddings = self.encode_text(input_ids, attention_mask, *args, **kwargs)
184 |
185 | labels_embeddings = self.encode_labels(labels_input_ids, labels_attention_mask, *args, **kwargs)
186 | return token_embeddings, labels_embeddings
--------------------------------------------------------------------------------
/gliner/modeling/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
5 |
6 | class LstmSeq2SeqEncoder(nn.Module):
7 | def __init__(self, config, num_layers=1, dropout=0., bidirectional=True):
8 | super(LstmSeq2SeqEncoder, self).__init__()
9 | self.lstm = nn.LSTM(input_size=config.hidden_size,
10 | hidden_size=config.hidden_size//2,
11 | num_layers=num_layers,
12 | dropout=dropout,
13 | bidirectional=bidirectional,
14 | batch_first=True)
15 |
16 | def forward(self, x, mask, hidden=None):
17 | # Packing the input sequence
18 | lengths = mask.sum(dim=1).cpu()
19 | packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
20 |
21 | # Passing packed sequence through LSTM
22 | packed_output, hidden = self.lstm(packed_x, hidden)
23 |
24 | # Unpacking the output sequence
25 | output, _ = pad_packed_sequence(packed_output, batch_first=True)
26 |
27 | return output
28 |
29 |
30 | def create_projection_layer(hidden_size: int, dropout: float, out_dim: int = None) -> nn.Sequential:
31 | """
32 | Creates a projection layer with specified configurations.
33 | """
34 | if out_dim is None:
35 | out_dim = hidden_size
36 |
37 | return nn.Sequential(
38 | nn.Linear(hidden_size, out_dim * 4),
39 | nn.ReLU(),
40 | nn.Dropout(dropout),
41 | nn.Linear(out_dim * 4, out_dim)
42 | )
43 |
44 | class MultiheadAttention(nn.Module):
45 | def __init__(self, hidden_size, num_heads, dropout) -> None:
46 | super().__init__()
47 | self.hidden_size=hidden_size
48 | self.num_heads=num_heads
49 | self.attention_head_size=hidden_size//num_heads
50 | self.attention_probs_dropout_prob=dropout
51 | self.query_layer = nn.Linear(hidden_size, hidden_size)
52 | self.key_layer = nn.Linear(hidden_size, hidden_size)
53 | self.value_layer = nn.Linear(hidden_size, hidden_size)
54 |
55 | def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
56 | new_x_shape = x.size()[:-1] + (self.num_heads, self.attention_head_size)
57 | x = x.view(new_x_shape)
58 | return x.permute(0, 2, 1, 3)
59 |
60 | def forward(self, query, key=None, value=None, head_mask=None, attn_mask=None):
61 | query = self.transpose_for_scores(self.query_layer(query))
62 | if key is None:
63 | key = self.transpose_for_scores(self.key_layer(query))
64 | else:
65 | key = self.transpose_for_scores(self.key_layer(key))
66 | if value is None and key is None:
67 | value = self.transpose_for_scores(self.value_layer(query))
68 | elif value is None and key is not None:
69 | value = self.transpose_for_scores(self.value_layer(key))
70 | else:
71 | value = self.transpose_for_scores(self.value_layer(value))
72 |
73 | context_layer = torch.nn.functional.scaled_dot_product_attention(
74 | query,
75 | key,
76 | value,
77 | head_mask,
78 | self.attention_probs_dropout_prob if self.training else 0.0,
79 | is_causal=False,
80 | scale=None,
81 | )
82 |
83 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
84 | new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
85 | context_layer = context_layer.view(new_context_layer_shape)
86 |
87 | return context_layer, None
88 |
89 | class SelfAttentionBlock(nn.Module):
90 | def __init__(self, d_model, num_heads, dropout=0.1):
91 | super().__init__()
92 | self.self_attn = MultiheadAttention(d_model, num_heads, dropout=dropout)
93 | self.pre_norm = nn.LayerNorm(d_model)
94 | self.post_norm = nn.LayerNorm(d_model)
95 | self.dropout = nn.Dropout(dropout)
96 | self.q_proj = nn.Linear(d_model, d_model)
97 | self.k_proj = nn.Linear(d_model, d_model)
98 | self.v_proj = nn.Linear(d_model, d_model)
99 |
100 | def forward(self, x, mask=None):
101 | x = self.pre_norm(x)
102 | q = self.q_proj(x)
103 | k = self.k_proj(x)
104 | v = self.v_proj(x)
105 | attn_output, _ = self.self_attn(q, k, v, attn_mask=mask)
106 | output = x + self.dropout(attn_output)
107 | return self.post_norm(output)
108 |
109 | class CrossAttentionBlock(nn.Module):
110 | def __init__(self, d_model, num_heads, dropout=0.1):
111 | super().__init__()
112 | self.cross_attn = MultiheadAttention(d_model, num_heads, dropout=dropout)
113 | self.pre_norm = nn.LayerNorm(d_model)
114 | self.post_norm = nn.LayerNorm(d_model)
115 | self.dropout = nn.Dropout(dropout)
116 | self.v_proj = nn.Linear(d_model, d_model)
117 |
118 | def forward(self, query, key, value=None, mask=None):
119 | query = self.pre_norm(query)
120 | if value is None:
121 | value = self.v_proj(key)
122 | attn_output, _ = self.cross_attn(query, key, value, attn_mask=mask)
123 | output = query + self.dropout(attn_output)
124 | return self.post_norm(output)
125 |
126 | class CrossFuser(nn.Module):
127 | def __init__(self, d_model, query_dim, num_heads=8, num_layers=1, dropout=0.1, schema='l2l-l2t'):
128 | super().__init__()
129 | self.d_model = d_model
130 | self.schema = schema.split('-')
131 | layers = []
132 | for _ in range(num_layers):
133 | layer = []
134 | for attn_type in self.schema:
135 | if attn_type in {'l2l', 't2t'}:
136 | layer.append(SelfAttentionBlock(d_model, num_heads, dropout))
137 | else:
138 | layer.append(CrossAttentionBlock(d_model, num_heads, dropout))
139 | layer = nn.ModuleList(layer)
140 | layers.append(layer)
141 |
142 | self.layers = nn.ModuleList(layers)
143 | # self.dense_i = nn.Linear(query_dim, d_model)
144 | # self.dense_o = nn.Linear(d_model, query_dim)
145 |
146 | def forward(self, query, key, query_mask=None, key_mask=None):
147 | # query = self.dense_i(query)
148 | for sublayers in self.layers:
149 | for id, layer in enumerate(sublayers):
150 | if self.schema[id] == 'l2l':
151 | if query_mask is not None:
152 | self_attn_mask = query_mask.unsqueeze(1) * query_mask.unsqueeze(2)
153 | else:
154 | self_attn_mask = None
155 | query = layer(query, mask=self_attn_mask)
156 | elif self.schema[id] == 't2t':
157 | if key_mask is not None:
158 | self_attn_mask = key_mask.unsqueeze(1) * key_mask.unsqueeze(2)
159 | else:
160 | self_attn_mask = None
161 | key = layer(key, mask=self_attn_mask)
162 | elif self.schema[id] == 'l2t':
163 | if query_mask is not None and key_mask is not None:
164 | cross_attn_mask = query_mask.unsqueeze(-1) * key_mask.unsqueeze(1)
165 | else:
166 | cross_attn_mask = None
167 | query = layer(query, key, mask=cross_attn_mask)
168 | elif self.schema[id] == 't2l':
169 | if query_mask is not None and key_mask is not None:
170 | cross_attn_mask = key_mask.unsqueeze(-1) * query_mask.unsqueeze(1)
171 | else:
172 | cross_attn_mask = None
173 | key = layer(key, query, mask=cross_attn_mask)
174 | # query=self.dense_o(query)
175 | return query, key
176 |
177 | class LayersFuser(nn.Module):
178 | def __init__(self, num_layers, hidden_size, output_size=None):
179 | super().__init__()
180 | self.num_layers = num_layers
181 | self.hidden_size = hidden_size
182 | self.output_size = output_size if output_size is not None else hidden_size
183 |
184 | # Squeeze operation
185 | self.squeeze = nn.Linear(hidden_size, 1)
186 |
187 | # Excitation operation
188 | self.W1 = nn.Linear(num_layers, num_layers // 2)
189 | self.W2 = nn.Linear(num_layers // 2, num_layers)
190 |
191 | # Final projection
192 | self.output_projection = nn.Linear(self.hidden_size, self.output_size)
193 |
194 | def forward(self, encoder_outputs):
195 | # encoder_outputs is a list of tensors, each of shape [B, L, D]
196 | B, L, D = encoder_outputs[0].shape
197 |
198 | # Concatenate all layers
199 | U = torch.stack(encoder_outputs[1:], dim=1) # [B, K, L, D]
200 |
201 | # Squeeze operation
202 | Z = self.squeeze(U).squeeze(-1) # [B, K, L]
203 | Z = Z.mean(dim=2) # [B, K]
204 |
205 | # Excitation operation
206 | s = self.W2(F.relu(self.W1(Z))) # [B, K]
207 | s = torch.sigmoid(s) # [B, K]
208 |
209 | # Apply attention weights
210 | U_weighted = U * s.unsqueeze(-1).unsqueeze(-1) # [B, K, L, D]
211 |
212 | # Sum across layers
213 | U_sum = U_weighted.sum(dim=1) # [B, L, D]
214 |
215 | # final projection
216 | output = self.output_projection(U_sum) # [B, L, output_size]
217 |
218 | return output
--------------------------------------------------------------------------------
/gliner/modeling/loss_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def focal_loss_with_logits(
6 | inputs: torch.Tensor,
7 | targets: torch.Tensor,
8 | alpha: float = 0.25,
9 | gamma: float = 2,
10 | reduction: str = "none",
11 | label_smoothing: float = 0.0,
12 | ignore_index: int = -100 # default value for ignored index
13 | ) -> torch.Tensor:
14 | """
15 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
16 |
17 | Args:
18 | inputs (Tensor): A float tensor of arbitrary shape.
19 | The predictions for each example.
20 | targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
21 | classification label for each element in inputs
22 | (0 for the negative class and 1 for the positive class).
23 | alpha (float): Weighting factor in range (0,1) to balance
24 | positive vs negative examples or -1 for ignore. Default: ``0.25``.
25 | gamma (float): Exponent of the modulating factor (1 - p_t) to
26 | balance easy vs hard examples. Default: ``2``.
27 | reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
28 | ``'none'``: No reduction will be applied to the output.
29 | ``'mean'``: The output will be averaged.
30 | ``'sum'``: The output will be summed. Default: ``'none'``.
31 | label_smoothing (float): Specifies the amount of smoothing when computing the loss,
32 | where 0.0 means no smoothing.
33 | ignore_index (int): Specifies a target value that is ignored and does not contribute
34 | to the input gradient. Default: ``-100``.
35 | Returns:
36 | Loss tensor with the reduction option applied.
37 | """
38 | # Create a mask to ignore specified index
39 | valid_mask = targets != ignore_index
40 |
41 | # Apply label smoothing if needed
42 | if label_smoothing != 0:
43 | with torch.no_grad():
44 | targets = targets * (1 - label_smoothing) + 0.5 * label_smoothing
45 |
46 | # Apply sigmoid activation to inputs
47 | p = torch.sigmoid(inputs)
48 |
49 | # Compute the binary cross-entropy loss without reduction
50 | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
51 |
52 | # Apply the valid mask to the loss
53 | loss = loss * valid_mask
54 |
55 | # Apply focal loss modulation if gamma is greater than 0
56 | if gamma > 0:
57 | p_t = p * targets + (1 - p) * (1 - targets)
58 | loss = loss * ((1 - p_t) ** gamma)
59 |
60 | # Apply alpha weighting if alpha is specified
61 | if alpha >= 0:
62 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
63 | loss = alpha_t * loss
64 |
65 | # Apply reduction method
66 | if reduction == "none":
67 | return loss
68 | elif reduction == "mean":
69 | return loss.sum() / valid_mask.sum() # Normalize by the number of valid (non-ignored) elements
70 | elif reduction == "sum":
71 | return loss.sum()
72 | else:
73 | raise ValueError(
74 | f"Invalid value for argument 'reduction': '{reduction}'. "
75 | f"Supported reduction modes: 'none', 'mean', 'sum'"
76 | )
--------------------------------------------------------------------------------
/gliner/modeling/scorers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class Scorer(nn.Module):
5 | def __init__(self, hidden_size, dropout=0.1):
6 | super().__init__()
7 |
8 | self.proj_token = nn.Linear(hidden_size, hidden_size * 2)
9 | self.proj_label = nn.Linear(hidden_size, hidden_size * 2)
10 |
11 | self.out_mlp = nn.Sequential(
12 | nn.Linear(hidden_size * 3, hidden_size * 4),
13 | nn.Dropout(dropout),
14 | nn.ReLU(),
15 | nn.Linear(hidden_size * 4, 3) # start, end, score
16 | )
17 |
18 | def forward(self, token_rep, label_rep):
19 | batch_size, seq_len, hidden_size = token_rep.shape
20 | num_classes = label_rep.shape[1]
21 |
22 | # (batch_size, seq_len, 3, hidden_size)
23 | token_rep = self.proj_token(token_rep).view(batch_size, seq_len, 1, 2, hidden_size)
24 | label_rep = self.proj_label(label_rep).view(batch_size, 1, num_classes, 2, hidden_size)
25 |
26 | # (2, batch_size, seq_len, num_classes, hidden_size)
27 | token_rep = token_rep.expand(-1, -1, num_classes, -1, -1).permute(3, 0, 1, 2, 4)
28 | label_rep = label_rep.expand(-1, seq_len, -1, -1, -1).permute(3, 0, 1, 2, 4)
29 |
30 | # (batch_size, seq_len, num_classes, hidden_size * 3)
31 | cat = torch.cat([token_rep[0], label_rep[0], token_rep[1] * label_rep[1]], dim=-1)
32 |
33 | # (batch_size, seq_len, num_classes, 3)
34 | scores = self.out_mlp(cat).permute(3, 0, 1, 2)
35 |
36 | return scores
37 |
--------------------------------------------------------------------------------
/gliner/modeling/span_rep.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | from .layers import create_projection_layer
6 |
7 | class SpanQuery(nn.Module):
8 |
9 | def __init__(self, hidden_size, max_width, trainable=True):
10 | super().__init__()
11 |
12 | self.query_seg = nn.Parameter(torch.randn(hidden_size, max_width))
13 |
14 | nn.init.uniform_(self.query_seg, a=-1, b=1)
15 |
16 | if not trainable:
17 | self.query_seg.requires_grad = False
18 |
19 | self.project = nn.Sequential(
20 | nn.Linear(hidden_size, hidden_size),
21 | nn.ReLU()
22 | )
23 |
24 | def forward(self, h, *args):
25 | # h of shape [B, L, D]
26 | # query_seg of shape [D, max_width]
27 |
28 | span_rep = torch.einsum('bld, ds->blsd', h, self.query_seg)
29 |
30 | return self.project(span_rep)
31 |
32 |
33 | class SpanMLP(nn.Module):
34 |
35 | def __init__(self, hidden_size, max_width):
36 | super().__init__()
37 |
38 | self.mlp = nn.Linear(hidden_size, hidden_size * max_width)
39 |
40 | def forward(self, h, *args):
41 | # h of shape [B, L, D]
42 | # query_seg of shape [D, max_width]
43 |
44 | B, L, D = h.size()
45 |
46 | span_rep = self.mlp(h)
47 |
48 | span_rep = span_rep.view(B, L, -1, D)
49 |
50 | return span_rep.relu()
51 |
52 |
53 | class SpanCAT(nn.Module):
54 |
55 | def __init__(self, hidden_size, max_width):
56 | super().__init__()
57 |
58 | self.max_width = max_width
59 |
60 | self.query_seg = nn.Parameter(torch.randn(128, max_width))
61 |
62 | self.project = nn.Sequential(
63 | nn.Linear(hidden_size + 128, hidden_size),
64 | nn.ReLU()
65 | )
66 |
67 | def forward(self, h, *args):
68 | # h of shape [B, L, D]
69 | # query_seg of shape [D, max_width]
70 |
71 | B, L, D = h.size()
72 |
73 | h = h.view(B, L, 1, D).repeat(1, 1, self.max_width, 1)
74 |
75 | q = self.query_seg.view(1, 1, self.max_width, -1).repeat(B, L, 1, 1)
76 |
77 | span_rep = torch.cat([h, q], dim=-1)
78 |
79 | span_rep = self.project(span_rep)
80 |
81 | return span_rep
82 |
83 |
84 | class SpanConvBlock(nn.Module):
85 | def __init__(self, hidden_size, kernel_size, span_mode='conv_normal'):
86 | super().__init__()
87 |
88 | if span_mode == 'conv_conv':
89 | self.conv = nn.Conv1d(hidden_size, hidden_size,
90 | kernel_size=kernel_size)
91 |
92 | # initialize the weights
93 | nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='relu')
94 |
95 | elif span_mode == 'conv_max':
96 | self.conv = nn.MaxPool1d(kernel_size=kernel_size, stride=1)
97 | elif span_mode == 'conv_mean' or span_mode == 'conv_sum':
98 | self.conv = nn.AvgPool1d(kernel_size=kernel_size, stride=1)
99 |
100 | self.span_mode = span_mode
101 |
102 | self.pad = kernel_size - 1
103 |
104 | def forward(self, x):
105 |
106 | x = torch.einsum('bld->bdl', x)
107 |
108 | if self.pad > 0:
109 | x = F.pad(x, (0, self.pad), "constant", 0)
110 |
111 | x = self.conv(x)
112 |
113 | if self.span_mode == "conv_sum":
114 | x = x * (self.pad + 1)
115 |
116 | return torch.einsum('bdl->bld', x)
117 |
118 |
119 | class SpanConv(nn.Module):
120 | def __init__(self, hidden_size, max_width, span_mode):
121 | super().__init__()
122 |
123 | kernels = [i + 2 for i in range(max_width - 1)]
124 |
125 | self.convs = nn.ModuleList()
126 |
127 | for kernel in kernels:
128 | self.convs.append(SpanConvBlock(hidden_size, kernel, span_mode))
129 |
130 | self.project = nn.Sequential(
131 | nn.ReLU(),
132 | nn.Linear(hidden_size, hidden_size)
133 | )
134 |
135 | def forward(self, x, *args):
136 |
137 | span_reps = [x]
138 |
139 | for conv in self.convs:
140 | h = conv(x)
141 | span_reps.append(h)
142 |
143 | span_reps = torch.stack(span_reps, dim=-2)
144 |
145 | return self.project(span_reps)
146 |
147 |
148 | class SpanEndpointsBlock(nn.Module):
149 | def __init__(self, kernel_size):
150 | super().__init__()
151 |
152 | self.kernel_size = kernel_size
153 |
154 | def forward(self, x):
155 | B, L, D = x.size()
156 |
157 | span_idx = torch.LongTensor(
158 | [[i, i + self.kernel_size - 1] for i in range(L)]).to(x.device)
159 |
160 | x = F.pad(x, (0, 0, 0, self.kernel_size - 1), "constant", 0)
161 |
162 | # endrep
163 | start_end_rep = torch.index_select(x, dim=1, index=span_idx.view(-1))
164 |
165 | start_end_rep = start_end_rep.view(B, L, 2, D)
166 |
167 | return start_end_rep
168 |
169 |
170 | class ConvShare(nn.Module):
171 | def __init__(self, hidden_size, max_width):
172 | super().__init__()
173 |
174 | self.max_width = max_width
175 |
176 | self.conv_weigth = nn.Parameter(
177 | torch.randn(hidden_size, hidden_size, max_width))
178 |
179 | nn.init.kaiming_uniform_(self.conv_weigth, nonlinearity='relu')
180 |
181 | self.project = nn.Sequential(
182 | nn.ReLU(),
183 | nn.Linear(hidden_size, hidden_size)
184 | )
185 |
186 | def forward(self, x, *args):
187 | span_reps = []
188 |
189 | x = torch.einsum('bld->bdl', x)
190 |
191 | for i in range(self.max_width):
192 | pad = i
193 | x_i = F.pad(x, (0, pad), "constant", 0)
194 | conv_w = self.conv_weigth[:, :, :i + 1]
195 | out_i = F.conv1d(x_i, conv_w)
196 | span_reps.append(out_i.transpose(-1, -2))
197 |
198 | out = torch.stack(span_reps, dim=-2)
199 |
200 | return self.project(out)
201 |
202 |
203 | def extract_elements(sequence, indices):
204 | B, L, D = sequence.shape
205 | K = indices.shape[1]
206 |
207 | # Expand indices to [B, K, D]
208 | expanded_indices = indices.unsqueeze(2).expand(-1, -1, D)
209 |
210 | # Gather the elements
211 | extracted_elements = torch.gather(sequence, 1, expanded_indices)
212 |
213 | return extracted_elements
214 |
215 |
216 | class SpanMarker(nn.Module):
217 |
218 | def __init__(self, hidden_size, max_width, dropout=0.4):
219 | super().__init__()
220 |
221 | self.max_width = max_width
222 |
223 | self.project_start = nn.Sequential(
224 | nn.Linear(hidden_size, hidden_size * 2, bias=True),
225 | nn.ReLU(),
226 | nn.Dropout(dropout),
227 | nn.Linear(hidden_size * 2, hidden_size, bias=True),
228 | )
229 |
230 | self.project_end = nn.Sequential(
231 | nn.Linear(hidden_size, hidden_size * 2, bias=True),
232 | nn.ReLU(),
233 | nn.Dropout(dropout),
234 | nn.Linear(hidden_size * 2, hidden_size, bias=True),
235 | )
236 |
237 | self.out_project = nn.Linear(hidden_size * 2, hidden_size, bias=True)
238 |
239 | def forward(self, h, span_idx):
240 | # h of shape [B, L, D]
241 | # query_seg of shape [D, max_width]
242 |
243 | B, L, D = h.size()
244 |
245 | # project start and end
246 | start_rep = self.project_start(h)
247 | end_rep = self.project_end(h)
248 |
249 | start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
250 | end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
251 |
252 | # concat start and end
253 | cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
254 |
255 | # project
256 | cat = self.out_project(cat)
257 |
258 | # reshape
259 | return cat.view(B, L, self.max_width, D)
260 |
261 |
262 | class SpanMarkerV0(nn.Module):
263 | """
264 | Marks and projects span endpoints using an MLP.
265 | """
266 |
267 | def __init__(self, hidden_size: int, max_width: int, dropout: float = 0.4):
268 | super().__init__()
269 | self.max_width = max_width
270 | self.project_start = create_projection_layer(hidden_size, dropout)
271 | self.project_end = create_projection_layer(hidden_size, dropout)
272 |
273 | self.out_project = create_projection_layer(hidden_size * 2, dropout, hidden_size)
274 |
275 | def forward(self, h: torch.Tensor, span_idx: torch.Tensor) -> torch.Tensor:
276 | B, L, D = h.size()
277 |
278 | start_rep = self.project_start(h)
279 | end_rep = self.project_end(h)
280 |
281 | start_span_rep = extract_elements(start_rep, span_idx[:, :, 0])
282 | end_span_rep = extract_elements(end_rep, span_idx[:, :, 1])
283 |
284 | cat = torch.cat([start_span_rep, end_span_rep], dim=-1).relu()
285 |
286 | return self.out_project(cat).view(B, L, self.max_width, D)
287 |
288 |
289 | class ConvShareV2(nn.Module):
290 | def __init__(self, hidden_size, max_width):
291 | super().__init__()
292 |
293 | self.max_width = max_width
294 |
295 | self.conv_weigth = nn.Parameter(
296 | torch.randn(hidden_size, hidden_size, max_width)
297 | )
298 |
299 | nn.init.xavier_normal_(self.conv_weigth)
300 |
301 | def forward(self, x, *args):
302 | span_reps = []
303 |
304 | x = torch.einsum('bld->bdl', x)
305 |
306 | for i in range(self.max_width):
307 | pad = i
308 | x_i = F.pad(x, (0, pad), "constant", 0)
309 | conv_w = self.conv_weigth[:, :, :i + 1]
310 | out_i = F.conv1d(x_i, conv_w)
311 | span_reps.append(out_i.transpose(-1, -2))
312 |
313 | out = torch.stack(span_reps, dim=-2)
314 |
315 | return out
316 |
317 |
318 | class SpanRepLayer(nn.Module):
319 | """
320 | Various span representation approaches
321 | """
322 |
323 | def __init__(self, hidden_size, max_width, span_mode, **kwargs):
324 | super().__init__()
325 |
326 | if span_mode == 'marker':
327 | self.span_rep_layer = SpanMarker(hidden_size, max_width, **kwargs)
328 | elif span_mode == 'markerV0':
329 | self.span_rep_layer = SpanMarkerV0(hidden_size, max_width, **kwargs)
330 | elif span_mode == 'query':
331 | self.span_rep_layer = SpanQuery(
332 | hidden_size, max_width, trainable=True)
333 | elif span_mode == 'mlp':
334 | self.span_rep_layer = SpanMLP(hidden_size, max_width)
335 | elif span_mode == 'cat':
336 | self.span_rep_layer = SpanCAT(hidden_size, max_width)
337 | elif span_mode == 'conv_conv':
338 | self.span_rep_layer = SpanConv(
339 | hidden_size, max_width, span_mode='conv_conv')
340 | elif span_mode == 'conv_max':
341 | self.span_rep_layer = SpanConv(
342 | hidden_size, max_width, span_mode='conv_max')
343 | elif span_mode == 'conv_mean':
344 | self.span_rep_layer = SpanConv(
345 | hidden_size, max_width, span_mode='conv_mean')
346 | elif span_mode == 'conv_sum':
347 | self.span_rep_layer = SpanConv(
348 | hidden_size, max_width, span_mode='conv_sum')
349 | elif span_mode == 'conv_share':
350 | self.span_rep_layer = ConvShare(hidden_size, max_width)
351 | else:
352 | raise ValueError(f'Unknown span mode {span_mode}')
353 |
354 | def forward(self, x, *args):
355 |
356 | return self.span_rep_layer(x, *args)
357 |
--------------------------------------------------------------------------------
/gliner/multitask/__init__.py:
--------------------------------------------------------------------------------
1 | from .classification import GLiNERClassifier
2 | from .question_answering import GLiNERQuestionAnswerer, GLiNERSquadEvaluator
3 | from .open_extraction import GLiNEROpenExtractor
4 | from .relation_extraction import GLiNERRelationExtractor, GLiNERDocREDEvaluator
5 | from .summarization import GLiNERSummarizer
--------------------------------------------------------------------------------
/gliner/multitask/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List, Union, Optional
3 | import torch
4 | import warnings
5 |
6 | from ..model import GLiNER
7 |
8 | class GLiNERBasePipeline(ABC):
9 | """
10 | Base class for GLiNER pipelines. Provides an interface for preparing texts,
11 | processing predictions, and evaluating the model.
12 |
13 | Args:
14 | model_id (str): Identifier for the model to be loaded.
15 | prompt (str, optional): Prompt template for text preparation. Defaults to None.
16 | device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'.
17 |
18 | Attributes:
19 | model (GLiNER): The loaded GLiNER model.
20 | device (str): The device being used for computation.
21 | prompt (str): The prompt template for text preparation.
22 | """
23 |
24 | def __init__(self, model_id: str = None, model: GLiNER = None, prompt=None, device='cuda:0'):
25 | """
26 | Initializes the GLiNERBasePipeline.
27 |
28 | Args:
29 | model_id (str): Identifier for the model to be loaded.
30 | prompt (str, optional): Prompt template for text preparation. Defaults to None.
31 | device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'.
32 | """
33 | if 'cuda' in device and not torch.cuda.is_available():
34 | warnings.warn(f"{device} is not available, setting device as 'cpu'.")
35 | device = 'cpu'
36 | self.device = device
37 |
38 | if model is not None:
39 | self.model = model.to(self.device)
40 | elif model_id is not None:
41 | self.model = GLiNER.from_pretrained(model_id).to(self.device)
42 | else:
43 | raise ValueError("Either 'model_id' or 'model' must be provided to initialize the pipeline.")
44 |
45 | self.prompt = prompt
46 |
47 | @abstractmethod
48 | def prepare_texts(self, texts: List[str], *args, **kwargs):
49 | """
50 | Prepares texts for input to the model.
51 |
52 | Args:
53 | texts (List[str]): List of input texts.
54 | *args: Additional positional arguments.
55 | **kwargs: Additional keyword arguments.
56 |
57 | Returns:
58 | Any: The processed texts ready for model input.
59 | """
60 | pass
61 |
62 | @abstractmethod
63 | def process_predictions(self, predictions: List[dict]):
64 | """
65 | Processes model predictions into the desired format.
66 |
67 | Args:
68 | predictions (List[dict]): Raw predictions from the model.
69 |
70 | Returns:
71 | Any: Processed predictions in the desired format.
72 | """
73 | pass
74 |
75 | @abstractmethod
76 | def evaluate(self, dataset_id: str, labels: Optional[List[str]] = None, threshold: float = 0.5):
77 | """
78 | Evaluates the model on a given dataset.
79 |
80 | Args:
81 | dataset_id (str): Identifier for the evaluation dataset.
82 | labels (Optional[List[str]]): List of labels to evaluate. Defaults to None.
83 | threshold (float): Threshold for prediction confidence. Defaults to 0.5.
84 |
85 | Returns:
86 | Any: Evaluation results.
87 | """
88 | pass
89 |
90 | def __call__(self, texts: Union[str, List[str]], labels: List[str] = ['match'],
91 | threshold: float = 0.5, batch_size: int = 8, **kwargs):
92 | """
93 | Runs the model on the provided texts and returns processed results.
94 |
95 | Args:
96 | texts (Union[str, List[str]]): Single or list of input texts.
97 | labels (Optional[List[str]]): List of class labels for text preparation. Defaults to None.
98 | threshold (float): Threshold for prediction confidence. Defaults to 0.5.
99 | batch_size (int): Batch size for processing. Defaults to 8.
100 |
101 | Returns:
102 | Any: Processed results from the model.
103 | """
104 | if isinstance(texts, str):
105 | texts = [texts]
106 |
107 | prompts = self.prepare_texts(texts, **kwargs)
108 |
109 | predictions = self.model.run(prompts, labels, threshold=threshold, batch_size=batch_size)
110 |
111 | results = self.process_predictions(predictions, **kwargs)
112 |
113 | return results
--------------------------------------------------------------------------------
/gliner/multitask/classification.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 | import os
3 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
4 | import torch
5 | from datasets import load_dataset, Dataset
6 | from sklearn.metrics import f1_score
7 | from gliner import GLiNER
8 |
9 | from .base import GLiNERBasePipeline
10 |
11 | class GLiNERClassifier(GLiNERBasePipeline):
12 | """
13 | A class to evaluate the GLiNER model for classification tasks using F1 scores.
14 |
15 | Attributes:
16 | device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'.
17 | model (GLiNER): Loaded GLiNER model instance.
18 | prompt (str): Template prompt for text classification.
19 |
20 | Methods:
21 | compute_f_score(predicts, true_labels):
22 | Computes micro, macro, and weighted F1 scores.
23 | prepare_dataset(dataset, classes=None, text_column='text', label_column='label', split=None, max_examples=-1):
24 | Prepares texts and true labels from the given dataset.
25 | process_predictions(predictions):
26 | Processes model predictions to extract the most likely labels.
27 | prepare_texts(texts, labels):
28 | Creates classification prompts for each input text.
29 | __call__(texts, labels, threshold=0.5):
30 | Runs the model on the given texts and returns predicted labels.
31 | evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1):
32 | Evaluates the model on a dataset and computes F1 scores.
33 | """
34 |
35 | prompt = "Classify text into the following classes: {}"
36 |
37 | def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', prompt: Optional[str] = None):
38 | """
39 | Initializes the GLiNERClassifier.
40 |
41 | Args:
42 | model_id (str, optional): Identifier for the model to be loaded. Defaults to None.
43 | model (GLiNER, optional): Preloaded GLiNER model. Defaults to None.
44 | device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'.
45 | prompt (str, optional): Template prompt for text classification. Defaults to the class-level prompt.
46 | """
47 | # Use the provided prompt or default to the class-level prompt
48 | prompt = prompt if prompt is not None else self.prompt
49 | super().__init__(model_id=model_id, model=model, prompt=prompt, device=device)
50 |
51 |
52 | def compute_f_score(self, predicts, true_labels):
53 | """
54 | Computes the micro, macro, and weighted F1 scores.
55 |
56 | Args:
57 | predicts (list): List of predicted labels.
58 | true_labels (list): List of true labels.
59 |
60 | Returns:
61 | dict: Dictionary with micro, macro, and weighted F1 scores.
62 | """
63 | micro = f1_score(true_labels, predicts, average="micro")
64 | macro = f1_score(true_labels, predicts, average="macro")
65 | weighted = f1_score(true_labels, predicts, average="weighted")
66 | return {"micro": micro, "macro": macro, "weighted": weighted}
67 |
68 | def prepare_dataset(self, dataset: Dataset, classes=None, text_column='text', label_column="label", split=None, max_examples=-1):
69 | """
70 | Prepares the dataset by extracting texts and true labels.
71 |
72 | Args:
73 | dataset (Dataset or dict): The dataset to prepare.
74 | classes (list, optional): List of class labels. Defaults to None.
75 | text_column (str): Name of the text column. Defaults to 'text'.
76 | label_column (str): Name of the label column. Defaults to 'label'.
77 | split (str, optional): Delimiter for splitting class names. Defaults to None.
78 | max_examples (int): Maximum number of examples to use. Defaults to -1 (use all).
79 |
80 | Returns:
81 | tuple: Texts, classes, and true labels.
82 | """
83 | if 'test' in dataset:
84 | test_dataset = dataset['test']
85 | elif isinstance(dataset, Dataset):
86 | test_dataset = dataset
87 | else:
88 | test_dataset = dataset['train']
89 |
90 | if classes is None:
91 | classes = test_dataset.features[label_column].names
92 | if split is not None:
93 | classes = [' '.join(class_.split(split)) for class_ in classes]
94 |
95 | texts = test_dataset[text_column]
96 | true_labels = test_dataset[label_column]
97 |
98 | if isinstance(test_dataset[label_column][0], int):
99 | true_labels = [classes[label] for label in true_labels]
100 |
101 | if max_examples > 0:
102 | texts = texts[:max_examples]
103 | true_labels = true_labels[:max_examples]
104 |
105 | return texts, classes, true_labels
106 |
107 | def process_predictions(self, predictions, multi_label=False, **kwargs):
108 | """
109 | Processes predictions to extract the highest-scoring label(s).
110 |
111 | Args:
112 | predictions (list): List of predictions with scores.
113 | multi_label (bool): Whether to allow multiple labels per input. Defaults to False.
114 |
115 | Returns:
116 | list: List of predicted labels for each input.
117 | """
118 | batch_predicted_labels = []
119 |
120 | for prediction in predictions:
121 | # Sort predictions by score in descending order
122 | sorted_predictions = sorted(prediction, key=lambda entity: entity["score"], reverse=True)
123 |
124 | if not sorted_predictions:
125 | # Default prediction if no valid predictions are found
126 | batch_predicted_labels.append([{'label': 'other', 'score': 1.0}])
127 | continue
128 |
129 | if not multi_label:
130 | # Single-label mode: select the top prediction and compute softmax score
131 | scores = [item['score'] for item in sorted_predictions]
132 | softmax_scores = torch.softmax(torch.tensor(scores), dim=0).tolist()
133 | top_prediction = {'label': sorted_predictions[0]['text'], 'score': softmax_scores[0]}
134 | batch_predicted_labels.append([top_prediction])
135 | else:
136 | # Multi-label mode: retain all predictions with original scores
137 | predicted_labels = [{'label': pred['text'], 'score': pred['score']} for pred in sorted_predictions]
138 | batch_predicted_labels.append(predicted_labels)
139 |
140 | return batch_predicted_labels
141 |
142 | def prepare_texts(self, texts, classes, **kwargs):
143 | """
144 | Prepares prompts for classification by appending labels to texts.
145 |
146 | Args:
147 | texts (list): List of input texts.
148 | classes (list): List of classification labels.
149 |
150 | Returns:
151 | list: List of formatted prompts.
152 | """
153 | prompts = []
154 | labels_ = ', '.join(classes)
155 | for text in texts:
156 | prompt = f"{self.prompt.format(labels_)} \n {text}"
157 | prompts.append(prompt)
158 | return prompts
159 |
160 | def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None,
161 | labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1):
162 | """
163 | Evaluates the model on a specified dataset and computes evaluation metrics.
164 |
165 | Args:
166 | dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets).
167 | dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored.
168 | labels (list, optional): List of target labels to consider for classification. Defaults to None (use all).
169 | threshold (float): Confidence threshold for predictions. Defaults to 0.5.
170 | max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples).
171 |
172 | Returns:
173 | dict: A dictionary containing evaluation metrics such as F1 scores (micro, macro, and weighted).
174 |
175 | Raises:
176 | ValueError: If neither `dataset_id` nor `dataset` is provided.
177 | """
178 | if dataset is None and dataset_id is not None:
179 | dataset = load_dataset(dataset_id)
180 | elif dataset is not None and dataset_id is None:
181 | dataset = dataset
182 | else:
183 | raise ValueError("Either 'dataset_id' or 'dataset' must be provided to start evaluation.")
184 |
185 | test_texts, classes, true_labels = self.prepare_dataset(dataset, labels, max_examples=max_examples)
186 |
187 | predictions = self.__call__(test_texts, classes=classes, threshold=threshold)
188 | predicted_labels = [pred[0]['label'] for pred in predictions]
189 |
190 | return self.compute_f_score(predicted_labels, true_labels)
191 |
--------------------------------------------------------------------------------
/gliner/multitask/open_extraction.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Union
2 | import os
3 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
4 | import torch
5 | from datasets import load_dataset, Dataset
6 | from gliner import GLiNER
7 |
8 | from .base import GLiNERBasePipeline
9 |
10 | class GLiNEROpenExtractor(GLiNERBasePipeline):
11 | """
12 | A class to use GLiNER for open information extraction inference and evaluation.
13 |
14 | Attributes:
15 | device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'.
16 | model (GLiNER): Loaded GLiNER model instance.
17 | prompt (str): Template prompt for open information extraction.
18 |
19 | Methods:
20 | process_predictions(predictions):
21 | Processes model predictions to extract the most likely labels.
22 | prepare_texts(texts, labels):
23 | Creates open information extraction prompts for each input text.
24 | __call__(texts, labels, threshold=0.5):
25 | Runs the model on the given texts and returns predicted labels.
26 | evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1):
27 | Evaluates the model on a dataset and computes F1 scores.
28 | """
29 |
30 | prompt = ""
31 |
32 | def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', prompt: Optional[str] = None):
33 | """
34 | Initializes the GLiNEROpenExtractor.
35 |
36 | Args:
37 | model_id (str, optional): Identifier for the model to be loaded. Defaults to None.
38 | model (GLiNER, optional): Preloaded GLiNER model. Defaults to None.
39 | device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'.
40 | prompt (str, optional): Template prompt for open information extraction.
41 | """
42 | # Use the provided prompt or default to the class-level prompt
43 | prompt = prompt if prompt is not None else self.prompt
44 | super().__init__(model_id=model_id, model=model, prompt=prompt, device=device)
45 |
46 |
47 | def process_predictions(self, predictions, **kwargs):
48 | """
49 | Processes predictions to extract the highest-scoring label(s).
50 |
51 | Args:
52 | predictions (list): List of predictions with scores.
53 |
54 | Returns:
55 | list: List of predicted labels for each input.
56 | """
57 | return predictions
58 |
59 | def prepare_texts(self, texts: List[str], **kwargs):
60 | """
61 | Prepares prompts for open-information extraction.
62 |
63 | Args:
64 | texts (list): List of input texts.
65 |
66 | Returns:
67 | list: List of formatted prompts.
68 | """
69 | prompts = []
70 |
71 | for id, text in enumerate(texts):
72 | prompt = f"{self.prompt} \n {text}"
73 | prompts.append(prompt)
74 | return prompts
75 |
76 |
77 | def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None,
78 | labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1):
79 | """
80 | Evaluates the model on a specified dataset and computes evaluation metrics.
81 |
82 | Args:
83 | dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets).
84 | dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored.
85 | labels (list, optional): List of target labels to consider for extraction. Defaults to None (use all).
86 | threshold (float): Confidence threshold for predictions. Defaults to 0.5.
87 | max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples).
88 |
89 | Returns:
90 | dict: A dictionary containing evaluation metrics.
91 |
92 | Raises:
93 | ValueError: If neither `dataset_id` nor `dataset` is provided.
94 | """
95 | raise NotImplementedError("Currently `evaluate` method is not implemented.")
--------------------------------------------------------------------------------
/gliner/multitask/question_answering.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Union
2 | import os
3 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
4 | import torch
5 | from datasets import load_dataset, Dataset
6 | from gliner import GLiNER
7 |
8 | from .base import GLiNERBasePipeline
9 |
10 | class GLiNERQuestionAnswerer(GLiNERBasePipeline):
11 | """
12 | A class to use GLiNER for question-answering inference and evaluation.
13 |
14 | Attributes:
15 | device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'.
16 | model (GLiNER): Loaded GLiNER model instance.
17 | prompt (str): Template prompt for text question-asnwering.
18 |
19 | Methods:
20 | process_predictions(predictions):
21 | Processes model predictions to extract the most likely labels.
22 | prepare_texts(texts, labels):
23 | Creates Q&A prompts for each input text.
24 | __call__(texts, labels, threshold=0.5):
25 | Runs the model on the given texts and returns predicted labels.
26 | evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1):
27 | Evaluates the model on a dataset and computes F1 scores.
28 | """
29 |
30 | prompt = "Answer the following question: {}"
31 |
32 | def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', prompt: Optional[str] = None):
33 | """
34 | Initializes the GLiNERQuestionAnswerer.
35 |
36 | Args:
37 | model_id (str, optional): Identifier for the model to be loaded. Defaults to None.
38 | model (GLiNER, optional): Preloaded GLiNER model. Defaults to None.
39 | device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'.
40 | prompt (str, optional): Template prompt for question-answering.
41 | """
42 | # Use the provided prompt or default to the class-level prompt
43 | prompt = prompt if prompt is not None else self.prompt
44 | super().__init__(model_id=model_id, model=model, prompt=prompt, device=device)
45 |
46 |
47 | def process_predictions(self, predictions, **kwargs):
48 | """
49 | Processes predictions to extract the highest-scoring answer(s).
50 |
51 | Args:
52 | predictions (list): List of predictions with scores.
53 |
54 | Returns:
55 | list: List of predicted labels for each input.
56 | """
57 | batch_predicted_labels = []
58 |
59 | for prediction in predictions:
60 | # Sort predictions by score in descending order
61 | sorted_predictions = sorted(prediction, key=lambda entity: entity["score"], reverse=True)
62 |
63 | predicted_labels = [{'answer': pred['text'], 'score': pred['score']} for pred in sorted_predictions]
64 | batch_predicted_labels.append(predicted_labels)
65 |
66 | return batch_predicted_labels
67 |
68 | def prepare_texts(self, texts: List[str], questions: Union[List[str], str], **kwargs):
69 | """
70 | Prepares prompts for question-answering by appending questions to texts.
71 |
72 | Args:
73 | texts (list): List of input texts.
74 | questions (list|str): Question or list of questions.
75 |
76 | Returns:
77 | list: List of formatted prompts.
78 | """
79 | prompts = []
80 |
81 | for id, text in enumerate(texts):
82 | if isinstance(questions, str):
83 | question = questions
84 | else:
85 | question = questions[0]
86 | prompt = f"{self.prompt.format(question)} \n {text}"
87 | prompts.append(prompt)
88 | return prompts
89 |
90 | def __call__(self, texts: Union[str, List[str]], questions: Union[str, List[str]],
91 | labels: List[str] = ['answer'], threshold: float = 0.5,
92 | batch_size: int = 8, **kwargs):
93 | return super().__call__(texts, labels, threshold, batch_size, questions=questions)
94 |
95 | def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None,
96 | labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1):
97 | """
98 | Evaluates the model on a specified dataset and computes evaluation metrics.
99 |
100 | Args:
101 | dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets).
102 | dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored.
103 | labels (list, optional): List of target labels to consider for classification. Defaults to None (use all).
104 | threshold (float): Confidence threshold for predictions. Defaults to 0.5.
105 | max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples).
106 |
107 | Returns:
108 | dict: A dictionary containing evaluation metrics such as F1 scores.
109 |
110 | Raises:
111 | ValueError: If neither `dataset_id` nor `dataset` is provided.
112 | """
113 | raise NotImplementedError("Currently `evaluate` method is not implemented.")
114 |
115 | class GLiNERSquadEvaluator(GLiNERQuestionAnswerer):
116 | def evaluate(self, dataset_id: str = 'rajpurkar/squad_v2', dataset: Optional[Dataset] = None,
117 | labels: Optional[List[str]] = ['answer'], threshold: float = 0.5, max_examples: int = -1):
118 | """
119 | Evaluates the model on a specified dataset and computes evaluation metrics.
120 |
121 | Args:
122 | dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets).
123 | dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored.
124 | labels (list, optional): List of target labels to consider for classification. Defaults to ['answer'].
125 | threshold (float): Confidence threshold for predictions. Defaults to 0.5.
126 | max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples).
127 |
128 | Returns:
129 | dict: A dictionary containing evaluation metrics such as F1 Scores.
130 |
131 | Raises:
132 | ValueError: If neither `dataset_id` nor `dataset` is provided.
133 | """
134 | from evaluate import load
135 |
136 | # Validate input
137 | if not dataset and not dataset_id:
138 | raise ValueError("Either `dataset` or `dataset_id` must be provided.")
139 |
140 | # Load the dataset if not provided
141 | if not dataset:
142 | dataset = load_dataset(dataset_id, split="validation")
143 |
144 | if not isinstance(dataset, Dataset):
145 | dataset = dataset['validation']
146 |
147 | # Truncate dataset if max_examples is specified
148 | if max_examples > 0:
149 | dataset = dataset.shuffle().select(range(min(len(dataset), max_examples)))
150 |
151 | # Load evaluation metric for SQuAD
152 | squad_metric = load("squad_v2" if "squad_v2" in dataset_id else "squad")
153 |
154 | # Prepare predictions and references
155 | contexts = dataset['context']
156 | questions = dataset['question']
157 |
158 | raw_predictions = self(contexts, questions, labels=labels, threshold=threshold)
159 |
160 | predictions = []
161 | references = []
162 | for id, prediction in enumerate(raw_predictions):
163 | example = dataset[id]
164 |
165 | if len(prediction):
166 | predicted_answer = prediction[0]["answer"]
167 | no_answer_probability=0.0
168 | else:
169 | predicted_answer = ""
170 | no_answer_probability=1.0
171 |
172 | # Append to predictions and references
173 | predictions.append({
174 | "id": example["id"],
175 | "prediction_text": predicted_answer,
176 | "no_answer_probability": no_answer_probability
177 | })
178 |
179 | references.append({
180 | "id": example["id"],
181 | "answers": {"text": example["answers"]["text"], "answer_start": example["answers"]["answer_start"]}
182 | })
183 |
184 | # Compute metrics
185 | results = squad_metric.compute(predictions=predictions, references=references)
186 | return results
--------------------------------------------------------------------------------
/gliner/multitask/summarization.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Union
2 | import os
3 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
4 | import torch
5 | from datasets import load_dataset, Dataset
6 | from gliner import GLiNER
7 |
8 | from .base import GLiNERBasePipeline
9 |
10 | class GLiNERSummarizer(GLiNERBasePipeline):
11 | """
12 | A class to use GLiNER for summarization inference and evaluation.
13 |
14 | Attributes:
15 | device (str): Device to run the model on, e.g., 'cuda:0' or 'cpu'.
16 | model (GLiNER): Loaded GLiNER model instance.
17 | prompt (str): Template prompt for text summarization.
18 |
19 | Methods:
20 | process_predictions(predictions):
21 | Processes model predictions to extract the most likely labels.
22 | prepare_texts(texts, labels):
23 | Creates summarization prompts for each input text.
24 | __call__(texts, labels, threshold=0.5):
25 | Runs the model on the given texts and returns predicted labels.
26 | evaluate(dataset_id, labels=None, threshold=0.5, max_examples=-1):
27 | Evaluates the model on a dataset and computes F1 scores.
28 | """
29 |
30 | prompt = "Summarize the following text highlighting the most important information:"
31 |
32 | def __init__(self, model_id: str = None, model: GLiNER = None, device: str = 'cuda:0', prompt: Optional[str] = None):
33 | """
34 | Initializes the GLiNERSummarizer.
35 |
36 | Args:
37 | model_id (str, optional): Identifier for the model to be loaded. Defaults to None.
38 | model (GLiNER, optional): Preloaded GLiNER model. Defaults to None.
39 | device (str, optional): Device to run the model on ('cpu' or 'cuda:X'). Defaults to 'cuda:0'.
40 | prompt (str, optional): Template prompt for summarization.
41 | """
42 | # Use the provided prompt or default to the class-level prompt
43 | prompt = prompt if prompt is not None else self.prompt
44 | super().__init__(model_id=model_id, model=model, prompt=prompt, device=device)
45 |
46 |
47 | def process_predictions(self, predictions, **kwargs):
48 | """
49 | Processes predictions to extract the highest-scoring text chunk(s).
50 |
51 | Args:
52 | predictions (list): List of predictions with scores.
53 |
54 | Returns:
55 | list: List of predicted labels for each input.
56 | """
57 | batch_predicted_labels = []
58 |
59 | for prediction in predictions:
60 | # Sort predictions by score in descending order
61 | sorted_predictions = sorted(prediction, key=lambda entity: entity["start"], reverse=False)
62 |
63 | extracted_text = [pred['text'] for pred in sorted_predictions]
64 | batch_predicted_labels.append(' '.join(extracted_text))
65 |
66 | return batch_predicted_labels
67 |
68 | def prepare_texts(self, texts: List[str], **kwargs):
69 | """
70 | Prepares prompts for summarization by appending prompt to texts.
71 |
72 | Args:
73 | texts (list): List of input texts.
74 |
75 | Returns:
76 | list: List of formatted prompts.
77 | """
78 | prompts = []
79 |
80 | for id, text in enumerate(texts):
81 | prompt = f"{self.prompt} \n {text}"
82 | prompts.append(prompt)
83 | return prompts
84 |
85 | def __call__(self, texts: Union[str, List[str]], labels: List[str] = ['summary'],
86 | threshold: float = 0.25, batch_size: int = 8, **kwargs):
87 | return super().__call__(texts, labels, threshold, batch_size)
88 |
89 | def evaluate(self, dataset_id: Optional[str] = None, dataset: Optional[Dataset] = None,
90 | labels: Optional[List[str]]=None, threshold: float =0.5, max_examples: float =-1):
91 | """
92 | Evaluates the model on a specified dataset and computes evaluation metrics.
93 |
94 | Args:
95 | dataset_id (str, optional): Identifier for the dataset to load (e.g., from Hugging Face datasets).
96 | dataset (Dataset, optional): A pre-loaded dataset to evaluate. If provided, `dataset_id` is ignored.
97 | labels (list, optional): List of target labels to consider for summarization. Defaults to None (use all).
98 | threshold (float): Confidence threshold for predictions. Defaults to 0.5.
99 | max_examples (int): Maximum number of examples to evaluate. Defaults to -1 (use all available examples).
100 |
101 | Returns:
102 | dict: A dictionary containing evaluation metrics.
103 |
104 | Raises:
105 | ValueError: If neither `dataset_id` nor `dataset` is provided.
106 | """
107 | raise NotImplementedError("Currently `evaluate` method is not implemented.")
--------------------------------------------------------------------------------
/gliner/onnx/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/urchade/GLiNER/efbfa38211136657895372d33d4ee2fe11b6f11b/gliner/onnx/__init__.py
--------------------------------------------------------------------------------
/gliner/onnx/model.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict, Any
2 | from abc import ABC, abstractmethod
3 | import warnings
4 | import onnxruntime as ort
5 | import numpy as np
6 | import torch
7 |
8 | from ..modeling.base import GLiNERModelOutput
9 |
10 | class BaseORTModel(ABC):
11 | def __init__(self, session: ort.InferenceSession):
12 | self.session = session
13 | self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
14 | self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}
15 |
16 | def prepare_inputs(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
17 | """
18 | Prepare inputs for ONNX model inference.
19 |
20 | Args:
21 | inputs (Dict[str, torch.Tensor]): Dictionary of input names and tensors.
22 |
23 | Returns:
24 | Dict[str, np.ndarray]: Dictionary of input names and numpy arrays.
25 | """
26 | if not isinstance(inputs, dict):
27 | raise ValueError("Inputs must be a dictionary of input names and tensors.")
28 |
29 | prepared_inputs = {}
30 | for key, tensor in inputs.items():
31 | if key not in self.input_names:
32 | warnings.warn(f"Input key '{key}' not found in ONNX model's input names. Ignored.")
33 | continue
34 | prepared_inputs[key] = tensor.cpu().detach().numpy()
35 | return prepared_inputs
36 |
37 | def run_inference(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
38 | """
39 | Run the ONNX model inference.
40 |
41 | Args:
42 | inputs (Dict[str, np.ndarray]): Prepared inputs for the model.
43 |
44 | Returns:
45 | Dict[str, np.ndarray]: Model's outputs as numpy arrays.
46 | """
47 | onnx_outputs = self.session.run(None, inputs)
48 | outputs = {name: onnx_outputs[idx] for name, idx in self.output_names.items()}
49 | return outputs
50 |
51 | @abstractmethod
52 | def forward(self, input_ids, attention_mask, **kwargs) -> Dict[str, Any]:
53 | """
54 | Abstract method to perform forward pass. Must be implemented by subclasses.
55 | """
56 | pass
57 |
58 | def __call__(self, *args, **kwargs):
59 | return self.forward(*args, **kwargs)
60 |
61 | class SpanORTModel(BaseORTModel):
62 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
63 | words_mask: torch.Tensor, text_lengths: torch.Tensor,
64 | span_idx: torch.Tensor, span_mask: torch.Tensor, **kwargs) -> Dict[str, Any]:
65 | """
66 | Forward pass for span model using ONNX inference.
67 |
68 | Args:
69 | input_ids (torch.Tensor): Input IDs tensor.
70 | attention_mask (torch.Tensor): Attention mask tensor.
71 | span_idx (torch.Tensor): Span indices tensor.
72 | span_mask (torch.Tensor): Span mask tensor.
73 | **kwargs: Additional arguments.
74 |
75 | Returns:
76 | Dict[str, Any]: Model outputs.
77 | """
78 | inputs = {
79 | 'input_ids': input_ids,
80 | 'attention_mask': attention_mask,
81 | 'words_mask': words_mask,
82 | 'text_lengths': text_lengths,
83 | 'span_idx': span_idx,
84 | 'span_mask': span_mask
85 | }
86 | prepared_inputs = self.prepare_inputs(inputs)
87 | inference_output = self.run_inference(prepared_inputs)
88 | outputs = GLiNERModelOutput(
89 | logits=inference_output['logits']
90 | )
91 | return outputs
92 |
93 | class TokenORTModel(BaseORTModel):
94 | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
95 | words_mask: torch.Tensor, text_lengths: torch.Tensor,
96 | **kwargs) -> Dict[str, Any]:
97 | """
98 | Forward pass for token model using ONNX inference.
99 |
100 | Args:
101 | input_ids (torch.Tensor): Input IDs tensor.
102 | attention_mask (torch.Tensor): Attention mask tensor.
103 | **kwargs: Additional arguments.
104 |
105 | Returns:
106 | Dict[str, Any]: Model outputs.
107 | """
108 | inputs = {
109 | 'input_ids': input_ids,
110 | 'attention_mask': attention_mask,
111 | 'words_mask': words_mask,
112 | 'text_lengths': text_lengths,
113 | }
114 | prepared_inputs = self.prepare_inputs(inputs)
115 | inference_output = self.run_inference(prepared_inputs)
116 | outputs = GLiNERModelOutput(
117 | logits=inference_output['logits']
118 | )
119 | return outputs
--------------------------------------------------------------------------------
/gliner/training/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import Trainer, TrainingArguments
--------------------------------------------------------------------------------
/gliner/training/trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, Any, Dict, Tuple, List
2 | from dataclasses import dataclass, field
3 |
4 | import torch
5 | import transformers
6 | from numpy.ma.core import negative
7 | from transformers.training_args import OptimizerNames
8 | from transformers.trainer import (
9 | is_sagemaker_mp_enabled,
10 | get_parameter_names,
11 | ALL_LAYERNORM_LAYERS,
12 | )
13 | from transformers.trainer_utils import seed_worker
14 |
15 | if transformers.utils.is_apex_available():
16 | from apex import amp
17 |
18 | if is_sagemaker_mp_enabled():
19 | from transformers.trainer_pt_utils import smp_forward_backward
20 | from torch.utils.data import DataLoader, Dataset
21 |
22 | @dataclass
23 | class TrainingArguments(transformers.TrainingArguments):
24 | cache_dir: Optional[str] = field(default=None)
25 | optim: str = field(default="adamw_torch")
26 | others_lr: Optional[float] = None
27 | others_weight_decay: Optional[float] = 0.0
28 | focal_loss_alpha: Optional[float] = -1
29 | focal_loss_gamma: Optional[float] = 0
30 | label_smoothing: Optional[float] = 0
31 | loss_reduction: Optional[str] = 'sum'
32 | negatives: Optional[float] = 1.0
33 | masking: Optional[str] = 'global'
34 |
35 | class Trainer(transformers.Trainer):
36 | def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor:
37 | """
38 | Perform a training step on a batch of inputs.
39 |
40 | Subclass and override to inject custom behavior.
41 |
42 | Args:
43 | model (`nn.Module`):
44 | The model to train.
45 | inputs (`Dict[str, Union[torch.Tensor, Any]]`):
46 | The inputs and targets of the model.
47 |
48 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
49 | argument `labels`. Check your model's documentation for all accepted arguments.
50 |
51 | Return:
52 | `torch.Tensor`: The tensor with training loss on this batch.
53 | """
54 | model.train()
55 | try:
56 | inputs = self._prepare_inputs(inputs)
57 | if is_sagemaker_mp_enabled():
58 | loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
59 | return loss_mb.reduce_mean().detach().to(self.args.device)
60 |
61 | with self.compute_loss_context_manager():
62 | loss = self.compute_loss(model, inputs)
63 |
64 | del inputs
65 | torch.cuda.empty_cache()
66 |
67 | kwargs = {}
68 |
69 | # For LOMO optimizers you need to explicitly use the learnign rate
70 | # if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
71 | # kwargs["learning_rate"] = self._get_learning_rate()
72 |
73 | if self.args.n_gpu > 1:
74 | loss = loss.mean() # mean() to average on multi-gpu parallel training
75 |
76 | if self.use_apex:
77 | with amp.scale_loss(loss, self.optimizer) as scaled_loss:
78 | scaled_loss.backward()
79 | else:
80 | self.accelerator.backward(loss, **kwargs)
81 |
82 | return loss.detach() / self.args.gradient_accumulation_steps
83 | except Exception as e:
84 | print(f"Skipping iteration due to error: {e}")
85 | model.zero_grad(set_to_none=True)
86 | torch.cuda.empty_cache()
87 | return torch.tensor(0.0, requires_grad=True).to(model.device)
88 |
89 | def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
90 | self.model.save_pretrained(output_dir)
91 |
92 | def compute_loss(self, model, inputs):
93 | """
94 | Override compute_loss to use a custom loss function.
95 | """
96 | # Forward pass
97 | outputs = model(alpha = self.args.focal_loss_alpha,
98 | gamma = self.args.focal_loss_gamma,
99 | label_smoothing = self.args.label_smoothing,
100 | reduction = self.args.loss_reduction,
101 | negatives = self.args.negatives,
102 | masking = self.args.masking,
103 | **inputs)
104 | loss = outputs.loss
105 | return loss
106 |
107 | def create_optimizer(self):
108 | """
109 | Setup the optimizer.
110 |
111 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
112 | Trainer's init through `optimizers`, or subclass and override this method in a subclass.
113 | """
114 | if is_sagemaker_mp_enabled():
115 | return super().create_optimizer()
116 |
117 | opt_model = self.model
118 |
119 | if self.optimizer is None:
120 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
121 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
122 | if self.args.others_lr is not None:
123 | encoder_parameters = [name for name, _ in opt_model.named_parameters() if "token_rep_layer" in name]
124 | optimizer_grouped_parameters = [
125 | {
126 | "params": [
127 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in encoder_parameters and p.requires_grad)
128 | ],
129 | "weight_decay": self.args.others_weight_decay,
130 | "lr": self.args.others_lr,
131 | },
132 | {
133 | "params": [
134 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in encoder_parameters and p.requires_grad)
135 | ],
136 | "weight_decay": 0.0,
137 | "lr": self.args.others_lr,
138 | },
139 | {
140 | "params": [
141 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in encoder_parameters and p.requires_grad)
142 | ],
143 | "weight_decay": self.args.weight_decay,
144 | },
145 | {
146 | "params": [
147 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in encoder_parameters and p.requires_grad)
148 | ],
149 | "weight_decay": 0.0,
150 | },
151 | ]
152 | else:
153 | optimizer_grouped_parameters = [
154 | {
155 | "params": [
156 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
157 | ],
158 | "weight_decay": self.args.weight_decay,
159 | },
160 | {
161 | "params": [
162 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
163 | ],
164 | "weight_decay": 0.0,
165 | },
166 | ]
167 |
168 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
169 |
170 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
171 |
172 | return self.optimizer
173 |
174 | def prediction_step(
175 | self,
176 | model: torch.nn.Module,
177 | inputs: Dict[str, Union[torch.Tensor, Any]],
178 | prediction_loss_only: bool,
179 | ignore_keys: Optional[List[str]] = None,
180 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
181 | """
182 | Perform an evaluation step on model using inputs.
183 |
184 | Subclass and override to inject custom behavior.
185 |
186 | Args:
187 | model (nn.Module):
188 | The model to evaluate.
189 | inputs (Dict[str, Union[torch.Tensor, Any]]):
190 | The inputs and targets of the model.
191 |
192 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
193 | argument labels. Check your model's documentation for all accepted arguments.
194 | prediction_loss_only (bool):
195 | Whether or not to return the loss only.
196 | ignore_keys (List[str], *optional*):
197 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
198 | gathering predictions.
199 |
200 | Return:
201 | Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
202 | logits and labels (each being optional).
203 | """
204 | with torch.no_grad():
205 | loss = None
206 | with self.compute_loss_context_manager():
207 | outputs = model(**inputs)
208 | loss = outputs.loss
209 | logits = outputs.logits
210 | labels = inputs['labels']
211 | if prediction_loss_only:
212 | return (loss, None, None)
213 | return (loss, logits, labels)
214 |
215 |
216 | def get_train_dataloader(self) -> DataLoader:
217 | """
218 | Returns the training [`~torch.utils.data.DataLoader`].
219 |
220 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
221 | training if necessary) otherwise.
222 |
223 | Subclass and override this method if you want to inject some custom behavior.
224 | """
225 | if self.train_dataset is None:
226 | raise ValueError("Trainer: training requires a train_dataset.")
227 |
228 | train_dataset = self.train_dataset
229 | data_collator = self.data_collator
230 |
231 | dataloader_params = {
232 | "batch_size": self._train_batch_size,
233 | "collate_fn": data_collator,
234 | "num_workers": self.args.dataloader_num_workers,
235 | "pin_memory": self.args.dataloader_pin_memory,
236 | "persistent_workers": self.args.dataloader_persistent_workers,
237 | }
238 |
239 | if not isinstance(train_dataset, torch.utils.data.IterableDataset):
240 | dataloader_params["sampler"] = self._get_train_sampler()
241 | dataloader_params["drop_last"] = self.args.dataloader_drop_last
242 | dataloader_params["worker_init_fn"] = seed_worker
243 | dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
244 |
245 | return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
246 |
247 | def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
248 | """
249 | Returns the evaluation [`~torch.utils.data.DataLoader`].
250 |
251 | Subclass and override this method if you want to inject some custom behavior.
252 |
253 | Args:
254 | eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
255 | If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
256 | """
257 | if eval_dataset is None and self.eval_dataset is None:
258 | raise ValueError("Trainer: evaluation requires an eval_dataset.")
259 |
260 | # If we have persistent workers, don't do a fork bomb especially as eval datasets
261 | # don't change during training
262 | dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
263 | if (
264 | hasattr(self, "_eval_dataloaders")
265 | and dataloader_key in self._eval_dataloaders
266 | and self.args.dataloader_persistent_workers
267 | ):
268 | return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
269 |
270 | eval_dataset = (
271 | self.eval_dataset[eval_dataset]
272 | if isinstance(eval_dataset, str)
273 | else eval_dataset
274 | if eval_dataset is not None
275 | else self.eval_dataset
276 | )
277 | data_collator = self.data_collator
278 |
279 | dataloader_params = {
280 | "batch_size": self.args.eval_batch_size,
281 | "collate_fn": data_collator,
282 | "num_workers": self.args.dataloader_num_workers,
283 | "pin_memory": self.args.dataloader_pin_memory,
284 | "persistent_workers": self.args.dataloader_persistent_workers,
285 | }
286 |
287 | if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
288 | dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
289 | dataloader_params["drop_last"] = self.args.dataloader_drop_last
290 | dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
291 |
292 | # accelerator.free_memory() will destroy the references, so
293 | # we need to store the non-prepared version
294 | eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
295 | if self.args.dataloader_persistent_workers:
296 | if hasattr(self, "_eval_dataloaders"):
297 | self._eval_dataloaders[dataloader_key] = eval_dataloader
298 | else:
299 | self._eval_dataloaders = {dataloader_key: eval_dataloader}
300 |
301 | return self.accelerator.prepare(eval_dataloader)
302 |
--------------------------------------------------------------------------------
/gliner/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import warnings
3 | import yaml
4 |
5 | def load_config_as_namespace(config_file):
6 | with open(config_file, "r") as f:
7 | config_dict = yaml.safe_load(f)
8 | return argparse.Namespace(**config_dict)
9 |
10 | def is_module_available(module_name):
11 | """
12 | Checks whether the specified Python module is available.
13 |
14 | Args:
15 | module_name (str): The name of the module to check.
16 |
17 | Returns:
18 | bool: True if the module is available, False otherwise.
19 | """
20 | try:
21 | __import__(module_name)
22 | return True
23 | except ImportError:
24 | return False
25 |
26 | class MissedPackageException(Exception):
27 | """Raised when the requested decoder model is not supported."""
28 | pass
--------------------------------------------------------------------------------
/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/urchade/GLiNER/efbfa38211136657895372d33d4ee2fe11b6f11b/image.png
--------------------------------------------------------------------------------
/logo/FI Group.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/urchade/GLiNER/efbfa38211136657895372d33d4ee2fe11b6f11b/logo/FI Group.png
--------------------------------------------------------------------------------
/logo/FI_COMPLET_CW.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/urchade/GLiNER/efbfa38211136657895372d33d4ee2fe11b6f11b/logo/FI_COMPLET_CW.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.setuptools.packages.find]
6 | include = ["gliner", "gliner.*"]
7 |
8 | [tool.setuptools.dynamic]
9 | version = {attr = "gliner.__version__"}
10 |
11 | [project]
12 | name = "gliner"
13 | description = "Generalist model for NER (Extract any entity types from texts)"
14 | readme = "README.md"
15 | requires-python = ">=3.8"
16 | license = {text = "Apache-2.0"}
17 | keywords = [
18 | "named-entity-recognition",
19 | "ner",
20 | "data-science",
21 | "natural-language-processing",
22 | "artificial-intelligence",
23 | "nlp",
24 | "machine-learning",
25 | "transformers"
26 | ]
27 | authors = [
28 | {name = "Urchade Zaratiana"},
29 | {name = "Nadi Tomeh"},
30 | {name = "Pierre Holat"},
31 | {name = "Thierry Charnois"},
32 | ]
33 | maintainers = [
34 | {name = "Urchade Zaratiana"},
35 | ]
36 |
37 | dependencies = [
38 | "torch>=2.0.0",
39 | "transformers>=4.38.2",
40 | "huggingface_hub>=0.21.4",
41 | "tqdm",
42 | "onnxruntime",
43 | "sentencepiece",
44 | ]
45 |
46 | dynamic = ["version"]
47 |
48 | [project.optional-dependencies]
49 | gpu = ["onnxruntime-gpu"]
50 |
51 |
52 | [project.urls]
53 | Homepage = "https://github.com/urchade/GLiNER"
54 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.0.0
2 | transformers>=4.38.2,<=4.45.2
3 | huggingface_hub>=0.21.4
4 | onnxruntime-gpu
5 | sentencepiece
6 | tqdm
7 |
--------------------------------------------------------------------------------
/tests/test_features_selection.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from transformers import AutoTokenizer
4 | from gliner import GLiNERConfig
5 | from gliner.modeling.base import extract_prompt_features_and_word_embeddings
6 | from gliner.data_processing import SpanProcessor, WordsSplitter
7 |
8 | class TestFeaturesExtractor:
9 | @pytest.fixture(autouse=True)
10 | def setup(self):
11 | self.config = GLiNERConfig()
12 | self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
13 | self.config.class_token_index=len(self.tokenizer)
14 | self.tokenizer.add_tokens([self.config.ent_token, self.config.sep_token])
15 | self.splitter = WordsSplitter()
16 | self.base_tokens = [['Hello', 'world', '!']]
17 | self.tokens_with_missed = [['Hello', '', 'world', '']]
18 | self.labels = ['world']
19 | self.processor = SpanProcessor(self.config, self.tokenizer, self.splitter)
20 |
21 | def test_base_extraction(self):
22 | input_x = [{"tokenized_text": tk, "ner": None} for tk in self.base_tokens]
23 | raw_batch = self.processor.collate_raw_batch(input_x, self.labels)
24 | model_input = self.processor.collate_fn(raw_batch, prepare_labels=False)
25 | model_input['text_lengths'] = raw_batch['seq_length']
26 | token_embeds = torch.rand(model_input['words_mask'].shape + (self.config.hidden_size,))
27 |
28 | (prompts_embedding,
29 | prompts_embedding_mask,
30 | words_embedding,
31 | mask) = extract_prompt_features_and_word_embeddings(self.config, token_embeds, **model_input)
32 |
33 | assert prompts_embedding_mask.shape == (1, 1)
34 | assert prompts_embedding.shape == (1, 1, self.config.hidden_size)
35 | assert words_embedding.shape == (1, len(self.base_tokens[0]), self.config.hidden_size)
36 |
37 | def test_extraction_with_missed_tokens(self):
38 | input_x = [{"tokenized_text": tk, "ner": None} for tk in self.tokens_with_missed]
39 | raw_batch = self.processor.collate_raw_batch(input_x, self.labels)
40 | model_input = self.processor.collate_fn(raw_batch, prepare_labels=False)
41 | model_input['text_lengths'] = raw_batch['seq_length']
42 | token_embeds = torch.rand(model_input['words_mask'].shape + (self.config.hidden_size,))
43 |
44 | (prompts_embedding,
45 | prompts_embedding_mask,
46 | words_embedding,
47 | mask) = extract_prompt_features_and_word_embeddings(self.config, token_embeds, **model_input)
48 |
49 | assert prompts_embedding_mask.shape == (1, 1)
50 | assert prompts_embedding.shape == (1, 1, self.config.hidden_size)
51 | assert words_embedding.shape == (1, len(self.tokens_with_missed[0]), self.config.hidden_size)
52 |
53 |
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | from gliner import GLiNER
2 |
3 |
4 | def test_span_model():
5 | model = GLiNER.from_pretrained("gliner-community/gliner_small-v2.5")
6 |
7 | text = """
8 | Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time.
9 | """
10 |
11 | labels = ["person", "award", "date", "competitions", "teams", "person"]
12 |
13 | entities = model.predict_entities(text, labels)
14 |
15 | assert len(entities) > 0
16 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
3 | import argparse
4 | import random
5 | import json
6 |
7 | from transformers import AutoTokenizer
8 | import torch
9 |
10 | from gliner import GLiNERConfig, GLiNER
11 | from gliner.training import Trainer, TrainingArguments
12 | from gliner.data_processing.collator import DataCollatorWithPadding, DataCollator
13 | from gliner.utils import load_config_as_namespace
14 | from gliner.data_processing import WordsSplitter, GLiNERDataset
15 |
16 |
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--config', type=str, default= "configs/config.yaml")
20 | parser.add_argument('--log_dir', type=str, default = 'models/')
21 | parser.add_argument('--compile_model', type=bool, default = False)
22 | parser.add_argument('--freeze_language_model', type=bool, default = False)
23 | parser.add_argument('--new_data_schema', type=bool, default = False)
24 | args = parser.parse_args()
25 |
26 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
27 |
28 | config = load_config_as_namespace(args.config)
29 | config.log_dir = args.log_dir
30 |
31 | with open(config.train_data, 'r') as f:
32 | data = json.load(f)
33 |
34 | print('Dataset size:', len(data))
35 | #shuffle
36 | random.shuffle(data)
37 | print('Dataset is shuffled...')
38 |
39 | train_data = data[:int(len(data)*0.9)]
40 | test_data = data[int(len(data)*0.9):]
41 |
42 | print('Dataset is splitted...')
43 |
44 |
45 | if config.prev_path is not None:
46 | tokenizer = AutoTokenizer.from_pretrained(config.prev_path)
47 | model = GLiNER.from_pretrained(config.prev_path)
48 | model_config = model.config
49 | else:
50 | model_config = GLiNERConfig(**vars(config))
51 | tokenizer = AutoTokenizer.from_pretrained(model_config.model_name)
52 |
53 | words_splitter = WordsSplitter(model_config.words_splitter_type)
54 |
55 | model = GLiNER(model_config, tokenizer=tokenizer, words_splitter=words_splitter)
56 |
57 | if not config.labels_encoder:
58 | model_config.class_token_index=len(tokenizer)
59 | tokenizer.add_tokens([model_config.ent_token, model_config.sep_token], special_tokens=True)
60 | model_config.vocab_size = len(tokenizer)
61 | model.resize_token_embeddings([model_config.ent_token, model_config.sep_token],
62 | set_class_token_index = False,
63 | add_tokens_to_tokenizer=False)
64 |
65 | if args.compile_model:
66 | torch.set_float32_matmul_precision('high')
67 | model.to(device)
68 | model.compile_for_training()
69 |
70 | if args.freeze_language_model:
71 | model.model.token_rep_layer.bert_layer.model.requires_grad_(False)
72 | else:
73 | model.model.token_rep_layer.bert_layer.model.requires_grad_(True)
74 |
75 | if args.new_data_schema:
76 | train_dataset = GLiNERDataset(train_data, model_config, tokenizer, words_splitter)
77 | test_dataset = GLiNERDataset(test_data, model_config, tokenizer, words_splitter)
78 | data_collator = DataCollatorWithPadding(model_config)
79 | else:
80 | train_dataset = train_data
81 | test_dataset = test_data
82 | data_collator = DataCollator(model.config, data_processor=model.data_processor, prepare_labels=True)
83 |
84 | training_args = TrainingArguments(
85 | output_dir=config.log_dir,
86 | learning_rate=float(config.lr_encoder),
87 | weight_decay=float(config.weight_decay_encoder),
88 | others_lr=float(config.lr_others),
89 | others_weight_decay=float(config.weight_decay_other),
90 | focal_loss_gamma=config.loss_gamma,
91 | focal_loss_alpha=config.loss_alpha,
92 | lr_scheduler_type=config.scheduler_type,
93 | warmup_ratio=config.warmup_ratio,
94 | per_device_train_batch_size=config.train_batch_size,
95 | per_device_eval_batch_size=config.train_batch_size,
96 | max_grad_norm=config.max_grad_norm,
97 | max_steps=config.num_steps,
98 | evaluation_strategy="epoch",
99 | save_steps = config.eval_every,
100 | save_total_limit=config.save_total_limit,
101 | dataloader_num_workers = 8,
102 | use_cpu = False,
103 | report_to="none",
104 | bf16=True,
105 | )
106 |
107 | trainer = Trainer(
108 | model=model,
109 | args=training_args,
110 | train_dataset=train_dataset,
111 | eval_dataset=test_dataset,
112 | tokenizer=tokenizer,
113 | data_collator=data_collator,
114 | )
115 | trainer.train()
116 |
--------------------------------------------------------------------------------