├── remove_head.py ├── setup.py ├── t5_encoder ├── __init__.py └── modeling_t5.py ├── README.md ├── .gitignore ├── LICENSE ├── run_ner.py └── run_glue.py /remove_head.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from transformers import T5EncoderModel, AutoTokenizer 3 | 4 | parser = ArgumentParser('Head remover') 5 | 6 | parser.add_argument("-i", '--input_model', dest='input_model', type=str, help="Input model.") 7 | parser.add_argument("-o", '--output_model', dest='output_model', type=str, help="Output model.") 8 | 9 | if __name__ == "__main__": 10 | args = parser.parse_args() 11 | # Save model 12 | model = T5EncoderModel.from_pretrained(args.input_model) 13 | model.save_pretrained(args.output_model) 14 | # Save tokenizer 15 | tokenizer = AutoTokenizer.from_pretrained(args.input_model) 16 | tokenizer.save_pretrained(args.output_model) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import t5_encoder 3 | 4 | with open("README.md", "r") as fh: 5 | long_description = fh.read() 6 | 7 | setuptools.setup( 8 | name="t5_encoder", 9 | author="Oscar Sainz", 10 | version=t5_encoder.__version__, 11 | author_email="oscar.sainz@ehu.eus", 12 | description="A extension of Transformers library to include T5ForSequenceClassification class.", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/osainz59/t5-encoder", 16 | packages=setuptools.find_packages(), 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: Apache Software License", 20 | "Operating System :: OS Independent", 21 | ], 22 | python_requires=">=3.6", 23 | install_requires=["transformers", "torch"], 24 | ) -------------------------------------------------------------------------------- /t5_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers.models.auto.modeling_auto import ( 2 | MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, 3 | MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, 4 | _LazyAutoMapping 5 | ) 6 | from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES 7 | 8 | from .modeling_t5 import T5ForTokenClassification, T5ForSequenceClassification 9 | 10 | MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES["t5"] = "T5ForTokenClassification" 11 | MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES["t5"] = "T5ForSequenceClassification" 12 | 13 | MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( 14 | CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES 15 | ) 16 | 17 | MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( 18 | CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES 19 | ) 20 | 21 | from transformers.models import t5 22 | 23 | setattr(t5, "T5ForTokenClassification", T5ForTokenClassification) 24 | setattr(t5, "T5ForSequenceClassification", T5ForSequenceClassification) 25 | 26 | __version__ = "0.1" 27 | 28 | __all__ = [ 29 | "T5ForTokenClassification", 30 | "T5ForSequenceClassification" 31 | ] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # T5 Encoder-only extension for Transformers 2 | This repository contains the implementation of `T5ForSequenceClassification` and `T5ForTokenClassification` fully compatible with [Transformers](https://github.com/huggingface/transformers) library. While this could be a feature from the library itself is not implemented yet, so this repository contains the code for preliminary experiments before being actually included to the library. 3 | 4 | This implementation is inspired by [EncT5: A Framework for Fine-tuning T5 as Non-autoregressive Models](https://arxiv.org/abs/2110.08426) and [A Universal Discriminator for Zero-Shot Generalization](https://arxiv.org/pdf/2211.08099.pdf) that made use of T5 encoder only. 5 | 6 | ## Installation and use 7 | You can simply install this library by running the following command: 8 | ```bash 9 | python -m pip install git+https://github.com/osainz59/t5-encoder 10 | ``` 11 | To use the implemented classes you have to simply import `t5_encoder` along with transformers. Example: 12 | ```python 13 | import t5_encoder 14 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 15 | 16 | tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-base") 17 | model = AutoModelForSequenceClassification.from_pretrained("google/t5-v1_1-base") 18 | 19 | outputs = model(**tokenizer("This is a sentence to classify.", return_tensors="pt")) 20 | print(outputs.logits) 21 | >>> tensor([[ 0.0512, -0.0594]], grad_fn=) 22 | ``` 23 | 24 | ## GLUE results 25 | 26 | | Model | CoLA | SST2 | MRPC | STSb | QQP | MNLI | QNLI | RTE | WNLI | 27 | |:------|:------|:-----|:-----|:-----|:----|:-----|:-----|:----|:-----| 28 | | RoBERTalarge | **68.0** | **96.4** | 90.9 | 92.4 | **92.2** | 90.2/90.2 | 94.7 | 86.6 | **91.3** | 29 | | T5large | 61.2 | 96.3 | 92.4 | 89.9 | 89.9 | 89.9/89.6 | **94.8** | 87.2 | 85.6 | 30 | | T5-Enclarge | 55.0 | 96.1 | **93.3** | **92.7** | 91.4 | **90.5/90.4** | 94.7 | **88.8** | 47.9 | 31 | 32 | ## NER results 33 | | Model | CoNLL-2003 (F1) | 34 | |:------|:------| 35 | |[RoBERTa](https://huggingface.co/Gladiator/roberta-large_ner_conll2003)large | 96.57 | 36 | | T5 | - | 37 | | T5-Enclarge | 95.49 | 38 | 39 | **Important:** Those results are obtained by a single run, for those datasets with very few examples the performance might change drastically. 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore bash scripts 2 | *.sh 3 | .slurm/* 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /t5_encoder/modeling_t5.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, Tuple, Union 3 | import torch 4 | from torch import nn 5 | from torch.utils import checkpoint 6 | 7 | from transformers.models.t5.modeling_t5 import ( 8 | T5PreTrainedModel, 9 | T5Config, 10 | T5Stack, 11 | T5_START_DOCSTRING, 12 | T5_ENCODER_INPUTS_DOCSTRING, 13 | PARALLELIZE_DOCSTRING, 14 | DEPARALLELIZE_DOCSTRING, 15 | _CONFIG_FOR_DOC 16 | ) 17 | from transformers.modeling_outputs import ( 18 | SequenceClassifierOutput, 19 | TokenClassifierOutput 20 | ) 21 | from transformers.utils import ( 22 | add_start_docstrings, 23 | add_start_docstrings_to_model_forward, 24 | replace_return_docstrings 25 | ) 26 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 27 | from transformers.utils.logging import get_logger 28 | logger = get_logger("transformers") 29 | 30 | @add_start_docstrings( 31 | """T5 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for 32 | Named-Entity-Recognition (NER) tasks. 33 | """, 34 | T5_START_DOCSTRING 35 | ) 36 | class T5ForTokenClassification(T5PreTrainedModel): 37 | _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"] 38 | 39 | def __init__(self, config: T5Config): 40 | super().__init__(config) 41 | self.model_dim = config.d_model 42 | 43 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 44 | 45 | encoder_config = copy.deepcopy(config) 46 | encoder_config.is_decoder = False 47 | encoder_config.is_encoder_decoder = False 48 | encoder_config.use_cache = False 49 | self.encoder = T5Stack(encoder_config, self.shared) 50 | 51 | classifier_dropout = ( 52 | config.classifier_dropout if hasattr(config, 'classifier_dropout') else config.dropout_rate 53 | ) 54 | self.dropout = nn.Dropout(classifier_dropout) 55 | self.classifier = nn.Linear(config.d_model, config.num_labels) 56 | 57 | # Initialize weights and apply final processing 58 | self.post_init() 59 | 60 | # Model parallel 61 | self.model_parallel = False 62 | self.device_map = None 63 | 64 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 65 | def parallelize(self, device_map=None): 66 | self.device_map = ( 67 | get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) 68 | if device_map is None 69 | else device_map 70 | ) 71 | assert_device_map(self.device_map, len(self.encoder.block)) 72 | self.encoder.parallelize(self.device_map) 73 | self.classifier.to(self.encoder.first_device) 74 | self.model_parallel = True 75 | 76 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 77 | def deparallelize(self): 78 | self.encoder.deparallelize() 79 | self.encoder = self.encoder.to("cpu") 80 | self.classifier = self.classifier.to("cpu") 81 | self.model_parallel = False 82 | self.device_map = None 83 | torch.cuda.empty_cache() 84 | 85 | def get_input_embeddings(self): 86 | return self.shared 87 | 88 | def set_input_embeddings(self, new_embeddings): 89 | self.shared = new_embeddings 90 | self.encoder.set_input_embeddings(new_embeddings) 91 | 92 | def get_encoder(self): 93 | return self.encoder 94 | 95 | def _prune_heads(self, heads_to_prune): 96 | """ 97 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 98 | class PreTrainedModel 99 | """ 100 | for layer, heads in heads_to_prune.items(): 101 | self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) 102 | 103 | @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) 104 | @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) 105 | def forward( 106 | self, 107 | input_ids: Optional[torch.LongTensor] = None, 108 | attention_mask: Optional[torch.FloatTensor] = None, 109 | head_mask: Optional[torch.FloatTensor] = None, 110 | inputs_embeds: Optional[torch.FloatTensor] = None, 111 | labels: Optional[torch.LongTensor] = None, 112 | output_attentions: Optional[bool] = None, 113 | output_hidden_states: Optional[bool] = None, 114 | return_dict: Optional[bool] = None, 115 | ) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]: 116 | r""" 117 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 118 | Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. 119 | 120 | Returns: 121 | 122 | """ 123 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 124 | 125 | outputs = self.encoder( 126 | input_ids=input_ids, 127 | attention_mask=attention_mask, 128 | inputs_embeds=inputs_embeds, 129 | head_mask=head_mask, 130 | output_attentions=output_attentions, 131 | output_hidden_states=output_hidden_states, 132 | return_dict=return_dict, 133 | ) 134 | 135 | sequence_output = outputs[0] 136 | 137 | sequence_output = self.dropout(sequence_output) 138 | logits = self.classifier(sequence_output) 139 | 140 | loss = None 141 | if labels is not None: 142 | loss_fct = nn.CrossEntropyLoss() 143 | loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) 144 | 145 | if not return_dict: 146 | output = (logits,) + outputs[2:] 147 | return ((loss,) + output) if loss is not None else output 148 | 149 | return TokenClassifierOutput( 150 | loss=loss, 151 | logits=logits, 152 | hidden_states=outputs.hidden_states, 153 | attentions=outputs.attentions 154 | ) 155 | 156 | @add_start_docstrings( 157 | """T5 Model with a sequence classification head on top (a linear layer on top of the token). 158 | """, 159 | T5_START_DOCSTRING 160 | ) 161 | class T5ForSequenceClassification(T5PreTrainedModel): 162 | _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"] 163 | 164 | def __init__(self, config: T5Config): 165 | super().__init__(config) 166 | self.model_dim = config.d_model 167 | self.config.problem_type = None 168 | self.config.is_encoder_decoder = False 169 | 170 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 171 | 172 | encoder_config = copy.deepcopy(config) 173 | encoder_config.is_decoder = False 174 | encoder_config.is_encoder_decoder = False 175 | encoder_config.use_cache = False 176 | self.encoder = T5Stack(encoder_config, self.shared) 177 | 178 | classifier_dropout = ( 179 | config.classifier_dropout if hasattr(config, 'classifier_dropout') else config.dropout_rate 180 | ) 181 | self.dropout = nn.Dropout(classifier_dropout) 182 | self.classifier = nn.Linear(config.d_model, config.num_labels) 183 | 184 | # Initialize weights and apply final processing 185 | self.post_init() 186 | 187 | # Model parallel 188 | self.model_parallel = False 189 | self.device_map = None 190 | 191 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 192 | def parallelize(self, device_map=None): 193 | self.device_map = ( 194 | get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) 195 | if device_map is None 196 | else device_map 197 | ) 198 | assert_device_map(self.device_map, len(self.encoder.block)) 199 | self.encoder.parallelize(self.device_map) 200 | self.classifier.to(self.encoder.first_device) 201 | self.model_parallel = True 202 | 203 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 204 | def deparallelize(self): 205 | self.encoder.deparallelize() 206 | self.encoder = self.encoder.to("cpu") 207 | self.classifier = self.classifier.to("cpu") 208 | self.model_parallel = False 209 | self.device_map = None 210 | torch.cuda.empty_cache() 211 | 212 | def get_input_embeddings(self): 213 | return self.shared 214 | 215 | def set_input_embeddings(self, new_embeddings): 216 | self.shared = new_embeddings 217 | self.encoder.set_input_embeddings(new_embeddings) 218 | 219 | def get_encoder(self): 220 | return self.encoder 221 | 222 | def _prune_heads(self, heads_to_prune): 223 | """ 224 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 225 | class PreTrainedModel 226 | """ 227 | for layer, heads in heads_to_prune.items(): 228 | self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) 229 | 230 | @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) 231 | @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC) 232 | def forward( 233 | self, 234 | input_ids: Optional[torch.LongTensor] = None, 235 | attention_mask: Optional[torch.FloatTensor] = None, 236 | head_mask: Optional[torch.FloatTensor] = None, 237 | inputs_embeds: Optional[torch.FloatTensor] = None, 238 | labels: Optional[torch.LongTensor] = None, 239 | output_attentions: Optional[bool] = None, 240 | output_hidden_states: Optional[bool] = None, 241 | return_dict: Optional[bool] = None, 242 | ) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]: 243 | r""" 244 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 245 | Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. 246 | 247 | Returns: 248 | 249 | """ 250 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 251 | 252 | outputs = self.encoder( 253 | input_ids=input_ids, 254 | attention_mask=attention_mask, 255 | inputs_embeds=inputs_embeds, 256 | head_mask=head_mask, 257 | output_attentions=output_attentions, 258 | output_hidden_states=output_hidden_states, 259 | return_dict=return_dict, 260 | ) 261 | 262 | # Get last hidden indices 263 | # (batch_size) -> (batch_size, 1) -> (batch_size, hidden_size) -> (batch_size, 1, hidden_size) 264 | # Calculate the sum of non-padding tokens and subtract 1 265 | sums = (input_ids != self.config.pad_token_id).sum(dim=-1) - 1 266 | # Replace negative indices with 0 267 | mask = sums < 0 268 | if mask.any(): 269 | logger.warning("Found a sequence of input_ids == all pad_token_ids. Make sure that you've correctly tokenized the input text.") 270 | sums = torch.where(sums < 0, torch.zeros_like(sums), sums) 271 | 272 | last_hidden_indices = sums.unsqueeze(dim=-1).repeat(1, outputs[0].size(-1)).unsqueeze(1) 273 | sequence_output = outputs[0].gather(dim=1, index=last_hidden_indices).squeeze(1) 274 | 275 | sequence_output = self.dropout(sequence_output) 276 | logits = self.classifier(sequence_output) 277 | 278 | loss = None 279 | if labels is not None: 280 | if self.config.problem_type is None: 281 | if self.config.num_labels == 1: 282 | self.config.problem_type = "regression" 283 | elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 284 | self.config.problem_type = "single_label_classification" 285 | else: 286 | self.config.problem_type = "multi_label_classification" 287 | 288 | if self.config.problem_type == "regression": 289 | loss_fct = nn.MSELoss() 290 | if self.config.num_labels == 1: 291 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 292 | else: 293 | loss = loss_fct(logits, labels) 294 | elif self.config.problem_type == "single_label_classification": 295 | loss_fct = nn.CrossEntropyLoss() 296 | loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) 297 | elif self.config.problem_type == "multi_label_classification": 298 | loss_fct = nn.BCEWithLogitsLoss() 299 | loss = loss_fct(logits, labels) 300 | 301 | if not return_dict: 302 | output = (logits,) + outputs[2:] 303 | return ((loss,) + output) if loss is not None else output 304 | 305 | return SequenceClassifierOutput( 306 | loss=loss, 307 | logits=logits, 308 | hidden_states=outputs.hidden_states, 309 | attentions=outputs.attentions 310 | ) 311 | -------------------------------------------------------------------------------- /run_ner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Team All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for token classification. 18 | """ 19 | # You can also adapt this script on your own token classification task and datasets. Pointers for this are left as 20 | # comments. 21 | 22 | import logging 23 | import os 24 | import sys 25 | from dataclasses import dataclass, field 26 | from typing import Optional 27 | 28 | import datasets 29 | import numpy as np 30 | from datasets import ClassLabel, load_dataset, load_metric 31 | 32 | import t5_encoder 33 | import transformers 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModelForTokenClassification, 37 | AutoTokenizer, 38 | DataCollatorForTokenClassification, 39 | HfArgumentParser, 40 | PretrainedConfig, 41 | PreTrainedTokenizerFast, 42 | Trainer, 43 | TrainingArguments, 44 | set_seed, 45 | ) 46 | from transformers.trainer_utils import get_last_checkpoint 47 | from transformers.utils import check_min_version 48 | from transformers.utils.versions import require_version 49 | 50 | 51 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 52 | check_min_version("4.18.0") 53 | 54 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") 55 | 56 | logger = logging.getLogger(__name__) 57 | 58 | 59 | @dataclass 60 | class ModelArguments: 61 | """ 62 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 63 | """ 64 | 65 | model_name_or_path: str = field( 66 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 67 | ) 68 | config_name: Optional[str] = field( 69 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 70 | ) 71 | tokenizer_name: Optional[str] = field( 72 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 73 | ) 74 | cache_dir: Optional[str] = field( 75 | default=None, 76 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 77 | ) 78 | model_revision: str = field( 79 | default="main", 80 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 81 | ) 82 | use_auth_token: bool = field( 83 | default=False, 84 | metadata={ 85 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 86 | "with private models)." 87 | }, 88 | ) 89 | 90 | 91 | @dataclass 92 | class DataTrainingArguments: 93 | """ 94 | Arguments pertaining to what data we are going to input our model for training and eval. 95 | """ 96 | 97 | task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."}) 98 | dataset_name: Optional[str] = field( 99 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 100 | ) 101 | dataset_config_name: Optional[str] = field( 102 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 103 | ) 104 | train_file: Optional[str] = field( 105 | default=None, metadata={"help": "The input training data file (a csv or JSON file)."} 106 | ) 107 | validation_file: Optional[str] = field( 108 | default=None, 109 | metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."}, 110 | ) 111 | test_file: Optional[str] = field( 112 | default=None, 113 | metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."}, 114 | ) 115 | text_column_name: Optional[str] = field( 116 | default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."} 117 | ) 118 | label_column_name: Optional[str] = field( 119 | default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."} 120 | ) 121 | overwrite_cache: bool = field( 122 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 123 | ) 124 | preprocessing_num_workers: Optional[int] = field( 125 | default=None, 126 | metadata={"help": "The number of processes to use for the preprocessing."}, 127 | ) 128 | max_seq_length: int = field( 129 | default=None, 130 | metadata={ 131 | "help": "The maximum total input sequence length after tokenization. If set, sequences longer " 132 | "than this will be truncated, sequences shorter will be padded." 133 | }, 134 | ) 135 | pad_to_max_length: bool = field( 136 | default=False, 137 | metadata={ 138 | "help": "Whether to pad all samples to model maximum sentence length. " 139 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 140 | "efficient on GPU but very bad for TPU." 141 | }, 142 | ) 143 | max_train_samples: Optional[int] = field( 144 | default=None, 145 | metadata={ 146 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 147 | "value if set." 148 | }, 149 | ) 150 | max_eval_samples: Optional[int] = field( 151 | default=None, 152 | metadata={ 153 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 154 | "value if set." 155 | }, 156 | ) 157 | max_predict_samples: Optional[int] = field( 158 | default=None, 159 | metadata={ 160 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 161 | "value if set." 162 | }, 163 | ) 164 | label_all_tokens: bool = field( 165 | default=False, 166 | metadata={ 167 | "help": "Whether to put the label for one word on all tokens of generated by that word or just on the " 168 | "one (in which case the other tokens will have a padding index)." 169 | }, 170 | ) 171 | return_entity_level_metrics: bool = field( 172 | default=False, 173 | metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."}, 174 | ) 175 | 176 | def __post_init__(self): 177 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 178 | raise ValueError("Need either a dataset name or a training/validation file.") 179 | else: 180 | if self.train_file is not None: 181 | extension = self.train_file.split(".")[-1] 182 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 183 | if self.validation_file is not None: 184 | extension = self.validation_file.split(".")[-1] 185 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 186 | self.task_name = self.task_name.lower() 187 | 188 | 189 | def main(): 190 | # See all possible arguments in src/transformers/training_args.py 191 | # or by passing the --help flag to this script. 192 | # We now keep distinct sets of args, for a cleaner separation of concerns. 193 | 194 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 195 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 196 | # If we pass only one argument to the script and it's the path to a json file, 197 | # let's parse it to get our arguments. 198 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 199 | else: 200 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 201 | 202 | # Setup logging 203 | logging.basicConfig( 204 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 205 | datefmt="%m/%d/%Y %H:%M:%S", 206 | handlers=[logging.StreamHandler(sys.stdout)], 207 | ) 208 | 209 | log_level = training_args.get_process_log_level() 210 | logger.setLevel(log_level) 211 | datasets.utils.logging.set_verbosity(log_level) 212 | transformers.utils.logging.set_verbosity(log_level) 213 | transformers.utils.logging.enable_default_handler() 214 | transformers.utils.logging.enable_explicit_format() 215 | 216 | # Log on each process the small summary: 217 | logger.warning( 218 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 219 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 220 | ) 221 | logger.info(f"Training/evaluation parameters {training_args}") 222 | 223 | # Detecting last checkpoint. 224 | last_checkpoint = None 225 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 226 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 227 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 228 | raise ValueError( 229 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 230 | "Use --overwrite_output_dir to overcome." 231 | ) 232 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 233 | logger.info( 234 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 235 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 236 | ) 237 | 238 | # Set seed before initializing model. 239 | set_seed(training_args.seed) 240 | 241 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 242 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 243 | # (the dataset will be downloaded automatically from the datasets Hub). 244 | # 245 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 246 | # 'text' is found. You can easily tweak this behavior (see below). 247 | # 248 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 249 | # download the dataset. 250 | if data_args.dataset_name is not None: 251 | # Downloading and loading a dataset from the hub. 252 | raw_datasets = load_dataset( 253 | data_args.dataset_name, 254 | data_args.dataset_config_name, 255 | cache_dir=model_args.cache_dir, 256 | use_auth_token=True if model_args.use_auth_token else None, 257 | ) 258 | else: 259 | data_files = {} 260 | if data_args.train_file is not None: 261 | data_files["train"] = data_args.train_file 262 | if data_args.validation_file is not None: 263 | data_files["validation"] = data_args.validation_file 264 | if data_args.test_file is not None: 265 | data_files["test"] = data_args.test_file 266 | extension = data_args.train_file.split(".")[-1] 267 | raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) 268 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 269 | # https://huggingface.co/docs/datasets/loading_datasets.html. 270 | 271 | if training_args.do_train: 272 | column_names = raw_datasets["train"].column_names 273 | features = raw_datasets["train"].features 274 | else: 275 | column_names = raw_datasets["validation"].column_names 276 | features = raw_datasets["validation"].features 277 | 278 | if data_args.text_column_name is not None: 279 | text_column_name = data_args.text_column_name 280 | elif "tokens" in column_names: 281 | text_column_name = "tokens" 282 | else: 283 | text_column_name = column_names[0] 284 | 285 | if data_args.label_column_name is not None: 286 | label_column_name = data_args.label_column_name 287 | elif f"{data_args.task_name}_tags" in column_names: 288 | label_column_name = f"{data_args.task_name}_tags" 289 | else: 290 | label_column_name = column_names[1] 291 | 292 | # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the 293 | # unique labels. 294 | def get_label_list(labels): 295 | unique_labels = set() 296 | for label in labels: 297 | unique_labels = unique_labels | set(label) 298 | label_list = list(unique_labels) 299 | label_list.sort() 300 | return label_list 301 | 302 | # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere. 303 | # Otherwise, we have to get the list of labels manually. 304 | labels_are_int = isinstance(features[label_column_name].feature, ClassLabel) 305 | if labels_are_int: 306 | label_list = features[label_column_name].feature.names 307 | label_to_id = {i: i for i in range(len(label_list))} 308 | else: 309 | label_list = get_label_list(raw_datasets["train"][label_column_name]) 310 | label_to_id = {l: i for i, l in enumerate(label_list)} 311 | 312 | num_labels = len(label_list) 313 | 314 | # Load pretrained model and tokenizer 315 | # 316 | # Distributed training: 317 | # The .from_pretrained methods guarantee that only one local process can concurrently 318 | # download model & vocab. 319 | config = AutoConfig.from_pretrained( 320 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 321 | num_labels=num_labels, 322 | finetuning_task=data_args.task_name, 323 | cache_dir=model_args.cache_dir, 324 | revision=model_args.model_revision, 325 | use_auth_token=True if model_args.use_auth_token else None, 326 | ) 327 | 328 | tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path 329 | if config.model_type in {"gpt2", "roberta"}: 330 | tokenizer = AutoTokenizer.from_pretrained( 331 | tokenizer_name_or_path, 332 | cache_dir=model_args.cache_dir, 333 | use_fast=True, 334 | revision=model_args.model_revision, 335 | use_auth_token=True if model_args.use_auth_token else None, 336 | add_prefix_space=True, 337 | ) 338 | else: 339 | tokenizer = AutoTokenizer.from_pretrained( 340 | tokenizer_name_or_path, 341 | cache_dir=model_args.cache_dir, 342 | use_fast=True, 343 | revision=model_args.model_revision, 344 | use_auth_token=True if model_args.use_auth_token else None, 345 | ) 346 | 347 | model = AutoModelForTokenClassification.from_pretrained( 348 | model_args.model_name_or_path, 349 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 350 | config=config, 351 | cache_dir=model_args.cache_dir, 352 | revision=model_args.model_revision, 353 | use_auth_token=True if model_args.use_auth_token else None, 354 | ) 355 | 356 | # Tokenizer check: this script requires a fast tokenizer. 357 | if not isinstance(tokenizer, PreTrainedTokenizerFast): 358 | raise ValueError( 359 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models " 360 | "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this " 361 | "requirement" 362 | ) 363 | 364 | # Model has labels -> use them. 365 | if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: 366 | if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)): 367 | # Reorganize `label_list` to match the ordering of the model. 368 | if labels_are_int: 369 | label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)} 370 | label_list = [model.config.id2label[i] for i in range(num_labels)] 371 | else: 372 | label_list = [model.config.id2label[i] for i in range(num_labels)] 373 | label_to_id = {l: i for i, l in enumerate(label_list)} 374 | else: 375 | logger.warning( 376 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 377 | f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}." 378 | "\nIgnoring the model labels as a result.", 379 | ) 380 | 381 | # Set the correspondences label/ID inside the model config 382 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 383 | model.config.id2label = {i: l for i, l in enumerate(label_list)} 384 | 385 | # Map that sends B-Xxx label to its I-Xxx counterpart 386 | b_to_i_label = [] 387 | for idx, label in enumerate(label_list): 388 | if label.startswith("B-") and label.replace("B-", "I-") in label_list: 389 | b_to_i_label.append(label_list.index(label.replace("B-", "I-"))) 390 | else: 391 | b_to_i_label.append(idx) 392 | 393 | # Preprocessing the dataset 394 | # Padding strategy 395 | padding = "max_length" if data_args.pad_to_max_length else False 396 | 397 | # Tokenize all texts and align the labels with them. 398 | def tokenize_and_align_labels(examples): 399 | tokenized_inputs = tokenizer( 400 | examples[text_column_name], 401 | padding=padding, 402 | truncation=True, 403 | max_length=data_args.max_seq_length, 404 | # We use this argument because the texts in our dataset are lists of words (with a label for each word). 405 | is_split_into_words=True, 406 | ) 407 | labels = [] 408 | for i, label in enumerate(examples[label_column_name]): 409 | word_ids = tokenized_inputs.word_ids(batch_index=i) 410 | previous_word_idx = None 411 | label_ids = [] 412 | for word_idx in word_ids: 413 | # Special tokens have a word id that is None. We set the label to -100 so they are automatically 414 | # ignored in the loss function. 415 | if word_idx is None: 416 | label_ids.append(-100) 417 | # We set the label for the first token of each word. 418 | elif word_idx != previous_word_idx: 419 | label_ids.append(label_to_id[label[word_idx]]) 420 | # For the other tokens in a word, we set the label to either the current label or -100, depending on 421 | # the label_all_tokens flag. 422 | else: 423 | if data_args.label_all_tokens: 424 | label_ids.append(b_to_i_label[label_to_id[label[word_idx]]]) 425 | else: 426 | label_ids.append(-100) 427 | previous_word_idx = word_idx 428 | 429 | labels.append(label_ids) 430 | tokenized_inputs["labels"] = labels 431 | return tokenized_inputs 432 | 433 | if training_args.do_train: 434 | if "train" not in raw_datasets: 435 | raise ValueError("--do_train requires a train dataset") 436 | train_dataset = raw_datasets["train"] 437 | if data_args.max_train_samples is not None: 438 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 439 | train_dataset = train_dataset.select(range(max_train_samples)) 440 | with training_args.main_process_first(desc="train dataset map pre-processing"): 441 | train_dataset = train_dataset.map( 442 | tokenize_and_align_labels, 443 | batched=True, 444 | num_proc=data_args.preprocessing_num_workers, 445 | load_from_cache_file=not data_args.overwrite_cache, 446 | desc="Running tokenizer on train dataset", 447 | ) 448 | 449 | if training_args.do_eval: 450 | if "validation" not in raw_datasets: 451 | raise ValueError("--do_eval requires a validation dataset") 452 | eval_dataset = raw_datasets["validation"] 453 | if data_args.max_eval_samples is not None: 454 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 455 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 456 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 457 | eval_dataset = eval_dataset.map( 458 | tokenize_and_align_labels, 459 | batched=True, 460 | num_proc=data_args.preprocessing_num_workers, 461 | load_from_cache_file=not data_args.overwrite_cache, 462 | desc="Running tokenizer on validation dataset", 463 | ) 464 | 465 | if training_args.do_predict: 466 | if "test" not in raw_datasets: 467 | raise ValueError("--do_predict requires a test dataset") 468 | predict_dataset = raw_datasets["test"] 469 | if data_args.max_predict_samples is not None: 470 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 471 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 472 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 473 | predict_dataset = predict_dataset.map( 474 | tokenize_and_align_labels, 475 | batched=True, 476 | num_proc=data_args.preprocessing_num_workers, 477 | load_from_cache_file=not data_args.overwrite_cache, 478 | desc="Running tokenizer on prediction dataset", 479 | ) 480 | 481 | # Data collator 482 | data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) 483 | 484 | # Metrics 485 | metric = load_metric("seqeval") 486 | 487 | def compute_metrics(p): 488 | predictions, labels = p 489 | predictions = np.argmax(predictions, axis=2) 490 | 491 | # Remove ignored index (special tokens) 492 | true_predictions = [ 493 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 494 | for prediction, label in zip(predictions, labels) 495 | ] 496 | true_labels = [ 497 | [label_list[l] for (p, l) in zip(prediction, label) if l != -100] 498 | for prediction, label in zip(predictions, labels) 499 | ] 500 | 501 | results = metric.compute(predictions=true_predictions, references=true_labels) 502 | if data_args.return_entity_level_metrics: 503 | # Unpack nested dictionaries 504 | final_results = {} 505 | for key, value in results.items(): 506 | if isinstance(value, dict): 507 | for n, v in value.items(): 508 | final_results[f"{key}_{n}"] = v 509 | else: 510 | final_results[key] = value 511 | return final_results 512 | else: 513 | return { 514 | "precision": results["overall_precision"], 515 | "recall": results["overall_recall"], 516 | "f1": results["overall_f1"], 517 | "accuracy": results["overall_accuracy"], 518 | } 519 | 520 | # Initialize our Trainer 521 | trainer = Trainer( 522 | model=model, 523 | args=training_args, 524 | train_dataset=train_dataset if training_args.do_train else None, 525 | eval_dataset=eval_dataset if training_args.do_eval else None, 526 | tokenizer=tokenizer, 527 | data_collator=data_collator, 528 | compute_metrics=compute_metrics, 529 | ) 530 | 531 | # Training 532 | if training_args.do_train: 533 | checkpoint = None 534 | if training_args.resume_from_checkpoint is not None: 535 | checkpoint = training_args.resume_from_checkpoint 536 | elif last_checkpoint is not None: 537 | checkpoint = last_checkpoint 538 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 539 | metrics = train_result.metrics 540 | trainer.save_model() # Saves the tokenizer too for easy upload 541 | 542 | max_train_samples = ( 543 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 544 | ) 545 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 546 | 547 | trainer.log_metrics("train", metrics) 548 | trainer.save_metrics("train", metrics) 549 | trainer.save_state() 550 | 551 | # Evaluation 552 | if training_args.do_eval: 553 | logger.info("*** Evaluate ***") 554 | 555 | metrics = trainer.evaluate() 556 | 557 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 558 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 559 | 560 | trainer.log_metrics("eval", metrics) 561 | trainer.save_metrics("eval", metrics) 562 | 563 | # Predict 564 | if training_args.do_predict: 565 | logger.info("*** Predict ***") 566 | 567 | predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict") 568 | predictions = np.argmax(predictions, axis=2) 569 | 570 | # Remove ignored index (special tokens) 571 | true_predictions = [ 572 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 573 | for prediction, label in zip(predictions, labels) 574 | ] 575 | 576 | trainer.log_metrics("predict", metrics) 577 | trainer.save_metrics("predict", metrics) 578 | 579 | # Save predictions 580 | output_predictions_file = os.path.join(training_args.output_dir, "predictions.txt") 581 | if trainer.is_world_process_zero(): 582 | with open(output_predictions_file, "w") as writer: 583 | for prediction in true_predictions: 584 | writer.write(" ".join(prediction) + "\n") 585 | 586 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "token-classification"} 587 | if data_args.dataset_name is not None: 588 | kwargs["dataset_tags"] = data_args.dataset_name 589 | if data_args.dataset_config_name is not None: 590 | kwargs["dataset_args"] = data_args.dataset_config_name 591 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 592 | else: 593 | kwargs["dataset"] = data_args.dataset_name 594 | 595 | if training_args.push_to_hub: 596 | trainer.push_to_hub(**kwargs) 597 | else: 598 | trainer.create_model_card(**kwargs) 599 | 600 | 601 | def _mp_fn(index): 602 | # For xla_spawn (TPUs) 603 | main() 604 | 605 | 606 | if __name__ == "__main__": 607 | main() 608 | -------------------------------------------------------------------------------- /run_glue.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE.""" 17 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments. 18 | 19 | import logging 20 | import os 21 | import random 22 | import sys 23 | from dataclasses import dataclass, field 24 | from typing import Optional 25 | 26 | import datasets 27 | import numpy as np 28 | from datasets import load_dataset, load_metric 29 | 30 | import t5_encoder 31 | import transformers 32 | from transformers import ( 33 | AutoConfig, 34 | AutoModelForSequenceClassification, 35 | AutoTokenizer, 36 | DataCollatorWithPadding, 37 | EvalPrediction, 38 | HfArgumentParser, 39 | PretrainedConfig, 40 | Trainer, 41 | TrainingArguments, 42 | default_data_collator, 43 | set_seed, 44 | ) 45 | from transformers.trainer_utils import get_last_checkpoint 46 | from transformers.utils import check_min_version 47 | from transformers.utils.versions import require_version 48 | 49 | 50 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 51 | check_min_version("4.18.0") 52 | 53 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 54 | 55 | task_to_keys = { 56 | "cola": ("sentence", None), 57 | "mnli": ("premise", "hypothesis"), 58 | "mrpc": ("sentence1", "sentence2"), 59 | "qnli": ("question", "sentence"), 60 | "qqp": ("question1", "question2"), 61 | "rte": ("sentence1", "sentence2"), 62 | "sst2": ("sentence", None), 63 | "stsb": ("sentence1", "sentence2"), 64 | "wnli": ("sentence1", "sentence2"), 65 | } 66 | 67 | logger = logging.getLogger(__name__) 68 | 69 | 70 | @dataclass 71 | class DataTrainingArguments: 72 | """ 73 | Arguments pertaining to what data we are going to input our model for training and eval. 74 | 75 | Using `HfArgumentParser` we can turn this class 76 | into argparse arguments to be able to specify them on 77 | the command line. 78 | """ 79 | 80 | task_name: Optional[str] = field( 81 | default=None, 82 | metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, 83 | ) 84 | dataset_name: Optional[str] = field( 85 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 86 | ) 87 | dataset_config_name: Optional[str] = field( 88 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 89 | ) 90 | max_seq_length: int = field( 91 | default=128, 92 | metadata={ 93 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 94 | "than this will be truncated, sequences shorter will be padded." 95 | }, 96 | ) 97 | overwrite_cache: bool = field( 98 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 99 | ) 100 | pad_to_max_length: bool = field( 101 | default=True, 102 | metadata={ 103 | "help": "Whether to pad all samples to `max_seq_length`. " 104 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 105 | }, 106 | ) 107 | max_train_samples: Optional[int] = field( 108 | default=None, 109 | metadata={ 110 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 111 | "value if set." 112 | }, 113 | ) 114 | max_eval_samples: Optional[int] = field( 115 | default=None, 116 | metadata={ 117 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 118 | "value if set." 119 | }, 120 | ) 121 | max_predict_samples: Optional[int] = field( 122 | default=None, 123 | metadata={ 124 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 125 | "value if set." 126 | }, 127 | ) 128 | train_file: Optional[str] = field( 129 | default=None, metadata={"help": "A csv or a json file containing the training data."} 130 | ) 131 | validation_file: Optional[str] = field( 132 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 133 | ) 134 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) 135 | 136 | def __post_init__(self): 137 | if self.task_name is not None: 138 | self.task_name = self.task_name.lower() 139 | if self.task_name not in task_to_keys.keys(): 140 | raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) 141 | elif self.dataset_name is not None: 142 | pass 143 | elif self.train_file is None or self.validation_file is None: 144 | raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") 145 | else: 146 | train_extension = self.train_file.split(".")[-1] 147 | assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." 148 | validation_extension = self.validation_file.split(".")[-1] 149 | assert ( 150 | validation_extension == train_extension 151 | ), "`validation_file` should have the same extension (csv or json) as `train_file`." 152 | 153 | 154 | @dataclass 155 | class ModelArguments: 156 | """ 157 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 158 | """ 159 | 160 | model_name_or_path: str = field( 161 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 162 | ) 163 | config_name: Optional[str] = field( 164 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 165 | ) 166 | tokenizer_name: Optional[str] = field( 167 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 168 | ) 169 | cache_dir: Optional[str] = field( 170 | default=None, 171 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 172 | ) 173 | use_fast_tokenizer: bool = field( 174 | default=True, 175 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 176 | ) 177 | model_revision: str = field( 178 | default="main", 179 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 180 | ) 181 | use_auth_token: bool = field( 182 | default=False, 183 | metadata={ 184 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 185 | "with private models)." 186 | }, 187 | ) 188 | 189 | 190 | def main(): 191 | # See all possible arguments in src/transformers/training_args.py 192 | # or by passing the --help flag to this script. 193 | # We now keep distinct sets of args, for a cleaner separation of concerns. 194 | 195 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 196 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 197 | # If we pass only one argument to the script and it's the path to a json file, 198 | # let's parse it to get our arguments. 199 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 200 | else: 201 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 202 | 203 | # Setup logging 204 | logging.basicConfig( 205 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 206 | datefmt="%m/%d/%Y %H:%M:%S", 207 | handlers=[logging.StreamHandler(sys.stdout)], 208 | ) 209 | 210 | log_level = training_args.get_process_log_level() 211 | logger.setLevel(log_level) 212 | datasets.utils.logging.set_verbosity(log_level) 213 | transformers.utils.logging.set_verbosity(log_level) 214 | transformers.utils.logging.enable_default_handler() 215 | transformers.utils.logging.enable_explicit_format() 216 | 217 | # Log on each process the small summary: 218 | logger.warning( 219 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 220 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 221 | ) 222 | logger.info(f"Training/evaluation parameters {training_args}") 223 | 224 | # Detecting last checkpoint. 225 | last_checkpoint = None 226 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 227 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 228 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 229 | raise ValueError( 230 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 231 | "Use --overwrite_output_dir to overcome." 232 | ) 233 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 234 | logger.info( 235 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 236 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 237 | ) 238 | 239 | # Set seed before initializing model. 240 | set_seed(training_args.seed) 241 | 242 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 243 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 244 | # 245 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 246 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 247 | # label if at least two columns are provided. 248 | # 249 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 250 | # single column. You can easily tweak this behavior (see below) 251 | # 252 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 253 | # download the dataset. 254 | if data_args.task_name is not None: 255 | # Downloading and loading a dataset from the hub. 256 | raw_datasets = load_dataset( 257 | "glue", 258 | data_args.task_name, 259 | cache_dir=model_args.cache_dir, 260 | use_auth_token=True if model_args.use_auth_token else None, 261 | ) 262 | elif data_args.dataset_name is not None: 263 | # Downloading and loading a dataset from the hub. 264 | raw_datasets = load_dataset( 265 | data_args.dataset_name, 266 | data_args.dataset_config_name, 267 | cache_dir=model_args.cache_dir, 268 | use_auth_token=True if model_args.use_auth_token else None, 269 | ) 270 | else: 271 | # Loading a dataset from your local files. 272 | # CSV/JSON training and evaluation files are needed. 273 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file} 274 | 275 | # Get the test dataset: you can provide your own CSV/JSON test file (see below) 276 | # when you use `do_predict` without specifying a GLUE benchmark task. 277 | if training_args.do_predict: 278 | if data_args.test_file is not None: 279 | train_extension = data_args.train_file.split(".")[-1] 280 | test_extension = data_args.test_file.split(".")[-1] 281 | assert ( 282 | test_extension == train_extension 283 | ), "`test_file` should have the same extension (csv or json) as `train_file`." 284 | data_files["test"] = data_args.test_file 285 | else: 286 | raise ValueError("Need either a GLUE task or a test file for `do_predict`.") 287 | 288 | for key in data_files.keys(): 289 | logger.info(f"load a local file for {key}: {data_files[key]}") 290 | 291 | if data_args.train_file.endswith(".csv"): 292 | # Loading a dataset from local csv files 293 | raw_datasets = load_dataset( 294 | "csv", 295 | data_files=data_files, 296 | cache_dir=model_args.cache_dir, 297 | use_auth_token=True if model_args.use_auth_token else None, 298 | ) 299 | else: 300 | # Loading a dataset from local json files 301 | raw_datasets = load_dataset( 302 | "json", 303 | data_files=data_files, 304 | cache_dir=model_args.cache_dir, 305 | use_auth_token=True if model_args.use_auth_token else None, 306 | ) 307 | # See more about loading any type of standard or custom dataset at 308 | # https://huggingface.co/docs/datasets/loading_datasets.html. 309 | 310 | # Labels 311 | if data_args.task_name is not None: 312 | is_regression = data_args.task_name == "stsb" 313 | if not is_regression: 314 | label_list = raw_datasets["train"].features["label"].names 315 | num_labels = len(label_list) 316 | else: 317 | num_labels = 1 318 | else: 319 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 320 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] 321 | if is_regression: 322 | num_labels = 1 323 | else: 324 | # A useful fast method: 325 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 326 | label_list = raw_datasets["train"].unique("label") 327 | label_list.sort() # Let's sort it for determinism 328 | num_labels = len(label_list) 329 | 330 | # Load pretrained model and tokenizer 331 | # 332 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 333 | # download model & vocab. 334 | config = AutoConfig.from_pretrained( 335 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 336 | num_labels=num_labels, 337 | finetuning_task=data_args.task_name, 338 | cache_dir=model_args.cache_dir, 339 | revision=model_args.model_revision, 340 | use_auth_token=True if model_args.use_auth_token else None, 341 | ) 342 | tokenizer = AutoTokenizer.from_pretrained( 343 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 344 | cache_dir=model_args.cache_dir, 345 | use_fast=model_args.use_fast_tokenizer, 346 | revision=model_args.model_revision, 347 | use_auth_token=True if model_args.use_auth_token else None, 348 | ) 349 | model = AutoModelForSequenceClassification.from_pretrained( 350 | model_args.model_name_or_path, 351 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 352 | config=config, 353 | cache_dir=model_args.cache_dir, 354 | revision=model_args.model_revision, 355 | use_auth_token=True if model_args.use_auth_token else None, 356 | ) 357 | 358 | # Preprocessing the raw_datasets 359 | if data_args.task_name is not None: 360 | sentence1_key, sentence2_key = task_to_keys[data_args.task_name] 361 | else: 362 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 363 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] 364 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 365 | sentence1_key, sentence2_key = "sentence1", "sentence2" 366 | else: 367 | if len(non_label_column_names) >= 2: 368 | sentence1_key, sentence2_key = non_label_column_names[:2] 369 | else: 370 | sentence1_key, sentence2_key = non_label_column_names[0], None 371 | 372 | # Padding strategy 373 | if data_args.pad_to_max_length: 374 | padding = "max_length" 375 | else: 376 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 377 | padding = False 378 | 379 | # Some models have set the order of the labels to use, so let's make sure we do use it. 380 | label_to_id = None 381 | if ( 382 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 383 | and data_args.task_name is not None 384 | and not is_regression 385 | ): 386 | # Some have all caps in their config, some don't. 387 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 388 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 389 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} 390 | else: 391 | logger.warning( 392 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 393 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 394 | "\nIgnoring the model labels as a result.", 395 | ) 396 | elif data_args.task_name is None and not is_regression: 397 | label_to_id = {v: i for i, v in enumerate(label_list)} 398 | 399 | if label_to_id is not None: 400 | model.config.label2id = label_to_id 401 | model.config.id2label = {id: label for label, id in config.label2id.items()} 402 | elif data_args.task_name is not None and not is_regression: 403 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 404 | model.config.id2label = {id: label for label, id in config.label2id.items()} 405 | 406 | if data_args.max_seq_length > tokenizer.model_max_length: 407 | logger.warning( 408 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 409 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 410 | ) 411 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 412 | 413 | def preprocess_function(examples): 414 | # Tokenize the texts 415 | args = ( 416 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 417 | ) 418 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 419 | 420 | # Map labels to IDs (not necessary for GLUE tasks) 421 | if label_to_id is not None and "label" in examples: 422 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] 423 | return result 424 | 425 | with training_args.main_process_first(desc="dataset map pre-processing"): 426 | raw_datasets = raw_datasets.map( 427 | preprocess_function, 428 | batched=True, 429 | load_from_cache_file=not data_args.overwrite_cache, 430 | desc="Running tokenizer on dataset", 431 | ) 432 | if training_args.do_train: 433 | if "train" not in raw_datasets: 434 | raise ValueError("--do_train requires a train dataset") 435 | train_dataset = raw_datasets["train"] 436 | if data_args.max_train_samples is not None: 437 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 438 | train_dataset = train_dataset.select(range(max_train_samples)) 439 | 440 | if training_args.do_eval: 441 | if "validation" not in raw_datasets and "validation_matched" not in raw_datasets: 442 | raise ValueError("--do_eval requires a validation dataset") 443 | eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] 444 | if data_args.max_eval_samples is not None: 445 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 446 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 447 | 448 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: 449 | if "test" not in raw_datasets and "test_matched" not in raw_datasets: 450 | raise ValueError("--do_predict requires a test dataset") 451 | predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"] 452 | if data_args.max_predict_samples is not None: 453 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 454 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 455 | 456 | # Log a few random samples from the training set: 457 | if training_args.do_train: 458 | for index in random.sample(range(len(train_dataset)), 3): 459 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 460 | 461 | # Get the metric function 462 | if data_args.task_name is not None: 463 | metric = load_metric("glue", data_args.task_name) 464 | else: 465 | metric = load_metric("accuracy") 466 | 467 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 468 | # predictions and label_ids field) and has to return a dictionary string to float. 469 | def compute_metrics(p: EvalPrediction): 470 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 471 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 472 | if data_args.task_name is not None: 473 | result = metric.compute(predictions=preds, references=p.label_ids) 474 | if len(result) > 1: 475 | result["combined_score"] = np.mean(list(result.values())).item() 476 | return result 477 | elif is_regression: 478 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()} 479 | else: 480 | return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} 481 | 482 | # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if 483 | # we already did the padding. 484 | if data_args.pad_to_max_length: 485 | data_collator = default_data_collator 486 | elif training_args.fp16: 487 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 488 | else: 489 | data_collator = None 490 | 491 | # Initialize our Trainer 492 | trainer = Trainer( 493 | model=model, 494 | args=training_args, 495 | train_dataset=train_dataset if training_args.do_train else None, 496 | eval_dataset=eval_dataset if training_args.do_eval else None, 497 | compute_metrics=compute_metrics, 498 | tokenizer=tokenizer, 499 | data_collator=data_collator, 500 | ) 501 | 502 | # Training 503 | if training_args.do_train: 504 | checkpoint = None 505 | if training_args.resume_from_checkpoint is not None: 506 | checkpoint = training_args.resume_from_checkpoint 507 | elif last_checkpoint is not None: 508 | checkpoint = last_checkpoint 509 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 510 | metrics = train_result.metrics 511 | max_train_samples = ( 512 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 513 | ) 514 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 515 | 516 | trainer.save_model() # Saves the tokenizer too for easy upload 517 | 518 | trainer.log_metrics("train", metrics) 519 | trainer.save_metrics("train", metrics) 520 | trainer.save_state() 521 | 522 | # Evaluation 523 | if training_args.do_eval: 524 | logger.info("*** Evaluate ***") 525 | 526 | # Loop to handle MNLI double evaluation (matched, mis-matched) 527 | tasks = [data_args.task_name] 528 | eval_datasets = [eval_dataset] 529 | if data_args.task_name == "mnli": 530 | tasks.append("mnli-mm") 531 | eval_datasets.append(raw_datasets["validation_mismatched"]) 532 | combined = {} 533 | 534 | for eval_dataset, task in zip(eval_datasets, tasks): 535 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 536 | 537 | max_eval_samples = ( 538 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 539 | ) 540 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 541 | 542 | if task == "mnli-mm": 543 | metrics = {k + "_mm": v for k, v in metrics.items()} 544 | if task is not None and "mnli" in task: 545 | combined.update(metrics) 546 | 547 | trainer.log_metrics("eval", metrics) 548 | trainer.save_metrics("eval", combined if task is not None and "mnli" in task else metrics) 549 | 550 | if training_args.do_predict: 551 | logger.info("*** Predict ***") 552 | 553 | # Loop to handle MNLI double evaluation (matched, mis-matched) 554 | tasks = [data_args.task_name] 555 | predict_datasets = [predict_dataset] 556 | if data_args.task_name == "mnli": 557 | tasks.append("mnli-mm") 558 | predict_datasets.append(raw_datasets["test_mismatched"]) 559 | 560 | for predict_dataset, task in zip(predict_datasets, tasks): 561 | # Removing the `label` columns because it contains -1 and Trainer won't like that. 562 | predict_dataset = predict_dataset.remove_columns("label") 563 | predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions 564 | predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) 565 | 566 | output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt") 567 | if trainer.is_world_process_zero(): 568 | with open(output_predict_file, "w") as writer: 569 | logger.info(f"***** Predict results {task} *****") 570 | writer.write("index\tprediction\n") 571 | for index, item in enumerate(predictions): 572 | if is_regression: 573 | writer.write(f"{index}\t{item:3.3f}\n") 574 | else: 575 | item = label_list[item] 576 | writer.write(f"{index}\t{item}\n") 577 | 578 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"} 579 | if data_args.task_name is not None: 580 | kwargs["language"] = "en" 581 | kwargs["dataset_tags"] = "glue" 582 | kwargs["dataset_args"] = data_args.task_name 583 | kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}" 584 | 585 | if training_args.push_to_hub: 586 | trainer.push_to_hub(**kwargs) 587 | else: 588 | trainer.create_model_card(**kwargs) 589 | 590 | 591 | def _mp_fn(index): 592 | # For xla_spawn (TPUs) 593 | main() 594 | 595 | 596 | if __name__ == "__main__": 597 | main() 598 | --------------------------------------------------------------------------------