├── .gitignore
├── LICENSE
├── README.md
├── demo.py
├── gliclass
├── __init__.py
├── config.py
├── data_processing.py
├── layers.py
├── loss_functions.py
├── model.py
├── pipeline.py
├── poolings.py
├── scorers.py
├── training.py
└── utils.py
├── notebooks
└── finetuning.ipynb
├── pyproject.toml
├── test_gliclass.py
├── train.py
└── train_rl.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | #custom
10 | models/
11 | wandb/
12 | gradio_cached_examples/
13 | test.ipynb
14 | demo1.py
15 | .gradio/
16 |
17 | # Distribution / packaging
18 | .Python
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 | cover/
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 | db.sqlite3
70 | db.sqlite3-journal
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | .pybuilder/
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | # For a library or package, you might want to ignore these files since the code is
95 | # intended to run in multiple environments; otherwise, check them in:
96 | # .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # poetry
106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107 | # This is especially recommended for binary packages to ensure reproducibility, and is more
108 | # commonly ignored for libraries.
109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110 | #poetry.lock
111 |
112 | # pdm
113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
114 | #pdm.lock
115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
116 | # in version control.
117 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
118 | .pdm.toml
119 | .pdm-python
120 | .pdm-build/
121 |
122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
123 | __pypackages__/
124 |
125 | # Celery stuff
126 | celerybeat-schedule
127 | celerybeat.pid
128 |
129 | # SageMath parsed files
130 | *.sage.py
131 |
132 | # Environments
133 | .env
134 | .venv
135 | env/
136 | venv/
137 | ENV/
138 | env.bak/
139 | venv.bak/
140 |
141 | # Spyder project settings
142 | .spyderproject
143 | .spyproject
144 |
145 | # Rope project settings
146 | .ropeproject
147 |
148 | # mkdocs documentation
149 | /site
150 |
151 | # mypy
152 | .mypy_cache/
153 | .dmypy.json
154 | dmypy.json
155 |
156 | # Pyre type checker
157 | .pyre/
158 |
159 | # pytype static type analyzer
160 | .pytype/
161 |
162 | # Cython debug symbols
163 | cython_debug/
164 |
165 | # PyCharm
166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
168 | # and can be added to the global gitignore or merged into this file. For a more nuclear
169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
170 | #.idea/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ⭐ GLiClass: Generalist and Lightweight Model for Sequence Classification
2 |
3 | **GLiClass** is an efficient, zero-shot sequence classification model inspired by the [GLiNER](https://github.com/urchade/GLiNER/tree/main) framework. It achieves comparable performance to traditional cross-encoder models while being significantly more computationally efficient, offering classification results approximately **10 times faster** by performing classification in a single forward pass.
4 |
5 |
6 | 📄 Blog
7 | •
8 | 📢 Discord
9 | •
10 | 📺 Demo
11 | •
12 | 🤗 Available models
13 | •
14 |
15 |
16 |
17 |
18 |
19 | ### 🚀 Quick Start
20 |
21 | Install GLiClass easily using pip:
22 |
23 | ```bash
24 | pip install gliclass
25 | ```
26 |
27 | #### Install from Source
28 |
29 | Clone and install directly from GitHub:
30 |
31 | ```bash
32 | git clone https://github.com/Knowledgator/GLiClass
33 | cd GLiClass
34 |
35 | python -m venv venv
36 | source venv/bin/activate # Windows: venv\Scripts\activate
37 |
38 | pip install -r requirements.txt
39 | pip install .
40 | ```
41 |
42 | Verify your installation:
43 |
44 | ```python
45 | import gliclass
46 | print(gliclass.__version__)
47 | ```
48 |
49 | ### 🧑💻 Usage Example
50 |
51 | ```python
52 | from gliclass import GLiClassModel, ZeroShotClassificationPipeline
53 | from transformers import AutoTokenizer
54 |
55 | model = GLiClassModel.from_pretrained("knowledgator/gliclass-small-v1.0")
56 | tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-small-v1.0")
57 |
58 | pipeline = ZeroShotClassificationPipeline(
59 | model, tokenizer, classification_type='multi-label', device='cuda:0'
60 | )
61 |
62 | text = "One day I will see the world!"
63 | labels = ["travel", "dreams", "sport", "science", "politics"]
64 | results = pipeline(text, labels, threshold=0.5)[0]
65 |
66 | for result in results:
67 | print(f"{result['label']} => {result['score']:.3f}")
68 | ```
69 |
70 | ### 🌟 Retrieval-Augmented Classification (RAC)
71 |
72 | With new models trained with retrieval-agumented classification, such as [this model](https://huggingface.co/knowledgator/gliclass-base-v2.0-rac-init) you can specify examples to improve classification accuracy:
73 |
74 | ```python
75 | example = {
76 | "text": "A new machine learning platform automates complex data workflows but faces integration issues.",
77 | "all_labels": ["AI", "automation", "data_analysis", "usability", "integration"],
78 | "true_labels": ["AI", "integration", "automation"]
79 | }
80 |
81 | text = "The new AI-powered tool streamlines data analysis but has limited integration capabilities."
82 | labels = ["AI", "automation", "data_analysis", "usability", "integration"]
83 |
84 | results = pipeline(text, labels, threshold=0.1, rac_examples=[example])[0]
85 |
86 | for predict in results:
87 | print(f"{predict['label']} => {predict['score']:.3f}")
88 | ```
89 |
90 | ### 🎯 Key Use Cases
91 |
92 | - **Sentiment Analysis:** Rapidly classify texts as positive, negative, or neutral.
93 | - **Document Classification:** Efficiently organize and categorize large document collections.
94 | - **Search Results Re-ranking:** Improve relevance and precision by reranking search outputs.
95 | - **News Categorization:** Automatically tag and organize news articles into predefined categories.
96 | - **Fact Checking:** Quickly validate and categorize statements based on factual accuracy.
97 |
98 | ### 🛠️ How to Train
99 |
100 | Prepare your training data as follows:
101 |
102 | ```json
103 | [
104 | {"text": "Sample text.", "all_labels": ["sports", "science", "business"], "true_labels": ["sports"]},
105 | ...
106 | ]
107 | ```
108 |
109 | Optionally, specify confidence scores explicitly:
110 |
111 | ```json
112 | [
113 | {"text": "Sample text.", "all_labels": ["sports", "science"], "true_labels": {"sports": 0.9}},
114 | ...
115 | ]
116 | ```
117 |
118 | Please, refer to the `train.py` script to set up your training from scratch or fine-tune existing models.
119 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union
2 | import gradio as gr
3 | import torch
4 | from transformers import AutoTokenizer
5 |
6 | from gliclass import GLiClassModel, ZeroShotClassificationPipeline
7 |
8 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
9 |
10 | model_path = "models/checkpoint-1000"
11 | model = GLiClassModel.from_pretrained(model_path)
12 | tokenizer = AutoTokenizer.from_pretrained(model_path)
13 |
14 |
15 | pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda')
16 |
17 | text1 = """
18 | "I recently purchased the Sony WH-1000XM4 Wireless Noise-Canceling Headphones from Amazon and I must say, I'm thoroughly impressed. The package arrived in New York within 2 days, thanks to Amazon Prime's expedited shipping.
19 |
20 | The headphones themselves are remarkable. The noise-canceling feature works like a charm in the bustling city environment, and the 30-hour battery life means I don't have to charge them every day. Connecting them to my Samsung Galaxy S21 was a breeze, and the sound quality is second to none.
21 |
22 | I also appreciated the customer service from Amazon when I had a question about the warranty. They responded within an hour and provided all the information I needed.
23 |
24 | However, the headphones did not come with a hard case, which was listed in the product description. I contacted Amazon, and they offered a 10% discount on my next purchase as an apology.
25 |
26 | Overall, I'd give these headphones a 4.5/5 rating and highly recommend them to anyone looking for top-notch quality in both product and service."""
27 |
28 | text2 = """
29 | Apple Inc. is an American multinational technology company headquartered in Cupertino, California. Apple is the world's largest technology company by revenue, with US$394.3 billion in 2022 revenue. As of March 2023, Apple is the world's biggest company by market capitalization. As of June 2022, Apple is the fourth-largest personal computer vendor by unit sales and the second-largest mobile phone manufacturer in the world. It is considered one of the Big Five American information technology companies, alongside Alphabet (parent company of Google), Amazon, Meta Platforms, and Microsoft.
30 | Microsoft was founded by Bill Gates and Paul Allen on April 4, 1975 to develop and sell BASIC interpreters for the Altair 8800. During his career at Microsoft, Gates held the positions of chairman, chief executive officer, president and chief software architect, while also being the largest individual shareholder until May 2014.
31 | Apple was founded as Apple Computer Company on April 1, 1976, by Steve Wozniak, Steve Jobs (1955–2011) and Ronald Wayne to develop and sell Wozniak's Apple I personal computer. It was incorporated by Jobs and Wozniak as Apple Computer, Inc. in 1977. The company's second computer, the Apple II, became a best seller and one of the first mass-produced microcomputers. Apple went public in 1980 to instant financial success. The company developed computers featuring innovative graphical user interfaces, including the 1984 original Macintosh, announced that year in a critically acclaimed advertisement called "1984". By 1985, the high cost of its products, and power struggles between executives, caused problems. Wozniak stepped back from Apple and pursued other ventures, while Jobs resigned and founded NeXT, taking some Apple employees with him.
32 | """
33 |
34 | text3 = """
35 | Several studies have reported its pharmacological activities, including anti-inflammatory, antimicrobial, and antitumoral effects.
36 | The effect of E-anethole was studied in the osteosarcoma MG-63 cell line, and the antiproliferative activity was evaluated by an MTT assay.
37 | It showed a GI50 value of 60.25 μM with apoptosis induction through the mitochondrial-mediated pathway. Additionally, it induced cell cycle arrest at the G0/G1 phase, up-regulated the expression of p53, caspase-3, and caspase-9, and down-regulated Bcl-xL expression.
38 | Moreover, the antitumoral activity of anethole was assessed against oral tumor Ca9-22 cells, and the cytotoxic effects were evaluated by MTT and LDH assays.
39 | It demonstrated a LD50 value of 8 μM, and cellular proliferation was 42.7% and 5.2% at anethole concentrations of 3 μM and 30 μM, respectively.
40 | It was reported that it could selectively and in a dose-dependent manner decrease cell proliferation and induce apoptosis, as well as induce autophagy, decrease ROS production, and increase glutathione activity. The cytotoxic effect was mediated through NF-kB, MAP kinases, Wnt, caspase-3 and -9, and PARP1 pathways. Additionally, treatment with anethole inhibited cyclin D1 oncogene expression, increased cyclin-dependent kinase inhibitor p21WAF1, up-regulated p53 expression, and inhibited the EMT markers.
41 | """
42 | examples = [
43 | [
44 | text1,
45 | "product review, sport, competition, electronics, positive feadback, negative feadback",
46 | 0.5,
47 | False
48 | ],
49 | [
50 | text2,
51 | "business, computers, sport, politics, science",
52 | 0.5,
53 | False
54 | ],
55 | [
56 | text3,
57 | "business, biology, science, politics, positive review",
58 | 0.5,
59 | False
60 | ],
61 | ]
62 |
63 | def classification(
64 | text, labels: str, threshold: float, multi_label: bool = False
65 | ) -> str:
66 | labels = labels.split(",")
67 | if multi_label:
68 | pipeline.pipe.classification_type = 'multi-label'
69 | else:
70 | pipeline.pipe.classification_type = 'single-label'
71 |
72 | results = pipeline(text, labels, threshold=threshold)[0] #because we have one text
73 |
74 | predicts = {result['label']:float(result['score']) for result in results}
75 | # predicts = '\n'.join([f"{result['label']} => {result['score']}" for result in results])
76 | return predicts
77 |
78 |
79 | with gr.Blocks(title="GLiClass-small-v1.0") as demo:
80 | with gr.Accordion("How to run this model locally", open=False):
81 | gr.Markdown(
82 | """
83 | ## Installation
84 | To use this model, you must install the GLiClass Python library:
85 | ```
86 | !pip install gliclass
87 | ```
88 |
89 | ## Usage
90 | Once you've downloaded the GLiClass library, you can import the GLiClassModel and ZeroShotClassificationPipeline classes.
91 | """
92 | )
93 | gr.Code(
94 | '''
95 | from gliclass import GLiClassModel, ZeroShotClassificationPipeline
96 | from transformers import AutoTokenizer
97 |
98 | model = GLiClassModel.from_pretrained("knowledgator/gliclass-small-v1")
99 | tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-small-v1")
100 |
101 | pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
102 |
103 | text = "One day I will see the world!"
104 | labels = ["travel", "dreams", "sport", "science", "politics"]
105 | results = pipeline(text, labels, threshold=0.5)[0] #because we have one text
106 |
107 | for result in results:
108 | print(result["label"], "=>", result["score"])
109 | ''',
110 | language="python",
111 | )
112 |
113 | input_text = gr.Textbox(
114 | value=examples[0][0], label="Text input", placeholder="Enter your text here"
115 | )
116 | with gr.Row() as row:
117 | labels = gr.Textbox(
118 | value=examples[0][1],
119 | label="Labels",
120 | placeholder="Enter your labels here (comma separated)",
121 | scale=2,
122 | )
123 | threshold = gr.Slider(
124 | 0,
125 | 1,
126 | value=0.3,
127 | step=0.01,
128 | label="Threshold",
129 | info="Lower the threshold to increase how many entities get predicted.",
130 | scale=1,
131 | )
132 | multi_label = gr.Checkbox(
133 | value=examples[0][2],
134 | label="Multi-label classification",
135 | info="Allow for multi-label classification?",
136 | scale=0,
137 | )
138 | output = gr.Label(label="Output", color="#4b5563")
139 | submit_btn = gr.Button("Submit")
140 | examples = gr.Examples(
141 | examples,
142 | fn=classification,
143 | inputs=[input_text, labels, threshold, multi_label],
144 | outputs=output,
145 | cache_examples=True,
146 | )
147 |
148 | # Submitting
149 | input_text.submit(
150 | fn=classification, inputs=[input_text, labels, threshold, multi_label], outputs=output
151 | )
152 | labels.submit(
153 | fn=classification, inputs=[input_text, labels, threshold, multi_label], outputs=output
154 | )
155 | threshold.release(
156 | fn=classification, inputs=[input_text, labels, threshold, multi_label], outputs=output
157 | )
158 | submit_btn.click(
159 | fn=classification, inputs=[input_text, labels, threshold, multi_label], outputs=output
160 | )
161 | multi_label.change(
162 | fn=classification, inputs=[input_text, labels, threshold, multi_label], outputs=output
163 | )
164 |
165 | demo.queue()
166 | demo.launch(debug=True, share=True)
--------------------------------------------------------------------------------
/gliclass/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import GLiClassModel, GLiClassBiEncoder, GLiClassUniEncoder
2 | from .config import GLiClassModelConfig
3 | from .pipeline import ZeroShotClassificationPipeline, BiEncoderZeroShotClassificationPipeline, ZeroShotClassificationWithLabelsChunkingPipeline
--------------------------------------------------------------------------------
/gliclass/config.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 | from transformers.configuration_utils import PretrainedConfig
3 | from transformers.utils import logging
4 | from transformers.models.auto import CONFIG_MAPPING
5 | logger = logging.get_logger(__name__)
6 |
7 |
8 | class GLiClassModelConfig(PretrainedConfig):
9 | model_type = "GLiClass"
10 | is_composition = True
11 |
12 | def __init__(
13 | self,
14 | encoder_config = None,
15 | encoder_model=None,
16 | label_model_config=None,
17 | label_model_name=None,
18 | class_token_index = -1,
19 | text_token_index = -1,
20 | ignore_index=-100,
21 | hidden_size=None,
22 | projector_hidden_act="gelu",
23 | vocab_size=None,
24 | problem_type='single_label_classification',
25 | max_num_classes=25,
26 | use_lstm=False,
27 | initializer_range=0.03,
28 | scorer_type='simple',
29 | pooling_strategy='first',
30 | focal_loss_alpha=0.5,
31 | focal_loss_gamma=2,
32 | logit_scale_init_value=2.6592,
33 | normalize_features=False,
34 | extract_text_features=False,
35 | contrastive_loss_coef=0,
36 | architecture_type = 'uni-encoder',
37 | prompt_first = False,
38 | squeeze_layers = False,
39 | embed_class_token = True,
40 | **kwargs,
41 | ):
42 | if isinstance(encoder_config, dict):
43 | encoder_config["model_type"] = (encoder_config["model_type"]
44 | if "model_type" in encoder_config
45 | else "deberta-v2")
46 | encoder_config = CONFIG_MAPPING[encoder_config["model_type"]](**encoder_config)
47 | elif encoder_config is None:
48 | encoder_config = CONFIG_MAPPING["deberta-v2"]()
49 |
50 | self.encoder_config = encoder_config
51 | self.encoder_model_name = encoder_model
52 |
53 | if label_model_name is not None:
54 | if isinstance(label_model_config, dict):
55 | label_model_config["model_type"] = (label_model_config["model_type"]
56 | if "model_type" in label_model_config
57 | else "deberta-v2")
58 | label_model_config = CONFIG_MAPPING[label_model_config["model_type"]](**label_model_config)
59 | elif label_model_config is None:
60 | label_model_config = CONFIG_MAPPING["deberta-v2"]()
61 |
62 | self.label_model_config = label_model_config
63 | else:
64 | self.label_model_config = None
65 | self.label_model_name = label_model_name
66 |
67 | if hidden_size is None:
68 | self.hidden_size = self.encoder_config.hidden_size
69 | else:
70 | self.hidden_size = hidden_size
71 |
72 | if vocab_size is None:
73 | self.vocab_size = self.encoder_config.vocab_size
74 | else:
75 | self.vocab_size = vocab_size
76 |
77 | if class_token_index == -1:
78 | self.class_token_index = self.vocab_size
79 | else:
80 | self.class_token_index = class_token_index
81 |
82 | if text_token_index == -1:
83 | self.text_token_index = self.vocab_size+1
84 | else:
85 | self.text_token_index = text_token_index
86 |
87 | self.ignore_index = ignore_index
88 | self.projector_hidden_act = projector_hidden_act
89 | self.problem_type = problem_type
90 | self.max_num_classes = max_num_classes
91 | self.initializer_range=initializer_range
92 | self.scorer_type = scorer_type
93 | self.pooling_strategy=pooling_strategy
94 | self.use_lstm = use_lstm
95 | self.focal_loss_alpha=focal_loss_alpha
96 | self.focal_loss_gamma=focal_loss_gamma
97 | self.contrastive_loss_coef=contrastive_loss_coef
98 | self.logit_scale_init_value = logit_scale_init_value
99 | self.normalize_features=normalize_features
100 | self.extract_text_features = extract_text_features
101 | self.architecture_type = architecture_type
102 | self.prompt_first = prompt_first
103 | self.squeeze_layers = squeeze_layers
104 | self.embed_class_token = embed_class_token
105 | super().__init__(**kwargs)
106 |
107 |
--------------------------------------------------------------------------------
/gliclass/data_processing.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch.nn.utils.rnn import pad_sequence
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class GLiClassDataset(Dataset):
8 | def __init__(self, examples, tokenizer, max_length=512,
9 | problem_type='multi_label_classification',
10 | architecture_type = 'uni-encoder',
11 | prompt_first=False,
12 | get_negatives = False,
13 | max_labels = 50,
14 | labels_tokenizer=None,
15 | shuffle_labels = True):
16 | self.tokenizer = tokenizer
17 | self.labels_tokenizer = labels_tokenizer
18 | self.max_length = max_length
19 | self._data = examples
20 | self.problem_type = problem_type
21 | self.architecture_type = architecture_type
22 | self.prompt_first = prompt_first
23 | self.dataset_labels = self.collect_dataset_labels()
24 | self.get_negatives = get_negatives
25 | self.max_labels = max_labels
26 | self.shuffle_labels = shuffle_labels
27 | print('Total labels: ', len(self.dataset_labels))
28 |
29 | def collect_dataset_labels(self):
30 | dataset_labels = set()
31 | for example in self._data:
32 | dataset_labels.update(set(example['all_labels']))
33 | return dataset_labels
34 |
35 | def prepare_labels(self, example, label2idx, problem_type):
36 | if problem_type == 'single_label_classification':
37 | labels = label2idx[example['true_labels'][0]]
38 | elif problem_type == 'multi_label_classification':
39 | if isinstance(example['true_labels'], dict):
40 | labels = [example['true_labels'][label] if label in example['true_labels'] else 0. for label in example['all_labels']]
41 | else:
42 | labels = [1. if label in example['true_labels'] else 0. for label in example['all_labels']]
43 | else:
44 | raise NotImplementedError(f"{problem_type} is not implemented.")
45 | return torch.tensor(labels)
46 |
47 | def prepare_prompt(self, example):
48 | prompt_texts = []
49 | for label in example['all_labels']:
50 | label_tag = f"<>{str(label)}"
51 | prompt_texts.append(label_tag)
52 | prompt_texts.append('<>')
53 | return prompt_texts
54 |
55 | def tokenize(self, texts):
56 | tokenized_inputs = self.tokenizer(texts, truncation=True, max_length=self.max_length, padding="longest")
57 | return tokenized_inputs
58 |
59 | def tokenize_labels(self, labels):
60 | tokenized_inputs = self.labels_tokenizer(labels, truncation=True, max_length=self.max_length, padding="longest")
61 | return tokenized_inputs
62 |
63 | def tokenize_and_prepare_labels_for_uniencoder(self, example):
64 | if self.shuffle_labels:
65 | random.shuffle(example['all_labels'])
66 | input_text = self.prepare_prompt(example)
67 | if self.prompt_first:
68 | input_text = ''.join(input_text)+str(example['text'])
69 | else:
70 | input_text = str(example['text'])+''.join(input_text)
71 | label2idx = {label: idx for idx, label in enumerate(example['all_labels'])}
72 |
73 | tokenized_inputs = self.tokenize(input_text)
74 | tokenized_inputs['labels'] = self.prepare_labels(example, label2idx, self.problem_type)
75 | tokenized_inputs['labels_text'] = example['all_labels']
76 | tokenized_inputs['input_texts'] = example['text']
77 | return tokenized_inputs
78 |
79 | def tokenize_and_prepare_labels_for_encoder_decoder(self, example):
80 | if self.shuffle_labels:
81 | random.shuffle(example['all_labels'])
82 | class_texts = self.prepare_prompt(example)
83 | class_texts = ''.join(class_texts)
84 |
85 | label2idx = {label: idx for idx, label in enumerate(example['all_labels'])}
86 |
87 | tokenized_inputs = self.tokenize(example['text'])
88 | tokenized_classes = self.tokenize(class_texts)
89 | tokenized_inputs["class_input_ids"] = tokenized_classes["input_ids"]
90 | tokenized_inputs["class_attention_mask"] = tokenized_classes["attention_mask"]
91 | tokenized_inputs['labels'] = self.prepare_labels(example, label2idx, self.problem_type)
92 | return tokenized_inputs
93 |
94 | def tokenize_and_prepare_labels_for_biencoder(self, example):
95 | if self.shuffle_labels:
96 | random.shuffle(example['all_labels'])
97 | def prepare_prompt(labels):
98 | prompt_texts = []
99 | for label in labels:
100 | label_tag = f"<>"
101 | prompt_texts.append(label_tag)
102 | prompt_texts.append('<>')
103 | return ''.join(prompt_texts)
104 |
105 | input_text = example['text']
106 | class_texts = example['all_labels']
107 |
108 | if self.architecture_type == 'bi-encoder-fused':
109 | prompt = prepare_prompt(class_texts)
110 | if self.prompt_first:
111 | input_text = f"{prompt} {input_text}"
112 | else:
113 | input_text = f"{input_text} {prompt}"
114 |
115 | tokenized_inputs = self.tokenize(input_text)
116 | tokenized_classes = self.tokenize_labels(class_texts)
117 |
118 | tokenized_inputs["class_input_ids"] = torch.tensor(tokenized_classes["input_ids"])
119 | tokenized_inputs["class_attention_mask"] = torch.tensor(tokenized_classes["attention_mask"])
120 |
121 | label2idx = {label: idx for idx, label in enumerate(example['all_labels'])}
122 |
123 | tokenized_inputs['labels_mask'] = torch.ones(len(class_texts))
124 | tokenized_inputs['labels'] = self.prepare_labels(example, label2idx, self.problem_type)
125 | return tokenized_inputs
126 |
127 | def __len__(self):
128 | return len(self._data)
129 |
130 | def __getitem__(self, idx):
131 | example = self._data[idx]
132 |
133 | if self.get_negatives and random.randint(0, 1):
134 | max_negatives = max(self.max_labels-len(example['all_labels']), 1)
135 | new_negatives = random.sample(self.dataset_labels, k=random.randint(1, max_negatives))
136 | example['all_labels'].extend(new_negatives)
137 |
138 | if self.architecture_type == 'uni-encoder':
139 | model_inputs = self.tokenize_and_prepare_labels_for_uniencoder(example)
140 | elif self.architecture_type == 'encoder-decoder':
141 | model_inputs = self.tokenize_and_prepare_labels_for_encoder_decoder(example)
142 | elif self.architecture_type in {'bi-encoder', 'bi-encoder-fused'}:
143 | model_inputs = self.tokenize_and_prepare_labels_for_biencoder(example)
144 | else:
145 | raise NotImplementedError('This architecture type is not implemented.')
146 | return model_inputs
147 |
148 |
149 | def pad_2d_tensor(key_data):
150 | """
151 | Pad a list of 2D tensors to have the same size along both dimensions.
152 |
153 | :param key_data: List of 2D tensors to pad.
154 | :return: Tensor of padded tensors stacked along a new batch dimension.
155 | """
156 | if not key_data:
157 | raise ValueError("The input list 'key_data' should not be empty.")
158 |
159 | # Determine the maximum size along both dimensions
160 | max_rows = max(tensor.shape[0] for tensor in key_data)
161 | max_cols = max(tensor.shape[1] for tensor in key_data)
162 |
163 | tensors = []
164 |
165 | for tensor in key_data:
166 | rows, cols = tensor.shape
167 | row_padding = max_rows - rows
168 | col_padding = max_cols - cols
169 | # Pad the tensor along both dimensions
170 | padded_tensor = torch.nn.functional.pad(tensor, (0, col_padding, 0, row_padding),
171 | mode='constant', value=0)
172 | tensors.append(padded_tensor)
173 |
174 | # Stack the tensors into a single tensor along a new batch dimension
175 | padded_tensors = torch.stack(tensors)
176 |
177 | return padded_tensors
178 |
179 | class DataCollatorWithPadding:
180 | def __init__(self, device = 'cuda:0'):
181 | self.device = device
182 |
183 | def __call__(self, batch):
184 | keys = batch[0].keys()
185 | padded_batch = {key: [] for key in keys}
186 |
187 | for key in keys:
188 | key_data = [item[key] for item in batch]
189 | if isinstance(key_data[0], torch.Tensor):
190 | if key_data[0].dim() == 1:
191 | padded_batch[key] = pad_sequence(key_data, batch_first=True)
192 | elif key_data[0].dim() == 2:
193 | padded_batch[key] = pad_2d_tensor(key_data)
194 | elif isinstance(key_data[0], list):
195 | data_el = "string"
196 | if len(key_data[0]):
197 | data_el = key_data[0][0]
198 | if isinstance(data_el, str):
199 | padded_batch[key] = key_data
200 | else:
201 | max_length = max(len(seq) for seq in key_data)
202 | padded_batch[key] = torch.tensor([seq + [0] * (max_length - len(seq))
203 | for seq in key_data])
204 | elif type(key_data[0]) in {int, float}:
205 | padded_batch[key] = torch.tensor(key_data)
206 | elif isinstance(key_data[0], str):
207 | padded_batch[key] = key_data
208 | else:
209 | raise TypeError(f"Unsupported data type: {type(key_data[0])}")
210 |
211 | return padded_batch
--------------------------------------------------------------------------------
/gliclass/layers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 Microsoft and the Hugging Face Inc. team and Knowledgator.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from typing import Union
16 | import torch
17 | from torch import nn
18 | import torch.nn.functional as F
19 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
20 | from transformers.activations import ACT2FN
21 |
22 | from .config import GLiClassModelConfig
23 |
24 | class LstmSeq2SeqEncoder(nn.Module):
25 | def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False):
26 | super(LstmSeq2SeqEncoder, self).__init__()
27 | self.lstm = nn.LSTM(input_size=input_size,
28 | hidden_size=hidden_size,
29 | num_layers=num_layers,
30 | dropout=dropout,
31 | bidirectional=bidirectional,
32 | batch_first=True)
33 |
34 | def forward(self, x, mask, hidden=None):
35 | # Packing the input sequence
36 | lengths = mask.sum(dim=1).cpu()
37 | packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
38 |
39 | # Passing packed sequence through LSTM
40 | packed_output, hidden = self.lstm(packed_x, hidden)
41 |
42 | # Unpacking the output sequence
43 | output, _ = pad_packed_sequence(packed_output, batch_first=True)
44 |
45 | return output
46 |
47 | class FeaturesProjector(nn.Module):
48 | def __init__(self, config: GLiClassModelConfig):
49 | super().__init__()
50 |
51 | self.linear_1 = nn.Linear(config.encoder_config.hidden_size, config.hidden_size, bias=True)
52 | self.act = ACT2FN[config.projector_hidden_act]
53 | self.linear_2 = nn.Linear(config.encoder_config.hidden_size, config.encoder_config.hidden_size, bias=True)
54 |
55 | def forward(self, features):
56 | hidden_states = self.linear_1(features)
57 | hidden_states = self.act(hidden_states)
58 | hidden_states = self.linear_2(hidden_states)
59 | return hidden_states
60 |
61 | class BiEncoderProjector(nn.Module):
62 | def __init__(self, config: GLiClassModelConfig):
63 | super().__init__()
64 |
65 | self.linear_1 = nn.Linear(config.label_model_config.hidden_size, config.hidden_size, bias=True)
66 | self.act = ACT2FN[config.projector_hidden_act]
67 | self.linear_2 = nn.Linear(config.hidden_size, config.encoder_config.hidden_size, bias=True)
68 |
69 | def forward(self, features):
70 | hidden_states = self.linear_1(features)
71 | hidden_states = self.act(hidden_states)
72 | hidden_states = self.linear_2(hidden_states)
73 | return hidden_states
74 |
75 | # Copied from transformers.models.deberta.modeling_deberta.DropoutContext
76 | class DropoutContext(object):
77 | def __init__(self):
78 | self.dropout = 0
79 | self.mask = None
80 | self.scale = 1
81 | self.reuse_mask = True
82 |
83 |
84 | # Copied from transformers.models.deberta.modeling_deberta.get_mask
85 | def get_mask(input, local_context):
86 | if not isinstance(local_context, DropoutContext):
87 | dropout = local_context
88 | mask = None
89 | else:
90 | dropout = local_context.dropout
91 | dropout *= local_context.scale
92 | mask = local_context.mask if local_context.reuse_mask else None
93 |
94 | if dropout > 0 and mask is None:
95 | mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
96 |
97 | if isinstance(local_context, DropoutContext):
98 | if local_context.mask is None:
99 | local_context.mask = mask
100 |
101 | return mask, dropout
102 |
103 |
104 | # Copied from transformers.models.deberta.modeling_deberta.XDropout
105 | class XDropout(torch.autograd.Function):
106 | """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
107 |
108 | @staticmethod
109 | def forward(ctx, input, local_ctx):
110 | mask, dropout = get_mask(input, local_ctx)
111 | ctx.scale = 1.0 / (1 - dropout)
112 | if dropout > 0:
113 | ctx.save_for_backward(mask)
114 | return input.masked_fill(mask, 0) * ctx.scale
115 | else:
116 | return input
117 |
118 | @staticmethod
119 | def backward(ctx, grad_output):
120 | if ctx.scale > 1:
121 | (mask,) = ctx.saved_tensors
122 | return grad_output.masked_fill(mask, 0) * ctx.scale, None
123 | else:
124 | return grad_output, None
125 |
126 | @staticmethod
127 | def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
128 | from torch.onnx import symbolic_opset12
129 |
130 | dropout_p = local_ctx
131 | if isinstance(local_ctx, DropoutContext):
132 | dropout_p = local_ctx.dropout
133 | # StableDropout only calls this function when training.
134 | train = True
135 | # TODO: We should check if the opset_version being used to export
136 | # is > 12 here, but there's no good way to do that. As-is, if the
137 | # opset_version < 12, export will fail with a CheckerError.
138 | # Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
139 | # if opset_version < 12:
140 | # return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
141 | return symbolic_opset12.dropout(g, input, dropout_p, train)
142 |
143 | # Copied from transformers.models.deberta.modeling_deberta.StableDropout
144 | class StableDropout(nn.Module):
145 | """
146 | Optimized dropout module for stabilizing the training
147 |
148 | Args:
149 | drop_prob (float): the dropout probabilities
150 | """
151 |
152 | def __init__(self, drop_prob):
153 | super().__init__()
154 | self.drop_prob = drop_prob
155 | self.count = 0
156 | self.context_stack = None
157 |
158 | def forward(self, x):
159 | """
160 | Call the module
161 |
162 | Args:
163 | x (`torch.tensor`): The input tensor to apply dropout
164 | """
165 | if self.training and self.drop_prob > 0:
166 | return XDropout.apply(x, self.get_context())
167 | return x
168 |
169 | def clear_context(self):
170 | self.count = 0
171 | self.context_stack = None
172 |
173 | def init_context(self, reuse_mask=True, scale=1):
174 | if self.context_stack is None:
175 | self.context_stack = []
176 | self.count = 0
177 | for c in self.context_stack:
178 | c.reuse_mask = reuse_mask
179 | c.scale = scale
180 |
181 | def get_context(self):
182 | if self.context_stack is not None:
183 | if self.count >= len(self.context_stack):
184 | self.context_stack.append(DropoutContext())
185 | ctx = self.context_stack[self.count]
186 | ctx.dropout = self.drop_prob
187 | self.count += 1
188 | return ctx
189 | else:
190 | return self.drop_prob
191 |
192 | class SelfAttentionBlock(nn.Module):
193 | def __init__(self, d_model, num_heads, dropout=0.1):
194 | super().__init__()
195 | self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
196 | self.norm = nn.LayerNorm(d_model)
197 | self.dropout = nn.Dropout(dropout)
198 |
199 | def forward(self, x, mask=None):
200 | attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
201 | return self.norm(x + self.dropout(attn_output))
202 |
203 | class CrossAttentionBlock(nn.Module):
204 | def __init__(self, d_model, num_heads, dropout=0.1):
205 | super().__init__()
206 | self.cross_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
207 | self.norm = nn.LayerNorm(d_model)
208 | self.dropout = nn.Dropout(dropout)
209 |
210 | def forward(self, query, key, value, mask=None):
211 | attn_output, _ = self.cross_attn(query, key, value, attn_mask=mask)
212 | return self.norm(query + self.dropout(attn_output))
213 |
214 | class Fuser(nn.Module):
215 | def __init__(self, d_model, num_heads, num_layers, dropout=0.1):
216 | super().__init__()
217 | self.d_model = d_model
218 | self.layers = nn.ModuleList([
219 | nn.ModuleList([
220 | SelfAttentionBlock(d_model, num_heads, dropout),
221 | CrossAttentionBlock(d_model, num_heads, dropout)
222 | ])
223 | for _ in range(num_layers)
224 | ])
225 | self.fc = nn.Linear(d_model, d_model)
226 |
227 | def forward(self, query, key, query_mask=None, key_mask=None):
228 | if query_mask is not None and key_mask is not None:
229 | self_attn_mask = query_mask.unsqueeze(1) * query_mask.unsqueeze(2)
230 | cross_attn_mask = query_mask.unsqueeze(-1) * key_mask.unsqueeze(1)
231 | else:
232 | self_attn_mask = None
233 | cross_attn_mask = None
234 |
235 | value = self.fc(key)
236 |
237 | for self_attn, cross_attn in self.layers:
238 | query = self_attn(query, mask=self_attn_mask)
239 | query = cross_attn(query, key, value, mask=cross_attn_mask)
240 |
241 | return query
242 |
243 | class LayerwiseAttention(nn.Module):
244 | def __init__(self, num_layers, hidden_size, output_size=None):
245 | super().__init__()
246 | self.num_layers = num_layers
247 | self.hidden_size = hidden_size
248 | self.output_size = output_size if output_size is not None else hidden_size
249 |
250 | # Squeeze operation
251 | self.squeeze = nn.Linear(hidden_size, 1)
252 |
253 | # Excitation operation
254 | self.W1 = nn.Linear(num_layers, num_layers // 2)
255 | self.W2 = nn.Linear(num_layers // 2, num_layers)
256 |
257 | # Final projection
258 | self.output_projection = nn.Linear(self.hidden_size, self.output_size)
259 |
260 | def forward(self, encoder_outputs):
261 | # encoder_outputs is a list of tensors, each of shape [B, L, D]
262 | B, L, D = encoder_outputs[0].shape
263 |
264 | # Concatenate all layers
265 | U = torch.stack(encoder_outputs, dim=1) # [B, K, L, D]
266 |
267 | # Squeeze operation
268 | Z = self.squeeze(U).squeeze(-1) # [B, K, L]
269 | Z = Z.mean(dim=2) # [B, K]
270 |
271 | # Excitation operation
272 | s = self.W2(F.relu(self.W1(Z))) # [B, K]
273 | s = torch.sigmoid(s) # [B, K]
274 |
275 | # Apply attention weights
276 | U_weighted = U * s.unsqueeze(-1).unsqueeze(-1) # [B, K, L, D]
277 |
278 | # Sum across layers
279 | U_sum = U_weighted.sum(dim=1) # [B, L, D]
280 |
281 | # Final projection
282 | output = self.output_projection(U_sum) # [B, L, output_size]
283 |
284 | return output
--------------------------------------------------------------------------------
/gliclass/loss_functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def sequence_contrastive_loss(embeddings, mask):
6 | # embeddings shape: (B, L, D)
7 | # mask shape: (B, L)
8 | B, L, D = embeddings.shape
9 |
10 | # Normalize embeddings
11 | embeddings = F.normalize(embeddings, p=2, dim=-1)
12 |
13 | # Compute similarity matrix
14 | sim_matrix = torch.matmul(embeddings, embeddings.transpose(1, 2)) #/ self.temperature
15 |
16 | # Create labels for cross entropy (diagonal indices)
17 | labels = torch.arange(L, device=embeddings.device).unsqueeze(0).expand(B, -1)
18 |
19 | # Compute loss for each element in the batch
20 | loss = F.cross_entropy(sim_matrix.reshape(B*L, L), labels.reshape(-1), reduction='none')
21 |
22 | # Apply mask to loss
23 | loss = loss.view(B, L) * mask
24 |
25 | # Compute mean loss over non-padded elements
26 | loss = loss.sum() / mask.sum()
27 |
28 | return loss
29 |
30 |
31 | def focal_loss_with_logits(
32 | inputs: torch.Tensor,
33 | targets: torch.Tensor,
34 | alpha: float = 0.25,
35 | gamma: float = 2,
36 | reduction: str = "none",
37 | label_smoothing: float = 0.0,
38 | ignore_index: int = -100 # default value for ignored index
39 | ) -> torch.Tensor:
40 | """
41 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
42 |
43 | Args:
44 | inputs (Tensor): A float tensor of arbitrary shape.
45 | The predictions for each example.
46 | targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
47 | classification label for each element in inputs
48 | (0 for the negative class and 1 for the positive class).
49 | alpha (float): Weighting factor in range (0,1) to balance
50 | positive vs negative examples or -1 for ignore. Default: ``0.25``.
51 | gamma (float): Exponent of the modulating factor (1 - p_t) to
52 | balance easy vs hard examples. Default: ``2``.
53 | reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
54 | ``'none'``: No reduction will be applied to the output.
55 | ``'mean'``: The output will be averaged.
56 | ``'sum'``: The output will be summed. Default: ``'none'``.
57 | label_smoothing (float): Specifies the amount of smoothing when computing the loss,
58 | where 0.0 means no smoothing.
59 | ignore_index (int): Specifies a target value that is ignored and does not contribute
60 | to the input gradient. Default: ``-100``.
61 | Returns:
62 | Loss tensor with the reduction option applied.
63 | """
64 | # Create a mask to ignore specified index
65 | valid_mask = targets != ignore_index
66 |
67 | # Apply label smoothing if needed
68 | if label_smoothing != 0:
69 | with torch.no_grad():
70 | targets = targets * (1 - label_smoothing) + 0.5 * label_smoothing
71 |
72 | # Apply sigmoid activation to inputs
73 | p = torch.sigmoid(inputs)
74 |
75 | # Compute the binary cross-entropy loss without reduction
76 | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
77 |
78 | # Apply the valid mask to the loss
79 | loss = loss * valid_mask
80 |
81 | # Apply focal loss modulation if gamma is greater than 0
82 | if gamma > 0:
83 | p_t = p * targets + (1 - p) * (1 - targets)
84 | loss = loss * ((1 - p_t) ** gamma)
85 |
86 | # Apply alpha weighting if alpha is specified
87 | if alpha >= 0:
88 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
89 | loss = alpha_t * loss
90 |
91 | # Apply reduction method
92 | if reduction == "none":
93 | return loss
94 | elif reduction == "mean":
95 | return loss.sum() / valid_mask.sum() # Normalize by the number of valid (non-ignored) elements
96 | elif reduction == "sum":
97 | return loss.sum()
98 | else:
99 | raise ValueError(
100 | f"Invalid value for argument 'reduction': '{reduction}'. "
101 | f"Supported reduction modes: 'none', 'mean', 'sum'"
102 | )
--------------------------------------------------------------------------------
/gliclass/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | from pathlib import Path
4 | from dataclasses import dataclass
5 | from typing import List, Optional, Tuple, Union
6 |
7 | import torch
8 | import torch.utils.checkpoint
9 | from torch import nn
10 |
11 | from transformers import PreTrainedModel, AutoConfig, AutoModel
12 | from transformers.activations import ACT2FN
13 | from transformers.modeling_outputs import SequenceClassifierOutput
14 | from transformers.utils import (logging)
15 | from transformers.models.auto import AutoModel
16 | from .config import GLiClassModelConfig
17 | from .layers import FeaturesProjector, LstmSeq2SeqEncoder, BiEncoderProjector, LayerwiseAttention
18 | from .poolings import POOLING2OBJECT
19 | from .scorers import SCORER2OBJECT
20 | from .loss_functions import focal_loss_with_logits, sequence_contrastive_loss
21 | from .utils import is_module_available, MissedPackageException
22 |
23 | IS_LLM2VEC = is_module_available('llm2vec')
24 | IS_PEFT = is_module_available('peft')
25 | IS_TURBOT5 = is_module_available('turbot5')
26 | IS_FLASHDEBERTA = is_module_available('flashdeberta')
27 |
28 | logger = logging.get_logger(__name__)
29 |
30 | if IS_LLM2VEC:
31 | from llm2vec.models import MistralBiModel, LlamaBiModel, GemmaBiModel, Qwen2BiModel
32 | DECODER_MODEL_MAPPING = {
33 | "MistralConfig": MistralBiModel,
34 | "LlamaConfig": LlamaBiModel,
35 | "GemmaConfig": GemmaBiModel,
36 | "Qwen2Config": Qwen2BiModel
37 | }
38 | else:
39 | DECODER_MODEL_MAPPING = {}
40 |
41 | if IS_TURBOT5:
42 | from turbot5.model.modeling import T5EncoderModel
43 | else:
44 | from transformers import T5EncoderModel
45 |
46 | if IS_FLASHDEBERTA:
47 | from flashdeberta import FlashDebertaV2Model as DebertaV2Model
48 | else:
49 | from transformers import DebertaV2Model
50 |
51 | if IS_PEFT:
52 | from peft import LoraConfig, get_peft_model
53 |
54 | @dataclass
55 | class GLiClassOutput(SequenceClassifierOutput):
56 | text_embeddings: Optional[torch.Tensor] = None
57 | class_embeddings: Optional[torch.Tensor] = None
58 |
59 | class GLiClassPreTrainedModel(PreTrainedModel):
60 | config_class = GLiClassModelConfig
61 | base_model_prefix = "model"
62 | supports_gradient_checkpointing = True
63 | _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
64 |
65 | def _init_weights(self, module):
66 | std = (
67 | self.config.initializer_range
68 | if hasattr(self.config, "initializer_range")
69 | else self.config.encoder_config.initializer_range
70 | )
71 |
72 | if hasattr(module, "class_embedding"):
73 | module.class_embedding.data.normal_(mean=0.0, std=std)
74 |
75 | if isinstance(module, (nn.Linear, nn.Conv2d)):
76 | module.weight.data.normal_(mean=0.0, std=std)
77 | if module.bias is not None:
78 | module.bias.data.zero_()
79 | elif isinstance(module, nn.Embedding):
80 | module.weight.data.normal_(mean=0.0, std=std)
81 | if module.padding_idx is not None:
82 | module.weight.data[module.padding_idx].zero_()
83 | elif isinstance(module, nn.LSTM):
84 | for name, param in module.named_parameters():
85 | if 'weight_ih' in name or 'weight_hh' in name:
86 | nn.init.normal_(param.data, mean=0.0, std=std)
87 | elif 'bias' in name:
88 | param.data.zero_()
89 | @property
90 | def _supports_sdpa(self):
91 | """
92 | Retrieve language_model's attribute to check whether the model supports
93 | SDPA or not.
94 | """
95 | return self.language_model._supports_sdpa
96 |
97 |
98 | class GLiClassBaseModel(nn.Module):#):
99 | def __init__(self, config: GLiClassModelConfig, device='cpu', **kwargs):
100 | super().__init__()
101 | self.config = config
102 | self.text_projector = FeaturesProjector(config)
103 | self.classes_projector = FeaturesProjector(config)
104 |
105 | if config.pooling_strategy not in POOLING2OBJECT:
106 | raise NotImplementedError(f"{config.pooling_strategy} is not implemented pooling type.")
107 | else:
108 | self.pooler = POOLING2OBJECT[config.pooling_strategy]()
109 |
110 | if config.pooling_strategy not in POOLING2OBJECT:
111 | raise NotImplementedError(f"{config.scorer_type} is not implemented. Choose one of this: 'dot', 'weighted-dot'")
112 | else:
113 | self.scorer = SCORER2OBJECT[config.scorer_type](config.hidden_size)
114 |
115 | if config.use_lstm:
116 | self.lstm = LstmSeq2SeqEncoder(config.hidden_size, config.hidden_size//2, bidirectional=True)
117 |
118 | if config.squeeze_layers:
119 | self.layer_wise_attention = LayerwiseAttention(config.encoder_config.num_hidden_layers,
120 | config.encoder_config.hidden_size)
121 |
122 | drop_out = getattr(config.encoder_config, "cls_dropout", None)
123 | if drop_out is None:
124 | if hasattr(self.config.encoder_config, 'hidden_dropout_prob'):
125 | drop_out = self.config.encoder_config.hidden_dropout_prob
126 | elif hasattr(self.config.encoder_config, 'dropout_rate'):
127 | drop_out = self.config.encoder_config.dropout_rate
128 | else:
129 | drop_out = 0.15
130 | # self.dropout = StableDropout(drop_out)
131 | self.dropout = nn.Dropout(drop_out)
132 |
133 |
134 | self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
135 |
136 | self.epsilon = 1e-8
137 | self.vocab_size = config.vocab_size
138 | self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
139 | self.num_labels = -1
140 |
141 | self.device = torch.device(device)
142 |
143 | def _extract_class_features(self, token_embeds, input_ids, attention_mask):
144 | batch_size, sequence_length, embed_dim = token_embeds.shape
145 |
146 | class_token_mask = input_ids == self.config.class_token_index
147 | num_class_tokens = torch.sum(class_token_mask, dim=-1, keepdim=True)
148 |
149 | max_embed_dim = num_class_tokens.max()
150 |
151 | aranged_class_idx = torch.arange(max_embed_dim,
152 | dtype=attention_mask.dtype,
153 | device=token_embeds.device).expand(batch_size, -1)
154 |
155 | batch_indices, target_class_idx = torch.where(aranged_class_idx=text_token_indices.unsqueeze(1))
190 | _, target_text_idx = torch.where(aranged_token_idx= 0).nonzero()
211 | labels = labels.long()
212 | if label_index.size(0) > 0:
213 | labeled_logits = torch.gather(
214 | logits, 0, label_index.expand(label_index.size(0), logits.size(1))
215 | )
216 | labels = torch.gather(labels, 0, label_index.view(-1))
217 | loss_fct = nn.CrossEntropyLoss()
218 | loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
219 | else:
220 | loss = torch.tensor(0).to(logits)
221 | else:
222 | log_softmax = nn.LogSoftmax(-1)
223 | loss = -((log_softmax(logits) * labels).sum(-1)).mean()
224 | elif self.config.problem_type == "regression":
225 | loss_fct = nn.MSELoss()
226 | if self.num_labels == 1:
227 | loss = loss_fct(logits.squeeze(), labels.squeeze())
228 | else:
229 | loss = loss_fct(logits, labels)
230 | elif self.config.problem_type == "single_label_classification":
231 | loss_fct = nn.CrossEntropyLoss()
232 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
233 | elif self.config.problem_type == "multi_label_classification":
234 | all_losses = focal_loss_with_logits(logits, labels,
235 | self.config.focal_loss_alpha, self.config.focal_loss_gamma)
236 | if classes_embedding_mask is not None:
237 | all_losses = all_losses * classes_embedding_mask.float()
238 | loss = all_losses.mean()
239 |
240 | if self.config.contrastive_loss_coef>0 and classes_embedding is not None:
241 | contrastive_loss = sequence_contrastive_loss(classes_embedding, classes_embedding_mask)
242 | loss = loss+contrastive_loss*self.config.contrastive_loss_coef
243 | return loss
244 |
245 | class GLiClassUniEncoder(GLiClassBaseModel):
246 | def __init__(self, config: GLiClassModelConfig, from_pretrained = False):
247 | super().__init__(config)
248 | if config.encoder_config is None:
249 | if config.encoder_model_name is None:
250 | raise ValueError("You need to specify encoder model name to use it as a backbone.")
251 | config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)
252 |
253 | config_name = config.encoder_config.__class__.__name__
254 |
255 | if config_name in DECODER_MODEL_MAPPING:
256 | if not IS_LLM2VEC:
257 | raise MissedPackageException(f"The llm2vec package must be installed to use this decoder model: {config_name}")
258 | else:
259 | print('Loading decoder model using LLM2Vec...')
260 | ModelClass = DECODER_MODEL_MAPPING[config_name]
261 | decoder = True
262 | elif config_name in {'T5Config', 'MT5Config'}:
263 | decoder = False
264 | ModelClass = T5EncoderModel
265 | elif config_name in {'DebertaV2Config'}:
266 | decoder = False
267 | ModelClass = DebertaV2Model
268 | else:
269 | decoder = False
270 | ModelClass = AutoModel
271 |
272 | if from_pretrained:
273 | self.encoder_model = ModelClass.from_pretrained(
274 | config.encoder_model_name
275 | )
276 | else:
277 | if decoder:
278 | self.encoder_model = ModelClass(config.encoder_config)
279 | else:
280 | if config_name in {'T5Config', 'MT5Config', 'DebertaV2Config'}:
281 | self.encoder_model = ModelClass._from_config(
282 | config.encoder_config
283 | )
284 | else:
285 | self.encoder_model = ModelClass.from_config(
286 | config.encoder_config
287 | )
288 |
289 | adapter_config_file = Path(config.encoder_model_name) / "adapter_config.json"
290 |
291 | if adapter_config_file.exists():
292 | if not IS_PEFT:
293 | warnings.warn(f"Adapter configs were detected, if you want to apply them you need to install peft package.")
294 | else:
295 | adapter_config = LoraConfig.from_pretrained(config.encoder_model_name)
296 | self.encoder_model = get_peft_model(self.encoder_model, adapter_config)
297 |
298 | def forward(
299 | self,
300 | input_ids: Optional[torch.Tensor] = None,
301 | attention_mask: Optional[torch.Tensor] = None,
302 | inputs_embeds: Optional[torch.Tensor] = None,
303 | labels: Optional[torch.Tensor] = None,
304 | output_attentions: Optional[bool] = None,
305 | output_hidden_states: Optional[bool] = None,
306 | output_text_embeddings: Optional[bool] = None,
307 | output_class_embeddings: Optional[bool] = None,
308 | return_dict: Optional[bool] = None,
309 | **kwargs
310 | ) -> Union[Tuple, GLiClassOutput]:
311 | r"""
312 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
313 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
314 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
315 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
316 | """
317 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
318 |
319 | if self.config.squeeze_layers:
320 | output_hidden_states = True
321 | return_dict = True
322 |
323 | outputs = self.encoder_model(
324 | input_ids,
325 | attention_mask=attention_mask,
326 | # inputs_embeds=inputs_embeds,
327 | output_attentions=output_attentions,
328 | output_hidden_states=output_hidden_states,
329 | return_dict=return_dict,
330 | **kwargs
331 | )
332 |
333 | if self.config.squeeze_layers:
334 | encoder_layer = self.layer_wise_attention(outputs.hidden_states)
335 | else:
336 | encoder_layer = outputs[0]
337 |
338 | classes_embedding, classes_embedding_mask, text_token_embeddings, text_mask = self._extract_class_features(encoder_layer,
339 | input_ids, attention_mask)
340 | if self.config.use_lstm:
341 | text_token_embeddings = self.lstm(text_token_embeddings, text_mask)
342 |
343 | pooled_output = self.pooler(text_token_embeddings)
344 | pooled_output = self.text_projector(pooled_output)
345 | pooled_output = self.dropout(pooled_output)
346 | if self.config.normalize_features:
347 | pooled_output = pooled_output / (pooled_output.norm(p=2, dim=-1, keepdim=True)+self.epsilon)
348 |
349 | classes_embedding = self.classes_projector(classes_embedding)
350 | if self.config.normalize_features:
351 | classes_embedding = classes_embedding / (classes_embedding.norm(p=2, dim=-1, keepdim=True)+self.epsilon)
352 |
353 | logits = self.scorer(pooled_output, classes_embedding)
354 |
355 | if self.config.normalize_features:
356 | logits = logits*self.logit_scale.to(classes_embedding.device)
357 |
358 | loss = self.get_loss(logits, labels, classes_embedding, classes_embedding_mask)
359 |
360 | if not return_dict:
361 | output = (logits,) + outputs[1:]
362 | return ((loss,) + output) if loss is not None else output
363 |
364 | return GLiClassOutput(
365 | loss=loss, logits=logits,
366 | hidden_states=outputs.hidden_states,
367 | attentions=outputs.attentions,
368 | text_embeddings= pooled_output if output_text_embeddings else None,
369 | class_embeddings= classes_embedding if output_class_embeddings else None,
370 | )
371 |
372 |
373 | class GLiClassEncoderDecoder(GLiClassBaseModel):
374 | def __init__(self, config: GLiClassModelConfig, from_pretrained = False):
375 | super().__init__(config)
376 | if config.encoder_config is None:
377 | if config.encoder_model_name is None:
378 | raise ValueError("You need to specify encoder model name to use it as a backbone.")
379 | config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)
380 |
381 | if not config.encoder_config.is_encoder_decoder:
382 | raise ValueError("You need to choose encoder-decoder model as a backbone.")
383 |
384 | if from_pretrained:
385 | self.encoder_decoder_model = AutoModel.from_pretrained(
386 | config.encoder_model_name
387 | )
388 | else:
389 | self.encoder_decoder_model = AutoModel.from_config(
390 | config.encoder_config
391 | )
392 |
393 | def forward(
394 | self,
395 | input_ids: Optional[torch.Tensor] = None,
396 | attention_mask: Optional[torch.Tensor] = None,
397 | class_input_ids: Optional[torch.Tensor] = None,
398 | class_attention_mask: Optional[torch.Tensor] = None,
399 | inputs_embeds: Optional[torch.Tensor] = None,
400 | labels: Optional[torch.Tensor] = None,
401 | output_attentions: Optional[bool] = None,
402 | output_hidden_states: Optional[bool] = None,
403 | output_text_embeddings: Optional[bool] = None,
404 | output_class_embeddings: Optional[bool] = None,
405 | return_dict: Optional[bool] = True,
406 | **kwargs
407 | ) -> Union[Tuple, SequenceClassifierOutput]:
408 | r"""
409 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
410 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
411 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
412 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
413 | """
414 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
415 |
416 | outputs = self.encoder_decoder_model(
417 | input_ids=input_ids,
418 | attention_mask=attention_mask,
419 | decoder_input_ids=class_input_ids,
420 | decoder_attention_mask=class_attention_mask,
421 | inputs_embeds=inputs_embeds,
422 | output_attentions=output_attentions,
423 | output_hidden_states=output_hidden_states,
424 | return_dict=return_dict,
425 | **kwargs
426 | )
427 | text_token_embeddings = outputs.encoder_last_hidden_state
428 | decoder_token_embeddings = outputs.last_hidden_state
429 | classes_embedding, classes_embedding_mask, _, _ = self._extract_class_features(decoder_token_embeddings,
430 | class_input_ids, class_attention_mask)
431 |
432 | if self.config.use_lstm:
433 | text_token_embeddings = self.lstm(text_token_embeddings, attention_mask)
434 |
435 | pooled_output = self.pooler(text_token_embeddings)
436 | pooled_output = self.text_projector(pooled_output)
437 | pooled_output = self.dropout(pooled_output)
438 | if self.config.normalize_features:
439 | pooled_output = nn.functional.normalize(pooled_output, p=2, dim=-1, eps=self.epsilon)
440 |
441 | classes_embedding = self.classes_projector(classes_embedding)
442 | if self.config.normalize_features:
443 | classes_embedding = nn.functional.normalize(classes_embedding, p=2, dim=-1, eps=self.epsilon)
444 |
445 | logits = self.scorer(pooled_output, classes_embedding)
446 |
447 | if self.config.normalize_features:
448 | logits = logits*self.logit_scale.to(classes_embedding.device)
449 |
450 | loss = self.get_loss(logits, labels, classes_embedding, classes_embedding_mask)
451 |
452 | if not return_dict:
453 | output = (logits,) + outputs[1:]
454 | return ((loss,) + output) if loss is not None else output
455 |
456 | return GLiClassOutput(
457 | loss=loss, logits=logits,
458 | hidden_states = outputs.hidden_states,
459 | attentions = outputs.attentions,
460 | text_embeddings = pooled_output if output_text_embeddings else None,
461 | class_embeddings = classes_embedding if output_class_embeddings else None,
462 | )
463 |
464 | class GLiClassBiEncoder(GLiClassBaseModel):
465 | def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
466 | super().__init__(config)
467 | if config.encoder_config is None:
468 | if config.encoder_model_name is None:
469 | raise ValueError("You need to specify encoder model name to use it as a backbone.")
470 | config.encoder_config = AutoConfig.from_pretrained(config.encoder_model_name)
471 |
472 | if config.label_model_config is None:
473 | if config.label_model_name is None:
474 | raise ValueError("You need to specify label model name to use it as a backbone.")
475 | config.label_model_config = AutoConfig.from_pretrained(config.label_model_name)
476 |
477 | def initialize_encoder(configs, model_name, from_pretrained):
478 | if from_pretrained:
479 | return AutoModel.from_pretrained(model_name)
480 | else:
481 | return AutoModel.from_config(configs)
482 | self.encoder_model = initialize_encoder(config.encoder_config, config.encoder_model_name, from_pretrained)
483 | self.label_encoder = initialize_encoder(config.label_model_config, config.label_model_name, from_pretrained)
484 | self.biencoder_projector = BiEncoderProjector(config)
485 |
486 | def pool_outputs(self, encoder_outputs):
487 | text_embeddings = self.pooler(encoder_outputs[0])
488 | text_embeddings = self.text_projector(text_embeddings)
489 | text_embeddings = self.dropout(text_embeddings)
490 | if self.config.normalize_features:
491 | text_embeddings = nn.functional.normalize(text_embeddings, p=2, dim=-1, eps=self.epsilon)
492 | return text_embeddings
493 |
494 | def encode_text(self, input_ids, attention_mask):
495 | outputs = self.encoder_model(input_ids.squeeze(1), attention_mask=attention_mask.squeeze(1))
496 | text_embeddings = self.pool_outputs(outputs)
497 | return text_embeddings
498 |
499 | def encode_classes(self, class_input_ids, class_attention_mask, labels_mask=None):
500 | batch_size = class_input_ids.shape[0]
501 | num_classes = class_input_ids.shape[1]
502 | if labels_mask is not None:
503 | batch_indices, indices = torch.where(labels_mask==1)
504 | selected_input_ids = class_input_ids[batch_indices, indices]
505 | selected_attention_mask = class_attention_mask[batch_indices, indices]
506 |
507 | outputs = self.label_encoder(selected_input_ids, attention_mask=selected_attention_mask)
508 | class_embeddings_filtered = self.pooler(outputs[0])
509 |
510 | class_embeddings = torch.zeros(batch_size, num_classes, class_embeddings_filtered.shape[-1],
511 | dtype=class_embeddings_filtered.dtype,
512 | device=class_embeddings_filtered.device)
513 |
514 | class_embeddings[batch_indices, indices] = class_embeddings_filtered
515 | else:
516 | class_input_ids = class_input_ids.view(-1, class_input_ids.shape[-1])
517 | class_attention_mask = class_attention_mask.view(-1, class_input_ids.shape[-1])
518 | outputs = self.label_encoder(class_input_ids, attention_mask=class_attention_mask)
519 | class_embeddings = self.pooler(outputs[0])
520 | class_embeddings = class_embeddings.reshape(batch_size, num_classes, -1)
521 | class_embeddings = self.biencoder_projector(class_embeddings)
522 | class_embeddings = self.classes_projector(class_embeddings)
523 | if self.config.normalize_features:
524 | class_embeddings = nn.functional.normalize(class_embeddings, p=2, dim=-1, eps=self.epsilon)
525 | return class_embeddings
526 |
527 | def forward(
528 | self,
529 | input_ids: Optional[torch.Tensor] = None,
530 | attention_mask: Optional[torch.Tensor] = None,
531 | class_input_ids: Optional[torch.Tensor] = None,
532 | class_attention_mask: Optional[torch.Tensor] = None,
533 | labels_mask: Optional[torch.Tensor] = None,
534 | labels: Optional[torch.Tensor] = None,
535 | output_text_embeddings: Optional[bool] = None,
536 | output_class_embeddings: Optional[bool] = None,
537 | return_dict: Optional[bool] = None,
538 | **kwargs
539 | ) -> Union[Tuple, SequenceClassifierOutput]:
540 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
541 |
542 | text_embeddings = self.encode_text(input_ids, attention_mask)
543 | class_embeddings = self.encode_classes(class_input_ids, class_attention_mask, labels_mask)
544 | logits = self.scorer(text_embeddings, class_embeddings) * self.logit_scale.to(class_embeddings.device)
545 |
546 | if labels_mask is not None:
547 | logits = torch.where(labels_mask == 0, -1e3, logits)
548 |
549 | loss = self.get_loss(logits, labels, classes_embedding_mask=labels_mask)
550 |
551 | if not return_dict:
552 | output = (logits,)
553 | return ((loss,) + output) if loss is not None else output
554 |
555 | return GLiClassOutput(
556 | loss=loss, logits=logits,
557 | text_embeddings = text_embeddings if output_text_embeddings else None,
558 | class_embeddings = class_embeddings if output_class_embeddings else None,
559 | )
560 |
561 |
562 | class GLiClassBiEncoderFused(GLiClassBiEncoder):
563 | def __init__(self, config: GLiClassModelConfig, from_pretrained=False):
564 | super().__init__(config, from_pretrained)
565 |
566 | def encode_text(self, input_ids, attention_mask, class_embeddings, labels_mask):
567 | embedding_layer = self.encoder_model.get_input_embeddings()
568 | inputs_embeds = embedding_layer(input_ids)
569 |
570 | class_token_mask = input_ids==self.config.class_token_index
571 | batch_indices, class_token_indices = torch.where(class_token_mask)
572 |
573 | labels_batch_indices, labels_indices = torch.where(labels_mask==1)
574 |
575 | selected_class_embeddings = class_embeddings[labels_batch_indices, labels_indices]
576 |
577 | inputs_embeds[batch_indices, class_token_indices] = selected_class_embeddings
578 | encoder_outputs = self.encoder_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask.squeeze(1))
579 |
580 | post_class_embeddings = torch.zeros_like(class_embeddings)
581 | post_class_embeddings[labels_batch_indices, labels_indices] = encoder_outputs[0][batch_indices, class_token_indices]
582 | return encoder_outputs, post_class_embeddings
583 |
584 | def forward(
585 | self,
586 | input_ids: Optional[torch.Tensor] = None,
587 | attention_mask: Optional[torch.Tensor] = None,
588 | class_input_ids: Optional[torch.Tensor] = None,
589 | class_attention_mask: Optional[torch.Tensor] = None,
590 | labels_mask: Optional[torch.Tensor] = None,
591 | labels: Optional[torch.Tensor] = None,
592 | output_text_embeddings: Optional[bool] = None,
593 | output_class_embeddings: Optional[bool] = None,
594 | return_dict: Optional[bool] = None,
595 | **kwargs
596 | ) -> Union[Tuple, SequenceClassifierOutput]:
597 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
598 |
599 | raw_class_embeddings = self.encode_classes(class_input_ids, class_attention_mask, labels_mask)
600 |
601 | encoder_outputs, class_embeddings = self.encode_text(input_ids, attention_mask, raw_class_embeddings, labels_mask)
602 |
603 | text_embeddings = self.pool_outputs(encoder_outputs)
604 |
605 | logits = self.scorer(text_embeddings, class_embeddings) * self.logit_scale.to(class_embeddings.device)
606 |
607 | if labels_mask is not None:
608 | logits = torch.where(labels_mask == 0, -1e3, logits)
609 |
610 | loss = self.get_loss(logits, labels, classes_embedding_mask=labels_mask)
611 |
612 | if not return_dict:
613 | output = (logits,)
614 | return ((loss,) + output) if loss is not None else output
615 |
616 | return GLiClassOutput(
617 | loss=loss, logits=logits,
618 | text_embeddings = text_embeddings if output_text_embeddings else None,
619 | class_embeddings = class_embeddings if output_class_embeddings else None,
620 | )
621 |
622 |
623 | class GLiClassModel(GLiClassPreTrainedModel):
624 | def __init__(self, config, from_pretrained=False):
625 | super().__init__(config)
626 | if config.architecture_type == 'uni-encoder':
627 | self.model = GLiClassUniEncoder(config, from_pretrained)
628 | elif config.architecture_type == 'bi-encoder':
629 | self.model = GLiClassBiEncoder(config, from_pretrained)
630 | elif config.architecture_type == 'bi-encoder-fused':
631 | self.model = GLiClassBiEncoderFused(config, from_pretrained)
632 | elif config.architecture_type == 'encoder-decoder':
633 | self.model = GLiClassEncoderDecoder(config, from_pretrained)
634 | self.post_init()
635 |
636 | def get_input_embeddings(self):
637 | if self.config.architecture_type in {'uni-encoder'}:
638 | return self.model.encoder_model.get_input_embeddings()
639 | elif self.config.architecture_type == 'encoder-decoder':
640 | return self.model.encoder_decoder_model.get_input_embeddings()
641 | else:
642 | raise NotImplementedError('Getting input embeddings is not implemented for bi-encoder architecture')
643 |
644 | def set_input_embeddings(self, value):
645 | if self.config.architecture_type in {'uni-encoder'}:
646 | self.model.encoder_model.set_input_embeddings(value)
647 | return None
648 | elif self.config.architecture_type == 'encoder-decoder':
649 | self.model.encoder_decoder_model.set_input_embeddings(value)
650 | elif self.config.architecture_type in {'bi-encoder', 'bi-encoder-fused'}:
651 | self.model.encoder_model.set_input_embeddings(value)
652 | else:
653 | raise NotImplementedError('Setting input embeddings is not implemented for bi-encoder architecture')
654 |
655 | def tie_weights(self):
656 | if self.config.architecture_type in {'uni-encoder'}:
657 | return self.model.encoder_model.tie_weights()
658 | elif self.config.architecture_type == 'encoder-decoder':
659 | return self.model.encoder_decoder_model.tie_weights()
660 | elif self.config.architecture_type in {'bi-encoder', 'bi-encoder-fused'}:
661 | return self.model.encoder_model.tie_weights()
662 | else:
663 | raise NotImplementedError('Tie weights is not implemented for bi-encoder architecture')
664 |
665 | def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
666 | if self.config.architecture_type in {'uni-encoder'}:
667 | model_embeds = self.model.encoder_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
668 | elif self.config.architecture_type == 'encoder-decoder':
669 | model_embeds = self.model.encoder_decoder_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
670 | elif self.config.architecture_type in {'bi-encoder-fused'}:
671 | model_embeds = self.model.encoder_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
672 | else:
673 | raise NotImplementedError('Resizing is not implemented for bi-encoder architecture')
674 | self.config.encoder_config.vocab_size = model_embeds.num_embeddings
675 | self.config.vocab_size = model_embeds.num_embeddings
676 | self.vocab_size = model_embeds.num_embeddings
677 | return model_embeds
678 |
679 | def forward(self, *args, **kwargs):
680 | outputs = self.model(*args, **kwargs)
681 | return outputs
--------------------------------------------------------------------------------
/gliclass/pipeline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | from typing import List, Dict, Union
4 | from transformers import AutoTokenizer
5 | from abc import ABC, abstractmethod
6 | from .model import GLiClassModel, GLiClassBiEncoder
7 | from .utils import retrieval_augmented_text
8 |
9 | class BaseZeroShotClassificationPipeline(ABC):
10 | def __init__(self, model, tokenizer, max_classes=25, max_length=1024,
11 | classification_type='multi-label', device='cuda:0', progress_bar=True):
12 | self.model = model
13 | if isinstance(tokenizer, str):
14 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
15 | else:
16 | self.tokenizer = tokenizer
17 | self.max_classes = max_classes
18 | self.classification_type = classification_type
19 | self.max_length = max_length
20 | self.progress_bar = progress_bar
21 |
22 | if not isinstance(device, torch.device):
23 | if torch.cuda.is_available() and 'cuda' in device:
24 | self.device = torch.device(device)
25 | else:
26 | self.device = torch.device('cpu')
27 | else:
28 | self.device = device
29 |
30 | if self.model.device != self.device:
31 | self.model.to(self.device)
32 |
33 | @abstractmethod
34 | def prepare_inputs(self, texts, labels, same_labels = False):
35 | pass
36 |
37 | @torch.no_grad()
38 | def get_embeddings(self, texts, labels, batch_size=8):
39 | if isinstance(texts, str):
40 | texts = [texts]
41 | if isinstance(labels[0], str):
42 | same_labels = True
43 | else:
44 | same_labels = False
45 |
46 | results = []
47 |
48 | iterable = range(0, len(texts), batch_size)
49 | if self.progress_bar:
50 | iterable = tqdm(iterable)
51 |
52 | for idx in iterable:
53 | batch_texts = texts[idx:idx+batch_size]
54 | tokenized_inputs = self.prepare_inputs(batch_texts, labels, same_labels)
55 | model_output = self.model(**tokenized_inputs, output_text_embeddings=True,
56 | output_class_embeddings=True)
57 | logits = model_output.logits
58 | text_embeddings = model_output.text_embeddings
59 | class_embeddings = model_output.class_embeddings
60 | batch_size = logits.shape[0]
61 |
62 | for i in range(batch_size):
63 | result = {
64 | 'logits': logits[i].cpu().numpy(),
65 | 'text_embedding': text_embeddings[i].cpu().numpy(),
66 | 'class_embeddings': class_embeddings[i].cpu().numpy()
67 | }
68 | results.append(result)
69 |
70 | return results
71 |
72 | @torch.no_grad()
73 | def __call__(self, texts, labels, threshold = 0.5, batch_size=8, rac_examples=None):
74 | if isinstance(texts, str):
75 | if rac_examples:
76 | texts = retrieval_augmented_text(texts, rac_examples)
77 | texts = [texts]
78 | else:
79 | if rac_examples:
80 | texts = [retrieval_augmented_text(text, examples) for text, examples in zip(texts, rac_examples)]
81 | if isinstance(labels[0], str):
82 | same_labels = True
83 | else:
84 | same_labels = False
85 |
86 | results = []
87 | iterable = range(0, len(texts), batch_size)
88 | if self.progress_bar:
89 | iterable = tqdm(iterable)
90 |
91 | for idx in iterable:
92 | batch_texts = texts[idx:idx+batch_size]
93 | if not same_labels:
94 | batch_labels = labels[idx:idx+batch_size]
95 | else:
96 | batch_labels = labels
97 | tokenized_inputs = self.prepare_inputs(batch_texts, batch_labels, same_labels)
98 | model_output = self.model(**tokenized_inputs)
99 | logits = model_output.logits
100 | if self.classification_type == 'single-label':
101 | for i in range(len(batch_texts)):
102 | score = torch.softmax(logits[i], dim=-1)
103 | if same_labels:
104 | curr_labels = batch_labels
105 | else:
106 | curr_labels = batch_labels[i]
107 | pred_label = curr_labels[torch.argmax(score).item()]
108 | results.append([{'label': pred_label, 'score': score.max().item()}])
109 | elif self.classification_type == 'multi-label':
110 | sigmoid = torch.nn.Sigmoid()
111 | probs = sigmoid(logits)
112 | for i in range(len(batch_texts)):
113 | text_results = []
114 | if same_labels:
115 | curr_labels = batch_labels
116 | else:
117 | curr_labels = batch_labels[i]
118 | for j, prob in enumerate(probs[i][:len(curr_labels)]):
119 | score = prob.item()
120 | if score>=threshold and len(curr_labels):
121 | text_results.append({'label': curr_labels[j], 'score': score})
122 | results.append(text_results)
123 | else:
124 | raise ValueError("Unsupported classification type: choose 'single-label' or 'multi-label'")
125 | return results
126 |
127 | class UniEncoderZeroShotClassificationPipeline(BaseZeroShotClassificationPipeline):
128 | def __init__(self, model, tokenizer, max_classes=25, max_length=1024,
129 | classification_type='multi-label', device='cuda:0', progress_bar=True):
130 | super().__init__(model, tokenizer, max_classes, max_length, classification_type, device, progress_bar)
131 |
132 | def prepare_input(self, text, labels):
133 | input_text = []
134 | for label in labels:
135 | label_tag = f"<>{label.lower()}"
136 | input_text.append(label_tag)
137 | input_text.append('<>')
138 | if self.model.config.prompt_first:
139 | input_text = ''.join(input_text)+text
140 | else:
141 | input_text = text+''.join(input_text)
142 | return input_text
143 |
144 | def prepare_inputs(self, texts, labels, same_labels = False):
145 | inputs = []
146 |
147 | if same_labels:
148 | for text in texts:
149 | inputs.append(self.prepare_input(text, labels))
150 | else:
151 | for text, labels_ in zip(texts, labels):
152 | inputs.append(self.prepare_input(text, labels_))
153 |
154 | tokenized_inputs = self.tokenizer(inputs, truncation=True,
155 | max_length=self.max_length,
156 | padding="longest", return_tensors="pt").to(self.device)
157 |
158 | return tokenized_inputs
159 |
160 | class EncoderDecoderZeroShotClassificationPipeline(BaseZeroShotClassificationPipeline):
161 | def __init__(self, model, tokenizer, max_classes=25, max_length=1024,
162 | classification_type='multi-label', device='cuda:0', progress_bar=True):
163 | super().__init__(model, tokenizer, max_classes, max_length, classification_type, device, progress_bar)
164 |
165 | def prepare_labels_prompt(self, labels):
166 | input_text = []
167 | for label in labels:
168 | label_tag = f"<>{label.lower()}"
169 | input_text.append(label_tag)
170 | input_text.append('<>')
171 | input_text = ''.join(input_text)
172 | return input_text
173 |
174 | def prepare_inputs(self, texts, labels, same_labels = False):
175 | prompts = []
176 |
177 | if same_labels:
178 | for _ in texts:
179 | prompts.append(self.prepare_labels_prompt(labels))
180 | else:
181 | for labels_ in labels:
182 | prompts.append(self.prepare_labels_prompt(labels_))
183 |
184 | tokenized_inputs = self.tokenizer(texts, truncation=True,
185 | max_length=self.max_length,
186 | padding="longest", return_tensors="pt").to(self.device)
187 |
188 | tokenized_classes = self.tokenizer(prompts, max_length=self.max_length,
189 | truncation=True, padding="longest", return_tensors='pt').to(self.device)
190 | tokenized_inputs["class_input_ids"] = tokenized_classes["input_ids"]
191 | tokenized_inputs["class_attention_mask"] = tokenized_classes["attention_mask"]
192 |
193 | return tokenized_inputs
194 |
195 | class BiEncoderZeroShotClassificationPipeline(BaseZeroShotClassificationPipeline):
196 | def __init__(self, model, tokenizer, max_classes=25, max_length=1024,
197 | classification_type='multi-label', device='cuda:0', progress_bar=True):
198 | super().__init__(model, tokenizer, max_classes, max_length, classification_type, device, progress_bar)
199 | self.labels_tokenizer = AutoTokenizer.from_pretrained(model.config.label_model_name)
200 |
201 | def prepare_input(self, text, labels):
202 | input_text = []
203 | for label in labels:
204 | label_tag = f"<>"
205 | input_text.append(label_tag)
206 | input_text.append('<>')
207 | if self.model.config.prompt_first:
208 | input_text = ''.join(input_text)+text
209 | else:
210 | input_text = text+''.join(input_text)
211 | return input_text
212 |
213 | def prepare_inputs(self, texts, labels, same_labels=False):
214 | if self.model.config.architecture_type == 'bi-encoder-fused':
215 | inputs = []
216 | if same_labels:
217 | for text in texts:
218 | inputs.append(self.prepare_input(text, labels))
219 | else:
220 | for text, labels_ in zip(texts, labels):
221 | inputs.append(self.prepare_input(text, labels_))
222 | else:
223 | inputs = texts
224 | if same_labels:
225 | # If all texts use the same labels
226 | tokenized_inputs = self.tokenizer(inputs, truncation=True,
227 | max_length=self.max_length,
228 | padding="longest", return_tensors="pt").to(self.device)
229 |
230 | tokenized_labels = self.labels_tokenizer(labels, truncation=True,
231 | max_length=self.max_length,
232 | padding="longest", return_tensors="pt").to(self.device)
233 | tokenized_inputs['class_input_ids'] = tokenized_labels['input_ids'].expand(len(texts), -1, -1)
234 | tokenized_inputs['class_attention_mask'] = tokenized_labels['attention_mask'].expand(len(texts), -1, -1)
235 |
236 | labels_mask = [[1 for i in range(len(labels))] for j in range(len(texts))]
237 | tokenized_inputs["labels_mask"] = torch.tensor(labels_mask).to(self.device)
238 | else:
239 | # If each text has its own set of labels
240 | tokenized_inputs = self.tokenizer(inputs, truncation=True,
241 | max_length=self.max_length,
242 | padding="longest", return_tensors="pt").to(self.device)
243 |
244 | class_input_ids = []
245 | class_attention_mask = []
246 |
247 | for labels_set in labels:
248 | tokenized_labels = self.labels_tokenizer(labels_set, truncation=True,
249 | max_length=self.max_length,
250 | padding="max_length",
251 | return_tensors="pt").to(self.device)
252 | class_input_ids.append(tokenized_labels["input_ids"])
253 | class_attention_mask.append(tokenized_labels["attention_mask"])
254 |
255 | tokenized_inputs["class_input_ids"] = torch.stack(class_input_ids)
256 | tokenized_inputs["class_attention_mask"] = torch.stack(class_attention_mask)
257 |
258 | labels_mask = [[1 for i in range(len(labels[j]))] for j in range(len(texts))]
259 | tokenized_inputs["labels_mask"] = torch.tensor(labels_mask).to(self.device)
260 | return tokenized_inputs
261 |
262 | class ZeroShotClassificationPipeline:
263 | def __init__(self, model, tokenizer, max_classes=25, max_length=1024,
264 | classification_type='multi-label', device='cuda:0', progress_bar=True):
265 | if isinstance(model, str):
266 | model = GLiClassBiEncoder.from_pretrained(model)
267 | if model.config.architecture_type == 'uni-encoder':
268 | self.pipe = UniEncoderZeroShotClassificationPipeline(model, tokenizer, max_classes,
269 | max_length, classification_type, device, progress_bar)
270 | elif model.config.architecture_type in {'encoder-decoder'}:
271 | self.pipe = EncoderDecoderZeroShotClassificationPipeline(model, tokenizer, max_classes,
272 | max_length, classification_type, device, progress_bar)
273 | elif model.config.architecture_type in {'bi-encoder', 'bi-encoder-fused'}:
274 | self.pipe = BiEncoderZeroShotClassificationPipeline(model, tokenizer, max_classes,
275 | max_length, classification_type, device, progress_bar)
276 | else:
277 | raise NotImplementedError("This artchitecture is not implemented")
278 |
279 | def get_embeddings(self, *args, **kwargs):
280 | results = self.pipe.get_embeddings(*args, **kwargs)
281 | return results
282 |
283 | def __call__(self, texts, labels, threshold = 0.5, batch_size=8, rac_examples=None):
284 | results = self.pipe(texts, labels, threshold = threshold, batch_size=batch_size, rac_examples=rac_examples)
285 | return results
286 |
287 | class ZeroShotClassificationWithLabelsChunkingPipeline(BaseZeroShotClassificationPipeline):
288 | def __init__(self, model, tokenizer, max_classes=25, max_length=1024,
289 | classification_type='multi-label', device='cuda:0'):
290 | super().__init__(model, tokenizer, max_classes, max_length, classification_type, device)
291 | if isinstance(model, str):
292 | self.model = GLiClassModel.from_pretrained(model)
293 | else:
294 | self.model = model
295 |
296 | if self.model.device != self.device:
297 | self.model.to(self.device)
298 |
299 | def prepare_input(self, text, labels):
300 | input_text = []
301 | for label in labels:
302 | label_tag = f"<>{label.lower()}"
303 | input_text.append(label_tag)
304 | input_text.append('<>')
305 | input_text = ''.join(input_text)+text
306 | return input_text
307 |
308 | def prepare_inputs(self, texts, labels):
309 | inputs = []
310 |
311 | for text in texts:
312 | inputs.append(self.prepare_input(text, labels))
313 |
314 | tokenized_inputs = self.tokenizer(inputs, truncation=True,
315 | max_length=self.max_length,
316 | padding="longest", return_tensors="pt").to(self.device)
317 | return tokenized_inputs
318 |
319 | @torch.no_grad()
320 | def __call__(self, texts, labels, threshold = 0.5, batch_size=8, labels_chunk_size=4): #labels - List[str]
321 | results = []
322 |
323 | iterable = range(0, len(texts), batch_size)
324 | if self.progress_bar:
325 | iterable = tqdm(iterable)
326 |
327 | for idx in iterable:
328 | batch_texts = texts[idx:idx+batch_size]
329 |
330 | batch_results = []
331 | for labels_batch in range(0, len(labels), labels_chunk_size):
332 | curr_labels = labels[labels_batch:labels_batch+labels_chunk_size]
333 | tokenized_inputs = self.prepare_inputs(batch_texts, curr_labels)
334 | model_output = self.model(**tokenized_inputs)
335 | logits = model_output.logits
336 | curr_results = []
337 | if self.classification_type == 'single-label':
338 | for i in range(len(batch_texts)):
339 | score = logits[i]
340 | pred_label = curr_labels[torch.argmax(score).item()]
341 | curr_results.append([{'label': pred_label, 'score': score.max().item()}])
342 | elif self.classification_type == 'multi-label':
343 | sigmoid = torch.nn.Sigmoid()
344 | probs = sigmoid(logits)
345 | for i in range(len(batch_texts)):
346 | text_results = []
347 | for j, prob in enumerate(probs[i]):
348 | score = prob.item()
349 | if score>threshold:
350 | text_results.append({'label': curr_labels[j], 'score': score})
351 | curr_results.append(text_results)
352 | else:
353 | raise ValueError("Unsupported classification type: choose 'single-label' or 'multi-label'")
354 | batch_results.append(curr_results)
355 |
356 | # Merge results from different label chunks
357 | merged_batch_results = []
358 | for i in range(len(batch_texts)):
359 | text_results = []
360 | for chunk_result in batch_results:
361 | text_results.extend(chunk_result[i])
362 |
363 | if self.classification_type == 'single-label':
364 | # Keep only the highest scoring label
365 | merged_batch_results.append([max(text_results, key=lambda x: x['score'])])
366 | else:
367 | # Sort multi-label results by score in descending order
368 | merged_batch_results.append(sorted(text_results, key=lambda x: x['score'], reverse=True))
369 |
370 | results.extend(merged_batch_results)
371 |
372 | return results
373 |
--------------------------------------------------------------------------------
/gliclass/poolings.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class GlobalMaxPooling1D(nn.Module):
8 | """Applies Global Max Pooling on the timesteps dimension."""
9 |
10 | def forward(self, x: torch.Tensor):
11 | return x.amax(dim=1)
12 |
13 |
14 | class FirstTokenPooling1D(nn.Module):
15 | """Takes the first token's embedding."""
16 |
17 | def forward(self, x: torch.Tensor):
18 | return x[:, 0, :]
19 |
20 |
21 | class LastTokenPooling1D(nn.Module):
22 | """Takes the last token's embedding."""
23 |
24 | def forward(self, x: torch.Tensor):
25 | return x[:, -1, :]
26 |
27 |
28 | class GlobalAvgPooling1D(nn.Module):
29 | """Applies Global Average Pooling on the timesteps dimension."""
30 |
31 | def forward(
32 | self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
33 | ):
34 | if attention_mask is not None:
35 | attention_mask = attention_mask.repeat((1, 1, x.shape[-1])).to(
36 | dtype=x.dtype
37 | )
38 | x = x * attention_mask
39 | return x.sum(1) / attention_mask.sum(1)
40 | else:
41 | return x.mean(dim=1)
42 |
43 |
44 | class GlobalSumPooling1D(nn.Module):
45 | """Applies Global Sum Pooling on the timesteps dimension."""
46 |
47 | def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
48 | if attention_mask is not None:
49 | x = x * attention_mask
50 | return x.sum(dim=1)
51 |
52 |
53 | class GlobalRMSPooling1D(nn.Module):
54 | """Applies Global RMS Pooling on the timesteps dimension."""
55 |
56 | def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
57 | if attention_mask is not None:
58 | attention_mask = attention_mask.repeat((1, 1, x.shape[-1])).to(
59 | dtype=x.dtype
60 | )
61 | x = x * attention_mask
62 | return (x.pow(2).sum(dim=1) / attention_mask.sum(1)).sqrt()
63 | else:
64 | return x.pow(2).mean(dim=1).sqrt()
65 |
66 |
67 | class GlobalAbsMaxPooling1D(nn.Module):
68 | """Applies Global Max Pooling of absolute values on the timesteps dimension."""
69 |
70 | def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
71 | if attention_mask is not None:
72 | attention_mask = attention_mask.repeat((1, 1, x.shape[-1])).to(
73 | dtype=x.dtype
74 | )
75 | x = x * attention_mask
76 | return x.abs().amax(dim=1)
77 |
78 |
79 | class GlobalAbsAvgPooling1D(nn.Module):
80 | """Applies Global Average Pooling of absolute values on the timesteps dimension."""
81 |
82 | def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
83 | if attention_mask is not None:
84 | attention_mask = attention_mask.repeat((1, 1, x.shape[-1])).to(
85 | dtype=x.dtype
86 | )
87 | x = (x * attention_mask).abs()
88 | return x.sum(dim=1) / attention_mask.sum(1)
89 | else:
90 | return x.abs().mean(dim=1)
91 |
92 | POOLING2OBJECT = {
93 | 'max': GlobalMaxPooling1D,
94 | 'first': FirstTokenPooling1D,
95 | 'last': LastTokenPooling1D,
96 | 'avg': GlobalAvgPooling1D,
97 | 'sum': GlobalSumPooling1D,
98 | 'rms': GlobalRMSPooling1D,
99 | 'abs_max': GlobalAbsMaxPooling1D,
100 | 'abs_avg': GlobalAbsAvgPooling1D
101 | }
--------------------------------------------------------------------------------
/gliclass/scorers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class ScorerWeightedDot(nn.Module):
5 | def __init__(self, hidden_size, dropout=0.1):
6 | super().__init__()
7 |
8 | self.proj_text = nn.Linear(hidden_size, hidden_size * 2)
9 | self.proj_label = nn.Linear(hidden_size, hidden_size * 2)
10 |
11 | self.out_mlp = nn.Sequential(
12 | nn.Linear(hidden_size * 3, hidden_size * 4),
13 | nn.Dropout(dropout),
14 | nn.ReLU(),
15 | nn.Linear(hidden_size * 4, 1) # start, end, score
16 | )
17 |
18 | def forward(self, text_rep, label_rep):
19 | batch_size, hidden_size = text_rep.shape
20 | num_classes = label_rep.shape[1]
21 |
22 | # (batch_size, 1, 3, hidden_size)
23 | text_rep = self.proj_text(text_rep).view(batch_size, 1, 1, 2, hidden_size)
24 | label_rep = self.proj_label(label_rep).view(batch_size, 1, num_classes, 2, hidden_size)
25 |
26 | # (2, batch_size, 1, num_classes, hidden_size)
27 | text_rep = text_rep.expand(-1, -1, num_classes, -1, -1).permute(3, 0, 1, 2, 4)
28 | label_rep = label_rep.expand(-1, 1, -1, -1, -1).permute(3, 0, 1, 2, 4)
29 |
30 | # (batch_size, 1, num_classes, hidden_size * 3)
31 | cat = torch.cat([text_rep[0], label_rep[0], text_rep[1] * label_rep[1]], dim=-1)
32 |
33 | # (batch_size, num_classes)
34 | scores = self.out_mlp(cat).view(batch_size, num_classes)
35 |
36 | return scores
37 |
38 | class ScorerDot(nn.Module):
39 | def __init__(self, *args):
40 | super().__init__()
41 | pass
42 |
43 | def forward(self, text_rep, label_rep):
44 | # dot product with einsum
45 | scores = torch.einsum('BD,BCD->BC', text_rep, label_rep)
46 | return scores
47 |
48 | class MLPScorer(nn.Module):
49 | def __init__(self, hidden_size, mlp_hidden_size=256):
50 | super().__init__()
51 |
52 | # Calculate the input size for the MLP
53 | total_input_size = hidden_size*2
54 |
55 | # Define the MLP
56 | self.mlp = nn.Sequential(
57 | nn.Linear(total_input_size, mlp_hidden_size),
58 | nn.ReLU(),
59 | nn.Linear(mlp_hidden_size, mlp_hidden_size // 2),
60 | nn.ReLU(),
61 | nn.Linear(mlp_hidden_size // 2, 1)
62 | )
63 |
64 | def forward(self, text_rep, label_rep):
65 | # Concatenate text and label representations
66 | batch_size, num_labels, dim = label_rep.shape
67 | text_rep = text_rep.unsqueeze(1).expand(batch_size, num_labels, dim)
68 | combined_rep = torch.cat([text_rep, label_rep], dim=-1)
69 |
70 | # Pass through MLP
71 | scores = self.mlp(combined_rep).squeeze(-1)
72 |
73 | return scores
74 |
75 | class HopfieldScorer(nn.Module):
76 | def __init__(self, hidden_size, mlp_hidden_size=256, beta=4, num_iteration=1):
77 | super().__init__()
78 |
79 | self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
80 | self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
81 | self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
82 |
83 | # Define the MLP
84 | self.mlp = nn.Sequential(
85 | nn.Linear(hidden_size, mlp_hidden_size),
86 | nn.ReLU(),
87 | nn.Linear(mlp_hidden_size, mlp_hidden_size // 2),
88 | nn.ReLU(),
89 | nn.Linear(mlp_hidden_size // 2, 1)
90 | )
91 |
92 | self.beta = beta
93 | self.num_iteration = num_iteration
94 |
95 | def forward(self, text_rep, label_rep):
96 | """
97 | text_rep: [batch_size, hidden_size]
98 | label_rep: [batch_size, num_labels, hidden_size]
99 | """
100 | for i in range(self.num_iteration):
101 | # Expand text_rep to match label_rep’s batch shape
102 | text_rep_expanded = text_rep.unsqueeze(1) # [batch_size, 1, dim]
103 |
104 | # Compute Q, K, V
105 | query = self.q_proj(label_rep) # [batch_size, num_labels, dim]
106 | key = self.k_proj(text_rep_expanded) # [batch_size, 1, dim]
107 | value = self.v_proj(text_rep_expanded) # [batch_size, 1, dim]
108 |
109 |
110 | attn = torch.bmm(query, key.transpose(1, 2)) # [b, num_labels, 1]
111 | attn = attn * self.beta # optional beta scaling
112 | attn = torch.nn.functional.softmax(attn, dim=1) # softmax over labels
113 |
114 | context = attn * value # [b, num_labels, dim]
115 |
116 | label_rep = label_rep + context
117 |
118 | scores = self.mlp(label_rep).squeeze(-1) # [b, num_labels]
119 |
120 | return scores
121 |
122 |
123 | # Example dictionary for scorers
124 | SCORER2OBJECT = {
125 | "weighted-dot": ScorerWeightedDot,
126 | "simple": ScorerDot,
127 | "mlp": MLPScorer,
128 | "hopfield": HopfieldScorer # <-- Add reference here if you want
129 | }
--------------------------------------------------------------------------------
/gliclass/training.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Dict, List, Union, Any, Callable
2 | from tqdm import tqdm
3 | import numpy as np
4 | import os
5 | from dataclasses import dataclass, field
6 | import torch
7 | from transformers.trainer import (
8 | is_sagemaker_mp_enabled,
9 | get_parameter_names,
10 | ALL_LAYERNORM_LAYERS,
11 | )
12 | import transformers
13 | from transformers import ZeroShotClassificationPipeline as TransformersClassificationPipeline
14 | from .utils import default_f1_reward
15 | from .pipeline import ZeroShotClassificationPipeline
16 |
17 | @dataclass
18 | class TrainingArguments(transformers.TrainingArguments):
19 | cache_dir: Optional[str] = field(default=None)
20 | optim: str = field(default="adamw_torch")
21 | others_lr: Optional[float] = None
22 | others_weight_decay: Optional[float] = 0.0
23 |
24 | class Trainer(transformers.Trainer):
25 | def training_step(self, model, inputs, *args, **kwargs) -> torch.Tensor:
26 | """
27 | Perform a training step on a batch of inputs.
28 |
29 | Subclass and override to inject custom behavior.
30 |
31 | Args:
32 | model (`nn.Module`):
33 | The model to train.
34 | inputs (`Dict[str, Union[torch.Tensor, Any]]`):
35 | The inputs and targets of the model.
36 |
37 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
38 | argument `labels`. Check your model's documentation for all accepted arguments.
39 |
40 | Return:
41 | `torch.Tensor`: The tensor with training loss on this batch.
42 | """
43 | model.train()
44 | try:
45 | if "labels_text" in inputs:
46 | labels_text = inputs.pop('labels_text')
47 | if "input_texts" in inputs:
48 | input_texts = inputs.pop('input_texts')
49 | inputs = self._prepare_inputs(inputs)
50 | if is_sagemaker_mp_enabled():
51 | loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
52 | return loss_mb.reduce_mean().detach().to(self.args.device)
53 |
54 | with self.compute_loss_context_manager():
55 | loss = self.compute_loss(model, inputs)
56 |
57 | del inputs
58 | torch.cuda.empty_cache()
59 |
60 | kwargs = {}
61 |
62 | # For LOMO optimizers you need to explicitly use the learnign rate
63 | # if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
64 | # kwargs["learning_rate"] = self._get_learning_rate()
65 |
66 | if self.args.n_gpu > 1:
67 | loss = loss.mean() # mean() to average on multi-gpu parallel training
68 |
69 | if self.use_apex:
70 | with amp.scale_loss(loss, self.optimizer) as scaled_loss:
71 | scaled_loss.backward()
72 | else:
73 | self.accelerator.backward(loss, **kwargs)
74 |
75 | return loss.detach() / self.args.gradient_accumulation_steps
76 | except Exception as e:
77 | print(f"Skipping iteration due to error: {e}")
78 | model.zero_grad(set_to_none=True)
79 | torch.cuda.empty_cache()
80 | return torch.tensor(0.0, requires_grad=True).to(model.device)
81 |
82 | def prediction_step(
83 | self,
84 | model: torch.nn.Module,
85 | inputs: Dict[str, Union[torch.Tensor, Any]],
86 | prediction_loss_only: bool,
87 | ignore_keys: Optional[List[str]] = None,
88 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
89 | """
90 | Perform an evaluation step on model using inputs.
91 | Subclass and override to inject custom behavior.
92 | Args:
93 | model (nn.Module):
94 | The model to evaluate.
95 | inputs (Dict[str, Union[torch.Tensor, Any]]):
96 | The inputs and targets of the model.
97 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
98 | argument labels. Check your model's documentation for all accepted arguments.
99 | prediction_loss_only (bool):
100 | Whether or not to return the loss only.
101 | ignore_keys (List[str], *optional*):
102 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
103 | gathering predictions.
104 | Return:
105 | Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
106 | logits and labels (each being optional).
107 | """
108 | try:
109 | with torch.no_grad():
110 | if "labels_text" in inputs:
111 | labels_text = inputs.pop('labels_text')
112 | if "input_texts" in inputs:
113 | input_texts = inputs.pop('input_texts')
114 | loss = None
115 | with self.compute_loss_context_manager():
116 | try:
117 | outputs = model(**inputs)
118 | except Exception as e:
119 | raise RuntimeError(f"Error during model forward pass: {str(e)}")
120 |
121 | if not hasattr(outputs, 'loss'):
122 | raise AttributeError("Model output does not contain 'loss' attribute")
123 | loss = outputs.loss
124 |
125 | if not hasattr(outputs, 'logits'):
126 | raise AttributeError("Model output does not contain 'logits' attribute")
127 | logits = outputs.logits
128 |
129 | if 'labels' not in inputs:
130 | raise KeyError("'labels' not found in input dictionary")
131 | labels = inputs['labels']
132 |
133 | if prediction_loss_only:
134 | return (loss, None, None)
135 | return (loss, logits, labels)
136 |
137 | except Exception as e:
138 | print(f"An error occurred during prediction step: {str(e)}")
139 | return (None, None, None)
140 |
141 | def create_optimizer(self):
142 | """
143 | Setup the optimizer.
144 |
145 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
146 | Trainer's init through `optimizers`, or subclass and override this method in a subclass.
147 | """
148 | if is_sagemaker_mp_enabled():
149 | return super().create_optimizer()
150 |
151 | opt_model = self.model
152 |
153 | if self.optimizer is None:
154 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
155 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
156 | if self.args.others_lr is not None:
157 | encoder_parameters = [name for name, _ in opt_model.named_parameters() if "encoder" in name]
158 | optimizer_grouped_parameters = [
159 | {
160 | "params": [
161 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in encoder_parameters and p.requires_grad)
162 | ],
163 | "weight_decay": self.args.others_weight_decay,
164 | "lr": self.args.others_lr,
165 | },
166 | {
167 | "params": [
168 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in encoder_parameters and p.requires_grad)
169 | ],
170 | "weight_decay": 0.0,
171 | "lr": self.args.others_lr,
172 | },
173 | {
174 | "params": [
175 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in encoder_parameters and p.requires_grad)
176 | ],
177 | "weight_decay": self.args.weight_decay,
178 | },
179 | {
180 | "params": [
181 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in encoder_parameters and p.requires_grad)
182 | ],
183 | "weight_decay": 0.0,
184 | },
185 | ]
186 | else:
187 | optimizer_grouped_parameters = [
188 | {
189 | "params": [
190 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
191 | ],
192 | "weight_decay": self.args.weight_decay,
193 | },
194 | {
195 | "params": [
196 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
197 | ],
198 | "weight_decay": 0.0,
199 | },
200 | ]
201 |
202 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
203 |
204 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
205 |
206 | return self.optimizer
207 |
208 | @dataclass
209 | class RLTrainerConfig(TrainingArguments):
210 | cliprange: float = field(
211 | default=0.2,
212 | metadata={"help": "Clip range."},
213 | )
214 | num_rl_iters: int = field(
215 | default=3,
216 | metadata={"help": "Number of RL iterations."},
217 | )
218 | gamma: float = field(
219 | default=-1,
220 | metadata={"help": "Focal loss gamma."},
221 | )
222 | alpha: float = field(
223 | default=-1,
224 | metadata={"help": "Focal loss alpha."},
225 | )
226 | labels_smoothing: float = field(
227 | default=-1,
228 | metadata={"help": "Labels smoothing factor."}
229 | )
230 | entropy_beta: float = field(
231 | default=-1,
232 | metadata={"help": "Coeficient of entropy factor."}
233 | )
234 | kl_beta: float = field(
235 | default=-1,
236 | metadata={"help": "Coeficient of KL-divergence factor."}
237 | )
238 | get_actions: str = field(
239 | default="bernoulli",
240 | metadata={"help": "How to get actions of a model, default is `bernoulli`, another option is `threshold`"},
241 | )
242 | threshold: float = field(
243 | default=0.5,
244 | metadata={"help": "Threshold value for predictions."},
245 | )
246 |
247 | class RLTrainer(Trainer):
248 | def __init__(
249 | self,
250 | value_model: Optional[torch.nn.Module] = None,
251 | reference_model: Optional[Union[ZeroShotClassificationPipeline|TransformersClassificationPipeline]] = None,
252 | reward_components: Optional[List[Tuple[str, Callable]]] = None,
253 | *args,
254 | **kwargs
255 | ):
256 | super().__init__(*args, **kwargs)
257 | if value_model is not None:
258 | self.value_model = value_model.to(self.model.device)
259 | self.reference_model = reference_model
260 | if reward_components is None:
261 | reward_components = [('f1', default_f1_reward)]
262 | self.reward_components = reward_components
263 | self._init_metrics()
264 |
265 | def _init_metrics(self):
266 | self.metrics = {
267 | 'total_loss': [],
268 | 'advantages': [],
269 | }
270 | # Initialize metrics for each reward component
271 | for name, _ in self.reward_components.items():
272 | self.metrics[f'reward_{name}'] = []
273 |
274 | def compute_rewards(
275 | self,
276 | probs: torch.Tensor,
277 | actions: torch.Tensor,
278 | original_targets: torch.Tensor,
279 | valid_mask: torch.Tensor
280 | ) -> Dict[str, torch.Tensor]:
281 | rewards = {}
282 | total_reward = 0.0
283 | for name, reward_fn in self.reward_components.items():
284 | component = reward_fn(probs, actions, original_targets, valid_mask)
285 | rewards[name] = component
286 | total_reward += component
287 | rewards['total_reward'] = total_reward
288 | return rewards
289 |
290 | def get_reference_scores(self, input_texts, labels_text):
291 | if input_texts is None or labels_text is None:
292 | return None
293 | all_scores = []
294 | with torch.no_grad():
295 | if isinstance(self.reference_model, ZeroShotClassificationPipeline):
296 | results = self.reference_model(input_texts, labels_text, threshold=0.)
297 | for id, result in enumerate(results):
298 | label2score = {item['label']: item['score'] for item in result}
299 | label_scores = [label2score[label] for label in labels_text[id]]
300 | all_scores.append(label_scores)
301 | elif isinstance(self.reference_model, TransformersClassificationPipeline):
302 | for text, labels in zip(input_texts, labels_text):
303 | result = self.reference_model(text, labels)
304 | label2score = {label:score for label, score in zip(result['labels'], result['scores'])}
305 | label_scores = [label2score[label] for label in labels_text[id]]
306 | all_scores.append(label_scores)
307 | else:
308 | raise NotImplementedError("This classification pipelines is not supported as a reference model.")
309 | max_length = max(len(seq) for seq in all_scores)
310 | all_scores = torch.FloatTensor([seq + [0] * (max_length - len(seq))
311 | for seq in all_scores]).to(self.model.device)
312 | return all_scores
313 |
314 | def compute_loss(
315 | self,
316 | inputs: torch.Tensor,
317 | targets: torch.Tensor,
318 | log_prob_prev: Optional[torch.Tensor] = None,
319 | value_outputs: Optional[torch.Tensor] = None,
320 | reference_probs: Optional[torch.Tensor] = None,
321 | **kwargs
322 | ) -> Tuple[torch.Tensor, torch.Tensor]:
323 | valid_mask = targets != -100
324 | original_targets = targets.clone()
325 |
326 | probs = torch.sigmoid(inputs)
327 |
328 | if self.args.get_actions == 'bernoulli':
329 | actions = torch.bernoulli(probs).detach()
330 | else:
331 | actions = (probs > self.args.threshold).float().detach()
332 |
333 | with torch.no_grad():
334 | metrics = self.compute_rewards(probs, actions, original_targets, valid_mask)
335 |
336 | reward = metrics['total_reward']
337 |
338 | if value_outputs is not None:
339 | state_values = value_outputs.logits[:, 0].unsqueeze(-1) # Using first token logits as value prediction
340 | value_loss = torch.nn.functional.mse_loss(state_values, reward.detach())
341 | else:
342 | state_values = reward.mean()
343 | value_loss = torch.tensor(0.0).to(inputs.device)
344 |
345 | advantages = (reward - state_values).detach()
346 | self.metrics['advantages'].append(advantages.mean().item())
347 |
348 | for name, _ in self.reward_components.items():
349 | key = f'reward_{name}'
350 | self.metrics[key].append(metrics[name].mean().item())
351 |
352 | if self.args.label_smoothing_factor > 0:
353 | smoothed_actions = actions * (1 - self.args.label_smoothing_factor) + 0.5 * self.args.label_smoothing_factor
354 | log_prob_current = (
355 | smoothed_actions * torch.log(probs + 1e-8) +
356 | (1 - smoothed_actions) * torch.log(1 - probs + 1e-8)
357 | )
358 | else:
359 | log_prob_current = (
360 | actions * torch.log(probs + 1e-8) +
361 | (1 - actions) * torch.log(1 - probs + 1e-8)
362 | )
363 |
364 | if log_prob_prev is None:
365 | log_prob_prev = log_prob_current.detach()
366 |
367 | log_probs_diff = log_prob_current - log_prob_prev
368 | ratio = torch.exp(log_probs_diff)
369 |
370 | cliprange = self.args.cliprange
371 | per_label_loss1 = ratio * advantages
372 | per_label_loss2 = torch.clamp(ratio, 1 - cliprange, 1 + cliprange) * advantages
373 | loss_elements = -torch.min(per_label_loss1, per_label_loss2)
374 |
375 | loss_elements = loss_elements * valid_mask
376 | self.metrics['total_loss'].append(loss_elements.mean().item())
377 |
378 | if self.args.gamma > 0:
379 | p_t = probs * original_targets + (1 - probs) * (1 - original_targets)
380 | loss_elements = loss_elements * (p_t ** self.args.gamma)
381 |
382 | if self.args.alpha >= 0:
383 | alpha_t = self.args.alpha * original_targets + (1 - self.args.alpha) * (1 - original_targets)
384 | loss_elements = alpha_t * loss_elements
385 |
386 | loss = loss_elements.sum() / valid_mask.shape[0] + value_loss
387 |
388 | if reference_probs is not None:
389 | ref_per_token_logps = torch.log(reference_probs + 1e-8)
390 | per_token_logps = log_prob_current
391 | per_label_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
392 | per_label_kl = per_label_kl * valid_mask
393 | kl_loss = self.args.kl_beta * per_label_kl.mean()
394 | loss = loss + kl_loss
395 |
396 | if self.args.entropy_beta:
397 | entropy = - (probs * torch.log(probs + 1e-8) +
398 | (1 - probs) * torch.log(1 - probs + 1e-8))
399 | loss = loss + self.args.entropy_beta * entropy.mean()
400 |
401 | return loss, log_prob_current
402 |
403 |
404 | def _inner_training_loop(self, *args, **kwargs):
405 | self.create_optimizer()
406 | if self.value_model is not None:
407 | value_optimizer = torch.optim.Adam(self.value_model.parameters(), lr=self.args.learning_rate)
408 | args = self.args
409 | accelerator = self.accelerator
410 | optimizer = self.optimizer
411 | model = self.model
412 | dataloader = self.get_train_dataloader()
413 | device = accelerator.device
414 |
415 | num_local_steps = len(dataloader)
416 | num_iters = args.num_train_epochs*num_local_steps
417 | pbar = tqdm(total=num_iters, desc="Training iterations")
418 | self._init_metrics()
419 |
420 | for epoch in range(args.num_train_epochs):
421 | self._init_metrics()
422 | model.train()
423 | if self.value_model is not None:
424 | self.value_model.train()
425 |
426 | for step, inputs in enumerate(dataloader):
427 | global_step = step+epoch*num_local_steps
428 |
429 | inputs = self._prepare_inputs(inputs)
430 | labels = inputs.pop('labels').to(device)
431 | if "labels_text" in inputs:
432 | labels_text = inputs.pop('labels_text')
433 | else:
434 | labels_text = None
435 | if "input_texts" in inputs:
436 | input_texts = inputs.pop('input_texts')
437 | else:
438 | input_texts = None
439 | prev_logps = None
440 | for iter in range(args.num_rl_iters):
441 | try:
442 | outputs = model(**inputs)
443 | logits = outputs.logits
444 | if self.value_model is not None:
445 | value_outputs = self.value_model(**inputs)
446 | else:
447 | value_outputs = None
448 | if self.reference_model is not None:
449 | reference_probs = self.get_reference_scores(input_texts, labels_text)
450 | else:
451 | reference_probs = None
452 | loss, current_logps = self.compute_loss(logits, labels, log_prob_prev=prev_logps,
453 | value_outputs=value_outputs,
454 | reference_probs=reference_probs)
455 | except Exception as e:
456 | print(f"An error occurred during training step: {str(e)}")
457 | del inputs
458 | model.zero_grad(set_to_none=True)
459 | torch.cuda.empty_cache()
460 | break
461 |
462 | accelerator.backward(loss)
463 | if self.args.max_grad_norm is not None:
464 | torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
465 | if self.value_model is not None:
466 | torch.nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.max_grad_norm)
467 |
468 | optimizer.step()
469 | optimizer.zero_grad()
470 | if self.value_model is not None:
471 | value_optimizer.step()
472 | value_optimizer.zero_grad()
473 |
474 | prev_logps = current_logps.detach()
475 |
476 | if global_step % args.logging_steps == 0:
477 | self.log_metrics()
478 |
479 | if args.save_steps is not None and global_step % args.save_steps == 0:
480 | self._save_checkpoint(model, step=global_step)
481 |
482 | pbar.set_postfix(epoch=epoch, step=step)
483 | pbar.update(1)
484 |
485 | if args.evaluation_strategy == "epoch":
486 | self.evaluate()
487 |
488 | def log_metrics(self):
489 | logged_metrics = {
490 | 'loss': np.mean(self.metrics['total_loss']),
491 | 'advantages': np.mean(self.metrics['advantages']),
492 | }
493 | # Add user reward components
494 | for name, _ in self.reward_components.items():
495 | key = f'reward_{name}'
496 | logged_metrics[key] = np.mean(self.metrics[key])
497 | self.log(logged_metrics)
498 | self._init_metrics()
499 |
500 | def _save_checkpoint(self, model, step=None):
501 | checkpoint_dir = f"checkpoint-{step}" if step else "final_model"
502 | output_dir = os.path.join(self.args.output_dir, checkpoint_dir)
503 | os.makedirs(output_dir, exist_ok=True)
504 | model.save_pretrained(output_dir)
505 | if self.tokenizer is not None:
506 | self.tokenizer.save_pretrained(output_dir)
507 | print(f"Checkpoint saved to {output_dir}")
--------------------------------------------------------------------------------
/gliclass/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def is_module_available(module_name):
4 | """
5 | Checks whether the specified Python module is available.
6 |
7 | Args:
8 | module_name (str): The name of the module to check.
9 |
10 | Returns:
11 | bool: True if the module is available, False otherwise.
12 | """
13 | try:
14 | __import__(module_name)
15 | return True
16 | except ImportError:
17 | return False
18 |
19 | class MissedPackageException(Exception):
20 | """Raised when the requested decoder model is not supported."""
21 | pass
22 |
23 |
24 | def retrieval_augmented_text(text: str, examples: list) -> str:
25 | """
26 | Constructs a new text by appending relevant retrieved examples to the input text.
27 |
28 | Args:
29 | text (str): The main input text.
30 | examples (list): A list of examples in the format
31 | {"text": str, "true_labels": List[str], "all_labels": List[str]}.
32 |
33 | Returns:
34 | str: The modified text with relevant examples appended.
35 | """
36 | if not examples:
37 | return text
38 |
39 | retrieved_examples = []
40 | all_labels = set(label for example in examples for label in example.get("true_labels", []))
41 | relevant_examples = [ex for ex in examples if set(ex.get("true_labels", [])) & all_labels]
42 |
43 | for example in relevant_examples:
44 | example_text = example["text"]
45 | true_labels = example.get("true_labels", [])
46 | all_labels = example.get("all_labels", [])
47 |
48 | false_labels = list(set(all_labels) - set(true_labels))
49 |
50 | true_labels_str = " ".join([f"<> {label}" for label in true_labels])
51 | false_labels_str = " ".join([f"<> {label}" for label in false_labels])
52 |
53 | retrieved_example_str = f"<> {example_text} {true_labels_str} {false_labels_str} < >"
54 | retrieved_examples.append(retrieved_example_str)
55 |
56 | augmented_text = f"{text} {' '.join(retrieved_examples)}" if retrieved_examples else text
57 |
58 | return augmented_text
59 |
60 | def default_f1_reward(
61 | probs: torch.Tensor,
62 | actions: torch.Tensor,
63 | original_targets: torch.Tensor,
64 | valid_mask: torch.Tensor
65 | ) -> torch.Tensor:
66 | """
67 | A variant that extracts list-of-indices sets and then calculates
68 | the F1 score in a classical manner. Returns shape (N, 1).
69 |
70 | Args:
71 | probs: (N, T) Tensor of probabilities (not used here but left for interface consistency).
72 | actions: (N, T) Tensor of predicted labels in {0, 1}.
73 | original_targets: (N, T) Tensor of ground-truth labels in {0, 1}.
74 | valid_mask: (N, T) Tensor indicating which positions are valid (1) vs. invalid (0).
75 |
76 | Returns:
77 | f1_scores: (N, 1) Tensor containing the F1 score for each row.
78 | """
79 | N = actions.shape[0]
80 | f1_scores = []
81 |
82 | for i in range(N):
83 | # Filter valid positions
84 | valid_preds_i = actions[i] * valid_mask[i]
85 | valid_targets_i = original_targets[i] * valid_mask[i]
86 |
87 | # Get the set of indices where we predicted 1
88 | predicted_set = set((valid_preds_i == 1).nonzero(as_tuple=True)[0].tolist())
89 | # Get the set of indices where the ground truth is 1
90 | target_set = set((valid_targets_i == 1).nonzero(as_tuple=True)[0].tolist())
91 |
92 | # Compute intersection
93 | intersection = predicted_set.intersection(target_set)
94 |
95 | # Precision
96 | if len(predicted_set) > 0:
97 | precision = len(intersection) / len(predicted_set)
98 | else:
99 | precision = 0.0
100 |
101 | # Recall
102 | if len(target_set) > 0:
103 | recall = len(intersection) / len(target_set)
104 | else:
105 | recall = 0.0
106 |
107 | # F1 score
108 | if (precision + recall) > 0:
109 | f1 = 2 * precision * recall / (precision + recall)
110 | else:
111 | f1 = 0.0
112 |
113 | f1_scores.append(f1)
114 |
115 | # Convert list to tensor shape (N, 1)
116 | f1_scores = torch.tensor(f1_scores, dtype=torch.float).unsqueeze(-1)
117 | return f1_scores.detach().to(probs.device)
--------------------------------------------------------------------------------
/notebooks/finetuning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n",
11 | "\n",
12 | "from datasets import load_dataset, Dataset, DatasetDict\n",
13 | "\n",
14 | "from sklearn.metrics import classification_report, f1_score, precision_recall_fscore_support, accuracy_score\n",
15 | "import numpy as np\n",
16 | "import random\n",
17 | "\n",
18 | "from transformers import AutoTokenizer\n",
19 | "import torch\n",
20 | "\n",
21 | "from gliclass import GLiClassModel, ZeroShotClassificationPipeline\n",
22 | "from gliclass.data_processing import GLiClassDataset, DataCollatorWithPadding\n",
23 | "from gliclass.training import TrainingArguments, Trainer\n",
24 | "\n",
25 | "device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "def get_gliclass_predictions(pipeline, test_texts, classes, batch_size=8):\n",
35 | " results = pipeline(test_texts, classes, batch_size=batch_size)#, labels_chunk_size=1)\n",
36 | " predicts = [result[0]['label'] for result in results]\n",
37 | " return predicts\n",
38 | "\n",
39 | "def evaluate(predicts, true_labels):\n",
40 | " micro = f1_score(true_labels, predicts, average=\"micro\")\n",
41 | " macro = f1_score(true_labels, predicts, average=\"macro\")\n",
42 | " weighted = f1_score(true_labels, predicts, average=\"weighted\")\n",
43 | " return {\"micro\": micro, \"macro\": macro, \"weighted\": weighted}\n",
44 | "\n",
45 | "def get_train_dataset(dataset, N, label_column='label'):\n",
46 | " ids = []\n",
47 | " label2count = {}\n",
48 | " train_dataset = dataset.shuffle(seed=41)\n",
49 | " for id, example in enumerate(train_dataset):\n",
50 | " if example[label_column] not in label2count:\n",
51 | " label2count[example[label_column]]=1\n",
52 | " elif label2count[example[label_column]]>=N:\n",
53 | " continue\n",
54 | " else:\n",
55 | " label2count[example[label_column]]+=1\n",
56 | " ids.append(id)\n",
57 | " return train_dataset.select(ids)\n",
58 | "\n",
59 | "def prepare_dataset(dataset, classes = None, text_column = 'text', label_column = \"label\", split=None):\n",
60 | " if 'test' in dataset:\n",
61 | " test_dataset = dataset['test']\n",
62 | " elif isinstance(dataset, Dataset):\n",
63 | " test_dataset = dataset\n",
64 | " else:\n",
65 | " test_dataset = dataset['train']\n",
66 | " \n",
67 | " if classes is None:\n",
68 | " classes = test_dataset.features[label_column].names\n",
69 | " if split is not None:\n",
70 | " classes = [' '.join(class_.split(split)) for class_ in classes]\n",
71 | "\n",
72 | " texts = test_dataset[text_column]\n",
73 | "\n",
74 | " true_labels = test_dataset[label_column]\n",
75 | "\n",
76 | " print(classes)\n",
77 | " if type(test_dataset[label_column][0]) == int:\n",
78 | " true_labels = [classes[label] for label in true_labels]\n",
79 | "\n",
80 | " return texts, classes, true_labels\n",
81 | "\n",
82 | "\n",
83 | "def prepare_dataset_for_training(train_dataset, classes, text_column='text', label_column='label'):\n",
84 | " id2class = {id: class_ for id, class_ in enumerate(classes)}\n",
85 | " dataset = []\n",
86 | " for example in train_dataset:\n",
87 | " label = example[label_column]\n",
88 | " if type(label)==int:\n",
89 | " label = id2class[label]\n",
90 | " item = {'text': example[text_column], 'all_labels': classes, 'true_labels': [label]}\n",
91 | " dataset.append(item)\n",
92 | " random.shuffle(dataset)\n",
93 | " return dataset\n"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "emotions = load_dataset('dair-ai/emotion')\n",
103 | "\n",
104 | "train_data = get_train_dataset(emotions['train'], N=64)\n",
105 | "\n",
106 | "test_texts, classes, true_labels = prepare_dataset(emotions)\n",
107 | "\n",
108 | "train_data = prepare_dataset_for_training(train_data, classes)\n"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "ag_news = load_dataset('ag_news')\n",
118 | "\n",
119 | "train_data = get_train_dataset(ag_news['train'], N=64)\n",
120 | "\n",
121 | "test_texts, classes, true_labels = prepare_dataset(ag_news)\n",
122 | "\n",
123 | "train_data = prepare_dataset_for_training(train_data, classes)\n"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "sst5 = load_dataset('SetFit/sst5')\n",
133 | "\n",
134 | "train_data = get_train_dataset(sst5['train'], N=64)\n",
135 | "\n",
136 | "classes = ['very negative', 'negative', 'neutral', 'positive', 'very positive']\n",
137 | "\n",
138 | "test_texts, classes, true_labels = prepare_dataset(sst5, classes=classes)\n",
139 | "\n",
140 | "train_data = prepare_dataset_for_training(train_data, classes)\n"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": null,
146 | "metadata": {},
147 | "outputs": [],
148 | "source": [
149 | "banking = load_dataset('PolyAI/banking77')\n",
150 | "\n",
151 | "train_data = get_train_dataset(banking['train'], N=32)\n",
152 | "\n",
153 | "test_texts, classes, true_labels = prepare_dataset(banking)\n",
154 | "\n",
155 | "train_data = prepare_dataset_for_training(train_data, classes)\n"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "massive = load_dataset(\"AmazonScience/massive\", \"en-US\")\n",
165 | "\n",
166 | "train_data = get_train_dataset(massive['train'], N=32, label_column='intent')\n",
167 | "\n",
168 | "test_texts, classes, true_labels = prepare_dataset(massive, text_column='utt', label_column='intent')\n",
169 | "\n",
170 | "train_data = prepare_dataset_for_training(train_data, classes, text_column='utt', label_column='intent')"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": null,
176 | "metadata": {},
177 | "outputs": [],
178 | "source": [
179 | "model_name = 'knowledgator/gliclass-base-v1.0'\n",
180 | "\n",
181 | "model = GLiClassModel.from_pretrained(model_name).to(device)\n",
182 | "tokenizer = AutoTokenizer.from_pretrained(model_name)"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": null,
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "max_length = 1024\n",
192 | "problem_type = \"multi_label_classification\"\n",
193 | "architecture_type = model.config.architecture_type\n",
194 | "prompt_first = model.config.prompt_first\n",
195 | "\n",
196 | "train_dataset = GLiClassDataset(train_data, tokenizer, max_length, problem_type, architecture_type, prompt_first)\n",
197 | "test_dataset = GLiClassDataset(train_data[:int(len(train_data)*0.1)], tokenizer, max_length, problem_type, architecture_type, prompt_first)\n",
198 | "\n",
199 | "data_collator = DataCollatorWithPadding(device=device)\n",
200 | "\n",
201 | "training_args = TrainingArguments(\n",
202 | " output_dir='models/test',\n",
203 | " learning_rate=1e-5,\n",
204 | " weight_decay=0.01,\n",
205 | " others_lr=1e-5,\n",
206 | " others_weight_decay=0.01,\n",
207 | " lr_scheduler_type='linear',\n",
208 | " warmup_ratio=0.0,\n",
209 | " per_device_train_batch_size=8,\n",
210 | " per_device_eval_batch_size=8,\n",
211 | " num_train_epochs=8,\n",
212 | " evaluation_strategy=\"epoch\",\n",
213 | " save_steps = 1000,\n",
214 | " save_total_limit=10,\n",
215 | " dataloader_num_workers=8,\n",
216 | " logging_steps=10,\n",
217 | " use_cpu = False,\n",
218 | " report_to=\"none\",\n",
219 | " fp16=False,\n",
220 | " )"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": null,
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "trainer = Trainer(\n",
230 | " model=model,\n",
231 | " args=training_args,\n",
232 | " train_dataset=train_dataset,\n",
233 | " eval_dataset=test_dataset,\n",
234 | " tokenizer=tokenizer,\n",
235 | " data_collator=data_collator,\n",
236 | ")\n",
237 | "trainer.train()"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": null,
243 | "metadata": {},
244 | "outputs": [],
245 | "source": [
246 | "pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='single-label', device='cuda:0')\n",
247 | "\n",
248 | "predicts = get_gliclass_predictions(pipeline, test_texts, classes, batch_size=8)\n",
249 | "\n",
250 | "results = evaluate(predicts, true_labels)\n",
251 | "print(results)"
252 | ]
253 | }
254 | ],
255 | "metadata": {
256 | "kernelspec": {
257 | "display_name": "Python 3",
258 | "language": "python",
259 | "name": "python3"
260 | },
261 | "language_info": {
262 | "codemirror_mode": {
263 | "name": "ipython",
264 | "version": 3
265 | },
266 | "file_extension": ".py",
267 | "mimetype": "text/x-python",
268 | "name": "python",
269 | "nbconvert_exporter": "python",
270 | "pygments_lexer": "ipython3",
271 | "version": "3.8.10"
272 | },
273 | "orig_nbformat": 4
274 | },
275 | "nbformat": 4,
276 | "nbformat_minor": 2
277 | }
278 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "gliclass"
3 | version = "0.1.11"
4 | description = "Generalist and Lightweight Model for Text Classification"
5 | authors = ["knowledgator.com"]
6 | readme = "README.md"
7 |
8 | [tool.poetry.dependencies]
9 | python = ">=3.9,<4.0"
10 | torch = "^2.0.0"
11 | transformers = ">=4.37.2,<=4.48.2"
12 | scikit-learn = "^1.0.0"
13 | numpy = "^1.26.4"
14 |
15 |
16 | [build-system]
17 | requires = ["poetry-core"]
18 | build-backend = "poetry.core.masonry.api"
--------------------------------------------------------------------------------
/test_gliclass.py:
--------------------------------------------------------------------------------
1 | from gliclass import GLiClassModel, ZeroShotClassificationPipeline
2 | from transformers import AutoTokenizer
3 | from datasets import load_dataset, Dataset
4 | from datasets import ClassLabel
5 | from sklearn.metrics import f1_score
6 | import numpy as np
7 | from transformers import AutoTokenizer
8 | import torch
9 | import argparse
10 |
11 | device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')
12 |
13 |
14 | class TestModel:
15 |
16 | def __init__(self, model, token):
17 | self.model_name = model
18 | self.model = None
19 | self.tokeinzer = None
20 | self.token=token
21 | self.datasets = ["SetFit/CR", "SetFit/sst2", "SetFit/sst5", 'stanfordnlp/imdb',
22 | "SetFit/20_newsgroups", "SetFit/enron_spam", "AmazonScience/massive",
23 | 'PolyAI/banking77', 'takala/financial_phrasebank','ag_news', 'dair-ai/emotion',
24 | "MoritzLaurer/cap_sotu", 'cornell-movie-review-data/rotten_tomatoes']
25 | self.pipeline = None
26 |
27 | self.macro_scores = []
28 | def load_model(self):
29 | self.model = GLiClassModel.from_pretrained(self.model_name, token=self.token).to(dtype=torch.float16)
30 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=self.token, add_prefix_space=True)
31 | self.pipeline = ZeroShotClassificationPipeline(self.model, self.tokenizer, classification_type='single-label',
32 | device='cuda:0')
33 |
34 | def prepare_dataset(self, dataset, classes=None, text_column='text', label_column="label_text", split=None):
35 |
36 | if 'test' in dataset:
37 | test_dataset = dataset['test']
38 | elif isinstance(dataset, Dataset):
39 | test_dataset = dataset
40 | else:
41 | test_dataset = dataset['train']
42 | if classes is None:
43 | classes = test_dataset[label_column]
44 | classes = list(set(classes))
45 | if split is not None:
46 | classes = [' '.join(class_.split(split)) for class_ in classes]
47 | texts = test_dataset[text_column]
48 | true_labels = test_dataset[label_column]
49 | print(true_labels[:5])
50 | print(classes)
51 | if type(test_dataset[label_column][0]) == int:
52 | true_labels = [classes[label] for label in true_labels]
53 | return texts, classes, true_labels
54 |
55 | def prepare_nomapping(self, dataset, classes=None, text_column='text', label_column='label_text', split=None):
56 | if 'test' in dataset:
57 | test_dataset = dataset['test']
58 | elif isinstance(dataset, Dataset):
59 | test_dataset = dataset
60 | else:
61 | test_dataset = dataset['train']
62 | if classes is None:
63 | if isinstance(test_dataset.features[label_column], ClassLabel):
64 | classes = test_dataset.features[label_column].names
65 | else:
66 | classes = test_dataset[label_column]
67 | classes = list(set(classes))
68 | if split is not None:
69 | classes = [' '.join(class_.split(split)) for class_ in classes]
70 | texts = test_dataset[text_column]
71 | true_labels = test_dataset[label_column]
72 | # if isinstance(test_dataset.features[label_column], ClassLabel):
73 | # true_labels = [test_dataset.features[label_column].int2str(label) for label in true_labels]
74 | if type(true_labels[0]) == int:
75 | true_labels = [classes[label] for label in true_labels]
76 |
77 | return texts, classes, true_labels
78 |
79 | def get_gliclass_predictions(self, test_texts, classes, batch_size=8):
80 | results = self.pipeline(test_texts, classes, batch_size=batch_size)
81 | predicts = [result[0]['label'] for result in results]
82 | return predicts
83 |
84 | def evaluate(self, predicts, true_labels):
85 | micro = f1_score(true_labels, predicts, average="micro")
86 | macro = f1_score(true_labels, predicts, average="macro")
87 | weighted = f1_score(true_labels, predicts, average="weighted")
88 | return {"micro": micro, "macro": macro, "weighted": weighted}
89 |
90 | def process(self):
91 | self.load_model()
92 | for dataset in self.datasets:
93 | classes = None
94 | print(dataset)
95 | if dataset == 'SetFit/sst5':
96 | classes = ['very negative', 'negative', 'neutral', 'positive', 'very positive']
97 | ds = load_dataset(dataset, trust_remote_code=True)
98 | test_texts, classes, true_labels = self.prepare_nomapping(ds, classes=classes)
99 | elif dataset == 'PolyAI/banking77':
100 | ds = load_dataset(dataset, trust_remote_code=True)
101 | test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='text', label_column='label')
102 | elif dataset == 'takala/financial_phrasebank':
103 | ds = load_dataset('takala/financial_phrasebank', 'sentences_allagree', trust_remote_code=True)
104 | test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='sentence',
105 | label_column="label")
106 | elif dataset == "AmazonScience/massive":
107 | ds = load_dataset(dataset,"en-US")
108 | test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='utt',
109 | label_column="intent")
110 | elif dataset == 'stanfordnlp/imdb':
111 | ds = load_dataset(dataset, trust_remote_code=True)
112 | classes = ['negative', 'positive']
113 | test_texts, classes, true_labels = self.prepare_nomapping(ds, classes=classes, text_column='text', label_column='label')
114 | print(true_labels[0], classes)
115 | elif dataset == 'ag_news':
116 | ds = load_dataset(dataset, trust_remote_code=True)
117 | test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='text', label_column='label')
118 | elif dataset == 'dair-ai/emotion':
119 | ds = load_dataset(dataset, trust_remote_code=True)
120 | test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='text', label_column='label')
121 | elif dataset == 'MoritzLaurer/cap_sotu':
122 | ds = load_dataset(dataset, trust_remote_code=True)
123 | test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='text', label_column='labels')
124 | elif dataset == 'cornell-movie-review-data/rotten_tomatoes':
125 | ds = load_dataset(dataset, trust_remote_code=True)
126 | test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='text', label_column='label')
127 | elif dataset == 'massive':
128 | ds = load_dataset("AmazonScience/massive", "en-US", trust_remote_code=True)
129 | test_texts, classes, true_labels = self.prepare_nomapping(ds, text_column='utt', label_column='intent')
130 | else:
131 | ds = load_dataset(dataset, trust_remote_code=True)
132 | test_texts, classes, true_labels = self.prepare_nomapping(ds)
133 | predicts = self.get_gliclass_predictions(test_texts, classes, batch_size=8)
134 | results = self.evaluate(predicts, true_labels)
135 | self.macro_scores.append(results['macro'])
136 | print(results)
137 | print('Average Score:', np.mean(self.macro_scores))
138 |
139 | if __name__ == "__main__":
140 | parser = argparse.ArgumentParser(description="Run TestModel with arguments")
141 | parser.add_argument("--model", type=str, required=True, help="Model name to use")
142 | parser.add_argument("--api_key", type=str, required=False, default = None, help="API key for authentication")
143 |
144 | args = parser.parse_args()
145 |
146 | gliclasstest = TestModel(args.model, args.api_key)
147 | gliclasstest.process()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
3 | import numpy as np
4 | import argparse
5 | import json
6 |
7 | from sklearn.metrics import precision_recall_fscore_support, accuracy_score
8 | from transformers import AutoTokenizer, AutoConfig
9 |
10 | import random
11 | import torch
12 |
13 | from gliclass import GLiClassModelConfig, GLiClassModel
14 | from gliclass.training import TrainingArguments, Trainer
15 | from gliclass.data_processing import DataCollatorWithPadding, GLiClassDataset
16 |
17 | def compute_metrics(p):
18 | predictions, labels = p
19 | labels = labels.reshape(-1)
20 | if args.problem_type == 'single_label_classification':
21 | preds = np.argmax(predictions, axis=1)
22 | precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
23 | accuracy = accuracy_score(labels, preds)
24 | return {
25 | 'accuracy': accuracy,
26 | 'precision': precision,
27 | 'recall': recall,
28 | 'f1': f1,
29 | }
30 |
31 | elif args.problem_type == 'multi_label_classification':
32 | predictions = predictions.reshape(-1)
33 | preds = (predictions > 0.5).astype(int)
34 | labels = np.where(labels>0.5, 1, 0)
35 | precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
36 | accuracy = accuracy_score(labels, preds)
37 | return {
38 | 'accuracy': accuracy,
39 | 'precision': precision,
40 | 'recall': recall,
41 | 'f1': f1,
42 | }
43 | else:
44 | raise NotImplementedError(f"{args.problem_type} is not implemented.")
45 |
46 | def main(args):
47 | device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')
48 |
49 | if args.model_name is not None:
50 | model = GLiClassModel.from_pretrained(args.model_name, focal_loss_alpha=args.focal_loss_alpha,
51 | focal_loss_gamma=args.focal_loss_gamma)
52 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
53 | else:
54 | tokenizer = AutoTokenizer.from_pretrained(args.encoder_model_name)
55 | encoder_config = AutoConfig.from_pretrained(args.encoder_model_name)
56 |
57 | if args.label_model_name is not None:
58 | label_model_config = AutoConfig.from_pretrained(args.label_model_name)
59 |
60 | glicalss_config = GLiClassModelConfig(
61 | encoder_config=encoder_config,
62 | encoder_model=args.encoder_model_name,
63 | label_model_name=args.label_model_name,
64 | label_model_config=label_model_config,
65 | class_token_index=len(tokenizer),
66 | text_token_index=len(tokenizer)+1,
67 | pooling_strategy=args.pooler_type,
68 | scorer_type=args.scorer_type,
69 | use_lstm=args.use_lstm,
70 | focal_loss_alpha=args.focal_loss_alpha,
71 | focal_loss_gamma=args.focal_loss_gamma,
72 | contrastive_loss_coef=args.contrastive_loss_coef,
73 | normalize_features=args.normalize_features,
74 | extract_text_features=args.extract_text_features,
75 | architecture_type=args.architecture_type,
76 | prompt_first=args.prompt_first,
77 | squeeze_layers=args.squeeze_layers,
78 | shuffle_labels=args.shuffle_labels
79 | )
80 |
81 | model = GLiClassModel(glicalss_config, from_pretrained=True)
82 |
83 | if args.architecture_type in {'uni-encoder', 'bi-encoder-fused', 'encoder-decoder'}:
84 | new_words = ["<>", "<>"]
85 | tokenizer.add_tokens(new_words, special_tokens=True)
86 | model.resize_token_embeddings(len(tokenizer))
87 |
88 | model.to(device)
89 |
90 | if model.config.label_model_name is not None:
91 | labels_tokenizer = AutoTokenizer.from_pretrained(model.config.label_model_name)
92 | else:
93 | labels_tokenizer = None
94 |
95 | model.config.problem_type = args.problem_type
96 |
97 | with open(args.data_path, 'r') as f:
98 | data = json.load(f)
99 |
100 | print('Dataset size:', len(data))
101 | random.shuffle(data)
102 | print('Dataset is shuffled...')
103 |
104 | train_data = data[:int(len(data)*0.9)]
105 | test_data = data[int(len(data)*0.9):]
106 |
107 | print('Dataset is splitted...')
108 |
109 | train_dataset = GLiClassDataset(train_data, tokenizer, args.max_length,
110 | args.problem_type, args.architecture_type,
111 | args.prompt_first, labels_tokenizer=labels_tokenizer)
112 | test_dataset = GLiClassDataset(test_data, tokenizer, args.max_length, args.problem_type,
113 | args.architecture_type, args.prompt_first,
114 | labels_tokenizer = labels_tokenizer)
115 |
116 | data_collator = DataCollatorWithPadding(device=device)
117 |
118 | training_args = TrainingArguments(
119 | output_dir=args.save_path,
120 | learning_rate=args.encoder_lr,
121 | weight_decay=args.encoder_weight_decay,
122 | others_lr=args.others_lr,
123 | others_weight_decay=args.others_weight_decay,
124 | lr_scheduler_type=args.lr_scheduler_type,
125 | warmup_ratio=args.warmup_ratio,
126 | per_device_train_batch_size=args.batch_size,
127 | per_device_eval_batch_size=args.batch_size,
128 | num_train_epochs=args.num_epochs,
129 | evaluation_strategy="epoch",
130 | save_steps = args.save_steps,
131 | save_total_limit=args.save_total_limit,
132 | dataloader_num_workers = args.num_workers,
133 | logging_steps=100,
134 | use_cpu = False,
135 | report_to="none",
136 | fp16=args.fp16,
137 | )
138 |
139 | trainer = Trainer(
140 | model=model,
141 | args=training_args,
142 | train_dataset=train_dataset,
143 | eval_dataset=test_dataset,
144 | tokenizer=tokenizer,
145 | data_collator=data_collator,
146 | compute_metrics=compute_metrics,
147 | )
148 | trainer.train()
149 |
150 | if __name__ == '__main__':
151 | parser = argparse.ArgumentParser()
152 | parser.add_argument('--model_name', type=str, default= None)
153 | parser.add_argument('--encoder_model_name', type=str, default = 'microsoft/deberta-v3-small')
154 | parser.add_argument('--label_model_name', type=str, default = "BAAI/bge-small-en-v1.5")
155 | parser.add_argument('--save_path', type=str, default = 'models/')
156 | parser.add_argument('--data_path', type=str, default = 'data/zero-cats.json')
157 | parser.add_argument('--problem_type', type=str, default='multi_label_classification')
158 | parser.add_argument('--pooler_type', type=str, default='avg')
159 | parser.add_argument('--scorer_type', type=str, default='simple')
160 | parser.add_argument('--architecture_type', type=str, default='uni-encoder')
161 | parser.add_argument('--normalize_features', type=bool, default=False)
162 | parser.add_argument('--extract_text_features', type=bool, default=False)
163 | parser.add_argument('--prompt_first', type=bool, default=True)
164 | parser.add_argument('--use_lstm', type=bool, default=False)
165 | parser.add_argument('--squeeze_layers', type=bool, default=False)
166 | parser.add_argument('--shuffle_labels', type=bool, default=True)
167 | parser.add_argument('--num_epochs', type=int, default=3)
168 | parser.add_argument('--batch_size', type=int, default=8)
169 | parser.add_argument('--encoder_lr', type=float, default=1e-5)
170 | parser.add_argument('--others_lr', type=float, default=3e-5)
171 | parser.add_argument('--encoder_weight_decay', type=float, default=0.01)
172 | parser.add_argument('--others_weight_decay', type=float, default=0.01)
173 | parser.add_argument('--warmup_ratio', type=float, default=0.05)
174 | parser.add_argument('--lr_scheduler_type', type=str, default='linear')
175 | parser.add_argument('--focal_loss_alpha', type=float, default=-1)
176 | parser.add_argument('--focal_loss_gamma', type=float, default=-1)
177 | parser.add_argument('--contrastive_loss_coef', type=float, default=0.)
178 | parser.add_argument('--max_length', type=int, default=1024)
179 | parser.add_argument('--save_steps', type=int, default=1000)
180 | parser.add_argument('--save_total_limit', type=int, default=3)
181 | parser.add_argument('--num_workers', type=int, default=12)
182 | parser.add_argument('--fp16', type=bool, default=False)
183 | args = parser.parse_args()
184 |
185 | main(args)
186 |
--------------------------------------------------------------------------------
/train_rl.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
3 | import numpy as np
4 | import argparse
5 | import json
6 |
7 | from sklearn.metrics import precision_recall_fscore_support, accuracy_score
8 | from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
9 |
10 | import random
11 | import torch
12 |
13 | from gliclass import GLiClassModelConfig, GLiClassModel, ZeroShotClassificationPipeline
14 | from gliclass.training import TrainingArguments, Trainer, RLTrainerConfig, RLTrainer
15 | from gliclass.data_processing import DataCollatorWithPadding, GLiClassDataset
16 | from gliclass.utils import default_f1_reward
17 |
18 | def accuracy_reward(probs, actions, targets, valid_mask):
19 | probs = probs * valid_mask
20 | predicts = torch.argmax(probs, dim=-1)
21 | true_labels = torch.argmax(targets, dim=-1)
22 | correct = (predicts == true_labels).float().unsqueeze(1)
23 | return correct
24 |
25 | def recall_reward(
26 | probs: torch.Tensor,
27 | actions: torch.Tensor,
28 | original_targets: torch.Tensor,
29 | valid_mask: torch.Tensor
30 | ) -> torch.Tensor:
31 | valid_preds = actions * valid_mask
32 | valid_targets = original_targets * valid_mask
33 |
34 | TP = torch.sum((valid_preds * valid_targets), dim=-1)
35 | FN = torch.sum(((1 - valid_preds) * valid_targets), dim=-1)
36 |
37 | eps = 1e-8
38 | recall = TP / (TP + FN + eps)
39 | return recall.detach().unsqueeze(1)
40 |
41 | def compute_metrics(p):
42 | predictions, labels = p
43 | labels = labels.reshape(-1)
44 | if args.problem_type == 'single_label_classification':
45 | preds = np.argmax(predictions, axis=1)
46 | precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
47 | accuracy = accuracy_score(labels, preds)
48 | return {
49 | 'accuracy': accuracy,
50 | 'precision': precision,
51 | 'recall': recall,
52 | 'f1': f1,
53 | }
54 |
55 | elif args.problem_type == 'multi_label_classification':
56 | predictions = predictions.reshape(-1)
57 | preds = (predictions > 0.5).astype(int)
58 | precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
59 | accuracy = accuracy_score(labels, preds)
60 | return {
61 | 'accuracy': accuracy,
62 | 'precision': precision,
63 | 'recall': recall,
64 | 'f1': f1,
65 | }
66 | else:
67 | raise NotImplementedError(f"{args.problem_type} is not implemented.")
68 |
69 | def main(args):
70 | device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')
71 |
72 | if args.model_name is not None:
73 | model = GLiClassModel.from_pretrained(args.model_name, focal_loss_alpha=args.focal_loss_alpha,
74 | focal_loss_gamma=args.focal_loss_gamma)
75 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
76 | else:
77 | tokenizer = AutoTokenizer.from_pretrained(args.encoder_model_name)
78 | encoder_config = AutoConfig.from_pretrained(args.encoder_model_name)
79 |
80 | if args.label_model_name is not None:
81 | label_model_config = AutoConfig.from_pretrained(args.label_model_name)
82 |
83 | glicalss_config = GLiClassModelConfig(
84 | encoder_config=encoder_config,
85 | encoder_model=args.encoder_model_name,
86 | label_model_name=args.label_model_name,
87 | label_model_config=label_model_config,
88 | class_token_index=len(tokenizer),
89 | text_token_index=len(tokenizer)+1,
90 | pooling_strategy=args.pooler_type,
91 | scorer_type=args.scorer_type,
92 | use_lstm=args.use_lstm,
93 | focal_loss_alpha=args.focal_loss_alpha,
94 | focal_loss_gamma=args.focal_loss_gamma,
95 | labels_smoothing=args.labels_smoothing,
96 | entropy_beta=args.entropy_beta,
97 | kl_beta=args.kl_beta,
98 | contrastive_loss_coef=args.contrastive_loss_coef,
99 | normalize_features=args.normalize_features,
100 | extract_text_features=args.extract_text_features,
101 | architecture_type=args.architecture_type,
102 | prompt_first=args.prompt_first,
103 | squeeze_layers=args.squeeze_layers
104 | )
105 |
106 | glicalss_config.problem_type = args.problem_type
107 |
108 | model = GLiClassModel(glicalss_config, from_pretrained=True)
109 |
110 | if args.architecture_type in {'uni-encoder', 'bi-encoder-fused', 'encoder-decoder'}:
111 | new_words = ["<>", "<>"]
112 | tokenizer.add_tokens(new_words, special_tokens=True)
113 | model.resize_token_embeddings(len(tokenizer))
114 |
115 | if args.set_value_model:
116 | value_model = AutoModelForSequenceClassification.from_pretrained(model.config.encoder_model_name, num_labels=1)
117 | value_model.resize_token_embeddings(len(tokenizer))
118 | else:
119 | value_model = None
120 |
121 | if args.reference_model is not None:
122 | refrence_model = GLiClassModel.from_pretrained(args.reference_model)
123 | reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_model)
124 | reference_pipe = ZeroShotClassificationPipeline(refrence_model, reference_tokenizer,
125 | classification_type='multi-label',
126 | progress_bar=False, device=device)
127 | else:
128 | reference_pipe = None
129 |
130 | if args.label_model_name is not None:
131 | labels_tokenizer = AutoTokenizer.from_pretrained(args.label_model_name)
132 | else:
133 | labels_tokenizer = None
134 |
135 | model.to(device)
136 |
137 | with open(args.data_path, 'r') as f:
138 | data = json.load(f)[:]
139 | init_ld = len(data)*1
140 |
141 | print('Dataset size:', len(data))
142 | random.shuffle(data)
143 | print('Dataset is shuffled...')
144 |
145 | train_data = data[:int(len(data)*0.9)]
146 | test_data = data[int(len(data)*0.9):]
147 |
148 | print('Dataset is splitted...')
149 |
150 | train_dataset = GLiClassDataset(train_data, tokenizer, args.max_length,
151 | args.problem_type, args.architecture_type,
152 | args.prompt_first, labels_tokenizer=labels_tokenizer)
153 | test_dataset = GLiClassDataset(test_data, tokenizer, args.max_length, args.problem_type,
154 | args.architecture_type, args.prompt_first,
155 | labels_tokenizer = labels_tokenizer)
156 |
157 | data_collator = DataCollatorWithPadding(device=device)
158 |
159 | training_args = RLTrainerConfig(
160 | output_dir=args.save_path,
161 | learning_rate=args.encoder_lr,
162 | weight_decay=args.encoder_weight_decay,
163 | others_lr=args.others_lr,
164 | others_weight_decay=args.others_weight_decay,
165 | lr_scheduler_type=args.lr_scheduler_type,
166 | warmup_ratio=args.warmup_ratio,
167 | per_device_train_batch_size=args.batch_size,
168 | per_device_eval_batch_size=args.batch_size,
169 | num_train_epochs=args.num_epochs,
170 | evaluation_strategy="epoch",
171 | save_steps = args.save_steps,
172 | save_total_limit=args.save_total_limit,
173 | dataloader_num_workers = args.num_workers,
174 | logging_steps=100,
175 | use_cpu = False,
176 | report_to="none",
177 | fp16=args.fp16,
178 | cliprange=args.clip_range,
179 | num_rl_iters=args.num_rl_iters
180 | )
181 |
182 | trainer = RLTrainer(
183 | model=model,
184 | value_model=value_model,
185 | reference_model=reference_pipe,
186 | args=training_args,
187 | train_dataset=train_dataset,
188 | eval_dataset=test_dataset,
189 | tokenizer=tokenizer,
190 | data_collator=data_collator,
191 | reward_components={
192 | 'micro_f1': default_f1_reward,
193 | },
194 | )
195 | trainer.train()
196 |
197 | if __name__ == '__main__':
198 | parser = argparse.ArgumentParser()
199 | parser.add_argument('--model_name', type=str, default= "knowledgator/gliclass-modern-base-v2.0-init")
200 | parser.add_argument('--encoder_model_name', type=str, default = 'microsoft/deberta-v3-small')
201 | parser.add_argument('--label_model_name', type=str, default = "BAAI/bge-small-en-v1.5")
202 | parser.add_argument('--reference_model', type=str, default = None)
203 | parser.add_argument('--set_value_model', type=bool, default = True)
204 | parser.add_argument('--save_path', type=str, default = 'models/')
205 | parser.add_argument('--data_path', type=str, default = 'data/zero-cats.json')
206 | parser.add_argument('--problem_type', type=str, default='multi_label_classification')
207 | parser.add_argument('--pooler_type', type=str, default='avg')
208 | parser.add_argument('--scorer_type', type=str, default='simple')
209 | parser.add_argument('--architecture_type', type=str, default='uni-encoder')
210 | parser.add_argument('--normalize_features', type=bool, default=False)
211 | parser.add_argument('--extract_text_features', type=bool, default=False)
212 | parser.add_argument('--prompt_first', type=bool, default=True)
213 | parser.add_argument('--use_lstm', type=bool, default=False)
214 | parser.add_argument('--squeeze_layers', type=bool, default=False)
215 | parser.add_argument('--num_epochs', type=int, default=1)
216 | parser.add_argument('--batch_size', type=int, default=32)
217 | parser.add_argument('--encoder_lr', type=float, default=2e-6)
218 | parser.add_argument('--others_lr', type=float, default=3e-6)
219 | parser.add_argument('--encoder_weight_decay', type=float, default=0.01)
220 | parser.add_argument('--others_weight_decay', type=float, default=0.01)
221 | parser.add_argument('--warmup_ratio', type=float, default=0.05)
222 | parser.add_argument('--lr_scheduler_type', type=str, default='linear')
223 | parser.add_argument('--focal_loss_alpha', type=float, default=-1)
224 | parser.add_argument('--focal_loss_gamma', type=float, default=-1)
225 | parser.add_argument('--labels_smoothing', type=float, default=-1)
226 | parser.add_argument('--entropy_beta', type=float, default=-1)
227 | parser.add_argument('--kl_beta', type=float, default=0.1)
228 | parser.add_argument('--clip_range', type=float, default=0.2)
229 | parser.add_argument('--num_rl_iters', type=int, default=2)
230 | parser.add_argument('--contrastive_loss_coef', type=float, default=0.)
231 | parser.add_argument('--max_length', type=int, default=2048)
232 | parser.add_argument('--save_steps', type=int, default=300)
233 | parser.add_argument('--save_total_limit', type=int, default=3)
234 | parser.add_argument('--num_workers', type=int, default=12)
235 | parser.add_argument('--fp16', type=bool, default=False)
236 | args = parser.parse_args()
237 |
238 | main(args)
239 |
--------------------------------------------------------------------------------