├── .github
└── workflows
│ ├── publish-package.yml
│ ├── pylint.yml
│ └── pytest.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── pyproject.toml
├── requirements.txt
├── resources
├── generation_workflow.png
├── logo_fabricator.drawio_dark.png
└── logo_fabricator.drawio_white.png
├── setup.py
├── src
└── fabricator
│ ├── __init__.py
│ ├── dataset_generator.py
│ ├── dataset_transformations
│ ├── __init__.py
│ ├── question_answering.py
│ ├── text_classification.py
│ └── token_classification.py
│ ├── prompts
│ ├── __init__.py
│ ├── base.py
│ └── utils.py
│ ├── samplers
│ ├── __init__.py
│ └── samplers.py
│ └── utils.py
├── tests
├── test_dataset_generator.py
├── test_dataset_sampler.py
├── test_dataset_transformations.py
└── test_prompts.py
└── tutorials
├── TUTORIAL-1_OVERVIEW.md
├── TUTORIAL-2_GENERATION_WORKFLOWS.md
└── TUTORIAL-3_ADVANCED-GENERATION.md
/.github/workflows/publish-package.yml:
--------------------------------------------------------------------------------
1 | name: Publish Release to PyPI
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | jobs:
8 | lint:
9 | name: Linting Checks
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 | - name: Set up Python
14 | uses: actions/setup-python@v4
15 | with:
16 | python-version: "3.10"
17 | - name: Install dependencies
18 | run: |
19 | python -m pip install --upgrade pip
20 | python -m pip install -r requirements.txt
21 | pip install pylint
22 | - name: Analysing the code with pylint
23 | run: |
24 | pylint $(git ls-files '*.py') --fail-under 9.3
25 |
26 | build:
27 | name: Build distribution
28 | runs-on: ubuntu-latest
29 |
30 | steps:
31 | - uses: actions/checkout@v4
32 | - name: Set up Python
33 | uses: actions/setup-python@v4
34 | with:
35 | python-version: "3.10"
36 | - name: Install pypa/build
37 | run: >-
38 | python3 -m
39 | pip install
40 | build
41 | --user
42 | - name: Build binary wheel and source tarball
43 | run: python3 -m build
44 | - name: Store distribution packages
45 | uses: actions/upload-artifact@v3
46 | with:
47 | name: python-package-distributions
48 | path: dist/
49 |
50 | publish-to-testpypi:
51 | name: Publish distribution to TestPyPI
52 | needs:
53 | - lint
54 | - build
55 | runs-on: ubuntu-latest
56 |
57 | environment:
58 | name: testpypi
59 | url: https://test.pypi.org/p/fabricator-ai
60 |
61 | permissions:
62 | contents: read
63 | id-token: write
64 |
65 | steps:
66 | - name: Download all the dists
67 | uses: actions/download-artifact@v3
68 | with:
69 | name: python-package-distributions
70 | path: dist/
71 | - name: Publish distribution to TestPyPI
72 | uses: pypa/gh-action-pypi-publish@release/v1
73 | with:
74 | repository-url: https://test.pypi.org/legacy/
75 |
76 | publish-to-pypi:
77 | name: >-
78 | Publish Release to PyPI
79 | if: startsWith(github.ref, 'refs/tags/')
80 | needs:
81 | - publish-to-testpypi
82 | runs-on: ubuntu-latest
83 | environment:
84 | name: pypi
85 | url: https://pypi.org/p/fabricator-ai
86 | permissions:
87 | contents: read
88 | id-token: write
89 | steps:
90 | - name: Download all the dists
91 | uses: actions/download-artifact@v3
92 | with:
93 | name: python-package-distributions
94 | path: dist/
95 | - name: Publish distribution to PyPI
96 | uses: pypa/gh-action-pypi-publish@release/v1
--------------------------------------------------------------------------------
/.github/workflows/pylint.yml:
--------------------------------------------------------------------------------
1 | name: Pylint
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | matrix:
10 | python-version: ["3.8", "3.9", "3.10"]
11 | steps:
12 | - uses: actions/checkout@v3
13 | - name: Set up Python ${{ matrix.python-version }}
14 | uses: actions/setup-python@v3
15 | with:
16 | python-version: ${{ matrix.python-version }}
17 | - name: Install dependencies
18 | run: |
19 | python -m pip install --upgrade pip
20 | python -m pip install -r requirements.txt
21 | pip install pylint
22 | - name: Analysing the code with pylint
23 | run: |
24 | pylint $(git ls-files '*.py') --fail-under 9.3
25 |
--------------------------------------------------------------------------------
/.github/workflows/pytest.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Python package
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | python-version: ["3.8", "3.9", "3.10"]
20 |
21 | steps:
22 | - uses: actions/checkout@v3
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v3
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | python -m pip install pytest pytest-sugar
31 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
32 | - name: Test with pytest
33 | run: |
34 | python -m pytest
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # Logs
156 | logs/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | 
3 |
4 |
A flexible open-source framework to generate datasets with large language models.
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | [Installation](#installation) | [Basic Concepts](#basic-concepts) | [Examples](#examples) | [Tutorials](tutorials/TUTORIAL-1_OVERVIEW.md) |
14 | [Paper](https://arxiv.org/abs/2309.09582) | [Citation](#citation)
15 |
16 |
17 |
18 |
19 | ## News
20 |
21 | - **[10/23]** We released the first version of this repository on PyPI. You can install it via `pip install fabricator-ai`.
22 | - **[10/23]** Our paper got accepted at EMNLP 2023. You can find the preprint [here](https://arxiv.org/abs/2309.09582). You can find the experimental scripts under release v0.1.0.
23 | - **[09/23]** Support for `gpt-3.5-turbo-instruct` added in the new [Haystack](https://github.com/deepset-ai/haystack) release!
24 | - **[08/23]** Added several experimental scripts to investigate the generation and annotation ability of `gpt-3.5-turbo` on various downstream tasks + the influence of few-shot examples on the performance for different downstream tasks.
25 | - **[07/23]** Refactorings of majors classes - you can now simply use our BasePrompt class to create your own customized prompts for every downstream task!
26 | - **[07/23]** Added dataset transformations for token classification to prompt LLMs with textual spans rather than with list of tags.
27 | - **[06/23]** Initial version of fabricator supporting text classification and question answering tasks.
28 |
29 | ## Overview
30 |
31 | This repository:
32 |
33 | - is an easy-to-use open-source library to generate datasets with large language models. If you want to train
34 | a model on a specific domain / label distribution / downstream task, you can use this framework to generate
35 | a dataset for it.
36 | - builds on top of deepset's haystack and huggingface's datasets libraries. Thus, we support a wide range
37 | of language models and you can load and use the generated datasets as you know it from the Datasets library for your
38 | model training.
39 | - is highly flexible and offers various adaptions possibilities such as
40 | prompt customization, integration and sampling of fewshot examples or annotation of the unlabeled datasets.
41 |
42 | ## Installation
43 | Using conda:
44 | ```
45 | git clone git@github.com:flairNLP/fabricator.git
46 | cd fabricator
47 | conda create -y -n fabricator python=3.10
48 | conda activate fabricator
49 | pip install fabricator-ai
50 | ```
51 |
52 | If you want to install in editable mode, you can use the following command:
53 | ```
54 | pip install -e .
55 | ```
56 |
57 | ## Basic Concepts
58 |
59 | This framework is based on the idea of using large language models to generate datasets for specific tasks. To do so,
60 | we need four basic modules: a dataset, a prompt, a language model and a generator:
61 | - Dataset : We use [huggingface's datasets library](https://github.com/huggingface/datasets) to load fewshot or
62 | unlabeled datasets and store the generated or annotated datasets with their `Dataset` class. Once
63 | created, you can share the dataset with others via the hub or use it for your model training.
64 | - Prompt : A prompt is the instruction made to the language model. It can be a simple sentence or a more complex
65 | template with placeholders. We provide an easy interface for custom dataset generation prompts in which you can specify
66 | label options for the LLM to choose from, provide fewshot examples to support the prompt with or annotate an unlabeled
67 | dataset in a specific way.
68 | - LLM : We use [deepset's haystack library](https://github.com/deepset-ai/haystack) as our LLM interface. deepset
69 | supports a wide range of LLMs including OpenAI, all models from the HuggingFace model hub and many more.
70 | - Generator : The generator is the core of this framework. It takes a dataset, a prompt and a LLM and generates a
71 | dataset based on your specifications.
72 |
73 | ## Examples
74 |
75 | With our library, you can generate datasets for any task you want. You can start as simple
76 | as that:
77 |
78 | ### Generate a dataset from scratch
79 |
80 | ```python
81 | import os
82 | from haystack.nodes import PromptNode
83 | from fabricator import DatasetGenerator
84 | from fabricator.prompts import BasePrompt
85 |
86 | prompt = BasePrompt(
87 | task_description="Generate a short movie review.",
88 | )
89 |
90 | prompt_node = PromptNode(
91 | model_name_or_path="gpt-3.5-turbo",
92 | api_key=os.environ.get("OPENAI_API_KEY"),
93 | max_length=100,
94 | )
95 |
96 | generator = DatasetGenerator(prompt_node)
97 | generated_dataset = generator.generate(
98 | prompt_template=prompt,
99 | max_prompt_calls=10,
100 | )
101 |
102 | generated_dataset.push_to_hub("your-first-generated-dataset")
103 | ```
104 |
105 | In our tutorial, we introduce how to create classification datasets with label options to choose from, how to include
106 | fewshot examples or how to annotate unlabeled data into predefined categories.
107 |
108 | ## Citation
109 |
110 | If you find this repository useful, please cite our work.
111 |
112 | ```
113 | @inproceedings{golde2023fabricator,
114 | title = "Fabricator: An Open Source Toolkit for Generating Labeled Training Data with Teacher {LLM}s",
115 | author = "Golde, Jonas and Haller, Patrick and Hamborg, Felix and Risch, Julian and Akbik, Alan",
116 | booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing: System Demonstrations",
117 | month = dec,
118 | year = "2023",
119 | address = "Singapore",
120 | publisher = "Association for Computational Linguistics",
121 | url = "https://aclanthology.org/2023.emnlp-demo.1",
122 | pages = "1--11",
123 | }
124 | ```
125 |
126 |
127 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 119
3 |
4 | [tool.pylint.format]
5 | max-line-length = 119
6 |
7 | [tool.pylint.MASTER]
8 | ignore-patterns = '''test_.*\.py,
9 | setup.py
10 | '''
11 | ignore-paths = [ "^paper_experiments/.*$"]
12 |
13 | [tool.pylint."MESSAGE CONTROL"]
14 | disable = '''logging-fstring-interpolation,
15 | missing-module-docstring,
16 | '''
17 |
18 | [tool.pytest.ini_options]
19 | pythonpath = [
20 | "src"
21 | ]
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets
2 | farm-haystack>=1.18.0
3 | loguru
4 |
--------------------------------------------------------------------------------
/resources/generation_workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/flairNLP/fabricator/8dc23321a67ad03bef371a560452f0e9b6ff5afa/resources/generation_workflow.png
--------------------------------------------------------------------------------
/resources/logo_fabricator.drawio_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/flairNLP/fabricator/8dc23321a67ad03bef371a560452f0e9b6ff5afa/resources/logo_fabricator.drawio_dark.png
--------------------------------------------------------------------------------
/resources/logo_fabricator.drawio_white.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/flairNLP/fabricator/8dc23321a67ad03bef371a560452f0e9b6ff5afa/resources/logo_fabricator.drawio_white.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from setuptools import setup, find_packages
3 |
4 |
5 | def requirements():
6 | with open("requirements.txt", "r") as f:
7 | return f.read().splitlines()
8 |
9 |
10 | setup(
11 | name='fabricator-ai',
12 | version='0.2.0',
13 | author='Humboldt University Berlin, deepset GmbH',
14 | author_email="goldejon@informatik.hu-berlin.de",
15 | description='Conveniently generating datasets with large language models.',
16 | long_description=Path("README.md").read_text(encoding="utf-8"),
17 | long_description_content_type="text/markdown",
18 | package_dir={"": "src"},
19 | packages=find_packages("src"),
20 | license="Apache 2.0",
21 | python_requires=">=3.8",
22 | install_requires=requirements(),
23 | )
24 |
--------------------------------------------------------------------------------
/src/fabricator/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.1"
2 |
3 | from .prompts import (
4 | BasePrompt,
5 | infer_prompt_from_dataset,
6 | infer_prompt_from_task_template
7 | )
8 | from .dataset_transformations import *
9 | from .samplers import *
10 | from .dataset_generator import DatasetGenerator
11 |
--------------------------------------------------------------------------------
/src/fabricator/dataset_generator.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 |
4 | from pathlib import Path
5 | from collections import defaultdict
6 | from typing import Any, Callable, Dict, Optional, Union, Tuple, List
7 | from tqdm import tqdm
8 | from loguru import logger
9 |
10 | from datasets import Dataset
11 | from numpy.random import choice
12 | from haystack.nodes import PromptNode
13 | from haystack.nodes import PromptTemplate as HaystackPromptTemplate
14 |
15 | from .prompts import BasePrompt
16 | from .samplers import single_label_stratified_sample
17 | from .utils import log_dir, create_timestamp_path
18 |
19 |
20 | class DatasetGenerator:
21 | """The DatasetGenerator class is the main class of the fabricator package.
22 | It generates datasets based on a prompt template. The main function is generate()."""
23 |
24 | def __init__(self, prompt_node: PromptNode, max_tries: int = 10):
25 | """Initialize the DatasetGenerator with a prompt node.
26 |
27 | Args:
28 | prompt_node (PromptNode): Prompt node / LLM from haystack.
29 | """
30 | self.prompt_node = prompt_node
31 | self._base_log_dir = log_dir()
32 | self._max_tries = max_tries
33 |
34 | def _setup_log(self, prompt_template: BasePrompt) -> Path:
35 | """For every generation run create a new log file.
36 | Current format: _.jsonl
37 |
38 | Args:
39 | prompt_template (BasePrompt): Prompt template to generate the dataset with.
40 |
41 | Returns:
42 | Path: Path to the log file.
43 |
44 | """
45 | timestamp_path = create_timestamp_path(self._base_log_dir)
46 | log_file = Path(f"{timestamp_path}_{prompt_template.__class__.__name__}.jsonl")
47 | log_file.parent.mkdir(parents=True, exist_ok=True)
48 | log_file.touch()
49 | return log_file
50 |
51 | def generate(
52 | self,
53 | prompt_template: BasePrompt,
54 | fewshot_dataset: Optional[Dataset] = None,
55 | fewshot_sampling_strategy: Optional[str] = None,
56 | fewshot_examples_per_class: int = None,
57 | fewshot_sampling_column: Optional[str] = None,
58 | unlabeled_dataset: Optional[Dataset] = None,
59 | return_unlabeled_dataset: bool = False,
60 | max_prompt_calls: int = 10,
61 | num_samples_to_generate: int = 10,
62 | timeout_per_prompt: Optional[int] = None,
63 | log_every_n_api_calls: int = 25,
64 | dummy_response: Optional[Union[str, Callable]] = None
65 | ) -> Union[Dataset, Tuple[Dataset, Dataset]]:
66 | """Generate a dataset based on a prompt template and support examples.
67 | Optionally, unlabeled examples can be provided to annotate unlabeled data.
68 |
69 | Args:
70 | prompt_template (BasePrompt): Prompt template to generate the dataset with.
71 | fewshot_dataset (Dataset): Support examples to generate the dataset from. Defaults to None.
72 | fewshot_sampling_strategy (str, optional): Sampling strategy for support examples.
73 | Defaults to None and means all fewshot examples are used or limited by number of
74 | fewshot_examples_per_class.
75 | "uniform" sampling strategy means that fewshot examples for a uniformly sampled label are used.
76 | "stratified" sampling strategy means that fewshot examples uniformly selected from each label.
77 | fewshot_examples_per_class (int, optional): Number of support examples for a certain class per prompt.
78 | Defaults to None.
79 | fewshot_sampling_column (str, optional): Column to sample from. Defaults to None and function will try
80 | to sample from the generate_data_for_column attribute of the prompt template.
81 | unlabeled_dataset (Optional[Dataset], optional): Unlabeled examples to annotate. Defaults to None.
82 | return_unlabeled_dataset (bool, optional): Whether to return the original dataset. Defaults to False.
83 | max_prompt_calls (int, optional): Maximum number of prompt calls. Defaults to 10.
84 | num_samples_to_generate (int, optional): Number of samples to generate. Defaults to 10.
85 | timeout_per_prompt (Optional[int], optional): Timeout per prompt call. Defaults to None.
86 | log_every_n_api_calls (int, optional): Log every n api calls. Defaults to 25.
87 | dummy_response (Optional[Union[str, Callable]], optional): Dummy response for dry runs. Defaults to None.
88 |
89 | Returns:
90 | Union[Dataset, Tuple[Dataset, Dataset]]: Generated dataset or tuple of generated dataset and original
91 | dataset.
92 | """
93 | if fewshot_dataset:
94 | self._assert_fewshot_dataset_matches_prompt(prompt_template, fewshot_dataset)
95 |
96 | assert fewshot_sampling_strategy in [None, "uniform", "stratified"], \
97 | "Sampling strategy must be 'uniform' or 'stratified'"
98 |
99 | if fewshot_dataset and not fewshot_sampling_column:
100 | fewshot_sampling_column = prompt_template.generate_data_for_column[0]
101 |
102 | generated_dataset, original_dataset = self._inner_generate_loop(
103 | prompt_template,
104 | fewshot_dataset,
105 | fewshot_examples_per_class,
106 | fewshot_sampling_strategy,
107 | fewshot_sampling_column,
108 | unlabeled_dataset,
109 | return_unlabeled_dataset,
110 | max_prompt_calls,
111 | num_samples_to_generate,
112 | timeout_per_prompt,
113 | log_every_n_api_calls,
114 | dummy_response
115 | )
116 |
117 | if return_unlabeled_dataset:
118 | return generated_dataset, original_dataset
119 |
120 | return generated_dataset
121 |
122 | def _try_generate(
123 | self, prompt_text: str, invocation_context: Dict, dummy_response: Optional[Union[str, Callable]]
124 | ) -> Optional[str]:
125 | """Tries to generate a single example. Restrict the time spent on this.
126 |
127 | Args:
128 | prompt_text: Prompt text to generate an example for.
129 | invocation_context: Invocation context to generate an example for.
130 | dry_run: Whether to actually generate the example or just return a dummy example.
131 |
132 | Returns:
133 | Generated example
134 | """
135 |
136 | if dummy_response:
137 |
138 | if isinstance(dummy_response, str):
139 | logger.info(f"Returning dummy response: {dummy_response}")
140 | return dummy_response
141 |
142 | if callable(dummy_response):
143 | dummy_value = dummy_response(prompt_text)
144 | logger.info(f"Returning dummy response: {dummy_response}")
145 | return dummy_value
146 |
147 | raise ValueError("Dummy response must be a string or a callable")
148 |
149 | # Haystack internally uses timeouts and retries, so we dont have to do it
150 | # We dont catch authentification errors here, because we want to fail fast
151 | try:
152 | prediction = self.prompt_node.run(
153 | prompt_template=HaystackPromptTemplate(prompt=prompt_text),
154 | invocation_context=invocation_context,
155 | )[0]["results"]
156 | except Exception as error:
157 | logger.error(f"Error while generating example: {error}")
158 | return None
159 |
160 | return prediction
161 |
162 | def _inner_generate_loop(
163 | self,
164 | prompt_template: BasePrompt,
165 | fewshot_dataset: Dataset,
166 | fewshot_examples_per_class: int,
167 | fewshot_sampling_strategy: str,
168 | fewshot_sampling_column: str,
169 | unlabeled_dataset: Dataset,
170 | return_unlabeled_dataset: bool,
171 | max_prompt_calls: int,
172 | num_samples_to_generate: int,
173 | timeout_per_prompt: Optional[int],
174 | log_every_n_api_calls: int = 25,
175 | dummy_response: Optional[Union[str, Callable]] = None
176 | ):
177 | current_tries_left = self._max_tries
178 | current_log_file = self._setup_log(prompt_template)
179 |
180 | generated_dataset = defaultdict(list)
181 | original_dataset = defaultdict(list)
182 |
183 | if unlabeled_dataset:
184 | api_calls = range(min(max_prompt_calls, len(unlabeled_dataset)))
185 | else:
186 | api_calls = range(min(max_prompt_calls, num_samples_to_generate))
187 |
188 | for prompt_call_idx, unlabeled_example_idx in tqdm(
189 | enumerate(api_calls, start=1), desc="Generating dataset", total=len(api_calls)
190 | ):
191 | fewshot_examples = None
192 | unlabeled_example = None
193 | invocation_context = None
194 | prompt_labels = None
195 |
196 | if prompt_template.label_options:
197 | # At some point: how can we do label-conditioned generation without fewshot examples? Currently it
198 | # require a second parameter for sample from label options and not from fewshot examples
199 | prompt_labels = prompt_template.label_options
200 |
201 | if fewshot_dataset:
202 | prompt_labels, fewshot_examples = self._sample_fewshot_examples(
203 | prompt_template, fewshot_dataset, fewshot_sampling_strategy, fewshot_examples_per_class,
204 | fewshot_sampling_column
205 | )
206 |
207 | prompt_text = prompt_template.get_prompt_text(prompt_labels, fewshot_examples)
208 |
209 | if unlabeled_dataset:
210 | unlabeled_example = unlabeled_dataset[unlabeled_example_idx]
211 | invocation_context = prompt_template.filter_example_by_columns(
212 | unlabeled_example, prompt_template.fewshot_example_columns
213 | )
214 |
215 | if log_every_n_api_calls > 0:
216 | if prompt_call_idx % log_every_n_api_calls == 0:
217 | logger.info(
218 | f"Current prompt call: {prompt_call_idx}: \n"
219 | f"Prompt: {prompt_text} \n"
220 | f"Invocation context: {invocation_context} \n"
221 | )
222 |
223 | prediction = self._try_generate(prompt_text, invocation_context, dummy_response)
224 |
225 | if prediction is None:
226 | current_tries_left -= 1
227 | logger.warning(f"Could not generate example for prompt {prompt_text}.")
228 | if current_tries_left == 0:
229 | logger.warning(
230 | f"Max tries ({self._max_tries}) exceeded. Returning generated dataset with"
231 | f" {len(generated_dataset)} examples."
232 | )
233 | break
234 |
235 | if len(prediction) == 1:
236 | prediction = prediction[0]
237 |
238 | # If we have a target variable, we re-use the relevant columns of the input example
239 | # and add the prediction to the generated dataset
240 | if prompt_template.generate_data_for_column and unlabeled_example:
241 | generated_sample = prompt_template.filter_example_by_columns(
242 | unlabeled_example, prompt_template.fewshot_example_columns
243 | )
244 |
245 | for key, value in generated_sample.items():
246 | generated_dataset[key].append(value)
247 |
248 | # Try to safely convert the prediction to the type of the target variable
249 | if not prompt_template.generate_data_for_column[0] in unlabeled_example:
250 | prediction = self._convert_prediction(
251 | prediction, type(prompt_template.generate_data_for_column[0])
252 | )
253 |
254 | generated_dataset[prompt_template.generate_data_for_column[0]].append(prediction)
255 |
256 | else:
257 | generated_dataset[prompt_template.DEFAULT_TEXT_COLUMN[0]].append(prediction)
258 | if prompt_labels and isinstance(prompt_labels, str):
259 | generated_dataset[prompt_template.DEFAULT_LABEL_COLUMN[0]].append(prompt_labels)
260 |
261 | log_entry = {
262 | "prompt": prompt_text,
263 | "invocation_context": invocation_context,
264 | "prediction": prediction,
265 | "target": prompt_template.generate_data_for_column[0]
266 | if prompt_template.generate_data_for_column
267 | else prompt_template.DEFAULT_TEXT_COLUMN[0],
268 | }
269 | with open(current_log_file, "a", encoding="utf-8") as log_file:
270 | log_file.write(f"{json.dumps(log_entry)}\n")
271 |
272 | if return_unlabeled_dataset:
273 | for key, value in unlabeled_example.items():
274 | original_dataset[key].append(value)
275 |
276 | if prompt_call_idx >= max_prompt_calls:
277 | logger.info("Reached maximum number of prompt calls ({}).", max_prompt_calls)
278 | break
279 |
280 | if len(generated_dataset) >= num_samples_to_generate:
281 | logger.info("Generated {} samples.", num_samples_to_generate)
282 | break
283 |
284 | if timeout_per_prompt is not None:
285 | time.sleep(timeout_per_prompt)
286 |
287 | generated_dataset = Dataset.from_dict(generated_dataset)
288 |
289 | if return_unlabeled_dataset:
290 | original_dataset = Dataset.from_dict(original_dataset)
291 | return generated_dataset, original_dataset
292 |
293 | return generated_dataset, None
294 |
295 | def _convert_prediction(self, prediction: str, target_type: type) -> Any:
296 | """Converts a prediction to the target type.
297 |
298 | Args:
299 | prediction: Prediction to convert.
300 | target_type: Type to convert the prediction to.
301 |
302 | Returns:
303 | Converted prediction.
304 | """
305 |
306 | if isinstance(prediction, target_type):
307 | return prediction
308 |
309 | try:
310 | return target_type(prediction)
311 | except ValueError:
312 | logger.warning(
313 | "Could not convert prediction {} to type {}. "
314 | "Returning original prediction.", repr(prediction), target_type
315 | )
316 | return prediction
317 |
318 | @staticmethod
319 | def _sample_fewshot_examples(
320 | prompt_template: BasePrompt,
321 | fewshot_dataset: Dataset,
322 | fewshot_sampling_strategy: str,
323 | fewshot_examples_per_class: int,
324 | fewshot_sampling_column: str
325 | ) -> Tuple[Union[List[str], str], Dataset]:
326 |
327 | if fewshot_sampling_strategy == "uniform":
328 | prompt_labels = choice(prompt_template.label_options, 1)[0]
329 | fewshot_examples = fewshot_dataset.filter(
330 | lambda example: example[fewshot_sampling_column] == prompt_labels
331 | ).shuffle().select(range(fewshot_examples_per_class))
332 |
333 | elif fewshot_sampling_strategy == "stratified":
334 | prompt_labels = prompt_template.label_options
335 | fewshot_examples = single_label_stratified_sample(
336 | fewshot_dataset,
337 | fewshot_sampling_column,
338 | fewshot_examples_per_class
339 | )
340 |
341 | else:
342 | prompt_labels = prompt_template.label_options if prompt_template.label_options else None
343 | if fewshot_examples_per_class:
344 | fewshot_examples = fewshot_dataset.shuffle().select(range(fewshot_examples_per_class))
345 | else:
346 | fewshot_examples = fewshot_dataset.shuffle()
347 |
348 | assert len(fewshot_examples) > 0, f"Could not find any fewshot examples for label(s) {prompt_labels}." \
349 | f"Ensure that labels of fewshot examples match the label_options " \
350 | f"from the prompt."
351 |
352 | return prompt_labels, fewshot_examples
353 |
354 | @staticmethod
355 | def _assert_fewshot_dataset_matches_prompt(prompt_template: BasePrompt, fewshot_dataset: Dataset) -> None:
356 | """Asserts that the prompt template is valid and all columns are present in the fewshot dataset."""
357 | assert all(
358 | field in fewshot_dataset.column_names for field in prompt_template.relevant_columns_for_fewshot_examples
359 | ), "Not all required variables of the prompt template occur in the support examples."
360 |
--------------------------------------------------------------------------------
/src/fabricator/dataset_transformations/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | "preprocess_squad_format",
3 | "postprocess_squad_format",
4 | "calculate_answer_start",
5 | "convert_label_ids_to_texts",
6 | "get_labels_from_dataset",
7 | "replace_class_labels",
8 | "convert_token_labels_to_spans",
9 | "convert_spans_to_token_labels",
10 | ]
11 |
12 | from .question_answering import preprocess_squad_format, postprocess_squad_format, calculate_answer_start
13 | from .text_classification import convert_label_ids_to_texts, get_labels_from_dataset, replace_class_labels
14 | from .token_classification import convert_token_labels_to_spans, convert_spans_to_token_labels
15 |
--------------------------------------------------------------------------------
/src/fabricator/dataset_transformations/question_answering.py:
--------------------------------------------------------------------------------
1 | from datasets import Dataset
2 | from loguru import logger
3 |
4 |
5 | def preprocess_squad_format(dataset: Dataset) -> Dataset:
6 | """Preprocesses a dataset in SQuAD format (nested answers) to a dataset in SQuAD format that has flat answers.
7 | {"answer": {"text": "answer", "start": 0}} -> {"text": "answer"}
8 |
9 | Args:
10 | dataset (Dataset): A huggingface dataset in SQuAD format.
11 |
12 | Returns:
13 | Dataset: A huggingface dataset in SQuAD format with flat answers.
14 | """
15 |
16 | def preprocess(example):
17 | if example["answers"]:
18 | example["answers"] = example["answers"].pop()
19 | else:
20 | example["answers"] = ""
21 | return example
22 |
23 | dataset = dataset.flatten().rename_column("answers.text", "answers").map(preprocess)
24 | return dataset
25 |
26 |
27 | def postprocess_squad_format(dataset: Dataset, add_answer_start: bool = True) -> Dataset:
28 | """Postprocesses a dataset in SQuAD format (flat answers) to a dataset in SQuAD format that has nested answers.
29 | {"text": "answer"} -> {"answer": {"text": "answer", "start": 0}}
30 |
31 | Args:
32 | dataset (Dataset): A huggingface dataset in SQuAD format.
33 | add_answer_start (bool, optional): Whether to add the answer start index to the dataset. Defaults to True.
34 |
35 | Returns:
36 | Dataset: A huggingface dataset in SQuAD format with nested answers.
37 | """
38 | # remove punctuation and whitespace from the start and end of the answer
39 | def remove_punctuation(example):
40 | example["answers"] = example["answers"].strip(".,;!? ")
41 | return example
42 |
43 | dataset = dataset.map(remove_punctuation)
44 |
45 | if add_answer_start:
46 | dataset = dataset.map(calculate_answer_start)
47 |
48 | def unify_answers(example):
49 | is_answerable = "answer_start" in example
50 | if is_answerable:
51 | example["answers"] = {"text": [example["answers"]], "answer_start": [example["answer_start"]]}
52 | else:
53 | example["answers"] = {"text": [], "answer_start": []}
54 | return example
55 |
56 | dataset = dataset.map(unify_answers)
57 | if "answer_start" in dataset.column_names:
58 | dataset = dataset.remove_columns("answer_start")
59 | return dataset
60 |
61 |
62 | def calculate_answer_start(example):
63 | """Calculates the answer start index for a SQuAD example.
64 |
65 | Args:
66 | example (Dict): A SQuAD example.
67 |
68 | Returns:
69 | Dict: The SQuAD example with the answer start index added.
70 | """
71 | answer_start = example["context"].lower().find(example["answers"].lower())
72 | if answer_start < 0:
73 | logger.info(
74 | 'Could not calculate the answer start because the context "{}" ' 'does not contain the answer "{}".',
75 | example["context"],
76 | example["answers"],
77 | )
78 | answer_start = -1
79 | else:
80 | # check that the answer doesn't occur more than once in the context
81 | second_answer_start = example["context"].lower().find(example["answers"].lower(), answer_start + 1)
82 | if second_answer_start >= 0:
83 | logger.info("Could not calculate the answer start because the context contains the answer more than once.")
84 | answer_start = -1
85 | else:
86 | # correct potential wrong capitalization of the answer compared to the context
87 | example["answers"] = example["context"][answer_start : answer_start + len(example["answers"])]
88 | example["answer_start"] = answer_start
89 | return example
90 |
--------------------------------------------------------------------------------
/src/fabricator/dataset_transformations/text_classification.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple, Union
2 |
3 | from datasets import Dataset, DatasetDict, ClassLabel, Sequence
4 |
5 |
6 | def get_labels_from_dataset(dataset: Union[Dataset, DatasetDict], label_column: str) -> List[str]:
7 | """Gets the list of labels from a huggingface Dataset.
8 |
9 | Args:
10 | dataset (Union[Dataset, DatasetDict]): huggingface Dataset
11 | label_column (str): name of the column with the labels
12 |
13 | Returns:
14 | List[str]: list of labels
15 | """
16 | if isinstance(dataset, DatasetDict):
17 | tmp_ref_dataset = dataset["train"]
18 | else:
19 | tmp_ref_dataset = dataset
20 |
21 | if isinstance(tmp_ref_dataset.features[label_column], ClassLabel):
22 | features = tmp_ref_dataset.features[label_column]
23 | elif isinstance(tmp_ref_dataset.features[label_column], Sequence):
24 | features = tmp_ref_dataset.features[label_column].feature
25 | else:
26 | raise ValueError(f"Label column {label_column} is not of type ClassLabel or Sequence.")
27 | return features.names
28 |
29 |
30 | def replace_class_labels(id2label: Union[Dict[str, str], Dict[int, str]], expanded_labels: Dict) -> Dict:
31 | """Replaces class labels with expanded labels, i.e. label LOC should be expanded to LOCATION.
32 | Values of id2label need to match keys of expanded_labels.
33 |
34 | Args:
35 | id2label (Dict): mapping from label ids to label names
36 | expanded_labels (Dict): mapping from label names or label ids to expanded label names
37 |
38 | Returns:
39 | Dict: mapping from label ids to label names with expanded labels
40 | """
41 | if all(isinstance(k, int) for k in expanded_labels.keys()):
42 | type_keys = "int"
43 | elif all(isinstance(k, str) for k in expanded_labels.keys()):
44 | type_keys = "str"
45 | else:
46 | raise ValueError("Keys of expanded_labels must be either all ints or all strings.")
47 |
48 | if not all(isinstance(v, str) for v in expanded_labels.values()):
49 | raise ValueError("Values of expanded_labels must be strings.")
50 |
51 | if type_keys == "str":
52 | replaced_id2label = {}
53 | for idx, tag in id2label.items():
54 | if tag in expanded_labels:
55 | replaced_id2label[idx] = expanded_labels[tag]
56 | else:
57 | replaced_id2label[idx] = tag
58 | else:
59 | replaced_id2label = expanded_labels
60 | return replaced_id2label
61 |
62 |
63 | def convert_label_ids_to_texts(
64 | dataset: Union[Dataset, DatasetDict],
65 | label_column: str,
66 | expanded_label_mapping: Dict = None,
67 | return_label_options: bool = False,
68 | ) -> Union[Dataset, DatasetDict, Tuple[Union[Dataset, DatasetDict], List[str]]]:
69 | """Converts label IDs to natural language labels for any classification problem with a single label such as text
70 | classification. Note that if the function is not applied to a Dataset, the label column will contain the IDs.
71 | If the function is applied, the label column will contain the natural language labels.
72 |
73 | Args:
74 | dataset (Dataset): huggingface Dataset with label ids.
75 | label_column (str): name of the label column.
76 | expanded_label_mapping (Dict, optional): dictionary mapping label ids to natural language labels.
77 | Defaults to None.
78 | return_label_options (bool, optional): whether to return the list of possible labels. Defaults to False.
79 |
80 | Returns:
81 | Tuple[Dataset, List[str]]: huggingface Dataset with natural language labels and list of natural language
82 | labels.
83 | """
84 | labels = get_labels_from_dataset(dataset, label_column)
85 | id2label = dict(enumerate(labels))
86 |
87 | if expanded_label_mapping is not None:
88 | id2label = replace_class_labels(id2label, expanded_label_mapping)
89 |
90 | new_label_column = f"{label_column}_natural_language"
91 | label_options = list(id2label.values())
92 |
93 | def labels_to_natural_language(examples):
94 | examples[new_label_column] = id2label[examples[label_column]]
95 | return examples
96 |
97 | dataset = (
98 | dataset.map(labels_to_natural_language)
99 | .remove_columns(label_column)
100 | .rename_column(new_label_column, label_column)
101 | )
102 |
103 | if return_label_options:
104 | return dataset, label_options
105 |
106 | return dataset
107 |
--------------------------------------------------------------------------------
/src/fabricator/dataset_transformations/token_classification.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Dict, List, Tuple, Union
3 | from datasets import Dataset, Sequence
4 |
5 | from loguru import logger
6 |
7 | # These are fixed for encoding the prompt and decoding the output of the LLM
8 | SPAN_ANNOTATION_TEMPLATE = "{entity} is {label} entity."
9 | SPAN_ANNOTATION_REGEX = r'(.+) is (.+) entity\.'
10 |
11 |
12 | def convert_token_labels_to_spans(
13 | dataset: Dataset,
14 | token_column: str,
15 | label_column: str,
16 | expanded_label_mapping: Dict = None,
17 | return_label_options: bool = False
18 | ) -> Union[Dataset, Tuple[Dataset, List[str]]]:
19 | """Converts token level labels to spans. Useful for NER tasks to prompt the LLM with natural language labels.
20 |
21 | Args:
22 | dataset (Dataset): huggingface Dataset with token level labels
23 | token_column (str): name of the column with the tokens
24 | label_column (str): name of the column with the token level labels
25 | expanded_label_mapping (Dict): mapping from label ids to label names. Defaults to None.
26 | return_label_options (bool): whether to return a list of all possible annotations of the provided dataset
27 |
28 | Returns:
29 | Tuple[Dataset, List[str]]: huggingface Dataset with span labels and list of possible labels for the prompt
30 | """
31 | if expanded_label_mapping:
32 | if not len(expanded_label_mapping) == len(dataset.features[label_column].feature.names):
33 | raise ValueError(
34 | f"Length of expanded label mapping and original number of labels in dataset do not match.\n"
35 | f"Original labels: {dataset.features[label_column].feature.names}"
36 | f"Expanded labels: {list(expanded_label_mapping.values())}"
37 | )
38 | id2label = expanded_label_mapping
39 | elif isinstance(dataset.features[label_column], Sequence):
40 | id2label = dict(enumerate(dataset.features[label_column].feature.names))
41 | else:
42 | raise ValueError("Labels must be a Sequence feature or expanded_label_mapping must be provided.")
43 |
44 | span_column = "span_annotations"
45 |
46 | def labels_to_spans(example):
47 | span_annotations = [id2label.get(label).replace("B-", "").replace("I-", "") for label in example[label_column]]
48 |
49 | annotations_for_prompt = ""
50 |
51 | current_entity = None
52 | current_entity_type = None
53 | for idx, span_annotation in enumerate(span_annotations):
54 | if span_annotation == "O":
55 | if current_entity is not None:
56 | annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity,
57 | label=current_entity_type) + "\n"
58 | current_entity = None
59 | current_entity_type = None
60 | continue
61 | if current_entity is None:
62 | current_entity = example[token_column][idx]
63 | current_entity_type = span_annotation
64 | continue
65 | if current_entity_type == span_annotation:
66 | current_entity += " " + example[token_column][idx]
67 | else:
68 | annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity,
69 | label=current_entity_type) + "\n"
70 | current_entity = example[token_column][idx]
71 | current_entity_type = span_annotation
72 |
73 | if current_entity is not None:
74 | annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity,
75 | label=current_entity_type) + "\n"
76 |
77 | example[token_column] = " ".join(example[token_column])
78 | example[span_column] = annotations_for_prompt.rstrip("\n")
79 | return example
80 |
81 | dataset = dataset.map(labels_to_spans).remove_columns(label_column).rename_column(span_column, label_column)
82 |
83 | if return_label_options:
84 | # Spans have implicit BIO format, so sequences come in BIO format, we can ignore it
85 | label_options = list({label.replace("B-", "").replace("I-", "") for label in id2label.values()})
86 |
87 | # Ignore "outside" tokens
88 | if "O" in label_options:
89 | label_options.remove("O")
90 |
91 | return dataset, label_options
92 |
93 | return dataset
94 |
95 |
96 | def convert_spans_to_token_labels(
97 | dataset: Dataset,
98 | token_column: str,
99 | label_column: str,
100 | id2label: Dict,
101 | annotate_identical_words: bool = False
102 | ) -> Dataset:
103 | """Converts span level labels to token level labels.
104 | First, the function extracts all entities with its annotated types.
105 | Second, if annotations are present, the function converts them to a tag sequence in BIO format.
106 | If not present, simply return tag sequence of O-tokens.
107 | This is useful for NER tasks to decode the output of the LLM.
108 |
109 | Args:
110 | dataset (Dataset): huggingface Dataset with span level labels
111 | token_column (str): name of the column with the tokens
112 | label_column (str): name of the column with the span level labels
113 | id2label (Dict): mapping from label ids to label names
114 | annotate_identical_words (bool): whether to annotate all identical words in a sentence with a found entity
115 | type
116 |
117 | Returns:
118 | Dataset: huggingface Dataset with token level labels in BIO format
119 | """
120 | new_label_column = "sequence_tags"
121 | lower_label2id = {label.lower(): idx for idx, label in id2label.items()}
122 |
123 | def labels_to_spans(example):
124 | span_annotations = example[label_column].split("\n")
125 |
126 | ner_tag_tuples = []
127 |
128 | for span_annotation in span_annotations:
129 | matches = re.match(SPAN_ANNOTATION_REGEX, span_annotation)
130 | if matches:
131 | matched_entity = matches.group(1)
132 | matched_label = matches.group(2)
133 |
134 | span_tokens = matched_entity.split(" ")
135 | span_labels = ["B-" + matched_label if idx == 0 else "B-" + matched_label.lower()
136 | for idx, _ in enumerate(span_tokens)]
137 |
138 | for token, label in zip(span_tokens, span_labels):
139 | label_id = lower_label2id.get(label.lower())
140 | if label_id is None:
141 | logger.info(f"Entity {token} with label {label} is not in id2label: {id2label}.")
142 | else:
143 | ner_tag_tuples.append((token, label_id))
144 | else:
145 | pass
146 |
147 | if ner_tag_tuples:
148 | lower_tokens = example[token_column].lower().split(" ")
149 | # initialize all tokens with O type
150 | ner_tags = [0] * len(lower_tokens)
151 | for reference_token, entity_type_id in ner_tag_tuples:
152 | if lower_tokens.count(reference_token.lower()) == 0:
153 | logger.info(
154 | f"Entity {reference_token} is not found or occurs more than once: {lower_tokens}. "
155 | f"Thus, setting label to O."
156 | )
157 | elif lower_tokens.count(reference_token.lower()) > 1:
158 | if annotate_identical_words:
159 | insert_at_idxs = [index for index, value in enumerate(lower_tokens)
160 | if value == reference_token.lower()]
161 | for insert_at_idx in insert_at_idxs:
162 | ner_tags[insert_at_idx] = entity_type_id
163 | else:
164 | logger.info(
165 | f"Entity {reference_token} occurs more than once: {lower_tokens}. "
166 | f"Thus, setting label to O."
167 | )
168 | else:
169 | insert_at_idx = lower_tokens.index(reference_token.lower())
170 | ner_tags[insert_at_idx] = entity_type_id
171 | else:
172 | ner_tags = [0] * len(example[token_column].split(" "))
173 |
174 | example[token_column] = example[token_column].split(" ")
175 | example[new_label_column] = ner_tags
176 |
177 | return example
178 |
179 | dataset = (
180 | dataset.map(labels_to_spans)
181 | .remove_columns(label_column)
182 | .rename_column(new_label_column, label_column)
183 | )
184 |
185 | return dataset
186 |
187 |
188 | def replace_token_labels(id2label: Dict, expanded_labels: Dict) -> Dict:
189 | """Replaces token level labels with expanded labels, i.e. label PER should be expanded to PERSON.
190 | Values of id2label need to match keys of expanded_labels.
191 |
192 | Args:
193 | id2label (Dict): mapping from label ids to label names
194 | expanded_labels (Dict): mapping from label names to expanded label names
195 |
196 | Returns:
197 | Dict: mapping from label ids to label names with expanded labels
198 | """
199 | replaced_id2label = {}
200 | for idx, tag in id2label.items():
201 | if tag.startswith("B-") or tag.startswith("I-"):
202 | prefix, label = tag.split("-", 1)
203 | if label in expanded_labels:
204 | new_label = expanded_labels[label]
205 | new_label_bio = f"{prefix}-{new_label}"
206 | replaced_id2label[idx] = new_label_bio
207 | else:
208 | replaced_id2label[idx] = tag
209 | else:
210 | replaced_id2label[idx] = tag
211 | return replaced_id2label
212 |
--------------------------------------------------------------------------------
/src/fabricator/prompts/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | "BasePrompt",
3 | "infer_prompt_from_dataset",
4 | "infer_prompt_from_task_template"
5 | ]
6 |
7 | from .base import BasePrompt
8 | from .utils import infer_prompt_from_dataset, infer_prompt_from_task_template
9 |
--------------------------------------------------------------------------------
/src/fabricator/prompts/base.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Union, Optional
2 |
3 | from datasets import Dataset
4 | from loguru import logger
5 |
6 |
7 | class BasePrompt:
8 | """Base class for prompt generation. This class formats the prompt for the fewshot / support set examples
9 | and the target variable such that the dataset generator can simply put in the invocation context."""
10 |
11 | DEFAULT_TEXT_COLUMN = ["text"]
12 | DEFAULT_LABEL_COLUMN = ["label"]
13 |
14 | def __init__(
15 | self,
16 | task_description: str,
17 | generate_data_for_column: Optional[str] = None,
18 | fewshot_example_columns: Optional[Union[List[str], str]] = None,
19 | label_options: Optional[List[str]] = None,
20 | fewshot_formatting_template: Optional[str] = None,
21 | target_formatting_template: Optional[str] = None,
22 | fewshot_example_separator: str = "\n\n",
23 | inner_fewshot_example_separator: str = "\n",
24 | ):
25 | """Base class for prompt generation. This class formats the prompt for the fewshot / support set examples.
26 |
27 | Args:
28 | task_description (Optional[str], optional): Task description for the prompt (prefix).
29 | generate_data_for_column (Optional[str], optional): The column name to generate data for. Defaults to None.
30 | fewshot_example_columns (Union[List[str], str]): List of strings or string of column names for the
31 | fewshot / support set examples. Defaults to None.
32 | label_options (Optional[ClassificationOptions], optional): Label options for the LLM to choose from.
33 | Defaults to None.
34 | fewshot_formatting_template (Optional[str], optional): Template for formatting the fewshot / support set
35 | examples. Defaults to None.
36 | target_formatting_template (Optional[str], optional): Template for formatting the target variable.
37 | Defaults to None.
38 | fewshot_example_separator (str, optional): Separator between the fewshot / support set examples.
39 | Defaults to "\n\n".
40 | inner_fewshot_example_separator (str, optional): Separator in-between a single fewshot examples.
41 | Defaults to "\n".
42 |
43 | Raises:
44 | AttributeError: If label_options is not a dict or list
45 | KeyError: If the task_description cannot be formatted with the variable 'label_options'
46 | ValueError: You need specify either generate_data_for_column or
47 | generate_data_for_column + fewshot_example_columns. Only fewshot_example_columns is not supported.
48 | """
49 | self.task_description = task_description
50 |
51 | if label_options:
52 | self._assert_task_description_is_formattable(task_description)
53 | self.label_options = label_options
54 |
55 | if isinstance(generate_data_for_column, str) and generate_data_for_column:
56 | generate_data_for_column = [generate_data_for_column]
57 | self.generate_data_for_column = generate_data_for_column
58 |
59 | if isinstance(fewshot_example_columns, str) and fewshot_example_columns:
60 | fewshot_example_columns = [fewshot_example_columns]
61 | self.fewshot_example_columns = fewshot_example_columns
62 | self.fewshot_example_separator = fewshot_example_separator
63 | self.inner_fewshot_example_separator = inner_fewshot_example_separator
64 |
65 | if fewshot_example_columns:
66 | self.relevant_columns_for_fewshot_examples = self.fewshot_example_columns + self.generate_data_for_column
67 | elif generate_data_for_column:
68 | self.relevant_columns_for_fewshot_examples = self.generate_data_for_column
69 | else:
70 | self.relevant_columns_for_fewshot_examples = None
71 |
72 | # Create prompt template for fewshot examples
73 | if self.relevant_columns_for_fewshot_examples:
74 | if fewshot_formatting_template is None:
75 | self.fewshot_prompt = self.inner_fewshot_example_separator.join(
76 | [f"{var}: {{{var}}}" for var in self.relevant_columns_for_fewshot_examples]
77 | )
78 | else:
79 | self.fewshot_prompt = fewshot_formatting_template
80 |
81 | # Create format template for targets
82 | if target_formatting_template is None:
83 | self.target_formatting_template = self._infer_target_formatting_template()
84 | else:
85 | self.target_formatting_template = target_formatting_template
86 |
87 | logger.info(self._log_prompt())
88 |
89 | @staticmethod
90 | def _assert_task_description_is_formattable(task_description: str) -> None:
91 | """Checks if task_description is formattable.
92 |
93 | Args:
94 | task_description (str): Task description for the prompt (prefix).
95 | """
96 | if "testxyz" not in task_description.format("testxyz"):
97 | raise KeyError("If you provide label_options, you need the task_description to be formattable like"
98 | " 'Generate a {} text.'")
99 |
100 | def _infer_target_formatting_template(self) -> str:
101 | """Infer target formatting template from input columns and label column.
102 |
103 | Returns:
104 | str: Target formatting template
105 | """
106 | if self.generate_data_for_column and self.fewshot_example_columns:
107 | target_template = self.inner_fewshot_example_separator.join(
108 | [f"{var}: {{{var}}}" for var in self.fewshot_example_columns] +
109 | [f"{self.generate_data_for_column[0]}: "]
110 | )
111 |
112 | elif self.generate_data_for_column and not self.fewshot_example_columns:
113 | target_template = f"{self.generate_data_for_column[0]}: "
114 |
115 | elif not self.generate_data_for_column and not self.fewshot_example_columns:
116 | target_template = f"{self.DEFAULT_TEXT_COLUMN[0]}: "
117 |
118 | else:
119 | raise ValueError("Either generate_data_for_column or generate_data_for_column + fewshot_example_columns "
120 | "must be provided to infer target template.")
121 |
122 | return target_template
123 |
124 | def _log_prompt(self) -> str:
125 | """Log prompt.
126 |
127 | Returns:
128 | str: Prompt text
129 | """
130 | label = None
131 | fewshot_examples = None
132 |
133 | if self.label_options:
134 | label = "EXAMPLE LABEL"
135 |
136 | if self.relevant_columns_for_fewshot_examples:
137 | fewshot_examples = {}
138 | for column in self.relevant_columns_for_fewshot_examples:
139 | fewshot_examples[column] = [f"EXAMPLE TEXT FOR COLUMN {column}"]
140 | fewshot_examples = Dataset.from_dict(fewshot_examples)
141 |
142 | return "\nThe prompt to the LLM will be like:\n" + 10*"-" + "\n"\
143 | + self.get_prompt_text(label, fewshot_examples) + "\n" + 10*"-"
144 |
145 | @staticmethod
146 | def filter_example_by_columns(example: Dict[str, str], columns: List[str]) -> Dict[str, str]:
147 | """Filter single example by columns.
148 |
149 | Args:
150 | example (Dict[str, str]): Example to filter
151 | columns (List[str]): Columns to keep
152 |
153 | Returns:
154 | Dict[str, str]: Filtered example
155 | """
156 | filtered_example = {key: value for key, value in example.items() if key in columns}
157 | return filtered_example
158 |
159 | def filter_examples_by_columns(self, dataset: Dataset, columns: List[str]) -> List[Dict[str, str]]:
160 | """Filter examples by columns.
161 |
162 | Args:
163 | dataset (Dataset): Dataset to filter
164 | columns (List[str]): Columns to keep
165 |
166 | Returns:
167 | List[Dict[str, str]]: Filtered examples
168 | """
169 | filtered_inputs = []
170 | for example in dataset:
171 | filtered_inputs.append(self.filter_example_by_columns(example, columns))
172 | return filtered_inputs
173 |
174 | def get_prompt_text(self, labels: Union[str, List[str]] = None, examples: Optional[Dataset] = None) -> str:
175 | """Get prompt text for the given examples.
176 |
177 | Args:
178 | labels (Union[str, List[str]], optional): Label(s) to use for the prompt. Defaults to None.
179 | examples (Dataset): Examples to use for the prompt
180 |
181 | Returns:
182 | str: Prompt text
183 | """
184 | if isinstance(labels, list):
185 | labels = ", ".join(labels)
186 |
187 | if labels:
188 | task_description = self.task_description.format(labels)
189 | else:
190 | task_description = self.task_description
191 |
192 | if examples:
193 | examples = self.filter_examples_by_columns(examples, self.relevant_columns_for_fewshot_examples)
194 | formatted_examples = [self.fewshot_prompt.format(**example) for example in examples]
195 | prompt_text = self.fewshot_example_separator.join(
196 | [task_description] + formatted_examples + [self.target_formatting_template]
197 | )
198 | else:
199 | prompt_text = self.fewshot_example_separator.join(
200 | [task_description] + [self.target_formatting_template]
201 | )
202 | return prompt_text
203 |
--------------------------------------------------------------------------------
/src/fabricator/prompts/utils.py:
--------------------------------------------------------------------------------
1 | from datasets import Dataset, QuestionAnsweringExtractive, TextClassification, TaskTemplate
2 | from .base import BasePrompt
3 |
4 | DEFAULT_TEXT_CLASSIFICATION = "Classify the following texts exactly into one of the following categories: {" \
5 | "}."
6 | DEFAULT_QA = "Given a context and a question, generate an answer that occurs exactly and only once in the text."
7 |
8 |
9 | def infer_prompt_from_task_template(task_template: TaskTemplate) -> BasePrompt:
10 | """Infer TextLabelPrompt or ClassLabelPrompt with correct parameters from a task template's metadata.
11 |
12 | Args:
13 | task_template (TaskTemplate): task template from which to infer the prompt.
14 |
15 | Returns:
16 | BasePrompt: with correct parameters.
17 | """
18 | if isinstance(task_template, QuestionAnsweringExtractive):
19 | return BasePrompt(
20 | task_description=DEFAULT_QA,
21 | generate_data_for_column=task_template.answers_column, # assuming the dataset was preprocessed
22 | # with preprocess_squad_format, otherwise dataset.task_templates[0]["answers_column"]
23 | fewshot_example_columns=[task_template.context_column, task_template.question_column],
24 | )
25 |
26 | if isinstance(task_template, TextClassification):
27 | return BasePrompt(
28 | task_description=DEFAULT_TEXT_CLASSIFICATION,
29 | generate_data_for_column=task_template.label_column,
30 | fewshot_example_columns=task_template.text_column,
31 | label_options=task_template.label_schema["labels"].names,
32 | )
33 |
34 | raise ValueError(
35 | f"Automatic prompt is only supported for QuestionAnsweringExtractive and "
36 | f"TextClassification tasks but not for {type(task_template)}. You need to "
37 | f"specify the prompt manually."
38 | )
39 |
40 |
41 | def infer_prompt_from_dataset(dataset: Dataset):
42 | """Infer TextLabelPrompt or ClassLabelPrompt with correct parameters from a dataset's metadata."""
43 | if not dataset.task_templates:
44 | raise ValueError(
45 | "Dataset must have exactly one task template but there is none. You need to specify the "
46 | "prompt manually."
47 | )
48 |
49 | if len(dataset.task_templates) > 1:
50 | raise ValueError(
51 | f"Automatic prompt is only supported for datasets with exactly one task template but yours "
52 | f"has {len(dataset.task_templates)}. You need to specify the prompt manually."
53 | )
54 |
55 | return infer_prompt_from_task_template(dataset.task_templates[0])
56 |
--------------------------------------------------------------------------------
/src/fabricator/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | "single_label_task_sampler",
3 | "single_label_stratified_sample",
4 | "random_sampler",
5 | "ml_mc_sampler"
6 | ]
7 |
8 | from .samplers import single_label_task_sampler, single_label_stratified_sample, \
9 | random_sampler, ml_mc_sampler
10 |
--------------------------------------------------------------------------------
/src/fabricator/samplers/samplers.py:
--------------------------------------------------------------------------------
1 | """Sampling methods
2 |
3 | NOTE: All methods do not ensure, that all labels are contained in the samples.
4 | """
5 | import random
6 | from typing import Dict, List, Set, Union, Tuple
7 | from collections import defaultdict, deque
8 | from itertools import cycle
9 |
10 | from datasets import ClassLabel, Dataset, Sequence, Value
11 | from loguru import logger
12 | from tqdm import tqdm
13 |
14 |
15 | def random_sampler(dataset: Dataset, num_examples: int) -> Dataset:
16 | """Random sampler"""
17 | return dataset.select(random.sample(range(len(dataset)), num_examples))
18 |
19 |
20 | def single_label_task_sampler(
21 | dataset: Dataset, label_column: str, num_examples: int, return_unused_split: bool = False
22 | ) -> Dataset:
23 | """Sampler for single label tasks, like text classification
24 |
25 | Args:
26 | dataset: Dataset
27 | label_column: Name of the label column
28 | num_examples: Number of examples to sample
29 | return_unused_split: Whether to return the unused split
30 |
31 | Approach:
32 | num_examples > len(dataset): Samples all examples
33 | num_examples < len(dataset): Samples at least one example per label
34 | num_examples < len(dataset.features): Samples only num_examples and notify
35 | """
36 |
37 | if num_examples > len(dataset):
38 | return dataset
39 |
40 | if "train" in dataset:
41 | dataset = dataset["train"]
42 |
43 | pbar = tqdm(total=num_examples, desc="Sampling")
44 |
45 | class_labels = _infer_class_labels(dataset, label_column)
46 | num_classes = len(class_labels)
47 |
48 | unique_classes_sampled = set()
49 | total_examples_sampled = 0
50 |
51 | sampled_indices = []
52 |
53 | while total_examples_sampled < num_examples:
54 | # Lets try to be as random and possible and sample from the entire dataset
55 | idx = random.sample(range(len(dataset)), 1)[0]
56 |
57 | # Pass already sampled idx
58 | if idx in sampled_indices:
59 | continue
60 |
61 | sample = dataset.select([idx])[0]
62 | label = sample[label_column]
63 |
64 | # First sample at least one example per label
65 | if label not in unique_classes_sampled:
66 | unique_classes_sampled.add(label)
67 | sampled_indices.append(idx)
68 | total_examples_sampled += 1
69 | pbar.update(1)
70 |
71 | # Further sample if we collected at least one example per label
72 | elif len(sampled_indices) < num_examples and len(unique_classes_sampled) == num_classes:
73 | sampled_indices.append(idx)
74 | total_examples_sampled += 1
75 | pbar.update(1)
76 |
77 | if return_unused_split:
78 | unused_indices = list(set(range(len(dataset))) - set(sampled_indices))
79 | return dataset.select(sampled_indices), dataset.select(unused_indices)
80 |
81 | return dataset.select(sampled_indices)
82 |
83 |
84 | def _alternate_classes(dataset: Dataset, column: str) -> Dataset:
85 | """Alternate the occurrence of each class in the dataset.
86 |
87 | Args:
88 | dataset: Dataset
89 | column: Name of the column to alternate the classes of
90 |
91 | Returns:
92 | Dataset with the classes alternated
93 | """
94 |
95 | # Group the indices of each unique value in 'target' column
96 | targets = defaultdict(deque)
97 | for i, elem in enumerate(dataset[column]):
98 | targets[elem].append(i)
99 |
100 | # Create a cycle iterator from targets
101 | targets_cycle = cycle(targets.keys())
102 |
103 | # Alternate the occurrence of each class
104 | alternate_indices = []
105 | for target in targets_cycle:
106 | if not targets[target]: # If this class has no more indices, remove it from the cycle
107 | del targets[target]
108 | else: # Otherwise, add the next index of this class to the list
109 | alternate_indices.append(targets[target].popleft())
110 |
111 | # If there are no more indices, break the loop
112 | if not targets:
113 | break
114 |
115 | # Create new dataset from the alternate indices
116 | alternate_dataset = dataset.select(alternate_indices)
117 |
118 | return alternate_dataset
119 |
120 |
121 | def single_label_stratified_sample(
122 | dataset: Dataset,
123 | label_column: str,
124 | num_examples_per_class: int,
125 | return_unused_split: bool = False
126 | ) -> Union[Dataset, Tuple[Dataset, Dataset]]:
127 | """Stratified sampling for single label tasks, like text classification.
128 |
129 | Args:
130 | dataset: Dataset
131 | label_column: Name of the label column
132 | num_examples_per_class: Number of examples to sample per class
133 | return_unused_split: If True, return the unused split of the dataset
134 |
135 | Returns:
136 | Dataset: Stratified sample of the dataset
137 | """
138 | # Ensure the 'k' value is valid
139 | if num_examples_per_class <= 0:
140 | raise ValueError("'num_examples_per_class' should be a positive integer.")
141 |
142 | # Group the indices of each unique value in 'target' column
143 | targets = defaultdict(list)
144 | for i, elem in enumerate(dataset[label_column]):
145 | targets[elem].append(i)
146 |
147 | # Check if k is smaller or equal than the size of the smallest group
148 | if num_examples_per_class > min(len(indices) for indices in targets.values()):
149 | raise ValueError(
150 | "'num_examples_per_class' is greater than the size of the smallest group in the target column."
151 | )
152 |
153 | # Stratified sampling
154 | sample_indices = []
155 | for indices in targets.values():
156 | sample_indices.extend(random.sample(indices, num_examples_per_class))
157 |
158 | # Create new dataset from the sample
159 | sample_dataset = dataset.select(sample_indices)
160 | sample_dataset = _alternate_classes(sample_dataset, label_column)
161 |
162 | if return_unused_split:
163 | unused_indices = list(set(range(len(dataset))) - set(sample_indices))
164 | return sample_dataset, dataset.select(unused_indices)
165 |
166 | return sample_dataset
167 |
168 |
169 | def ml_mc_sampler(dataset: Dataset, labels_column: str, num_examples: int) -> Dataset:
170 | """Multi label multi class sampler
171 |
172 | Args:
173 | dataset: Dataset
174 | label_column: Name of the label column
175 | num_examples: Number of examples to sample, if -1 sample as long as subset does not contain every label
176 |
177 | """
178 |
179 | if num_examples > len(dataset):
180 | return dataset
181 |
182 | if "train" in dataset:
183 | dataset = dataset["train"]
184 |
185 | total_labels = _infer_class_labels(dataset, labels_column)
186 | num_classes = len(total_labels)
187 |
188 | # Because of random sampling we do not ensure, that we ever sampled all examples
189 | # Nor do we know if all labels are present. We therefore use a max try counter
190 | # So we dont get stuck in infinite while loop
191 | if num_examples == -1:
192 | max_tries = 2 * len(dataset)
193 | else:
194 | max_tries = -1
195 |
196 | tries = 0
197 |
198 | pbar = tqdm(total=num_examples, desc="Sampling")
199 |
200 | unique_classes_sampled: Set[str] = set()
201 | total_examples_sampled = 0
202 |
203 | sampled_indices = []
204 |
205 | while len(unique_classes_sampled) < len(total_labels):
206 | # Lets try to be as random and possible and sample from the entire dataset
207 | idx = random.sample(range(len(dataset)), 1)[0]
208 |
209 | # Pass already sampled idx
210 | if idx in sampled_indices:
211 | continue
212 |
213 | sample = dataset.select([idx])[0]
214 |
215 | labels = sample[labels_column]
216 |
217 | if not isinstance(labels, list):
218 | labels = [labels]
219 |
220 | labels_found = [label for label in labels if label not in unique_classes_sampled]
221 |
222 | # Check if current sample contains labels not found yet
223 | complements = _relative_complements(labels_found, unique_classes_sampled)
224 |
225 | if len(complements) > 0:
226 | unique_classes_sampled.update(complements)
227 | sampled_indices.append(idx)
228 | total_examples_sampled += 1
229 | pbar.update(1)
230 |
231 | # Further sample if we collected at least one example per label
232 | elif len(sampled_indices) < num_examples and len(unique_classes_sampled) == num_classes:
233 | sampled_indices.append(idx)
234 | total_examples_sampled += 1
235 | pbar.update(1)
236 |
237 | if num_examples != -1 and total_examples_sampled == num_examples:
238 | break
239 |
240 | tries += 1
241 | if tries == max_tries:
242 | logger.info("Stopping sample. Max tries(={}) exceeded.", max_tries)
243 | break
244 |
245 | return dataset.select(sampled_indices)
246 |
247 |
248 | def _infer_class_labels(dataset: Dataset, label_column: str) -> Dict[int, str]:
249 | """Infer the total set of labels"""
250 | features = dataset.features
251 |
252 | if label_column not in features:
253 | raise ValueError(f"Label column {label_column} not found in dataset")
254 |
255 | if isinstance(features[label_column], Value):
256 | logger.info("Label column {} is of type Value. Inferring labels from dataset", label_column)
257 | class_labels = dataset.class_encode_column(label_column).features[label_column].names
258 | elif isinstance(features[label_column], ClassLabel):
259 | class_labels = features[label_column].names
260 | elif isinstance(features[label_column], Sequence):
261 | class_labels = features[label_column].feature.names
262 | else:
263 | raise ValueError(
264 | f"Label column {label_column} is of type {type(features[label_column])}. Expected Value, "
265 | f"ClassLabel or Sequence"
266 | )
267 |
268 | return dict(enumerate(class_labels))
269 |
270 |
271 | def _relative_complements(list1: List, list2: Union[List, Set]) -> Set:
272 | """a \\ b"""
273 | return set(list1) - set(list2)
274 |
--------------------------------------------------------------------------------
/src/fabricator/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 |
4 |
5 | def log_dir():
6 | """Returns the log directory.
7 |
8 | Note:
9 | Keep it simple for now
10 | """
11 | return os.environ.get("LOG_DIR", "./logs")
12 |
13 |
14 | def create_timestamp_path(directory: str):
15 | """Returns a timestamped path for logging."""
16 | return os.path.join(directory, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
17 |
18 |
19 | def save_create_directory(path: str):
20 | """Creates a directory if it does not exist."""
21 | if not os.path.exists(path):
22 | os.makedirs(path)
23 |
--------------------------------------------------------------------------------
/tests/test_dataset_generator.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from datasets import Dataset, load_dataset
4 |
5 | from fabricator import DatasetGenerator
6 | from fabricator.prompts import BasePrompt
7 | from fabricator.dataset_transformations.text_classification import convert_label_ids_to_texts
8 |
9 |
10 | class TestDatasetGenerator(unittest.TestCase):
11 | """Testcase for Prompts"""
12 |
13 | def setUp(self) -> None:
14 | """Set up test dataset"""
15 | self.text_classification_dataset = Dataset.from_dict({
16 | "text": ["This movie is great!", "This movie is bad!"],
17 | "label": ["positive", "negative"]
18 | })
19 |
20 | # We are using dummy respones here, because we are not testing the LLM itself.
21 | self.generator = DatasetGenerator(None)
22 |
23 | def test_simple_generation(self):
24 | """Test simple generation without fewshot examples."""
25 | prompt = BasePrompt(
26 | task_description="Generate a short movie review.",
27 | )
28 |
29 | generated_dataset = self.generator.generate(
30 | prompt_template=prompt,
31 | max_prompt_calls=2,
32 | dummy_response="A dummy movie review."
33 | )
34 |
35 | self.assertEqual(len(generated_dataset), 2)
36 | self.assertEqual(generated_dataset.features["text"].dtype, "string")
37 | self.assertIn("text", generated_dataset.features)
38 |
39 | def test_simple_generation_with_label_options(self):
40 | """Test simple generation without fewshot examples with label options."""
41 | prompt = BasePrompt(
42 | task_description="Generate a short {} movie review.",
43 | label_options=["positive", "negative"],
44 | )
45 |
46 | generated_dataset = self.generator.generate(
47 | prompt_template=prompt,
48 | max_prompt_calls=2,
49 | dummy_response="A dummy movie review."
50 | )
51 |
52 | self.assertEqual(len(generated_dataset), 2)
53 | self.assertEqual(generated_dataset.features["text"].dtype, "string")
54 | self.assertIn("text", generated_dataset.features)
55 |
56 | def test_generation_with_fewshot_examples(self):
57 | label_options = ["positive", "negative"]
58 |
59 | prompt = BasePrompt(
60 | task_description="Generate a {} movie review.",
61 | label_options=label_options,
62 | generate_data_for_column="text",
63 | )
64 |
65 | generated_dataset = self.generator.generate(
66 | prompt_template=prompt,
67 | fewshot_dataset=self.text_classification_dataset,
68 | fewshot_examples_per_class=1,
69 | fewshot_sampling_strategy="uniform",
70 | fewshot_sampling_column="label",
71 | max_prompt_calls=2,
72 | dummy_response="A dummy movie review."
73 | )
74 |
75 | self.assertEqual(len(generated_dataset), 2)
76 | self.assertEqual(generated_dataset.features["text"].dtype, "string")
77 | self.assertIn("text", generated_dataset.features)
78 |
79 | def test_annotation_with_fewshot_and_unlabeled_examples(self):
80 | label_options = ["positive", "negative"]
81 |
82 | unlabeled_dataset = Dataset.from_dict({
83 | "text": ["This movie was a blast!", "This movie was not bad!"],
84 | })
85 |
86 | prompt = BasePrompt(
87 | task_description="Annotate movie reviews as either: {}.",
88 | label_options=label_options,
89 | generate_data_for_column="label",
90 | fewshot_example_columns="text",
91 | )
92 |
93 | generated_dataset = self.generator.generate(
94 | prompt_template=prompt,
95 | fewshot_dataset=self.text_classification_dataset,
96 | fewshot_examples_per_class=1,
97 | fewshot_sampling_strategy="stratified",
98 | unlabeled_dataset=unlabeled_dataset,
99 | max_prompt_calls=2,
100 | dummy_response="A dummy movie review."
101 | )
102 |
103 | self.assertEqual(len(generated_dataset), 2)
104 | self.assertEqual(generated_dataset.features["text"].dtype, "string")
105 | self.assertEqual(generated_dataset.features["label"].dtype, "string")
106 | self.assertIn("text", generated_dataset.features)
107 | self.assertIn("label", generated_dataset.features)
108 | self.assertEqual(generated_dataset[0]["text"], "This movie was a blast!")
109 | self.assertEqual(generated_dataset[1]["text"], "This movie was not bad!")
110 |
111 | def test_tranlation(self):
112 | fewshot_dataset = Dataset.from_dict({
113 | "german": ["Der Film ist großartig!", "Der Film ist schlecht!"],
114 | "english": ["This movie is great!", "This movie is bad!"],
115 | })
116 |
117 | unlabeled_dataset = Dataset.from_dict({
118 | "english": ["This movie was a blast!", "This movie was not bad!"],
119 | })
120 |
121 | prompt = BasePrompt(
122 | task_description="Translate to german:", # Since we do not have a label column,
123 | # we can just specify the task description
124 | generate_data_for_column="german",
125 | fewshot_example_columns="english",
126 | )
127 |
128 | generated_dataset = self.generator.generate(
129 | prompt_template=prompt,
130 | fewshot_dataset=fewshot_dataset,
131 | fewshot_examples_per_class=2, # Take both fewshot examples per prompt
132 | fewshot_sampling_strategy=None,
133 | # Since we do not have a class label column, we can just set this to None
134 | # (default)
135 | unlabeled_dataset=unlabeled_dataset,
136 | max_prompt_calls=2,
137 | dummy_response="A dummy movie review."
138 | )
139 |
140 | self.assertEqual(len(generated_dataset), 2)
141 | self.assertEqual(generated_dataset.features["english"].dtype, "string")
142 | self.assertEqual(generated_dataset.features["german"].dtype, "string")
143 | self.assertIn("english", generated_dataset.features)
144 | self.assertIn("german", generated_dataset.features)
145 |
146 | def test_textual_similarity(self):
147 | dataset = load_dataset("glue", "mrpc", split="train")
148 | dataset, label_options = convert_label_ids_to_texts(dataset, "label", return_label_options=True) # convert the
149 | # label ids to text labels and return the label options
150 |
151 | fewshot_dataset = dataset.select(range(10))
152 | unlabeled_dataset = dataset.select(range(10, 20))
153 |
154 | prompt = BasePrompt(
155 | task_description="Annotate the sentence pair whether it is: {}",
156 | label_options=label_options,
157 | generate_data_for_column="label",
158 | fewshot_example_columns=["sentence1", "sentence2"],
159 | )
160 |
161 | generated_dataset, original_dataset = self.generator.generate(
162 | prompt_template=prompt,
163 | fewshot_dataset=fewshot_dataset,
164 | fewshot_examples_per_class=1,
165 | fewshot_sampling_column="label",
166 | fewshot_sampling_strategy="stratified",
167 | unlabeled_dataset=unlabeled_dataset,
168 | max_prompt_calls=2,
169 | return_unlabeled_dataset=True,
170 | dummy_response="A dummy movie review."
171 | )
172 |
173 | self.assertEqual(len(generated_dataset), 2)
174 | self.assertEqual(generated_dataset.features["sentence1"].dtype, "string")
175 | self.assertEqual(generated_dataset.features["sentence2"].dtype, "string")
176 | self.assertEqual(generated_dataset.features["label"].dtype, "string")
177 |
178 | def test_sampling_all_fewshot_examples(self):
179 | """Test sampling all fewshot examples"""
180 | prompt = BasePrompt(
181 | task_description="Generate a short movie review.",
182 | )
183 |
184 | prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples(
185 | prompt_template=prompt,
186 | fewshot_dataset=self.text_classification_dataset,
187 | fewshot_sampling_strategy=None,
188 | fewshot_examples_per_class=None,
189 | fewshot_sampling_column=None,
190 | )
191 |
192 | self.assertEqual(prompt_labels, None)
193 | self.assertEqual(len(fewshot_examples), 2)
194 |
195 | def test_sampling_all_fewshot_examples_with_label_options(self):
196 | """Test sampling all fewshot examples with label options"""
197 | prompt = BasePrompt(
198 | task_description="Generate a short movie review: {}.",
199 | label_options=["positive", "negative"],
200 | )
201 |
202 | prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples(
203 | prompt_template=prompt,
204 | fewshot_dataset=self.text_classification_dataset,
205 | fewshot_sampling_strategy=None,
206 | fewshot_examples_per_class=None,
207 | fewshot_sampling_column=None,
208 | )
209 |
210 | prompt_text = prompt.task_description.format(prompt_labels)
211 | self.assertIn("positive", prompt_text)
212 | self.assertIn("negative", prompt_text)
213 |
214 | def test_sampling_uniform_fewshot_examples(self):
215 | """Test uniform sampling fewshot examples"""
216 | prompt = BasePrompt(
217 | task_description="Generate a short movie review: {}.",
218 | label_options=["positive", "negative"],
219 | )
220 |
221 | prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples(
222 | prompt_template=prompt,
223 | fewshot_dataset=self.text_classification_dataset,
224 | fewshot_sampling_strategy="uniform",
225 | fewshot_examples_per_class=1,
226 | fewshot_sampling_column="label",
227 | )
228 |
229 | self.assertEqual(len(fewshot_examples), 1)
230 | self.assertIn(prompt_labels, ["positive", "negative"])
231 |
232 | def test_sampling_stratified_fewshot_examples(self):
233 | """Test stratified sampling fewshot examples"""
234 | larger_fewshot_dataset = Dataset.from_dict({
235 | "text": ["This movie is great!", "This movie is bad!", "This movie is great!", "This movie is bad!"],
236 | "label": ["positive", "negative", "positive", "negative"]
237 | })
238 |
239 | prompt = BasePrompt(
240 | task_description="Generate a short movie review: {}.",
241 | label_options=["positive", "negative"],
242 | )
243 |
244 | prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples(
245 | prompt_template=prompt,
246 | fewshot_dataset=larger_fewshot_dataset,
247 | fewshot_sampling_strategy="stratified",
248 | fewshot_examples_per_class=1,
249 | fewshot_sampling_column="label",
250 | )
251 |
252 | self.assertEqual(len(fewshot_examples), 2)
253 | self.assertEqual(len(set(fewshot_examples["label"])), 2)
254 | self.assertEqual(len(prompt_labels), 2)
255 |
256 | def test_sampling_uniform_fewshot_examples_without_number_of_examples(self):
257 | """Test failure of uniform sampling fewshot examples if attributes are missing"""
258 | prompt = BasePrompt(
259 | task_description="Generate a short movie review: {}.",
260 | label_options=["positive", "negative"],
261 | )
262 |
263 | with self.assertRaises(KeyError):
264 | prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples(
265 | prompt_template=prompt,
266 | fewshot_dataset=self.text_classification_dataset,
267 | fewshot_sampling_strategy="uniform",
268 | fewshot_examples_per_class=None,
269 | fewshot_sampling_column=None,
270 | )
271 |
272 | def test_sampling_uniform_fewshot_examples_with_number_of_examples_without_sampling_column(self):
273 | """Test failure of uniform sampling fewshot examples if attributes are missing"""
274 | prompt = BasePrompt(
275 | task_description="Generate a short movie review: {}.",
276 | label_options=["positive", "negative"],
277 | )
278 |
279 | with self.assertRaises(KeyError):
280 | prompt_labels, fewshot_examples = self.generator._sample_fewshot_examples(
281 | prompt_template=prompt,
282 | fewshot_dataset=self.text_classification_dataset,
283 | fewshot_sampling_strategy="uniform",
284 | fewshot_examples_per_class=1,
285 | fewshot_sampling_column=None,
286 | )
287 |
288 | def test_dummy_response(self):
289 |
290 | prompt = BasePrompt(
291 | task_description="Generate a short movie review.",
292 | )
293 | generated_dataset = self.generator.generate(
294 | prompt_template=prompt,
295 | max_prompt_calls=2,
296 | dummy_response=lambda _: "This is a dummy movie review."
297 | )
298 |
299 | self.assertEqual(len(generated_dataset), 2)
300 | self.assertEqual(generated_dataset.features["text"].dtype, "string")
301 | self.assertIn("text", generated_dataset.features)
302 | self.assertEqual(generated_dataset[0]["text"], "This is a dummy movie review.")
303 | self.assertEqual(generated_dataset[1]["text"], "This is a dummy movie review.")
304 |
305 |
306 | generated_dataset = self.generator.generate(
307 | prompt_template=prompt,
308 | max_prompt_calls=2,
309 | dummy_response="This is a dummy movie review as a string."
310 | )
311 |
312 | self.assertEqual(len(generated_dataset), 2)
313 | self.assertEqual(generated_dataset.features["text"].dtype, "string")
314 | self.assertIn("text", generated_dataset.features)
315 | self.assertEqual(generated_dataset[0]["text"], "This is a dummy movie review as a string.")
316 | self.assertEqual(generated_dataset[1]["text"], "This is a dummy movie review as a string.")
317 |
--------------------------------------------------------------------------------
/tests/test_dataset_sampler.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from collections import Counter
4 | from datasets import load_dataset
5 |
6 | from fabricator.samplers import random_sampler, single_label_task_sampler, ml_mc_sampler, \
7 | single_label_stratified_sample
8 |
9 |
10 | def _flatten(l):
11 | return [item for sublist in l for item in sublist]
12 |
13 |
14 | class TestDatasetSamplerMethodsSingleLabel(unittest.TestCase):
15 | """Testcase for dataset sampler methods"""
16 |
17 | def setUp(self) -> None:
18 | self.dataset = load_dataset("imdb", split="train")
19 |
20 | def test_random_sampler(self):
21 | """Test random sampler"""
22 | random_sample = random_sampler(self.dataset, num_examples=10)
23 | self.assertEqual(len(random_sample), 10)
24 |
25 | def test_single_label_task_sampler(self):
26 | """Test single label task sampler. We use imdb which has two labels: positive and negative"""
27 | single_label_sample = single_label_task_sampler(self.dataset, label_column="label", num_examples=2)
28 | self.assertEqual(len(single_label_sample), 2)
29 | labels = list(single_label_sample["label"])
30 | self.assertEqual(len(set(labels)), 2)
31 |
32 | def test_single_label_task_sampler_more_examples_than_ds(self):
33 | """Test single label task sampler with more examples than dataset"""
34 | subset_dataset = self.dataset.select(range(100))
35 | single_label_sample = single_label_task_sampler(subset_dataset, label_column="label", num_examples=110)
36 | self.assertEqual(len(single_label_sample), 100)
37 |
38 |
39 | class TestDatasetSamplerMethodsMultiLabel(unittest.TestCase):
40 | """Testcase for multilabel dataset sampler methods"""
41 |
42 | def setUp(self) -> None:
43 | """Load dataset"""
44 | self.dataset = load_dataset("conll2003", split="train")
45 |
46 | def test_ml_mc_sampler(self):
47 | """Test multilabel multiclass sampler"""
48 | subset_dataset = ml_mc_sampler(self.dataset, labels_column="pos_tags", num_examples=10)
49 | label_idxs = list(range(len(self.dataset.features["pos_tags"].feature.names)))
50 | self.assertEqual(len(subset_dataset), 10)
51 |
52 | tags = set(_flatten([sample["pos_tags"] for sample in subset_dataset]))
53 |
54 | # We do not guarantee, that all tags are contained in sampled examples
55 | self.assertLessEqual(len(tags), len(label_idxs))
56 |
57 |
58 | class TestStratifiedSampler(unittest.TestCase):
59 |
60 | def setUp(self) -> None:
61 | """Load dataset"""
62 | self.dataset = load_dataset("trec", split="train")
63 |
64 | def test_stratified_sampler(self):
65 | """Test stratified sampler"""
66 | subset_dataset = single_label_stratified_sample(self.dataset, label_column="coarse_label",
67 | num_examples_per_class=2)
68 | label_idxs = list(range(len(self.dataset.features["coarse_label"].names)))
69 | self.assertEqual(len(subset_dataset), 2 * len(label_idxs))
70 |
71 | for occurences in Counter(subset_dataset["coarse_label"]).values():
72 | self.assertEqual(occurences, 2)
73 |
--------------------------------------------------------------------------------
/tests/test_dataset_transformations.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from datasets import load_dataset
4 |
5 | from fabricator.prompts import BasePrompt
6 | from fabricator.dataset_transformations.question_answering import *
7 | from fabricator.dataset_transformations.text_classification import *
8 | from fabricator.dataset_transformations.token_classification import *
9 |
10 |
11 | class TestTransformationsTextClassification(unittest.TestCase):
12 | """Testcase for ClassLabelPrompt"""
13 |
14 | def setUp(self) -> None:
15 | self.dataset = load_dataset("trec", split="train")
16 |
17 | def test_label_ids_to_textual_label(self):
18 | """Test transformation output only"""
19 | dataset, label_options = convert_label_ids_to_texts(self.dataset, "coarse_label", return_label_options=True)
20 | self.assertEqual(len(label_options), 6)
21 | self.assertEqual(set(label_options), set(self.dataset.features["coarse_label"].names))
22 | self.assertEqual(type(dataset[0]["coarse_label"]), str)
23 | self.assertNotEqual(type(dataset[0]["coarse_label"]), int)
24 | self.assertIn(dataset[0]["coarse_label"], label_options)
25 |
26 | def test_formatting_with_textual_labels(self):
27 | """Test formatting with textual labels"""
28 | dataset, label_options = convert_label_ids_to_texts(self.dataset, "coarse_label", return_label_options=True)
29 | fewshot_examples = dataset.select([1, 2, 3])
30 | prompt = BasePrompt(
31 | task_description="Annotate the question into following categories: {}.",
32 | generate_data_for_column="coarse_label",
33 | fewshot_example_columns="text",
34 | label_options=label_options,
35 | )
36 | raw_prompt = prompt.get_prompt_text(label_options, fewshot_examples)
37 | self.assertIn("text: What films featured the character Popeye Doyle ?\ncoarse_label: ENTY", raw_prompt)
38 | for label in label_options:
39 | self.assertIn(label, raw_prompt)
40 |
41 | def test_expanded_textual_labels(self):
42 | """Test formatting with expanded textual labels"""
43 | extended_mapping = {
44 | "DESC": "Description",
45 | "ENTY": "Entity",
46 | "ABBR": "Abbreviation",
47 | "HUM": "Human",
48 | "NUM": "Number",
49 | "LOC": "Location",
50 | }
51 | dataset, label_options = convert_label_ids_to_texts(
52 | self.dataset, "coarse_label", expanded_label_mapping=extended_mapping, return_label_options=True
53 | )
54 | self.assertIn("Location", label_options)
55 | self.assertNotIn("LOC", label_options)
56 | fewshot_examples = dataset.select([1, 2, 3])
57 | prompt = BasePrompt(
58 | task_description="Annotate the question into following categories: {}.",
59 | generate_data_for_column="coarse_label",
60 | fewshot_example_columns="text",
61 | label_options=label_options,
62 | )
63 | raw_prompt = prompt.get_prompt_text(label_options, fewshot_examples)
64 | self.assertIn("text: What films featured the character Popeye Doyle ?\ncoarse_label: Entity", raw_prompt)
65 | self.assertNotIn("ENTY", raw_prompt)
66 | for label in label_options:
67 | self.assertIn(label, raw_prompt)
68 |
69 | def test_textual_labels_to_label_ids(self):
70 | """Test conversion back to label ids"""
71 | dataset, label_options = convert_label_ids_to_texts(self.dataset, "coarse_label", return_label_options=True)
72 | self.assertIn(dataset[0]["coarse_label"], label_options)
73 | dataset = dataset.class_encode_column("coarse_label")
74 | self.assertIn(dataset[0]["coarse_label"], range(len(label_options)))
75 |
76 | def test_false_inputs_raises_error(self):
77 | """Test that false inputs raise errors"""
78 | with self.assertRaises(AttributeError):
79 | dataset, label_options = convert_label_ids_to_texts("coarse_label", "dataset")
80 |
81 |
82 | class TestTransformationsTokenClassification(unittest.TestCase):
83 | """Testcase for TokenLabelTransformations"""
84 |
85 | def setUp(self) -> None:
86 | self.dataset = load_dataset("conll2003", split="train").select(range(150))
87 |
88 | def test_bio_tokens_to_spans(self):
89 | """Test transformation output only (BIO to spans)"""
90 | dataset, label_options = convert_token_labels_to_spans(
91 | self.dataset, "tokens", "ner_tags", return_label_options=True
92 | )
93 | self.assertEqual(len(label_options), 4)
94 | self.assertEqual(type(dataset[0]["ner_tags"]), str)
95 | self.assertNotEqual(type(dataset[0]["ner_tags"]), int)
96 | spans = [span for span in dataset[0]["ner_tags"].split("\n")]
97 | for span in spans:
98 | self.assertTrue(any([label in span for label in label_options]))
99 |
100 | def test_formatting_with_span_labels(self):
101 | """Test formatting with span labels"""
102 | dataset, label_options = convert_token_labels_to_spans(
103 | dataset=self.dataset,
104 | token_column="tokens",
105 | label_column="ner_tags",
106 | return_label_options=True
107 | )
108 | fewshot_examples = dataset.select([1, 2, 3])
109 | prompt = BasePrompt(
110 | task_description="Annotate each of the following tokens with the following labels: {}.",
111 | generate_data_for_column="ner_tags",
112 | fewshot_example_columns="tokens",
113 | label_options=label_options,
114 | )
115 | raw_prompt = prompt.get_prompt_text(label_options, fewshot_examples)
116 | self.assertIn("Peter Blackburn is PER entity.", raw_prompt)
117 | self.assertIn("BRUSSELS is LOC entity.", raw_prompt)
118 | for label in label_options:
119 | self.assertIn(label, raw_prompt)
120 |
121 | def test_expanded_textual_labels(self):
122 | """Test formatting with expanded textual labels"""
123 | expanded_label_mapping = {
124 | 0: "O",
125 | 1: "B-person",
126 | 2: "I-person",
127 | 3: "B-location",
128 | 4: "I-location",
129 | 5: "B-organization",
130 | 6: "I-organization",
131 | 7: "B-miscellaneous",
132 | 8: "I-miscellaneous",
133 | }
134 |
135 | dataset, label_options = convert_token_labels_to_spans(
136 | dataset=self.dataset,
137 | token_column="tokens",
138 | label_column="ner_tags",
139 | expanded_label_mapping=expanded_label_mapping,
140 | return_label_options=True
141 | )
142 | fewshot_examples = dataset.select([1, 2, 3])
143 | prompt = BasePrompt(
144 | task_description="Annotate each of the following tokens with the following labels: {}.",
145 | generate_data_for_column="ner_tags",
146 | fewshot_example_columns="tokens",
147 | label_options=label_options,
148 | )
149 | raw_prompt = prompt.get_prompt_text(label_options, fewshot_examples)
150 | self.assertIn("Peter Blackburn is person entity.", raw_prompt)
151 | self.assertNotIn("PER", raw_prompt)
152 | for label in label_options:
153 | self.assertIn(label, raw_prompt)
154 |
155 | def test_textual_labels_to_label_ids(self):
156 | """Test conversion back to label ids on token-level"""
157 | dataset, label_options = convert_token_labels_to_spans(
158 | dataset=self.dataset,
159 | token_column="tokens",
160 | label_column="ner_tags",
161 | return_label_options=True
162 | )
163 | id2label = dict(enumerate(self.dataset.features["ner_tags"].feature.names))
164 | self.assertEqual(dataset[0]["ner_tags"], "EU is ORG entity.\nGerman is MISC entity.\nBritish is MISC entity.")
165 | dataset = dataset.select(range(10))
166 | dataset = convert_spans_to_token_labels(
167 | dataset=dataset,
168 | token_column="tokens",
169 | label_column="ner_tags",
170 | id2label=id2label
171 | )
172 | for label in dataset[0]["ner_tags"]:
173 | self.assertIn(label, id2label.keys())
174 |
175 | def test_false_inputs_raises_error(self):
176 | """Test that false inputs raise errors"""
177 | with self.assertRaises(AttributeError):
178 | dataset, label_options = convert_token_labels_to_spans(
179 | "ner_tags", "tokens", {1: "a", 2: "b", 3: "c"}
180 | )
181 |
182 | with self.assertRaises(AttributeError):
183 | dataset, label_options = convert_token_labels_to_spans(
184 | {1: "a", 2: "b", 3: "c"}, "tokens", "ner_tags"
185 | )
186 |
187 |
188 | class TestTransformationsQuestionAnswering(unittest.TestCase):
189 | """Testcase for QA Transformations"""
190 |
191 | def setUp(self) -> None:
192 | self.dataset = load_dataset("squad_v2", split="train")
193 |
194 | def test_squad_preprocessing(self):
195 | """Test transformation from squad fromat into flat structure"""
196 | self.assertEqual(type(self.dataset[0]["answers"]), dict)
197 | dataset = preprocess_squad_format(self.dataset.select(range(30)))
198 | self.assertEqual(type(dataset[0]["answers"]), str)
199 | with self.assertRaises(KeyError):
200 | x = dataset[0]["answer_start"]
201 |
202 | def test_squad_postprocessing(self):
203 | """Test transformation flat structure into squad format"""
204 | dataset = preprocess_squad_format(self.dataset.select(range(50)))
205 | dataset = postprocess_squad_format(dataset)
206 | self.assertEqual(type(dataset[0]["answers"]), dict)
207 | self.assertIn("answer_start", dataset[0]["answers"])
208 |
--------------------------------------------------------------------------------
/tests/test_prompts.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from datasets import load_dataset, Dataset, QuestionAnsweringExtractive, TextClassification, Summarization
4 |
5 | from fabricator.prompts import (
6 | BasePrompt,
7 | infer_prompt_from_task_template,
8 | )
9 |
10 |
11 | class TestPrompt(unittest.TestCase):
12 | """Testcase for Prompts"""
13 |
14 | def setUp(self) -> None:
15 | """Set up test dataset"""
16 | self.dataset = Dataset.from_dict({
17 | "text": ["This movie is great!", "This movie is bad!"],
18 | "label": ["positive", "negative"]
19 | })
20 |
21 | def test_plain_template(self):
22 | """Test plain prompt template"""
23 | prompt_template = BasePrompt(task_description="Generate movies reviews.")
24 | self.assertEqual(prompt_template.get_prompt_text(), 'Generate movies reviews.\n\ntext: ')
25 | self.assertEqual(prompt_template.target_formatting_template, "text: ")
26 | self.assertEqual(prompt_template.generate_data_for_column, None)
27 | self.assertEqual(prompt_template.fewshot_example_columns, None)
28 |
29 | with self.assertRaises(TypeError):
30 | prompt_template = BasePrompt()
31 |
32 | def test_template_with_label_options(self):
33 | """Test prompt template with label options"""
34 | label_options = ["positive", "negative"]
35 | prompt_template = BasePrompt(
36 | task_description="Generate a {} movie review.",
37 | label_options=label_options,
38 | )
39 | self.assertIn("positive", prompt_template.get_prompt_text(label_options[0]))
40 | self.assertIn("negative", prompt_template.get_prompt_text(label_options[1]))
41 | self.assertEqual(prompt_template.target_formatting_template, "text: ")
42 |
43 | def test_initialization_only_target_column(self):
44 | """Test initialization with only target column"""
45 | prompt_template = BasePrompt(
46 | task_description="Generate similar movie reviews.",
47 | generate_data_for_column="text",
48 | )
49 | self.assertEqual(prompt_template.relevant_columns_for_fewshot_examples, ["text"])
50 | self.assertEqual(type(prompt_template.generate_data_for_column), list)
51 | self.assertEqual(len(prompt_template.generate_data_for_column), 1)
52 | self.assertEqual(prompt_template.fewshot_example_columns, None)
53 |
54 | prompt_text = 'Generate similar movie reviews.\n\ntext: This movie is great!\n\ntext: ' \
55 | 'This movie is bad!\n\ntext: '
56 |
57 | self.assertEqual(prompt_template.get_prompt_text(None, self.dataset), prompt_text)
58 |
59 | def test_initialization_target_and_fewshot_columns(self):
60 | """Test initialization with target and fewshot columns"""
61 | prompt_template = BasePrompt(
62 | task_description="Generate movie reviews.",
63 | generate_data_for_column="label",
64 | fewshot_example_columns="text"
65 | )
66 | self.assertEqual(prompt_template.relevant_columns_for_fewshot_examples, ["text", "label"])
67 | self.assertEqual(type(prompt_template.generate_data_for_column), list)
68 | self.assertEqual(len(prompt_template.generate_data_for_column), 1)
69 | self.assertEqual(type(prompt_template.fewshot_example_columns), list)
70 |
71 | prompt_text = 'Generate movie reviews.\n\ntext: This movie is great!\nlabel: positive\n\n' \
72 | 'text: This movie is bad!\nlabel: negative\n\ntext: {text}\nlabel: '
73 |
74 | self.assertEqual(prompt_template.get_prompt_text(None, self.dataset), prompt_text)
75 |
76 | def test_initialization_with_multiple_fewshot_columns(self):
77 | """Test initialization with multiple fewshot columns"""
78 | text_label_prompt = BasePrompt(
79 | task_description="Test two fewshot columns.",
80 | generate_data_for_column="label",
81 | fewshot_example_columns=["fewshot1", "fewshot2"],
82 | )
83 | self.assertEqual(text_label_prompt.relevant_columns_for_fewshot_examples, ["fewshot1", "fewshot2", "label"])
84 | self.assertEqual(type(text_label_prompt.fewshot_example_columns), list)
85 | self.assertEqual(len(text_label_prompt.fewshot_example_columns), 2)
86 |
87 | def test_custom_formatting_template(self):
88 | label_options = ["positive", "negative"]
89 |
90 | fewshot_examples = Dataset.from_dict({
91 | "text": ["This movie is great!", "This movie is bad!"],
92 | "label": label_options
93 | })
94 |
95 | prompt = BasePrompt(
96 | task_description="Annotate the sentiment of the following movie review whether it is: {}.",
97 | generate_data_for_column="label",
98 | fewshot_example_columns="text",
99 | fewshot_formatting_template="Movie Review: {text}\nSentiment: {label}",
100 | target_formatting_template="Movie Review: {text}\nSentiment: ",
101 | label_options=label_options,
102 | )
103 |
104 | prompt_text = prompt.get_prompt_text(label_options, fewshot_examples)
105 |
106 | self.assertIn("whether it is: positive, negative.", prompt_text)
107 | self.assertIn("Movie Review: This movie is great!\nSentiment: positive", prompt_text)
108 | self.assertIn("Movie Review: This movie is bad!\nSentiment: negative", prompt_text)
109 | self.assertIn("Movie Review: {text}\nSentiment: ", prompt.target_formatting_template)
110 |
111 |
112 | class TestDownstreamTasks(unittest.TestCase):
113 | """Testcase for downstream tasks"""
114 |
115 | def setUp(self) -> None:
116 | """Set up test datasets"""
117 |
118 | def preprocess_qa(example):
119 | if example["answer"]:
120 | example["answer"] = example["answer"].pop()
121 | else:
122 | example["answer"] = ""
123 | return example
124 |
125 | self.text_classification = load_dataset("trec", split="train").select([1, 2, 3])
126 | self.question_answering = load_dataset("squad", split="train").flatten()\
127 | .rename_column("answers.text", "answer").map(preprocess_qa).select([1, 2, 3])
128 | self.ner = load_dataset("conll2003", split="train").select([1, 2, 3])
129 | self.translation = load_dataset("opus100", "de-nl", split="test").flatten()\
130 | .rename_columns({"translation.de": "german", "translation.nl": "dutch"}).select([1, 2, 3])
131 |
132 | def test_translation(self):
133 | prompt = BasePrompt(
134 | task_description="Given a german phrase, translate it into dutch.",
135 | generate_data_for_column="dutch",
136 | fewshot_example_columns="german",
137 | )
138 |
139 | raw_prompt = prompt.get_prompt_text(None, self.translation)
140 | self.assertIn("Marktorganisation für Wein(1)", raw_prompt)
141 | self.assertIn("Gelet op Verordening (EG) nr. 1493/1999", raw_prompt)
142 | self.assertIn("na z'n partijtje golf", raw_prompt)
143 | self.assertIn("dutch: ", raw_prompt)
144 |
145 | def test_text_classification(self):
146 | label_options = self.text_classification.features["coarse_label"].names
147 |
148 | prompt = BasePrompt(
149 | task_description="Classify the question into one of the following categories: {}",
150 | label_options=label_options,
151 | generate_data_for_column="coarse_label",
152 | fewshot_example_columns="text",
153 | )
154 |
155 | raw_prompt = prompt.get_prompt_text(label_options, self.text_classification)
156 | self.assertIn(", ".join(label_options), raw_prompt)
157 | self.assertIn("What fowl grabs the spotlight after the Chinese Year of the Monkey ?", raw_prompt)
158 | self.assertIn("How can I find a list of celebrities ' real names ?", raw_prompt)
159 | self.assertIn("What films featured the character Popeye Doyle ?", raw_prompt)
160 | self.assertIn("coarse_label: 2", raw_prompt)
161 |
162 | def test_named_entity_recognition(self):
163 | label_options = self.ner.features["ner_tags"].feature.names
164 |
165 | prompt = BasePrompt(
166 | task_description="Classify each token into one of the following categories: {}",
167 | generate_data_for_column="ner_tags",
168 | fewshot_example_columns="tokens",
169 | label_options=label_options,
170 | )
171 |
172 | raw_prompt = prompt.get_prompt_text(label_options, self.ner)
173 | self.assertIn(", ".join(label_options), raw_prompt)
174 | self.assertIn("'BRUSSELS', '1996-08-22'", raw_prompt)
175 | self.assertIn("'Peter', 'Blackburn'", raw_prompt)
176 | self.assertIn("3, 4, 0, 0, 0, 0, 0, 0, 7, 0", raw_prompt)
177 | self.assertIn("ner_tags: [1, 2]", raw_prompt)
178 |
179 | def test_question_answering(self):
180 | prompt = BasePrompt(
181 | task_description="Given context and question, answer the question.",
182 | generate_data_for_column="answer",
183 | fewshot_example_columns=["context", "question"],
184 | )
185 |
186 | raw_prompt = prompt.get_prompt_text(None, self.question_answering)
187 | self.assertIn("answer: the Main Building", raw_prompt)
188 | self.assertIn("context: Architecturally, the school", raw_prompt)
189 | self.assertIn("question: The Basilica", raw_prompt)
190 | self.assertIn("context: {context}", raw_prompt)
191 |
192 |
193 | class TestAutoInference(unittest.TestCase):
194 | """Testcase for AutoInference"""
195 |
196 | def test_auto_infer_text_label_prompt(self):
197 | """Test auto inference of QuestionAnsweringExtractive task template"""
198 | task_template = QuestionAnsweringExtractive()
199 | prompt = infer_prompt_from_task_template(task_template)
200 | self.assertIsInstance(prompt, BasePrompt)
201 | self.assertEqual(prompt.fewshot_example_columns, ["context", "question"])
202 | self.assertEqual(prompt.generate_data_for_column, ["answers"])
203 |
204 | def test_auto_infer_class_label_prompt(self):
205 | """Test auto inference of TextClassification task template"""
206 | task_template = TextClassification()
207 | task_template.label_schema["labels"].names = ["neg", "pos"]
208 | prompt = infer_prompt_from_task_template(task_template)
209 | self.assertIsInstance(prompt, BasePrompt)
210 | self.assertEqual(prompt.fewshot_example_columns, ["text"])
211 | self.assertEqual(prompt.generate_data_for_column, ["labels"])
212 |
213 | def test_auto_infer_fails_for_unsupported_task(self):
214 | """Test auto inference of prompt fails for unsupported task template Summarization"""
215 | with self.assertRaises(ValueError):
216 | infer_prompt_from_task_template(Summarization())
217 |
--------------------------------------------------------------------------------
/tutorials/TUTORIAL-1_OVERVIEW.md:
--------------------------------------------------------------------------------
1 | # Tutorial 1: Fabricator Introduction
2 |
3 | ## 1) Dataset Generation
4 |
5 | ### 1.1) Recipe for Dataset Generation 📚
6 | When starting from scratch, to generate an arbitrary dataset, you need to implement some instance of:
7 |
8 | - **_Datasets_**: For few-shot examples and final storage of (pair-wise) data to train a small PLM.
9 | - **_LLMs_**: To annotate existing, unlabeled datasets or generate completely new ones.
10 | - **_Prompts_**: To format the provided inputs (task description, few-shot examples, etc.) to prompt the LLM.
11 | - **_Orchestrator_**: To aligns all components and steer the generation process.
12 |
13 | ### 1.2) Creating a Workflow From Scratch Requires Careful Consideration of Intricate Details 👨🍳
14 | The following figure illustrates the typical generation workflow when using large language models as teachers for
15 | smaller pre-trained language models (PLMs) like BERT. Establishing this workflow demands attention to implementation
16 | details and requires boilerplate. Further, the setup process may vary based on a particular LLM or dataset format.
17 |
18 |
19 |
20 | The generation workflow when using LLMs as a teacher to smaller PLMs such as BERT.
21 |
22 |
23 | ### 1.3) Efficiently Generate Datasets With Fabricator 🍜
24 |
25 | With fabricator, you simply need to define your generation settings,
26 | e.g. how many few-shot examples to include per prompt, how to sample few-shot instances from a pool of available
27 | examples, or which LLM to use. In addition, everything is built on top of Hugging Face's
28 | [datasets](https://github.com/huggingface/datasets) library that you can directly
29 | incorporate the generated datasets in your usual training workflows or share them via the Hugging Face hub.
30 |
31 | ## 2) Fabricator Compoments
32 |
33 | ### 2.1) Datasets
34 |
35 | Datasets are build upon the `Dataset` class of the huggingface datasets library. They are used to store the data in a
36 | tabular format and provide a convenient way to access the data. The generated datasets will always be in that format such
37 | that they can be easily integrated with standard machine learning frameworks or shared with the research community via
38 | the huggingface hub.
39 |
40 | ```python
41 | from datasets import load_dataset
42 |
43 | # Load the imdb dataset from the huggingface hub, i.e. for annotation
44 | dataset = load_dataset("imdb")
45 |
46 | # Load custom dataset from a jsonl file
47 | dataset = load_dataset("json", data_files="path/to/file.jsonl")
48 |
49 | # Share generated dataset with the huggingface hub
50 | dataset.push_to_hub("my-dataset")
51 | ```
52 |
53 | ### 2.2) LLMs
54 | We simply use haystack's `PromptNode` as our LLM interface. The PromptNode is a wrapper for multiple LLMs such as the ones
55 | from OpenAI or all available models on the huggingface hub. You can set all generation-related parameters such as
56 | temperature, top_k, maximum generation length via the PromptNode (see also the [documentation](https://docs.haystack.deepset.ai/docs/prompt_node)).
57 |
58 | ```python
59 | import os
60 | from haystack.nodes import PromptNode
61 |
62 | # Load a model from huggingface hub
63 | prompt_node = PromptNode("google/flan-t5-base")
64 |
65 | # Create a PromptNode with the OpenAI API
66 | prompt_node = PromptNode(
67 | model_name_or_path="text-davinci-003",
68 | api_key=os.environ.get("OPENAI_API_KEY"),
69 | )
70 | ```
71 |
72 | ### 2.3) Prompts
73 |
74 | The `BasePrompt` class is used to format the prompts for the LLMs. This class is highly flexible and thus can be
75 | adapted to various settings:
76 | - define a `task_description` (e.g. "Generate a [_label_] movie review.") to generate data for certain class, e.g. a movie review for the label "positive".
77 | - include pre-defined `label_options` (e.g. "Annotate the following review with one of the followings labels: positive, negative.") when annotating unlabeled datasets.
78 | - customize format of fewshot examples inside the prompt
79 |
80 | #### Prompt for generating plain text
81 |
82 | ```python
83 | from fabricator.prompts import BasePrompt
84 |
85 | prompt_template = BasePrompt(task_description="Generate a movie review.")
86 | print(prompt_template.get_prompt_text())
87 | ```
88 | Prompt during generation:
89 | ```console
90 | Generate a movie reviews.
91 |
92 | text:
93 | ```
94 |
95 | #### Prompt for label-conditioned generation with label options
96 |
97 | ```python
98 | from fabricator.prompts import BasePrompt
99 |
100 | label_options = ["positive", "negative"]
101 | prompt_template = BasePrompt(
102 | task_description="Generate a {} movie review.",
103 | label_options=label_options,
104 | )
105 | ```
106 |
107 | Label-conditioned prompts during generation:
108 | ```console
109 | Generate a positive movie review.
110 |
111 | text:
112 | ---
113 |
114 | Generate a negative movie review.
115 |
116 | text:
117 | ---
118 | ```
119 |
120 | Note: You can define in the orchestrator class the desired distribution of labels, e.g. uniformly
121 | sampling from both labels in the examples in each iteration.
122 |
123 | #### Prompt with few-shot examples
124 |
125 | ```python
126 | from datasets import Dataset
127 | from fabricator.prompts import BasePrompt
128 |
129 | label_options = ["positive", "negative"]
130 |
131 | fewshot_examples = Dataset.from_dict({
132 | "text": ["This movie is great!", "This movie is bad!"],
133 | "label": label_options
134 | })
135 |
136 | prompt_template = BasePrompt(
137 | task_description="Generate a {} movie review.",
138 | label_options=label_options,
139 | generate_data_for_column="text",
140 | )
141 | ```
142 |
143 | Prompts with few-shot examples during generation:
144 | ```console
145 | Generate a positive movie review.
146 |
147 | text: This movie is great!
148 |
149 | text:
150 | ---
151 |
152 | Generate a negative movie review.
153 |
154 | text: This movie is bad!
155 |
156 | text:
157 | ---
158 | ```
159 |
160 | Note: The `generate_data_for_column` attribute defines the column of the few-shot dataset for which additional data is generated.
161 | As previously shown, the orchestrator will sample a label and includes a matching few-shot example.
162 |
163 | ### 2.4) DatasetGenerator
164 |
165 | The `DatasetGenerator` class is fabricator's orchestrator. It takes a `Dataset`, a `PromptNode` and a
166 | `BasePrompt` as inputs and generates the final dataset based on these instances. The `generate` method returns a `Dataset` object that can be used with standard machine learning
167 | frameworks such as [flair](https://github.com/flairNLP/flair), deepset's [haystack](https://github.com/deepset-ai/haystack), or Hugging Face's [transformers](https://github.com/huggingface/transformers).
168 |
169 | ```python
170 | from fabricator import BasePrompt, DatasetGenerator
171 |
172 | prompt = BasePrompt(task_description="Generate a movie review.")
173 |
174 | prompt_node = PromptNode("google/flan-t5-base")
175 |
176 | generator = DatasetGenerator(prompt_node)
177 | generated_dataset = generator.generate(
178 | prompt_template=prompt,
179 | max_prompt_calls=10,
180 | )
181 | ```
182 |
183 | In the following [tutorial](TUTORIAL-2_GENERATION_WORKFLOWS.md), we introduce the different generation processes covered by fabricator.
184 |
--------------------------------------------------------------------------------
/tutorials/TUTORIAL-2_GENERATION_WORKFLOWS.md:
--------------------------------------------------------------------------------
1 | # Tutorial 2: Generation Workflows
2 |
3 | In this tutorial, you will learn:
4 | 1. how to generate datasets
5 | 2. how to annotate unlabeled datasets
6 | 3. how to configure hyperparameters for your generation process
7 |
8 | ## 1) Generating Datasets
9 |
10 | ### 1.1) Generating Plain Text
11 |
12 | In this example, we demonstrate how to merge components created in previous tutorials by fabricators to create a
13 | movie review dataset. We don't explicitly direct the Language Learning Model (LLM) to generate movie reviews for
14 | specific labels (such as binary sentiment) or offer a few examples to guide the LLM in generating similar content.
15 | Instead, all it requires is a task description. The LLM then produces a dataset containing movie reviews based on the
16 | provided instructions. This dataset can be easily uploaded to the Hugging Face Hub.
17 |
18 | ```python
19 | import os
20 | from haystack.nodes import PromptNode
21 | from fabricator import DatasetGenerator
22 | from fabricator.prompts import BasePrompt
23 |
24 | prompt = BasePrompt(
25 | task_description="Generate a very very short movie review.",
26 | )
27 |
28 | prompt_node = PromptNode(
29 | model_name_or_path="gpt-3.5-turbo",
30 | api_key=os.environ.get("OPENAI_API_KEY"),
31 | max_length=100,
32 | )
33 |
34 | generator = DatasetGenerator(prompt_node)
35 | generated_dataset = generator.generate(
36 | prompt_template=prompt,
37 | max_prompt_calls=10,
38 | )
39 |
40 | generated_dataset.push_to_hub("your-first-generated-dataset")
41 | ```
42 |
43 | ### 1.2) Generate Label-Conditioned Datasets With Label Options and Few-Shot Examples
44 | To create datasets that are conditioned on specific labels and use few-shot examples,
45 | we need a few-shot dataset that is already annotated. The prompt should have the same labels as those in the few-shot
46 | dataset. Additionally, as explained in a previous tutorial, we must set the `generate_data_for_column`
47 | parameter to specify the column in the dataset for which we want to generate text.
48 |
49 | In the dataset generator, we define certain hyperparameters for the generation process. `fewshot_examples_per_class`
50 | determines how many few-shot examples are incorporated for each class per prompt. `fewshot_sampling_strategy`
51 | can be set to either "uniform" if each label has an equal chance of being sampled,
52 | or "stratified" if the distribution from the few-shot dataset needs to be preserved.
53 | `fewshot_sampling_column` specifies the dataset column representing the classes. `max_prompt_calls`
54 | sets the limit for how many prompts should be generated.
55 |
56 | Crucially, the prompt instance contains all the details about how a single prompt for a specific data point should
57 | be structured. This includes information like which few-shot examples should appear alongside which task instruction.
58 | On the other hand, the dataset generator defines the overall generation process,
59 | such as determining the label distribution, specified by the `fewshot_sampling_column`.
60 |
61 | ```python
62 | import os
63 | from datasets import Dataset
64 | from haystack.nodes import PromptNode
65 | from fabricator import DatasetGenerator
66 | from fabricator.prompts import BasePrompt
67 |
68 | label_options = ["positive", "negative"]
69 |
70 | fewshot_dataset = Dataset.from_dict({
71 | "text": ["This movie is great!", "This movie is bad!"],
72 | "label": label_options
73 | })
74 |
75 | prompt = BasePrompt(
76 | task_description="Generate a {} movie review.",
77 | label_options=label_options,
78 | generate_data_for_column="text",
79 | )
80 |
81 | prompt_node = PromptNode(
82 | model_name_or_path="gpt-3.5-turbo",
83 | api_key=os.environ.get("OPENAI_API_KEY"),
84 | max_length=100,
85 | )
86 |
87 | generator = DatasetGenerator(prompt_node)
88 | generated_dataset = generator.generate(
89 | prompt_template=prompt,
90 | fewshot_dataset=fewshot_dataset,
91 | fewshot_examples_per_class=1,
92 | fewshot_sampling_strategy="uniform",
93 | fewshot_sampling_column="label",
94 | max_prompt_calls=10,
95 | )
96 |
97 | generated_dataset.push_to_hub("your-first-generated-dataset")
98 | ```
99 |
100 | ## 2) Annotate unlabeled data with fewshot examples
101 |
102 | This example demonstrates how to add annotations to unlabeled data using few-shot examples. We have a few-shot dataset containing two columns: `text` and `label`, and an unlabeled dataset with only a `text` column. The goal is to annotate the unlabeled dataset using information from the few-shot dataset.
103 |
104 | To achieve this, we utilize the `DatasetGenerator.generate()` method. To begin, we provide the `unlabeled_dataset` argument, indicating the dataset we want to annotate. Additionally, we specify the `fewshot_examples_per_class` argument, determining how many few-shot examples to use for each class. In this scenario, we choose one example per class.
105 |
106 | The `fewshot_sampling_strategy` argument dictates how the few-shot dataset is sampled. In this case, we employ a stratified sampling strategy. This means that the generator will select precisely one example from each class within the few-shot dataset.
107 |
108 | It's worth noting that there's no need to explicitly specify the `fewshot_sampling_column` argument. By default, the generator uses the column specified in `generate_data_for_column` for this purpose.
109 |
110 | ```python
111 | import os
112 | from datasets import Dataset
113 | from haystack.nodes import PromptNode
114 | from fabricator import DatasetGenerator
115 | from fabricator.prompts import BasePrompt
116 |
117 | label_options = ["positive", "negative"]
118 |
119 | fewshot_dataset = Dataset.from_dict({
120 | "text": ["This movie is great!", "This movie is bad!"],
121 | "label": label_options
122 | })
123 |
124 | unlabeled_dataset = Dataset.from_dict({
125 | "text": ["This movie was a blast!", "This movie was not bad!"],
126 | })
127 |
128 | prompt = BasePrompt(
129 | task_description="Annotate movie reviews as either: {}.",
130 | label_options=label_options,
131 | generate_data_for_column="label",
132 | fewshot_example_columns="text",
133 | )
134 |
135 | prompt_node = PromptNode(
136 | model_name_or_path="gpt-3.5-turbo",
137 | api_key=os.environ.get("OPENAI_API_KEY"),
138 | max_length=100,
139 | )
140 |
141 | generator = DatasetGenerator(prompt_node)
142 | generated_dataset = generator.generate(
143 | prompt_template=prompt,
144 | fewshot_dataset=fewshot_dataset,
145 | fewshot_examples_per_class=1,
146 | fewshot_sampling_strategy="stratified",
147 | unlabeled_dataset=unlabeled_dataset,
148 | max_prompt_calls=10,
149 | )
150 |
151 | generated_dataset.push_to_hub("your-first-generated-dataset")
152 | ```
153 |
--------------------------------------------------------------------------------
/tutorials/TUTORIAL-3_ADVANCED-GENERATION.md:
--------------------------------------------------------------------------------
1 | # Tutorial 3: Advanced Dataset Generation
2 |
3 | ## Customizing Prompts
4 |
5 | Sometimes, you want to customize your prompt to your specific needs. For example, you might want to add a custom
6 | formatting template (the default takes the column names of the dataset):
7 |
8 | ```python
9 | from datasets import Dataset
10 | from fabricator.prompts import BasePrompt
11 |
12 | label_options = ["positive", "negative"]
13 |
14 | fewshot_examples = Dataset.from_dict({
15 | "text": ["This movie is great!", "This movie is bad!"],
16 | "label": label_options
17 | })
18 |
19 | prompt = BasePrompt(
20 | task_description="Annotate the sentiment of the following movie review whether it is: {}.",
21 | generate_data_for_column="label",
22 | fewshot_example_columns="text",
23 | fewshot_formatting_template="Movie Review: {text}\nSentiment: {label}",
24 | target_formatting_template="Movie Review: {text}\nSentiment: ",
25 | label_options=label_options,
26 | )
27 |
28 | print(prompt.get_prompt_text(label_options, fewshot_examples))
29 | ```
30 |
31 | which yields:
32 |
33 | ```text
34 | Annotate the sentiment of the following movie review whether it is: positive, negative.
35 |
36 | Movie Review: This movie is great!
37 | Sentiment: positive
38 |
39 | Movie Review: This movie is bad!
40 | Sentiment: negative
41 |
42 | Movie Review: {text}
43 | Sentiment:
44 | ```
45 |
46 | ## Inferring the Prompt from Dataset Info
47 |
48 | Huggingface Dataset objects provide the possibility to infer a prompt from the dataset. This can be achieved by using
49 | the `infer_prompt_from_dataset` function. This function takes a dataset
50 | as input and returns a `BasePrompt` object. The `BasePrompt` object contains the task description, the label options
51 | and the respective columns which can be used to generate a dataset with the `DatasetGenerator` class.
52 |
53 | ```python
54 | from datasets import load_dataset
55 | from fabricator.prompts import infer_prompt_from_dataset
56 |
57 | dataset = load_dataset("imdb", split="train")
58 | prompt = infer_prompt_from_dataset(dataset)
59 |
60 | print(prompt.get_prompt_text() + "\n---")
61 |
62 | label_options = dataset.features["label"].names
63 | fewshot_example = dataset.shuffle(seed=42).select([0])
64 |
65 | print(prompt.get_prompt_text(label_options, fewshot_example))
66 | ```
67 |
68 | The output of this script is:
69 |
70 | ```text
71 | Classify the following texts exactly into one of the following categories: {}.
72 |
73 | text: {text}
74 | label:
75 | ---
76 | Classify the following texts exactly into one of the following categories: neg, pos.
77 |
78 | text: There is no relation at all between Fortier and Profiler but the fact that both are police series [...]
79 | label: 1
80 |
81 | text: {text}
82 | label:
83 | ```
84 |
85 | This feature is particularly useful, if you have nested structures that follow a common format such as for
86 | extractive question answering:
87 |
88 | ```python
89 | from datasets import load_dataset
90 | from fabricator.prompts import infer_prompt_from_dataset
91 |
92 | dataset = load_dataset("squad_v2", split="train")
93 | prompt = infer_prompt_from_dataset(dataset)
94 |
95 | print(prompt.get_prompt_text() + "\n---")
96 |
97 | label_options = dataset.features["label"].names
98 | fewshot_example = dataset.shuffle(seed=42).select([0])
99 |
100 | print(prompt.get_prompt_text(label_options, fewshot_example))
101 | ```
102 |
103 | This script outputs:
104 |
105 | ```text
106 | Given a context and a question, generate an answer that occurs exactly and only once in the text.
107 |
108 | context: {context}
109 | question: {question}
110 | answers:
111 | ---
112 | Given a context and a question, generate an answer that occurs exactly and only once in the text.
113 |
114 | context: The Roman Catholic Church canon law also includes the main five rites (groups) of churches which are in full union with the Roman Catholic Church and the Supreme Pontiff:
115 | question: What term characterizes the intersection of the rites with the Roman Catholic Church?
116 | answers: {'text': ['full union'], 'answer_start': [104]}
117 |
118 | context: {context}
119 | question: {question}
120 | answers:
121 | ```
122 |
123 | ## Preprocess datasets
124 |
125 | In the previous example, we highlighted the simplicity of generating prompts using Hugging Face Datasets information.
126 | However, for optimal utilization of LLMs in generating text, it's essential to incorporate label names instead of IDs
127 | for text classification. Similarly, for question answering tasks, plain substrings are preferred over JSON-formatted
128 | strings. We'll elaborate on these limitations in the following example.
129 |
130 | ```text
131 | Classify the following texts exactly into one of the following categories: **neg, pos**.
132 |
133 | text: There is no relation at all between Fortier and Profiler but the fact that both are police series [...]
134 | **label: 1**
135 |
136 | ---
137 |
138 | Given a context and a question, generate an answer that occurs exactly and only once in the text.
139 |
140 | context: The Roman Catholic Church canon law also includes the main five rites (groups) of churches which are in full union with the Roman Catholic Church and the Supreme Pontiff:
141 | question: What term characterizes the intersection of the rites with the Roman Catholic Church?
142 | answers: **{'text': ['full union'], 'answer_start': [104]}**
143 | ```
144 |
145 | To overcome this, we provide a range of preprocessing functions for various downstream tasks.
146 |
147 | ### Text Classification
148 |
149 | The `convert_label_ids_to_texts` function transforms your text classification dataset with label IDs into textual
150 | labels. The default will be the label names specified in the features column.
151 |
152 | ```python
153 | from datasets import load_dataset
154 | from fabricator.prompts import infer_prompt_from_dataset
155 | from fabricator.dataset_transformations.text_classification import convert_label_ids_to_texts
156 |
157 | dataset = load_dataset("imdb", split="train")
158 | prompt = infer_prompt_from_dataset(dataset)
159 | dataset, label_options = convert_label_ids_to_texts(
160 | dataset=dataset,
161 | label_column="label",
162 | return_label_options=True
163 | )
164 |
165 | fewshot_example = dataset.shuffle(seed=42).select([0])
166 | print(prompt.get_prompt_text(label_options, fewshot_example))
167 | ```
168 |
169 | Which yields:
170 |
171 | ```text
172 | Classify the following texts exactly into one of the following categories: neg, pos.
173 |
174 | text: There is no relation at all between Fortier and Profiler but the fact that both are police series [...]
175 | label: pos
176 |
177 | text: {text}
178 | label:
179 | ```
180 |
181 | If we want to provide more meaningful names we can do so by specifying `expanded_label_mapping`.
182 | Remember to update the label options accordingly in the `BasePrompt` class.
183 |
184 | ```python
185 | extended_mapping = {0: "negative", 1: "positive"}
186 | dataset, label_options = convert_label_ids_to_texts(
187 | dataset=dataset,
188 | label_column="label",
189 | expanded_label_mapping=extended_mapping,
190 | return_label_options=True
191 | )
192 | prompt.label_options = label_options
193 | ```
194 |
195 | This yields:
196 |
197 | ```text
198 | Classify the following texts exactly into one of the following categories: positive, negative.
199 |
200 | text: There is no relation at all between Fortier and Profiler but the fact that both are police series [...]
201 | label: positive
202 |
203 | text: {text}
204 | label:
205 | ```
206 |
207 | Once the dataset is generated, one can easily convert the string labels back to label IDs by
208 | using huggingface's `class_encode_labels` function.
209 |
210 | ```python
211 | dataset = dataset.class_encode_column("label")
212 | print("Labels: " + str(dataset["label"][:5]))
213 | print("Features: " + str(dataset.features["label"]))
214 | ```
215 |
216 | Which yields:
217 |
218 | ```text
219 | Labels: [0, 1, 1, 0, 0]
220 | Features: ClassLabel(names=['negative', 'positive'], id=None)
221 | ```
222 |
223 | Note: While generating the dataset, the model is supposed to assign labels based on the specific options
224 | provided. However, we do not filter the data if it doesn't adhere to these predefined labels.
225 | Therefore, it's important to double-check if the annotations match the expected label options.
226 | If they don't, you should make corrections accordingly.
227 |
228 | ### Question Answering (Extractive)
229 |
230 | In question answering tasks, we offer two functions to handle dataset processing: preprocessing and postprocessing.
231 | The preprocessing function is responsible for transforming datasets from SQuAD format into flat strings.
232 | On the other hand, the postprocessing function reverses this process by converting flat predictions back into
233 | SQuAD format. It achieves this by determining the starting point of the answer and checking if the answer cannot be
234 | found in the given context or if it occurs multiple times.
235 |
236 | ```python
237 | from datasets import load_dataset
238 | from fabricator.prompts import infer_prompt_from_dataset
239 | from fabricator.dataset_transformations.question_answering import preprocess_squad_format, postprocess_squad_format
240 |
241 | dataset = load_dataset("squad_v2", split="train")
242 | prompt = infer_prompt_from_dataset(dataset)
243 |
244 | dataset = preprocess_squad_format(dataset)
245 |
246 | print(prompt.get_prompt_text(None, dataset.select([0])) + "\n---")
247 |
248 | dataset = postprocess_squad_format(dataset)
249 | print(dataset[0]["answers"])
250 | ```
251 |
252 | Which yields:
253 |
254 | ```text
255 | Given a context and a question, generate an answer that occurs exactly and only once in the text.
256 |
257 | context: Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny's Child. Managed by her father, Mathew Knowles, the group became one of the world's best-selling girl groups of all time. Their hiatus saw the release of Beyoncé's debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles "Crazy in Love" and "Baby Boy".
258 | question: When did Beyonce start becoming popular?
259 | answers: in the late 1990s
260 |
261 | context: {context}
262 | question: {question}
263 | answers:
264 | ---
265 | {'start': 269, 'text': 'in the late 1990s'}
266 | ```
267 |
268 | ### Named Entity Recognition
269 |
270 | If you attempt to create a dataset for named entity recognition without any preprocessing, the prompt might be
271 | challenging for the language model to understand.
272 |
273 | ```python
274 | from datasets import load_dataset
275 | from fabricator.prompts import BasePrompt
276 |
277 | dataset = load_dataset("conll2003", split="train")
278 | prompt = BasePrompt(
279 | task_description="Annotate each token with its named entity label: {}.",
280 | generate_data_for_column="ner_tags",
281 | fewshot_example_columns=["tokens"],
282 | label_options=dataset.features["ner_tags"].feature.names,
283 | )
284 |
285 | print(prompt.get_prompt_text(prompt.label_options, dataset.select([0])))
286 | ```
287 |
288 | Which outputs:
289 |
290 | ```text
291 | Annotate each token with its named entity label: O, B-PER, I-PER, B-ORG, I-ORG, B-LOC, I-LOC, B-MISC, I-MISC.
292 |
293 | tokens: ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.']
294 | ner_tags: [3, 0, 7, 0, 0, 0, 7, 0, 0]
295 |
296 | tokens: {tokens}
297 | ner_tags:
298 | ```
299 |
300 | To enhance prompt clarity, we can preprocess the dataset by converting labels into spans. This conversion can be
301 | accomplished using the `convert_token_labels_to_spans` function. Additionally, the function will provide the
302 | available label options:
303 |
304 | ```python
305 | from datasets import load_dataset
306 | from fabricator.prompts import BasePrompt
307 | from fabricator.dataset_transformations import convert_token_labels_to_spans
308 |
309 | dataset = load_dataset("conll2003", split="train")
310 | dataset, label_options = convert_token_labels_to_spans(dataset, "tokens", "ner_tags", return_label_options=True)
311 | prompt = BasePrompt(
312 | task_description="Annotate each token with its named entity label: {}.",
313 | generate_data_for_column="ner_tags",
314 | fewshot_example_columns=["tokens"],
315 | label_options=label_options,
316 | )
317 |
318 | print(prompt.get_prompt_text(prompt.label_options, dataset.select([0])))
319 | ```
320 | Output:
321 | ```text
322 | Annotate each token with its named entity label: MISC, ORG, PER, LOC.
323 |
324 | tokens: EU rejects German call to boycott British lamb .
325 | ner_tags: EU is ORG entity.
326 | German is MISC entity.
327 | British is MISC entity.
328 |
329 | tokens: {tokens}
330 | ner_tags:
331 | ```
332 |
333 | As in text classification, we can also specify more semantically precise labels with the `expanded_label_mapping`:
334 |
335 | ```python
336 | expanded_label_mapping = {
337 | 0: "O",
338 | 1: "B-person",
339 | 2: "I-person",
340 | 3: "B-location",
341 | 4: "I-location",
342 | 5: "B-organization",
343 | 6: "I-organization",
344 | 7: "B-miscellaneous",
345 | 8: "I-miscellaneous",
346 | }
347 |
348 | dataset, label_options = convert_token_labels_to_spans(
349 | dataset=dataset,
350 | token_column="tokens",
351 | label_column="ner_tags",
352 | expanded_label_mapping=expanded_label_mapping,
353 | return_label_mapping=True
354 | )
355 | ```
356 |
357 | Output:
358 |
359 | ```text
360 | Annotate each token with its named entity label: organization, person, location, miscellaneous.
361 |
362 | tokens: EU rejects German call to boycott British lamb .
363 | ner_tags: EU is organization entity.
364 | German is miscellaneous entity.
365 | British is miscellaneous entity.
366 |
367 | tokens: {tokens}
368 | ner_tags:
369 | ```
370 |
371 | Once the dataset is created, we can use the `convert_spans_to_token_labels` function to convert spans back to labels
372 | IDs. This function will only add spans the occur only once in the text. If a span occurs multiple times, it will be
373 | ignored. Note: this takes rather long is currently build on spacy. We are working on a faster implementation or welcome
374 | contributions.
375 |
376 | ```python
377 | from fabricator.dataset_transformations import convert_spans_to_token_labels
378 |
379 | dataset = convert_spans_to_token_labels(
380 | dataset=dataset.select(range(20)),
381 | token_column="tokens",
382 | label_column="ner_tags",
383 | id2label=expanded_label_mapping
384 | )
385 | ```
386 |
387 | Outputs:
388 |
389 | ```Text
390 | {'id': '1', 'tokens': ['Peter', 'Blackburn'], 'pos_tags': [22, 22], 'chunk_tags': [11, 12], 'ner_tags': [1, 2]}
391 | ```
392 |
393 | ## Adapt to arbitrary datasets
394 |
395 | The `BasePrompt` class is designed to be easily adaptable to arbitrary datasets. Just like in the examples for text
396 | classification, token classification or question answering, you can specify the task description, the column to generate
397 | data for and the fewshot example columns. The only difference is that you have to specify the optional label options
398 | yourself.
399 |
400 | ### Translation
401 |
402 | ```python
403 | import os
404 | from datasets import Dataset
405 | from haystack.nodes import PromptNode
406 | from fabricator import DatasetGenerator
407 | from fabricator.prompts import BasePrompt
408 |
409 | fewshot_dataset = Dataset.from_dict({
410 | "german": ["Der Film ist großartig!", "Der Film ist schlecht!"],
411 | "english": ["This movie is great!", "This movie is bad!"],
412 | })
413 |
414 | unlabeled_dataset = Dataset.from_dict({
415 | "english": ["This movie was a blast!", "This movie was not bad!"],
416 | })
417 |
418 | prompt = BasePrompt(
419 | task_description="Translate to german:", # Since we do not have a label column,
420 | # we can just specify the task description
421 | generate_data_for_column="german",
422 | fewshot_example_columns="english",
423 | )
424 |
425 | prompt_node = PromptNode(
426 | model_name_or_path="gpt-3.5-turbo",
427 | api_key=os.environ.get("OPENAI_API_KEY"),
428 | max_length=100,
429 | )
430 |
431 | generator = DatasetGenerator(prompt_node)
432 | generated_dataset = generator.generate(
433 | prompt_template=prompt,
434 | fewshot_dataset=fewshot_dataset,
435 | fewshot_examples_per_class=2, # Take both fewshot examples per prompt
436 | fewshot_sampling_strategy=None, # Since we do not have a class label column, we can just set this to None
437 | # (default)
438 | unlabeled_dataset=unlabeled_dataset,
439 | max_prompt_calls=2,
440 | )
441 |
442 | generated_dataset.push_to_hub("your-first-generated-dataset")
443 | ```
444 |
445 | ### Textual similarity
446 |
447 | ```python
448 | import os
449 | from datasets import load_dataset
450 | from haystack.nodes import PromptNode
451 | from fabricator import DatasetGenerator
452 | from fabricator.prompts import BasePrompt
453 | from fabricator.dataset_transformations.text_classification import convert_label_ids_to_texts
454 |
455 | dataset = load_dataset("glue", "mrpc", split="train")
456 | dataset, label_options = convert_label_ids_to_texts(dataset, "label", return_label_options=True) # convert the
457 | # label ids to text labels and return the label options
458 |
459 | fewshot_dataset = dataset.select(range(10))
460 | unlabeled_dataset = dataset.select(range(10, 20))
461 |
462 | prompt = BasePrompt(
463 | task_description="Annotate the sentence pair whether it is: {}",
464 | label_options=label_options,
465 | generate_data_for_column="label",
466 | fewshot_example_columns=["sentence1", "sentence2"], # we can pass an array of columns to use for the fewshot
467 | )
468 |
469 | prompt_node = PromptNode(
470 | model_name_or_path="gpt-3.5-turbo",
471 | api_key=os.environ.get("OPENAI_API_KEY"),
472 | max_length=100,
473 | )
474 |
475 | generator = DatasetGenerator(prompt_node)
476 | generated_dataset, original_dataset = generator.generate(
477 | prompt_template=prompt,
478 | fewshot_dataset=fewshot_dataset,
479 | fewshot_examples_per_class=1, # Take 1 fewshot examples per class per prompt
480 | fewshot_sampling_column="label", # We want to sample fewshot examples based on the label column
481 | fewshot_sampling_strategy="stratified", # We want to sample fewshot examples stratified by class
482 | unlabeled_dataset=unlabeled_dataset,
483 | max_prompt_calls=2,
484 | return_unlabeled_dataset=True, # We can return the original unlabelled dataset which might be interesting in this
485 | # case to compare the annotation quality
486 | )
487 |
488 | generated_dataset = generated_dataset.class_encode_column("label")
489 |
490 | generated_dataset.push_to_hub("your-first-generated-dataset")
491 | ```
492 |
493 | You can also easily switch out columns to be annotated if you want, for example, to generate a second sentence given a
494 | first sentence and a label like:
495 |
496 | ```python
497 | import os
498 | from datasets import load_dataset
499 | from haystack.nodes import PromptNode
500 | from fabricator import DatasetGenerator
501 | from fabricator.prompts import BasePrompt
502 | from fabricator.dataset_transformations.text_classification import convert_label_ids_to_texts
503 |
504 | dataset = load_dataset("glue", "mrpc", split="train")
505 | dataset, label_options = convert_label_ids_to_texts(dataset, "label", return_label_options=True) # convert the
506 | # label ids to text labels and return the label options
507 |
508 | fewshot_dataset = dataset.select(range(10))
509 | unlabeled_dataset = dataset.select(range(10, 20))
510 |
511 | prompt = BasePrompt(
512 | task_description="Generate a sentence that is {} to sentence1.",
513 | label_options=label_options,
514 | generate_data_for_column="sentence2",
515 | fewshot_example_columns=["sentence1", "label"], # we can pass an array of columns to use for the fewshot
516 | )
517 |
518 | prompt_node = PromptNode(
519 | model_name_or_path="gpt-3.5-turbo",
520 | api_key=os.environ.get("OPENAI_API_KEY"),
521 | max_length=100,
522 | )
523 |
524 | generator = DatasetGenerator(prompt_node)
525 | generated_dataset, original_dataset = generator.generate(
526 | prompt_template=prompt,
527 | fewshot_dataset=fewshot_dataset,
528 | fewshot_examples_per_class=1, # Take 1 fewshot examples per class per prompt
529 | fewshot_sampling_column="label", # We want to sample fewshot examples based on the label column
530 | fewshot_sampling_strategy="stratified", # We want to sample fewshot examples stratified by class
531 | unlabeled_dataset=unlabeled_dataset,
532 | max_prompt_calls=2,
533 | return_unlabeled_dataset=True, # We can return the original unlabelled dataset which might be interesting in this
534 | # case to compare the annotation quality
535 | )
536 |
537 | generated_dataset = generated_dataset.class_encode_column("label")
538 |
539 | generated_dataset.push_to_hub("your-first-generated-dataset")
540 | ```
541 |
--------------------------------------------------------------------------------