├── .github └── workflows │ ├── publish-package.yml │ ├── pylint.yml │ └── pytest.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── pyproject.toml ├── requirements.txt ├── resources ├── generation_workflow.png ├── logo_fabricator.drawio_dark.png └── logo_fabricator.drawio_white.png ├── setup.py ├── src └── fabricator │ ├── __init__.py │ ├── dataset_generator.py │ ├── dataset_transformations │ ├── __init__.py │ ├── question_answering.py │ ├── text_classification.py │ └── token_classification.py │ ├── prompts │ ├── __init__.py │ ├── base.py │ └── utils.py │ ├── samplers │ ├── __init__.py │ └── samplers.py │ └── utils.py ├── tests ├── test_dataset_generator.py ├── test_dataset_sampler.py ├── test_dataset_transformations.py └── test_prompts.py └── tutorials ├── TUTORIAL-1_OVERVIEW.md ├── TUTORIAL-2_GENERATION_WORKFLOWS.md └── TUTORIAL-3_ADVANCED-GENERATION.md /.github/workflows/publish-package.yml: -------------------------------------------------------------------------------- 1 | name: Publish Release to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | lint: 9 | name: Linting Checks 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Set up Python 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: "3.10" 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | python -m pip install -r requirements.txt 21 | pip install pylint 22 | - name: Analysing the code with pylint 23 | run: | 24 | pylint $(git ls-files '*.py') --fail-under 9.3 25 | 26 | build: 27 | name: Build distribution 28 | runs-on: ubuntu-latest 29 | 30 | steps: 31 | - uses: actions/checkout@v4 32 | - name: Set up Python 33 | uses: actions/setup-python@v4 34 | with: 35 | python-version: "3.10" 36 | - name: Install pypa/build 37 | run: >- 38 | python3 -m 39 | pip install 40 | build 41 | --user 42 | - name: Build binary wheel and source tarball 43 | run: python3 -m build 44 | - name: Store distribution packages 45 | uses: actions/upload-artifact@v3 46 | with: 47 | name: python-package-distributions 48 | path: dist/ 49 | 50 | publish-to-testpypi: 51 | name: Publish distribution to TestPyPI 52 | needs: 53 | - lint 54 | - build 55 | runs-on: ubuntu-latest 56 | 57 | environment: 58 | name: testpypi 59 | url: https://test.pypi.org/p/fabricator-ai 60 | 61 | permissions: 62 | contents: read 63 | id-token: write 64 | 65 | steps: 66 | - name: Download all the dists 67 | uses: actions/download-artifact@v3 68 | with: 69 | name: python-package-distributions 70 | path: dist/ 71 | - name: Publish distribution to TestPyPI 72 | uses: pypa/gh-action-pypi-publish@release/v1 73 | with: 74 | repository-url: https://test.pypi.org/legacy/ 75 | 76 | publish-to-pypi: 77 | name: >- 78 | Publish Release to PyPI 79 | if: startsWith(github.ref, 'refs/tags/') 80 | needs: 81 | - publish-to-testpypi 82 | runs-on: ubuntu-latest 83 | environment: 84 | name: pypi 85 | url: https://pypi.org/p/fabricator-ai 86 | permissions: 87 | contents: read 88 | id-token: write 89 | steps: 90 | - name: Download all the dists 91 | uses: actions/download-artifact@v3 92 | with: 93 | name: python-package-distributions 94 | path: dist/ 95 | - name: Publish distribution to PyPI 96 | uses: pypa/gh-action-pypi-publish@release/v1 -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.8", "3.9", "3.10"] 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | python -m pip install -r requirements.txt 21 | pip install pylint 22 | - name: Analysing the code with pylint 23 | run: | 24 | pylint $(git ls-files '*.py') --fail-under 9.3 25 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.8", "3.9", "3.10"] 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install pytest pytest-sugar 31 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 32 | - name: Test with pytest 33 | run: | 34 | python -m pytest 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # Logs 156 | logs/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Fabricator Logo](resources/logo_fabricator.drawio_dark.png#gh-dark-mode-only) 2 | ![Fabricator Logo](resources/logo_fabricator.drawio_white.png#gh-light-mode-only) 3 | 4 |

A flexible open-source framework to generate datasets with large language models.

5 |

6 | version 7 | python 8 | Static Badge 9 |

