├── .github └── workflows │ ├── release.yaml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── conf.py ├── grouphug.rst └── index.rst ├── examples ├── from-readme.ipynb ├── glue.ipynb ├── lm.ipynb ├── neox.ipynb ├── sentiment.ipynb └── utils.py ├── grouphug ├── __init__.py ├── collator.py ├── config.py ├── dataset_collection.py ├── dataset_formatter.py ├── heads │ ├── __init__.py │ ├── base.py │ ├── classification.py │ └── lm.py ├── model.py ├── trainer.py └── utils.py ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── automodel.py ├── conftest.py ├── test_dataset_collection.py ├── test_dataset_formatter.py ├── test_model.py └── test_train.py /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | 14 | - name: Set up Python 15 | uses: actions/setup-python@v1 16 | with: 17 | python-version: 3.8 18 | 19 | - name: Install dependencies 20 | run: | 21 | pip install poetry twine 22 | poetry install 23 | 24 | - name: Run tests 25 | run: poetry run pytest -v -s tests 26 | 27 | - name: Build package 28 | run: poetry build 29 | 30 | - name: Release to PyPI 31 | env: 32 | TWINE_USERNAME: ${{ secrets.PYPI_USER }} 33 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 34 | run: twine upload --verbose dist/* 35 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | pull_request: 5 | branches: [ master ] 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | python-version: ["3.8", "3.9"] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Install dependencies 25 | run: | 26 | pip install poetry 27 | poetry install 28 | 29 | - name: Run tests 30 | run: poetry run pytest -v -s tests 31 | 32 | - name: Build package 33 | run: poetry build 34 | 35 | - name: Build docs 36 | run: cd docs && poetry run make html -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # scratch notebooks for development and testing 2 | dev*.ipynb 3 | dev*.py 4 | # huggingface caches, trained model examples and such 5 | output/ 6 | models/ 7 | tests/output 8 | tests/models 9 | tmp_trainer 10 | 11 | # pycharm 12 | .idea 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | *.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | pip-wheel-metadata/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | *.py,cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | db.sqlite3-journal 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - repo: https://github.com/floatingpurr/sync_with_poetry 4 | rev: 0.2.1 5 | hooks: 6 | - id: sync_with_poetry 7 | - repo: https://github.com/psf/black 8 | rev: 22.10.0 9 | hooks: 10 | - id: black 11 | - repo: https://github.com/timothycrosley/isort 12 | rev: 5.10.1 13 | hooks: 14 | - id: isort 15 | name: isort (python) 16 | - repo: local 17 | hooks: 18 | - id: jupyter-nb-clear-output 19 | name: jupyter-nb-clear-output 20 | files: examples/.*\.ipynb 21 | stages: [commit] 22 | language: system 23 | entry: jupyter nbconvert --ClearOutputPreprocessor.enabled=True --inplace 24 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Feature requests and PRs 2 | 3 | If you want to add a complex feature or change which is not already mentioned, please open an issue or discussion topic to discuss. 4 | For simple additions and bug fixes you can open a PR directly. 5 | 6 | ## Formatting and pre-commit hooks 7 | 8 | To ensure your PR is properly formatted, install pre-commit hooks using `pre-commit install` 9 | 10 | This will run black, isort, and clear any output from example notebooks when committing. 11 | 12 | # Notes on Grouphug internals 13 | 14 | This section contains notes on implementation details of huggingface transformers and grouphug. 15 | 16 | ## Computing metrics 17 | 18 | Computing metrics has been changed to be passed extra parameters, allowing the metrics function to know what data is passed. 19 | The function in examples/utils works as a fairly generic implementation and could be added as a default in future versions. 20 | 21 | # Notes on Huggingface Transformers internals 22 | 23 | These are largely my own notes on the internals of the transformers package and how they interact. 24 | 25 | ## Tokenizers 26 | 27 | * Tokenizers have `.model_input_names` to determined what to pad, e.g. `['input_ids, 'token_type_ids','attention_mask']` 28 | * However, these are mostly ignored except for the first, and `_pad` has a hardcoded check for `['input_ids, 'token_type_ids','attention_mask','special_tokens_mask']` 29 | * Tokenizers have model_max_len, which is often unset and left at it's default of LARGE_INTEGER 30 | * Dynamic padding is done by various collators via Tokenizer.pad, but this does not truncate. 31 | 32 | ## Trainer 33 | 34 | * Model outputs are ordered dicts 35 | * All keys not named 'loss' are assumed to be logits 36 | * Somehow one of the GPT models still returns two losses. 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This repository includes code copied and/or modified from the HuggingFace Transformers library, which is also licensed under the Apache 2 license. 2 | See details at https://github.com/huggingface/transformers/ 3 | 4 | ----------------------------------------------------------------------------------------- 5 | Aside from the above, the license for all other content in this repository is as follows: 6 | ----------------------------------------------------------------------------------------- 7 | 8 | Copyright 2022 Chatdesk Inc. All rights reserved. 9 | 10 | Apache License 11 | Version 2.0, January 2004 12 | http://www.apache.org/licenses/ 13 | 14 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 15 | 16 | 1. Definitions. 17 | 18 | "License" shall mean the terms and conditions for use, reproduction, 19 | and distribution as defined by Sections 1 through 9 of this document. 20 | 21 | "Licensor" shall mean the copyright owner or entity authorized by 22 | the copyright owner that is granting the License. 23 | 24 | "Legal Entity" shall mean the union of the acting entity and all 25 | other entities that control, are controlled by, or are under common 26 | control with that entity. For the purposes of this definition, 27 | "control" means (i) the power, direct or indirect, to cause the 28 | direction or management of such entity, whether by contract or 29 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 30 | outstanding shares, or (iii) beneficial ownership of such entity. 31 | 32 | "You" (or "Your") shall mean an individual or Legal Entity 33 | exercising permissions granted by this License. 34 | 35 | "Source" form shall mean the preferred form for making modifications, 36 | including but not limited to software source code, documentation 37 | source, and configuration files. 38 | 39 | "Object" form shall mean any form resulting from mechanical 40 | transformation or translation of a Source form, including but 41 | not limited to compiled object code, generated documentation, 42 | and conversions to other media types. 43 | 44 | "Work" shall mean the work of authorship, whether in Source or 45 | Object form, made available under the License, as indicated by a 46 | copyright notice that is included in or attached to the work 47 | (an example is provided in the Appendix below). 48 | 49 | "Derivative Works" shall mean any work, whether in Source or Object 50 | form, that is based on (or derived from) the Work and for which the 51 | editorial revisions, annotations, elaborations, or other modifications 52 | represent, as a whole, an original work of authorship. For the purposes 53 | of this License, Derivative Works shall not include works that remain 54 | separable from, or merely link (or bind by name) to the interfaces of, 55 | the Work and Derivative Works thereof. 56 | 57 | "Contribution" shall mean any work of authorship, including 58 | the original version of the Work and any modifications or additions 59 | to that Work or Derivative Works thereof, that is intentionally 60 | submitted to Licensor for inclusion in the Work by the copyright owner 61 | or by an individual or Legal Entity authorized to submit on behalf of 62 | the copyright owner. For the purposes of this definition, "submitted" 63 | means any form of electronic, verbal, or written communication sent 64 | to the Licensor or its representatives, including but not limited to 65 | communication on electronic mailing lists, source code control systems, 66 | and issue tracking systems that are managed by, or on behalf of, the 67 | Licensor for the purpose of discussing and improving the Work, but 68 | excluding communication that is conspicuously marked or otherwise 69 | designated in writing by the copyright owner as "Not a Contribution." 70 | 71 | "Contributor" shall mean Licensor and any individual or Legal Entity 72 | on behalf of whom a Contribution has been received by Licensor and 73 | subsequently incorporated within the Work. 74 | 75 | 2. Grant of Copyright License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | copyright license to reproduce, prepare Derivative Works of, 79 | publicly display, publicly perform, sublicense, and distribute the 80 | Work and such Derivative Works in Source or Object form. 81 | 82 | 3. Grant of Patent License. Subject to the terms and conditions of 83 | this License, each Contributor hereby grants to You a perpetual, 84 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 85 | (except as stated in this section) patent license to make, have made, 86 | use, offer to sell, sell, import, and otherwise transfer the Work, 87 | where such license applies only to those patent claims licensable 88 | by such Contributor that are necessarily infringed by their 89 | Contribution(s) alone or by combination of their Contribution(s) 90 | with the Work to which such Contribution(s) was submitted. If You 91 | institute patent litigation against any entity (including a 92 | cross-claim or counterclaim in a lawsuit) alleging that the Work 93 | or a Contribution incorporated within the Work constitutes direct 94 | or contributory patent infringement, then any patent licenses 95 | granted to You under this License for that Work shall terminate 96 | as of the date such litigation is filed. 97 | 98 | 4. Redistribution. You may reproduce and distribute copies of the 99 | Work or Derivative Works thereof in any medium, with or without 100 | modifications, and in Source or Object form, provided that You 101 | meet the following conditions: 102 | 103 | (a) You must give any other recipients of the Work or 104 | Derivative Works a copy of this License; and 105 | 106 | (b) You must cause any modified files to carry prominent notices 107 | stating that You changed the files; and 108 | 109 | (c) You must retain, in the Source form of any Derivative Works 110 | that You distribute, all copyright, patent, trademark, and 111 | attribution notices from the Source form of the Work, 112 | excluding those notices that do not pertain to any part of 113 | the Derivative Works; and 114 | 115 | (d) If the Work includes a "NOTICE" text file as part of its 116 | distribution, then any Derivative Works that You distribute must 117 | include a readable copy of the attribution notices contained 118 | within such NOTICE file, excluding those notices that do not 119 | pertain to any part of the Derivative Works, in at least one 120 | of the following places: within a NOTICE text file distributed 121 | as part of the Derivative Works; within the Source form or 122 | documentation, if provided along with the Derivative Works; or, 123 | within a display generated by the Derivative Works, if and 124 | wherever such third-party notices normally appear. The contents 125 | of the NOTICE file are for informational purposes only and 126 | do not modify the License. You may add Your own attribution 127 | notices within Derivative Works that You distribute, alongside 128 | or as an addendum to the NOTICE text from the Work, provided 129 | that such additional attribution notices cannot be construed 130 | as modifying the License. 131 | 132 | You may add Your own copyright statement to Your modifications and 133 | may provide additional or different license terms and conditions 134 | for use, reproduction, or distribution of Your modifications, or 135 | for any such Derivative Works as a whole, provided Your use, 136 | reproduction, and distribution of the Work otherwise complies with 137 | the conditions stated in this License. 138 | 139 | 5. Submission of Contributions. Unless You explicitly state otherwise, 140 | any Contribution intentionally submitted for inclusion in the Work 141 | by You to the Licensor shall be under the terms and conditions of 142 | this License, without any additional terms or conditions. 143 | Notwithstanding the above, nothing herein shall supersede or modify 144 | the terms of any separate license agreement you may have executed 145 | with Licensor regarding such Contributions. 146 | 147 | 6. Trademarks. This License does not grant permission to use the trade 148 | names, trademarks, service marks, or product names of the Licensor, 149 | except as required for reasonable and customary use in describing the 150 | origin of the Work and reproducing the content of the NOTICE file. 151 | 152 | 7. Disclaimer of Warranty. Unless required by applicable law or 153 | agreed to in writing, Licensor provides the Work (and each 154 | Contributor provides its Contributions) on an "AS IS" BASIS, 155 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 156 | implied, including, without limitation, any warranties or conditions 157 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 158 | PARTICULAR PURPOSE. You are solely responsible for determining the 159 | appropriateness of using or redistributing the Work and assume any 160 | risks associated with Your exercise of permissions under this License. 161 | 162 | 8. Limitation of Liability. In no event and under no legal theory, 163 | whether in tort (including negligence), contract, or otherwise, 164 | unless required by applicable law (such as deliberate and grossly 165 | negligent acts) or agreed to in writing, shall any Contributor be 166 | liable to You for damages, including any direct, indirect, special, 167 | incidental, or consequential damages of any character arising as a 168 | result of this License or out of the use or inability to use the 169 | Work (including but not limited to damages for loss of goodwill, 170 | work stoppage, computer failure or malfunction, or any and all 171 | other commercial damages or losses), even if such Contributor 172 | has been advised of the possibility of such damages. 173 | 174 | 9. Accepting Warranty or Additional Liability. While redistributing 175 | the Work or Derivative Works thereof, You may choose to offer, 176 | and charge a fee for, acceptance of support, warranty, indemnity, 177 | or other liability obligations and/or rights consistent with this 178 | License. However, in accepting such obligations, You may act only 179 | on Your own behalf and on Your sole responsibility, not on behalf 180 | of any other Contributor, and only if You agree to indemnify, 181 | defend, and hold each Contributor harmless for any liability 182 | incurred by, or claims asserted against, such Contributor by reason 183 | of your accepting any such warranty or additional liability. 184 | 185 | END OF TERMS AND CONDITIONS 186 | 187 | APPENDIX: How to apply the Apache License to your work. 188 | 189 | To apply the Apache License to your work, attach the following 190 | boilerplate notice, with the fields enclosed by brackets "[]" 191 | replaced with your own identifying information. (Don't include 192 | the brackets!) The text should be enclosed in the appropriate 193 | comment syntax for the file format. We also recommend that a 194 | file or class name and description of purpose be included on the 195 | same "printed page" as the copyright notice for easier 196 | identification within third-party archives. 197 | 198 | Copyright [yyyy] [name of copyright owner] 199 | 200 | Licensed under the Apache License, Version 2.0 (the "License"); 201 | you may not use this file except in compliance with the License. 202 | You may obtain a copy of the License at 203 | 204 | http://www.apache.org/licenses/LICENSE-2.0 205 | 206 | Unless required by applicable law or agreed to in writing, software 207 | distributed under the License is distributed on an "AS IS" BASIS, 208 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 209 | See the License for the specific language governing permissions and 210 | limitations under the License. 211 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # grouphug 3 | 4 | GroupHug is a library with extensions to 🤗 transformers for multitask language modelling. 5 | In addition, it contains utilities that ease data preparation, training, and inference. 6 | 7 | ## Project Moved 8 | 9 | Grouphug maintenance and future versions have moved to [my personal repository](https://github.com/sanderland/grouphug). 10 | 11 | ## Overview 12 | 13 | The package is optimized for training a single language model to make quick and robust predictions for a wide variety of related tasks at once, 14 | as well as to investigate the regularizing effect of training a language modelling task at the same time. 15 | 16 | You can train on multiple datasets, with each dataset containing an arbitrary subset of your tasks. Supported tasks include: 17 | 18 | * A single language modelling task (Masked language modelling, Masked token detection, Causal language modelling). 19 | * The default collator included handles most preprocessing for these heads automatically. 20 | * Any number of classification tasks, including single- and multi-label classification and regression 21 | * A utility function that automatically creates a classification head from your data. 22 | * Additional options such as hidden layer size, additional input variables, and class weights. 23 | * You can also define your own model heads. 24 | 25 | ## Quick Start 26 | 27 | The project is based on Python 3.8+ and PyTorch 1.10+. To install it, simply use: 28 | 29 | `pip install grouphug` 30 | 31 | ### Documentation 32 | 33 | Documentation can be generated from docstrings using `make html` in the `docs` directory, but this is not yet on a hosted site. 34 | 35 | ### Example usage 36 | 37 | ```python 38 | import pandas as pd 39 | from datasets import load_dataset 40 | from transformers import AutoTokenizer 41 | 42 | from grouphug import AutoMultiTaskModel, ClassificationHeadConfig, DatasetFormatter, LMHeadConfig, MultiTaskTrainer 43 | 44 | # load some data. 'label' gets renamed in huggingface, so is better avoided as a feature name. 45 | task_one = load_dataset("tweet_eval",'emoji').rename_column("label", "tweet_label") 46 | both_tasks = pd.DataFrame({"text": ["yay :)", "booo!"], "sentiment": ["pos", "neg"], "tweet_label": [0,14]}) 47 | 48 | # create a tokenizer 49 | base_model = "prajjwal1/bert-tiny" 50 | tokenizer = AutoTokenizer.from_pretrained(base_model) 51 | 52 | # preprocess your data: tokenization, preparing class variables 53 | formatter = DatasetFormatter().tokenize().encode("sentiment") 54 | # data converted to a DatasetCollection: essentially a dict of DatasetDict 55 | data = formatter.apply({"one": task_one, "both": both_tasks}, tokenizer=tokenizer, test_size=0.05) 56 | 57 | # define which model heads you would like 58 | head_configs = [ 59 | LMHeadConfig(weight=0.1), # default is BERT-style masked language modelling 60 | ClassificationHeadConfig.from_data(data, "sentiment"), # detects dimensions and type 61 | ClassificationHeadConfig.from_data(data, "tweet_label"), # detects dimensions and type 62 | ] 63 | # create the model, optionally saving the tokenizer and formatter along with it 64 | model = AutoMultiTaskModel.from_pretrained(base_model, head_configs, formatter=formatter, tokenizer=tokenizer) 65 | # create the trainer 66 | trainer = MultiTaskTrainer( 67 | model=model, 68 | tokenizer=tokenizer, 69 | train_data=data[:, "train"], 70 | eval_data=data[["one"], "test"], 71 | eval_heads={"one": ["tweet_label"]}, # limit evaluation to one classification task 72 | ) 73 | trainer.train() 74 | ``` 75 | 76 | ### Tutorials 77 | 78 | See [examples](./examples) for a few notebooks that demonstrate the key features. 79 | 80 | ## Supported Models 81 | 82 | The package has support for the following base models: 83 | 84 | * Bert, DistilBert, Roberta/DistilRoberta, XLM-Roberta 85 | * Deberta/DebertaV2 86 | * Electra 87 | * GPT2, GPT-J, GPT-NeoX, OPT 88 | 89 | Extending it to support other models is possible by simply inheriting from `_BaseMultiTaskModel`, although language modelling head weights may not always load. 90 | 91 | ## Limitations 92 | 93 | * The package only supports PyTorch, and will not work with other frameworks. There are no plans to change this. 94 | * Grouphug was developed and tested with 🤗 transformers 4.19-4.22. We will aim to test and keep compatibility with the latest version, but it is still recommended to lock the latest working versions. 95 | 96 | See the [contributing page](CONTRIBUTING.md) if you are interested in contributing. 97 | 98 | ## License 99 | 100 | grouphug was developed by [Chatdesk](http://www.chatdesk.com) and is licensed under the Apache 2 [license](LICENSE). 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import re 15 | import sys 16 | 17 | sys.path.insert(0, os.path.abspath("..")) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = "grouphug" 23 | copyright = "2022, Chatdesk" 24 | author = "Chatdesk" 25 | release = re.search('^version[\s"]+=(.*)"', open("../pyproject.toml").read(), re.M).group(1).strip() 26 | 27 | # -- General configuration --------------------------------------------------- 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon"] 33 | 34 | 35 | # Add any paths that contain templates here, relative to this directory. 36 | templates_path = ["_templates"] 37 | 38 | # List of patterns, relative to source directory, that match files and 39 | # directories to ignore when looking for source files. 40 | # This pattern also affects html_static_path and html_extra_path. 41 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 42 | 43 | 44 | # -- Options for HTML output ------------------------------------------------- 45 | 46 | # The theme to use for HTML and HTML Help pages. See the documentation for 47 | # a list of builtin themes. 48 | # 49 | # html_theme = 'alabaster' 50 | html_theme = "sphinx_rtd_theme" 51 | 52 | # Add any paths that contain custom static files (such as style sheets) here, 53 | # relative to this directory. They are copied after the builtin static files, 54 | # so a file named "default.css" will overwrite the builtin "default.css". 55 | html_static_path = ["_static"] 56 | 57 | pygments_style = "sphinx" 58 | -------------------------------------------------------------------------------- /docs/grouphug.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | See the 'examples' directory in github for some examples that will quickly get you up to speed. 5 | 6 | Model classes 7 | ============= 8 | 9 | AutoMultiTaskModel 10 | ------------------ 11 | 12 | .. autoclass:: grouphug.AutoMultiTaskModel 13 | :members: 14 | :member-order: bysource 15 | 16 | Individual Model Classes and utilities 17 | -------------------------------------- 18 | 19 | Most of the model classes here would typically be initialized using `AutoMultiTaskModel.from_pretrained`. 20 | 21 | .. automodule:: grouphug.model 22 | :members: 23 | :private-members: _BaseMultiTaskModel 24 | :member-order: bysource 25 | 26 | Model heads 27 | =========== 28 | 29 | Classification 30 | -------------- 31 | 32 | .. autoclass:: grouphug.ClassificationHeadConfig 33 | :members: 34 | :member-order: bysource 35 | 36 | 37 | Language modelling 38 | ------------------ 39 | 40 | .. autoclass:: grouphug.LMHeadConfig 41 | :members: 42 | :member-order: bysource 43 | 44 | 45 | DatasetCollection and DatasetFormatter 46 | ====================================== 47 | Typically you would set up a `DatasetFormatter` in training, whose `apply` method returns a `DatasetCollection`. 48 | In stand-alone inference and evaluation you can also, pass the same arguments (`data, test_size`), directly to the `DatasetCollection` constructor. 49 | 50 | 51 | .. autoclass:: grouphug.DatasetFormatter 52 | :members: 53 | :member-order: bysource 54 | 55 | 56 | .. autoclass:: grouphug.DatasetCollection 57 | :members: 58 | :member-order: bysource 59 | 60 | 61 | MultiTaskTrainer 62 | ================ 63 | 64 | .. autoclass:: grouphug.MultiTaskTrainer 65 | :members: 66 | :member-order: bysource 67 | 68 | AutoCollator 69 | ============ 70 | This is the default collator for MultiTaskTrainer 71 | 72 | .. autoclass:: grouphug.collator.AutoCollator 73 | :members: 74 | :member-order: bysource 75 | 76 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. grouphug documentation master file, created by 2 | sphinx-quickstart on Wed Jun 8 13:30:02 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to grouphug's documentation! 7 | ==================================== 8 | 9 | Contents 10 | ^^^^^^^^ 11 | .. toctree:: 12 | grouphug 13 | -------------------------------------------------------------------------------- /examples/from-readme.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a1b50547-7eca-48e4-b73d-bb046ac882aa", 6 | "metadata": {}, 7 | "source": [ 8 | "# This example is mainly to test and run the short example in the README file\n", 9 | "Other examples have a little more explanation. However, the `compute_classification_metrics` function may be worth a look." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "id": "96963a2d-60ce-4326-b1e6-18c4c64cba1b", 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import sys\n", 20 | "\n", 21 | "sys.path.append(\"..\") # ensure we can run examples as-is in the package's poetry env" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "8c4a4165-5e43-4a8c-b6ae-ec0518b6c757", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import pandas as pd\n", 32 | "from datasets import load_dataset\n", 33 | "from transformers import AutoTokenizer, TrainingArguments\n", 34 | "\n", 35 | "from grouphug import AutoMultiTaskModel, ClassificationHeadConfig, DatasetFormatter, LMHeadConfig, MultiTaskTrainer\n", 36 | "from utils import compute_classification_metrics" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "id": "9e312f42-9c81-4ce0-862d-b52c00ac28aa", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "# load some data. 'label' gets renamed in huggingface, so is better avoided as a feature name.\n", 47 | "task_one = load_dataset(\"tweet_eval\", \"emoji\").rename_column(\"label\", \"tweet_label\")\n", 48 | "both_tasks = pd.DataFrame({\"text\": [\"yay :)\", \"booo!\"], \"sentiment\": [\"pos\", \"neg\"], \"tweet_label\": [0, 14]})\n", 49 | "\n", 50 | "# create a tokenizer\n", 51 | "base_model = \"prajjwal1/bert-tiny\"\n", 52 | "tokenizer = AutoTokenizer.from_pretrained(base_model)\n", 53 | "\n", 54 | "# preprocess your data: tokenization, preparing class variables\n", 55 | "formatter = DatasetFormatter().tokenize().encode(\"sentiment\")\n", 56 | "# data converted to a DatasetCollection: essentially a dict of DatasetDict\n", 57 | "data = formatter.apply({\"one\": task_one, \"both\": both_tasks}, tokenizer=tokenizer, test_size=0.05)\n", 58 | "\n", 59 | "# define which model heads you would like\n", 60 | "head_configs = [\n", 61 | " LMHeadConfig(weight=0.1), # default is BERT-style masked language modelling\n", 62 | " ClassificationHeadConfig.from_data(data, \"sentiment\"), # detects dimensions and type\n", 63 | " ClassificationHeadConfig.from_data(data, \"tweet_label\"), # detects dimensions and type\n", 64 | "]\n", 65 | "# create the model, optionally saving the tokenizer and formatter along with it\n", 66 | "model = AutoMultiTaskModel.from_pretrained(base_model, head_configs, formatter=formatter, tokenizer=tokenizer)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "59f6b3bd-c570-48c1-a3f5-b19cc56ba521", 72 | "metadata": {}, 73 | "source": [ 74 | "## Create the trainer and train the model" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "5c94eca6-980a-40e3-acb6-98eba17eaf8e", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "trainer = MultiTaskTrainer(\n", 85 | " model=model,\n", 86 | " tokenizer=tokenizer,\n", 87 | " train_data=data[:, \"train\"],\n", 88 | " eval_data=data[[\"one\"], \"test\"], # using a list as first key to keep this as a dict\n", 89 | " eval_heads={\"one\": [\"tweet_label\"]}, # limit evaluation to one classification task\n", 90 | " compute_metrics=compute_classification_metrics,\n", 91 | " args=TrainingArguments(output_dir=\"../output\", evaluation_strategy=\"epoch\",save_steps=5000),\n", 92 | ")\n", 93 | "trainer.train()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "id": "78022963-3594-46cb-a3f5-e305c42146e9", 99 | "metadata": {}, 100 | "source": [ 101 | "## Example inference" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "id": "2d7af27a-533b-4c6b-853b-7495a5ccc5e9", 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "model.predict({\"text\": \"this is nice\"}) # single sample inference" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "id": "2ee13f85-49e9-4b00-81ee-94f47088db1e", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "model.predict(both_tasks) # dataframe inference" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "id": "c11ebc22-92f7-4b4c-b3a5-1e20680ddf50", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "model.predict(data[\"one\", \"test\"]) # dataset inference" 132 | ] 133 | } 134 | ], 135 | "metadata": { 136 | "kernelspec": { 137 | "display_name": "Python 3 (ipykernel)", 138 | "language": "python", 139 | "name": "python3" 140 | }, 141 | "language_info": { 142 | "codemirror_mode": { 143 | "name": "ipython", 144 | "version": 3 145 | }, 146 | "file_extension": ".py", 147 | "mimetype": "text/x-python", 148 | "name": "python", 149 | "nbconvert_exporter": "python", 150 | "pygments_lexer": "ipython3", 151 | "version": "3.8.10" 152 | } 153 | }, 154 | "nbformat": 4, 155 | "nbformat_minor": 5 156 | } 157 | -------------------------------------------------------------------------------- /examples/glue.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "23eedacc-0041-41a6-b9a2-15a6e74a495b", 6 | "metadata": {}, 7 | "source": [ 8 | "# GLUE training\n", 9 | "This notebook shows how to fine-tune a model on *all* glue tasks simultaneously, including evaluation metrics." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "id": "01bce55a-54f1-453f-b50a-06eae814fd5b", 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import sys\n", 20 | "\n", 21 | "sys.path.append(\"..\") # ensure we can run examples as-is in the package's poetry env" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "b15ab1f8-4d91-40e1-b6ee-6201ff8d91b0", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import numpy as np\n", 32 | "import torch\n", 33 | "import transformers\n", 34 | "from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset, load_metric\n", 35 | "from transformers import AutoConfig, AutoModel, AutoTokenizer, TrainingArguments\n", 36 | "\n", 37 | "from grouphug import AutoMultiTaskModel, ClassificationHeadConfig, DatasetFormatter, LMHeadConfig, MultiTaskTrainer\n", 38 | "from grouphug.config import logger\n", 39 | "\n", 40 | "torch.cuda.is_available()" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "id": "4d2ddc31-8083-4b1f-a79a-afe797192174", 46 | "metadata": {}, 47 | "source": [ 48 | "## Define which model to fine-tune" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "d83e01fe-cd49-4429-8aa5-29c0c0759c91", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "# transformers.logging.set_verbosity_info() # uncomment for more logging\n", 59 | "base_model = \"HannahRoseKirk/Hatemoji\" # a deberta model" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "4945ae0f-bdce-4470-a522-72db654eeeda", 65 | "metadata": {}, 66 | "source": [ 67 | "## Load data" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "id": "7010ac53-5173-441f-ac55-8abb1fd0920a", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "task_to_keys = {\n", 78 | " \"cola\": (\"sentence\", None), # is this sentence grammatical?\n", 79 | " \"mnli\": (\"premise\", \"hypothesis\"), # label as neutral, entailment, contradiction\n", 80 | " \"mrpc\": (\"sentence1\", \"sentence2\"), # whether the sentences in the pair are semantically equivalent.\n", 81 | " \"qnli\": (\"question\", \"sentence\"), # whether the context sentence contains the answer to the question\n", 82 | " \"qqp\": (\"question1\", \"question2\"), # determine whether a pair of questions are semantically equivalent.\n", 83 | " \"rte\": (\"sentence1\", \"sentence2\"), # similar to mnli\n", 84 | " \"sst2\": (\"sentence\", None), # sentiment\n", 85 | " \"stsb\": (\"sentence1\", \"sentence2\"), # similarity score from 0 to 5.\n", 86 | " \"wnli\": (\"sentence1\", \"sentence2\"), # entailment\n", 87 | "}\n", 88 | "tasks = list(task_to_keys.keys())\n", 89 | "\n", 90 | "\n", 91 | "def load_and_rename(task, reduce_size_target=None):\n", 92 | " k1, k2 = task_to_keys[task]\n", 93 | " dataset = load_dataset(\"glue\", task).rename_column(\"label\", task)\n", 94 | "\n", 95 | " if k2 is not None:\n", 96 | " dataset = dataset.rename_column(k1, \"text1\").rename_column(k2, \"text2\")\n", 97 | " else:\n", 98 | " dataset = dataset.rename_column(k1, \"text\")\n", 99 | "\n", 100 | " dataset = DatasetDict(\n", 101 | " {\n", 102 | " \"train\": dataset[\"train\"],\n", 103 | " \"validation\": concatenate_datasets([v for k, v in dataset.items() if k.startswith(\"validation\")]),\n", 104 | " \"test\": concatenate_datasets([v for k, v in dataset.items() if k.startswith(\"test\")]),\n", 105 | " }\n", 106 | " )\n", 107 | " test_labels = dataset[\"test\"].unique(task)\n", 108 | " if reduce_size_target:\n", 109 | " for k, target_size in reduce_size_target.items():\n", 110 | " dataset[k] = Dataset.from_dict(dataset[k][:target_size])\n", 111 | " logger.debug(f\"Reducing sizes to {len(dataset[k])} for {k}\")\n", 112 | " return dataset" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "id": "35d4a285-bf54-4d1d-8611-8215f3a6e652", 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "target_size = {\"train\": 2000, \"validation\": 100} # just to keep it quick\n", 123 | "glue_data = {task: load_and_rename(task, target_size) for task in tasks}" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "id": "bfd65faa-1899-4050-b463-40f6f519a38f", 129 | "metadata": {}, 130 | "source": [ 131 | "## Define tokenizer and preprocess data" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "33ee7e9a-ea2b-45b7-a15f-531964c66c43", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "tokenizer = AutoTokenizer.from_pretrained(base_model)\n", 142 | "fmt = DatasetFormatter().tokenize(max_length=512).tokenize((\"text1\", \"text2\"), max_length=512)\n", 143 | "data = fmt.apply(glue_data, tokenizer=tokenizer, splits=[\"train\", \"validation\"])" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "id": "99009c76-4403-4bb7-8489-eb200f46ee4f", 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "head_configs = [ClassificationHeadConfig.from_data(data, task, detached=False, ignore_index=-1) for task in tasks]\n", 154 | "# We fine-tune directly on masked inputs. This works well in practice, but may not work well when single words are very important like Cola.\n", 155 | "head_configs += [LMHeadConfig(weight=0.25)]" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "e77046b9-985d-40fe-96ef-7a0f7b09d915", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "model = AutoMultiTaskModel.from_pretrained(base_model, head_configs, formatter=fmt, tokenizer=tokenizer)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "57802e28-f4ae-4f83-a644-975156b127da", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "output_dir = \"../output/demo\"\n", 176 | "training_args = TrainingArguments(\n", 177 | " output_dir=output_dir,\n", 178 | " num_train_epochs=2,\n", 179 | " per_device_train_batch_size=4,\n", 180 | " per_device_eval_batch_size=4,\n", 181 | " gradient_accumulation_steps=8,\n", 182 | " save_total_limit=1,\n", 183 | " evaluation_strategy=\"epoch\",\n", 184 | ")" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "id": "af19e216-0e85-4341-9274-879268139cf7", 190 | "metadata": {}, 191 | "source": [ 192 | "## Define metrics function\n", 193 | "Note additional arguments" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "id": "5edb475a-0581-4c56-823b-521ac39282dc", 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "def compute_metrics(eval_preds, dataset_name, heads):\n", 204 | " metrics_f = load_metric(\"glue\", dataset_name)\n", 205 | " logits, labels = eval_preds\n", 206 | " if dataset_name == \"stsb\":\n", 207 | " return metrics_f.compute(predictions=logits, references=labels)\n", 208 | " predictions = np.argmax(logits, axis=-1)\n", 209 | " return metrics_f.compute(predictions=predictions, references=labels)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "id": "7a224e23-12c4-4e1b-a9f2-ed8567fc1517", 215 | "metadata": {}, 216 | "source": [ 217 | "## Train the model" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "id": "0da0bcb8-e1d9-4cba-8967-87aadd04584f", 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "trainer = MultiTaskTrainer(\n", 228 | " model=model,\n", 229 | " tokenizer=tokenizer,\n", 230 | " args=training_args,\n", 231 | " train_data=data[:, \"train\"],\n", 232 | " eval_data=data[:, \"validation\"],\n", 233 | " eval_heads={t: [t] for t in tasks}, # for dataset [key], run heads [value]\n", 234 | " compute_metrics=compute_metrics,\n", 235 | ")" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "id": "bb6a6b56-bf23-42c9-bb35-d5b93567699d", 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "train_res = trainer.train()" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "id": "15bfa1be-178a-473a-b541-5b46f08c95db", 251 | "metadata": {}, 252 | "source": [ 253 | "## The model predict function takes dicts or entire datasets and preprocesses, infers, and maps back to labels" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "id": "7a23dea5-4597-45e5-bf5d-42096dc1441e", 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "model.predict({\"text\": \"The quick brown fox jumped over the lazy dog!\"})[\"cola_predicted_label\"]" 264 | ] 265 | } 266 | ], 267 | "metadata": { 268 | "kernelspec": { 269 | "display_name": "Python 3 (ipykernel)", 270 | "language": "python", 271 | "name": "python3" 272 | }, 273 | "language_info": { 274 | "codemirror_mode": { 275 | "name": "ipython", 276 | "version": 3 277 | }, 278 | "file_extension": ".py", 279 | "mimetype": "text/x-python", 280 | "name": "python", 281 | "nbconvert_exporter": "python", 282 | "pygments_lexer": "ipython3", 283 | "version": "3.8.10" 284 | } 285 | }, 286 | "nbformat": 4, 287 | "nbformat_minor": 5 288 | } 289 | -------------------------------------------------------------------------------- /examples/lm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a1b50547-7eca-48e4-b73d-bb046ac882aa", 6 | "metadata": {}, 7 | "source": [ 8 | "# This demo tests the effect of different language modelling heads" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "96963a2d-60ce-4326-b1e6-18c4c64cba1b", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import sys\n", 19 | "\n", 20 | "sys.path.append(\"..\") # ensure we can run examples as-is in the package's poetry env" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "fc4f7a95-1b1d-476d-a7a5-420fa28926c7", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import pandas as pd\n", 31 | "import transformers\n", 32 | "from datasets import load_dataset\n", 33 | "from transformers import AutoTokenizer, TrainingArguments\n", 34 | "import torch\n", 35 | "from grouphug import AutoMultiTaskModel, ClassificationHeadConfig, DatasetFormatter, LMHeadConfig, MultiTaskTrainer\n", 36 | "\n", 37 | "from utils import compute_classification_metrics" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "d4dbcdcf-3942-4f6e-a61b-21e8ad428f47", 43 | "metadata": {}, 44 | "source": [ 45 | "## A basic modelling task similar to the readme example" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "106748fe-e755-4017-870e-43c168f21c4f", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "tweet_emotion = load_dataset(\"tweet_eval\",\"emotion\").rename_column(\"label\", \"emotion\")\n", 56 | "\n", 57 | "base_model = \"distilbert-base-uncased\"\n", 58 | "tokenizer = AutoTokenizer.from_pretrained(base_model)\n", 59 | "\n", 60 | "formatter = DatasetFormatter().tokenize()\n", 61 | "data = formatter.apply(tweet_emotion, tokenizer=tokenizer)\n", 62 | "\n", 63 | "head_configs = [ClassificationHeadConfig.from_data(data, \"emotion\", classifier_hidden_size=32)]" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "id": "d034c437-2b9c-4c94-83bf-813c56fcb822", 69 | "metadata": {}, 70 | "source": [ 71 | "## Adding different LM heads to a classification task and training" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "bcf39639-bc5c-411b-9a79-364d48b2fd0b", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "test_lm_heads = {\n", 82 | " \"none\": [],\n", 83 | " \"mlm\": [LMHeadConfig(weight=0.2)],\n", 84 | " \"mtd\": [LMHeadConfig(masked_token_detection=True,weight=0.2)],\n", 85 | " \"mlm+mtd\": [LMHeadConfig(masked_language_modelling=True, masked_token_detection=True,weight=0.2)],\n", 86 | "}\n", 87 | "results = {}\n", 88 | "training_args = TrainingArguments(\n", 89 | " output_dir=\"../output\",\n", 90 | " evaluation_strategy=\"epoch\",\n", 91 | " num_train_epochs=10,\n", 92 | " save_strategy=\"no\",\n", 93 | ")\n", 94 | "for test_key, lm_head in test_lm_heads.items():\n", 95 | " model = AutoMultiTaskModel.from_pretrained(\n", 96 | " base_model, head_configs + lm_head, formatter=formatter, tokenizer=tokenizer\n", 97 | " )\n", 98 | " trainer = MultiTaskTrainer(\n", 99 | " model=model,\n", 100 | " tokenizer=tokenizer,\n", 101 | " train_data=data[:, \"train\"],\n", 102 | " eval_data=data[:, \"test\"],\n", 103 | " eval_heads=[\"emotion\"],\n", 104 | " compute_metrics=compute_classification_metrics,\n", 105 | " args=training_args,\n", 106 | " )\n", 107 | " trainer.train()\n", 108 | " results[test_key] = pd.DataFrame(trainer.state.log_history)\n", 109 | " model = None\n", 110 | " trainer = None\n", 111 | " torch.cuda.empty_cache()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "id": "d3c86d24-43dc-43d0-b6da-40c6bdf3e0bd", 117 | "metadata": {}, 118 | "source": [ 119 | "## Inspecting results" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "3ae5de52-1244-4f0a-a6cf-27614e27092e", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "import matplotlib.pyplot as plt\n", 130 | "plt.figure(figsize=(20,5))\n", 131 | "for i, k in enumerate(['loss','eval_loss','eval_emotion_f1','eval_emotion_matthews_correlation']):\n", 132 | " for test_name, df in results.items():\n", 133 | " ax = plt.subplot(1,4,i+1)\n", 134 | " df.dropna(subset=k).plot(x='step',y=k,ax=ax)\n", 135 | " plt.legend(results.keys())\n", 136 | " plt.title(k)" 137 | ] 138 | } 139 | ], 140 | "metadata": { 141 | "kernelspec": { 142 | "display_name": "Python 3 (ipykernel)", 143 | "language": "python", 144 | "name": "python3" 145 | }, 146 | "language_info": { 147 | "codemirror_mode": { 148 | "name": "ipython", 149 | "version": 3 150 | }, 151 | "file_extension": ".py", 152 | "mimetype": "text/x-python", 153 | "name": "python", 154 | "nbconvert_exporter": "python", 155 | "pygments_lexer": "ipython3", 156 | "version": "3.8.10" 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 5 161 | } 162 | -------------------------------------------------------------------------------- /examples/neox.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a1b50547-7eca-48e4-b73d-bb046ac882aa", 6 | "metadata": {}, 7 | "source": [ 8 | "# This demo shows how to train GPT-NeoX from scratch" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "96963a2d-60ce-4326-b1e6-18c4c64cba1b", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import sys\n", 19 | "\n", 20 | "sys.path.append(\"..\") # ensure we can run examples as-is in the package's poetry env" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "fc4f7a95-1b1d-476d-a7a5-420fa28926c7", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import pandas as pd\n", 31 | "import torch\n", 32 | "import transformers\n", 33 | "from datasets import load_dataset\n", 34 | "from transformers import (\n", 35 | " AutoConfig,\n", 36 | " AutoTokenizer,\n", 37 | " DebertaPreTrainedModel,\n", 38 | " GPTNeoXConfig,\n", 39 | " GPTNeoXPreTrainedModel,\n", 40 | " TrainingArguments,\n", 41 | ")\n", 42 | "from utils import compute_classification_metrics\n", 43 | "\n", 44 | "from grouphug import AutoMultiTaskModel, ClassificationHeadConfig, DatasetFormatter, LMHeadConfig, MultiTaskTrainer" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "7fb7da8e-7208-4ae8-95aa-e5c506965ca4", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "config = GPTNeoXConfig(\n", 55 | " hidden_size=768, intermediate_size=3072, num_attention_heads=12, num_hidden_layers=12, is_decoder=True\n", 56 | ")" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "149b85fa-2b54-41ba-aaaa-a383a533cce7", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "tweet_emotion = load_dataset(\"tweet_eval\", \"emotion\").rename_column(\"label\", \"emotion\")\n", 67 | "wiki_data = load_dataset(\"wikitext\", \"wikitext-2-v1\")\n", 68 | "\n", 69 | "tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neox-20b\")\n", 70 | "formatter = DatasetFormatter().tokenize()\n", 71 | "data = formatter.apply(tweet_emotion, tokenizer=tokenizer)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "787e3d8f-11b1-4b73-ba37-aa5f5ece2258", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "head_configs = [\n", 82 | " LMHeadConfig(causal_language_modelling=True),\n", 83 | " ClassificationHeadConfig.from_data(data, \"emotion\", classifier_hidden_size=32),\n", 84 | "]\n", 85 | "model = AutoMultiTaskModel.from_config(config, head_configs=head_configs, tokenizer=tokenizer, formatter=formatter)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "bcf39639-bc5c-411b-9a79-364d48b2fd0b", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "training_args = TrainingArguments(\n", 96 | " output_dir=\"../output\",\n", 97 | " evaluation_strategy=\"epoch\",\n", 98 | " num_train_epochs=10,\n", 99 | " save_strategy=\"no\",\n", 100 | ")\n", 101 | "trainer = MultiTaskTrainer(\n", 102 | " model=model,\n", 103 | " tokenizer=tokenizer,\n", 104 | " train_data=data[:, \"train\"],\n", 105 | " eval_data=data[:, \"test\"],\n", 106 | " eval_heads=[\"emotion\"],\n", 107 | " compute_metrics=compute_classification_metrics,\n", 108 | " args=training_args,\n", 109 | ")\n", 110 | "trainer.train()" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "id": "689b6eba-0d67-4f59-b828-061d46c0118b", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [] 120 | } 121 | ], 122 | "metadata": { 123 | "kernelspec": { 124 | "display_name": "Python 3 (ipykernel)", 125 | "language": "python", 126 | "name": "python3" 127 | }, 128 | "language_info": { 129 | "codemirror_mode": { 130 | "name": "ipython", 131 | "version": 3 132 | }, 133 | "file_extension": ".py", 134 | "mimetype": "text/x-python", 135 | "name": "python", 136 | "nbconvert_exporter": "python", 137 | "pygments_lexer": "ipython3", 138 | "version": "3.8.10" 139 | } 140 | }, 141 | "nbformat": 4, 142 | "nbformat_minor": 5 143 | } 144 | -------------------------------------------------------------------------------- /examples/sentiment.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "23eedacc-0041-41a6-b9a2-15a6e74a495b", 6 | "metadata": {}, 7 | "source": [ 8 | "# Sentiment training\n", 9 | "This notebook shows how to fine-tune a model on a few different sentiment datasets." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "id": "a4bddf81-1d84-4874-9374-8fef35fdf10a", 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import sys\n", 20 | "\n", 21 | "sys.path.append(\"..\") # ensure we can run examples as-is in the package's poetry env" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "10cdd4dc-1e3d-47d4-872a-973d7606a2bf", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import torch\n", 32 | "import transformers\n", 33 | "from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset, load_metric\n", 34 | "from transformers import AutoConfig, AutoModel, AutoTokenizer, TrainingArguments\n", 35 | "\n", 36 | "from grouphug import AutoMultiTaskModel, ClassificationHeadConfig, DatasetFormatter, LMHeadConfig, MultiTaskTrainer\n", 37 | "from grouphug.config import logger\n", 38 | "\n", 39 | "torch.cuda.is_available()" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "4d2ddc31-8083-4b1f-a79a-afe797192174", 45 | "metadata": {}, 46 | "source": [ 47 | "## Define which model to fine-tune" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "d83e01fe-cd49-4429-8aa5-29c0c0759c91", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# transformers.logging.set_verbosity_info() # uncomment for more logging\n", 58 | "base_model = \"prajjwal1/bert-tiny\"" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "id": "4945ae0f-bdce-4470-a522-72db654eeeda", 64 | "metadata": {}, 65 | "source": [ 66 | "## Load data" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "3c1cea71-fc50-474b-a80d-3e88d0d694fd", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "gp_data = load_dataset(\"IsaacBot/GP-Sentiment\").rename_column(\"content\", \"text\")\n", 77 | "imdb_data = load_dataset(\"imdb\").rename_column(\"label\", \"negpos\")" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "id": "668aa27c-edbd-4d42-85e8-a49c228e0b4f", 83 | "metadata": {}, 84 | "source": [ 85 | "## Define tokenizer and preprocess data" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "8e4a68e4-06ab-4035-9705-7c20e6e0e86b", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "tokenizer = AutoTokenizer.from_pretrained(base_model)\n", 96 | "fmt = DatasetFormatter().tokenize(max_length=512).encode('score')\n", 97 | "data = fmt.apply({\"gp\": gp_data, \"imdb\": imdb_data}, tokenizer=tokenizer, splits=[\"train\", \"test\"])" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "id": "bfd65faa-1899-4050-b463-40f6f519a38f", 103 | "metadata": {}, 104 | "source": [ 105 | "## Define model" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "99009c76-4403-4bb7-8489-eb200f46ee4f", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "head_configs = (\n", 116 | " [ # as labels are different, we create different classifier heads for each task, but the base model is shared\n", 117 | " ClassificationHeadConfig.from_data(data, \"score\", classifier_hidden_size=50),\n", 118 | " ClassificationHeadConfig.from_data(data, \"negpos\", classifier_hidden_size=20, weight=2),\n", 119 | " ]\n", 120 | ")" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "id": "e77046b9-985d-40fe-96ef-7a0f7b09d915", 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "model = AutoMultiTaskModel.from_pretrained(base_model, head_configs, formatter=fmt, tokenizer=tokenizer)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "id": "7a224e23-12c4-4e1b-a9f2-ed8567fc1517", 136 | "metadata": {}, 137 | "source": [ 138 | "## Train the model" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "0da0bcb8-e1d9-4cba-8967-87aadd04584f", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "output_dir = \"../output/demo\"\n", 149 | "training_args = TrainingArguments(\n", 150 | " output_dir=output_dir,\n", 151 | " num_train_epochs=2,\n", 152 | " per_device_train_batch_size=4,\n", 153 | " per_device_eval_batch_size=4,\n", 154 | " gradient_accumulation_steps=8,\n", 155 | " save_total_limit=1,\n", 156 | " evaluation_strategy=\"epoch\",\n", 157 | ")\n", 158 | "\n", 159 | "trainer = MultiTaskTrainer(\n", 160 | " model=model,\n", 161 | " tokenizer=tokenizer,\n", 162 | " args=training_args,\n", 163 | " train_data=data[:, \"train\"],\n", 164 | " eval_data=data[:, \"test\"],\n", 165 | ")" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "20bc551e-e7f4-4c47-aa18-2e047ac8fc76", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "train_res = trainer.train()" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "id": "15bfa1be-178a-473a-b541-5b46f08c95db", 181 | "metadata": {}, 182 | "source": [ 183 | "## The model predict function takes dicts or entire datasets and preprocesses, infers, and maps back to labels" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "id": "7a23dea5-4597-45e5-bf5d-42096dc1441e", 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "model.predict({\"text\": \"This will predict both things at once, giving probabilities, labels, and predicted ids. Awesome!\"})" 194 | ] 195 | } 196 | ], 197 | "metadata": { 198 | "kernelspec": { 199 | "display_name": "Python 3 (ipykernel)", 200 | "language": "python", 201 | "name": "python3" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.8.10" 214 | } 215 | }, 216 | "nbformat": 4, 217 | "nbformat_minor": 5 218 | } 219 | -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | # This is a generic compute_metrics function that will give a range of metrics for classification tasks 2 | 3 | import evaluate 4 | import numpy as np 5 | 6 | from grouphug import ClassificationHead 7 | from grouphug.config import IGNORE_INDEX 8 | 9 | metrics = {k: evaluate.load(k) for k in ["accuracy", "f1", "recall", "precision", "matthews_correlation"]} 10 | 11 | 12 | def compute_classification_metrics(eval_preds, dataset_name, heads): 13 | all_logits, all_labels = eval_preds 14 | if not isinstance(all_logits, tuple): 15 | all_logits = (all_logits,) 16 | all_labels = (all_labels,) 17 | results = {} 18 | 19 | for logits, labels, hc in zip(all_logits, all_labels, heads): 20 | labels_1d = labels.ravel() 21 | mask = labels_1d != hc.ignore_index 22 | labels_1d = labels_1d[mask] 23 | if hc.problem_type == ClassificationHead.MULTI: 24 | predictions = logits > 0 25 | predictions_1d = predictions.ravel()[mask] 26 | exact_match = ((predictions == labels) | (labels == IGNORE_INDEX)).all(axis=-1) 27 | # entire prediction is correct 28 | results[f"{hc.name}_subset_accuracy"] = exact_match.sum() / len(exact_match) 29 | else: 30 | predictions_1d = np.argmax(logits, axis=-1).ravel()[mask] 31 | for k, f in metrics.items(): 32 | try: 33 | kwargs = {"average": "weighted"} if k in ["f1", "recall", "precision"] else {} 34 | for mk, mv in f.compute(predictions=predictions_1d, references=labels_1d, **kwargs).items(): 35 | results[f"{hc.name}_{mk}"] = mv 36 | except Exception as e: 37 | print(f"metric {k} on dataset {dataset_name} head {hc.name} failed: {e}") 38 | return results 39 | -------------------------------------------------------------------------------- /grouphug/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_collection import DatasetCollection 2 | from .dataset_formatter import DatasetFormatter 3 | from .heads import ClassificationHead, ClassificationHeadConfig, LMHeadConfig 4 | from .model import AutoMultiTaskModel 5 | from .trainer import MultiTaskTrainer 6 | -------------------------------------------------------------------------------- /grouphug/collator.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from transformers import default_data_collator 5 | 6 | from grouphug.config import ( 7 | CLM_LABELS_VAR, 8 | IGNORE_INDEX, 9 | INPUT_IDS_VAR, 10 | MASKED_INPUT_IDS_VAR, 11 | MLM_LABELS_VAR, 12 | MTD_LABELS_VAR, 13 | MTD_TOKEN_RANDOM, 14 | MTD_TOKEN_SIMILARITY, 15 | logger, 16 | ) 17 | from grouphug.model import _BaseMultiTaskModel 18 | 19 | 20 | class AutoCollator: 21 | def __init__(self, model: _BaseMultiTaskModel, tokenizer): 22 | """Collates inputs and handles masking when needed""" 23 | self.model = model 24 | self.mlm_head = model.get_mlm_head() 25 | self.mlm_active = self.mlm_head is not None 26 | if self.mlm_active: 27 | if tokenizer is None: 28 | raise ValueError("Pass a tokenizer to MultiTaskTrainer for masked language modelling") 29 | if self.mlm_head.causal_language_modelling: 30 | if not model.config.is_decoder: 31 | logger.warning("Model not set as is_decoder, which is usual with causal_language_modelling") 32 | elif self.mlm_head.mask_probability > 0.0 and tokenizer.mask_token is None: 33 | raise AttributeError(f"Tokenizer has no mask token, which is necessary when mask_probability > 0.0") 34 | 35 | self.input_prefixes = model.input_prefixes() 36 | self.model_vars = model.vars() 37 | self.tokenizer = tokenizer 38 | 39 | def update_mlm_active(self): # cached check 40 | self.mlm_active = self.mlm_head in self.model.get_active_heads() 41 | return self.mlm_active 42 | 43 | def _maybe_pad(self, columns, return_tensors) -> Dict[str, torch.Tensor]: 44 | """Determines if the inputs still need to be padded and pads if needed""" 45 | if INPUT_IDS_VAR in columns[0] and not all( 46 | len(x[INPUT_IDS_VAR]) == len(columns[0][INPUT_IDS_VAR]) for x in columns 47 | ): 48 | if not self.tokenizer: 49 | raise ValueError( 50 | "Inputs are of different lengths, and no tokenizer passed to trainer to dynamically pad." 51 | ) 52 | if self.input_prefixes != {""}: # Too messy to support for now. 53 | raise ValueError( 54 | "Inputs are of different lengths, and multiple text inputs expected. AutoCollator does not support this, pad to max length in formatting instead." 55 | ) 56 | if self.tokenizer.pad_token is None: 57 | logger.warning("Setting tokenizer 'pad_token' to 'eos_token' as we really need to pad now.") 58 | self.tokenizer.pad_token = self.tokenizer.eos_token 59 | self.model.config.pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) 60 | 61 | # pad collates as well. DOES NOT TRUNCATE! 62 | return self.tokenizer.pad(columns, return_tensors=return_tensors) 63 | else: 64 | return default_data_collator(columns, return_tensors=return_tensors) 65 | 66 | def generate_replacement_tokens(self, token_ids: torch.Tensor, replace_indices: torch.Tensor, strategy: str): 67 | """Generates replacement tokens for MLM/MTD. 68 | Default just takes random token ids, as used in BERT. 69 | This appears to be 'too easy' for MTD, so consider overwriting this with a generator based version.""" 70 | tokens_to_replace = token_ids[replace_indices] 71 | if strategy == MTD_TOKEN_RANDOM: 72 | return torch.randint(len(self.tokenizer), tokens_to_replace.shape, dtype=torch.long) 73 | elif strategy == MTD_TOKEN_SIMILARITY: # cosine distance ish 74 | unique_ids_to_replace, ixs = torch.unique(tokens_to_replace, return_inverse=True) 75 | similarity = self.model.token_similarity(unique_ids_to_replace).to(token_ids.device) # to cpu 76 | similarity -= torch.mean(similarity, dim=1, keepdim=True) 77 | similarity /= torch.std(similarity, dim=1, keepdim=True) 78 | # sample ps 79 | similarity[similarity < 3] = 0 # only take top 0.2% of tokens 80 | for i in range(similarity.size(0)): # TODO: scatter_ ? 81 | similarity[i, unique_ids_to_replace[i]] = 1e-6 # ensures sum is never 0 82 | replacement_tokens = torch.multinomial(similarity[ixs, :], 1)[:, 0] 83 | return replacement_tokens 84 | else: 85 | raise ValueError("Invalid strategy") 86 | 87 | def torch_mask_tokens(self, original_inputs: torch.Tensor, special_tokens_mask: torch.Tensor): 88 | """ 89 | Prepare masked tokens inputs/labels 90 | Default for only masked language modeling: 80% MASK, 10% replaced, 10% original. 91 | Default for masked token detection: 100% replaced. 92 | """ 93 | masked_inputs = original_inputs.clone() 94 | labels = original_inputs.clone() 95 | 96 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) 97 | probability_matrix = torch.full(labels.shape, self.mlm_head.mlm_probability) 98 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) # mlm_probability in non-special tokens 99 | masked_indices = torch.bernoulli(probability_matrix).bool() 100 | 101 | # (MLM:80%, MTD:0%) of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 102 | indices_replaced = None 103 | if self.mlm_head.mask_probability > 0.0: # may not even have a mask token, so avoid this 104 | indices_replaced = ( 105 | torch.bernoulli(torch.full(labels.shape, self.mlm_head.mask_probability)).bool() & masked_indices 106 | ) 107 | masked_inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) 108 | 109 | # (MLM: 10%, MTD: 100%) of the time, we replace masked input tokens with random word 110 | indices_random = masked_indices 111 | if self.mlm_head.generated_token_probability < 1.0: 112 | indices_random &= torch.bernoulli( 113 | torch.full(labels.shape, self.mlm_head.generated_token_probability) 114 | ).bool() 115 | if indices_replaced is not None: 116 | indices_random &= ~indices_replaced 117 | masked_inputs[indices_random] = self.generate_replacement_tokens( 118 | masked_inputs, indices_random, self.mlm_head.mtd_strategy 119 | ) 120 | 121 | # The rest of the time (MLM: 10%, MTD: 0%) we keep the masked input tokens unchanged 122 | return masked_inputs, labels, masked_indices 123 | 124 | def collate_for_mlm(self, batch) -> Dict: 125 | original_inputs = batch[INPUT_IDS_VAR] 126 | special_tokens_mask = batch.pop("special_tokens_mask", None) 127 | if special_tokens_mask is None: 128 | special_tokens_mask = torch.tensor( 129 | [ 130 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) 131 | for val in original_inputs.tolist() 132 | ], 133 | dtype=torch.bool, 134 | ) 135 | else: 136 | special_tokens_mask = special_tokens_mask.bool() 137 | 138 | masked_inputs, mlm_labels, masked_indices = self.torch_mask_tokens(original_inputs, special_tokens_mask) 139 | 140 | if self.mlm_head.masked_language_modelling: 141 | if not self.mlm_head.predict_all_tokens: 142 | mlm_labels.masked_fill_(~masked_indices, IGNORE_INDEX) # only compute loss on masked tokens 143 | else: 144 | mlm_labels.masked_fill_(special_tokens_mask, IGNORE_INDEX) 145 | batch[MLM_LABELS_VAR] = mlm_labels 146 | 147 | if self.mlm_head.masked_token_detection: 148 | batch[MTD_LABELS_VAR] = masked_indices.long() 149 | batch[MTD_LABELS_VAR].masked_fill_(special_tokens_mask, IGNORE_INDEX) 150 | 151 | if self.mlm_head.separate_embedding: 152 | batch[MASKED_INPUT_IDS_VAR] = masked_inputs 153 | else: 154 | batch[INPUT_IDS_VAR] = masked_inputs 155 | 156 | return batch 157 | 158 | def __call__(self, features, return_tensors=None): 159 | return_tensors = return_tensors or "pt" 160 | features_to_collate = [{k: v for k, v in f.items() if k in self.model_vars} for f in features] 161 | collated_features = self._maybe_pad(features_to_collate, return_tensors) 162 | if self.mlm_active: 163 | if not self.mlm_head.causal_language_modelling: 164 | collated_features = self.collate_for_mlm(collated_features) 165 | else: # basically just signals mlm_active to the head 166 | collated_features[CLM_LABELS_VAR] = collated_features[INPUT_IDS_VAR] 167 | return collated_features 168 | -------------------------------------------------------------------------------- /grouphug/config.py: -------------------------------------------------------------------------------- 1 | from transformers.utils.logging import get_logger, set_verbosity 2 | 3 | logger = get_logger("transformers") 4 | 5 | # conventions on feature names 6 | TEXT_VAR = "text" # text columns are called TEXT_VAR 7 | INPUT_IDS_VAR = "input_ids" # token ids are called INPUT_IDS_VAR 8 | INPUT_EMBEDDING_VAR = "embedding" # embedding vars are called INPUT_EMBEDDING_VAR 9 | MASKED_PREFIX = "masked_" 10 | MASKED_INPUT_IDS_VAR = f"{MASKED_PREFIX}{INPUT_IDS_VAR}" 11 | MLM_LABELS_VAR = "mlm_labels" 12 | MTD_LABELS_VAR = "mtd_labels" 13 | CLM_LABELS_VAR = "clm_labels" 14 | 15 | MTD_TOKEN_SIMILARITY = "token_similarity" 16 | MTD_TOKEN_RANDOM = "random" 17 | 18 | # essentially what _pad cares about 19 | TOKENIZER_VARS = [INPUT_IDS_VAR, "attention_mask", "token_type_ids", "special_tokens_mask"] 20 | 21 | # for labels and losses 22 | IGNORE_INDEX = -100 23 | 24 | # for saving and loading models 25 | HEADS_FILE_NAME = "head_configs.json" 26 | FORMATTER_FILE_NAME = "formatter.json" 27 | 28 | # for random splits 29 | DEFAULT_SEED = 42 30 | -------------------------------------------------------------------------------- /grouphug/dataset_collection.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Dict, Iterable, List, Union 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from datasets import Dataset, DatasetDict 7 | 8 | from grouphug.config import DEFAULT_SEED 9 | 10 | 11 | def is_iterable(arg): 12 | return isinstance(arg, (collections.abc.Iterable, slice)) and not isinstance(arg, (str, bytes)) 13 | 14 | 15 | def dataset_value_counts(datasets: List[Dataset], column): 16 | """Determines number of records and value counts of a column""" 17 | if not all(column in ds.features for ds in datasets): 18 | raise ValueError(f"Column {column} not in all datasets") 19 | if not datasets: 20 | raise ValueError(f"Column {column} not in any dataset") 21 | if is_iterable(datasets[0][0][column]): 22 | rows = sum([ds[column] for ds in datasets], []) # ragged arrays/dim problems 23 | values = np.concatenate(rows) 24 | else: 25 | rows = np.concatenate([ds[column] for ds in datasets]) 26 | values = rows 27 | n = len(rows) 28 | if None in values: 29 | raise ValueError(f"Column {column} contains None values, which are not supported") 30 | unique, counts = np.unique(values, return_counts=True) 31 | return n, dict(zip(unique, counts)) 32 | 33 | 34 | class DatasetCollection(collections.UserDict): 35 | """Represents a map of name -> DatasetDict 36 | 37 | Args: 38 | data: dictionary of dataset name to DatasetDict/Dataset/DataFrame. All are converted to DatasetDict, either split or as 100% training data 39 | test_size: for entries that are not a DatasetDict, how to split into train or test size. Can be specified overall or by key. Missing entries are not split. 40 | shuffle: whether to shuffle on splitting 41 | """ 42 | 43 | def __init__( 44 | self, 45 | data: Dict[str, Union[pd.DataFrame, Dataset, DatasetDict]], 46 | test_size: Union[float, Dict[str, float]] = 0.05, 47 | shuffle=True, 48 | seed=DEFAULT_SEED, 49 | ): 50 | if not isinstance(test_size, dict): 51 | test_size = {k: test_size for k in data} 52 | data = data.copy() 53 | for k, dataset in data.items(): 54 | if isinstance(dataset, pd.DataFrame): 55 | dataset = Dataset.from_pandas(dataset) 56 | if isinstance(dataset, Dataset): 57 | ds_test_size = test_size.get(k, None) 58 | if ds_test_size: 59 | dataset = dataset.train_test_split(test_size=ds_test_size, shuffle=shuffle, seed=seed) 60 | else: 61 | dataset = DatasetDict({"train": dataset}) 62 | elif isinstance(dataset, (DatasetDict, dict)): 63 | dataset = DatasetDict(dataset) # shallow copy 64 | else: 65 | raise ValueError(f"Unexpected value in key {k}: {dataset.__class__.__name__} not supported.") 66 | data[k] = dataset 67 | super().__init__(data) 68 | 69 | def __getitem__(self, key) -> Union["DatasetCollection", DatasetDict, Dataset]: 70 | """Supports data[key(s), split(s)] with option to use : for all.""" 71 | if isinstance(key, tuple) and len(key) == 2: # dc[:,'train'], dc['imdb','test'] 72 | if not is_iterable(key[0]) and not is_iterable(key[1]): # dc['imdb','test'] -> return Dataset 73 | return self[key[0]][key[1]] 74 | ds_key = [key[0]] if not is_iterable(key[0]) else key[0] # for dc['set',:] 75 | split_key = [key[1]] if not is_iterable(key[1]) else key[1] 76 | selected_dicts = self[ds_key] 77 | filtered_dsc = { 78 | k: DatasetDict({dk: ds for dk, ds in dsd.items() if split_key == slice(None) or dk in split_key}) 79 | for k, dsd in selected_dicts.items() 80 | } 81 | filtered_dsc = {k: v for k, v in filtered_dsc.items() if v} 82 | if not filtered_dsc: 83 | raise KeyError(f"Key {key[1]} is not in any of the {len(selected_dicts)} DatasetDict '{ds_key}'") 84 | return DatasetCollection(filtered_dsc) 85 | elif key == slice(None): # dc[:] = all 86 | return self 87 | elif isinstance(key, list): # dc[ ['imdb','yelp'] ] 88 | if len(set(key)) != len(key): 89 | raise ValueError(f"Key '{key}' can not contain duplicates") 90 | return DatasetCollection({k: self.get(k) for k in key}) 91 | else: 92 | return super().__getitem__(key) # -> dc[k] = DatasetDict 93 | 94 | def entries(self) -> List[Dataset]: 95 | """Returns all datasets""" 96 | return [ds for dsd in self.values() for ds in dsd.values()] 97 | 98 | def gather_column( 99 | self: "DatasetCollection", column: str, splits: Union[str, Iterable[str]] = "all" 100 | ) -> List[Dataset]: 101 | return [ 102 | ds 103 | for ds_dicts in self.values() 104 | for split, ds in ds_dicts.items() 105 | if (splits == "all" or split in splits) and column in ds.features 106 | ] 107 | -------------------------------------------------------------------------------- /grouphug/dataset_formatter.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | from typing import Any, Dict, Iterable, List, Optional, Tuple, Union 4 | 5 | import demoji 6 | import numpy as np 7 | import pandas as pd 8 | import regex 9 | import unidecode 10 | from datasets import ClassLabel, Dataset, DatasetDict 11 | from transformers import AutoConfig 12 | from transformers.tokenization_utils_base import LARGE_INTEGER 13 | 14 | from . import DatasetCollection 15 | from .config import DEFAULT_SEED, IGNORE_INDEX, INPUT_IDS_VAR, logger 16 | from .dataset_collection import dataset_value_counts, is_iterable 17 | 18 | 19 | def build_regex_or(entries, regexes=False): 20 | if regexes: # regexes, add brackets 21 | strings = ["(?:" + s + ")" for s in entries] 22 | else: # strings 23 | strings = [regex.escape(s) for s in entries] 24 | return "(?:" + "|".join(strings) + ")" 25 | 26 | 27 | def multiple_replacer(replace_dict): 28 | replacement_function = lambda match: replace_dict[match.group(0)] 29 | pattern = regex.compile(build_regex_or(replace_dict.keys()), regex.M) 30 | return lambda string: pattern.sub(replacement_function, string) 31 | 32 | 33 | # various replacements of common formatting errors 34 | NORMALIZATIONS = { 35 | "&": "&", 36 | " ": " ", 37 | """: '"', 38 | "<": "<", 39 | ">": ">", 40 | "
": "\n", 41 | "

": "\n", 42 | "

": "\n", 43 | "’": "'", 44 | "`": "'", 45 | "…": "...", 46 | } 47 | character_normalizer = multiple_replacer(NORMALIZATIONS) # efficient replacement 48 | 49 | # regular expression for @username on instagram/twitter 50 | HANDLES_RE = regex.compile( 51 | r"(? ? 56 | URLS_RE = regex.compile(r"(?:https?[: ]//|(? only 3""" 69 | return regex.sub(r"(.)\1{2,}", r"\1\1\1", text) 70 | 71 | 72 | def replace_handles(text, replacement="@USER"): 73 | """twitter/instagram handles""" 74 | return HANDLES_RE.sub(replacement, text) 75 | 76 | 77 | def remove_handles(text): 78 | return replace_handles(text, " ") 79 | 80 | 81 | def replace_urls(text, replacement=" URL "): 82 | """urls""" 83 | return URLS_RE.sub(replacement, text) 84 | 85 | 86 | def run_unidecode(text): 87 | return unidecode.unidecode(text) 88 | 89 | 90 | demoji.set_emoji_pattern() 91 | 92 | 93 | @functools.lru_cache(maxsize=1000) 94 | def _demoji_build_replacement(emoji, start_marker, end_marker, separator): 95 | desc = demoji._CODE_TO_DESC[emoji] 96 | # drop skin tones, and compress to encourage it to be a single token if common 97 | parts = [p.replace(" ", "").replace("-", "").lower() for p in desc.split(":") if "skin tone" not in p] 98 | if not parts: 99 | return "" # orphan skin tone 100 | return start_marker + separator.join(parts) + end_marker 101 | 102 | 103 | # regex is faster here, emoji's trie approach is another 10x faster 104 | _EMOJI_PAT = regex.compile(demoji._EMOJI_PAT.pattern) 105 | 106 | 107 | def demojize_text(text, start_marker=" :|", end_marker=" ", separator="|"): # somewhat optimized for few sentencepieces 108 | return _EMOJI_PAT.sub(lambda m: _demoji_build_replacement(m[0], start_marker, end_marker, separator), text) 109 | 110 | 111 | class DatasetFormatter: 112 | PREPROCESS = "preprocess" 113 | TOKENIZE = "tokenize" 114 | ENCODE = "encode" 115 | BINARIZE = "binarize" 116 | 117 | PREPROCESSORS = { 118 | "lowercase": tolower, 119 | "normalize": character_normalizer, 120 | "truncate_repeat": truncate_repeat, 121 | "replace_handles": replace_handles, 122 | "remove_handles": remove_handles, 123 | "replace_urls": replace_urls, 124 | "demojize": demojize_text, 125 | "unidecode": run_unidecode, 126 | "normalize_spaces": normalize_spaces, 127 | } 128 | 129 | @classmethod 130 | def register_preprocessor(cls, key, function): 131 | """Use this to add your own preprocessor function, but make sure to do it at both training and inference""" 132 | cls.PREPROCESSORS[key] = function 133 | 134 | def __init__( 135 | self, save_oov: bool = False, operations: List[Tuple] = None, drop_columns: Optional[Iterable[str]] = None 136 | ): 137 | """DatasetFormatter is a pipeline for formatting and preparing your DatasetCollection. 138 | save_oov: save texts that resulted in out-of-vocabulary tokens in .oov_texts 139 | Other args: used to load.""" 140 | self.operations = copy.deepcopy(operations or []) 141 | for op, column, output_column, args in self.operations: 142 | if "label_names" in args and "label2id" not in args: 143 | args["label2id"] = {k: i for i, k in enumerate(args["label_names"])} 144 | self.drop_columns = set(drop_columns or []) 145 | self.save_oov = save_oov 146 | self.oov_texts = [] 147 | 148 | def preprocess(self, operations: List[str], column="text", output_column=None) -> "DatasetFormatter": 149 | """Adds a preprocessing step to your pipeline. 150 | 151 | Args: 152 | operations: 153 | * List of strings from `register_preprocessor` or one of the following built-in: 154 | * lowercase: convert text to lower case 155 | * normalize: convert text to lower case 156 | * truncate_repeat: replaces 3+ repeats of the same character with 3 157 | * replace_handles: replace @username with @USER 158 | * replace_urls: replace urls with URL 159 | * demojize: replaces emojis with an ascii equivalent, ignoring skin tone modifiers 160 | * unidecode: converts all characters to ascii using unidecode 161 | * normalize_spaces: replaces multiple spaces with one. typically does not affect models, more for inspection.""" 162 | for operation in operations: 163 | if operation not in self.PREPROCESSORS: 164 | raise ValueError(f"Unknown operation {operation}") 165 | self.operations.append((self.PREPROCESS, column, output_column or column, operations)) 166 | return self 167 | 168 | def tokenize( 169 | self, 170 | column: Union[str, Tuple[str, str]] = "text", 171 | output_prefix="", 172 | drop=True, 173 | truncation=True, 174 | padding=False, 175 | return_special_tokens_mask=True, 176 | **x_tokenizer_args, 177 | ) -> "DatasetFormatter": 178 | """Adds a tokenizing step to your pipeline. Target is {output_prefix}input_ids and such. 179 | Note that the default collator in MultiTaskTrainer will dynamically pad, but will not truncate, hence the defaults here""" 180 | tokenizer_args = dict( 181 | truncation=truncation, 182 | padding=padding, 183 | return_special_tokens_mask=return_special_tokens_mask, 184 | **x_tokenizer_args, 185 | ) 186 | if drop: 187 | self.drop_columns |= {column} if isinstance(column, str) else set(column) 188 | self.operations.append((self.TOKENIZE, column, output_prefix, tokenizer_args)) 189 | return self 190 | 191 | def encode( 192 | self, column="labels", label_names=None, max_labels=None, min_freq=0, output_column=None, drop=True 193 | ) -> "DatasetFormatter": 194 | """Adds a step to encode labels for classification (from strings/numbers to 0..num_labels-1) 195 | 196 | Args: 197 | column: input column 198 | output_column: output column, default overwrites input column 199 | drop: drop input column if not used as output column 200 | label_names: If label_names is given, only these are encoded and missing labels will get value -100. Default: all labels are used 201 | max_labels/min_freq: If max_labels/min_freq is given, labels are detected and most common ones above threshold used for label_names 202 | """ 203 | if drop and output_column not in [None, column]: 204 | self.drop_columns.add(column) 205 | self.operations.append( 206 | ( 207 | self.ENCODE, 208 | column, 209 | output_column or column, 210 | dict(label_names=label_names, min_freq=min_freq, max_labels=max_labels), 211 | ) 212 | ) 213 | return self 214 | 215 | def binarize( 216 | self, 217 | column="labels", 218 | ignore_column=None, 219 | label_names=None, 220 | max_labels=None, 221 | min_freq=0, 222 | output_column=None, 223 | drop=True, 224 | ): 225 | """Similar to encode, but results in a one-hot encoded or binarized column. Source can be either lists or single values.""" 226 | self.operations.append( 227 | ( 228 | self.BINARIZE, 229 | column, 230 | output_column or column, 231 | { 232 | "label_names": label_names, 233 | "ignore_column": ignore_column, 234 | "min_freq": min_freq, 235 | "max_labels": max_labels, 236 | }, 237 | ) 238 | ) 239 | if drop and output_column not in [None, column]: 240 | self.drop_columns.add(column) 241 | if drop and ignore_column not in [None, column]: 242 | self.drop_columns.add(ignore_column) 243 | return self 244 | 245 | def apply( 246 | self, 247 | data: Union[ 248 | DatasetCollection, Dataset, DatasetDict, pd.DataFrame, Dict[str, Union[Dataset, DatasetDict, pd.DataFrame]] 249 | ], 250 | tokenizer=None, 251 | test_size: Union[float, Dict[str, float]] = 0.05, 252 | shuffle: bool = True, 253 | seed=DEFAULT_SEED, 254 | splits=("train", "test", "validation"), 255 | batch_size=100, 256 | **map_args, 257 | ) -> DatasetCollection: 258 | """ 259 | Formats your data 260 | 261 | Args: 262 | data: if not a DatasetCollection, will make it one using test_size/shuffle. If a single entry is given it will be called 'data'. 263 | test_size, shuffle, seed: for shuffle: passed to DatasetCollection 264 | splits: which splits are mapped and used in determining labels 265 | batch_size, map_args: arguments for map""" 266 | if isinstance(data, (Dataset, DatasetDict, pd.DataFrame)): 267 | data = {"data": data} 268 | if not isinstance(data, DatasetCollection): 269 | data = DatasetCollection(data, test_size=test_size, shuffle=shuffle, seed=seed) 270 | 271 | # prepare automatic encoder/binarizer 272 | feature_label_names = {} 273 | tokenizer_args = None 274 | for op, column, output_column, args in self.operations: 275 | if op in [self.ENCODE, self.BINARIZE]: 276 | if args["label_names"] is None: 277 | logger.info(f"Automatically determining labels for {op} on {column}") 278 | n, counts = dataset_value_counts(data.gather_column(column, splits=splits), column) 279 | min_freq = args["min_freq"] 280 | max_labels = args["max_labels"] 281 | sorted_counts = sorted( 282 | [ 283 | (c, k) 284 | for k, c in counts.items() 285 | if k != IGNORE_INDEX # TODO: option? 286 | and ((min_freq >= 1 and c >= min_freq) or (min_freq < 1 and c / n >= min_freq)) 287 | ], 288 | reverse=True, 289 | ) 290 | if max_labels: 291 | sorted_counts = sorted_counts[:max_labels] 292 | label_names = sorted([k for c, k in sorted_counts]) 293 | freqs = [{k: c / n for c, k in sorted_counts}] 294 | feature_label_names[column] = ClassLabel(names=[str(k) for k in label_names]) # only accepts str 295 | logger.info(f"Determined labels and frequencies for {column} as {freqs}") 296 | args["label_names"] = label_names 297 | feature_label_names[column] = ClassLabel(names=[str(k) for k in args["label_names"]]) # only str 298 | args["label2id"] = {k: i for i, k in enumerate(args["label_names"])} 299 | elif op == self.TOKENIZE: 300 | tokenizer_args = args 301 | 302 | if tokenizer_args is not None: 303 | if tokenizer is None: 304 | raise ValueError("Should pass a tokenizer if tokenizing a column") 305 | if tokenizer.model_max_length > LARGE_INTEGER and "max_length" not in tokenizer_args: 306 | try: 307 | config = AutoConfig.from_pretrained(tokenizer.name_or_path) 308 | tokenizer.model_max_length = config.max_position_embeddings 309 | logger.warning( 310 | f"Tokenizer has no .model_max_length and no max_length passed to tokenize, setting it to {tokenizer.model_max_length} based on model max_position_embeddings" 311 | ) 312 | except Exception as e: 313 | logger.warning(f"Error while trying to set tokenizer max length: {e}") 314 | 315 | # log and check ops 316 | ops_log = [ 317 | f"{col}: {op}({args})" + (f" -> {tocol}" if tocol and tocol != col else "") 318 | for (op, col, tocol, args) in self.operations 319 | ] 320 | logger.info(f"Applying the following operations to all datasets: {ops_log}") 321 | 322 | for ds_name, dataset_dict in data.items(): 323 | for ds_split, dataset in dataset_dict.items(): 324 | 325 | if ds_split in splits: 326 | old_fingerprint = dataset_dict[ds_split]._fingerprint 327 | old_oov_len = len(self.oov_texts) 328 | # new_fingerprint = update_fingerprint(old_fingerprint, op_name, ops_log) 329 | remove_columns = list(self.drop_columns & set(dataset.features)) 330 | dataset_dict[ds_split] = dataset_dict[ds_split].map( 331 | self._format_batch, 332 | fn_kwargs={"tokenizer": tokenizer, "operations": self.operations}, 333 | batch_size=batch_size, 334 | batched=True, 335 | remove_columns=remove_columns, 336 | **map_args, 337 | ) 338 | new_fingerprint = dataset_dict[ds_split]._fingerprint 339 | new_oov = len(self.oov_texts) - old_oov_len 340 | oov_message = f"{new_oov} texts with out-of-vocabulary () tokens. " if new_oov else "" 341 | # Fingerprint {new_fingerprint} 342 | logger.info( 343 | f"Formatted dataset {ds_name}[{ds_split}], {len(dataset_dict[ds_split])} samples. Dropping {','.join(remove_columns) or ''}, features = {','.join(dataset_dict[ds_split].features.keys())}. {oov_message}" 344 | ) 345 | # add names to feature. TODO: cast? does not like binarized or some other way to pass id2label to .from_data 346 | for column, classlabel in feature_label_names.items(): 347 | if column in dataset_dict[ds_split].features: 348 | dataset_dict[ds_split].features[column] = classlabel 349 | 350 | if tokenizer_args and self.save_oov: 351 | logger.info(f"{len(self.oov_texts)} texts with out-of-vocabulary () tokens stored in .oov_texts") 352 | 353 | return data 354 | 355 | def apply_batch( 356 | self, 357 | batch: Dict[str, List[Any]], 358 | tokenizer=None, 359 | ) -> Dict[str, List[Any]]: 360 | """Format single batch, to be used for inference and such""" 361 | if tokenizer is None and any(opargs[0] == self.TOKENIZE for opargs in self.operations): 362 | raise ValueError("Should pass a tokenizer if tokenizing a column") 363 | return { 364 | k: v for k, v in self._format_batch(batch, tokenizer, self.operations).items() if k not in self.drop_columns 365 | } 366 | 367 | def _format_batch(self, batch, tokenizer, operations: List[Tuple]) -> Dict[str, List]: 368 | """used by apply...""" 369 | 370 | # TODO: cache? 371 | def binarize_labels(values, ignore_values, label2id: Dict): 372 | encoded_values = [] 373 | for labels, ignore_labels in zip(values, ignore_values): 374 | binarized_labels = np.zeros(len(label2id)) 375 | if not is_iterable(labels): 376 | labels = [labels] # allow binarizing single values 377 | if ignore_values: # set ignore first, then potentially overwrite with positive labels 378 | for label in ignore_labels: 379 | ix = label2id.get(label) 380 | if ix is not None: 381 | binarized_labels[ix] = IGNORE_INDEX 382 | for label in labels: 383 | ix = label2id.get(label) 384 | if ix is not None: 385 | binarized_labels[ix] = 1 386 | encoded_values.append(binarized_labels) 387 | return encoded_values 388 | 389 | def encode_labels(values, label2id: Dict): 390 | return [label2id.get(label, IGNORE_INDEX) for label in values] 391 | 392 | output = batch 393 | for operation, column, output_col, args in operations: 394 | if operation == DatasetFormatter.TOKENIZE: # potentially multi column with [SEP] 395 | if isinstance(column, str): 396 | if column not in batch: 397 | continue 398 | tokenized = tokenizer(batch[column], **args) 399 | else: 400 | if not all(c in batch for c in column): 401 | continue 402 | pairs = list(zip(*[batch[c] for c in column])) 403 | tokenized = tokenizer(pairs, **args) 404 | if self.save_oov: 405 | self.oov_texts.extend( 406 | [ 407 | (text, tokenizer.decode([t for t in tokens if t != tokenizer.pad_token_id])) 408 | for tokens, text in zip(tokenized[INPUT_IDS_VAR], batch[column]) 409 | if tokenizer.unk_token_id in tokens 410 | ] 411 | ) 412 | 413 | output.update({output_col + k: v for k, v in tokenized.items()}) 414 | else: # all single column ops 415 | if column not in batch: 416 | continue 417 | if operation == DatasetFormatter.PREPROCESS: 418 | texts = batch[column] 419 | for operation in args: 420 | f = self.PREPROCESSORS[operation] 421 | texts = [f(text) for text in texts] 422 | output[output_col] = texts 423 | 424 | elif operation == DatasetFormatter.ENCODE: 425 | output[output_col] = encode_labels(batch[column], args["label2id"]) 426 | elif operation == DatasetFormatter.BINARIZE: 427 | ignore_col = args.get("ignore_column") 428 | if ignore_col: 429 | if ignore_col not in batch: 430 | raise ValueError(f"Missing ignore_column '{ignore_col}' in binarizing '{column}'") 431 | ignore_values = batch[ignore_col] 432 | else: 433 | ignore_values = [[] for _ in batch[column]] 434 | output[output_col] = binarize_labels(batch[column], ignore_values, args["label2id"]) 435 | else: 436 | raise ValueError(f"Unknown formatting operation {operation}") 437 | return output 438 | 439 | # for loading/saving 440 | def to_dict(self) -> Dict: 441 | operations = copy.deepcopy(self.operations) 442 | for op, _, _, args in operations: 443 | if op in [self.ENCODE, self.BINARIZE]: 444 | args.pop("label2id", None) # can not encode non-string keys in json 445 | return {"operations": operations, "drop_columns": list(self.drop_columns)} 446 | 447 | @classmethod 448 | def from_dict(cls, data: Dict) -> "DatasetFormatter": 449 | return cls(**data) 450 | -------------------------------------------------------------------------------- /grouphug/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from grouphug.heads.classification import ClassificationHead, ClassificationHeadConfig 4 | from grouphug.heads.lm import LMHeadConfig 5 | 6 | 7 | # not a class method to avoid circular import hell 8 | def head_config_from_dict(data: Dict): 9 | hcls = globals().get(data["class"]) 10 | if hcls is None: 11 | raise NotImplementedError(f"Could not find head config class {data['class']}") 12 | return hcls(**data["args"]) 13 | -------------------------------------------------------------------------------- /grouphug/heads/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Any, Dict, Optional, Set 3 | 4 | import torch 5 | from torch import nn 6 | from transformers.modeling_outputs import ModelOutput 7 | 8 | from grouphug.config import INPUT_IDS_VAR 9 | 10 | 11 | class ModelHead(nn.Module): 12 | _rename_keys = [] # key suffixes to rename when loading 13 | 14 | 15 | class HeadConfig(ABC): 16 | 17 | """ 18 | Abstract base config which contains all common config options for multi-task model heads 19 | Args: 20 | input_prefix: [prefix]text -> [prefix]input_ids -> [prefix]embeddings 21 | name: If ommitted, a name is generated. Key in the model outputs, among others. 22 | attribute: variable name to store head in on the model, to ensure pre-trained MLM heads load 23 | weight: .. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | input_prefix: str = "", 29 | weight: float = 1.0, 30 | name: Optional[str] = None, 31 | attribute: Optional[str] = None, 32 | **_kwargs, # ignored/for loading 33 | ): 34 | self.weight = weight 35 | self.input_prefix = input_prefix 36 | self.name = name 37 | self.attribute = attribute 38 | 39 | self._head = None 40 | self._required_input_vars = None # cache for _input_vars 41 | 42 | # Task weight, loss is multiplied by this 43 | # Can be changed for siamese or other multi-text inputs 44 | 45 | # For internal use 46 | def input_vars(self) -> Dict[str, Set[str]]: # set of variables required to run train/inference on head 47 | req_vars = {self.input_prefix + INPUT_IDS_VAR} 48 | return {"train": req_vars, "infer": req_vars} 49 | 50 | def input_prefixes(self): # used in model forward, general setup to allow for siamese models 51 | return [self.input_prefix] 52 | 53 | def _name(self): # generates name if not given 54 | return f"{self._head.__class__.__name__}" 55 | 56 | def create_head(self, config): 57 | raise NotImplementedError() 58 | 59 | def output_stats(self, output: ModelOutput) -> Dict[str, Any]: 60 | """Turns head output into a set of richer statistics""" 61 | if getattr(output, "loss", None) is None: 62 | return {} 63 | else: 64 | return {"loss": output.loss.item()} 65 | 66 | def to_dict(self) -> Dict: 67 | args = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} 68 | for k, v in args.items(): 69 | if isinstance(v, torch.Tensor): 70 | args[k] = v.detach().cpu().numpy() # pos_weight 71 | return { 72 | "class": str(self.__class__.__name__), 73 | "args": args, 74 | } 75 | -------------------------------------------------------------------------------- /grouphug/heads/classification.py: -------------------------------------------------------------------------------- 1 | # sequence classification 2 | from typing import Any, Dict, Iterable, List, Optional, Set, Union 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 7 | from transformers import PretrainedConfig 8 | from transformers.modeling_outputs import SequenceClassifierOutput 9 | 10 | from grouphug.config import IGNORE_INDEX, INPUT_EMBEDDING_VAR, INPUT_IDS_VAR 11 | from grouphug.heads.base import HeadConfig, ModelHead 12 | 13 | from .. import DatasetCollection 14 | from ..config import logger 15 | from ..dataset_collection import dataset_value_counts, is_iterable 16 | 17 | 18 | class ClassificationHead(ModelHead): 19 | SINGLE = "single_label_classification" 20 | MULTI = "multi_label_classification" 21 | REGRESSION = "regression" 22 | PROBLEM_TYPES = [SINGLE, MULTI, REGRESSION] 23 | 24 | """Head for sentence-level classification tasks with: 25 | * Modular setup for pooling, loss, head structure 26 | * Additional options, see ClassificationHeadConfig""" 27 | 28 | def __init__(self, config: PretrainedConfig, head_config: "ClassificationHeadConfig"): 29 | super().__init__() 30 | self.head_config = head_config 31 | self.config = config 32 | self.num_labels = head_config.num_labels 33 | if self.num_labels is None: 34 | raise ValueError(f"Must set 'num_labels' for head {self}") 35 | 36 | self.num_extra_inputs = head_config.num_extra_inputs 37 | self.classifier_hidden_size = head_config.classifier_hidden_size or config.hidden_size 38 | if head_config.dropout is not None: 39 | self.dropout = head_config.dropout 40 | else: 41 | self.dropout = 0 42 | 43 | self.init_modules(config.hidden_size + self.num_extra_inputs, head_config.num_labels) 44 | # initialize here because of pos_weight 45 | if self.head_config.problem_type == self.MULTI and head_config.pos_weight is not None: 46 | if not isinstance(head_config.pos_weight, Iterable): 47 | head_config.pos_weight = [head_config.pos_weight] * head_config.num_labels 48 | self.pos_weight = torch.nn.Parameter(torch.Tensor(head_config.pos_weight), requires_grad=False) # to device 49 | else: 50 | self.pos_weight = None 51 | 52 | def init_modules(self, input_dim, output_dim): 53 | """Overwrite this method to change the architecture of the classification head""" 54 | hidden_dim = self.classifier_hidden_size 55 | self.head = nn.Sequential( 56 | nn.Dropout(self.dropout), 57 | nn.Linear(input_dim, hidden_dim), 58 | nn.Tanh(), 59 | nn.Dropout(self.dropout), 60 | nn.Linear(hidden_dim, output_dim), 61 | ) 62 | 63 | def __repr__(self): 64 | return f"{self.__class__.__name__}({self.head_config.labels_var})" 65 | 66 | def pool_embedding(self, embeddings, **kwargs): 67 | """Overwrite this method to change the pooling in the classification head""" 68 | features = embeddings[0] # up to this point it is still the dict output of roberta etc. 69 | if self.head_config.detached: 70 | features = features.detach() 71 | if self.head_config.pooling_method in ["cls", "first"]: 72 | return features[:, 0, :] # take first token, typically or [CLS] 73 | if self.head_config.pooling_method == "last": # last non-padded token 74 | input_ids: torch.Tensor = kwargs.get(f"{self.head_config.input_prefix}{INPUT_IDS_VAR}") 75 | if self.config.pad_token_id is None: # take last non-pad token, usually 76 | raise ValueError("No pad token set, can not detect last token to use in classification") 77 | if (input_ids[:, 0] == self.config.pad_token_id).any(): # left padded 78 | sequence_lengths = -1 79 | else: # right padded, as is usual 80 | sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 81 | return features[torch.arange(features.shape[0], device=features.device), sequence_lengths] 82 | else: 83 | attention_mask_col = f"{self.head_config.input_prefix}attention_mask" 84 | if attention_mask_col not in kwargs: 85 | raise ValueError(f"mean or max pooling requires column {attention_mask_col}") 86 | attention_mask = kwargs[attention_mask_col] 87 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(features.size()).float() 88 | if self.head_config.pooling_method == "mean": # from huggingface docs 89 | return torch.sum(features * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 90 | elif self.head_config.pooling_method == "max": 91 | return torch.max(features * input_mask_expanded, 1).values 92 | else: 93 | raise ValueError(f"Unknown pooling method {self.head_config.pooling_method}") 94 | 95 | def loss(self, logits: torch.Tensor, labels: torch.Tensor): 96 | """Overwrite this method to change the loss of the classification head""" 97 | if self.head_config.problem_type == self.REGRESSION: 98 | loss_fct = MSELoss() 99 | if self.num_labels == 1: 100 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 101 | else: 102 | loss = loss_fct(logits, labels) 103 | else: 104 | label_smoothing = self.head_config.label_smoothing if self.training else 0.0 105 | if self.head_config.problem_type == self.SINGLE: 106 | loss_fct = CrossEntropyLoss(ignore_index=self.head_config.ignore_index, label_smoothing=label_smoothing) 107 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 108 | if torch.isnan(loss) and (labels == self.head_config.ignore_index).all(): 109 | loss = 0.0 * logits.sum() # no labels, zero loss better than nan 110 | else: 111 | assert self.head_config.problem_type == self.MULTI 112 | labels: torch.Tensor = labels.float() # BCEWithLogitsLoss does not like ints 113 | ignore_labels_mask: torch.Tensor = labels == self.head_config.ignore_index # save this before smoothing 114 | if label_smoothing: 115 | labels = labels * (1 - label_smoothing) + torch.ones_like(labels) * 0.5 * label_smoothing 116 | if ignore_labels_mask.any(): # ignore entries by setting loss to 0 117 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight, reduction="none") 118 | loss_entries = loss_fct(logits, labels) 119 | loss_entries.masked_fill_(ignore_labels_mask, 0.0) 120 | loss = loss_entries.mean() # could have some option for .sum() / (~masked).sum() 121 | else: 122 | loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) 123 | loss = loss_fct(logits, labels) 124 | 125 | return loss 126 | 127 | def forward(self, **kwargs): 128 | if self.head_config.pooling_method == "auto": 129 | self.head_config.pooling_method = "last" if self.config.is_decoder else "first" 130 | logger.info( 131 | f"Set pooling method to '{self.head_config.pooling_method}' for {self} based on config.is_decoder = {self.config.is_decoder}" 132 | ) 133 | 134 | input_embedding = kwargs[f"{self.head_config.input_prefix}{INPUT_EMBEDDING_VAR}"] 135 | x = self.pool_embedding(input_embedding, **kwargs) 136 | 137 | if self.num_extra_inputs: 138 | extra_classification_inputs = torch.cat([kwargs[v] for v in self.head_config.extra_inputs_vars], dim=-1) 139 | if extra_classification_inputs.shape[1] != self.num_extra_inputs: 140 | raise ValueError( 141 | f"Head {self} expected extra_classification_inputs of dimension {self.num_extra_inputs} but found {extra_classification_inputs.shape[1]}" 142 | ) 143 | x = torch.cat((x, extra_classification_inputs), dim=-1) 144 | 145 | logits = self.head(x) 146 | 147 | labels = kwargs.get(self.head_config.labels_var) 148 | if labels is not None: 149 | loss = self.loss(logits, labels) 150 | else: 151 | loss = None 152 | 153 | return SequenceClassifierOutput(loss=loss, logits=logits) 154 | 155 | 156 | class ClassificationHeadConfig(HeadConfig): 157 | """Config for ClassificationHead 158 | 159 | Args: 160 | num_labels: Dimension of the labels. 161 | labels_var: Dataset column to use for labels. Also the default name of the head. 162 | id2label: List of label names 163 | problem_type: 'single_label_classification', 'multi_label_classification', 'regression', or None for detecting single/multi 164 | dropout: dropout used both after embeddings and after hidden layer 165 | detached: Stops gradients flowing to the base model, equivalent to freezing base model if this is the only head 166 | pooling method from embeddings to head inputs (first/cls, mean, max, last). 'auto' uses first/last based on config.is_decoder 167 | classifier_hidden_size: size of middle layer in classification head. default is embedding's `hidden_size` 168 | extra_inputs_vars: names of vars in dataset for additional context to classification head 169 | num_extra_inputs: total length of extra_inputs_vars (recommended to have this handled by .from_data) 170 | label_smoothing: label_smoothing during classification, only applied in training 171 | pos_weight: parameter for multi label classification's crossentropy loss 172 | ignore_index: which index to ignore in loss (typically -100 or -1). Also works with multi-label classification. 173 | other args passed to HeadConfig, most notably 'weight' 174 | """ 175 | 176 | def __init__( 177 | self, 178 | problem_type: str, 179 | num_labels: int, 180 | labels_var: str = "labels", 181 | id2label: Optional[List] = None, 182 | dropout: Optional[float] = 0.1, 183 | detached: bool = False, 184 | pooling_method: str = "auto", 185 | classifier_hidden_size: Optional[int] = None, 186 | extra_inputs_vars: List[str] = None, 187 | num_extra_inputs: Optional[int] = 0, 188 | pos_weight: Union[float, List[float]] = None, 189 | label_smoothing: float = 0.0, 190 | ignore_index=IGNORE_INDEX, 191 | **kwargs, 192 | ): 193 | super().__init__(**kwargs) 194 | if labels_var == "label": # TODO: in some collator or tokenizers? 195 | logger.warning("It is not recommended to use 'label' as your labels_var, as transformers renames it") 196 | 197 | self.num_labels = num_labels 198 | self.labels_var = labels_var 199 | self.ignore_index = ignore_index 200 | self.id2label = id2label 201 | if self.id2label is not None and len(self.id2label) != num_labels: 202 | raise ValueError(f"id2label (length {len(self.id2label)}) should have length num_labels = {num_labels}") 203 | self.problem_type = problem_type 204 | if self.problem_type not in ClassificationHead.PROBLEM_TYPES: 205 | raise ValueError( 206 | f"Unknown problem type {self.problem_type}, expecting one of {ClassificationHead.PROBLEM_TYPES}" 207 | ) 208 | self.dropout = dropout 209 | self.detached = detached 210 | self.pos_weight = pos_weight 211 | self.label_smoothing = label_smoothing 212 | self.pooling_method = pooling_method 213 | self.classifier_hidden_size = classifier_hidden_size 214 | self.extra_inputs_vars = extra_inputs_vars or [] 215 | self.num_extra_inputs = num_extra_inputs 216 | 217 | def create_head(self, config): 218 | return ClassificationHead(config, self) 219 | 220 | # For internal use 221 | def input_vars(self) -> Dict[str, Set[str]]: # set of variables required to run train/inference on head 222 | infer_vars = super().input_vars()["infer"] | set(self.extra_inputs_vars) 223 | return {"train": infer_vars | {self.labels_var}, "infer": infer_vars} 224 | 225 | def _name(self): # generates name if not given 226 | return self.labels_var 227 | 228 | def __repr__(self): 229 | return f"{self.__class__.__name__}(labels_var = {self.labels_var}, problem_type = {self.problem_type}, num_labels = {self.num_labels})" 230 | 231 | @classmethod 232 | def from_data( 233 | cls, 234 | data: DatasetCollection, 235 | labels_var, 236 | num_labels=None, 237 | problem_type=None, 238 | extra_inputs_vars: List[str] = None, 239 | num_extra_inputs: Optional[int] = 0, 240 | pooling_method: str = "auto", 241 | classifier_hidden_size: Optional[int] = None, 242 | id2label: List[Any] = None, 243 | ignore_index: int = IGNORE_INDEX, 244 | **kwargs, 245 | ): 246 | """Creates a classification head from a DatasetCollection, automatically determining classification type and number of labels 247 | 248 | Args: 249 | labels_var: Which 250 | num_labels: Dimension of the labels, automatically determined when omitted 251 | problem_type: determined using get_classification_type if omitted 252 | extra_inputs_vars: Which additional columns should be used as inputs to the head, in addition to the embeddings? 253 | num_extra_inputs: Total dimension of these, automatically determined when omitted 254 | pooling_method, classifier_hidden_size, kwargs: passed to the constructor 255 | """ 256 | # TODO: option for class balancing? 257 | 258 | auto_determined = {} 259 | if problem_type is None: 260 | problem_type = get_classification_type(data, labels_var) 261 | auto_determined["problem_type"] = problem_type 262 | label_data = data.gather_column(labels_var, splits=["train", "test"]) 263 | 264 | if num_labels is None: 265 | for ds in label_data: 266 | feature_names = getattr(ds.features[labels_var], "names", None) 267 | if feature_names: 268 | num_labels = len(feature_names) 269 | break 270 | else: 271 | if problem_type == ClassificationHead.SINGLE: 272 | n, counts = dataset_value_counts(label_data, labels_var) 273 | if ignore_index in counts: 274 | del counts[ignore_index] 275 | num_labels = len(counts) 276 | elif problem_type == ClassificationHead.REGRESSION: 277 | if is_iterable(label_data[0][labels_var][0]): # regression to a list of labels 278 | num_labels = len(label_data[0][labels_var][0]) 279 | else: # single value regression 280 | num_labels = 1 281 | else: # multi label/onehot 282 | num_labels = len(label_data[0][labels_var][0]) 283 | auto_determined["num_labels"] = num_labels 284 | 285 | if extra_inputs_vars and num_extra_inputs is None: 286 | num_extra_inputs = 0 287 | for column in extra_inputs_vars: 288 | col_data = data.gather_column(column, splits=["train"]) 289 | if not col_data: 290 | raise ValueError(f"No dataset with feature {column} found in training data") 291 | try: 292 | num_extra_inputs += len(col_data[0][column][0]) 293 | except Exception as e: 294 | raise ValueError( 295 | f"Expected a list in {column}, but found {col_data[0][column][0]} as first value. Make sure to binarize and not encode '{column}' in formatting data: {e}" 296 | ) 297 | auto_determined["num_extra_inputs"] = num_extra_inputs 298 | 299 | if id2label is None: 300 | for ds in label_data: 301 | id2label = id2label or getattr(ds.features[labels_var], "names", None) 302 | if id2label is not None: 303 | auto_determined["id2label"] = id2label 304 | 305 | if auto_determined: 306 | logger.info( 307 | f"Automatically determined parameters for classification on '{labels_var}' as {auto_determined}" 308 | ) 309 | 310 | return cls( 311 | labels_var=labels_var, 312 | num_labels=num_labels, 313 | problem_type=problem_type, 314 | id2label=id2label, 315 | pooling_method=pooling_method, 316 | classifier_hidden_size=classifier_hidden_size, 317 | num_extra_inputs=num_extra_inputs, 318 | extra_inputs_vars=extra_inputs_vars, 319 | **kwargs, 320 | ) 321 | 322 | def output_stats(self, output: SequenceClassifierOutput) -> Dict[str, Any]: 323 | """Turns head output into a set of richer statistics""" 324 | stats = super().output_stats(output) 325 | if output.logits is not None: 326 | logits = output.logits[0].detach().cpu() 327 | np_logits = logits.numpy() 328 | if self.problem_type == ClassificationHead.SINGLE: 329 | stats["probs"] = torch.softmax(logits, dim=0).numpy() 330 | stats["predicted_id"] = np_logits.argmax() 331 | if self.id2label is not None: 332 | stats["predicted_label"] = self.id2label[stats["predicted_id"]] 333 | elif self.problem_type == ClassificationHead.MULTI: 334 | stats["probs"] = torch.sigmoid(logits).numpy() 335 | stats["predicted_ids"] = [i for i, p in enumerate(stats["probs"]) if p > 0.5] 336 | if self.id2label is not None: 337 | stats["predicted_labels"] = [self.id2label[i] for i in stats["predicted_ids"]] 338 | elif self.problem_type == ClassificationHead.REGRESSION: 339 | if self.num_labels == 1: 340 | stats["predicted_value"] = np_logits[0] 341 | else: 342 | stats["predicted_values"] = np_logits 343 | return stats 344 | 345 | 346 | def get_classification_type(data: DatasetCollection, column: str, splits=("train", "test")) -> str: 347 | """Determines what kind of problem type is likely specified by the column""" 348 | datasets = data.gather_column(column, splits) 349 | 350 | if not datasets: 351 | raise ValueError(f"Can not determine classification type for column {column}, as no dataset contains it") 352 | if is_iterable(datasets[0][column][0]): 353 | return ClassificationHead.MULTI 354 | elif datasets[0].features[column].dtype in ["float32", "float64"]: # could be SINGLE as well 355 | return ClassificationHead.REGRESSION 356 | else: 357 | return ClassificationHead.SINGLE 358 | -------------------------------------------------------------------------------- /grouphug/heads/lm.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Optional, Set 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import CrossEntropyLoss 6 | from transformers import ( 7 | BertConfig, 8 | DebertaConfig, 9 | DebertaV2Config, 10 | DistilBertConfig, 11 | ElectraConfig, 12 | OPTConfig, 13 | PretrainedConfig, 14 | RobertaConfig, 15 | XLMRobertaConfig, 16 | ) 17 | from transformers.activations import gelu, get_activation 18 | from transformers.modeling_outputs import MaskedLMOutput 19 | from transformers.models.bert.modeling_bert import BertLMPredictionHead 20 | 21 | # MLM heads 22 | # Since huggingface is rather inconsistent in how MLM heads are implemented, we pretend to be the 'official' one as good as we can 23 | from transformers.models.deberta.modeling_deberta import DebertaLMPredictionHead 24 | from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2LMPredictionHead 25 | 26 | from grouphug.config import ( 27 | CLM_LABELS_VAR, 28 | IGNORE_INDEX, 29 | INPUT_EMBEDDING_VAR, 30 | MASKED_PREFIX, 31 | MLM_LABELS_VAR, 32 | MTD_LABELS_VAR, 33 | MTD_TOKEN_RANDOM, 34 | MTD_TOKEN_SIMILARITY, 35 | logger, 36 | ) 37 | from grouphug.heads.base import HeadConfig, ModelHead 38 | 39 | 40 | class MTDPredictions(nn.Module): # from transformers.models.electra.modeling_electra 41 | """Prediction module for the discriminator, made up of two dense layers.""" 42 | 43 | def __init__(self, config): 44 | super().__init__() 45 | 46 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 47 | self.dense_prediction = nn.Linear(config.hidden_size, 1) 48 | self.config = config 49 | self.activation = get_activation(getattr(config, "hidden_act", "gelu")) # Changed: robust default 50 | 51 | def forward(self, discriminator_hidden_states): 52 | hidden_states = self.dense(discriminator_hidden_states) 53 | hidden_states = self.activation(hidden_states) 54 | logits = self.dense_prediction(hidden_states).squeeze(-1) 55 | 56 | return logits 57 | 58 | 59 | class BaseLMHead(ModelHead): 60 | def __init__(self, config: PretrainedConfig, head_config: "LMHeadConfig"): 61 | super().__init__() 62 | self.config = config 63 | self.head_config = head_config 64 | if self.head_config.masked_language_modelling: 65 | self.init_mlm_head() 66 | if self.head_config.masked_token_detection: 67 | self.init_mtd_head() 68 | if self.head_config.causal_language_modelling: 69 | self.init_clm_head() 70 | 71 | def _init_default_lm_head(self): 72 | """Used as default for MLM and CLM""" 73 | self.dense = nn.Linear(self.config.hidden_size, self.config.hidden_size) 74 | self.layer_norm = nn.LayerNorm(self.config.hidden_size, eps=getattr(self.config, "layer_norm_eps", 1e-12)) 75 | self.decoder = nn.Linear(self.config.hidden_size, self.config.vocab_size) 76 | self.bias = nn.Parameter(torch.zeros(self.config.vocab_size)) 77 | self.decoder.bias = self.bias 78 | 79 | def _default_lm_logits(self, embeddings): 80 | """Used as default for MLM and CLM""" 81 | features = embeddings[0] # up to this point it is still the dict output of roberta etc. 82 | 83 | x = self.dense(features) 84 | x = gelu(x) 85 | x = self.layer_norm(x) 86 | 87 | # project back to size of vocabulary with bias 88 | return self.decoder(x) 89 | 90 | # MLM functions - default is basically roberta 91 | 92 | def init_mlm_head(self): 93 | """Creates a masked language modelling head, that is embeddings -> logits over vocab for masked tokens. 94 | Requires care in naming to load pretrained vars.""" 95 | self._init_default_lm_head() 96 | 97 | def get_mlm_logits(self, embeddings): 98 | return self._default_lm_logits(embeddings) 99 | 100 | def mlm_loss(self, prediction_scores, labels): 101 | loss_fct = CrossEntropyLoss() 102 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 103 | if torch.isnan(masked_lm_loss) and (labels == -100).all(): 104 | masked_lm_loss = prediction_scores.new_zeros([], requires_grad=True) # 0.0 with matching type and device 105 | return masked_lm_loss 106 | 107 | def head_output_embeddings(self): 108 | if self.head_config.masked_language_modelling or self.head_config.causal_language_modelling: 109 | return self.decoder # used in MultiTaskModel#get_output_embeddings 110 | 111 | def _tie_weights(self): 112 | # To tie those two weights if they get disconnected (on TPU or when the bias is resized) 113 | if self.head_config.masked_language_modelling or self.head_config.causal_language_modelling: 114 | self.bias = self.decoder.bias 115 | 116 | # Masked Token Detection functions, based on Electra 117 | 118 | def init_mtd_head(self): 119 | """Creates a masked token detection head. Not common in pretrained models, so shared across them.""" 120 | self.discriminator_predictions = MTDPredictions(self.config) 121 | 122 | def get_mtd_logits(self, embeddings): 123 | features = embeddings[0] # up to this point it is still the dict output of roberta etc. 124 | return self.discriminator_predictions(features) 125 | 126 | def mtd_loss(self, logits, labels): 127 | predict_mask = labels != IGNORE_INDEX 128 | masked_logits = logits[predict_mask] 129 | if self.head_config.mtd_pos_weight is not None: 130 | pos_weight = torch.full_like(masked_logits, self.head_config.mtd_pos_weight) 131 | else: 132 | pos_weight = None 133 | loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight) 134 | loss = loss_fct(masked_logits, labels[predict_mask].float()) 135 | return loss 136 | 137 | # Causal language modelling functions. Again, defaults taken from Roberta 138 | 139 | def init_clm_head(self): 140 | """Creates a causal language modelling head. By default this has the same structure as an MLM head""" 141 | self._init_default_lm_head() 142 | 143 | def get_clm_logits(self, embeddings): 144 | return self._default_lm_logits(embeddings) 145 | 146 | def clm_loss(self, logits, labels): 147 | # we are doing next-token prediction; shift prediction scores and input ids by one 148 | shifted_prediction_scores = logits[:, :-1, :].contiguous() 149 | labels = labels[:, 1:].contiguous() 150 | loss_fct = CrossEntropyLoss() 151 | loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 152 | return loss 153 | 154 | # Head implementation 155 | 156 | def forward(self, **kwargs): 157 | input_embedding = kwargs[f"{self.head_config.input_prefix}{INPUT_EMBEDDING_VAR}"] 158 | all_logits = [] 159 | all_losses = [] 160 | 161 | if self.head_config.causal_language_modelling: 162 | token_labels = kwargs.get(CLM_LABELS_VAR) 163 | clm_logits = self.get_clm_logits(input_embedding) 164 | if token_labels is not None: 165 | clm_loss = self.clm_loss(clm_logits, token_labels) 166 | all_losses.append(clm_loss) 167 | all_logits.append(clm_logits) 168 | else: 169 | if self.head_config.masked_language_modelling: 170 | mlm_labels = kwargs.get(MLM_LABELS_VAR) 171 | if mlm_labels is not None: 172 | prediction_scores = self.get_mlm_logits(input_embedding) 173 | mlm_loss = self.mlm_loss(prediction_scores, mlm_labels) 174 | all_losses.append(mlm_loss) 175 | all_logits.append(prediction_scores) 176 | 177 | if self.head_config.masked_token_detection: 178 | token_labels = kwargs.get(MTD_LABELS_VAR) 179 | token_logits = self.get_mtd_logits(input_embedding) 180 | if token_labels is not None: 181 | mtd_loss = self.mtd_loss(token_logits, token_labels) 182 | all_losses.append(mtd_loss) 183 | all_logits.append(token_logits) 184 | 185 | return MaskedLMOutput( 186 | loss=sum(all_losses) if all_losses else None, 187 | logits=all_logits[0] if len(all_logits) == 1 and all_logits[0] is not None else None, 188 | ) 189 | 190 | 191 | class RobertaLMHead(BaseLMHead): 192 | """Roberta Head for masked language modeling. Used as the default.""" 193 | 194 | pass 195 | 196 | 197 | class BaseBertLMHead(BaseLMHead): 198 | def init_mlm_head(self): 199 | self.predictions = self.HEAD_CLASS(self.config) 200 | 201 | def get_mlm_logits(self, embeddings): 202 | features = embeddings[0] # up to this point it is still the dict output of roberta etc. 203 | return self.predictions(features) 204 | 205 | def head_output_embeddings(self): 206 | if self.head_config.masked_language_modelling: 207 | return self.predictions.decoder # used in MultiTaskModel#get_output_embeddings 208 | else: 209 | return super().head_output_embeddings() 210 | 211 | def _tie_weights(self): 212 | pass 213 | 214 | 215 | class BertLMHead(BaseBertLMHead): 216 | HEAD_CLASS = BertLMPredictionHead 217 | 218 | 219 | class DebertaLMHead(BaseBertLMHead): 220 | HEAD_CLASS = DebertaLMPredictionHead 221 | 222 | 223 | class DebertaV2LMHead(BaseBertLMHead): 224 | _rename_keys = [ # loads microsoft models mlm head 225 | ("lm_predictions.lm_head.LayerNorm.bias", "cls.predictions.transform.LayerNorm.bias"), 226 | ("lm_predictions.lm_head.LayerNorm.weight", "cls.predictions.transform.LayerNorm.weight"), 227 | ("lm_predictions.lm_head.dense.bias", "cls.predictions.transform.dense.bias"), 228 | ("lm_predictions.lm_head.dense.weight", "cls.predictions.transform.dense.weight"), 229 | ("lm_predictions.lm_head.bias", "cls.predictions.bias"), 230 | ] 231 | 232 | HEAD_CLASS = DebertaV2LMPredictionHead 233 | 234 | 235 | class DistilBertLMHead(BaseLMHead): # from DistilBertForMaskedLM, which sticks everything in the base model 236 | _rename_keys = [ 237 | (f"{v}.{wb}", f"mlm_head.{v}.{wb}") 238 | for v in ["vocab_transform", "vocab_layer_norm", "vocab_projector"] 239 | for wb in ["weight", "bias"] 240 | ] 241 | 242 | def init_mlm_head(self): 243 | config = self.config 244 | self.activation = get_activation(config.activation) 245 | self.vocab_transform = nn.Linear(config.dim, config.dim) 246 | self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) 247 | self.vocab_projector = nn.Linear(config.dim, config.vocab_size) 248 | 249 | def get_mlm_logits(self, embeddings): 250 | hidden_states = embeddings[0] # up to this point it is still the dict output of roberta etc. 251 | prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim) 252 | prediction_logits = self.activation(prediction_logits) # (bs, seq_length, dim) 253 | prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim) 254 | prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size) 255 | return prediction_logits 256 | 257 | def head_output_embeddings(self): 258 | if self.head_config.masked_language_modelling: 259 | return self.vocab_projector 260 | else: 261 | return super().head_output_embeddings() 262 | 263 | def _tie_weights(self): 264 | pass 265 | 266 | 267 | class ElectraLMHead(BaseLMHead): # from ElectraGeneratorPredictions 268 | _rename_keys = [ 269 | ("generator_lm_head.weight", "lm_head.generator_lm_head.weight"), 270 | ("generator_lm_head.bias", "lm_head.generator_lm_head.bias"), 271 | ("generator_predictions.LayerNorm.bias", "lm_head.LayerNorm.bias"), 272 | ("generator_predictions.LayerNorm.weight", "lm_head.LayerNorm.weight"), 273 | ("generator_predictions.dense.bias", "lm_head.dense.bias"), 274 | ("generator_predictions.dense.weight", "lm_head.dense.weight"), 275 | ("discriminator_predictions.dense.bias", "lm_head.discriminator_predictions.dense.bias"), 276 | ("discriminator_predictions.dense.weight", "lm_head.discriminator_predictions.dense.weight"), 277 | ("discriminator_predictions.dense_prediction.bias", "lm_head.discriminator_predictions.dense_prediction.bias"), 278 | ( 279 | "discriminator_predictions.dense_prediction.weight", 280 | "lm_head.discriminator_predictions.dense_prediction.weight", 281 | ), 282 | ] 283 | 284 | def init_mlm_head(self): 285 | config = self.config 286 | self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) 287 | self.dense = nn.Linear(config.hidden_size, config.embedding_size) 288 | self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size) 289 | 290 | def get_mlm_logits(self, embeddings): 291 | generator_hidden_states = embeddings[0] # up to this point it is still the dict output 292 | hidden_states = self.dense(generator_hidden_states) 293 | hidden_states = get_activation("gelu")(hidden_states) 294 | hidden_states = self.LayerNorm(hidden_states) 295 | logits = self.generator_lm_head(hidden_states) 296 | return logits 297 | 298 | def head_output_embeddings(self): 299 | if self.head_config.masked_language_modelling: 300 | return self.generator_lm_head # used in MultiTaskModel#get_output_embeddings 301 | else: 302 | return super().head_output_embeddings() 303 | 304 | def _tie_weights(self): 305 | pass 306 | 307 | 308 | class BaseGPTLMHead(BaseLMHead): # GPT style models just have a simple projection back to the vocabulary 309 | def init_clm_head(self): 310 | self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False) 311 | 312 | def get_clm_logits(self, embeddings): 313 | features = embeddings[0] 314 | return self.lm_head(features) 315 | 316 | def head_output_embeddings(self): # used in MultiTaskModel#get_output_embeddings 317 | if self.head_config.causal_language_modelling: 318 | return self.lm_head 319 | else: 320 | return super().head_output_embeddings() 321 | 322 | def _tie_weights(self): # no bias 323 | pass 324 | 325 | 326 | class LMHeadConfig(HeadConfig): 327 | """ 328 | General Masked Language Modelling / Masked Token Detection / Causal language modelling head. 329 | In general, this has a few different implementations that should load most common options for pre-trained heads, and default to a new one otherwise. 330 | Expects text -> input_ids in the dataset, and the default AutoCollator in MultitaskTrainer will take care of the rest. 331 | 332 | Args: 333 | separate_embedding: 334 | * When False (default), the collator will overwrite input_ids and other heads will use the model embeddings from the masked input. 335 | * When True, the base model will be called separately (using the prefix `lm_`), and other heads will use non-masked results unless explicitly told to use this prefix (slower, but more accurate finetuning). 336 | masked_language_modelling: turn on masked language modelling, where the model predicts the original tokens replaced with [MASK] or another token. Default choice. 337 | masked_token_detection: turn on electra-style masked token detection, where the model predicts which tokens were replaced with another. See AutoCollator.generate_replacement_tokens for notes on how they are replaced. 338 | causal_language_modelling: turn on GPT-style causal language modelling, where the model predicts the next token based on previous ones. Should be used with model.config.is_decoder = True and a compatible model. 339 | mlm_probability: Percentage of tokens to be replaced. 340 | mask_probability: Of the tokens chosen with mlm_probability, which fraction should be masked. Default 80% without masked_token_detection, 0% otherwise. 341 | generated_token_probability: Of the tokens chosen with mlm_probability, but not masked, which ones should be replaced. Default 50% without masked_token_detection (corresponding to the default 80/10/10 split in BERT), 100% otherwise. 342 | predict_all_tokens: Calculate loss over non-masked tokens as well in MLM 343 | mtd_pos_weight: weight for the positive entries in BCEWithLogitsLoss for masked token detection. 344 | mtd_strategy: strategy for token masking or a callable, see collator for details. Defaults to "token_similarity" for masked_token_detection and "random" for masked_language_modelling. 345 | attribute: by default automatically determined from model type to ensure pre-trained weights load. 346 | """ 347 | 348 | """Defines what model type has what (attribute, class) for the MLM head""" 349 | _CONFIG_TO_HEAD_TYPE = { 350 | ElectraConfig: ("lm_head", ElectraLMHead), 351 | XLMRobertaConfig: ("lm_head", RobertaLMHead), 352 | RobertaConfig: ("lm_head", RobertaLMHead), 353 | DebertaConfig: ("cls", DebertaLMHead), 354 | DebertaV2Config: ("cls", DebertaV2LMHead), 355 | DistilBertConfig: ("mlm_head", DistilBertLMHead), # major renaming, mlm_head from _rename_keys 356 | BertConfig: ("cls", BertLMHead), 357 | OPTConfig: ("lm_head", BaseGPTLMHead), 358 | } 359 | 360 | def __init__( 361 | self, 362 | separate_embedding: bool = False, 363 | masked_language_modelling: bool = None, 364 | masked_token_detection: bool = False, 365 | causal_language_modelling=False, 366 | mlm_probability: float = 0.15, 367 | mask_probability: Optional[float] = None, 368 | generated_token_probability: Optional[float] = None, 369 | predict_all_tokens: bool = False, 370 | mtd_pos_weight: float = None, 371 | mtd_strategy: Optional[Callable] = None, 372 | **kwargs, 373 | ): 374 | # Allow MLM, MTD, CLM and MLM+MTD 375 | if masked_language_modelling is None and not (masked_token_detection or causal_language_modelling): 376 | masked_language_modelling = True 377 | if causal_language_modelling and (masked_language_modelling or masked_token_detection): 378 | raise ValueError("Can not combine causal_language_modelling with other modes.") 379 | if not masked_language_modelling and not masked_token_detection and not causal_language_modelling: 380 | raise ValueError("Can not turn off all modes.") 381 | if causal_language_modelling and separate_embedding: 382 | raise ValueError( 383 | "Since the inputs are unchanged in causal_language_modelling, separate_embedding does not make sense." 384 | ) 385 | self.separate_embedding = separate_embedding 386 | 387 | self.masked_language_modelling = masked_language_modelling 388 | self.mlm_probability = mlm_probability 389 | self.predict_all_tokens = predict_all_tokens 390 | 391 | self.masked_token_detection = masked_token_detection 392 | self.mtd_pos_weight = mtd_pos_weight 393 | 394 | self.causal_language_modelling = causal_language_modelling 395 | 396 | if mask_probability is None: 397 | mask_probability = 0.0 if self.masked_token_detection else 0.8 398 | self.mask_probability = mask_probability 399 | if generated_token_probability is None: 400 | generated_token_probability = 1.0 if self.masked_token_detection else 0.5 401 | self.generated_token_probability = generated_token_probability 402 | 403 | if mtd_strategy is not None: 404 | self.mtd_strategy = mtd_strategy 405 | elif self.masked_token_detection: # MTD default 406 | self.mtd_strategy = MTD_TOKEN_SIMILARITY 407 | else: # MLM default 408 | self.mtd_strategy = MTD_TOKEN_RANDOM 409 | 410 | kwargs["input_prefix"] = MASKED_PREFIX if separate_embedding else "" 411 | super().__init__(**kwargs) 412 | 413 | def create_head(self, config): 414 | for config_cls, (attr, head_cls) in self._CONFIG_TO_HEAD_TYPE.items(): 415 | if type(config) == config_cls: 416 | self.attribute = self.attribute or attr 417 | return head_cls(config, self) 418 | 419 | logger.warning(f"No language modelling head registered for {config.__class__.__name__}, using default") 420 | self.attribute = self.attribute or "lm_head" 421 | return BaseLMHead(config, self) 422 | 423 | @property 424 | def labels_var(self): # used to check if we can compute metrics in trainer 425 | if self.masked_language_modelling: 426 | return MLM_LABELS_VAR 427 | else: 428 | return MTD_LABELS_VAR 429 | 430 | def input_vars(self) -> Dict[str, Set[str]]: # set of variables required to run train/inference on head 431 | infer_vars = super().input_vars()["infer"] 432 | train_vars = infer_vars.copy() 433 | if self.masked_language_modelling: 434 | train_vars.add(MLM_LABELS_VAR) 435 | if self.masked_token_detection: 436 | train_vars.add(MTD_LABELS_VAR) 437 | return {"train": infer_vars | train_vars, "infer": infer_vars} 438 | 439 | def _name(self): # generates name if not given 440 | return "mlm" 441 | 442 | def __repr__(self): 443 | return f"{self.__class__.__name__}(masked_language_modelling={self.masked_language_modelling}, masked_token_detection={self.masked_token_detection})" 444 | -------------------------------------------------------------------------------- /grouphug/model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from abc import ABC 4 | from contextlib import contextmanager 5 | from dataclasses import dataclass 6 | from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from datasets import Dataset, DatasetDict 12 | from torch import nn 13 | from tqdm import tqdm 14 | from transformers import ( 15 | AutoConfig, 16 | AutoModel, 17 | AutoTokenizer, 18 | BertPreTrainedModel, 19 | DebertaPreTrainedModel, 20 | DebertaV2PreTrainedModel, 21 | DistilBertPreTrainedModel, 22 | ElectraPreTrainedModel, 23 | GPT2PreTrainedModel, 24 | GPTJPreTrainedModel, 25 | GPTNeoXPreTrainedModel, 26 | OPTPreTrainedModel, 27 | PreTrainedTokenizerBase, 28 | RobertaPreTrainedModel, 29 | XLMRobertaConfig, 30 | ) 31 | from transformers.configuration_utils import PretrainedConfig 32 | from transformers.modeling_outputs import ModelOutput 33 | 34 | from grouphug import DatasetFormatter 35 | from grouphug.config import ( 36 | FORMATTER_FILE_NAME, 37 | HEADS_FILE_NAME, 38 | INPUT_EMBEDDING_VAR, 39 | INPUT_IDS_VAR, 40 | MASKED_PREFIX, 41 | TOKENIZER_VARS, 42 | logger, 43 | ) 44 | from grouphug.heads import LMHeadConfig, head_config_from_dict 45 | from grouphug.heads.base import HeadConfig 46 | 47 | tqdm.pandas() 48 | 49 | 50 | @dataclass 51 | class MultiTaskOutput(ModelOutput): 52 | """ 53 | Base class for outputs of sentence classification models. 54 | 55 | Args: 56 | loss: Total weighted loss across all heads 57 | logits: Logits of the active tasks as tuple if more than one, or just a tensor otherwise 58 | head_outputs: Results of model heads 59 | embeddings: embeddings (final hidden states of the base model) by prefix 60 | """ 61 | 62 | loss: Optional[torch.FloatTensor] = None 63 | logits: Optional[Union[torch.FloatTensor, Tuple[torch.FloatTensor]]] = None # active task logit(s) 64 | head_outputs: Dict[str, ModelOutput] = None 65 | embeddings: Dict[str, torch.Tensor] = None 66 | 67 | 68 | class ModelInferenceError(Exception): 69 | """Raised when model.forward fails, saving the batch and the head that failed for easier interactive debugging""" 70 | 71 | def __init__(self, message, batch=None, cls=None, head=None): 72 | super().__init__(message) 73 | self.head = head 74 | self.orig_cls = cls 75 | self.batch = batch 76 | 77 | 78 | DEFAULT_IGNORE_MISSING = ["lm_head.lm_head", "lm_head.decoder.weight"] 79 | DEFAULT_IGNORE_SAVE = ["lm_head.decoder.weight"] 80 | 81 | 82 | class _BaseMultiTaskModel(ABC): 83 | AUTOMODEL_CLASSES = [] 84 | 85 | def __init_subclass__(cls, register_auto_class=True, **kwargs): 86 | super().__init_subclass__(**kwargs) 87 | if register_auto_class and hasattr(cls, "config_class"): # if not some intermediate helper mixin 88 | cls._keys_to_ignore_on_load_missing = (cls._keys_to_ignore_on_load_missing or []) + DEFAULT_IGNORE_MISSING 89 | cls._keys_to_ignore_on_save = (cls._keys_to_ignore_on_save or []) + DEFAULT_IGNORE_SAVE 90 | _BaseMultiTaskModel.AUTOMODEL_CLASSES.append((cls.config_class, cls)) 91 | 92 | def __init__( 93 | self, 94 | config: PretrainedConfig, 95 | head_configs: List[HeadConfig], 96 | formatter: "DatasetFormatter" = None, 97 | tokenizer=None, 98 | ): 99 | super().__init__(config) 100 | self.config = config 101 | self._tokenizer = tokenizer 102 | 103 | # cached convenience vars 104 | self._vars = None 105 | 106 | # this is self.bert / self.roberta, and such 107 | self._init_base_model(config) 108 | 109 | self.formatter = formatter 110 | self.head_configs = self._create_heads(config, head_configs) 111 | self._active_heads = self.head_configs 112 | # each head is stored in either it's specified attribute, or in a ModuleDict 113 | self.other_heads = nn.ModuleDict() # heads without attribute 114 | for hc in self.head_configs: 115 | if hc.attribute: 116 | setattr(self, hc.attribute, hc._head) 117 | else: 118 | self.other_heads[hc.name] = hc._head 119 | 120 | # Initialize weights and apply final processing 121 | self.post_init() 122 | 123 | def _init_base_model(self, config): # does not have pooling layer option 124 | # this is self.bert / self.roberta, and such 125 | setattr(self, self.base_model_prefix, AutoModel.from_config(config)) 126 | 127 | # @property 128 | # def base_model(self): 129 | # return getattr(self, self.base_model_prefix) 130 | 131 | def get_mlm_head(self) -> Optional[LMHeadConfig]: 132 | for hc in self.head_configs: 133 | if isinstance(hc, LMHeadConfig): 134 | return hc 135 | 136 | @staticmethod 137 | def _create_heads(config: PretrainedConfig, head_configs: List[HeadConfig]): 138 | """populates additional fields in configs""" 139 | # TODO: could allow other heads' outputs as inputs? 140 | names_used = set() 141 | output_embeds = [] 142 | for hc in head_configs: 143 | hc._required_input_vars = hc.input_vars() 144 | hc._head = hc.create_head(config) # may set attribute as well 145 | hc.name = (hc.name or hc._name()).replace(".", "_") # . invalid in module names 146 | while hc.name in names_used: 147 | hc.name += "(duplicate name)" # should be rare, no sensible names for you 148 | if hasattr(hc._head, "head_output_embeddings"): 149 | output_embeds.append(hc._head) 150 | if len(output_embeds) > 1: 151 | logger.warning( 152 | "Found multiple heads with output embeddings, you will need to tie weights yourself before training!" 153 | ) 154 | return head_configs 155 | 156 | def set_active_heads(self, heads: List[Union[str, HeadConfig]] = None): 157 | """Limit model to run only a subset of heads, either by head config name or config itself""" 158 | if heads is None: 159 | self._active_heads = self.head_configs 160 | else: 161 | self._active_heads = [hc for hc in self.head_configs if any(hc is t or hc.name == t for t in heads)] 162 | 163 | def get_active_heads(self): 164 | return self._active_heads 165 | 166 | @contextmanager 167 | def active_heads(self, heads: List[Union[str, HeadConfig]] = None): 168 | old_heads = self.get_active_heads() 169 | self.set_active_heads(heads) 170 | try: 171 | yield self 172 | finally: 173 | self.set_active_heads(old_heads) 174 | 175 | # tokenizer helper 176 | def tokenizer(self, name_or_path=None, **kwargs): 177 | if self._tokenizer is None: 178 | self._tokenizer = AutoTokenizer.from_pretrained(name_or_path or self.config._name_or_path, **kwargs) 179 | return self._tokenizer 180 | 181 | # ... 182 | def input_prefixes(self) -> Set[str]: 183 | """Which prefixes for input_ids are potentially used in .forward ?""" 184 | return {p for hc in self.head_configs for p in hc.input_prefixes() if p != MASKED_PREFIX} 185 | 186 | def vars(self) -> Set[str]: 187 | """Which variables are potentially used in .forward ? 188 | Cached and used for inference and collating""" 189 | if not self._vars: 190 | input_prefixes = self.input_prefixes() 191 | tokenizer_vars = {p + tv for p in input_prefixes for tv in TOKENIZER_VARS} # also include masks etc 192 | self._vars = {v for hc in self.head_configs for v in hc.input_vars()["train"]} | tokenizer_vars 193 | return self._vars 194 | 195 | # loading 196 | 197 | def get_output_embeddings(self): # used by ~.from_pretrained via default tie_weights 198 | mlm_hc = self.get_mlm_head() 199 | if mlm_hc is not None: 200 | return mlm_hc._head.head_output_embeddings() 201 | else: 202 | return None 203 | 204 | def set_output_embeddings(self, new_embeddings): 205 | raise NotImplementedError( 206 | "set_output_embeddings is not implemented, use tie weights." 207 | ) # in resizing embeddings? 208 | 209 | # inference methods 210 | 211 | def calculate_model_embeddings(self, prefix="", **kwargs): 212 | encoder = self.base_model 213 | # mlm just gives input ids, so we keep the other vars 214 | no_mlm_prefix = "" if prefix == MASKED_PREFIX else "" 215 | optional_args = dict( 216 | attention_mask=kwargs.get(no_mlm_prefix + "attention_mask"), 217 | token_type_ids=kwargs.get(no_mlm_prefix + "token_type_ids"), 218 | position_ids=kwargs.get(no_mlm_prefix + "position_ids"), 219 | ) 220 | optional_args = {k: v for k, v in optional_args.items() if v is not None} 221 | return encoder(kwargs[prefix + INPUT_IDS_VAR], return_dict=True, **optional_args) # not optional 222 | 223 | def forward(self, inference_only: bool = False, return_loss: bool = True, **kwargs): 224 | r"""Determines which heads can be run, and returns weighted loss over them along with individual head outputs 225 | 226 | Args: 227 | inference_only: will run heads even when labels are missing, and return full details 228 | return_embeddings: will return a dict in .prefix_to_embedding 229 | return_loss: default to True to make Trainer compute loss during evaluation stage 230 | """ 231 | # Which heads can we infer with these args? 232 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 233 | # which head can we run given the inputs and mode? 234 | rv_type = "infer" if inference_only else "train" 235 | relevant_heads = [hc for hc in self._active_heads if all(v in kwargs for v in hc._required_input_vars[rv_type])] 236 | # Which inputs do we need to encode? Usually just 'input_ids' (so prefixes=[""]) 237 | relevant_input_prefixes = {p for h in relevant_heads for p in h.input_prefixes()} 238 | 239 | # TODO: this is pretty naive, and could be cached (not recalculating non-masked entries etc) 240 | try: 241 | prefix_to_embedding = { 242 | f"{p}{INPUT_EMBEDDING_VAR}": self.calculate_model_embeddings(p, **kwargs) 243 | for p in relevant_input_prefixes 244 | } 245 | except Exception as e: # here we tell people in which head the error was, otherwise it's hard to debug 246 | raise ModelInferenceError(f"Error in calculating embeddings: {e}", cls=e.__class__, batch=kwargs) from e 247 | # run heads and collect losses 248 | outputs = {} 249 | losses = [] 250 | for hc in relevant_heads: 251 | combined_args = {**kwargs, **prefix_to_embedding} 252 | try: 253 | head_output = hc._head(**combined_args) 254 | except Exception as e: # here we tell people in which head the error was, otherwise it's hard to debug 255 | raise ModelInferenceError( 256 | f"Error in head {hc.name}: {e}", cls=e.__class__, head=hc._head, batch=combined_args 257 | ) from e 258 | outputs[hc.name] = head_output 259 | if head_output.loss is not None: 260 | losses.append(head_output.loss * hc.weight) 261 | 262 | loss = None 263 | if losses: 264 | loss = sum(losses) 265 | elif not inference_only: 266 | req_vars = {str(hc): hc._required_input_vars[rv_type] for hc in self._active_heads} 267 | raise ModelInferenceError( 268 | f"No valid heads among {len(self._active_heads)} active heads for batch with arguments {kwargs.keys()} and inference_only=False, required vars = {req_vars}", 269 | batch=kwargs, 270 | ) 271 | 272 | task_logits = tuple(getattr(outputs.get(hc.name), "logits", None) for hc in self._active_heads) 273 | if any(l is None for l in task_logits): 274 | task_logits = None # huggingface doesn't like mixing in None 275 | elif len(task_logits) == 1: 276 | task_logits = task_logits[0] # huggingface will helpfully change the type of labels, so we do this too 277 | 278 | # by default return only loss + logits, since huggingface trainer does not like dicts in outputs 279 | return MultiTaskOutput( 280 | loss=loss, 281 | logits=task_logits, 282 | head_outputs=outputs if inference_only else None, 283 | embeddings=prefix_to_embedding if inference_only else None, 284 | ) 285 | 286 | def format_forward(self, tokenizer=None, **kwargs): 287 | """Simply wraps variables in a list and passes to format_forward_batch for convenience""" 288 | return self.format_forward_batch(tokenizer=tokenizer, **{k: [v] for k, v in kwargs.items()}) 289 | 290 | def _tensorize(self, v: Any) -> Union[str, torch.Tensor]: 291 | return v if isinstance(v, str) else torch.tensor(v, device=self.device) 292 | 293 | def format_forward_batch(self, tokenizer=None, **kwargs): 294 | """If any kind of expected input_ids is passed, assumes data is formatter 295 | kwargs: variables, either single data points or a batch 296 | """ 297 | if not any(p + INPUT_IDS_VAR in kwargs for hc in self.head_configs for p in hc.input_prefixes()): 298 | if not self.formatter: 299 | raise ValueError("Expecting either input_ids, or a formatter present. Pass one to from_pretrained!") 300 | tokenizer = tokenizer or self.tokenizer() 301 | data = self.formatter.apply_batch(kwargs, tokenizer=tokenizer) 302 | else: 303 | data = kwargs 304 | model_vars = self.vars() 305 | data = {k: self._tensorize(np.array(v)) for k, v in data.items() if k in model_vars} # np array avoids warning 306 | with torch.no_grad(): 307 | return self.forward(inference_only=True, **data) 308 | 309 | def predict( 310 | self, 311 | data: Union[Dict, Iterable[Dict], DatasetDict, Dataset, pd.DataFrame], 312 | heads: List = None, 313 | tokenizer=None, 314 | show_progress: bool = False, 315 | ): 316 | """Runs the model on some data and gives predictions. 317 | Uses .format_forward_batch, so input data does not have to be formatted. 318 | 319 | Args: 320 | data: 321 | * dict -> assumed to be a single sample, returns a dict with results 322 | * Dataset or Iterable -> list of such dicts 323 | * DatasetDict -> dict of such lists 324 | * DataFrame -> DataFrame 325 | heads: Only run these heads (via .active_heads) 326 | tokenizer: passed to format_forward_batch 327 | show_progress: use tqdm to show progress""" 328 | if heads is not None: 329 | with self.active_heads(heads): 330 | return self.predict(data, tokenizer=tokenizer) 331 | 332 | def process_record(record): 333 | result = self.format_forward(tokenizer=tokenizer, **record) 334 | stats = {} 335 | if result.loss is not None: 336 | stats["loss"] = result.loss.item() 337 | for hc in self.head_configs: 338 | if hc.name in result.head_outputs: 339 | head_stats = hc.output_stats(result.head_outputs[hc.name]) 340 | stats.update({hc.name + "_" + k: v for k, v in head_stats.items()}) 341 | return stats 342 | 343 | if isinstance(data, pd.DataFrame): 344 | data_apply_f = data.progress_apply if show_progress else data.apply 345 | return data_apply_f(lambda r: pd.Series(process_record(r)), axis=1) 346 | if isinstance(data, DatasetDict): 347 | return {k: self.predict(v) for k, v in data.items()} 348 | if isinstance(data, dict): 349 | return process_record(data) 350 | if isinstance(data, Iterable): 351 | if show_progress: 352 | data = tqdm(data) 353 | return [process_record(record) for record in data] 354 | raise ValueError(f"Unknown format {data.__class__} for data") 355 | 356 | @classmethod 357 | def _load_pretrained_model(cls, model, state_dict, loaded_keys, *args, **kwargs): 358 | # since some huggingface models stick different parameters in the base model which can be used in a head, we rename some here 359 | # TODO: low mem version? 360 | renamed = {} 361 | for hc in model.head_configs: 362 | for from_key, to_key in hc._head._rename_keys: 363 | if from_key in state_dict: 364 | state_dict[to_key] = state_dict[from_key] 365 | del state_dict[from_key] 366 | renamed[from_key] = to_key # do not modify in iteration 367 | if renamed: 368 | logger.warning(f"Renaming {renamed} in loading pre-trained model") 369 | return super()._load_pretrained_model(model, state_dict, state_dict.keys(), *args, **kwargs) 370 | 371 | @classmethod 372 | def from_pretrained(cls, pretrained_model_name_or_path, head_configs=None, formatter=None, *args, **kwargs): 373 | """Loads model, head configs, data formatter""" 374 | if head_configs is None: 375 | config_file = os.path.join(pretrained_model_name_or_path, HEADS_FILE_NAME) 376 | if not os.path.isfile(config_file): 377 | raise ValueError( 378 | f"Must give either head_configs directly, or have {HEADS_FILE_NAME} in directory {pretrained_model_name_or_path}" 379 | ) 380 | logger.info(f"loading heads config file {config_file}") 381 | with open(config_file, "r") as f: 382 | head_config_json = json.load(f) 383 | head_configs = [head_config_from_dict(hc_dict) for hc_dict in head_config_json] 384 | 385 | if formatter is None: 386 | formatter_file = os.path.join(pretrained_model_name_or_path, FORMATTER_FILE_NAME) 387 | if os.path.isfile(formatter_file): # optional, so no error if missing 388 | logger.info(f"loading formatter from {formatter_file}") 389 | with open(formatter_file, "r") as f: 390 | formatter = DatasetFormatter.from_dict(json.load(f)) 391 | 392 | return super().from_pretrained( 393 | pretrained_model_name_or_path, head_configs=head_configs, formatter=formatter, *args, **kwargs 394 | ) 395 | 396 | def get_word_embeddings(self) -> torch.nn.Embedding: 397 | return self.base_model.embeddings.word_embeddings 398 | 399 | def token_similarity(self, indices: torch.Tensor): 400 | """given a tensor of token indices of size n, returns a tensor of size n x vocab_size of token similarity""" 401 | embeddings = torch.nn.functional.normalize(self.get_word_embeddings().weight.detach()) 402 | similarity = embeddings[indices] @ embeddings.t() # range (-1..1) 403 | similarity[:, self.tokenizer().all_special_ids] = -1.0 # minimum value 404 | return similarity 405 | 406 | 407 | # common to bert/roberta models in huggingface is removing the pooling layer 408 | class _BertModelBase(_BaseMultiTaskModel, register_auto_class=False): 409 | _keys_to_ignore_on_load_unexpected = ["pooler", "cls.seq_relationship"] 410 | 411 | def _init_base_model(self, config): 412 | setattr(self, self.base_model_prefix, AutoModel.from_config(config, add_pooling_layer=False)) 413 | 414 | 415 | class BertMultiTaskModel(_BertModelBase, BertPreTrainedModel): 416 | pass 417 | 418 | 419 | class DistilBertMultiTaskModel(_BaseMultiTaskModel, DistilBertPreTrainedModel): 420 | pass 421 | 422 | 423 | class RobertaMultiTaskModel(_BertModelBase, RobertaPreTrainedModel): 424 | pass 425 | 426 | 427 | class XLMRobertaMultiTaskModel(_BertModelBase, RobertaPreTrainedModel): 428 | config_class = XLMRobertaConfig # sort of the same? 429 | 430 | 431 | class ElectraMultiTaskModel(_BaseMultiTaskModel, ElectraPreTrainedModel): 432 | pass 433 | 434 | 435 | class DebertaMultiTaskModel(_BaseMultiTaskModel, DebertaPreTrainedModel): 436 | pass 437 | 438 | 439 | class DebertaV2MultiTaskModel(_BaseMultiTaskModel, DebertaV2PreTrainedModel): 440 | _keys_to_ignore_on_load_unexpected = ["position_embeddings", "mask_predictions"] # common in pretrained 441 | _keys_to_ignore_on_load_missing = ["position_ids", "cls.predictions.decoder"] 442 | 443 | 444 | class OPTMultiTaskModel(_BaseMultiTaskModel, OPTPreTrainedModel): 445 | pass 446 | 447 | 448 | class GPT2MultiTaskModel(_BaseMultiTaskModel, GPT2PreTrainedModel): 449 | pass 450 | 451 | 452 | class GPTJMultiTaskModel(_BaseMultiTaskModel, GPTJPreTrainedModel): 453 | pass 454 | 455 | 456 | class GPTNeoXMultiTaskModel(_BaseMultiTaskModel, GPTNeoXPreTrainedModel): 457 | pass 458 | 459 | 460 | class AutoMultiTaskModel: 461 | @staticmethod 462 | def _model_class_for_config(config): 463 | automodel_cls = [k for k in _BaseMultiTaskModel.AUTOMODEL_CLASSES if k[0] == config.__class__] 464 | if not automodel_cls: 465 | return None 466 | assert len(automodel_cls) == 1, "Multiple registered auto-classes found, call one of them directly" 467 | return automodel_cls 468 | 469 | @classmethod 470 | def from_pretrained( 471 | cls, 472 | pretrained_model_name_or_path, 473 | head_configs: List[HeadConfig] = None, 474 | formatter: DatasetFormatter = None, 475 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 476 | **kwargs, 477 | ) -> _BaseMultiTaskModel: 478 | """Initialized a model from a pre-trained multi-task or base model 479 | 480 | Args: 481 | head_configs: model head configurations. Will try to load if omitted. 482 | formatter: optional DatasetFormatter which will be saved with the model, and can be used to infer on non-formatted data. Will try to load if omitted. 483 | tokenizer: pass a tokenizer here to avoid it being created by the model when it is missing one. 484 | kwargs: passed to model init, always empty in current setup, but can be used for your own models 485 | """ 486 | autoconfig = AutoConfig.from_pretrained(pretrained_model_name_or_path) 487 | kwargs = dict(head_configs=head_configs, formatter=formatter, tokenizer=tokenizer, **kwargs) 488 | 489 | automodel_cls = cls._model_class_for_config(autoconfig) 490 | if not automodel_cls: 491 | raise NotImplementedError( 492 | f"{pretrained_model_name_or_path} uses {autoconfig.__class__.__name__} which is not supported, see documentation for which models are supported by AutoMultiTaskModel" 493 | ) 494 | return automodel_cls[0][1].from_pretrained(pretrained_model_name_or_path, **kwargs) 495 | 496 | @classmethod 497 | def from_config( 498 | cls, 499 | config, 500 | head_configs: List[HeadConfig], 501 | formatter: DatasetFormatter = None, 502 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 503 | **kwargs, 504 | ) -> _BaseMultiTaskModel: 505 | """See from_pretrained, but taking a config object instead""" 506 | automodel_cls = cls._model_class_for_config(config) 507 | return automodel_cls[0][1]._from_config( 508 | config, head_configs=head_configs, formatter=formatter, tokenizer=tokenizer, **kwargs 509 | ) 510 | -------------------------------------------------------------------------------- /grouphug/trainer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import os 4 | import time 5 | from typing import Dict, List, Optional, Union 6 | 7 | import numpy as np 8 | from datasets import Dataset 9 | from torch.utils.data import DataLoader 10 | from transformers import PreTrainedTokenizerBase, Trainer, TrainingArguments 11 | from transformers.trainer_utils import speed_metrics 12 | 13 | from grouphug import DatasetCollection 14 | from grouphug.collator import AutoCollator 15 | from grouphug.config import FORMATTER_FILE_NAME, HEADS_FILE_NAME, MLM_LABELS_VAR, logger 16 | from grouphug.model import _BaseMultiTaskModel 17 | from grouphug.utils import np_json_dump 18 | 19 | 20 | class MultiTaskDataLoader: 21 | """Data loader-like object that combines and samples from multiple single-task data loaders.""" 22 | 23 | def __init__(self, dataloaders: List[DataLoader], shuffler=None): 24 | self.dataloaders = dataloaders 25 | self.shuffler = shuffler # reproducible training 26 | 27 | def __len__(self): 28 | return sum(len(dl) for dl in self.dataloaders) 29 | 30 | def __iter__(self): 31 | """For each batch, sample a task, and yield a batch from the respective task Dataloader.""" 32 | task_choice_list = [] 33 | for i, dl in enumerate(self.dataloaders): 34 | task_choice_list += [i] * len(dl) 35 | task_choice_list = np.array(task_choice_list) 36 | if self.shuffler is not None: 37 | self.shuffler.shuffle(task_choice_list) 38 | dataloader_iters = [iter(dl) for dl in self.dataloaders] 39 | for i in task_choice_list: 40 | yield next(dataloader_iters[i]) 41 | 42 | 43 | class MultiTaskTrainer(Trainer): 44 | """Multi task trainer, taking a list of datasets. 45 | 46 | Args: 47 | train_data: DatasetCollection or List of Datasets. Often data[:,'train'] or [data[k,'train'] for k in some_list_with_duplicates] 48 | eval_data: DatasetCollection or Dataset. Often data[:,'test'] or data['keytask','test'] 49 | eval_heads: which tasks are active in evaluation, can be task (=DatasetCollection key) dependent passed as dict 50 | data_collator: the default uses AutoCollator, which handles most things. Inherit from it for e.g. masked token detection generation.""" 51 | 52 | def __init__( 53 | self, 54 | model: _BaseMultiTaskModel, 55 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 56 | train_data: Union[DatasetCollection, List[Dataset]] = None, 57 | eval_data: DatasetCollection = None, 58 | eval_heads: Union[Dict[str, List], List] = None, 59 | data_collator=None, 60 | args: TrainingArguments = None, 61 | *xargs, 62 | **kwargs, 63 | ): 64 | 65 | if "train_dataset" in kwargs or "eval_dataset" in kwargs: 66 | raise ValueError("MultitaskTrainer: Use train_data and eval_data instead of train/eval_dataset") 67 | data_collator = data_collator or AutoCollator(model, tokenizer) 68 | 69 | super().__init__(model=model, tokenizer=tokenizer, args=args, data_collator=data_collator, *xargs, **kwargs) 70 | # fix training args 71 | self.args.label_names = [] # default is ['labels'], prediction_step checks this, and we don't have a fixed set 72 | self.args.remove_unused_columns = False # always required since our .forward has **kwargs 73 | 74 | # store datasets for use in get_*_dataloader 75 | self.train_dataset = [] # this is checked in .train for some reason 76 | self.train_data = train_data 77 | self.eval_data = eval_data 78 | if isinstance(eval_heads, str): # single task, not quite right but we can fix 79 | eval_heads = [eval_heads] 80 | self.eval_heads = eval_heads 81 | # Do some checks that could otherwise take a long time to appear 82 | if isinstance(eval_heads, dict) or isinstance(self.compute_metrics, dict): 83 | if not isinstance(eval_data, (dict, DatasetCollection)): 84 | raise ValueError("When passing eval_heads as dict, eval_data can not be a single entry.") 85 | missing_keys = eval_heads.keys() - eval_data.keys() 86 | if missing_keys: 87 | raise ValueError(f"eval_heads needs all keys from eval_data, missing {missing_keys}") 88 | 89 | def num_examples(self, dataloader: Union[DataLoader, MultiTaskDataLoader]) -> int: 90 | if isinstance(dataloader, MultiTaskDataLoader): 91 | return sum(len(dl.dataset) for dl in dataloader.dataloaders) 92 | else: 93 | return len(dataloader.dataset) 94 | 95 | def get_train_dataloader(self): 96 | """Returns a MultitaskDataloader, which is not actually a Dataloader but just defers to a list of underlying ones""" 97 | # get_train_dataloader only uses self.train_dataset, so we pass it that way. avoids copy-pasting the method 98 | # fun fact: super() does not work in list comprehensions 99 | train_datasets = self.train_data 100 | if isinstance(train_datasets, DatasetCollection): 101 | train_datasets = train_datasets.entries() 102 | dataloaders = [super(MultiTaskTrainer, self).get_train_dataloader() for self.train_dataset in train_datasets] 103 | self._update_collator_heads() 104 | return MultiTaskDataLoader(dataloaders, shuffler=np.random.RandomState(self.args.seed)) 105 | 106 | def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): 107 | if not self.is_world_process_zero(): 108 | return 109 | super().save_model(output_dir, _internal_call=_internal_call) 110 | if self.args.should_save: 111 | # save head configs 112 | output_dir = output_dir or self.args.output_dir 113 | json_path = os.path.join(output_dir, HEADS_FILE_NAME) 114 | head_configs = copy.deepcopy(self.model.head_configs) 115 | for hc in head_configs: 116 | hc._head = None # don't save heads 117 | 118 | with open(json_path, "w", encoding="utf-8") as f: 119 | np_json_dump( 120 | [hc.to_dict() for hc in head_configs], 121 | f, 122 | indent=2, 123 | ) 124 | logger.info(f"Model head configs saved in {json_path}") 125 | 126 | if self.model.formatter is not None: 127 | fmt_json_path = os.path.join(output_dir, FORMATTER_FILE_NAME) 128 | with open(fmt_json_path, "w", encoding="utf-8") as f: 129 | np_json_dump(self.model.formatter.to_dict(), f, indent=2) 130 | logger.info(f"Model formatter saved in {fmt_json_path}") 131 | 132 | def _update_collator_heads(self): # allows dynamically disabling MLM in evaluation. TODO: try? 133 | f = getattr(self.data_collator, "update_mlm_active") 134 | if f: 135 | f() 136 | 137 | def evaluate( 138 | self, 139 | eval_dataset: Optional[Union[DatasetCollection, Dataset]] = None, 140 | ignore_keys: Optional[List[str]] = None, 141 | metric_key_prefix: str = "eval", 142 | heads: Union[Dict[str, List], List] = None, 143 | ) -> Dict[str, float]: 144 | """A reimplementation of the evaluate methods allowing for multiple heads and datasets. 145 | 146 | Args: 147 | * metric_key_prefix: automatically suffixed with dataset name and split key if multiple entries exist 148 | """ 149 | 150 | self._memory_tracker.start() 151 | start_time = time.time() 152 | 153 | heads = heads or self.eval_heads 154 | eval_data = eval_dataset or self.eval_data 155 | compute_metrics_fn = self.compute_metrics 156 | if not isinstance(eval_data, (DatasetCollection, dict)): 157 | eval_data = {"": {"": eval_data}} 158 | if not isinstance(heads, dict): 159 | heads = {k: heads for k in eval_data} 160 | 161 | weighted_loss = 0.0 162 | sample_count = 0 163 | combined_metrics = {} 164 | try: 165 | for dataset_name, dsd in eval_data.items(): 166 | for ds_key, dataset in dsd.items(): 167 | task_heads = heads[dataset_name] 168 | with self.model.active_heads(task_heads): # default is None = all 169 | self._update_collator_heads() 170 | active_heads = self.model.get_active_heads() 171 | logger.info( 172 | f"Set active heads to {[hc.name for hc in active_heads]} for evaluation on {dataset_name}" 173 | ) 174 | if task_heads and len(active_heads) != len(task_heads): 175 | raise ValueError( 176 | f"Invalid heads spec {task_heads} for task {dataset_name} does not match model heads. Did you forget to name a head?" 177 | ) 178 | if compute_metrics_fn: 179 | self.compute_metrics = lambda *args: compute_metrics_fn( 180 | *args, dataset_name=dataset_name, heads=active_heads 181 | ) 182 | labels_vars = [getattr(h, "labels_var", None) for h in active_heads] 183 | for labels_var, h in zip(labels_vars, active_heads): 184 | if not labels_var: 185 | raise ValueError( 186 | f"Can not compute metrics for {dataset_name} as first head {h.name} has no labels_var" 187 | ) 188 | if labels_var != MLM_LABELS_VAR and labels_var not in dataset.features: 189 | raise ValueError( 190 | f"Can not compute metrics as an dataset {dataset_name} is missing '{labels_var}'" 191 | ) # MINOR: collator could add more things? 192 | self.label_names = labels_vars 193 | prefix = metric_key_prefix 194 | if len(eval_data) > 1: 195 | prefix += f"_{dataset_name}" 196 | if len(dsd) > 1: 197 | prefix += f"_{ds_key}" 198 | eval_dataloader = self.get_eval_dataloader(dataset) 199 | output = self.evaluation_loop( 200 | eval_dataloader, 201 | description=f"Evaluation on {dataset_name}", 202 | prediction_loss_only=True if self.compute_metrics is None else None, 203 | ignore_keys=ignore_keys, 204 | metric_key_prefix=prefix, 205 | ) 206 | num_samples = self.num_examples(eval_dataloader) 207 | weighted_loss += num_samples * output.metrics[f"{prefix}_loss"] # TODO: .get? 208 | sample_count += num_samples 209 | combined_metrics.update(output.metrics) 210 | except Exception: 211 | raise 212 | finally: 213 | self.compute_metrics = compute_metrics_fn 214 | self.label_names = [] # during training, do not look for labels 215 | self._update_collator_heads() # activate MLM again if needed 216 | 217 | # copied from transformers 218 | total_batch_size = self.args.eval_batch_size * self.args.world_size 219 | 220 | loss_key = ( 221 | f"{metric_key_prefix}_loss" # notebook tracker derived metric_key_prefix back from the LAST loss key! 222 | ) 223 | _ = combined_metrics.pop(loss_key, None) 224 | combined_metrics[loss_key] = weighted_loss / sample_count 225 | combined_metrics.update( 226 | speed_metrics( 227 | metric_key_prefix, 228 | start_time, 229 | num_samples=num_samples, 230 | num_steps=math.ceil(num_samples / total_batch_size), 231 | ) 232 | ) 233 | self.log(combined_metrics) 234 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, combined_metrics) 235 | self._memory_tracker.stop_and_update_metrics(combined_metrics) 236 | return combined_metrics 237 | -------------------------------------------------------------------------------- /grouphug/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | 5 | 6 | class NumpyAwareJsonEncoder(json.JSONEncoder): 7 | """Handles numpy datatypes and such in json encoding""" 8 | 9 | def default(self, obj): 10 | if isinstance(obj, np.integer): 11 | return int(obj) 12 | elif isinstance(obj, np.floating): 13 | return float(obj) 14 | elif isinstance(obj, np.ndarray): 15 | return obj.tolist() 16 | elif isinstance(obj, set): 17 | return list(obj) 18 | else: 19 | return super().default(obj) 20 | 21 | 22 | def np_json_dumps(data, **kwargs): 23 | return json.dumps(data, cls=NumpyAwareJsonEncoder, **kwargs) 24 | 25 | 26 | def np_json_dump(data, f, **kwargs): 27 | return json.dump(data, f, cls=NumpyAwareJsonEncoder, **kwargs) 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "chatdesk-grouphug" 3 | version = "0.8.1" 4 | description = "GroupHug is a library with extensions to 🤗 transformers for multitask language modelling." 5 | authors = ["Sander Land"] 6 | license = "Apache2" 7 | readme = "README.md" 8 | homepage = "https://github.com/sanderland/grouphug" 9 | keywords = ["transformers","language modelling","machine learning","classification"] 10 | packages = [ 11 | {include = "grouphug"} 12 | ] 13 | 14 | [tool.isort] 15 | profile = "black" 16 | line_length = 120 17 | 18 | [tool.black] 19 | line-length = 120 20 | target_version = ['py38'] 21 | include = '\.py$' 22 | 23 | [tool.poetry.dependencies] 24 | python = "^3.8,<4.0" 25 | transformers = "^4.20.0" 26 | datasets = "^2.0.0" 27 | evaluate = "^0.3.0" 28 | torch = "^1.10.0" 29 | numpy = "^1.21" 30 | regex = "^2022.3.15" 31 | Unidecode = "^1.3.4" 32 | sentencepiece = "^0.1.96" 33 | demoji = "^1.1.0" 34 | 35 | [tool.poetry.dev-dependencies] 36 | pytest = "^7.1.1" 37 | pytest-subtests = "^0.7.0" 38 | jupyterlab = "^3.3.2" 39 | jupyterlab-code-formatter = "^1.4.10" 40 | isort = "^5.10.1" 41 | black = "^22.3.0" 42 | ipywidgets = "^7.6.3" 43 | pre-commit = "^2.13.0" 44 | matplotlib = "^3.4.2" 45 | Werkzeug = "<2.1" 46 | scikit-learn = "^1.0.2" 47 | Sphinx = "^5.0.1" 48 | sphinx-rtd-theme = "^1.0.0" 49 | pytest-pycharm = "^0.7.0" 50 | 51 | [build-system] 52 | requires = ["poetry-core>=1.0.0"] 53 | build-backend = "poetry.core.masonry.api" 54 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chatdesk/grouphug/e001e4be359230c3572feef195f5b2815eed4555/tests/__init__.py -------------------------------------------------------------------------------- /tests/automodel.py: -------------------------------------------------------------------------------- 1 | from grouphug import AutoMultiTaskModel, ClassificationHeadConfig, LMHeadConfig 2 | 3 | 4 | def test_automodel(): 5 | base_model = "vinai/bertweet-base" 6 | head_configs = [ 7 | LMHeadConfig(), 8 | ClassificationHeadConfig(problem_type="multi_label_classification", name="topics", num_labels=10), 9 | ClassificationHeadConfig( 10 | problem_type="single_label_classification", 11 | num_labels=3, 12 | labels_var="action_label", 13 | ), 14 | ] 15 | model = AutoMultiTaskModel.from_pretrained(base_model, head_configs) 16 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | from datasets import Dataset, load_dataset 5 | from transformers import AutoTokenizer, TrainingArguments 6 | 7 | from grouphug import DatasetFormatter 8 | from grouphug.config import logger 9 | 10 | logger.setLevel(logging.INFO) 11 | 12 | SMALL_MODEL = "prajjwal1/bert-tiny" 13 | 14 | 15 | @pytest.fixture 16 | def dataset_action_label(): 17 | ds = Dataset.from_dict({"text": ["abc", "def"], "action_label": [2, 1]}) 18 | return ds 19 | 20 | 21 | @pytest.fixture 22 | def dataset_multiclass_topics_star(): 23 | ds = Dataset.from_dict( 24 | { 25 | "text": ["xyz", "this is another little dataset", "this one has ignores in index"], 26 | "topics": [[0] * 7, [1, 0, 1, 0, 1, 0, 1], [1, -100, -100, -100, -100, -100, -100]], 27 | "star": [23, 7, 1], 28 | "string_label": ["a", "a", "b"], 29 | } 30 | ) 31 | return ds 32 | 33 | 34 | @pytest.fixture 35 | def dataset_regress(): 36 | ds = Dataset.from_dict({"text": [" ".join(f"token{i}" for i in range(100)), "x y z"], "y": [3.14, 42]}) 37 | return ds 38 | 39 | 40 | @pytest.fixture 41 | def dataset_text_only(): 42 | ds = Dataset.from_dict({"text": [f"the lazy dog jumps over the quick fox {i} times" for i in range(10)]}) 43 | return ds 44 | 45 | 46 | @pytest.fixture 47 | def dataset_demo_review_star(): 48 | return load_dataset("lhoestq/demo1").rename_column("review", "text") 49 | 50 | 51 | @pytest.fixture 52 | def multiple_datasets_with_cls(dataset_action_label, dataset_multiclass_topics_star, dataset_demo_review_star): 53 | return { 54 | "action": dataset_action_label, 55 | "topicsstar": dataset_multiclass_topics_star, 56 | "reviews": dataset_demo_review_star, 57 | } 58 | 59 | 60 | @pytest.fixture 61 | def multiple_datasets(multiple_datasets_with_cls, dataset_text_only, dataset_regress): 62 | return {**multiple_datasets_with_cls, "onlytext": dataset_text_only, "regress": dataset_regress} 63 | 64 | 65 | @pytest.fixture 66 | def multiple_formatter(): 67 | return DatasetFormatter().tokenize().encode("action_label").encode("star") 68 | 69 | 70 | @pytest.fixture 71 | def training_args(): 72 | return TrainingArguments( 73 | output_dir="output/test", 74 | do_train=True, 75 | num_train_epochs=1, 76 | per_device_train_batch_size=2, 77 | per_device_eval_batch_size=1, 78 | gradient_accumulation_steps=2, 79 | seed=42, 80 | logging_steps=1, 81 | evaluation_strategy="epoch", 82 | ) 83 | 84 | 85 | @pytest.fixture 86 | def tiny_tokenizer(): 87 | return AutoTokenizer.from_pretrained(SMALL_MODEL) 88 | 89 | 90 | def losses_not_nan(trainer): 91 | for r in trainer.state.log_history: 92 | losses = [v for k, v in r.items() if "loss" in k] 93 | assert losses, f"Record {r} in state history has missing loss" 94 | assert all(0 <= l <= 100 for l in losses), f"Record {r} in state history has invalid loss" 95 | -------------------------------------------------------------------------------- /tests/test_dataset_collection.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset, DatasetDict 3 | 4 | from grouphug import DatasetCollection 5 | from grouphug.config import DEFAULT_SEED 6 | 7 | 8 | def check_datasetcollection_format(data): 9 | assert isinstance(data, DatasetCollection) 10 | for task, dsd in data.items(): 11 | assert isinstance(task, (str, int)) 12 | assert isinstance(dsd, DatasetDict) 13 | for k, ds in dsd.items(): 14 | assert isinstance(k, str) 15 | assert isinstance(ds, Dataset) 16 | 17 | 18 | def test_dataset_collection(multiple_datasets): 19 | data = DatasetCollection(multiple_datasets, test_size={"reviews": 0.2}) 20 | check_datasetcollection_format(data[:, "train"]) 21 | check_datasetcollection_format(data[["action"], "train"]) 22 | check_datasetcollection_format(data[["action", "onlytext"], "train"]) 23 | check_datasetcollection_format(data["action", :]) 24 | assert len(data[:, "test"]) == 1 25 | assert isinstance(data["action", "train"], Dataset) 26 | assert isinstance(data.entries(), list) 27 | assert len(data.entries()) == 6 28 | with pytest.raises(KeyError): 29 | _ = data[:, "foo"] 30 | 31 | 32 | def test_dataset_collection_split(multiple_datasets): 33 | data1 = DatasetCollection(multiple_datasets, test_size={"onlytext": 0.2}) 34 | data2 = DatasetCollection(multiple_datasets, test_size={"onlytext": 0.2}) 35 | data3 = DatasetCollection(multiple_datasets, test_size=0.2) 36 | data4 = DatasetCollection(multiple_datasets, test_size=0.2, seed=DEFAULT_SEED + 1) # somewhat fragile 37 | 38 | assert data1["onlytext", "test"]["text"] == data2["onlytext", "test"]["text"] 39 | assert data1["onlytext", "test"]["text"] == data3["onlytext", "test"]["text"] 40 | assert data1["onlytext", "test"]["text"] != data4["onlytext", "test"]["text"] 41 | 42 | 43 | def test_dataset_collection_num_key(dataset_text_only): 44 | data = DatasetCollection({1: dataset_text_only, 2: dataset_text_only}, test_size={1: 0.2}) 45 | assert isinstance(data[1, "train"], Dataset) 46 | check_datasetcollection_format(data[:, "train"]) 47 | check_datasetcollection_format(data[[2, 1], "train"]) 48 | assert len(data[:, "test"]) == 1 49 | -------------------------------------------------------------------------------- /tests/test_dataset_formatter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset, DatasetDict, load_dataset 3 | from transformers import AutoTokenizer 4 | 5 | from grouphug import DatasetCollection, DatasetFormatter 6 | from grouphug.dataset_formatter import remove_handles, replace_handles, replace_urls, tolower, truncate_repeat 7 | from tests.conftest import SMALL_MODEL 8 | 9 | 10 | def test_formatter(multiple_datasets): 11 | ds = Dataset.from_dict( 12 | { 13 | "text": ["the lazy dog jumps over the quick fox", "another sentence", "a third sentence"], 14 | "labels": [2, 1, -100], 15 | "multi_label": [[1, 2, 3], [4, 5, 1], []], 16 | "ignore_entry": [[4, 5, 1], [], [1, 2, 3, 4]], 17 | } 18 | ) 19 | data = ( 20 | DatasetFormatter().binarize("multi_label", ignore_column="ignore_entry").encode("labels").apply(ds, test_size=0) 21 | ) 22 | assert "data" in data 23 | dt = data["data"]["train"] 24 | assert dt["labels"] == [1, 0, -100] 25 | assert dt["multi_label"] == [[1, 1, 1, -100, -100], [1, 0, 0, 1, 1], [-100, -100, -100, -100, 0]] 26 | 27 | data = DatasetFormatter().binarize("multi_label").binarize("labels").apply(ds, test_size=0) 28 | dt2 = data["data"]["train"] 29 | 30 | assert dt["labels"] == [1, 0, -100] # shallow copy working 31 | assert dt2["labels"] == [[0, 1], [1, 0], [0, 0]] # one-hot encoding working 32 | 33 | 34 | def test_truncate_repeat(): 35 | assert truncate_repeat("yay!!!!! aaaawesome!") == "yay!!! aaawesome!" 36 | assert truncate_repeat("That's hilarious! 😂😂😂😂😂😂😂😂😂😂😂😂") == "That's hilarious! 😂😂😂" 37 | 38 | 39 | def test_lower(): 40 | assert tolower("AbC ABC") == "abc abc" 41 | 42 | 43 | def test_handles(): 44 | assert remove_handles("@user is this ok @me") == " is this ok " 45 | assert replace_handles("@user is this ok @me") == "@USER is this ok @USER" 46 | 47 | 48 | def test_replace_urls(): 49 | assert replace_urls("Go to www.blah.com/asdasdasdasd?b=c or to https://a.com") == "Go to URL or to URL " 50 | assert replace_urls("Go to www.blah.com/asdasdasdasd?b=c or to https://a.com") == "Go to URL or to URL " 51 | 52 | 53 | def test_preprocess_formatter(): 54 | ds = Dataset.from_dict( 55 | { 56 | "text": ["한국어....

That`s amazing!!!!!

😂😂😂😂😂😂😂😂😂😂😂😂......"], 57 | } 58 | ) 59 | 60 | data = ( 61 | DatasetFormatter() 62 | .preprocess( 63 | [ 64 | "normalize", 65 | "replace_handles", 66 | "truncate_repeat", 67 | "demojize", 68 | "unidecode", 69 | "normalize_spaces", 70 | ] 71 | ) 72 | .apply(ds, test_size=0) 73 | ) 74 | assert data["data"]["train"]["text"] == [ 75 | "hangugeo... That's amazing!!! :|facewithtearsofjoy :|facewithtearsofjoy :|facewithtearsofjoy ..." 76 | ] 77 | 78 | 79 | def test_tokenize_pairs(dataset_text_only): 80 | ds = Dataset.from_dict( 81 | { 82 | "sentence1": ["abc"], 83 | "sentence2": ["def"], 84 | } 85 | ) 86 | tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny") 87 | data = ( 88 | DatasetFormatter() 89 | .tokenize() 90 | .tokenize(("sentence1", "sentence2")) 91 | .apply({"text": dataset_text_only, "ds": ds}, tokenizer=tokenizer, test_size=0) 92 | ) 93 | assert "input_ids" in data["ds"]["train"][0].keys() 94 | 95 | 96 | def test_single_dataset(): 97 | text_data = load_dataset("text", data_files=__file__)["train"] 98 | assert isinstance(text_data, Dataset) 99 | tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL) 100 | data = DatasetFormatter().tokenize().apply(text_data, tokenizer=tokenizer, test_size=0.1) 101 | assert "data" in data 102 | assert "input_ids" in data["data"]["train"].features 103 | assert "input_ids" in data["data"]["test"].features 104 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers import AutoTokenizer 3 | 4 | from grouphug import AutoMultiTaskModel, LMHeadConfig 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "base_model", 9 | [ 10 | "prajjwal1/bert-tiny", 11 | "vinai/bertweet-base", 12 | "google/electra-small-generator", 13 | "distilbert-base-uncased", 14 | "microsoft/deberta-v3-small", 15 | ], 16 | ) 17 | def test_model_init_all_expected(base_model): 18 | model, info = AutoMultiTaskModel.from_pretrained(base_model, [LMHeadConfig()], output_loading_info=True) 19 | assert info["missing_keys"] == [] 20 | assert info["unexpected_keys"] == [] 21 | assert info["mismatched_keys"] == [] 22 | assert info["error_msgs"] == [] 23 | 24 | 25 | @pytest.mark.parametrize("base_model", ["sentence-transformers/paraphrase-MiniLM-L3-v2"]) 26 | def test_model_init_mlm_new(base_model): 27 | model, info = AutoMultiTaskModel.from_pretrained(base_model, [LMHeadConfig()], output_loading_info=True) 28 | prefix = model.get_mlm_head().attribute + "." 29 | assert len(info["missing_keys"]) > 0 30 | assert all(k.startswith(prefix) for k in info["missing_keys"]) 31 | assert info["unexpected_keys"] == [] 32 | assert info["mismatched_keys"] == [] 33 | assert info["error_msgs"] == [] 34 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from grouphug.config import IGNORE_INDEX 7 | from grouphug.model import ModelInferenceError 8 | 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "" # TODO: remove 10 | 11 | import pandas as pd 12 | import pytest 13 | from datasets import Dataset, load_metric 14 | from transformers import AutoTokenizer, TrainingArguments 15 | 16 | from grouphug import ( 17 | AutoMultiTaskModel, 18 | ClassificationHead, 19 | ClassificationHeadConfig, 20 | DatasetFormatter, 21 | LMHeadConfig, 22 | MultiTaskTrainer, 23 | ) 24 | from grouphug.utils import np_json_dumps 25 | from tests.conftest import SMALL_MODEL, losses_not_nan 26 | 27 | 28 | def test_train_save_load(multiple_datasets, multiple_formatter, training_args): 29 | base_model = SMALL_MODEL 30 | tokenizer = AutoTokenizer.from_pretrained(base_model) 31 | data = multiple_formatter.apply(multiple_datasets, tokenizer=tokenizer, test_size=0) 32 | 33 | head_configs = [ 34 | LMHeadConfig(), 35 | ClassificationHeadConfig.from_data( 36 | data, "topics", name="topic_cls", label_smoothing=0.1, pos_weight=[1, 2, 3, 4, 5, 6, 7] 37 | ), 38 | ClassificationHeadConfig.from_data(data, "action_label", pooling_type="mean", label_smoothing=0.1), 39 | ClassificationHeadConfig.from_data(data, "star", weight=2, pooling_type="max"), 40 | ClassificationHeadConfig.from_data(data, "y"), 41 | ] 42 | assert head_configs[1].num_labels == 7 43 | assert head_configs[2].num_labels == 2 44 | assert head_configs[4].num_labels == 1 45 | 46 | assert head_configs[1].problem_type == ClassificationHead.MULTI 47 | assert head_configs[2].problem_type == ClassificationHead.SINGLE 48 | assert head_configs[3].problem_type == ClassificationHead.SINGLE 49 | assert head_configs[4].problem_type == ClassificationHead.REGRESSION 50 | 51 | model = AutoMultiTaskModel.from_pretrained(base_model, head_configs, formatter=multiple_formatter) 52 | training_args.save_steps = 1 53 | training_args.save_total_limit = 1 54 | trainer = MultiTaskTrainer( 55 | model=model, 56 | tokenizer=tokenizer, 57 | args=training_args, 58 | train_data=data[:, "train"], 59 | eval_data=data["topicsstar", "train"], 60 | eval_heads=["topic_cls"], 61 | ) 62 | _train_res = trainer.train() 63 | trainer.save_model() 64 | loaded_model = AutoMultiTaskModel.from_pretrained(training_args.output_dir) 65 | assert type(loaded_model) == type(model) 66 | for lmhc, ohc in zip(loaded_model.head_configs, model.head_configs): 67 | for k, v in ohc.__dict__.items(): 68 | if k not in ["_head", "auto_attribute"]: # expected not matching 69 | eq = lmhc.__dict__[k] == v 70 | if isinstance(eq, (list, np.ndarray, torch.Tensor)): 71 | eq = all(eq) 72 | assert eq, f"Key {k} for head {ohc} does not match" 73 | assert loaded_model.formatter is not None 74 | assert np_json_dumps(loaded_model.formatter.to_dict()) == np_json_dumps(model.formatter.to_dict()) # tuple vs list 75 | 76 | 77 | # test multiple models, really shows edge cases 78 | @pytest.mark.parametrize("mlm", [False, True]) 79 | def test_multi_train(multiple_datasets_with_cls, multiple_formatter, training_args, mlm, subtests): 80 | fingerprints = [] 81 | for base_model in ["vinai/bertweet-base", "sentence-transformers/paraphrase-MiniLM-L3-v2", SMALL_MODEL]: 82 | with subtests.test(msg=f"Testing model {base_model} with mlm = {mlm}", base_model=base_model): 83 | tokenizer = AutoTokenizer.from_pretrained(base_model) 84 | data = multiple_formatter.apply(multiple_datasets_with_cls, tokenizer=tokenizer, test_size=0) 85 | fingerprints.append([v._fingerprint for d in data.values() for k, v in d.items()]) 86 | 87 | head_configs = [ 88 | ClassificationHeadConfig.from_data(data, "topics", name="topic_cls"), 89 | ClassificationHeadConfig.from_data(data, "star", weight=2, pooling_type="max"), 90 | ClassificationHeadConfig.from_data(data, "action_label", name="cls_action", pooling_type="mean"), 91 | ] 92 | if mlm: 93 | head_configs.append(LMHeadConfig()) 94 | 95 | model = AutoMultiTaskModel.from_pretrained(base_model, head_configs) 96 | word_embeddings = model.base_model.embeddings.word_embeddings 97 | 98 | for ds in data.gather_column("input_ids"): 99 | for input_ids in ds["input_ids"]: 100 | assert max(input_ids) < word_embeddings.num_embeddings 101 | 102 | trainer = MultiTaskTrainer( 103 | model=model, 104 | tokenizer=tokenizer, 105 | args=training_args, 106 | train_data=data[:, "train"], 107 | eval_data=data["action", :], 108 | eval_heads=["cls_action"], 109 | ) 110 | _train_res = trainer.train() 111 | 112 | # ensure model and such is taken into account for data formatting 113 | flattened_fps = [f for fp_row in fingerprints for f in fp_row] 114 | assert len(flattened_fps) == len(set(flattened_fps)) 115 | 116 | 117 | def test_train_alt_input(training_args): 118 | base_model = SMALL_MODEL 119 | 120 | ds = Dataset.from_dict( 121 | {"reply_text": ["the lazy dog jumps over the quick fox", "another sentence"], "labels": [0, 1]} 122 | ) 123 | 124 | tokenizer = AutoTokenizer.from_pretrained(base_model) 125 | fmt = DatasetFormatter().tokenize("reply_text", output_prefix="reply_") 126 | data = fmt.apply(ds, tokenizer=tokenizer) 127 | head_configs = [ 128 | ClassificationHeadConfig.from_data( 129 | data, labels_var="labels", input_prefix="reply_", pooling_method="max", name="ch" 130 | ), 131 | ] 132 | assert head_configs[0].num_labels == 2 133 | assert head_configs[0].input_prefixes() == ["reply_"] 134 | model = AutoMultiTaskModel.from_pretrained(base_model, head_configs=head_configs, formatter=fmt) 135 | trainer = MultiTaskTrainer( 136 | model=model, 137 | args=training_args, 138 | train_data=data[:, "train"], 139 | eval_data=data["data", "train"], 140 | ) 141 | 142 | expected_vars = { 143 | "reply_input_ids", 144 | "reply_token_type_ids", 145 | "reply_attention_mask", 146 | "reply_special_tokens_mask", 147 | "labels", 148 | } 149 | assert set(data["data"]["train"].features.keys()) == expected_vars 150 | assert trainer.data_collator.model_vars == expected_vars 151 | trainer.train() 152 | losses_not_nan(trainer) 153 | 154 | # test infer methods 155 | result = model.format_forward(reply_text="abc") 156 | assert len(result.head_outputs) == 1 157 | assert len(result.head_outputs["ch"].logits[0]) == 2 158 | 159 | assert head_configs[0].output_stats(result.head_outputs["ch"]).keys() == { 160 | "probs", 161 | "predicted_id", 162 | } 163 | 164 | result = model.format_forward(reply_text="abc", labels=0) 165 | assert result.head_outputs["ch"].loss is not None 166 | assert head_configs[0].output_stats(result.head_outputs["ch"]).keys() == { 167 | "loss", 168 | "probs", 169 | "predicted_id", 170 | } 171 | 172 | assert isinstance(model.predict({"reply_text": "abc", "labels": 0}), dict) 173 | assert isinstance(model.predict([{"reply_text": "abc"}]), list) 174 | df = pd.DataFrame([{"reply_text": "abc"}]) 175 | dataset = Dataset.from_pandas(df) 176 | assert isinstance(model.predict(df), pd.DataFrame) 177 | assert isinstance(model.predict(dataset), list) 178 | 179 | 180 | def test_train_eval_metrics(multiple_datasets, multiple_formatter): 181 | base_model = SMALL_MODEL 182 | tokenizer = AutoTokenizer.from_pretrained(base_model) 183 | data = multiple_formatter.apply(multiple_datasets, tokenizer=tokenizer, test_size=0) 184 | 185 | head_configs = [ 186 | ClassificationHeadConfig.from_data(data, "topics", name="topic_cls", id2label=[str(i) for i in range(7)]), 187 | ClassificationHeadConfig.from_data(data, "star", weight=2, pooling_type="max"), 188 | ClassificationHeadConfig.from_data(data, "action_label", pooling_type="mean"), 189 | LMHeadConfig(), 190 | ] 191 | model = AutoMultiTaskModel.from_pretrained(base_model, head_configs, formatter=multiple_formatter) 192 | training_args = TrainingArguments( 193 | output_dir="output/test", 194 | do_train=True, 195 | max_steps=5, 196 | per_device_train_batch_size=1, 197 | per_device_eval_batch_size=1, 198 | gradient_accumulation_steps=1, 199 | seed=42, 200 | logging_steps=1, 201 | evaluation_strategy="steps", 202 | eval_steps=2, 203 | ) 204 | 205 | def compute_metrics(eval_preds, dataset_name, heads): 206 | all_logits, all_labels = eval_preds 207 | if not isinstance(all_logits, tuple): 208 | all_logits = (all_logits,) 209 | all_labels = (all_labels,) 210 | metrics = {} 211 | accuracy_f = load_metric("accuracy") 212 | for logits, labels, hc in zip(all_logits, all_labels, heads): 213 | labels = labels.ravel() 214 | mask = labels != IGNORE_INDEX 215 | if getattr(hc, "problem_type", None) == ClassificationHead.MULTI: 216 | predictions = (logits > 0).ravel()[mask] 217 | else: 218 | predictions = np.argmax(logits, axis=-1).ravel()[mask] 219 | acc = accuracy_f.compute(predictions=predictions, references=labels[mask]) 220 | metrics[f"{hc.name}_accuracy"] = acc["accuracy"] 221 | return metrics 222 | 223 | trainer = MultiTaskTrainer( 224 | model=model, 225 | tokenizer=tokenizer, 226 | args=training_args, 227 | train_data=data[:, "train"], 228 | eval_data=data[["topicsstar", "action"], "train"], 229 | eval_heads=["lm", "star", "action_label"], 230 | compute_metrics=compute_metrics, 231 | ) 232 | with pytest.raises(ValueError): # action does not have star labels 233 | _train_res = trainer.train() 234 | trainer.eval_heads = {"topicsstar": ["star", "topic_cls", "mlm"], "action": ["action_label"]} 235 | _train_res = trainer.train() 236 | losses_not_nan(trainer) 237 | all_keys = {k for l in trainer.state.log_history for k in l.keys()} 238 | assert "eval_action_action_label_accuracy" in all_keys 239 | 240 | predictions = model.predict(dict(text="abc")) 241 | assert "star_predicted_label" in predictions.keys() 242 | assert "action_label_predicted_label" in predictions.keys() 243 | assert "topic_cls_predicted_labels" in predictions.keys() 244 | 245 | 246 | def test_train_mlm_mtd(dataset_regress, training_args): 247 | base_model = SMALL_MODEL 248 | tokenizer = AutoTokenizer.from_pretrained(base_model) 249 | fmt = DatasetFormatter().tokenize() 250 | data = fmt.apply(dataset_regress, tokenizer=tokenizer, test_size=0) 251 | training_args.evaluation_strategy = None 252 | 253 | for sep in [True, False]: 254 | for mlm, mtd, mtd_strategy in [(True, False, None), (False, True, "random"), (True, True, "token_similarity")]: 255 | head_configs = [ 256 | LMHeadConfig( 257 | masked_language_modelling=mlm, 258 | masked_token_detection=mtd, 259 | mtd_pos_weight=2.0, 260 | separate_embedding=sep, 261 | ), 262 | ClassificationHeadConfig.from_data(data, labels_var="y"), 263 | ] 264 | model = AutoMultiTaskModel.from_pretrained(base_model, head_configs, tokenizer=tokenizer, formatter=fmt) 265 | trainer = MultiTaskTrainer( 266 | model=model, 267 | tokenizer=tokenizer, 268 | args=training_args, 269 | train_data=data[:, "train"], 270 | ) 271 | trainer.train() 272 | result = model.predict(dict(text="blabla")) 273 | assert result.keys() == {"y_predicted_value"} 274 | 275 | 276 | @pytest.mark.parametrize("base_model", [SMALL_MODEL, "facebook/opt-125m"]) 277 | @pytest.mark.parametrize("padding_side", ["left", "right"]) 278 | def test_train_clm(dataset_regress, training_args, base_model, padding_side): 279 | tokenizer = AutoTokenizer.from_pretrained(base_model) 280 | tokenizer.padding_side = padding_side 281 | fmt = DatasetFormatter().tokenize() 282 | data = fmt.apply(dataset_regress, tokenizer=tokenizer, test_size=0) 283 | training_args.evaluation_strategy = None 284 | 285 | head_configs = [ 286 | LMHeadConfig(causal_language_modelling=True), 287 | ClassificationHeadConfig.from_data(data, labels_var="y", pooling_method="last", classifier_hidden_size=3), 288 | ClassificationHeadConfig.from_data(data, name="yauto", labels_var="y", classifier_hidden_size=3), 289 | ] 290 | model = AutoMultiTaskModel.from_pretrained(base_model, head_configs, tokenizer=tokenizer, formatter=fmt) 291 | trainer = MultiTaskTrainer( 292 | model=model, 293 | tokenizer=tokenizer, 294 | args=training_args, 295 | train_data=data[:, "train"], 296 | ) 297 | trainer.train() 298 | result = model.predict(dict(text="blabla")) 299 | assert result.keys() == {"y_predicted_value", "yauto_predicted_value"} 300 | 301 | 302 | def test_train_forgot_encode(dataset_multiclass_topics_star, training_args): 303 | tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL) 304 | fmt = DatasetFormatter().tokenize() # .encode("string_label") 305 | data = fmt.apply(dataset_multiclass_topics_star, tokenizer=tokenizer, test_size=0) 306 | 307 | head_configs = [ 308 | ClassificationHeadConfig.from_data(data, labels_var="string_label", classifier_hidden_size=1), 309 | ] 310 | model = AutoMultiTaskModel.from_pretrained(SMALL_MODEL, head_configs, tokenizer=tokenizer, formatter=fmt) 311 | training_args.evaluation_strategy = None 312 | trainer = MultiTaskTrainer( 313 | model=model, 314 | tokenizer=tokenizer, 315 | args=training_args, 316 | train_data=data[:, "train"], 317 | ) 318 | with pytest.raises(ModelInferenceError): 319 | trainer.train() 320 | --------------------------------------------------------------------------------