10 |
11 |
12 | 13 | [Installation](#installation) | [Basic Concepts](#basic-concepts) | [Examples](#examples) | [Tutorials](tutorials/TUTORIAL-1_OVERVIEW.md) | 14 | [Paper](https://arxiv.org/abs/2309.09582) | [Citation](#citation) 15 | 16 |
17 |
18 | 19 | ## News 20 | 21 | - **[10/23]** We released the first version of this repository on PyPI. You can install it via `pip install fabricator-ai`. 22 | - **[10/23]** Our paper got accepted at EMNLP 2023. You can find the preprint [here](https://arxiv.org/abs/2309.09582). You can find the experimental scripts under release v0.1.0. 23 | - **[09/23]** Support for `gpt-3.5-turbo-instruct` added in the new [Haystack](https://github.com/deepset-ai/haystack) release! 24 | - **[08/23]** Added several experimental scripts to investigate the generation and annotation ability of `gpt-3.5-turbo` on various downstream tasks + the influence of few-shot examples on the performance for different downstream tasks. 25 | - **[07/23]** Refactorings of majors classes - you can now simply use our BasePrompt class to create your own customized prompts for every downstream task! 26 | - **[07/23]** Added dataset transformations for token classification to prompt LLMs with textual spans rather than with list of tags. 27 | - **[06/23]** Initial version of fabricator supporting text classification and question answering tasks. 28 | 29 | ## Overview 30 | 31 | This repository: 32 | 33 | - is an easy-to-use open-source library to generate datasets with large language models. If you want to train 34 | a model on a specific domain / label distribution / downstream task, you can use this framework to generate 35 | a dataset for it. 36 | - builds on top of deepset's haystack and huggingface's datasets libraries. Thus, we support a wide range 37 | of language models and you can load and use the generated datasets as you know it from the Datasets library for your 38 | model training. 39 | - is highly flexible and offers various adaptions possibilities such as 40 | prompt customization, integration and sampling of fewshot examples or annotation of the unlabeled datasets. 41 | 42 | ## Installation 43 | Using conda: 44 | ``` 45 | git clone git@github.com:flairNLP/fabricator.git 46 | cd fabricator 47 | conda create -y -n fabricator python=3.10 48 | conda activate fabricator 49 | pip install fabricator-ai 50 | ``` 51 | 52 | If you want to install in editable mode, you can use the following command: 53 | ``` 54 | pip install -e . 55 | ``` 56 | 57 | ## Basic Concepts 58 | 59 | This framework is based on the idea of using large language models to generate datasets for specific tasks. To do so, 60 | we need four basic modules: a dataset, a prompt, a language model and a generator: 61 | - Dataset: We use [huggingface's datasets library](https://github.com/huggingface/datasets) to load fewshot or 62 | unlabeled datasets and store the generated or annotated datasets with their `Dataset` class. Once 63 | created, you can share the dataset with others via the hub or use it for your model training. 64 | - Prompt: A prompt is the instruction made to the language model. It can be a simple sentence or a more complex 65 | template with placeholders. We provide an easy interface for custom dataset generation prompts in which you can specify 66 | label options for the LLM to choose from, provide fewshot examples to support the prompt with or annotate an unlabeled 67 | dataset in a specific way. 68 | - LLM: We use [deepset's haystack library](https://github.com/deepset-ai/haystack) as our LLM interface. deepset 69 | supports a wide range of LLMs including OpenAI, all models from the HuggingFace model hub and many more. 70 | - Generator: The generator is the core of this framework. It takes a dataset, a prompt and a LLM and generates a 71 | dataset based on your specifications. 72 | 73 | ## Examples 74 | 75 | With our library, you can generate datasets for any task you want. You can start as simple 76 | as that: 77 | 78 | ### Generate a dataset from scratch 79 | 80 | ```python 81 | import os 82 | from haystack.nodes import PromptNode 83 | from fabricator import DatasetGenerator 84 | from fabricator.prompts import BasePrompt 85 | 86 | prompt = BasePrompt( 87 | task_description="Generate a short movie review.", 88 | ) 89 | 90 | prompt_node = PromptNode( 91 | model_name_or_path="gpt-3.5-turbo", 92 | api_key=os.environ.get("OPENAI_API_KEY"), 93 | max_length=100, 94 | ) 95 | 96 | generator = DatasetGenerator(prompt_node) 97 | generated_dataset = generator.generate( 98 | prompt_template=prompt, 99 | max_prompt_calls=10, 100 | ) 101 | 102 | generated_dataset.push_to_hub("your-first-generated-dataset") 103 | ``` 104 | 105 | In our tutorial, we introduce how to create classification datasets with label options to choose from, how to include 106 | fewshot examples or how to annotate unlabeled data into predefined categories. 107 | 108 | ## Citation 109 | 110 | If you find this repository useful, please cite our work. 111 | 112 | ``` 113 | @inproceedings{golde2023fabricator, 114 | title = "Fabricator: An Open Source Toolkit for Generating Labeled Training Data with Teacher {LLM}s", 115 | author = "Golde, Jonas and Haller, Patrick and Hamborg, Felix and Risch, Julian and Akbik, Alan", 116 | booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing: System Demonstrations", 117 | month = dec, 118 | year = "2023", 119 | address = "Singapore", 120 | publisher = "Association for Computational Linguistics", 121 | url = "https://aclanthology.org/2023.emnlp-demo.1", 122 | pages = "1--11", 123 | } 124 | ``` 125 | 126 | 127 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | 4 | [tool.pylint.format] 5 | max-line-length = 119 6 | 7 | [tool.pylint.MASTER] 8 | ignore-patterns = '''test_.*\.py, 9 | setup.py 10 | ''' 11 | ignore-paths = [ "^paper_experiments/.*$"] 12 | 13 | [tool.pylint."MESSAGE CONTROL"] 14 | disable = '''logging-fstring-interpolation, 15 | missing-module-docstring, 16 | ''' 17 | 18 | [tool.pytest.ini_options] 19 | pythonpath = [ 20 | "src" 21 | ] 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | farm-haystack>=1.18.0 3 | loguru 4 | -------------------------------------------------------------------------------- /resources/generation_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flairNLP/fabricator/8dc23321a67ad03bef371a560452f0e9b6ff5afa/resources/generation_workflow.png -------------------------------------------------------------------------------- /resources/logo_fabricator.drawio_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flairNLP/fabricator/8dc23321a67ad03bef371a560452f0e9b6ff5afa/resources/logo_fabricator.drawio_dark.png -------------------------------------------------------------------------------- /resources/logo_fabricator.drawio_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flairNLP/fabricator/8dc23321a67ad03bef371a560452f0e9b6ff5afa/resources/logo_fabricator.drawio_white.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup, find_packages 3 | 4 | 5 | def requirements(): 6 | with open("requirements.txt", "r") as f: 7 | return f.read().splitlines() 8 | 9 | 10 | setup( 11 | name='fabricator-ai', 12 | version='0.2.0', 13 | author='Humboldt University Berlin, deepset GmbH', 14 | author_email="goldejon@informatik.hu-berlin.de", 15 | description='Conveniently generating datasets with large language models.', 16 | long_description=Path("README.md").read_text(encoding="utf-8"), 17 | long_description_content_type="text/markdown", 18 | package_dir={"": "src"}, 19 | packages=find_packages("src"), 20 | license="Apache 2.0", 21 | python_requires=">=3.8", 22 | install_requires=requirements(), 23 | ) 24 | -------------------------------------------------------------------------------- /src/fabricator/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | 3 | from .prompts import ( 4 | BasePrompt, 5 | infer_prompt_from_dataset, 6 | infer_prompt_from_task_template 7 | ) 8 | from .dataset_transformations import * 9 | from .samplers import * 10 | from .dataset_generator import DatasetGenerator 11 | -------------------------------------------------------------------------------- /src/fabricator/dataset_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | from pathlib import Path 5 | from collections import defaultdict 6 | from typing import Any, Callable, Dict, Optional, Union, Tuple, List 7 | from tqdm import tqdm 8 | from loguru import logger 9 | 10 | from datasets import Dataset 11 | from numpy.random import choice 12 | from haystack.nodes import PromptNode 13 | from haystack.nodes import PromptTemplate as HaystackPromptTemplate 14 | 15 | from .prompts import BasePrompt 16 | from .samplers import single_label_stratified_sample 17 | from .utils import log_dir, create_timestamp_path 18 | 19 | 20 | class DatasetGenerator: 21 | """The DatasetGenerator class is the main class of the fabricator package. 22 | It generates datasets based on a prompt template. The main function is generate().""" 23 | 24 | def __init__(self, prompt_node: PromptNode, max_tries: int = 10): 25 | """Initialize the DatasetGenerator with a prompt node. 26 | 27 | Args: 28 | prompt_node (PromptNode): Prompt node / LLM from haystack. 29 | """ 30 | self.prompt_node = prompt_node 31 | self._base_log_dir = log_dir() 32 | self._max_tries = max_tries 33 | 34 | def _setup_log(self, prompt_template: BasePrompt) -> Path: 35 | """For every generation run create a new log file. 36 | Current format: _.jsonl 37 | 38 | Args: 39 | prompt_template (BasePrompt): Prompt template to generate the dataset with. 40 | 41 | Returns: 42 | Path: Path to the log file. 43 | 44 | """ 45 | timestamp_path = create_timestamp_path(self._base_log_dir) 46 | log_file = Path(f"{timestamp_path}_{prompt_template.__class__.__name__}.jsonl") 47 | log_file.parent.mkdir(parents=True, exist_ok=True) 48 | log_file.touch() 49 | return log_file 50 | 51 | def generate( 52 | self, 53 | prompt_template: BasePrompt, 54 | fewshot_dataset: Optional[Dataset] = None, 55 | fewshot_sampling_strategy: Optional[str] = None, 56 | fewshot_examples_per_class: int = None, 57 | fewshot_sampling_column: Optional[str] = None, 58 | unlabeled_dataset: Optional[Dataset] = None, 59 | return_unlabeled_dataset: bool = False, 60 | max_prompt_calls: int = 10, 61 | num_samples_to_generate: int = 10, 62 | timeout_per_prompt: Optional[int] = None, 63 | log_every_n_api_calls: int = 25, 64 | dummy_response: Optional[Union[str, Callable]] = None 65 | ) -> Union[Dataset, Tuple[Dataset, Dataset]]: 66 | """Generate a dataset based on a prompt template and support examples. 67 | Optionally, unlabeled examples can be provided to annotate unlabeled data. 68 | 69 | Args: 70 | prompt_template (BasePrompt): Prompt template to generate the dataset with. 71 | fewshot_dataset (Dataset): Support examples to generate the dataset from. Defaults to None. 72 | fewshot_sampling_strategy (str, optional): Sampling strategy for support examples. 73 | Defaults to None and means all fewshot examples are used or limited by number of 74 | fewshot_examples_per_class. 75 | "uniform" sampling strategy means that fewshot examples for a uniformly sampled label are used. 76 | "stratified" sampling strategy means that fewshot examples uniformly selected from each label. 77 | fewshot_examples_per_class (int, optional): Number of support examples for a certain class per prompt. 78 | Defaults to None. 79 | fewshot_sampling_column (str, optional): Column to sample from. Defaults to None and function will try 80 | to sample from the generate_data_for_column attribute of the prompt template. 81 | unlabeled_dataset (Optional[Dataset], optional): Unlabeled examples to annotate. Defaults to None. 82 | return_unlabeled_dataset (bool, optional): Whether to return the original dataset. Defaults to False. 83 | max_prompt_calls (int, optional): Maximum number of prompt calls. Defaults to 10. 84 | num_samples_to_generate (int, optional): Number of samples to generate. Defaults to 10. 85 | timeout_per_prompt (Optional[int], optional): Timeout per prompt call. Defaults to None. 86 | log_every_n_api_calls (int, optional): Log every n api calls. Defaults to 25. 87 | dummy_response (Optional[Union[str, Callable]], optional): Dummy response for dry runs. Defaults to None. 88 | 89 | Returns: 90 | Union[Dataset, Tuple[Dataset, Dataset]]: Generated dataset or tuple of generated dataset and original 91 | dataset. 92 | """ 93 | if fewshot_dataset: 94 | self._assert_fewshot_dataset_matches_prompt(prompt_template, fewshot_dataset) 95 | 96 | assert fewshot_sampling_strategy in [None, "uniform", "stratified"], \ 97 | "Sampling strategy must be 'uniform' or 'stratified'" 98 | 99 | if fewshot_dataset and not fewshot_sampling_column: 100 | fewshot_sampling_column = prompt_template.generate_data_for_column[0] 101 | 102 | generated_dataset, original_dataset = self._inner_generate_loop( 103 | prompt_template, 104 | fewshot_dataset, 105 | fewshot_examples_per_class, 106 | fewshot_sampling_strategy, 107 | fewshot_sampling_column, 108 | unlabeled_dataset, 109 | return_unlabeled_dataset, 110 | max_prompt_calls, 111 | num_samples_to_generate, 112 | timeout_per_prompt, 113 | log_every_n_api_calls, 114 | dummy_response 115 | ) 116 | 117 | if return_unlabeled_dataset: 118 | return generated_dataset, original_dataset 119 | 120 | return generated_dataset 121 | 122 | def _try_generate( 123 | self, prompt_text: str, invocation_context: Dict, dummy_response: Optional[Union[str, Callable]] 124 | ) -> Optional[str]: 125 | """Tries to generate a single example. Restrict the time spent on this. 126 | 127 | Args: 128 | prompt_text: Prompt text to generate an example for. 129 | invocation_context: Invocation context to generate an example for. 130 | dry_run: Whether to actually generate the example or just return a dummy example. 131 | 132 | Returns: 133 | Generated example 134 | """ 135 | 136 | if dummy_response: 137 | 138 | if isinstance(dummy_response, str): 139 | logger.info(f"Returning dummy response: {dummy_response}") 140 | return dummy_response 141 | 142 | if callable(dummy_response): 143 | dummy_value = dummy_response(prompt_text) 144 | logger.info(f"Returning dummy response: {dummy_response}") 145 | return dummy_value 146 | 147 | raise ValueError("Dummy response must be a string or a callable") 148 | 149 | # Haystack internally uses timeouts and retries, so we dont have to do it 150 | # We dont catch authentification errors here, because we want to fail fast 151 | try: 152 | prediction = self.prompt_node.run( 153 | prompt_template=HaystackPromptTemplate(prompt=prompt_text), 154 | invocation_context=invocation_context, 155 | )[0]["results"] 156 | except Exception as error: 157 | logger.error(f"Error while generating example: {error}") 158 | return None 159 | 160 | return prediction 161 | 162 | def _inner_generate_loop( 163 | self, 164 | prompt_template: BasePrompt, 165 | fewshot_dataset: Dataset, 166 | fewshot_examples_per_class: int, 167 | fewshot_sampling_strategy: str, 168 | fewshot_sampling_column: str, 169 | unlabeled_dataset: Dataset, 170 | return_unlabeled_dataset: bool, 171 | max_prompt_calls: int, 172 | num_samples_to_generate: int, 173 | timeout_per_prompt: Optional[int], 174 | log_every_n_api_calls: int = 25, 175 | dummy_response: Optional[Union[str, Callable]] = None 176 | ): 177 | current_tries_left = self._max_tries 178 | current_log_file = self._setup_log(prompt_template) 179 | 180 | generated_dataset = defaultdict(list) 181 | original_dataset = defaultdict(list) 182 | 183 | if unlabeled_dataset: 184 | api_calls = range(min(max_prompt_calls, len(unlabeled_dataset))) 185 | else: 186 | api_calls = range(min(max_prompt_calls, num_samples_to_generate)) 187 | 188 | for prompt_call_idx, unlabeled_example_idx in tqdm( 189 | enumerate(api_calls, start=1), desc="Generating dataset", total=len(api_calls) 190 | ): 191 | fewshot_examples = None 192 | unlabeled_example = None 193 | invocation_context = None 194 | prompt_labels = None 195 | 196 | if prompt_template.label_options: 197 | # At some point: how can we do label-conditioned generation without fewshot examples? Currently it 198 | # require a second parameter for sample from label options and not from fewshot examples 199 | prompt_labels = prompt_template.label_options 200 | 201 | if fewshot_dataset: 202 | prompt_labels, fewshot_examples = self._sample_fewshot_examples( 203 | prompt_template, fewshot_dataset, fewshot_sampling_strategy, fewshot_examples_per_class, 204 | fewshot_sampling_column 205 | ) 206 | 207 | prompt_text = prompt_template.get_prompt_text(prompt_labels, fewshot_examples) 208 | 209 | if unlabeled_dataset: 210 | unlabeled_example = unlabeled_dataset[unlabeled_example_idx] 211 | invocation_context = prompt_template.filter_example_by_columns( 212 | unlabeled_example, prompt_template.fewshot_example_columns 213 | ) 214 | 215 | if log_every_n_api_calls > 0: 216 | if prompt_call_idx % log_every_n_api_calls == 0: 217 | logger.info( 218 | f"Current prompt call: {prompt_call_idx}: \n" 219 | f"Prompt: {prompt_text} \n" 220 | f"Invocation context: {invocation_context} \n" 221 | ) 222 | 223 | prediction = self._try_generate(prompt_text, invocation_context, dummy_response) 224 | 225 | if prediction is None: 226 | current_tries_left -= 1 227 | logger.warning(f"Could not generate example for prompt {prompt_text}.") 228 | if current_tries_left == 0: 229 | logger.warning( 230 | f"Max tries ({self._max_tries}) exceeded. Returning generated dataset with" 231 | f" {len(generated_dataset)} examples." 232 | ) 233 | break 234 | 235 | if len(prediction) == 1: 236 | prediction = prediction[0] 237 | 238 | # If we have a target variable, we re-use the relevant columns of the input example 239 | # and add the prediction to the generated dataset 240 | if prompt_template.generate_data_for_column and unlabeled_example: 241 | generated_sample = prompt_template.filter_example_by_columns( 242 | unlabeled_example, prompt_template.fewshot_example_columns 243 | ) 244 | 245 | for key, value in generated_sample.items(): 246 | generated_dataset[key].append(value) 247 | 248 | # Try to safely convert the prediction to the type of the target variable 249 | if not prompt_template.generate_data_for_column[0] in unlabeled_example: 250 | prediction = self._convert_prediction( 251 | prediction, type(prompt_template.generate_data_for_column[0]) 252 | ) 253 | 254 | generated_dataset[prompt_template.generate_data_for_column[0]].append(prediction) 255 | 256 | else: 257 | generated_dataset[prompt_template.DEFAULT_TEXT_COLUMN[0]].append(prediction) 258 | if prompt_labels and isinstance(prompt_labels, str): 259 | generated_dataset[prompt_template.DEFAULT_LABEL_COLUMN[0]].append(prompt_labels) 260 | 261 | log_entry = { 262 | "prompt": prompt_text, 263 | "invocation_context": invocation_context, 264 | "prediction": prediction, 265 | "target": prompt_template.generate_data_for_column[0] 266 | if prompt_template.generate_data_for_column 267 | else prompt_template.DEFAULT_TEXT_COLUMN[0], 268 | } 269 | with open(current_log_file, "a", encoding="utf-8") as log_file: 270 | log_file.write(f"{json.dumps(log_entry)}\n") 271 | 272 | if return_unlabeled_dataset: 273 | for key, value in unlabeled_example.items(): 274 | original_dataset[key].append(value) 275 | 276 | if prompt_call_idx >= max_prompt_calls: 277 | logger.info("Reached maximum number of prompt calls ({}).", max_prompt_calls) 278 | break 279 | 280 | if len(generated_dataset) >= num_samples_to_generate: 281 | logger.info("Generated {} samples.", num_samples_to_generate) 282 | break 283 | 284 | if timeout_per_prompt is not None: 285 | time.sleep(timeout_per_prompt) 286 | 287 | generated_dataset = Dataset.from_dict(generated_dataset) 288 | 289 | if return_unlabeled_dataset: 290 | original_dataset = Dataset.from_dict(original_dataset) 291 | return generated_dataset, original_dataset 292 | 293 | return generated_dataset, None 294 | 295 | def _convert_prediction(self, prediction: str, target_type: type) -> Any: 296 | """Converts a prediction to the target type. 297 | 298 | Args: 299 | prediction: Prediction to convert. 300 | target_type: Type to convert the prediction to. 301 | 302 | Returns: 303 | Converted prediction. 304 | """ 305 | 306 | if isinstance(prediction, target_type): 307 | return prediction 308 | 309 | try: 310 | return target_type(prediction) 311 | except ValueError: 312 | logger.warning( 313 | "Could not convert prediction {} to type {}. " 314 | "Returning original prediction.", repr(prediction), target_type 315 | ) 316 | return prediction 317 | 318 | @staticmethod 319 | def _sample_fewshot_examples( 320 | prompt_template: BasePrompt, 321 | fewshot_dataset: Dataset, 322 | fewshot_sampling_strategy: str, 323 | fewshot_examples_per_class: int, 324 | fewshot_sampling_column: str 325 | ) -> Tuple[Union[List[str], str], Dataset]: 326 | 327 | if fewshot_sampling_strategy == "uniform": 328 | prompt_labels = choice(prompt_template.label_options, 1)[0] 329 | fewshot_examples = fewshot_dataset.filter( 330 | lambda example: example[fewshot_sampling_column] == prompt_labels 331 | ).shuffle().select(range(fewshot_examples_per_class)) 332 | 333 | elif fewshot_sampling_strategy == "stratified": 334 | prompt_labels = prompt_template.label_options 335 | fewshot_examples = single_label_stratified_sample( 336 | fewshot_dataset, 337 | fewshot_sampling_column, 338 | fewshot_examples_per_class 339 | ) 340 | 341 | else: 342 | prompt_labels = prompt_template.label_options if prompt_template.label_options else None 343 | if fewshot_examples_per_class: 344 | fewshot_examples = fewshot_dataset.shuffle().select(range(fewshot_examples_per_class)) 345 | else: 346 | fewshot_examples = fewshot_dataset.shuffle() 347 | 348 | assert len(fewshot_examples) > 0, f"Could not find any fewshot examples for label(s) {prompt_labels}." \ 349 | f"Ensure that labels of fewshot examples match the label_options " \ 350 | f"from the prompt." 351 | 352 | return prompt_labels, fewshot_examples 353 | 354 | @staticmethod 355 | def _assert_fewshot_dataset_matches_prompt(prompt_template: BasePrompt, fewshot_dataset: Dataset) -> None: 356 | """Asserts that the prompt template is valid and all columns are present in the fewshot dataset.""" 357 | assert all( 358 | field in fewshot_dataset.column_names for field in prompt_template.relevant_columns_for_fewshot_examples 359 | ), "Not all required variables of the prompt template occur in the support examples." 360 | -------------------------------------------------------------------------------- /src/fabricator/dataset_transformations/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "preprocess_squad_format", 3 | "postprocess_squad_format", 4 | "calculate_answer_start", 5 | "convert_label_ids_to_texts", 6 | "get_labels_from_dataset", 7 | "replace_class_labels", 8 | "convert_token_labels_to_spans", 9 | "convert_spans_to_token_labels", 10 | ] 11 | 12 | from .question_answering import preprocess_squad_format, postprocess_squad_format, calculate_answer_start 13 | from .text_classification import convert_label_ids_to_texts, get_labels_from_dataset, replace_class_labels 14 | from .token_classification import convert_token_labels_to_spans, convert_spans_to_token_labels 15 | -------------------------------------------------------------------------------- /src/fabricator/dataset_transformations/question_answering.py: -------------------------------------------------------------------------------- 1 | from datasets import Dataset 2 | from loguru import logger 3 | 4 | 5 | def preprocess_squad_format(dataset: Dataset) -> Dataset: 6 | """Preprocesses a dataset in SQuAD format (nested answers) to a dataset in SQuAD format that has flat answers. 7 | {"answer": {"text": "answer", "start": 0}} -> {"text": "answer"} 8 | 9 | Args: 10 | dataset (Dataset): A huggingface dataset in SQuAD format. 11 | 12 | Returns: 13 | Dataset: A huggingface dataset in SQuAD format with flat answers. 14 | """ 15 | 16 | def preprocess(example): 17 | if example["answers"]: 18 | example["answers"] = example["answers"].pop() 19 | else: 20 | example["answers"] = "" 21 | return example 22 | 23 | dataset = dataset.flatten().rename_column("answers.text", "answers").map(preprocess) 24 | return dataset 25 | 26 | 27 | def postprocess_squad_format(dataset: Dataset, add_answer_start: bool = True) -> Dataset: 28 | """Postprocesses a dataset in SQuAD format (flat answers) to a dataset in SQuAD format that has nested answers. 29 | {"text": "answer"} -> {"answer": {"text": "answer", "start": 0}} 30 | 31 | Args: 32 | dataset (Dataset): A huggingface dataset in SQuAD format. 33 | add_answer_start (bool, optional): Whether to add the answer start index to the dataset. Defaults to True. 34 | 35 | Returns: 36 | Dataset: A huggingface dataset in SQuAD format with nested answers. 37 | """ 38 | # remove punctuation and whitespace from the start and end of the answer 39 | def remove_punctuation(example): 40 | example["answers"] = example["answers"].strip(".,;!? ") 41 | return example 42 | 43 | dataset = dataset.map(remove_punctuation) 44 | 45 | if add_answer_start: 46 | dataset = dataset.map(calculate_answer_start) 47 | 48 | def unify_answers(example): 49 | is_answerable = "answer_start" in example 50 | if is_answerable: 51 | example["answers"] = {"text": [example["answers"]], "answer_start": [example["answer_start"]]} 52 | else: 53 | example["answers"] = {"text": [], "answer_start": []} 54 | return example 55 | 56 | dataset = dataset.map(unify_answers) 57 | if "answer_start" in dataset.column_names: 58 | dataset = dataset.remove_columns("answer_start") 59 | return dataset 60 | 61 | 62 | def calculate_answer_start(example): 63 | """Calculates the answer start index for a SQuAD example. 64 | 65 | Args: 66 | example (Dict): A SQuAD example. 67 | 68 | Returns: 69 | Dict: The SQuAD example with the answer start index added. 70 | """ 71 | answer_start = example["context"].lower().find(example["answers"].lower()) 72 | if answer_start < 0: 73 | logger.info( 74 | 'Could not calculate the answer start because the context "{}" ' 'does not contain the answer "{}".', 75 | example["context"], 76 | example["answers"], 77 | ) 78 | answer_start = -1 79 | else: 80 | # check that the answer doesn't occur more than once in the context 81 | second_answer_start = example["context"].lower().find(example["answers"].lower(), answer_start + 1) 82 | if second_answer_start >= 0: 83 | logger.info("Could not calculate the answer start because the context contains the answer more than once.") 84 | answer_start = -1 85 | else: 86 | # correct potential wrong capitalization of the answer compared to the context 87 | example["answers"] = example["context"][answer_start : answer_start + len(example["answers"])] 88 | example["answer_start"] = answer_start 89 | return example 90 | -------------------------------------------------------------------------------- /src/fabricator/dataset_transformations/text_classification.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Union 2 | 3 | from datasets import Dataset, DatasetDict, ClassLabel, Sequence 4 | 5 | 6 | def get_labels_from_dataset(dataset: Union[Dataset, DatasetDict], label_column: str) -> List[str]: 7 | """Gets the list of labels from a huggingface Dataset. 8 | 9 | Args: 10 | dataset (Union[Dataset, DatasetDict]): huggingface Dataset 11 | label_column (str): name of the column with the labels 12 | 13 | Returns: 14 | List[str]: list of labels 15 | """ 16 | if isinstance(dataset, DatasetDict): 17 | tmp_ref_dataset = dataset["train"] 18 | else: 19 | tmp_ref_dataset = dataset 20 | 21 | if isinstance(tmp_ref_dataset.features[label_column], ClassLabel): 22 | features = tmp_ref_dataset.features[label_column] 23 | elif isinstance(tmp_ref_dataset.features[label_column], Sequence): 24 | features = tmp_ref_dataset.features[label_column].feature 25 | else: 26 | raise ValueError(f"Label column {label_column} is not of type ClassLabel or Sequence.") 27 | return features.names 28 | 29 | 30 | def replace_class_labels(id2label: Union[Dict[str, str], Dict[int, str]], expanded_labels: Dict) -> Dict: 31 | """Replaces class labels with expanded labels, i.e. label LOC should be expanded to LOCATION. 32 | Values of id2label need to match keys of expanded_labels. 33 | 34 | Args: 35 | id2label (Dict): mapping from label ids to label names 36 | expanded_labels (Dict): mapping from label names or label ids to expanded label names 37 | 38 | Returns: 39 | Dict: mapping from label ids to label names with expanded labels 40 | """ 41 | if all(isinstance(k, int) for k in expanded_labels.keys()): 42 | type_keys = "int" 43 | elif all(isinstance(k, str) for k in expanded_labels.keys()): 44 | type_keys = "str" 45 | else: 46 | raise ValueError("Keys of expanded_labels must be either all ints or all strings.") 47 | 48 | if not all(isinstance(v, str) for v in expanded_labels.values()): 49 | raise ValueError("Values of expanded_labels must be strings.") 50 | 51 | if type_keys == "str": 52 | replaced_id2label = {} 53 | for idx, tag in id2label.items(): 54 | if tag in expanded_labels: 55 | replaced_id2label[idx] = expanded_labels[tag] 56 | else: 57 | replaced_id2label[idx] = tag 58 | else: 59 | replaced_id2label = expanded_labels 60 | return replaced_id2label 61 | 62 | 63 | def convert_label_ids_to_texts( 64 | dataset: Union[Dataset, DatasetDict], 65 | label_column: str, 66 | expanded_label_mapping: Dict = None, 67 | return_label_options: bool = False, 68 | ) -> Union[Dataset, DatasetDict, Tuple[Union[Dataset, DatasetDict], List[str]]]: 69 | """Converts label IDs to natural language labels for any classification problem with a single label such as text 70 | classification. Note that if the function is not applied to a Dataset, the label column will contain the IDs. 71 | If the function is applied, the label column will contain the natural language labels. 72 | 73 | Args: 74 | dataset (Dataset): huggingface Dataset with label ids. 75 | label_column (str): name of the label column. 76 | expanded_label_mapping (Dict, optional): dictionary mapping label ids to natural language labels. 77 | Defaults to None. 78 | return_label_options (bool, optional): whether to return the list of possible labels. Defaults to False. 79 | 80 | Returns: 81 | Tuple[Dataset, List[str]]: huggingface Dataset with natural language labels and list of natural language 82 | labels. 83 | """ 84 | labels = get_labels_from_dataset(dataset, label_column) 85 | id2label = dict(enumerate(labels)) 86 | 87 | if expanded_label_mapping is not None: 88 | id2label = replace_class_labels(id2label, expanded_label_mapping) 89 | 90 | new_label_column = f"{label_column}_natural_language" 91 | label_options = list(id2label.values()) 92 | 93 | def labels_to_natural_language(examples): 94 | examples[new_label_column] = id2label[examples[label_column]] 95 | return examples 96 | 97 | dataset = ( 98 | dataset.map(labels_to_natural_language) 99 | .remove_columns(label_column) 100 | .rename_column(new_label_column, label_column) 101 | ) 102 | 103 | if return_label_options: 104 | return dataset, label_options 105 | 106 | return dataset 107 | -------------------------------------------------------------------------------- /src/fabricator/dataset_transformations/token_classification.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict, List, Tuple, Union 3 | from datasets import Dataset, Sequence 4 | 5 | from loguru import logger 6 | 7 | # These are fixed for encoding the prompt and decoding the output of the LLM 8 | SPAN_ANNOTATION_TEMPLATE = "{entity} is {label} entity." 9 | SPAN_ANNOTATION_REGEX = r'(.+) is (.+) entity\.' 10 | 11 | 12 | def convert_token_labels_to_spans( 13 | dataset: Dataset, 14 | token_column: str, 15 | label_column: str, 16 | expanded_label_mapping: Dict = None, 17 | return_label_options: bool = False 18 | ) -> Union[Dataset, Tuple[Dataset, List[str]]]: 19 | """Converts token level labels to spans. Useful for NER tasks to prompt the LLM with natural language labels. 20 | 21 | Args: 22 | dataset (Dataset): huggingface Dataset with token level labels 23 | token_column (str): name of the column with the tokens 24 | label_column (str): name of the column with the token level labels 25 | expanded_label_mapping (Dict): mapping from label ids to label names. Defaults to None. 26 | return_label_options (bool): whether to return a list of all possible annotations of the provided dataset 27 | 28 | Returns: 29 | Tuple[Dataset, List[str]]: huggingface Dataset with span labels and list of possible labels for the prompt 30 | """ 31 | if expanded_label_mapping: 32 | if not len(expanded_label_mapping) == len(dataset.features[label_column].feature.names): 33 | raise ValueError( 34 | f"Length of expanded label mapping and original number of labels in dataset do not match.\n" 35 | f"Original labels: {dataset.features[label_column].feature.names}" 36 | f"Expanded labels: {list(expanded_label_mapping.values())}" 37 | ) 38 | id2label = expanded_label_mapping 39 | elif isinstance(dataset.features[label_column], Sequence): 40 | id2label = dict(enumerate(dataset.features[label_column].feature.names)) 41 | else: 42 | raise ValueError("Labels must be a Sequence feature or expanded_label_mapping must be provided.") 43 | 44 | span_column = "span_annotations" 45 | 46 | def labels_to_spans(example): 47 | span_annotations = [id2label.get(label).replace("B-", "").replace("I-", "") for label in example[label_column]] 48 | 49 | annotations_for_prompt = "" 50 | 51 | current_entity = None 52 | current_entity_type = None 53 | for idx, span_annotation in enumerate(span_annotations): 54 | if span_annotation == "O": 55 | if current_entity is not None: 56 | annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity, 57 | label=current_entity_type) + "\n" 58 | current_entity = None 59 | current_entity_type = None 60 | continue 61 | if current_entity is None: 62 | current_entity = example[token_column][idx] 63 | current_entity_type = span_annotation 64 | continue 65 | if current_entity_type == span_annotation: 66 | current_entity += " " + example[token_column][idx] 67 | else: 68 | annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity, 69 | label=current_entity_type) + "\n" 70 | current_entity = example[token_column][idx] 71 | current_entity_type = span_annotation 72 | 73 | if current_entity is not None: 74 | annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity, 75 | label=current_entity_type) + "\n" 76 | 77 | example[token_column] = " ".join(example[token_column]) 78 | example[span_column] = annotations_for_prompt.rstrip("\n") 79 | return example 80 | 81 | dataset = dataset.map(labels_to_spans).remove_columns(label_column).rename_column(span_column, label_column) 82 | 83 | if return_label_options: 84 | # Spans have implicit BIO format, so sequences come in BIO format, we can ignore it 85 | label_options = list({label.replace("B-", "").replace("I-", "") for label in id2label.values()}) 86 | 87 | # Ignore "outside" tokens 88 | if "O" in label_options: 89 | label_options.remove("O") 90 | 91 | return dataset, label_options 92 | 93 | return dataset 94 | 95 | 96 | def convert_spans_to_token_labels( 97 | dataset: Dataset, 98 | token_column: str, 99 | label_column: str, 100 | id2label: Dict, 101 | annotate_identical_words: bool = False 102 | ) -> Dataset: 103 | """Converts span level labels to token level labels. 104 | First, the function extracts all entities with its annotated types. 105 | Second, if annotations are present, the function converts them to a tag sequence in BIO format. 106 | If not present, simply return tag sequence of O-tokens. 107 | This is useful for NER tasks to decode the output of the LLM. 108 | 109 | Args: 110 | dataset (Dataset): huggingface Dataset with span level labels 111 | token_column (str): name of the column with the tokens 112 | label_column (str): name of the column with the span level labels 113 | id2label (Dict): mapping from label ids to label names 114 | annotate_identical_words (bool): whether to annotate all identical words in a sentence with a found entity 115 | type 116 | 117 | Returns: 118 | Dataset: huggingface Dataset with token level labels in BIO format 119 | """ 120 | new_label_column = "sequence_tags" 121 | lower_label2id = {label.lower(): idx for idx, label in id2label.items()} 122 | 123 | def labels_to_spans(example): 124 | span_annotations = example[label_column].split("\n") 125 | 126 | ner_tag_tuples = [] 127 | 128 | for span_annotation in span_annotations: 129 | matches = re.match(SPAN_ANNOTATION_REGEX, span_annotation) 130 | if matches: 131 | matched_entity = matches.group(1) 132 | matched_label = matches.group(2) 133 | 134 | span_tokens = matched_entity.split(" ") 135 | span_labels = ["B-" + matched_label if idx == 0 else "B-" + matched_label.lower() 136 | for idx, _ in enumerate(span_tokens)] 137 | 138 | for token, label in zip(span_tokens, span_labels): 139 | label_id = lower_label2id.get(label.lower()) 140 | if label_id is None: 141 | logger.info(f"Entity {token} with label {label} is not in id2label: {id2label}.") 142 | else: 143 | ner_tag_tuples.append((token, label_id)) 144 | else: 145 | pass 146 | 147 | if ner_tag_tuples: 148 | lower_tokens = example[token_column].lower().split(" ") 149 | # initialize all tokens with O type 150 | ner_tags = [0] * len(lower_tokens) 151 | for reference_token, entity_type_id in ner_tag_tuples: 152 | if lower_tokens.count(reference_token.lower()) == 0: 153 | logger.info( 154 | f"Entity {reference_token} is not found or occurs more than once: {lower_tokens}. " 155 | f"Thus, setting label to O." 156 | ) 157 | elif lower_tokens.count(reference_token.lower()) > 1: 158 | if annotate_identical_words: 159 | insert_at_idxs = [index for index, value in enumerate(lower_tokens) 160 | if value == reference_token.lower()] 161 | for insert_at_idx in insert_at_idxs: 162 | ner_tags[insert_at_idx] = entity_type_id 163 | else: 164 | logger.info( 165 | f"Entity {reference_token} occurs more than once: {lower_tokens}. " 166 | f"Thus, setting label to O." 167 | ) 168 | else: 169 | insert_at_idx = lower_tokens.index(reference_token.lower()) 170 | ner_tags[insert_at_idx] = entity_type_id 171 | else: 172 | ner_tags = [0] * len(example[token_column].split(" ")) 173 | 174 | example[token_column] = example[token_column].split(" ") 175 | example[new_label_column] = ner_tags 176 | 177 | return example 178 | 179 | dataset = ( 180 | dataset.map(labels_to_spans) 181 | .remove_columns(label_column) 182 | .rename_column(new_label_column, label_column) 183 | ) 184 | 185 | return dataset 186 | 187 | 188 | def replace_token_labels(id2label: Dict, expanded_labels: Dict) -> Dict: 189 | """Replaces token level labels with expanded labels, i.e. label PER should be expanded to PERSON. 190 | Values of id2label need to match keys of expanded_labels. 191 | 192 | Args: 193 | id2label (Dict): mapping from label ids to label names 194 | expanded_labels (Dict): mapping from label names to expanded label names 195 | 196 | Returns: 197 | Dict: mapping from label ids to label names with expanded labels 198 | """ 199 | replaced_id2label = {} 200 | for idx, tag in id2label.items(): 201 | if tag.startswith("B-") or tag.startswith("I-"): 202 | prefix, label = tag.split("-", 1) 203 | if label in expanded_labels: 204 | new_label = expanded_labels[label] 205 | new_label_bio = f"{prefix}-{new_label}" 206 | replaced_id2label[idx] = new_label_bio 207 | else: 208 | replaced_id2label[idx] = tag 209 | else: 210 | replaced_id2label[idx] = tag 211 | return replaced_id2label 212 | -------------------------------------------------------------------------------- /src/fabricator/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "BasePrompt", 3 | "infer_prompt_from_dataset", 4 | "infer_prompt_from_task_template" 5 | ] 6 | 7 | from .base import BasePrompt 8 | from .utils import infer_prompt_from_dataset, infer_prompt_from_task_template 9 | -------------------------------------------------------------------------------- /src/fabricator/prompts/base.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Union, Optional 2 | 3 | from datasets import Dataset 4 | from loguru import logger 5 | 6 | 7 | class BasePrompt: 8 | """Base class for prompt generation. This class formats the prompt for the fewshot / support set examples 9 | and the target variable such that the dataset generator can simply put in the invocation context.""" 10 | 11 | DEFAULT_TEXT_COLUMN = ["text"] 12 | DEFAULT_LABEL_COLUMN = ["label"] 13 | 14 | def __init__( 15 | self, 16 | task_description: str, 17 | generate_data_for_column: Optional[str] = None, 18 | fewshot_example_columns: Optional[Union[List[str], str]] = None, 19 | label_options: Optional[List[str]] = None, 20 | fewshot_formatting_template: Optional[str] = None, 21 | target_formatting_template: Optional[str] = None, 22 | fewshot_example_separator: str = "\n\n", 23 | inner_fewshot_example_separator: str = "\n", 24 | ): 25 | """Base class for prompt generation. This class formats the prompt for the fewshot / support set examples. 26 | 27 | Args: 28 | task_description (Optional[str], optional): Task description for the prompt (prefix). 29 | generate_data_for_column (Optional[str], optional): The column name to generate data for. Defaults to None. 30 | fewshot_example_columns (Union[List[str], str]): List of strings or string of column names for the 31 | fewshot / support set examples. Defaults to None. 32 | label_options (Optional[ClassificationOptions], optional): Label options for the LLM to choose from. 33 | Defaults to None. 34 | fewshot_formatting_template (Optional[str], optional): Template for formatting the fewshot / support set 35 | examples. Defaults to None. 36 | target_formatting_template (Optional[str], optional): Template for formatting the target variable. 37 | Defaults to None. 38 | fewshot_example_separator (str, optional): Separator between the fewshot / support set examples. 39 | Defaults to "\n\n". 40 | inner_fewshot_example_separator (str, optional): Separator in-between a single fewshot examples. 41 | Defaults to "\n". 42 | 43 | Raises: 44 | AttributeError: If label_options is not a dict or list 45 | KeyError: If the task_description cannot be formatted with the variable 'label_options' 46 | ValueError: You need specify either generate_data_for_column or 47 | generate_data_for_column + fewshot_example_columns. Only fewshot_example_columns is not supported. 48 | """ 49 | self.task_description = task_description 50 | 51 | if label_options: 52 | self._assert_task_description_is_formattable(task_description) 53 | self.label_options = label_options 54 | 55 | if isinstance(generate_data_for_column, str) and generate_data_for_column: 56 | generate_data_for_column = [generate_data_for_column] 57 | self.generate_data_for_column = generate_data_for_column 58 | 59 | if isinstance(fewshot_example_columns, str) and fewshot_example_columns: 60 | fewshot_example_columns = [fewshot_example_columns] 61 | self.fewshot_example_columns = fewshot_example_columns 62 | self.fewshot_example_separator = fewshot_example_separator 63 | self.inner_fewshot_example_separator = inner_fewshot_example_separator 64 | 65 | if fewshot_example_columns: 66 | self.relevant_columns_for_fewshot_examples = self.fewshot_example_columns + self.generate_data_for_column 67 | elif generate_data_for_column: 68 | self.relevant_columns_for_fewshot_examples = self.generate_data_for_column 69 | else: 70 | self.relevant_columns_for_fewshot_examples = None 71 | 72 | # Create prompt template for fewshot examples 73 | if self.relevant_columns_for_fewshot_examples: 74 | if fewshot_formatting_template is None: 75 | self.fewshot_prompt = self.inner_fewshot_example_separator.join( 76 | [f"{var}: {{{var}}}" for var in self.relevant_columns_for_fewshot_examples] 77 | ) 78 | else: 79 | self.fewshot_prompt = fewshot_formatting_template 80 | 81 | # Create format template for targets 82 | if target_formatting_template is None: 83 | self.target_formatting_template = self._infer_target_formatting_template() 84 | else: 85 | self.target_formatting_template = target_formatting_template 86 | 87 | logger.info(self._log_prompt()) 88 | 89 | @staticmethod 90 | def _assert_task_description_is_formattable(task_description: str) -> None: 91 | """Checks if task_description is formattable. 92 | 93 | Args: 94 | task_description (str): Task description for the prompt (prefix). 95 | """ 96 | if "testxyz" not in task_description.format("testxyz"): 97 | raise KeyError("If you provide label_options, you need the task_description to be formattable like" 98 | " 'Generate a {} text.'") 99 | 100 | def _infer_target_formatting_template(self) -> str: 101 | """Infer target formatting template from input columns and label column. 102 | 103 | Returns: 104 | str: Target formatting template 105 | """ 106 | if self.generate_data_for_column and self.fewshot_example_columns: 107 | target_template = self.inner_fewshot_example_separator.join( 108 | [f"{var}: {{{var}}}" for var in self.fewshot_example_columns] + 109 | [f"{self.generate_data_for_column[0]}: "] 110 | ) 111 | 112 | elif self.generate_data_for_column and not self.fewshot_example_columns: 113 | target_template = f"{self.generate_data_for_column[0]}: " 114 | 115 | elif not self.generate_data_for_column and not self.fewshot_example_columns: 116 | target_template = f"{self.DEFAULT_TEXT_COLUMN[0]}: " 117 | 118 | else: 119 | raise ValueError("Either generate_data_for_column or generate_data_for_column + fewshot_example_columns " 120 | "must be provided to infer target template.") 121 | 122 | return target_template 123 | 124 | def _log_prompt(self) -> str: 125 | """Log prompt. 126 | 127 | Returns: 128 | str: Prompt text 129 | """ 130 | label = None 131 | fewshot_examples = None 132 | 133 | if self.label_options: 134 | label = "EXAMPLE LABEL" 135 | 136 | if self.relevant_columns_for_fewshot_examples: 137 | fewshot_examples = {} 138 | for column in self.relevant_columns_for_fewshot_examples: 139 | fewshot_examples[column] = [f"EXAMPLE TEXT FOR COLUMN {column}"] 140 | fewshot_examples = Dataset.from_dict(fewshot_examples) 141 | 142 | return "\nThe prompt to the LLM will be like:\n" + 10*"-" + "\n"\ 143 | + self.get_prompt_text(label, fewshot_examples) + "\n" + 10*"-" 144 | 145 | @staticmethod 146 | def filter_example_by_columns(example: Dict[str, str], columns: List[str]) -> Dict[str, str]: 147 | """Filter single example by columns. 148 | 149 | Args: 150 | example (Dict[str, str]): Example to filter 151 | columns (List[str]): Columns to keep 152 | 153 | Returns: 154 | Dict[str, str]: Filtered example 155 | """ 156 | filtered_example = {key: value for key, value in example.items() if key in columns} 157 | return filtered_example 158 | 159 | def filter_examples_by_columns(self, dataset: Dataset, columns: List[str]) -> List[Dict[str, str]]: 160 | """Filter examples by columns. 161 | 162 | Args: 163 | dataset (Dataset): Dataset to filter 164 | columns (List[str]): Columns to keep 165 | 166 | Returns: 167 | List[Dict[str, str]]: Filtered examples 168 | """ 169 | filtered_inputs = [] 170 | for example in dataset: 171 | filtered_inputs.append(self.filter_example_by_columns(example, columns)) 172 | return filtered_inputs 173 | 174 | def get_prompt_text(self, labels: Union[str, List[str]] = None, examples: Optional[Dataset] = None) -> str: 175 | """Get prompt text for the given examples. 176 | 177 | Args: 178 | labels (Union[str, List[str]], optional): Label(s) to use for the prompt. Defaults to None. 179 | examples (Dataset): Examples to use for the prompt 180 | 181 | Returns: 182 | str: Prompt text 183 | """ 184 | if isinstance(labels, list): 185 | labels = ", ".join(labels) 186 | 187 | if labels: 188 | task_description = self.task_description.format(labels) 189 | else: 190 | task_description = self.task_description 191 | 192 | if examples: 193 | examples = self.filter_examples_by_columns(examples, self.relevant_columns_for_fewshot_examples) 194 | formatted_examples = [self.fewshot_prompt.format(**example) for example in examples] 195 | prompt_text = self.fewshot_example_separator.join( 196 | [task_description] + formatted_examples + [self.target_formatting_template] 197 | ) 198 | else: 199 | prompt_text = self.fewshot_example_separator.join( 200 | [task_description] + [self.target_formatting_template] 201 | ) 202 | return prompt_text 203 | -------------------------------------------------------------------------------- /src/fabricator/prompts/utils.py: -------------------------------------------------------------------------------- 1 | from datasets import Dataset, QuestionAnsweringExtractive, TextClassification, TaskTemplate 2 | from .base import BasePrompt 3 | 4 | DEFAULT_TEXT_CLASSIFICATION = "Classify the following texts exactly into one of the following categories: {" \ 5 | "}." 6 | DEFAULT_QA = "Given a context and a question, generate an answer that occurs exactly and only once in the text." 7 | 8 | 9 | def infer_prompt_from_task_template(task_template: TaskTemplate) -> BasePrompt: 10 | """Infer TextLabelPrompt or ClassLabelPrompt with correct parameters from a task template's metadata. 11 | 12 | Args: 13 | task_template (TaskTemplate): task template from which to infer the prompt. 14 | 15 | Returns: 16 | BasePrompt: with correct parameters. 17 | """ 18 | if isinstance(task_template, QuestionAnsweringExtractive): 19 | return BasePrompt( 20 | task_description=DEFAULT_QA, 21 | generate_data_for_column=task_template.answers_column, # assuming the dataset was preprocessed 22 | # with preprocess_squad_format, otherwise dataset.task_templates[0]["answers_column"] 23 | fewshot_example_columns=[task_template.context_column, task_template.question_column], 24 | ) 25 | 26 | if isinstance(task_template, TextClassification): 27 | return BasePrompt( 28 | task_description=DEFAULT_TEXT_CLASSIFICATION, 29 | generate_data_for_column=task_template.label_column, 30 | fewshot_example_columns=task_template.text_column, 31 | label_options=task_template.label_schema["labels"].names, 32 | ) 33 | 34 | raise ValueError( 35 | f"Automatic prompt is only supported for QuestionAnsweringExtractive and " 36 | f"TextClassification tasks but not for {type(task_template)}. You need to " 37 | f"specify the prompt manually." 38 | ) 39 | 40 | 41 | def infer_prompt_from_dataset(dataset: Dataset): 42 | """Infer TextLabelPrompt or ClassLabelPrompt with correct parameters from a dataset's metadata.""" 43 | if not dataset.task_templates: 44 | raise ValueError( 45 | "Dataset must have exactly one task template but there is none. You need to specify the " 46 | "prompt manually." 47 | ) 48 | 49 | if len(dataset.task_templates) > 1: 50 | raise ValueError( 51 | f"Automatic prompt is only supported for datasets with exactly one task template but yours " 52 | f"has {len(dataset.task_templates)}. You need to specify the prompt manually." 53 | ) 54 | 55 | return infer_prompt_from_task_template(dataset.task_templates[0]) 56 | -------------------------------------------------------------------------------- /src/fabricator/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "single_label_task_sampler", 3 | "single_label_stratified_sample", 4 | "random_sampler", 5 | "ml_mc_sampler" 6 | ] 7 | 8 | from .samplers import single_label_task_sampler, single_label_stratified_sample, \ 9 | random_sampler, ml_mc_sampler 10 | -------------------------------------------------------------------------------- /src/fabricator/samplers/samplers.py: -------------------------------------------------------------------------------- 1 | """Sampling methods 2 | 3 | NOTE: All methods do not ensure, that all labels are contained in the samples. 4 | """ 5 | import random 6 | from typing import Dict, List, Set, Union, Tuple 7 | from collections import defaultdict, deque 8 | from itertools import cycle 9 | 10 | from datasets import ClassLabel, Dataset, Sequence, Value 11 | from loguru import logger 12 | from tqdm import tqdm 13 | 14 | 15 | def random_sampler(dataset: Dataset, num_examples: int) -> Dataset: 16 | """Random sampler""" 17 | return dataset.select(random.sample(range(len(dataset)), num_examples)) 18 | 19 | 20 | def single_label_task_sampler( 21 | dataset: Dataset, label_column: str, num_examples: int, return_unused_split: bool = False 22 | ) -> Dataset: 23 | """Sampler for single label tasks, like text classification 24 | 25 | Args: 26 | dataset: Dataset 27 | label_column: Name of the label column 28 | num_examples: Number of examples to sample 29 | return_unused_split: Whether to return the unused split 30 | 31 | Approach: 32 | num_examples > len(dataset): Samples all examples 33 | num_examples < len(dataset): Samples at least one example per label 34 | num_examples < len(dataset.features): Samples only num_examples and notify 35 | """ 36 | 37 | if num_examples > len(dataset): 38 | return dataset 39 | 40 | if "train" in dataset: 41 | dataset = dataset["train"] 42 | 43 | pbar = tqdm(total=num_examples, desc="Sampling") 44 | 45 | class_labels = _infer_class_labels(dataset, label_column) 46 | num_classes = len(class_labels) 47 | 48 | unique_classes_sampled = set() 49 | total_examples_sampled = 0 50 | 51 | sampled_indices = [] 52 | 53 | while total_examples_sampled < num_examples: 54 | # Lets try to be as random and possible and sample from the entire dataset 55 | idx = random.sample(range(len(dataset)), 1)[0] 56 | 57 | # Pass already sampled idx 58 | if idx in sampled_indices: 59 | continue 60 | 61 | sample = dataset.select([idx])[0] 62 | label = sample[label_column] 63 | 64 | # First sample at least one example per label 65 | if label not in unique_classes_sampled: 66 | unique_classes_sampled.add(label) 67 | sampled_indices.append(idx) 68 | total_examples_sampled += 1 69 | pbar.update(1) 70 | 71 | # Further sample if we collected at least one example per label 72 | elif len(sampled_indices) < num_examples and len(unique_classes_sampled) == num_classes: 73 | sampled_indices.append(idx) 74 | total_examples_sampled += 1 75 | pbar.update(1) 76 | 77 | if return_unused_split: 78 | unused_indices = list(set(range(len(dataset))) - set(sampled_indices)) 79 | return dataset.select(sampled_indices), dataset.select(unused_indices) 80 | 81 | return dataset.select(sampled_indices) 82 | 83 | 84 | def _alternate_classes(dataset: Dataset, column: str) -> Dataset: 85 | """Alternate the occurrence of each class in the dataset. 86 | 87 | Args: 88 | dataset: Dataset 89 | column: Name of the column to alternate the classes of 90 | 91 | Returns: 92 | Dataset with the classes alternated 93 | """ 94 | 95 | # Group the indices of each unique value in 'target' column 96 | targets = defaultdict(deque) 97 | for i, elem in enumerate(dataset[column]): 98 | targets[elem].append(i) 99 | 100 | # Create a cycle iterator from targets 101 | targets_cycle = cycle(targets.keys()) 102 | 103 | # Alternate the occurrence of each class 104 | alternate_indices = [] 105 | for target in targets_cycle: 106 | if not targets[target]: # If this class has no more indices, remove it from the cycle 107 | del targets[target] 108 | else: # Otherwise, add the next index of this class to the list 109 | alternate_indices.append(targets[target].popleft()) 110 | 111 | # If there are no more indices, break the loop 112 | if not targets: 113 | break 114 | 115 | # Create new dataset from the alternate indices 116 | alternate_dataset = dataset.select(alternate_indices) 117 | 118 | return alternate_dataset 119 | 120 | 121 | def single_label_stratified_sample( 122 | dataset: Dataset, 123 | label_column: str, 124 | num_examples_per_class: int, 125 | return_unused_split: bool = False 126 | ) -> Union[Dataset, Tuple[Dataset, Dataset]]: 127 | """Stratified sampling for single label tasks, like text classification. 128 | 129 | Args: 130 | dataset: Dataset 131 | label_column: Name of the label column 132 | num_examples_per_class: Number of examples to sample per class 133 | return_unused_split: If True, return the unused split of the dataset 134 | 135 | Returns: 136 | Dataset: Stratified sample of the dataset 137 | """ 138 | # Ensure the 'k' value is valid 139 | if num_examples_per_class <= 0: 140 | raise ValueError("'num_examples_per_class' should be a positive integer.") 141 | 142 | # Group the indices of each unique value in 'target' column 143 | targets = defaultdict(list) 144 | for i, elem in enumerate(dataset[label_column]): 145 | targets[elem].append(i) 146 | 147 | # Check if k is smaller or equal than the size of the smallest group 148 | if num_examples_per_class > min(len(indices) for indices in targets.values()): 149 | raise ValueError( 150 | "'num_examples_per_class' is greater than the size of the smallest group in the target column." 151 | ) 152 | 153 | # Stratified sampling 154 | sample_indices = [] 155 | for indices in targets.values(): 156 | sample_indices.extend(random.sample(indices, num_examples_per_class)) 157 | 158 | # Create new dataset from the sample 159 | sample_dataset = dataset.select(sample_indices) 160 | sample_dataset = _alternate_classes(sample_dataset, label_column) 161 | 162 | if return_unused_split: 163 | unused_indices = list(set(range(len(dataset))) - set(sample_indices)) 164 | return sample_dataset, dataset.select(unused_indices) 165 | 166 | return sample_dataset 167 | 168 | 169 | def ml_mc_sampler(dataset: Dataset, labels_column: str, num_examples: int) -> Dataset: 170 | """Multi label multi class sampler 171 | 172 | Args: 173 | dataset: Dataset 174 | label_column: Name of the label column 175 | num_examples: Number of examples to sample, if -1 sample as long as subset does not contain every label 176 | 177 | """ 178 | 179 | if num_examples > len(dataset): 180 | return dataset 181 | 182 | if "train" in dataset: 183 | dataset = dataset["train"] 184 | 185 | total_labels = _infer_class_labels(dataset, labels_column) 186 | num_classes = len(total_labels) 187 | 188 | # Because of random sampling we do not ensure, that we ever sampled all examples 189 | # Nor do we know if all labels are present. We therefore use a max try counter 190 | # So we dont get stuck in infinite while loop 191 | if num_examples == -1: 192 | max_tries = 2 * len(dataset) 193 | else: 194 | max_tries = -1 195 | 196 | tries = 0 197 | 198 | pbar = tqdm(total=num_examples, desc="Sampling") 199 | 200 | unique_classes_sampled: Set[str] = set() 201 | total_examples_sampled = 0 202 | 203 | sampled_indices = [] 204 | 205 | while len(unique_classes_sampled) < len(total_labels): 206 | # Lets try to be as random and possible and sample from the entire dataset 207 | idx = random.sample(range(len(dataset)), 1)[0] 208 | 209 | # Pass already sampled idx 210 | if idx in sampled_indices: 211 | continue 212 | 213 | sample = dataset.select([idx])[0] 214 | 215 | labels = sample[labels_column] 216 | 217 | if not isinstance(labels, list): 218 | labels = [labels] 219 | 220 | labels_found = [label for label in labels if label not in unique_classes_sampled] 221 | 222 | # Check if current sample contains labels not found yet 223 | complements = _relative_complements(labels_found, unique_classes_sampled) 224 | 225 | if len(complements) > 0: 226 | unique_classes_sampled.update(complements) 227 | sampled_indices.append(idx) 228 | total_examples_sampled += 1 229 | pbar.update(1) 230 | 231 | # Further sample if we collected at least one example per label 232 | elif len(sampled_indices) < num_examples and len(unique_classes_sampled) == num_classes: 233 | sampled_indices.append(idx) 234 | total_examples_sampled += 1 235 | pbar.update(1) 236 | 237 | if num_examples != -1 and total_examples_sampled == num_examples: 238 | break 239 | 240 | tries += 1 241 | if tries == max_tries: 242 | logger.info("Stopping sample. Max tries(={}) exceeded.", max_tries) 243 | break 244 | 245 | return dataset.select(sampled_indices) 246 | 247 | 248 | def _infer_class_labels(dataset: Dataset, label_column: str) -> Dict[int, str]: 249 | """Infer the total set of labels""" 250 | features = dataset.features 251 | 252 | if label_column not in features: 253 | raise ValueError(f"Label column {label_column} not found in dataset") 254 | 255 | if isinstance(features[label_column], Value): 256 | logger.info("Label column {} is of type Value. Inferring labels from dataset", label_column) 257 | class_labels = dataset.class_encode_column(label_column).features[label_column].names 258 | elif isinstance(features[label_column], ClassLabel): 259 | class_labels = features[label_column].names 260 | elif isinstance(features[label_column], Sequence): 261 | class_labels = features[label_column].feature.names 262 | else: 263 | raise ValueError( 264 | f"Label column {label_column} is of type {type(features[label_column])}. Expected Value, " 265 | f"ClassLabel or Sequence" 266 | ) 267 | 268 | return dict(enumerate(class_labels)) 269 | 270 | 271 | def _relative_complements(list1: List, list2: Union[List, Set]) -> Set: 272 | """a \\ b""" 273 | return set(list1) - set(list2) 274 | -------------------------------------------------------------------------------- /src/fabricator/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | 5 | def log_dir(): 6 | """Returns the log directory. 7 | 8 | Note: 9 | Keep it simple for now 10 | """ 11 | return os.environ.get("LOG_DIR", "./logs") 12 | 13 | 14 | def create_timestamp_path(directory: str): 15 | """Returns a timestamped path for logging.""" 16 | return os.path.join(directory, datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) 17 | 18 | 19 | def save_create_directory(path: str): 20 | """Creates a directory if it does not exist.""" 21 | if not os.path.exists(path): 22 | os.makedirs(path) 23 | -------------------------------------------------------------------------------- /tests/test_dataset_generator.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from datasets import Dataset, load_dataset 4 | 5 | from fabricator import DatasetGenerator 6 | from fabricator.prompts import BasePrompt 7 | from fabricator.dataset_transformations.text_classification import convert_label_ids_to_texts 8 | 9 | 10 | class TestDatasetGenerator(unittest.TestCase): 11 | """Testcase for Prompts""" 12 | 13 | def setUp(self) -> None: 14 | """Set up test dataset""" 15 | self.text_classification_dataset = Dataset.from_dict({ 16 | "text": ["This movie is great!", "This movie is bad!"], 17 | "label": ["positive", "negative"] 18 | }) 19 | 20 | # We are using dummy respones here, because we are not testing the LLM itself. 21 | self.generator = DatasetGenerator(None) 22 | 23 | def test_simple_generation(self): 24 | """Test simple generation without fewshot examples.""" 25 | prompt = BasePrompt( 26 | task_description="Generate a short movie review.", 27 | ) 28 | 29 | generated_dataset = self.generator.generate( 30 | prompt_template=prompt, 31 | max_prompt_calls=2, 32 | dummy_response="A dummy movie review." 33 | ) 34 | 35 | self.assertEqual(len(generated_dataset), 2) 36 | self.assertEqual(generated_dataset.features["text"].dtype, "string") 37 | self.assertIn("text", generated_dataset.features) 38 | 39 | def test_simple_generation_with_label_options(self): 40 | """Test simple generation without fewshot examples with label options.""" 41 | prompt = BasePrompt( 42 | task_description="Generate a short {} movie review.", 43 | label_options=["positive", "negative"], 44 | ) 45 | 46 | generated_dataset = self.generator.generate( 47 | prompt_template=prompt, 48 | max_prompt_calls=2, 49 | dummy_response="A dummy movie review." 50 | ) 51 | 52 | self.assertEqual(len(generated_dataset), 2) 53 | self.assertEqual(generated_dataset.features["text"].dtype, "string") 54 | self.assertIn("text", generated_dataset.features) 55 | 56 | def test_generation_with_fewshot_examples(self): 57 | label_options = ["positive", "negative"] 58 | 59 | prompt = BasePrompt( 60 | task_description="Generate a {} movie review.", 61 | label_options=label_options, 62 | generate_data_for_column="text", 63 | ) 64 | 65 | generated_dataset = self.generator.generate( 66 | prompt_template=prompt, 67 | fewshot_dataset=self.text_classification_dataset, 68 | fewshot_examples_per_class=1, 69 | fewshot_sampling_strategy="uniform", 70 | fewshot_sampling_column="label", 71 | max_prompt_calls=2, 72 | dummy_response="A dummy movie review." 73 | ) 74 | 75 | self.assertEqual(len(generated_dataset), 2) 76 | self.assertEqual(generated_dataset.features["text"].dtype, "string") 77 | self.assertIn("text", generated_dataset.features) 78 | 79 | def test_annotation_with_fewshot_and_unlabeled_examples(self): 80 | label_options = ["positive", "negative"] 81 | 82 | unlabeled_dataset = Dataset.from_dict({ 83 | "text": ["This movie was a blast!", "This movie was not bad!"], 84 | }) 85 | 86 | prompt = BasePrompt( 87 | task_description="Annotate movie reviews as either: {}.", 88 | label_options=label_options, 89 | generate_data_for_column="label", 90 | fewshot_example_columns="text", 91 | ) 92 | 93 | generated_dataset = self.generator.generate( 94 | prompt_template=prompt, 95 | fewshot_dataset=self.text_classification_dataset, 96 | fewshot_examples_per_class=1, 97 | fewshot_sampling_strategy="stratified", 98 | unlabeled_dataset=unlabeled_dataset, 99 | max_prompt_calls=2, 100 | dummy_response="A dummy movie review." 101 | ) 102 | 103 | self.assertEqual(len(generated_dataset), 2) 104 | self.assertEqual(generated_dataset.features["text"].dtype, "string") 105 | self.assertEqual(generated_dataset.features["label"].dtype, "string") 106 | self.assertIn("text", generated_dataset.features) 107 | self.assertIn("label", generated_dataset.features) 108 | self.assertEqual(generated_dataset[0]["text"], "This movie was a blast!") 109 | self.assertEqual(generated_dataset[1]["text"], "This movie was not bad!") 110 | 111 | def test_tranlation(self): 112 | fewshot_dataset = Dataset.from_dict({ 113 | "german": ["Der Film ist großartig!", "Der Film ist schlecht!"], 114 | "english": ["This movie is great!", "This movie is bad!"], 115 | }) 116 | 117 | unlabeled_dataset = Dataset.from_dict({ 118 | "english": ["This movie was a blast!", "This movie was not bad!"], 119 | }) 120 | 121 | prompt = BasePrompt( 122 | task_description="Translate to german:", # Since we do not have a label column, 123 | # we can just specify the task description 124 | generate_data_for_column="german", 125 | fewshot_example_columns="english", 126 | ) 127 | 128 | generated_dataset = self.generator.generate( 129 | prompt_template=prompt, 130 | fewshot_dataset=fewshot_dataset, 131 | fewshot_examples_per_class=2, # Take both fewshot examples per prompt 132 | fewshot_sampling_strategy=None, 133 | # Since we do not have a class label column, we can just set this to None 134 | # (default) 135 | unlabeled_dataset=unlabeled_dataset, 136 | max_prompt_calls=2, 137 | dummy_response="A dummy movie review." 138 | ) 139 | 140 | self.assertEqual(len(generated_dataset), 2) 141 | self.assertEqual(generated_dataset.features["english"].dtype, "string") 142 | self.assertEqual(generated_dataset.features["german"].dtype, "string") 143 | self.assertIn("english", generated_dataset.features) 144 | self.assertIn("german", generated_dataset.features) 145 | 146 | def test_textual_similarity(self): 147 | dataset = load_dataset("glue", "mrpc", split="train") 148 | dataset, label_options = convert_label_ids_to_texts(dataset, "label", return_label_options=True) # convert the 149 | # label ids to text labels and return the label options 150 | 151 | fewshot_dataset = dataset.select(range(10)) 152 | unlabeled_dataset = dataset.select(range(10, 20)) 153 | 154 | prompt = BasePrompt( 155 | task_description="Annotate the sentence pair whether it is: {}", 156 | label_options=label_options, 157 | generate_data_for_column="label", 158 | fewshot_example_columns=["sentence1", "sentence2"], 159 | ) 160 | 161 | generated_dataset, original_dataset = self.generator.generate( 162 | prompt_template=prompt, 163 | fewshot_dataset=fewshot_dataset, 164 | fewshot_examples_per_class=1, 165 | fewshot_sampling_column="label", 166 | fewshot_sampling_strategy="stratified", 167 | unlabeled_dataset=unlabeled_dataset, 168 | max_prompt_calls=2, 169 | return_unlabeled_dataset=True, 170 | dummy_response="A dummy movie review." 171 | ) 172 | 173 | self.assertEqual(len(generated_dataset), 2) 174 | self.assertEqual(generated_dataset.features["sentence1"].dtype, "string") 175 | self.assertEqual(generated_dataset.features["sentence2"].dtype, "string") 176 | self.assertEqual(generated_dataset.features["label"].dtype, "string") 177 | 178 | def test_sampling_all_fewshot_examples(self): 179 | """Test sampling all fewshot examples""" 180 | prompt = BasePrompt( 181 | task_description="Generate a short movie review.", 182 | ) 183 | 184 | prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples( 185 | prompt_template=prompt, 186 | fewshot_dataset=self.text_classification_dataset, 187 | fewshot_sampling_strategy=None, 188 | fewshot_examples_per_class=None, 189 | fewshot_sampling_column=None, 190 | ) 191 | 192 | self.assertEqual(prompt_labels, None) 193 | self.assertEqual(len(fewshot_examples), 2) 194 | 195 | def test_sampling_all_fewshot_examples_with_label_options(self): 196 | """Test sampling all fewshot examples with label options""" 197 | prompt = BasePrompt( 198 | task_description="Generate a short movie review: {}.", 199 | label_options=["positive", "negative"], 200 | ) 201 | 202 | prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples( 203 | prompt_template=prompt, 204 | fewshot_dataset=self.text_classification_dataset, 205 | fewshot_sampling_strategy=None, 206 | fewshot_examples_per_class=None, 207 | fewshot_sampling_column=None, 208 | ) 209 | 210 | prompt_text = prompt.task_description.format(prompt_labels) 211 | self.assertIn("positive", prompt_text) 212 | self.assertIn("negative", prompt_text) 213 | 214 | def test_sampling_uniform_fewshot_examples(self): 215 | """Test uniform sampling fewshot examples""" 216 | prompt = BasePrompt( 217 | task_description="Generate a short movie review: {}.", 218 | label_options=["positive", "negative"], 219 | ) 220 | 221 | prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples( 222 | prompt_template=prompt, 223 | fewshot_dataset=self.text_classification_dataset, 224 | fewshot_sampling_strategy="uniform", 225 | fewshot_examples_per_class=1, 226 | fewshot_sampling_column="label", 227 | ) 228 | 229 | self.assertEqual(len(fewshot_examples), 1) 230 | self.assertIn(prompt_labels, ["positive", "negative"]) 231 | 232 | def test_sampling_stratified_fewshot_examples(self): 233 | """Test stratified sampling fewshot examples""" 234 | larger_fewshot_dataset = Dataset.from_dict({ 235 | "text": ["This movie is great!", "This movie is bad!", "This movie is great!", "This movie is bad!"], 236 | "label": ["positive", "negative", "positive", "negative"] 237 | }) 238 | 239 | prompt = BasePrompt( 240 | task_description="Generate a short movie review: {}.", 241 | label_options=["positive", "negative"], 242 | ) 243 | 244 | prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples( 245 | prompt_template=prompt, 246 | fewshot_dataset=larger_fewshot_dataset, 247 | fewshot_sampling_strategy="stratified", 248 | fewshot_examples_per_class=1, 249 | fewshot_sampling_column="label", 250 | ) 251 | 252 | self.assertEqual(len(fewshot_examples), 2) 253 | self.assertEqual(len(set(fewshot_examples["label"])), 2) 254 | self.assertEqual(len(prompt_labels), 2) 255 | 256 | def test_sampling_uniform_fewshot_examples_without_number_of_examples(self): 257 | """Test failure of uniform sampling fewshot examples if attributes are missing""" 258 | prompt = BasePrompt( 259 | task_description="Generate a short movie review: {}.", 260 | label_options=["positive", "negative"], 261 | ) 262 | 263 | with self.assertRaises(KeyError): 264 | prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples( 265 | prompt_template=prompt, 266 | fewshot_dataset=self.text_classification_dataset, 267 | fewshot_sampling_strategy="uniform", 268 | fewshot_examples_per_class=None, 269 | fewshot_sampling_column=None, 270 | ) 271 | 272 | def test_sampling_uniform_fewshot_examples_with_number_of_examples_without_sampling_column(self): 273 | """Test failure of uniform sampling fewshot examples if attributes are missing""" 274 | prompt = BasePrompt( 275 | task_description="Generate a short movie review: {}.", 276 | label_options=["positive", "negative"], 277 | ) 278 | 279 | with self.assertRaises(KeyError): 280 | prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples( 281 | prompt_template=prompt, 282 | fewshot_dataset=self.text_classification_dataset, 283 | fewshot_sampling_strategy="uniform", 284 | fewshot_examples_per_class=1, 285 | fewshot_sampling_column=None, 286 | ) 287 | 288 | def test_dummy_response(self): 289 | 290 | prompt = BasePrompt( 291 | task_description="Generate a short movie review.", 292 | ) 293 | generated_dataset = self.generator.generate( 294 | prompt_template=prompt, 295 | max_prompt_calls=2, 296 | dummy_response=lambda _: "This is a dummy movie review." 297 | ) 298 | 299 | self.assertEqual(len(generated_dataset), 2) 300 | self.assertEqual(generated_dataset.features["text"].dtype, "string") 301 | self.assertIn("text", generated_dataset.features) 302 | self.assertEqual(generated_dataset[0]["text"], "This is a dummy movie review.") 303 | self.assertEqual(generated_dataset[1]["text"], "This is a dummy movie review.") 304 | 305 | 306 | generated_dataset = self.generator.generate( 307 | prompt_template=prompt, 308 | max_prompt_calls=2, 309 | dummy_response="This is a dummy movie review as a string." 310 | ) 311 | 312 | self.assertEqual(len(generated_dataset), 2) 313 | self.assertEqual(generated_dataset.features["text"].dtype, "string") 314 | self.assertIn("text", generated_dataset.features) 315 | self.assertEqual(generated_dataset[0]["text"], "This is a dummy movie review as a string.") 316 | self.assertEqual(generated_dataset[1]["text"], "This is a dummy movie review as a string.") 317 | -------------------------------------------------------------------------------- /tests/test_dataset_sampler.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from collections import Counter 4 | from datasets import load_dataset 5 | 6 | from fabricator.samplers import random_sampler, single_label_task_sampler, ml_mc_sampler, \ 7 | single_label_stratified_sample 8 | 9 | 10 | def _flatten(l): 11 | return [item for sublist in l for item in sublist] 12 | 13 | 14 | class TestDatasetSamplerMethodsSingleLabel(unittest.TestCase): 15 | """Testcase for dataset sampler methods""" 16 | 17 | def setUp(self) -> None: 18 | self.dataset = load_dataset("imdb", split="train") 19 | 20 | def test_random_sampler(self): 21 | """Test random sampler""" 22 | random_sample = random_sampler(self.dataset, num_examples=10) 23 | self.assertEqual(len(random_sample), 10) 24 | 25 | def test_single_label_task_sampler(self): 26 | """Test single label task sampler. We use imdb which has two labels: positive and negative""" 27 | single_label_sample = single_label_task_sampler(self.dataset, label_column="label", num_examples=2) 28 | self.assertEqual(len(single_label_sample), 2) 29 | labels = list(single_label_sample["label"]) 30 | self.assertEqual(len(set(labels)), 2) 31 | 32 | def test_single_label_task_sampler_more_examples_than_ds(self): 33 | """Test single label task sampler with more examples than dataset""" 34 | subset_dataset = self.dataset.select(range(100)) 35 | single_label_sample = single_label_task_sampler(subset_dataset, label_column="label", num_examples=110) 36 | self.assertEqual(len(single_label_sample), 100) 37 | 38 | 39 | class TestDatasetSamplerMethodsMultiLabel(unittest.TestCase): 40 | """Testcase for multilabel dataset sampler methods""" 41 | 42 | def setUp(self) -> None: 43 | """Load dataset""" 44 | self.dataset = load_dataset("conll2003", split="train") 45 | 46 | def test_ml_mc_sampler(self): 47 | """Test multilabel multiclass sampler""" 48 | subset_dataset = ml_mc_sampler(self.dataset, labels_column="pos_tags", num_examples=10) 49 | label_idxs = list(range(len(self.dataset.features["pos_tags"].feature.names))) 50 | self.assertEqual(len(subset_dataset), 10) 51 | 52 | tags = set(_flatten([sample["pos_tags"] for sample in subset_dataset])) 53 | 54 | # We do not guarantee, that all tags are contained in sampled examples 55 | self.assertLessEqual(len(tags), len(label_idxs)) 56 | 57 | 58 | class TestStratifiedSampler(unittest.TestCase): 59 | 60 | def setUp(self) -> None: 61 | """Load dataset""" 62 | self.dataset = load_dataset("trec", split="train") 63 | 64 | def test_stratified_sampler(self): 65 | """Test stratified sampler""" 66 | subset_dataset = single_label_stratified_sample(self.dataset, label_column="coarse_label", 67 | num_examples_per_class=2) 68 | label_idxs = list(range(len(self.dataset.features["coarse_label"].names))) 69 | self.assertEqual(len(subset_dataset), 2 * len(label_idxs)) 70 | 71 | for occurences in Counter(subset_dataset["coarse_label"]).values(): 72 | self.assertEqual(occurences, 2) 73 | -------------------------------------------------------------------------------- /tests/test_dataset_transformations.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from datasets import load_dataset 4 | 5 | from fabricator.prompts import BasePrompt 6 | from fabricator.dataset_transformations.question_answering import * 7 | from fabricator.dataset_transformations.text_classification import * 8 | from fabricator.dataset_transformations.token_classification import * 9 | 10 | 11 | class TestTransformationsTextClassification(unittest.TestCase): 12 | """Testcase for ClassLabelPrompt""" 13 | 14 | def setUp(self) -> None: 15 | self.dataset = load_dataset("trec", split="train") 16 | 17 | def test_label_ids_to_textual_label(self): 18 | """Test transformation output only""" 19 | dataset, label_options = convert_label_ids_to_texts(self.dataset, "coarse_label", return_label_options=True) 20 | self.assertEqual(len(label_options), 6) 21 | self.assertEqual(set(label_options), set(self.dataset.features["coarse_label"].names)) 22 | self.assertEqual(type(dataset[0]["coarse_label"]), str) 23 | self.assertNotEqual(type(dataset[0]["coarse_label"]), int) 24 | self.assertIn(dataset[0]["coarse_label"], label_options) 25 | 26 | def test_formatting_with_textual_labels(self): 27 | """Test formatting with textual labels""" 28 | dataset, label_options = convert_label_ids_to_texts(self.dataset, "coarse_label", return_label_options=True) 29 | fewshot_examples = dataset.select([1, 2, 3]) 30 | prompt = BasePrompt( 31 | task_description="Annotate the question into following categories: {}.", 32 | generate_data_for_column="coarse_label", 33 | fewshot_example_columns="text", 34 | label_options=label_options, 35 | ) 36 | raw_prompt = prompt.get_prompt_text(label_options, fewshot_examples) 37 | self.assertIn("text: What films featured the character Popeye Doyle ?\ncoarse_label: ENTY", raw_prompt) 38 | for label in label_options: 39 | self.assertIn(label, raw_prompt) 40 | 41 | def test_expanded_textual_labels(self): 42 | """Test formatting with expanded textual labels""" 43 | extended_mapping = { 44 | "DESC": "Description", 45 | "ENTY": "Entity", 46 | "ABBR": "Abbreviation", 47 | "HUM": "Human", 48 | "NUM": "Number", 49 | "LOC": "Location", 50 | } 51 | dataset, label_options = convert_label_ids_to_texts( 52 | self.dataset, "coarse_label", expanded_label_mapping=extended_mapping, return_label_options=True 53 | ) 54 | self.assertIn("Location", label_options) 55 | self.assertNotIn("LOC", label_options) 56 | fewshot_examples = dataset.select([1, 2, 3]) 57 | prompt = BasePrompt( 58 | task_description="Annotate the question into following categories: {}.", 59 | generate_data_for_column="coarse_label", 60 | fewshot_example_columns="text", 61 | label_options=label_options, 62 | ) 63 | raw_prompt = prompt.get_prompt_text(label_options, fewshot_examples) 64 | self.assertIn("text: What films featured the character Popeye Doyle ?\ncoarse_label: Entity", raw_prompt) 65 | self.assertNotIn("ENTY", raw_prompt) 66 | for label in label_options: 67 | self.assertIn(label, raw_prompt) 68 | 69 | def test_textual_labels_to_label_ids(self): 70 | """Test conversion back to label ids""" 71 | dataset, label_options = convert_label_ids_to_texts(self.dataset, "coarse_label", return_label_options=True) 72 | self.assertIn(dataset[0]["coarse_label"], label_options) 73 | dataset = dataset.class_encode_column("coarse_label") 74 | self.assertIn(dataset[0]["coarse_label"], range(len(label_options))) 75 | 76 | def test_false_inputs_raises_error(self): 77 | """Test that false inputs raise errors""" 78 | with self.assertRaises(AttributeError): 79 | dataset, label_options = convert_label_ids_to_texts("coarse_label", "dataset") 80 | 81 | 82 | class TestTransformationsTokenClassification(unittest.TestCase): 83 | """Testcase for TokenLabelTransformations""" 84 | 85 | def setUp(self) -> None: 86 | self.dataset = load_dataset("conll2003", split="train").select(range(150)) 87 | 88 | def test_bio_tokens_to_spans(self): 89 | """Test transformation output only (BIO to spans)""" 90 | dataset, label_options = convert_token_labels_to_spans( 91 | self.dataset, "tokens", "ner_tags", return_label_options=True 92 | ) 93 | self.assertEqual(len(label_options), 4) 94 | self.assertEqual(type(dataset[0]["ner_tags"]), str) 95 | self.assertNotEqual(type(dataset[0]["ner_tags"]), int) 96 | spans = [span for span in dataset[0]["ner_tags"].split("\n")] 97 | for span in spans: 98 | self.assertTrue(any([label in span for label in label_options])) 99 | 100 | def test_formatting_with_span_labels(self): 101 | """Test formatting with span labels""" 102 | dataset, label_options = convert_token_labels_to_spans( 103 | dataset=self.dataset, 104 | token_column="tokens", 105 | label_column="ner_tags", 106 | return_label_options=True 107 | ) 108 | fewshot_examples = dataset.select([1, 2, 3]) 109 | prompt = BasePrompt( 110 | task_description="Annotate each of the following tokens with the following labels: {}.", 111 | generate_data_for_column="ner_tags", 112 | fewshot_example_columns="tokens", 113 | label_options=label_options, 114 | ) 115 | raw_prompt = prompt.get_prompt_text(label_options, fewshot_examples) 116 | self.assertIn("Peter Blackburn is PER entity.", raw_prompt) 117 | self.assertIn("BRUSSELS is LOC entity.", raw_prompt) 118 | for label in label_options: 119 | self.assertIn(label, raw_prompt) 120 | 121 | def test_expanded_textual_labels(self): 122 | """Test formatting with expanded textual labels""" 123 | expanded_label_mapping = { 124 | 0: "O", 125 | 1: "B-person", 126 | 2: "I-person", 127 | 3: "B-location", 128 | 4: "I-location", 129 | 5: "B-organization", 130 | 6: "I-organization", 131 | 7: "B-miscellaneous", 132 | 8: "I-miscellaneous", 133 | } 134 | 135 | dataset, label_options = convert_token_labels_to_spans( 136 | dataset=self.dataset, 137 | token_column="tokens", 138 | label_column="ner_tags", 139 | expanded_label_mapping=expanded_label_mapping, 140 | return_label_options=True 141 | ) 142 | fewshot_examples = dataset.select([1, 2, 3]) 143 | prompt = BasePrompt( 144 | task_description="Annotate each of the following tokens with the following labels: {}.", 145 | generate_data_for_column="ner_tags", 146 | fewshot_example_columns="tokens", 147 | label_options=label_options, 148 | ) 149 | raw_prompt = prompt.get_prompt_text(label_options, fewshot_examples) 150 | self.assertIn("Peter Blackburn is person entity.", raw_prompt) 151 | self.assertNotIn("PER", raw_prompt) 152 | for label in label_options: 153 | self.assertIn(label, raw_prompt) 154 | 155 | def test_textual_labels_to_label_ids(self): 156 | """Test conversion back to label ids on token-level""" 157 | dataset, label_options = convert_token_labels_to_spans( 158 | dataset=self.dataset, 159 | token_column="tokens", 160 | label_column="ner_tags", 161 | return_label_options=True 162 | ) 163 | id2label = dict(enumerate(self.dataset.features["ner_tags"].feature.names)) 164 | self.assertEqual(dataset[0]["ner_tags"], "EU is ORG entity.\nGerman is MISC entity.\nBritish is MISC entity.") 165 | dataset = dataset.select(range(10)) 166 | dataset = convert_spans_to_token_labels( 167 | dataset=dataset, 168 | token_column="tokens", 169 | label_column="ner_tags", 170 | id2label=id2label 171 | ) 172 | for label in dataset[0]["ner_tags"]: 173 | self.assertIn(label, id2label.keys()) 174 | 175 | def test_false_inputs_raises_error(self): 176 | """Test that false inputs raise errors""" 177 | with self.assertRaises(AttributeError): 178 | dataset, label_options = convert_token_labels_to_spans( 179 | "ner_tags", "tokens", {1: "a", 2: "b", 3: "c"} 180 | ) 181 | 182 | with self.assertRaises(AttributeError): 183 | dataset, label_options = convert_token_labels_to_spans( 184 | {1: "a", 2: "b", 3: "c"}, "tokens", "ner_tags" 185 | ) 186 | 187 | 188 | class TestTransformationsQuestionAnswering(unittest.TestCase): 189 | """Testcase for QA Transformations""" 190 | 191 | def setUp(self) -> None: 192 | self.dataset = load_dataset("squad_v2", split="train") 193 | 194 | def test_squad_preprocessing(self): 195 | """Test transformation from squad fromat into flat structure""" 196 | self.assertEqual(type(self.dataset[0]["answers"]), dict) 197 | dataset = preprocess_squad_format(self.dataset.select(range(30))) 198 | self.assertEqual(type(dataset[0]["answers"]), str) 199 | with self.assertRaises(KeyError): 200 | x = dataset[0]["answer_start"] 201 | 202 | def test_squad_postprocessing(self): 203 | """Test transformation flat structure into squad format""" 204 | dataset = preprocess_squad_format(self.dataset.select(range(50))) 205 | dataset = postprocess_squad_format(dataset) 206 | self.assertEqual(type(dataset[0]["answers"]), dict) 207 | self.assertIn("answer_start", dataset[0]["answers"]) 208 | -------------------------------------------------------------------------------- /tests/test_prompts.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from datasets import load_dataset, Dataset, QuestionAnsweringExtractive, TextClassification, Summarization 4 | 5 | from fabricator.prompts import ( 6 | BasePrompt, 7 | infer_prompt_from_task_template, 8 | ) 9 | 10 | 11 | class TestPrompt(unittest.TestCase): 12 | """Testcase for Prompts""" 13 | 14 | def setUp(self) -> None: 15 | """Set up test dataset""" 16 | self.dataset = Dataset.from_dict({ 17 | "text": ["This movie is great!", "This movie is bad!"], 18 | "label": ["positive", "negative"] 19 | }) 20 | 21 | def test_plain_template(self): 22 | """Test plain prompt template""" 23 | prompt_template = BasePrompt(task_description="Generate movies reviews.") 24 | self.assertEqual(prompt_template.get_prompt_text(), 'Generate movies reviews.\n\ntext: ') 25 | self.assertEqual(prompt_template.target_formatting_template, "text: ") 26 | self.assertEqual(prompt_template.generate_data_for_column, None) 27 | self.assertEqual(prompt_template.fewshot_example_columns, None) 28 | 29 | with self.assertRaises(TypeError): 30 | prompt_template = BasePrompt() 31 | 32 | def test_template_with_label_options(self): 33 | """Test prompt template with label options""" 34 | label_options = ["positive", "negative"] 35 | prompt_template = BasePrompt( 36 | task_description="Generate a {} movie review.", 37 | label_options=label_options, 38 | ) 39 | self.assertIn("positive", prompt_template.get_prompt_text(label_options[0])) 40 | self.assertIn("negative", prompt_template.get_prompt_text(label_options[1])) 41 | self.assertEqual(prompt_template.target_formatting_template, "text: ") 42 | 43 | def test_initialization_only_target_column(self): 44 | """Test initialization with only target column""" 45 | prompt_template = BasePrompt( 46 | task_description="Generate similar movie reviews.", 47 | generate_data_for_column="text", 48 | ) 49 | self.assertEqual(prompt_template.relevant_columns_for_fewshot_examples, ["text"]) 50 | self.assertEqual(type(prompt_template.generate_data_for_column), list) 51 | self.assertEqual(len(prompt_template.generate_data_for_column), 1) 52 | self.assertEqual(prompt_template.fewshot_example_columns, None) 53 | 54 | prompt_text = 'Generate similar movie reviews.\n\ntext: This movie is great!\n\ntext: ' \ 55 | 'This movie is bad!\n\ntext: ' 56 | 57 | self.assertEqual(prompt_template.get_prompt_text(None, self.dataset), prompt_text) 58 | 59 | def test_initialization_target_and_fewshot_columns(self): 60 | """Test initialization with target and fewshot columns""" 61 | prompt_template = BasePrompt( 62 | task_description="Generate movie reviews.", 63 | generate_data_for_column="label", 64 | fewshot_example_columns="text" 65 | ) 66 | self.assertEqual(prompt_template.relevant_columns_for_fewshot_examples, ["text", "label"]) 67 | self.assertEqual(type(prompt_template.generate_data_for_column), list) 68 | self.assertEqual(len(prompt_template.generate_data_for_column), 1) 69 | self.assertEqual(type(prompt_template.fewshot_example_columns), list) 70 | 71 | prompt_text = 'Generate movie reviews.\n\ntext: This movie is great!\nlabel: positive\n\n' \ 72 | 'text: This movie is bad!\nlabel: negative\n\ntext: {text}\nlabel: ' 73 | 74 | self.assertEqual(prompt_template.get_prompt_text(None, self.dataset), prompt_text) 75 | 76 | def test_initialization_with_multiple_fewshot_columns(self): 77 | """Test initialization with multiple fewshot columns""" 78 | text_label_prompt = BasePrompt( 79 | task_description="Test two fewshot columns.", 80 | generate_data_for_column="label", 81 | fewshot_example_columns=["fewshot1", "fewshot2"], 82 | ) 83 | self.assertEqual(text_label_prompt.relevant_columns_for_fewshot_examples, ["fewshot1", "fewshot2", "label"]) 84 | self.assertEqual(type(text_label_prompt.fewshot_example_columns), list) 85 | self.assertEqual(len(text_label_prompt.fewshot_example_columns), 2) 86 | 87 | def test_custom_formatting_template(self): 88 | label_options = ["positive", "negative"] 89 | 90 | fewshot_examples = Dataset.from_dict({ 91 | "text": ["This movie is great!", "This movie is bad!"], 92 | "label": label_options 93 | }) 94 | 95 | prompt = BasePrompt( 96 | task_description="Annotate the sentiment of the following movie review whether it is: {}.", 97 | generate_data_for_column="label", 98 | fewshot_example_columns="text", 99 | fewshot_formatting_template="Movie Review: {text}\nSentiment: {label}", 100 | target_formatting_template="Movie Review: {text}\nSentiment: ", 101 | label_options=label_options, 102 | ) 103 | 104 | prompt_text = prompt.get_prompt_text(label_options, fewshot_examples) 105 | 106 | self.assertIn("whether it is: positive, negative.", prompt_text) 107 | self.assertIn("Movie Review: This movie is great!\nSentiment: positive", prompt_text) 108 | self.assertIn("Movie Review: This movie is bad!\nSentiment: negative", prompt_text) 109 | self.assertIn("Movie Review: {text}\nSentiment: ", prompt.target_formatting_template) 110 | 111 | 112 | class TestDownstreamTasks(unittest.TestCase): 113 | """Testcase for downstream tasks""" 114 | 115 | def setUp(self) -> None: 116 | """Set up test datasets""" 117 | 118 | def preprocess_qa(example): 119 | if example["answer"]: 120 | example["answer"] = example["answer"].pop() 121 | else: 122 | example["answer"] = "" 123 | return example 124 | 125 | self.text_classification = load_dataset("trec", split="train").select([1, 2, 3]) 126 | self.question_answering = load_dataset("squad", split="train").flatten()\ 127 | .rename_column("answers.text", "answer").map(preprocess_qa).select([1, 2, 3]) 128 | self.ner = load_dataset("conll2003", split="train").select([1, 2, 3]) 129 | self.translation = load_dataset("opus100", "de-nl", split="test").flatten()\ 130 | .rename_columns({"translation.de": "german", "translation.nl": "dutch"}).select([1, 2, 3]) 131 | 132 | def test_translation(self): 133 | prompt = BasePrompt( 134 | task_description="Given a german phrase, translate it into dutch.", 135 | generate_data_for_column="dutch", 136 | fewshot_example_columns="german", 137 | ) 138 | 139 | raw_prompt = prompt.get_prompt_text(None, self.translation) 140 | self.assertIn("Marktorganisation für Wein(1)", raw_prompt) 141 | self.assertIn("Gelet op Verordening (EG) nr. 1493/1999", raw_prompt) 142 | self.assertIn("na z'n partijtje golf", raw_prompt) 143 | self.assertIn("dutch: ", raw_prompt) 144 | 145 | def test_text_classification(self): 146 | label_options = self.text_classification.features["coarse_label"].names 147 | 148 | prompt = BasePrompt( 149 | task_description="Classify the question into one of the following categories: {}", 150 | label_options=label_options, 151 | generate_data_for_column="coarse_label", 152 | fewshot_example_columns="text", 153 | ) 154 | 155 | raw_prompt = prompt.get_prompt_text(label_options, self.text_classification) 156 | self.assertIn(", ".join(label_options), raw_prompt) 157 | self.assertIn("What fowl grabs the spotlight after the Chinese Year of the Monkey ?", raw_prompt) 158 | self.assertIn("How can I find a list of celebrities ' real names ?", raw_prompt) 159 | self.assertIn("What films featured the character Popeye Doyle ?", raw_prompt) 160 | self.assertIn("coarse_label: 2", raw_prompt) 161 | 162 | def test_named_entity_recognition(self): 163 | label_options = self.ner.features["ner_tags"].feature.names 164 | 165 | prompt = BasePrompt( 166 | task_description="Classify each token into one of the following categories: {}", 167 | generate_data_for_column="ner_tags", 168 | fewshot_example_columns="tokens", 169 | label_options=label_options, 170 | ) 171 | 172 | raw_prompt = prompt.get_prompt_text(label_options, self.ner) 173 | self.assertIn(", ".join(label_options), raw_prompt) 174 | self.assertIn("'BRUSSELS', '1996-08-22'", raw_prompt) 175 | self.assertIn("'Peter', 'Blackburn'", raw_prompt) 176 | self.assertIn("3, 4, 0, 0, 0, 0, 0, 0, 7, 0", raw_prompt) 177 | self.assertIn("ner_tags: [1, 2]", raw_prompt) 178 | 179 | def test_question_answering(self): 180 | prompt = BasePrompt( 181 | task_description="Given context and question, answer the question.", 182 | generate_data_for_column="answer", 183 | fewshot_example_columns=["context", "question"], 184 | ) 185 | 186 | raw_prompt = prompt.get_prompt_text(None, self.question_answering) 187 | self.assertIn("answer: the Main Building", raw_prompt) 188 | self.assertIn("context: Architecturally, the school", raw_prompt) 189 | self.assertIn("question: The Basilica", raw_prompt) 190 | self.assertIn("context: {context}", raw_prompt) 191 | 192 | 193 | class TestAutoInference(unittest.TestCase): 194 | """Testcase for AutoInference""" 195 | 196 | def test_auto_infer_text_label_prompt(self): 197 | """Test auto inference of QuestionAnsweringExtractive task template""" 198 | task_template = QuestionAnsweringExtractive() 199 | prompt = infer_prompt_from_task_template(task_template) 200 | self.assertIsInstance(prompt, BasePrompt) 201 | self.assertEqual(prompt.fewshot_example_columns, ["context", "question"]) 202 | self.assertEqual(prompt.generate_data_for_column, ["answers"]) 203 | 204 | def test_auto_infer_class_label_prompt(self): 205 | """Test auto inference of TextClassification task template""" 206 | task_template = TextClassification() 207 | task_template.label_schema["labels"].names = ["neg", "pos"] 208 | prompt = infer_prompt_from_task_template(task_template) 209 | self.assertIsInstance(prompt, BasePrompt) 210 | self.assertEqual(prompt.fewshot_example_columns, ["text"]) 211 | self.assertEqual(prompt.generate_data_for_column, ["labels"]) 212 | 213 | def test_auto_infer_fails_for_unsupported_task(self): 214 | """Test auto inference of prompt fails for unsupported task template Summarization""" 215 | with self.assertRaises(ValueError): 216 | infer_prompt_from_task_template(Summarization()) 217 | -------------------------------------------------------------------------------- /tutorials/TUTORIAL-1_OVERVIEW.md: -------------------------------------------------------------------------------- 1 | # Tutorial 1: Fabricator Introduction 2 | 3 | ## 1) Dataset Generation 4 | 5 | ### 1.1) Recipe for Dataset Generation 📚 6 | When starting from scratch, to generate an arbitrary dataset, you need to implement some instance of: 7 | 8 | - **_Datasets_**: For few-shot examples and final storage of (pair-wise) data to train a small PLM. 9 | - **_LLMs_**: To annotate existing, unlabeled datasets or generate completely new ones. 10 | - **_Prompts_**: To format the provided inputs (task description, few-shot examples, etc.) to prompt the LLM. 11 | - **_Orchestrator_**: To aligns all components and steer the generation process. 12 | 13 | ### 1.2) Creating a Workflow From Scratch Requires Careful Consideration of Intricate Details 👨‍🍳 14 | The following figure illustrates the typical generation workflow when using large language models as teachers for 15 | smaller pre-trained language models (PLMs) like BERT. Establishing this workflow demands attention to implementation 16 | details and requires boilerplate. Further, the setup process may vary based on a particular LLM or dataset format. 17 | 18 |
19 | Generation Worklow 20 |
The generation workflow when using LLMs as a teacher to smaller PLMs such as BERT.
21 |
22 | 23 | ### 1.3) Efficiently Generate Datasets With Fabricator 🍜 24 | 25 | With fabricator, you simply need to define your generation settings, 26 | e.g. how many few-shot examples to include per prompt, how to sample few-shot instances from a pool of available 27 | examples, or which LLM to use. In addition, everything is built on top of Hugging Face's 28 | [datasets](https://github.com/huggingface/datasets) library that you can directly 29 | incorporate the generated datasets in your usual training workflows or share them via the Hugging Face hub. 30 | 31 | ## 2) Fabricator Compoments 32 | 33 | ### 2.1) Datasets 34 | 35 | Datasets are build upon the `Dataset` class of the huggingface datasets library. They are used to store the data in a 36 | tabular format and provide a convenient way to access the data. The generated datasets will always be in that format such 37 | that they can be easily integrated with standard machine learning frameworks or shared with the research community via 38 | the huggingface hub. 39 | 40 | ```python 41 | from datasets import load_dataset 42 | 43 | # Load the imdb dataset from the huggingface hub, i.e. for annotation 44 | dataset = load_dataset("imdb") 45 | 46 | # Load custom dataset from a jsonl file 47 | dataset = load_dataset("json", data_files="path/to/file.jsonl") 48 | 49 | # Share generated dataset with the huggingface hub 50 | dataset.push_to_hub("my-dataset") 51 | ``` 52 | 53 | ### 2.2) LLMs 54 | We simply use haystack's `PromptNode` as our LLM interface. The PromptNode is a wrapper for multiple LLMs such as the ones 55 | from OpenAI or all available models on the huggingface hub. You can set all generation-related parameters such as 56 | temperature, top_k, maximum generation length via the PromptNode (see also the [documentation](https://docs.haystack.deepset.ai/docs/prompt_node)). 57 | 58 | ```python 59 | import os 60 | from haystack.nodes import PromptNode 61 | 62 | # Load a model from huggingface hub 63 | prompt_node = PromptNode("google/flan-t5-base") 64 | 65 | # Create a PromptNode with the OpenAI API 66 | prompt_node = PromptNode( 67 | model_name_or_path="text-davinci-003", 68 | api_key=os.environ.get("OPENAI_API_KEY"), 69 | ) 70 | ``` 71 | 72 | ### 2.3) Prompts 73 | 74 | The `BasePrompt` class is used to format the prompts for the LLMs. This class is highly flexible and thus can be 75 | adapted to various settings: 76 | - define a `task_description` (e.g. "Generate a [_label_] movie review.") to generate data for certain class, e.g. a movie review for the label "positive". 77 | - include pre-defined `label_options` (e.g. "Annotate the following review with one of the followings labels: positive, negative.") when annotating unlabeled datasets. 78 | - customize format of fewshot examples inside the prompt 79 | 80 | #### Prompt for generating plain text 81 | 82 | ```python 83 | from fabricator.prompts import BasePrompt 84 | 85 | prompt_template = BasePrompt(task_description="Generate a movie review.") 86 | print(prompt_template.get_prompt_text()) 87 | ``` 88 | Prompt during generation: 89 | ```console 90 | Generate a movie reviews. 91 | 92 | text: 93 | ``` 94 | 95 | #### Prompt for label-conditioned generation with label options 96 | 97 | ```python 98 | from fabricator.prompts import BasePrompt 99 | 100 | label_options = ["positive", "negative"] 101 | prompt_template = BasePrompt( 102 | task_description="Generate a {} movie review.", 103 | label_options=label_options, 104 | ) 105 | ``` 106 | 107 | Label-conditioned prompts during generation: 108 | ```console 109 | Generate a positive movie review. 110 | 111 | text: 112 | --- 113 | 114 | Generate a negative movie review. 115 | 116 | text: 117 | --- 118 | ``` 119 | 120 | Note: You can define in the orchestrator class the desired distribution of labels, e.g. uniformly 121 | sampling from both labels in the examples in each iteration. 122 | 123 | #### Prompt with few-shot examples 124 | 125 | ```python 126 | from datasets import Dataset 127 | from fabricator.prompts import BasePrompt 128 | 129 | label_options = ["positive", "negative"] 130 | 131 | fewshot_examples = Dataset.from_dict({ 132 | "text": ["This movie is great!", "This movie is bad!"], 133 | "label": label_options 134 | }) 135 | 136 | prompt_template = BasePrompt( 137 | task_description="Generate a {} movie review.", 138 | label_options=label_options, 139 | generate_data_for_column="text", 140 | ) 141 | ``` 142 | 143 | Prompts with few-shot examples during generation: 144 | ```console 145 | Generate a positive movie review. 146 | 147 | text: This movie is great! 148 | 149 | text: 150 | --- 151 | 152 | Generate a negative movie review. 153 | 154 | text: This movie is bad! 155 | 156 | text: 157 | --- 158 | ``` 159 | 160 | Note: The `generate_data_for_column` attribute defines the column of the few-shot dataset for which additional data is generated. 161 | As previously shown, the orchestrator will sample a label and includes a matching few-shot example. 162 | 163 | ### 2.4) DatasetGenerator 164 | 165 | The `DatasetGenerator` class is fabricator's orchestrator. It takes a `Dataset`, a `PromptNode` and a 166 | `BasePrompt` as inputs and generates the final dataset based on these instances. The `generate` method returns a `Dataset` object that can be used with standard machine learning 167 | frameworks such as [flair](https://github.com/flairNLP/flair), deepset's [haystack](https://github.com/deepset-ai/haystack), or Hugging Face's [transformers](https://github.com/huggingface/transformers). 168 | 169 | ```python 170 | from fabricator import BasePrompt, DatasetGenerator 171 | 172 | prompt = BasePrompt(task_description="Generate a movie review.") 173 | 174 | prompt_node = PromptNode("google/flan-t5-base") 175 | 176 | generator = DatasetGenerator(prompt_node) 177 | generated_dataset = generator.generate( 178 | prompt_template=prompt, 179 | max_prompt_calls=10, 180 | ) 181 | ``` 182 | 183 | In the following [tutorial](TUTORIAL-2_GENERATION_WORKFLOWS.md), we introduce the different generation processes covered by fabricator. 184 | -------------------------------------------------------------------------------- /tutorials/TUTORIAL-2_GENERATION_WORKFLOWS.md: -------------------------------------------------------------------------------- 1 | # Tutorial 2: Generation Workflows 2 | 3 | In this tutorial, you will learn: 4 | 1. how to generate datasets 5 | 2. how to annotate unlabeled datasets 6 | 3. how to configure hyperparameters for your generation process 7 | 8 | ## 1) Generating Datasets 9 | 10 | ### 1.1) Generating Plain Text 11 | 12 | In this example, we demonstrate how to merge components created in previous tutorials by fabricators to create a 13 | movie review dataset. We don't explicitly direct the Language Learning Model (LLM) to generate movie reviews for 14 | specific labels (such as binary sentiment) or offer a few examples to guide the LLM in generating similar content. 15 | Instead, all it requires is a task description. The LLM then produces a dataset containing movie reviews based on the 16 | provided instructions. This dataset can be easily uploaded to the Hugging Face Hub. 17 | 18 | ```python 19 | import os 20 | from haystack.nodes import PromptNode 21 | from fabricator import DatasetGenerator 22 | from fabricator.prompts import BasePrompt 23 | 24 | prompt = BasePrompt( 25 | task_description="Generate a very very short movie review.", 26 | ) 27 | 28 | prompt_node = PromptNode( 29 | model_name_or_path="gpt-3.5-turbo", 30 | api_key=os.environ.get("OPENAI_API_KEY"), 31 | max_length=100, 32 | ) 33 | 34 | generator = DatasetGenerator(prompt_node) 35 | generated_dataset = generator.generate( 36 | prompt_template=prompt, 37 | max_prompt_calls=10, 38 | ) 39 | 40 | generated_dataset.push_to_hub("your-first-generated-dataset") 41 | ``` 42 | 43 | ### 1.2) Generate Label-Conditioned Datasets With Label Options and Few-Shot Examples 44 | To create datasets that are conditioned on specific labels and use few-shot examples, 45 | we need a few-shot dataset that is already annotated. The prompt should have the same labels as those in the few-shot 46 | dataset. Additionally, as explained in a previous tutorial, we must set the `generate_data_for_column` 47 | parameter to specify the column in the dataset for which we want to generate text. 48 | 49 | In the dataset generator, we define certain hyperparameters for the generation process. `fewshot_examples_per_class` 50 | determines how many few-shot examples are incorporated for each class per prompt. `fewshot_sampling_strategy` 51 | can be set to either "uniform" if each label has an equal chance of being sampled, 52 | or "stratified" if the distribution from the few-shot dataset needs to be preserved. 53 | `fewshot_sampling_column` specifies the dataset column representing the classes. `max_prompt_calls` 54 | sets the limit for how many prompts should be generated. 55 | 56 | Crucially, the prompt instance contains all the details about how a single prompt for a specific data point should 57 | be structured. This includes information like which few-shot examples should appear alongside which task instruction. 58 | On the other hand, the dataset generator defines the overall generation process, 59 | such as determining the label distribution, specified by the `fewshot_sampling_column`. 60 | 61 | ```python 62 | import os 63 | from datasets import Dataset 64 | from haystack.nodes import PromptNode 65 | from fabricator import DatasetGenerator 66 | from fabricator.prompts import BasePrompt 67 | 68 | label_options = ["positive", "negative"] 69 | 70 | fewshot_dataset = Dataset.from_dict({ 71 | "text": ["This movie is great!", "This movie is bad!"], 72 | "label": label_options 73 | }) 74 | 75 | prompt = BasePrompt( 76 | task_description="Generate a {} movie review.", 77 | label_options=label_options, 78 | generate_data_for_column="text", 79 | ) 80 | 81 | prompt_node = PromptNode( 82 | model_name_or_path="gpt-3.5-turbo", 83 | api_key=os.environ.get("OPENAI_API_KEY"), 84 | max_length=100, 85 | ) 86 | 87 | generator = DatasetGenerator(prompt_node) 88 | generated_dataset = generator.generate( 89 | prompt_template=prompt, 90 | fewshot_dataset=fewshot_dataset, 91 | fewshot_examples_per_class=1, 92 | fewshot_sampling_strategy="uniform", 93 | fewshot_sampling_column="label", 94 | max_prompt_calls=10, 95 | ) 96 | 97 | generated_dataset.push_to_hub("your-first-generated-dataset") 98 | ``` 99 | 100 | ## 2) Annotate unlabeled data with fewshot examples 101 | 102 | This example demonstrates how to add annotations to unlabeled data using few-shot examples. We have a few-shot dataset containing two columns: `text` and `label`, and an unlabeled dataset with only a `text` column. The goal is to annotate the unlabeled dataset using information from the few-shot dataset. 103 | 104 | To achieve this, we utilize the `DatasetGenerator.generate()` method. To begin, we provide the `unlabeled_dataset` argument, indicating the dataset we want to annotate. Additionally, we specify the `fewshot_examples_per_class` argument, determining how many few-shot examples to use for each class. In this scenario, we choose one example per class. 105 | 106 | The `fewshot_sampling_strategy` argument dictates how the few-shot dataset is sampled. In this case, we employ a stratified sampling strategy. This means that the generator will select precisely one example from each class within the few-shot dataset. 107 | 108 | It's worth noting that there's no need to explicitly specify the `fewshot_sampling_column` argument. By default, the generator uses the column specified in `generate_data_for_column` for this purpose. 109 | 110 | ```python 111 | import os 112 | from datasets import Dataset 113 | from haystack.nodes import PromptNode 114 | from fabricator import DatasetGenerator 115 | from fabricator.prompts import BasePrompt 116 | 117 | label_options = ["positive", "negative"] 118 | 119 | fewshot_dataset = Dataset.from_dict({ 120 | "text": ["This movie is great!", "This movie is bad!"], 121 | "label": label_options 122 | }) 123 | 124 | unlabeled_dataset = Dataset.from_dict({ 125 | "text": ["This movie was a blast!", "This movie was not bad!"], 126 | }) 127 | 128 | prompt = BasePrompt( 129 | task_description="Annotate movie reviews as either: {}.", 130 | label_options=label_options, 131 | generate_data_for_column="label", 132 | fewshot_example_columns="text", 133 | ) 134 | 135 | prompt_node = PromptNode( 136 | model_name_or_path="gpt-3.5-turbo", 137 | api_key=os.environ.get("OPENAI_API_KEY"), 138 | max_length=100, 139 | ) 140 | 141 | generator = DatasetGenerator(prompt_node) 142 | generated_dataset = generator.generate( 143 | prompt_template=prompt, 144 | fewshot_dataset=fewshot_dataset, 145 | fewshot_examples_per_class=1, 146 | fewshot_sampling_strategy="stratified", 147 | unlabeled_dataset=unlabeled_dataset, 148 | max_prompt_calls=10, 149 | ) 150 | 151 | generated_dataset.push_to_hub("your-first-generated-dataset") 152 | ``` 153 | -------------------------------------------------------------------------------- /tutorials/TUTORIAL-3_ADVANCED-GENERATION.md: -------------------------------------------------------------------------------- 1 | # Tutorial 3: Advanced Dataset Generation 2 | 3 | ## Customizing Prompts 4 | 5 | Sometimes, you want to customize your prompt to your specific needs. For example, you might want to add a custom 6 | formatting template (the default takes the column names of the dataset): 7 | 8 | ```python 9 | from datasets import Dataset 10 | from fabricator.prompts import BasePrompt 11 | 12 | label_options = ["positive", "negative"] 13 | 14 | fewshot_examples = Dataset.from_dict({ 15 | "text": ["This movie is great!", "This movie is bad!"], 16 | "label": label_options 17 | }) 18 | 19 | prompt = BasePrompt( 20 | task_description="Annotate the sentiment of the following movie review whether it is: {}.", 21 | generate_data_for_column="label", 22 | fewshot_example_columns="text", 23 | fewshot_formatting_template="Movie Review: {text}\nSentiment: {label}", 24 | target_formatting_template="Movie Review: {text}\nSentiment: ", 25 | label_options=label_options, 26 | ) 27 | 28 | print(prompt.get_prompt_text(label_options, fewshot_examples)) 29 | ``` 30 | 31 | which yields: 32 | 33 | ```text 34 | Annotate the sentiment of the following movie review whether it is: positive, negative. 35 | 36 | Movie Review: This movie is great! 37 | Sentiment: positive 38 | 39 | Movie Review: This movie is bad! 40 | Sentiment: negative 41 | 42 | Movie Review: {text} 43 | Sentiment: 44 | ``` 45 | 46 | ## Inferring the Prompt from Dataset Info 47 | 48 | Huggingface Dataset objects provide the possibility to infer a prompt from the dataset. This can be achieved by using 49 | the `infer_prompt_from_dataset` function. This function takes a dataset 50 | as input and returns a `BasePrompt` object. The `BasePrompt` object contains the task description, the label options 51 | and the respective columns which can be used to generate a dataset with the `DatasetGenerator` class. 52 | 53 | ```python 54 | from datasets import load_dataset 55 | from fabricator.prompts import infer_prompt_from_dataset 56 | 57 | dataset = load_dataset("imdb", split="train") 58 | prompt = infer_prompt_from_dataset(dataset) 59 | 60 | print(prompt.get_prompt_text() + "\n---") 61 | 62 | label_options = dataset.features["label"].names 63 | fewshot_example = dataset.shuffle(seed=42).select([0]) 64 | 65 | print(prompt.get_prompt_text(label_options, fewshot_example)) 66 | ``` 67 | 68 | The output of this script is: 69 | 70 | ```text 71 | Classify the following texts exactly into one of the following categories: {}. 72 | 73 | text: {text} 74 | label: 75 | --- 76 | Classify the following texts exactly into one of the following categories: neg, pos. 77 | 78 | text: There is no relation at all between Fortier and Profiler but the fact that both are police series [...] 79 | label: 1 80 | 81 | text: {text} 82 | label: 83 | ``` 84 | 85 | This feature is particularly useful, if you have nested structures that follow a common format such as for 86 | extractive question answering: 87 | 88 | ```python 89 | from datasets import load_dataset 90 | from fabricator.prompts import infer_prompt_from_dataset 91 | 92 | dataset = load_dataset("squad_v2", split="train") 93 | prompt = infer_prompt_from_dataset(dataset) 94 | 95 | print(prompt.get_prompt_text() + "\n---") 96 | 97 | label_options = dataset.features["label"].names 98 | fewshot_example = dataset.shuffle(seed=42).select([0]) 99 | 100 | print(prompt.get_prompt_text(label_options, fewshot_example)) 101 | ``` 102 | 103 | This script outputs: 104 | 105 | ```text 106 | Given a context and a question, generate an answer that occurs exactly and only once in the text. 107 | 108 | context: {context} 109 | question: {question} 110 | answers: 111 | --- 112 | Given a context and a question, generate an answer that occurs exactly and only once in the text. 113 | 114 | context: The Roman Catholic Church canon law also includes the main five rites (groups) of churches which are in full union with the Roman Catholic Church and the Supreme Pontiff: 115 | question: What term characterizes the intersection of the rites with the Roman Catholic Church? 116 | answers: {'text': ['full union'], 'answer_start': [104]} 117 | 118 | context: {context} 119 | question: {question} 120 | answers: 121 | ``` 122 | 123 | ## Preprocess datasets 124 | 125 | In the previous example, we highlighted the simplicity of generating prompts using Hugging Face Datasets information. 126 | However, for optimal utilization of LLMs in generating text, it's essential to incorporate label names instead of IDs 127 | for text classification. Similarly, for question answering tasks, plain substrings are preferred over JSON-formatted 128 | strings. We'll elaborate on these limitations in the following example. 129 | 130 | ```text 131 | Classify the following texts exactly into one of the following categories: **neg, pos**. 132 | 133 | text: There is no relation at all between Fortier and Profiler but the fact that both are police series [...] 134 | **label: 1** 135 | 136 | --- 137 | 138 | Given a context and a question, generate an answer that occurs exactly and only once in the text. 139 | 140 | context: The Roman Catholic Church canon law also includes the main five rites (groups) of churches which are in full union with the Roman Catholic Church and the Supreme Pontiff: 141 | question: What term characterizes the intersection of the rites with the Roman Catholic Church? 142 | answers: **{'text': ['full union'], 'answer_start': [104]}** 143 | ``` 144 | 145 | To overcome this, we provide a range of preprocessing functions for various downstream tasks. 146 | 147 | ### Text Classification 148 | 149 | The `convert_label_ids_to_texts` function transforms your text classification dataset with label IDs into textual 150 | labels. The default will be the label names specified in the features column. 151 | 152 | ```python 153 | from datasets import load_dataset 154 | from fabricator.prompts import infer_prompt_from_dataset 155 | from fabricator.dataset_transformations.text_classification import convert_label_ids_to_texts 156 | 157 | dataset = load_dataset("imdb", split="train") 158 | prompt = infer_prompt_from_dataset(dataset) 159 | dataset, label_options = convert_label_ids_to_texts( 160 | dataset=dataset, 161 | label_column="label", 162 | return_label_options=True 163 | ) 164 | 165 | fewshot_example = dataset.shuffle(seed=42).select([0]) 166 | print(prompt.get_prompt_text(label_options, fewshot_example)) 167 | ``` 168 | 169 | Which yields: 170 | 171 | ```text 172 | Classify the following texts exactly into one of the following categories: neg, pos. 173 | 174 | text: There is no relation at all between Fortier and Profiler but the fact that both are police series [...] 175 | label: pos 176 | 177 | text: {text} 178 | label: 179 | ``` 180 | 181 | If we want to provide more meaningful names we can do so by specifying `expanded_label_mapping`. 182 | Remember to update the label options accordingly in the `BasePrompt` class. 183 | 184 | ```python 185 | extended_mapping = {0: "negative", 1: "positive"} 186 | dataset, label_options = convert_label_ids_to_texts( 187 | dataset=dataset, 188 | label_column="label", 189 | expanded_label_mapping=extended_mapping, 190 | return_label_options=True 191 | ) 192 | prompt.label_options = label_options 193 | ``` 194 | 195 | This yields: 196 | 197 | ```text 198 | Classify the following texts exactly into one of the following categories: positive, negative. 199 | 200 | text: There is no relation at all between Fortier and Profiler but the fact that both are police series [...] 201 | label: positive 202 | 203 | text: {text} 204 | label: 205 | ``` 206 | 207 | Once the dataset is generated, one can easily convert the string labels back to label IDs by 208 | using huggingface's `class_encode_labels` function. 209 | 210 | ```python 211 | dataset = dataset.class_encode_column("label") 212 | print("Labels: " + str(dataset["label"][:5])) 213 | print("Features: " + str(dataset.features["label"])) 214 | ``` 215 | 216 | Which yields: 217 | 218 | ```text 219 | Labels: [0, 1, 1, 0, 0] 220 | Features: ClassLabel(names=['negative', 'positive'], id=None) 221 | ``` 222 | 223 | Note: While generating the dataset, the model is supposed to assign labels based on the specific options 224 | provided. However, we do not filter the data if it doesn't adhere to these predefined labels. 225 | Therefore, it's important to double-check if the annotations match the expected label options. 226 | If they don't, you should make corrections accordingly. 227 | 228 | ### Question Answering (Extractive) 229 | 230 | In question answering tasks, we offer two functions to handle dataset processing: preprocessing and postprocessing. 231 | The preprocessing function is responsible for transforming datasets from SQuAD format into flat strings. 232 | On the other hand, the postprocessing function reverses this process by converting flat predictions back into 233 | SQuAD format. It achieves this by determining the starting point of the answer and checking if the answer cannot be 234 | found in the given context or if it occurs multiple times. 235 | 236 | ```python 237 | from datasets import load_dataset 238 | from fabricator.prompts import infer_prompt_from_dataset 239 | from fabricator.dataset_transformations.question_answering import preprocess_squad_format, postprocess_squad_format 240 | 241 | dataset = load_dataset("squad_v2", split="train") 242 | prompt = infer_prompt_from_dataset(dataset) 243 | 244 | dataset = preprocess_squad_format(dataset) 245 | 246 | print(prompt.get_prompt_text(None, dataset.select([0])) + "\n---") 247 | 248 | dataset = postprocess_squad_format(dataset) 249 | print(dataset[0]["answers"]) 250 | ``` 251 | 252 | Which yields: 253 | 254 | ```text 255 | Given a context and a question, generate an answer that occurs exactly and only once in the text. 256 | 257 | context: Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny's Child. Managed by her father, Mathew Knowles, the group became one of the world's best-selling girl groups of all time. Their hiatus saw the release of Beyoncé's debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles "Crazy in Love" and "Baby Boy". 258 | question: When did Beyonce start becoming popular? 259 | answers: in the late 1990s 260 | 261 | context: {context} 262 | question: {question} 263 | answers: 264 | --- 265 | {'start': 269, 'text': 'in the late 1990s'} 266 | ``` 267 | 268 | ### Named Entity Recognition 269 | 270 | If you attempt to create a dataset for named entity recognition without any preprocessing, the prompt might be 271 | challenging for the language model to understand. 272 | 273 | ```python 274 | from datasets import load_dataset 275 | from fabricator.prompts import BasePrompt 276 | 277 | dataset = load_dataset("conll2003", split="train") 278 | prompt = BasePrompt( 279 | task_description="Annotate each token with its named entity label: {}.", 280 | generate_data_for_column="ner_tags", 281 | fewshot_example_columns=["tokens"], 282 | label_options=dataset.features["ner_tags"].feature.names, 283 | ) 284 | 285 | print(prompt.get_prompt_text(prompt.label_options, dataset.select([0]))) 286 | ``` 287 | 288 | Which outputs: 289 | 290 | ```text 291 | Annotate each token with its named entity label: O, B-PER, I-PER, B-ORG, I-ORG, B-LOC, I-LOC, B-MISC, I-MISC. 292 | 293 | tokens: ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'] 294 | ner_tags: [3, 0, 7, 0, 0, 0, 7, 0, 0] 295 | 296 | tokens: {tokens} 297 | ner_tags: 298 | ``` 299 | 300 | To enhance prompt clarity, we can preprocess the dataset by converting labels into spans. This conversion can be 301 | accomplished using the `convert_token_labels_to_spans` function. Additionally, the function will provide the 302 | available label options: 303 | 304 | ```python 305 | from datasets import load_dataset 306 | from fabricator.prompts import BasePrompt 307 | from fabricator.dataset_transformations import convert_token_labels_to_spans 308 | 309 | dataset = load_dataset("conll2003", split="train") 310 | dataset, label_options = convert_token_labels_to_spans(dataset, "tokens", "ner_tags", return_label_options=True) 311 | prompt = BasePrompt( 312 | task_description="Annotate each token with its named entity label: {}.", 313 | generate_data_for_column="ner_tags", 314 | fewshot_example_columns=["tokens"], 315 | label_options=label_options, 316 | ) 317 | 318 | print(prompt.get_prompt_text(prompt.label_options, dataset.select([0]))) 319 | ``` 320 | Output: 321 | ```text 322 | Annotate each token with its named entity label: MISC, ORG, PER, LOC. 323 | 324 | tokens: EU rejects German call to boycott British lamb . 325 | ner_tags: EU is ORG entity. 326 | German is MISC entity. 327 | British is MISC entity. 328 | 329 | tokens: {tokens} 330 | ner_tags: 331 | ``` 332 | 333 | As in text classification, we can also specify more semantically precise labels with the `expanded_label_mapping`: 334 | 335 | ```python 336 | expanded_label_mapping = { 337 | 0: "O", 338 | 1: "B-person", 339 | 2: "I-person", 340 | 3: "B-location", 341 | 4: "I-location", 342 | 5: "B-organization", 343 | 6: "I-organization", 344 | 7: "B-miscellaneous", 345 | 8: "I-miscellaneous", 346 | } 347 | 348 | dataset, label_options = convert_token_labels_to_spans( 349 | dataset=dataset, 350 | token_column="tokens", 351 | label_column="ner_tags", 352 | expanded_label_mapping=expanded_label_mapping, 353 | return_label_mapping=True 354 | ) 355 | ``` 356 | 357 | Output: 358 | 359 | ```text 360 | Annotate each token with its named entity label: organization, person, location, miscellaneous. 361 | 362 | tokens: EU rejects German call to boycott British lamb . 363 | ner_tags: EU is organization entity. 364 | German is miscellaneous entity. 365 | British is miscellaneous entity. 366 | 367 | tokens: {tokens} 368 | ner_tags: 369 | ``` 370 | 371 | Once the dataset is created, we can use the `convert_spans_to_token_labels` function to convert spans back to labels 372 | IDs. This function will only add spans the occur only once in the text. If a span occurs multiple times, it will be 373 | ignored. Note: this takes rather long is currently build on spacy. We are working on a faster implementation or welcome 374 | contributions. 375 | 376 | ```python 377 | from fabricator.dataset_transformations import convert_spans_to_token_labels 378 | 379 | dataset = convert_spans_to_token_labels( 380 | dataset=dataset.select(range(20)), 381 | token_column="tokens", 382 | label_column="ner_tags", 383 | id2label=expanded_label_mapping 384 | ) 385 | ``` 386 | 387 | Outputs: 388 | 389 | ```Text 390 | {'id': '1', 'tokens': ['Peter', 'Blackburn'], 'pos_tags': [22, 22], 'chunk_tags': [11, 12], 'ner_tags': [1, 2]} 391 | ``` 392 | 393 | ## Adapt to arbitrary datasets 394 | 395 | The `BasePrompt` class is designed to be easily adaptable to arbitrary datasets. Just like in the examples for text 396 | classification, token classification or question answering, you can specify the task description, the column to generate 397 | data for and the fewshot example columns. The only difference is that you have to specify the optional label options 398 | yourself. 399 | 400 | ### Translation 401 | 402 | ```python 403 | import os 404 | from datasets import Dataset 405 | from haystack.nodes import PromptNode 406 | from fabricator import DatasetGenerator 407 | from fabricator.prompts import BasePrompt 408 | 409 | fewshot_dataset = Dataset.from_dict({ 410 | "german": ["Der Film ist großartig!", "Der Film ist schlecht!"], 411 | "english": ["This movie is great!", "This movie is bad!"], 412 | }) 413 | 414 | unlabeled_dataset = Dataset.from_dict({ 415 | "english": ["This movie was a blast!", "This movie was not bad!"], 416 | }) 417 | 418 | prompt = BasePrompt( 419 | task_description="Translate to german:", # Since we do not have a label column, 420 | # we can just specify the task description 421 | generate_data_for_column="german", 422 | fewshot_example_columns="english", 423 | ) 424 | 425 | prompt_node = PromptNode( 426 | model_name_or_path="gpt-3.5-turbo", 427 | api_key=os.environ.get("OPENAI_API_KEY"), 428 | max_length=100, 429 | ) 430 | 431 | generator = DatasetGenerator(prompt_node) 432 | generated_dataset = generator.generate( 433 | prompt_template=prompt, 434 | fewshot_dataset=fewshot_dataset, 435 | fewshot_examples_per_class=2, # Take both fewshot examples per prompt 436 | fewshot_sampling_strategy=None, # Since we do not have a class label column, we can just set this to None 437 | # (default) 438 | unlabeled_dataset=unlabeled_dataset, 439 | max_prompt_calls=2, 440 | ) 441 | 442 | generated_dataset.push_to_hub("your-first-generated-dataset") 443 | ``` 444 | 445 | ### Textual similarity 446 | 447 | ```python 448 | import os 449 | from datasets import load_dataset 450 | from haystack.nodes import PromptNode 451 | from fabricator import DatasetGenerator 452 | from fabricator.prompts import BasePrompt 453 | from fabricator.dataset_transformations.text_classification import convert_label_ids_to_texts 454 | 455 | dataset = load_dataset("glue", "mrpc", split="train") 456 | dataset, label_options = convert_label_ids_to_texts(dataset, "label", return_label_options=True) # convert the 457 | # label ids to text labels and return the label options 458 | 459 | fewshot_dataset = dataset.select(range(10)) 460 | unlabeled_dataset = dataset.select(range(10, 20)) 461 | 462 | prompt = BasePrompt( 463 | task_description="Annotate the sentence pair whether it is: {}", 464 | label_options=label_options, 465 | generate_data_for_column="label", 466 | fewshot_example_columns=["sentence1", "sentence2"], # we can pass an array of columns to use for the fewshot 467 | ) 468 | 469 | prompt_node = PromptNode( 470 | model_name_or_path="gpt-3.5-turbo", 471 | api_key=os.environ.get("OPENAI_API_KEY"), 472 | max_length=100, 473 | ) 474 | 475 | generator = DatasetGenerator(prompt_node) 476 | generated_dataset, original_dataset = generator.generate( 477 | prompt_template=prompt, 478 | fewshot_dataset=fewshot_dataset, 479 | fewshot_examples_per_class=1, # Take 1 fewshot examples per class per prompt 480 | fewshot_sampling_column="label", # We want to sample fewshot examples based on the label column 481 | fewshot_sampling_strategy="stratified", # We want to sample fewshot examples stratified by class 482 | unlabeled_dataset=unlabeled_dataset, 483 | max_prompt_calls=2, 484 | return_unlabeled_dataset=True, # We can return the original unlabelled dataset which might be interesting in this 485 | # case to compare the annotation quality 486 | ) 487 | 488 | generated_dataset = generated_dataset.class_encode_column("label") 489 | 490 | generated_dataset.push_to_hub("your-first-generated-dataset") 491 | ``` 492 | 493 | You can also easily switch out columns to be annotated if you want, for example, to generate a second sentence given a 494 | first sentence and a label like: 495 | 496 | ```python 497 | import os 498 | from datasets import load_dataset 499 | from haystack.nodes import PromptNode 500 | from fabricator import DatasetGenerator 501 | from fabricator.prompts import BasePrompt 502 | from fabricator.dataset_transformations.text_classification import convert_label_ids_to_texts 503 | 504 | dataset = load_dataset("glue", "mrpc", split="train") 505 | dataset, label_options = convert_label_ids_to_texts(dataset, "label", return_label_options=True) # convert the 506 | # label ids to text labels and return the label options 507 | 508 | fewshot_dataset = dataset.select(range(10)) 509 | unlabeled_dataset = dataset.select(range(10, 20)) 510 | 511 | prompt = BasePrompt( 512 | task_description="Generate a sentence that is {} to sentence1.", 513 | label_options=label_options, 514 | generate_data_for_column="sentence2", 515 | fewshot_example_columns=["sentence1", "label"], # we can pass an array of columns to use for the fewshot 516 | ) 517 | 518 | prompt_node = PromptNode( 519 | model_name_or_path="gpt-3.5-turbo", 520 | api_key=os.environ.get("OPENAI_API_KEY"), 521 | max_length=100, 522 | ) 523 | 524 | generator = DatasetGenerator(prompt_node) 525 | generated_dataset, original_dataset = generator.generate( 526 | prompt_template=prompt, 527 | fewshot_dataset=fewshot_dataset, 528 | fewshot_examples_per_class=1, # Take 1 fewshot examples per class per prompt 529 | fewshot_sampling_column="label", # We want to sample fewshot examples based on the label column 530 | fewshot_sampling_strategy="stratified", # We want to sample fewshot examples stratified by class 531 | unlabeled_dataset=unlabeled_dataset, 532 | max_prompt_calls=2, 533 | return_unlabeled_dataset=True, # We can return the original unlabelled dataset which might be interesting in this 534 | # case to compare the annotation quality 535 | ) 536 | 537 | generated_dataset = generated_dataset.class_encode_column("label") 538 | 539 | generated_dataset.push_to_hub("your-first-generated-dataset") 540 | ``` 541 | --------------------------------------------------------------------------------