├── .coveragerc
├── .github
└── workflows
│ ├── build_documentation.yml
│ ├── build_pr_documentation.yml
│ ├── delete_doc_comment_trigger.yml
│ ├── quality.yml
│ ├── tests.yml
│ └── upload_pr_documentation.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── RELEASE.md
├── Zero_cost_Zero_time_Zero_shot_Financial_Sentiment_Analysis.ipynb
├── assets
└── setfit.png
├── docs
├── README.md
└── source
│ ├── _config.py
│ └── en
│ ├── _toctree.yml
│ ├── conceptual_guides
│ ├── sampling_strategies.mdx
│ └── setfit.mdx
│ ├── how_to
│ ├── absa.mdx
│ ├── batch_sizes.mdx
│ ├── callbacks.mdx
│ ├── classification_heads.mdx
│ ├── hyperparameter_optimization.mdx
│ ├── knowledge_distillation.mdx
│ ├── model_cards.mdx
│ ├── multilabel.mdx
│ ├── overview.mdx
│ ├── v1.0.0_migration_guide.mdx
│ └── zero_shot.mdx
│ ├── index.mdx
│ ├── installation.mdx
│ ├── quickstart.mdx
│ ├── reference
│ ├── main.mdx
│ ├── trainer.mdx
│ └── utility.mdx
│ └── tutorials
│ ├── onnx.mdx
│ ├── overview.mdx
│ └── zero_shot.mdx
├── final_results
├── adapet
│ ├── .gitkeep
│ ├── albert-xxlarge-v2.tar.gz
│ ├── xlm-roberta-base_all__lang_prompt.tar.gz
│ ├── xlm-roberta-base_each__lang_prompt.tar.gz
│ └── xlm-roberta-base_en__eng_prompt.tar.gz
├── distillation
│ └── distillation_results.tar.gz
├── perfect.tar.gz
├── setfit
│ ├── all-MiniLM-L6-v2.tar.gz
│ ├── all-roberta-large-v1.tar.gz
│ ├── paraphrase-mpnet-base-v2-epochs_20-with-augmentation.tar.gz
│ ├── paraphrase-mpnet-base-v2-epochs_20.tar.gz
│ ├── paraphrase-mpnet-base-v2-epochs_5.tar.gz
│ ├── paraphrase-mpnet-base-v2-linear-probe.tar.gz
│ └── paraphrase-multilingual-mpnet-base-v2.tar.gz
├── tfew
│ ├── t011b_pretrained-v2.tar.gz
│ └── t03b_pretrained-v2.tar.gz
└── transformers
│ ├── distilbert-base-uncased.tar.gz
│ ├── roberta-large.tar.gz
│ └── xlm-roberta-base.tar.gz
├── notebooks
├── .gitkeep
├── README.md
├── multilabel_HoC.ipynb
├── onnx_model_export.ipynb
├── openvino_inference.ipynb
├── setfit-absa-fiqa.ipynb
├── setfit-onnx-optimum.ipynb
├── setfit-optimum-intel.ipynb
├── text-classification.ipynb
├── text-classification_hyperparameter-search.ipynb
├── text-classification_multilabel.ipynb
├── zero-shot-classification.ipynb
└── zero_cost_zero_time_zero_shot_financial_sentiment_analysis.ipynb
├── scripts
├── adapet
│ ├── .gitkeep
│ └── ADAPET
│ │ ├── .gitignore
│ │ ├── README.md
│ │ ├── bin
│ │ ├── dev.sh
│ │ ├── init.sh
│ │ ├── setup.sh
│ │ ├── test.sh
│ │ └── train.sh
│ │ ├── cli.py
│ │ ├── config
│ │ ├── BoolQ.json
│ │ ├── CB.json
│ │ ├── COPA.json
│ │ ├── Generic.json
│ │ ├── MultiRC.json
│ │ ├── RTE.json
│ │ ├── ReCoRD.json
│ │ ├── WSC.json
│ │ ├── WiC.json
│ │ └── sst-2.json
│ │ ├── requirements.txt
│ │ ├── setfit_adapet.py
│ │ ├── src
│ │ ├── adapet.py
│ │ ├── adapet_test_eval.py
│ │ ├── data
│ │ │ ├── Batcher.py
│ │ │ ├── BoolQReader.py
│ │ │ ├── CBReader.py
│ │ │ ├── COPAReader.py
│ │ │ ├── Dataset.py
│ │ │ ├── DatasetReader.py
│ │ │ ├── GenericReader.py
│ │ │ ├── MultiRCReader.py
│ │ │ ├── RTEReader.py
│ │ │ ├── RecordReader.py
│ │ │ ├── WSCReader.py
│ │ │ ├── WiCReader.py
│ │ │ └── tokenize.py
│ │ ├── dev.py
│ │ ├── eval
│ │ │ ├── Scorer.py
│ │ │ ├── Writer.py
│ │ │ └── eval_model.py
│ │ ├── run_pretrained.py
│ │ ├── scripts
│ │ │ └── example_convert_sst_2_generic.py
│ │ ├── test.py
│ │ ├── train.py
│ │ └── utils
│ │ │ ├── Config.py
│ │ │ └── util.py
│ │ └── utilcode.py
├── create_summary_table.py
├── perfect
│ └── README.md
├── plot_summary_comparison.py
├── setfit
│ ├── README.md
│ ├── distillation_baseline.py
│ ├── run_fewshot.py
│ ├── run_fewshot_distillation.py
│ ├── run_fewshot_multilabel.py
│ ├── run_fewshot_multilingual.py
│ └── run_zeroshot.py
├── tfew
│ ├── README.md
│ ├── requirements.txt
│ ├── run_tfew_11b.sh
│ └── run_tfew_test.sh
└── transformers
│ ├── README.md
│ ├── requirements.txt
│ ├── run_fewshot.py
│ ├── run_fewshot_multilingual.py
│ ├── run_full.py
│ ├── run_full_multilingual.py
│ ├── run_inference.py
│ ├── run_zeroshot.py
│ └── utils.py
├── setup.cfg
├── setup.py
├── src
└── setfit
│ ├── __init__.py
│ ├── data.py
│ ├── exporters
│ ├── __init__.py
│ ├── onnx.py
│ ├── openvino.py
│ └── utils.py
│ ├── integrations.py
│ ├── logging.py
│ ├── losses.py
│ ├── model_card.py
│ ├── model_card_template.md
│ ├── modeling.py
│ ├── notebook.py
│ ├── sampler.py
│ ├── span
│ ├── __init__.py
│ ├── aspect_extractor.py
│ ├── modeling.py
│ └── trainer.py
│ ├── trainer.py
│ ├── trainer_distillation.py
│ ├── training_args.py
│ └── utils.py
├── tests
├── __init__.py
├── conftest.py
├── exporters
│ ├── test_onnx.py
│ └── test_openvino.py
├── model_card_pattern.py
├── span
│ ├── __init__.py
│ ├── aspect_model_card_pattern.py
│ ├── polarity_model_card_pattern.py
│ ├── test_model_card.py
│ ├── test_modeling.py
│ └── test_trainer.py
├── test_data.py
├── test_deprecated_trainer.py
├── test_deprecated_trainer_distillation.py
├── test_model_card.py
├── test_modeling.py
├── test_sampler.py
├── test_trainer.py
├── test_trainer_distillation.py
├── test_training_args.py
└── utils.py
└── utils
├── create_notebook_table.py
└── release.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | # Configuration file to control (pytest) coverage
2 | [run]
3 | # Run branch coverage, too
4 | branch = True
5 |
6 | [paths]
7 | source =
8 | src/setfit
9 |
10 | [report]
11 | # Regexes for lines to exclude from consideration
12 | exclude_lines =
13 | # Have to re-enable the standard pragma
14 | pragma: no cover
15 |
16 | # Don't complain about missing debug-only code:
17 | def __repr__
18 | if self\.debug
19 |
20 | # Don't complain if tests don't hit defensive assertion code:
21 | raise AssertionError
22 | raise NotImplementedError
23 |
24 | # Don't complain if non-runnable code isn't run:
25 | if 0:
26 | if __name__ == .__main__.:
27 |
28 | # Don't complain about abstract methods, they aren't run:
29 | @(abc\.)?abstractmethod
30 |
31 | # Ignore TYPE_CHECKING code
32 | if TYPE_CHECKING:
33 |
34 | [html]
35 | directory = coverage_report_html
36 | title = SetFit coverage report
--------------------------------------------------------------------------------
/.github/workflows/build_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build documentation
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - doc-builder*
8 | - v*-release
9 |
10 | jobs:
11 | build:
12 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
13 | with:
14 | commit_sha: ${{ github.sha }}
15 | package: setfit
16 | notebook_folder: setfit_doc
17 | languages: en
18 | secrets:
19 | token: ${{ secrets.HUGGINGFACE_PUSH }}
20 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
--------------------------------------------------------------------------------
/.github/workflows/build_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build PR Documentation
2 |
3 | on:
4 | pull_request:
5 |
6 | concurrency:
7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
8 | cancel-in-progress: true
9 |
10 | jobs:
11 | build:
12 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
13 | with:
14 | commit_sha: ${{ github.event.pull_request.head.sha }}
15 | pr_number: ${{ github.event.number }}
16 | package: setfit
17 | languages: en
--------------------------------------------------------------------------------
/.github/workflows/delete_doc_comment_trigger.yml:
--------------------------------------------------------------------------------
1 | name: Delete doc comment trigger
2 |
3 | on:
4 | pull_request:
5 | types: [ closed ]
6 |
7 |
8 | jobs:
9 | delete:
10 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
11 | with:
12 | pr_number: ${{ github.event.number }}
--------------------------------------------------------------------------------
/.github/workflows/quality.yml:
--------------------------------------------------------------------------------
1 | name: Quality
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - v*-release
8 | - v*-pre
9 | pull_request:
10 | branches:
11 | - main
12 | - v*-pre
13 | workflow_dispatch:
14 |
15 | jobs:
16 |
17 | check_code_quality:
18 | name: Check code quality
19 | runs-on: ubuntu-latest
20 | steps:
21 | - name: Checkout code
22 | uses: actions/checkout@v2
23 | - name: Setup Python environment
24 | uses: actions/setup-python@v2
25 | with:
26 | python-version: 3.9
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | python -m pip install ".[quality]"
31 | - name: Code quality
32 | run: |
33 | make quality
34 |
35 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Unit tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - v*-release
8 | - v*-pre
9 | pull_request:
10 | branches:
11 | - main
12 | - v*-pre
13 | workflow_dispatch:
14 |
15 | env:
16 | TRANSFORMERS_IS_CI: 1
17 |
18 | jobs:
19 |
20 | test_sampling:
21 | name: Run unit tests
22 | strategy:
23 | matrix:
24 | python-version: ['3.9', '3.10', '3.11', '3.12']
25 | os: [ubuntu-latest, windows-latest]
26 | requirements: ['.[tests]', '.[compat_tests]']
27 | fail-fast: false
28 | runs-on: ${{ matrix.os }}
29 | steps:
30 | - name: Checkout code
31 | uses: actions/checkout@v3
32 |
33 | - name: Setup Python environment
34 | uses: actions/setup-python@v4
35 | with:
36 | python-version: ${{ matrix.python-version }}
37 |
38 | # Taken from https://github.com/actions/cache?tab=readme-ov-file#creating-a-cache-key
39 | # Use date to invalidate cache every week
40 | - name: Get date
41 | id: get-date
42 | run: |
43 | echo "date=$(/bin/date -u "+%G%V")" >> $GITHUB_OUTPUT
44 | shell: bash
45 |
46 | - name: Try to load cached dependencies
47 | uses: actions/cache@v3
48 | id: restore-cache
49 | with:
50 | path: ${{ env.pythonLocation }}
51 | key: python-dependencies-${{ matrix.os }}-${{ steps.get-date.outputs.date }}-${{ matrix.python-version }}-${{ matrix.requirements }}-${{ hashFiles('setup.py') }}-${{ env.pythonLocation }}
52 |
53 | - name: Install external dependencies on cache miss
54 | run: |
55 | python -m pip install --no-cache-dir --upgrade pip
56 | python -m pip install --no-cache-dir ${{ matrix.requirements }}
57 | python -m pip install '.[codecarbon]'
58 | python -m spacy download en_core_web_lg
59 | python -m spacy download en_core_web_sm
60 | if: steps.restore-cache.outputs.cache-hit != 'true'
61 |
62 | - name: Install the checked-out setfit
63 | run: python -m pip install .
64 |
65 | - name: Restore HF models from cache
66 | uses: actions/cache/restore@v3
67 | with:
68 | path: |
69 | ~/.cache/huggingface/hub
70 | ~/.cache/torch
71 | key: hf-models-${{ matrix.os }}-${{ env.NEW_HF_CACHE_HASH }}
72 | restore-keys: |
73 | hf-models-${{ matrix.os }}-
74 |
75 | - name: Run unit tests
76 | shell: bash
77 | run: |
78 | echo "OLD_HF_CACHE_HASH=$(find ~/.cache/huggingface/hub ~/.cache/torch -type f -exec sha256sum {} + | LC_ALL=C sort | sha256sum | cut -d ' ' -f 1)" >> $GITHUB_ENV
79 | pytest -v tests/
80 | echo "NEW_HF_CACHE_HASH=$(find ~/.cache/huggingface/hub ~/.cache/torch -type f -exec sha256sum {} + | LC_ALL=C sort | sha256sum | cut -d ' ' -f 1)" >> $GITHUB_ENV
81 |
82 | - name: Save new HF models to cache
83 | uses: actions/cache/save@v3
84 | with:
85 | path: |
86 | ~/.cache/huggingface/hub
87 | ~/.cache/torch
88 | key: hf-models-${{ matrix.os }}-${{ env.NEW_HF_CACHE_HASH }}
89 | # Only save cache if the hash has changed
90 | if: env.NEW_HF_CACHE_HASH != env.OLD_HF_CACHE_HASH
91 |
--------------------------------------------------------------------------------
/.github/workflows/upload_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Upload PR Documentation
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["Build PR Documentation"]
6 | types:
7 | - completed
8 |
9 | jobs:
10 | build:
11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
12 | with:
13 | package_name: setfit
14 | secrets:
15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # pycharm
2 | .idea
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 | *.pyc
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | pip-wheel-metadata/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99 | __pypackages__/
100 |
101 | # Celery stuff
102 | celerybeat-schedule
103 | celerybeat.pid
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 | launch.json
135 |
136 | # VS Code
137 | .history
138 | .vscode/launch.json
139 |
140 | # Temporary results files
141 | scripts/**/results.json
142 | scripts/**/checkpoints
143 | scripts/**/train_script.py
144 | scripts/**/summary_table.csv
145 | scripts/tfew/t-few
146 | scripts/tfew/results
147 | scripts/tfew/run_tmux.sh
148 |
149 | # macOS
150 | .DS_Store
151 | .vscode/settings.json
152 |
153 | # Common SetFit Trainer logging folders
154 | wandb
155 | runs/
156 | notebooks/perf_metrics.json
157 | notebooks/nc_workspace/
158 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include src/setfit/model_card_template.md
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: notebooks
2 |
3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
4 | export PYTHONPATH = src
5 |
6 | check_dirs := scripts tests src setup.py
7 |
8 | style:
9 | python -m black --line-length 119 --exclude="scripts/adapet|scripts/tfew" --target-version py39 $(check_dirs)
10 | python -m isort --skip scripts/adapet --skip scripts/tfew $(check_dirs)
11 |
12 | quality:
13 | python -m black --check --line-length 119 --exclude="scripts/adapet|scripts/tfew" --target-version py39 $(check_dirs)
14 | python -m isort --check-only --skip scripts/adapet --skip scripts/tfew $(check_dirs)
15 | python -m flake8 --max-line-length 119 $(check_dirs)
16 |
17 | test:
18 | python -m pytest -sv tests/
19 |
20 | coverage:
21 | python -m pytest --cov=src --cov-report=term-missing -sv tests/
22 |
23 | notebooks:
24 | python utils/create_notebook_table.py
25 |
26 | # Release stuff
27 |
28 | pre-release:
29 | python utils/release.py
30 |
31 | pre-patch:
32 | python utils/release.py --patch
33 |
34 | post-release:
35 | python utils/release.py --post_release
36 |
37 | post-patch:
38 | python utils/release.py --post_release --patch
39 |
40 | wheels:
41 | python setup.py bdist_wheel && python setup.py sdist
42 |
43 | wheels_clean:
44 | rm -rf build && rm -rf dist
45 |
46 | pypi_upload:
47 | python -m pip install twine
48 | twine upload dist/* -r pypi
49 |
50 | pypi_test_upload:
51 | python -m pip install twine
52 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
53 |
54 | pypi_test_install:
55 | python -m pip install evaluate==0.3.0 datasets==2.3.2 sentence_transformers==2.2.2
56 | python -m pip install -i https://testpypi.python.org/pypi setfit
57 | python -c "from setfit import *"
58 | echo "🚀 Successfully installed setfit from test.pypi.org"
59 |
--------------------------------------------------------------------------------
/assets/setfit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/assets/setfit.png
--------------------------------------------------------------------------------
/docs/source/_config.py:
--------------------------------------------------------------------------------
1 | # docstyle-ignore
2 | INSTALL_CONTENT = """
3 | # SetFit installation
4 | ! pip install setfit
5 | # To install from source instead of the last release, comment the command above and uncomment the following one.
6 | # ! pip install git+https://github.com/huggingface/setfit.git
7 | """
8 |
9 | notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}]
--------------------------------------------------------------------------------
/docs/source/en/_toctree.yml:
--------------------------------------------------------------------------------
1 | - sections:
2 | - local: index
3 | title: SetFit
4 | - local: quickstart
5 | title: Quickstart
6 | - local: installation
7 | title: Installation
8 | title: Get started
9 |
10 | - sections:
11 | - local: tutorials/overview
12 | title: Overview
13 | - local: tutorials/zero_shot
14 | title: Zero-shot Text Classification
15 | - local: tutorials/onnx
16 | title: Efficiently run SetFit with ONNX
17 | title: Tutorials
18 |
19 | - sections:
20 | - local: how_to/overview
21 | title: Overview
22 | - local: how_to/callbacks
23 | title: Callbacks
24 | - local: how_to/model_cards
25 | title: Model Cards
26 | - local: how_to/classification_heads
27 | title: Classification Heads
28 | - local: how_to/multilabel
29 | title: Multilabel Text Classification
30 | - local: how_to/zero_shot
31 | title: Zero-shot Text Classification
32 | - local: how_to/hyperparameter_optimization
33 | title: Hyperparameter Optimization
34 | - local: how_to/knowledge_distillation
35 | title: Knowledge Distillation
36 | - local: how_to/batch_sizes
37 | title: Batch Sizes for Inference
38 | - local: how_to/absa
39 | title: Aspect Based Sentiment Analysis
40 | - local: how_to/v1.0.0_migration_guide
41 | title: v1.0.0 Migration Guide
42 | title: How-to Guides
43 |
44 | - sections:
45 | - local: conceptual_guides/setfit
46 | title: SetFit
47 | - local: conceptual_guides/sampling_strategies
48 | title: Sampling Strategies
49 | title: Conceptual Guides
50 |
51 | - sections:
52 | - local: reference/main
53 | title: Main classes
54 | - local: reference/trainer
55 | title: Trainer classes
56 | - local: reference/utility
57 | title: Utility
58 | title: Reference
--------------------------------------------------------------------------------
/docs/source/en/conceptual_guides/sampling_strategies.mdx:
--------------------------------------------------------------------------------
1 |
2 | # SetFit Sampling Strategies
3 |
4 | SetFit supports various contrastive pair sampling strategies in [`TrainingArguments`]. In this conceptual guide, we will learn about the following four sampling strategies:
5 |
6 | 1. `"oversampling"` (the default)
7 | 2. `"undersampling"`
8 | 3. `"unique"`
9 | 4. `"num_iterations"`
10 |
11 | Consider first reading the [SetFit conceptual guide](../setfit) for a background on contrastive learning and positive & negative pairs.
12 |
13 | ## Running example
14 |
15 | Throughout this conceptual guide, we will use to the following example scenario:
16 |
17 | * 3 classes: "happy", "content", and "sad".
18 | * 20 total samples: 8 "happy", 4 "content", and 8 "sad" samples.
19 |
20 | Considering that a sentence pair of `(X, Y)` and `(Y, X)` result in the same embedding distance/loss, we only want to consider one of those two cases. Furthermore, we don't want pairs where both sentences are the same, e.g. no `(X, X)`.
21 |
22 | The resulting positive and negative pairs can be visualized in a table like below. The `+` and `-` represent positive and negative pairs, respectively. Furthermore, `h-n` represents the n-th "happy" sentence, `c-n` the n-th "content" sentence, and `s-n` the n-th "sad" sentence. Note that the area below the diagonal is not used as `(X, Y)` and `(Y, X)` result in the same embedding distances, and that the diagonal is not used as we are not interested in pairs where both sentences are identical.
23 |
24 | | |h-1|h-2|h-3|h-4|h-5|h-6|h-7|h-8|c-1|c-2|c-3|c-4|s-1|s-2|s-3|s-4|s-5|s-6|s-7|s-8|
25 | |-------|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
26 | |**h-1**| | + | + | + | + | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - |
27 | |**h-2**| | | + | + | + | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - |
28 | |**h-3**| | | | + | + | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - |
29 | |**h-4**| | | | | + | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - |
30 | |**h-5**| | | | | | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - |
31 | |**h-6**| | | | | | | + | + | - | - | - | - | - | - | - | - | - | - | - | - |
32 | |**h-7**| | | | | | | | + | - | - | - | - | - | - | - | - | - | - | - | - |
33 | |**h-8**| | | | | | | | | - | - | - | - | - | - | - | - | - | - | - | - |
34 | |**c-1**| | | | | | | | | | + | + | + | - | - | - | - | - | - | - | - |
35 | |**c-2**| | | | | | | | | | | + | + | - | - | - | - | - | - | - | - |
36 | |**c-3**| | | | | | | | | | | | + | - | - | - | - | - | - | - | - |
37 | |**c-4**| | | | | | | | | | | | | - | - | - | - | - | - | - | - |
38 | |**s-1**| | | | | | | | | | | | | | + | + | + | + | + | + | + |
39 | |**s-2**| | | | | | | | | | | | | | | + | + | + | + | + | + |
40 | |**s-3**| | | | | | | | | | | | | | | | + | + | + | + | + |
41 | |**s-4**| | | | | | | | | | | | | | | | | + | + | + | + |
42 | |**s-5**| | | | | | | | | | | | | | | | | | + | + | + |
43 | |**s-6**| | | | | | | | | | | | | | | | | | | + | + |
44 | |**s-7**| | | | | | | | | | | | | | | | | | | | + |
45 | |**s-8**| | | | | | | | | | | | | | | | | | | | |
46 |
47 | As shown in the prior table, we have 28 positive pairs for "happy", 6 positive pairs for "content", and another 28 positive pairs for "sad". In total, this is 62 positive pairs. Also, we have 32 negative pairs between "happy" and "content", 64 negative pairs between "happy" and "sad", and 32 negative pairs between "content" and "sad". In total, this is 128 negative pairs.
48 |
49 | ## Oversampling
50 |
51 | By default, SetFit applies the oversampling strategy for its contrastive pairs. This strategy samples an equal amount of positive and negative training pairs, oversampling the minority pair type to match that of the majority pair type. As the number of negative pairs is generally larger than the number of positive pairs, this usually involves oversampling the positive pairs.
52 |
53 | In our running example, this would involve oversampling the 62 positive pairs up to 128, resulting in one epoch of 128 + 128 = 256 pairs. In summary:
54 |
55 | * ✅ An equal amount of positive and negative pairs are sampled.
56 | * ✅ Every possible pair is used.
57 | * ❌ There is some data duplication.
58 |
59 | ## Undersampling
60 |
61 | Like oversampling, this strategy samples an equal amount of positive and negative training pairs. However, it undersamples the majority pair type to match that of the minority pair type. This usually involves undersampling the negative pairs to match the positive pairs.
62 |
63 | In our running example, this would involve undersampling the 128 negative pairs down to 62, resulting in one epoch of 62 + 62 = 124 pairs. In summary:
64 |
65 | * ✅ An equal amount of positive and negative pairs are sampled.
66 | * ❌ **Not** every possible pair is used.
67 | * ✅ There is **no** data duplication.
68 |
69 | ## Unique
70 |
71 | Thirdly, the unique strategy does not sample an equal amount of positive and negative training pairs. Instead, it simply samples all possible pairs exactly once. No form of oversampling or undersampling is used here.
72 |
73 | In our running example, this would involve sampling all negative and positive pairs, resulting in one epoch of 62 + 128 = 190 pairs. In summary:
74 |
75 | * ❌ **Not** an equal amount of positive and negative pairs are sampled.
76 | * ✅ Every possible pair is used.
77 | * ✅ There is **no** data duplication.
78 |
79 | ## `num_iterations`
80 |
81 | Lastly, SetFit can still be used with a deprecated sampling strategy involving the `num_iterations` training argument. Unlike the other sampling strategies, this strategy does not involve the number of possible pairs. Instead, it samples `num_iterations` positive pairs and `num_iterations` negative pairs for each training sample.
82 |
83 | In our running example, if we assume `num_iterations=20`, then we would sample 20 positive pairs and 20 negative pairs per training sample. Because there's 20 samples, this involves (20 + 20) * 20 = 800 pairs. Because there are only 190 unique pairs, this certainly involves some data duplication. However, it does not guarantee that every possible pair is used. In summary:
84 |
85 | * ✅ **Not** an equal amount of positive and negative pairs are sampled.
86 | * ❌ Not necessarily every possible pair is used.
87 | * ❌ There is some data duplication.
--------------------------------------------------------------------------------
/docs/source/en/conceptual_guides/setfit.mdx:
--------------------------------------------------------------------------------
1 |
2 | # Sentence Transformers Finetuning (SetFit)
3 |
4 | SetFit is a model framework to efficiently train text classification models with surprisingly little training data. For example, with only 8 labeled examples per class on the Customer Reviews (CR) sentiment dataset, SetFit is competitive with fine-tuning RoBERTa Large on the full training set of 3k examples. Furthermore, SetFit is fast to train and run inference with, and can easily support multilingual tasks.
5 |
6 | Every SetFit model consists of two parts: a **sentence transformer** embedding model (the body) and a **classifier** (the head). These two parts are trained in two separate phases: the **embedding finetuning phase** and the **classifier training phase**. This conceptual guide will elaborate on the intuition between these phases, and why SetFit works so well.
7 |
8 | ## Embedding finetuning phase
9 |
10 | The first phase has one primary goal: finetune a sentence transformer embedding model to produce useful embeddings for *our* classification task. The [Hugging Face Hub](https://huggingface.co/models?library=sentence-transformers) already has thousands of sentence transformer available, many of which have been trained to very accurately group the embeddings of texts with similar semantic meaning.
11 |
12 | However, models that are good at Semantic Textual Similarity (STS) are not necessarily immediately good at *our* classification task. For example, according to an embedding model, the sentence of 1) `"He biked to work."` will be much more similar to 2) `"He drove his car to work."` than to 3) `"Peter decided to take the bicycle to the beach party!"`. But if our classification task involves classifying texts into transportation modes, then we want our embedding model to place sentences 1 and 3 closely together, and 2 further away.
13 |
14 | To do so, we can finetune the chosen sentence transformer embedding model. The goal here is to nudge the model to use its pretrained knowledge in a different way that better aligns with our classification task, rather than making it completely forget what it has learned.
15 |
16 | For finetuning, SetFit uses **contrastive learning**. This training approach involves creating **positive and negative pairs** of sentences. A sentence pair will be positive if both of the sentences are of the same class, and negative otherwise. For example, in the case of binary "positive"-"negative" sentiment analysis, `("The movie was awesome", "I loved it")` is a positive pair, and `("The movie was awesome", "It was quite disappointing")` is a negative pair.
17 |
18 | During training, the embedding model receives these pairs, and will convert the sentences to embeddings. If the pair is positive, then it will pull on the model weights such that the text embeddings will be more similar, and vice versa for a negative pair. Through this approach, sentences with the same label will be embedded more similarly, and sentences with different labels less similarly.
19 |
20 | Conveniently, this contrastive learning works with pairs rather than individual samples, and we can create plenty of unique pairs from just a few samples. For example, given 8 positive sentences and 8 negative sentences, we can create 28 positive pairs and 64 negative pairs for 92 unique training pairs. This grows exponentially to the number of sentences and classes, and that is why SetFit can train with just a few examples and still correctly finetune the sentence transformer embedding model. However, we should still be wary of overfitting.
21 |
22 | ## Classifier training phase
23 |
24 | Once the sentence transformer embedding model has been finetuned for our task at hand, we can start training the classifier. This phase has one primary goal: create a good mapping from the sentence transformer embeddings to the classes.
25 |
26 | Unlike with the first phase, training the classifier is done from scratch and using the labeled samples directly, rather than using pairs. By default, the classifier is a simple **logistic regression** classifier from scikit-learn. First, all training sentences are fed through the now-finetuned sentence transformer embedding model, and then the sentence embeddings and labels are used to fit the logistic regression classifier. The result is a strong and efficient classifier.
27 |
28 | Using these two parts, SetFit models are efficient, performant and easy to train, even on CPU-only devices.
29 |
--------------------------------------------------------------------------------
/docs/source/en/how_to/batch_sizes.mdx:
--------------------------------------------------------------------------------
1 |
2 | # Batch sizes for Inference
3 | In this how-to guide we will explore the effects of increasing the batch sizes in [`SetFitModel.predict`].
4 |
5 | ## What are they?
6 | When processing on GPUs, often times not all data fits on the GPU its VRAM at once. As a result, the data gets split up into **batches** of some often pre-determined batch size. This is done both during training and during inference. In both scenarios, increasing the batch size often has notable consequences to processing efficiency and VRAM memory usage, as transferring data to and from the GPU can be relatively slow.
7 |
8 | For inference, it is often recommended to set the batch size high to get notably quicker processing speeds.
9 |
10 | ## In SetFit
11 | The batch size for inference in SetFit is set to 32, but it can be affected by passing a `batch_size` argument to [`SetFitModel.predict`]. For example, on a RTX 3090 with a SetFit model based on the [paraphrase-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-mpnet-base-v2) Sentence Transformer, the following throughputs are reached:
12 |
13 | 
14 |
15 |
16 |
17 | Each sentence consists of 11 words in this experiment.
18 |
19 |
20 |
21 | The default batch size of 32 does not result in the highest possible throughput on this hardware. Consider experimenting with the batch size to reach your highest possible throughput.
--------------------------------------------------------------------------------
/docs/source/en/how_to/callbacks.mdx:
--------------------------------------------------------------------------------
1 |
2 | # Callbacks
3 | SetFit models can be influenced by callbacks, for example for logging or early stopping.
4 |
5 | This guide will show you what they are and how they can be used.
6 |
7 | ## Callbacks in SetFit
8 |
9 | Callbacks are objects that customize the behaviour of the training loop in the SetFit [`Trainer`] that can inspect the training loop state (for progress reporting, logging, inspecting embeddings during training) and take decisions (e.g. early stopping).
10 |
11 | In particular, the [`Trainer`] uses a [`TrainerControl`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.TrainerControl) that can be influenced by callbacks to stop training, save models, evaluate, or log, and a [`TrainerState`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.TrainerState) which tracks some training loop metrics during training, such as the number of training steps so far.
12 |
13 | SetFit relies on the Callbacks implemented in `transformers`, as described in the `transformers` documentation [here](https://huggingface.co/docs/transformers/main_classes/callback).
14 |
15 | ## Default Callbacks
16 |
17 | SetFit uses the `TrainingArguments.report_to` argument to specify which of the built-in callbacks should be enabled. This argument defaults to `"all"`, meaning that all third-party callbacks from `transformers` that are also installed will be enabled. For example the [`TensorBoardCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.integrations.TensorBoardCallback) or the [`WandbCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.integrations.WandbCallback).
18 |
19 | Beyond that, the [`PrinterCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.PrinterCallback) or [`ProgressCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.ProgressCallback) is always enabled to show the training progress, and [`DefaultFlowCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.DefaultFlowCallback) is also always enabled to properly update the `TrainerControl`.
20 |
21 | ## Using Callbacks
22 |
23 | As mentioned, you can use `TrainingArguments.report_to` to specify exactly which callbacks you would like to enable. For example:
24 |
25 | ```py
26 | from setfit import TrainingArguments
27 |
28 | args = TrainingArguments(
29 | ...,
30 | report_to="wandb",
31 | ...,
32 | )
33 | # or
34 | args = TrainingArguments(
35 | ...,
36 | report_to=["wandb", "tensorboard"],
37 | ...,
38 | )
39 | ```
40 | You can also use [`Trainer.add_callback`], [`Trainer.pop_callback`] and [`Trainer.remove_callback`] to influence the trainer callbacks, and you can specify callbacks via the [`Trainer`] init, e.g.:
41 |
42 | ```py
43 | from setfit import Trainer
44 |
45 | ...
46 |
47 | trainer = Trainer(
48 | model,
49 | args=args,
50 | train_dataset=train_dataset,
51 | eval_dataset=eval_dataset,
52 | callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
53 | )
54 | trainer.train()
55 | ```
56 |
57 | ## Custom Callbacks
58 |
59 | SetFit supports custom callbacks in the same way that `transformers` does: by subclassing [`TrainerCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.TrainerCallback). This class implements a lot of `on_...` methods that can be overridden. For example, the following script shows a custom callback that saves plots of the tSNE of the training and evaluation embeddings during training.
60 |
61 | ```py
62 | import os
63 | import matplotlib.pyplot as plt
64 | from sklearn.manifold import TSNE
65 |
66 | class EmbeddingPlotCallback(TrainerCallback):
67 | """Simple embedding plotting callback that plots the tSNE of the training and evaluation datasets throughout training."""
68 | def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
69 | os.makedirs("logs", exist_ok=True)
70 |
71 | def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: SetFitModel, **kwargs):
72 | train_embeddings = model.encode(train_dataset["text"])
73 | eval_embeddings = model.encode(eval_dataset["text"])
74 |
75 | fig, (train_ax, eval_ax) = plt.subplots(ncols=2)
76 |
77 | train_X = TSNE(n_components=2).fit_transform(train_embeddings)
78 | train_ax.scatter(*train_X.T, c=train_dataset["label"], label=train_dataset["label"])
79 | train_ax.set_title("Training embeddings")
80 |
81 | eval_X = TSNE(n_components=2).fit_transform(eval_embeddings)
82 | eval_ax.scatter(*eval_X.T, c=eval_dataset["label"], label=eval_dataset["label"])
83 | eval_ax.set_title("Evaluation embeddings")
84 |
85 | fig.suptitle(f"tSNE of training and evaluation embeddings at step {state.global_step} of {state.max_steps}.")
86 | fig.savefig(f"logs/step_{state.global_step}.png")
87 | ```
88 |
89 | with
90 |
91 | ```py
92 | trainer = Trainer(
93 | model=model,
94 | args=args,
95 | train_dataset=train_dataset,
96 | eval_dataset=eval_dataset,
97 | callbacks=[EmbeddingPlotCallback()]
98 | )
99 | trainer.train()
100 | ```
101 |
102 | The `on_evaluate` from `EmbeddingPlotCallback` will be triggered on every single evaluation call. In the case of this example, it resulted in the following figures being plotted:
103 |
104 | | Step 20 | Step 40 |
105 | |-------------|-------------|
106 | |  |  |
107 | | **Step 60** | **Step 80** |
108 | |  |  |
--------------------------------------------------------------------------------
/docs/source/en/how_to/multilabel.mdx:
--------------------------------------------------------------------------------
1 |
2 | # Multilabel Text Classification
3 |
4 | SetFit supports multilabel classification, allowing multiple labels to be assigned to each instance.
5 |
6 |
7 |
8 | Unless each instance must be assigned multiple outputs, you frequently do not need to specify a multi target strategy.
9 |
10 |
11 |
12 | This guide will show you how to train and use multilabel SetFit models.
13 |
14 | ## Multilabel strategies
15 |
16 | SetFit will initialise a multilabel classification head from `sklearn` - the following options are available for `multi_target_strategy`:
17 |
18 | * `"one-vs-rest"`: uses a [`OneVsRestClassifier`](https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html) head.
19 | * `"multi-output"`: uses a [`MultiOutputClassifier`](https://scikit-learn.org/stable/modules/generated/sklearn.multioutput.MultiOutputClassifier.html) head.
20 | * `"classifier-chain"`: uses a [`ClassifierChain`](https://scikit-learn.org/stable/modules/generated/sklearn.multioutput.ClassifierChain.html) head.
21 |
22 | See the [scikit-learn documentation for multiclass and multioutput classification](https://scikit-learn.org/stable/modules/multiclass.html#multiclass-classification) for more details.
23 |
24 | ## Initializing SetFit models with multilabel strategies
25 |
26 | Using the default LogisticRegression head, we can apply multi target strategies like so:
27 |
28 | ```py
29 | from setfit import SetFitModel
30 |
31 | model = SetFitModel.from_pretrained(
32 | model_id, # e.g. "BAAI/bge-small-en-v1.5"
33 | multi_target_strategy="multi-output",
34 | )
35 | ```
36 |
37 | With a differentiable head it looks like so:
38 |
39 | ```py
40 | from setfit import SetFitModel
41 |
42 | model = SetFitModel.from_pretrained(
43 | model_id, # e.g. "BAAI/bge-small-en-v1.5"
44 | multi_target_strategy="one-vs-rest"
45 | use_differentiable_head=True,
46 | head_params={"out_features": num_classes},
47 | )
48 | ```
--------------------------------------------------------------------------------
/docs/source/en/how_to/overview.mdx:
--------------------------------------------------------------------------------
1 |
2 | # Overview
3 |
4 | Welcome to the SetFit How-to Guides! The how-to guides offer a more comprehensive overview of all the tools 🤗 SetFit offers and how to use them.
5 | These guides are designed to be concise and code-heavy, written in "show, don't tell" style. For example, using these guides you may learn how to perform hyperparameter optimization, knowledge distillation, apply callbacks, etc.
6 |
7 | Most how-to guides end with an "end to end" script showing all code from the guide for easy adaptation into your own code.
8 |
9 | For simpler documentation explaining SetFit functionality from start to finish, consider visiting the [Tutorials](../tutorials/overview) section or the [quickstart](../quickstart).
10 |
--------------------------------------------------------------------------------
/docs/source/en/how_to/zero_shot.mdx:
--------------------------------------------------------------------------------
1 |
2 | # Zero-shot Text Classification
3 |
4 | [[open-in-colab]]
5 |
6 | Your class names are likely already good descriptors of the text that you're looking to classify. With 🤗 SetFit, you can use these class names with strong pretrained Sentence Transformer models to get a strong baseline model without any training samples.
7 |
8 | This guide will show you how to perform zero-shot text classification.
9 |
10 | ## Testing dataset
11 |
12 | We'll use the [dair-ai/emotion](https://huggingface.co/datasets/dair-ai/emotion) dataset to test the performance of our zero-shot model.
13 |
14 | ```py
15 | from datasets import load_dataset
16 |
17 | test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
18 | ```
19 |
20 | This dataset stores the class names within the dataset `Features`, so we'll extract the classes like so:
21 | ```py
22 | classes = test_dataset.features["label"].names
23 | # => ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
24 | ```
25 | Otherwise, we could manually set the list of classes.
26 |
27 | ## Synthetic dataset
28 |
29 | Then, we can use [`get_templated_dataset`] to synthetically generate a dummy dataset given these class names.
30 |
31 | ```py
32 | from setfit import get_templated_dataset
33 |
34 | train_dataset = get_templated_dataset()
35 | ```
36 | ```py
37 | print(train_dataset)
38 | # => Dataset({
39 | # features: ['text', 'label'],
40 | # num_rows: 48
41 | # })
42 | print(train_dataset[0])
43 | # {'text': 'This sentence is sadness', 'label': 0}
44 | ```
45 |
46 | ## Training
47 |
48 | We can use this dataset to train a SetFit model just like normal:
49 |
50 | ```py
51 | from setfit import SetFitModel, Trainer, TrainingArguments
52 |
53 | model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5")
54 |
55 | args = TrainingArguments(
56 | batch_size=32,
57 | num_epochs=1,
58 | )
59 |
60 | trainer = Trainer(
61 | model=model,
62 | args=args,
63 | train_dataset=train_dataset,
64 | eval_dataset=test_dataset,
65 | )
66 | trainer.train()
67 | ```
68 | ```
69 | ***** Running training *****
70 | Num examples = 60
71 | Num epochs = 1
72 | Total optimization steps = 60
73 | Total train batch size = 32
74 | {'embedding_loss': 0.2628, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.02}
75 | {'embedding_loss': 0.0222, 'learning_rate': 3.7037037037037037e-06, 'epoch': 0.83}
76 | {'train_runtime': 15.4717, 'train_samples_per_second': 124.098, 'train_steps_per_second': 3.878, 'epoch': 1.0}
77 | 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00, 6.35it/s]
78 | ```
79 |
80 | Once trained, we can evaluate the model:
81 |
82 | ```py
83 | metrics = trainer.evaluate()
84 | print(metrics)
85 | ```
86 | ```
87 | ***** Running evaluation *****
88 | {'accuracy': 0.591}
89 | ```
90 |
91 | And run predictions:
92 |
93 | ```py
94 | preds = model.predict([
95 | "i am just feeling cranky and blue",
96 | "i feel incredibly lucky just to be able to talk to her",
97 | "you're pissing me off right now",
98 | "i definitely have thalassophobia, don't get me near water like that",
99 | "i did not see that coming at all",
100 | ])
101 | print([classes[idx] for idx in preds])
102 | ```
103 | ```py
104 | ['sadness', 'joy', 'anger', 'fear', 'surprise']
105 | ```
106 |
107 | These predictions all look right!
108 |
109 | ## Baseline
110 |
111 | To show that the zero-shot performance of SetFit works well, we'll compare it against a zero-shot classification model from `transformers`.
112 |
113 | ```py
114 | from transformers import pipeline
115 | from datasets import load_dataset
116 | import evaluate
117 |
118 | # Prepare the testing dataset
119 | test_dataset = load_dataset("dair-ai/emotion", "split", split="test")
120 | classes = test_dataset.features["label"].names
121 |
122 | # Set up the zero-shot classification pipeline from transformers
123 | # Uses 'facebook/bart-large-mnli' by default
124 | pipe = pipeline("zero-shot-classification", device=0)
125 | zeroshot_preds = pipe(test_dataset["text"], batch_size=16, candidate_labels=classes)
126 | preds = [classes.index(pred["labels"][0]) for pred in zeroshot_preds]
127 |
128 | # Compute the accuracy
129 | metric = evaluate.load("accuracy")
130 | transformers_accuracy = metric.compute(predictions=preds, references=test_dataset["label"])
131 | print(transformers_accuracy)
132 | ```
133 | ```py
134 | {'accuracy': 0.3765}
135 | ```
136 |
137 | With its 59.1% accuracy, the 0-shot SetFit heavily outperforms the recommended zero-shot model by `transformers`.
138 |
139 | ## Prediction latency
140 |
141 | Beyond getting higher accuracies, SetFit is much faster too. Let's compute the latency of SetFit with `BAAI/bge-small-en-v1.5` versus the latency of `transformers` with `facebook/bart-large-mnli`. Both tests were performed on a GPU.
142 |
143 | ```py
144 | import time
145 |
146 | start_t = time.time()
147 | pipe(test_dataset["text"], batch_size=32, candidate_labels=classes)
148 | delta_t = time.time() - start_t
149 | print(f"`transformers` with `facebook/bart-large-mnli` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
150 | ```
151 | ```
152 | `transformers` with `facebook/bart-large-mnli` latency: 31.1765ms per sentence
153 | ```
154 |
155 | ```py
156 | import time
157 |
158 | start_t = time.time()
159 | model.predict(test_dataset["text"])
160 | delta_t = time.time() - start_t
161 | print(f"SetFit with `BAAI/bge-small-en-v1.5` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence")
162 | ```
163 | ```
164 | SetFit with `BAAI/bge-small-en-v1.5` latency: 0.4600ms per sentence
165 | ```
166 |
167 | So, SetFit with `BAAI/bge-small-en-v1.5` is 67x faster than `transformers` with `facebook/bart-large-mnli`, alongside being more accurate:
168 |
169 | 
170 |
--------------------------------------------------------------------------------
/docs/source/en/index.mdx:
--------------------------------------------------------------------------------
1 | # SetFit
2 |
3 |
8 |
9 | 🤗 SetFit is an efficient and prompt-free framework for few-shot fine-tuning of [Sentence Transformers](https://sbert.net/). It achieves high accuracy with little labeled data - for instance, with only 8 labeled examples per class on the Customer Reviews sentiment dataset, 🤗 SetFit is competitive with fine-tuning RoBERTa Large on the full training set of 3k examples!
10 |
11 | Compared to other few-shot learning methods, SetFit has several unique features:
12 |
13 | * 🗣 **No prompts or verbalizers:** Current techniques for few-shot fine-tuning require handcrafted prompts or verbalizers to convert examples into a format suitable for the underlying language model. SetFit dispenses with prompts altogether by generating rich embeddings directly from text examples.
14 | * 🏎 **Fast to train:** SetFit doesn't require large-scale models like T0, Llama or GPT-4 to achieve high accuracy. As a result, it is typically an order of magnitude (or more) faster to train and run inference with.
15 | * 🌎 **Multilingual support**: SetFit can be used with any [Sentence Transformer](https://huggingface.co/models?library=sentence-transformers&sort=downloads) on the Hub, which means you can classify text in multiple languages by simply fine-tuning a multilingual checkpoint.
16 |
17 |
37 |
--------------------------------------------------------------------------------
/docs/source/en/installation.mdx:
--------------------------------------------------------------------------------
1 |
2 | # Installation
3 |
4 | Before you start, you'll need to setup your environment and install the appropriate packages. 🤗 SetFit is tested on **Python 3.9+**.
5 |
6 | ## pip
7 |
8 | The most straightforward way to install 🤗 SetFit is with pip:
9 |
10 | ```bash
11 | pip install setfit
12 | ```
13 |
14 | If you have a CUDA-capable graphics card, then it is recommended to [install `torch` with CUDA support](https://pytorch.org/get-started/locally/) to train and performing inference much more quickly:
15 |
16 | ```bash
17 | pip install torch --index-url https://download.pytorch.org/whl/cu118
18 | ```
19 |
20 | ## Installing from source
21 |
22 | Building 🤗 SetFit from source lets you make changes to the code base. To install from the source, clone the repository and install 🤗 SetFit in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) with the following commands:
23 |
24 | ```bash
25 | git clone https://github.com/huggingface/setfit.git
26 | cd setfit
27 | pip install -e .
28 | ```
29 |
30 | If you just want the bleeding-edge version without making any changes of your own, then install from source by running:
31 |
32 | ```bash
33 | pip install git+https://github.com/huggingface/setfit.git
34 | ```
35 |
36 | ## Conda
37 |
38 | If conda is your package management system of choice, then you can install 🤗 SetFit like so:
39 |
40 | ```bash
41 | conda install -c conda-forge setfit
42 | ```
--------------------------------------------------------------------------------
/docs/source/en/reference/main.mdx:
--------------------------------------------------------------------------------
1 |
2 | # Main Classes
3 |
4 | ## SetFitModel
5 |
6 | [[autodoc]] SetFitModel
7 | - all
8 | - from_pretrained
9 | - save_pretrained
10 | - push_to_hub
11 | - __call__
12 | - label2id
13 | - id2label
14 |
15 | ## SetFitHead
16 |
17 | [[autodoc]] SetFitHead
18 |
19 | ## SetFitModelCardData
20 |
21 | [[autodoc]] SetFitModelCardData
22 | - to_dict
23 | - to_yaml
24 |
25 | ## AbsaModel
26 |
27 | [[autodoc]] AbsaModel
28 | - __call__
29 | - device
30 | - from_pretrained
31 | - predict
32 | - push_to_hub
33 | - to
34 | - save_pretrained
35 |
36 | ### AspectModel
37 |
38 | [[autodoc]] AspectModel
39 | - __call__
40 | - device
41 | - from_pretrained
42 | - predict
43 | - push_to_hub
44 | - save_pretrained
45 | - to
46 |
47 | ### PolarityModel
48 |
49 | [[autodoc]] PolarityModel
50 | - __call__
51 | - device
52 | - from_pretrained
53 | - predict
54 | - push_to_hub
55 | - save_pretrained
56 | - to
57 |
--------------------------------------------------------------------------------
/docs/source/en/reference/trainer.mdx:
--------------------------------------------------------------------------------
1 |
2 | # Trainer Classes
3 |
4 | ## TrainingArguments
5 |
6 | [[autodoc]] TrainingArguments
7 | - to_dict
8 | - from_dict
9 | - copy
10 | - update
11 |
12 | ## Trainer
13 |
14 | [[autodoc]] Trainer
15 | - add_callback
16 | - apply_hyperparameters
17 | - evaluate
18 | - hyperparameter_search
19 | - pop_callback
20 | - push_to_hub
21 | - remove_callback
22 | - train
23 | - train_classifier
24 | - train_embeddings
25 |
26 | ## DistillationTrainer
27 |
28 | [[autodoc]] DistillationTrainer
29 | - add_callback
30 | - apply_hyperparameters
31 | - evaluate
32 | - hyperparameter_search
33 | - pop_callback
34 | - push_to_hub
35 | - remove_callback
36 | - train
37 | - train_classifier
38 | - train_embeddings
39 |
40 | ## AbsaTrainer
41 |
42 | [[autodoc]] AbsaTrainer
43 | - add_callback
44 | - evaluate
45 | - pop_callback
46 | - push_to_hub
47 | - remove_callback
48 | - train
49 | - train_aspect
50 | - train_polarity
51 |
--------------------------------------------------------------------------------
/docs/source/en/reference/utility.mdx:
--------------------------------------------------------------------------------
1 |
2 | # Utility Functions
3 |
4 | [[autodoc]] get_templated_dataset
5 |
6 | [[autodoc]] sample_dataset
--------------------------------------------------------------------------------
/docs/source/en/tutorials/overview.mdx:
--------------------------------------------------------------------------------
1 |
2 | # Overview
3 |
4 | Welcome to the SetFit tutorials! These tutorials are designed to walk you through particular applications. For example, we'll delve into topics such as zero-shot text classification, where you'll learn how to use SetFit without any predefined labels or examples during training. See also the [SetFit Notebooks](https://github.com/huggingface/setfit/tree/main/notebooks) for more applications, such as hyperparameter searching and ONNX, though some might be outdated.
5 |
6 | For more concise guides on how to configure SetFit or use it for specific forms of text classification, see the [How-to Guides](../how_to/overview) section.
7 |
8 | If you have any questions about SetFit, feel free to open an [issue](https://github.com/huggingface/setfit/issues).
9 |
--------------------------------------------------------------------------------
/final_results/adapet/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/adapet/.gitkeep
--------------------------------------------------------------------------------
/final_results/adapet/albert-xxlarge-v2.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/adapet/albert-xxlarge-v2.tar.gz
--------------------------------------------------------------------------------
/final_results/adapet/xlm-roberta-base_all__lang_prompt.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/adapet/xlm-roberta-base_all__lang_prompt.tar.gz
--------------------------------------------------------------------------------
/final_results/adapet/xlm-roberta-base_each__lang_prompt.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/adapet/xlm-roberta-base_each__lang_prompt.tar.gz
--------------------------------------------------------------------------------
/final_results/adapet/xlm-roberta-base_en__eng_prompt.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/adapet/xlm-roberta-base_en__eng_prompt.tar.gz
--------------------------------------------------------------------------------
/final_results/distillation/distillation_results.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/distillation/distillation_results.tar.gz
--------------------------------------------------------------------------------
/final_results/perfect.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/perfect.tar.gz
--------------------------------------------------------------------------------
/final_results/setfit/all-MiniLM-L6-v2.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/all-MiniLM-L6-v2.tar.gz
--------------------------------------------------------------------------------
/final_results/setfit/all-roberta-large-v1.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/all-roberta-large-v1.tar.gz
--------------------------------------------------------------------------------
/final_results/setfit/paraphrase-mpnet-base-v2-epochs_20-with-augmentation.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/paraphrase-mpnet-base-v2-epochs_20-with-augmentation.tar.gz
--------------------------------------------------------------------------------
/final_results/setfit/paraphrase-mpnet-base-v2-epochs_20.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/paraphrase-mpnet-base-v2-epochs_20.tar.gz
--------------------------------------------------------------------------------
/final_results/setfit/paraphrase-mpnet-base-v2-epochs_5.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/paraphrase-mpnet-base-v2-epochs_5.tar.gz
--------------------------------------------------------------------------------
/final_results/setfit/paraphrase-mpnet-base-v2-linear-probe.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/paraphrase-mpnet-base-v2-linear-probe.tar.gz
--------------------------------------------------------------------------------
/final_results/setfit/paraphrase-multilingual-mpnet-base-v2.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/paraphrase-multilingual-mpnet-base-v2.tar.gz
--------------------------------------------------------------------------------
/final_results/tfew/t011b_pretrained-v2.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/tfew/t011b_pretrained-v2.tar.gz
--------------------------------------------------------------------------------
/final_results/tfew/t03b_pretrained-v2.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/tfew/t03b_pretrained-v2.tar.gz
--------------------------------------------------------------------------------
/final_results/transformers/distilbert-base-uncased.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/transformers/distilbert-base-uncased.tar.gz
--------------------------------------------------------------------------------
/final_results/transformers/roberta-large.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/transformers/roberta-large.tar.gz
--------------------------------------------------------------------------------
/final_results/transformers/xlm-roberta-base.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/transformers/xlm-roberta-base.tar.gz
--------------------------------------------------------------------------------
/notebooks/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/notebooks/.gitkeep
--------------------------------------------------------------------------------
/scripts/adapet/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/scripts/adapet/.gitkeep
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/.gitignore:
--------------------------------------------------------------------------------
1 | exp_out
2 | exp_out/
3 | cur_exp_dir
4 | env/
5 | results/
6 | wandb/
7 | lib/
8 | output/
9 | adapet_models
10 | adapet_models/
11 | data
12 | data/
13 | pretrained_models
14 | pretrained_models/
15 | .idea/
16 | eche_
17 | runs/
18 | *.pyc
19 | .installed.cfg
20 | develop-eggs
21 | dist
22 | downloads
23 | eggs
24 | parts
25 | src/*.egg-info
26 | lib
27 | lib64
28 | !src/data
29 | env.yaml
30 | pet.yml
31 | results/
32 | results
33 | venv
34 | venv/
35 | output_slurm
36 | output_slurm/
37 | slurm_adapetsetfit.sh
38 | seed_ouput
39 | seed_ouput/
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/README.md:
--------------------------------------------------------------------------------
1 | # ADAPET for SetFit #
2 | This is a fork of the original ADAPET, which can be found [here](https://github.com/rrmenon10/ADAPET).
3 |
4 | Our results were created in Python 3.6.8 with a 40 GB NVIDIA A100 Tensor Core GPU
5 |
6 | To setup, please clone the repo and follow the instructions below.
7 | ````
8 | python3.6 -m venv venv
9 | source mvenv/bin/activate
10 | pip install -r requirements.txt
11 | pip install torch==1.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
12 | ````
13 |
14 |
15 |
16 | To run ADAPET on SetFit datasets, specify the dataset and PLM as argument to python. Additionally, if you wish to
17 | run on a multilingual dataset, then you can choose to prompt/verbalize in the English or the language in question.
18 | You can also remove the prompt.
19 |
20 | For example, if you wish to run ADAPET on `sst2` with a prompt and with `albert-xxlarge-v2` as the PLM, simply run the following
21 | ```
22 | python setfit_adapet.py --pretrained_weight="albert-xxlarge-v2"\
23 | --english=True\
24 | --prompt=True\
25 | --task_name='SetFit/sst2'\
26 | ```
27 |
28 | If you wish to run ADAPET on amazon_reviews_multi_ja, prompt in Japanese, with mdeberta-base
29 | ```
30 | python setfit_adapet.py --pretrained_weight="microsoft/mdeberta-v3-base"\
31 | --english=False\
32 | --prompt=True\
33 | --task_name='SetFit/amazon_reviews_multi_ja'\
34 | ```
35 |
36 | This will run ADAPET and evaluate it on the test set for the 8 and 64 splits.
37 |
38 | In the multilingual case, ADAPET runs in the "each" scenario as described in the paper by default. You can change this by adding a multilingual argument, such as
39 |
40 | ```
41 | python setfit_adapet.py --pretrained_weight="microsoft/mdeberta-v3-base"\
42 | --english=False\
43 | --prompt=True\
44 | --task_name='SetFit/amazon_reviews_multi_ja'\
45 | --multilingual='all'\
46 | ```
47 |
48 |
49 | Note that with default hyperparameters this will take a very long time. If you change to distilbert-base-uncase or another smaller model, ADAPET will be much faster.
50 |
51 | Once ADAPET is done, the results will be written as follows
52 | ```
53 | seed_output / model_name / dataset_name / train-{num_samples}-{split_dx} / results.json
54 | ```
55 |
56 | For non-English datasets, the plm may have different endings. The ending describes the prompting situation:
57 | ````
58 | {plm}__eng_prompt == prompt and verbalize in english
59 | {plm}__lang_prompt == prompt and verbalize in the language in question.
60 | {plm}__lang_no-prompt == take the prompt away and verbalize in the language in question
61 | {plm}__eng_no-prompt == take the prompt away and verbalize english
62 | ````
63 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/bin/dev.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -exu
4 |
5 | exp_dir=$1
6 |
7 | python -m src.dev -e $exp_dir
8 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/bin/init.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ ! -d "data/superglue/" ] ; then
4 | mkdir -p data
5 | cd data
6 | mkdir -p superglue
7 | cd superglue
8 | wget "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/combined.zip"
9 | unzip combined.zip
10 | cd ../..
11 | fi
12 |
13 | if [ ! -d "data/fewglue/" ] ; then
14 | mkdir -p data
15 | cd data
16 | git clone https://github.com/timoschick/fewglue.git
17 | cd fewglue
18 | rm -rf .git
19 | rm README.md
20 | mv FewGLUE/* .
21 | rm -r FewGLUE
22 | cd ../..
23 | fi
24 |
25 | if [ ! -d "env" ] ; then
26 | python -m venv env
27 | source env/bin/activate
28 | pip install --upgrade pip
29 | pip install -r requirements.txt
30 | fi
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/bin/setup.sh:
--------------------------------------------------------------------------------
1 | source env/bin/activate
2 | export ADAPET_ROOT=`pwd`
3 | export PYTHONPATH=$ADAPET_ROOT:$PYTHONPATH
4 | export PYTHON_EXEC=python
5 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/bin/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -exu
4 |
5 | exp_dir=$1
6 |
7 | python -m src.test -e $exp_dir
8 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/bin/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -exu
4 |
5 | config_file=$1
6 |
7 | python -m src.train -c $config_file
8 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/cli.py:
--------------------------------------------------------------------------------
1 | #import argparse
2 | import os
3 | import json
4 |
5 | from src.utils.Config import Config
6 | from src.train import train
7 |
8 |
9 |
10 | def call_adapet(updated_args):
11 | '''
12 | parser = argparse.ArgumentParser()
13 |
14 | # Arguments for running any datasets
15 | parser.add_argument('-d', "--data_dir", default=configs['generic_data_dir'],
16 | help="Data directory containing train/val/test jsonl files")
17 | parser.add_argument('-p', "--pattern", default=configs['pattern'],
18 | help="Pattern to be used for this dataset")
19 | parser.add_argument('-v', "--dict_verbalizer", type=json.loads, default=configs['dict_verbalizer'],
20 | help="Dictionary mapping label name (in dataset) to the verbalizer to use, e.g. '{\"0\": \"Yes\", \"1\": \"No\"}'")
21 |
22 | # Model and training hyperparams
23 | parser.add_argument('-w', '--pretrained_weight', type=str, default=configs['pretrained_weight'],
24 | help='Pretrained model weights from huggingface')
25 | parser.add_argument('-bs', '--batch_size', type=int, default=1, help='batch size during training')
26 | parser.add_argument('--eval_batch_size', type=int, default=configs['batch_size'],
27 | help='batch size during evaluation')
28 | parser.add_argument('--grad_accumulation_factor', type=int, default=16, help='number of gradient accumulation steps')
29 | parser.add_argument('--num_batches', type=int, default=configs['num_batches'],
30 | help='number of batches for experiment; 1 batch = grad_accumulation_factor x batch_size')
31 |
32 | parser.add_argument('--eval_every', type=int, default=configs['eval_every'],
33 | help='number of training batches per evaluation')
34 | parser.add_argument('--max_text_length', type=int, default=256, help='maximum total input sequence length after tokenization for ADAPET')
35 | parser.add_argument('--lr', type=float, default=1e-5, help='learning rate for the model')
36 | parser.add_argument('--weight_decay', type=float, default=1e-2, help='weight decay for the optmizer')
37 | parser.add_argument('--grad_clip_norm', type=float, default=1, help='gradient clipping norm')
38 | parser.add_argument('--warmup_ratio', type=float, default=0.06, help='linear warmup over warmup_steps for num_batches')
39 |
40 | # ADAPET hyperparameters
41 | parser.add_argument('--pattern_idx', default=1, help="Pattern index among all patterns available; For SuperGLUE, can use numbers >1 depending on dataset. For a new dataset, please set this to 1.")
42 | parser.add_argument('--mask_alpha', type=float, default=0.105, help='masking ratio for the label conditioning loss')
43 | parser.add_argument('--idx_txt_trim', type=int, default=1, help="TXT_ID of the text that can be trimmed (usually the longer text). Eg. if TXT1 needs to be trimmed, set this to 1.")
44 | parser.add_argument('--max_num_lbl_tok', type=int, default=configs['max_num_lbl_tok'], help="The maximum number of tokens per label for the verbalizer. It will raise an error if the tokenizer tokenizes a label into more than 'max_num_lbl_tok' tokens.")
45 |
46 | # Replicating SuperGLUE results
47 | parser.add_argument('-c', '--config', type=str, default=None, help='Use this for replicating SuperGLUE results.')
48 | '''
49 |
50 |
51 | #args = final_parser.parse_args()
52 | args = updated_args
53 |
54 | # If even one of these three arguments are provided, we need all three as input
55 | if args.data_dir or args.pattern or args.dict_verbalizer:
56 | assert args.data_dir and args.pattern and args.dict_verbalizer, 'Please enter all of data_dir, pattern, dict_verbalizer!'
57 |
58 | if args.config is not None:
59 | use_config = args.config
60 | else:
61 | assert args.data_dir or args.pattern or args.dict_verbalizer, 'Please enter all of data_dir, pattern, dict_verbalizer if not providing config!'
62 | use_config = os.path.join("config", "Generic.json")
63 |
64 | update_config = vars(args)
65 | config = Config(use_config, update_config, mkdir=True)
66 | train(config)
67 | print(config.exp_dir)
68 | return config.exp_dir
69 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/config/BoolQ.json:
--------------------------------------------------------------------------------
1 | {
2 | "pretrained_weight": "albert-xxlarge-v2",
3 | "dataset": "fewglue/BoolQ",
4 | "max_text_length": 256,
5 | "batch_size": 1,
6 | "eval_batch_size": 1,
7 | "num_batches": 1000,
8 | "max_num_lbl_tok": 1,
9 | "eval_every": 250,
10 | "warmup_ratio": 0.06,
11 | "mask_alpha": 0.105,
12 | "grad_accumulation_factor": 16,
13 | "seed": 42,
14 | "lr": 1e-5,
15 | "weight_decay": 1e-2,
16 | "pattern_idx": 1,
17 | "eval_train": true
18 | }
19 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/config/CB.json:
--------------------------------------------------------------------------------
1 | {
2 | "pretrained_weight": "albert-xxlarge-v2",
3 | "dataset": "fewglue/CB",
4 | "max_text_length": 256,
5 | "batch_size": 1,
6 | "eval_batch_size": 1,
7 | "num_batches": 1000,
8 | "max_num_lbl_tok": 1,
9 | "mask_alpha": 0.105,
10 | "eval_every": 250,
11 | "warmup_ratio": 0.06,
12 | "grad_accumulation_factor": 16,
13 | "seed": 42,
14 | "lr": 1e-5,
15 | "dropout_rate": 0.1,
16 | "weight_decay": 1e-2,
17 | "pattern_idx": 1,
18 | "eval_train": true
19 | }
20 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/config/COPA.json:
--------------------------------------------------------------------------------
1 | {
2 | "pretrained_weight": "albert-xxlarge-v2",
3 | "dataset": "fewglue/COPA",
4 | "max_text_length": 256,
5 | "batch_size": 1,
6 | "eval_batch_size": 1,
7 | "num_batches": 1000,
8 | "eval_every": 250,
9 | "mask_alpha": 0.105,
10 | "warmup_ratio": 0.06,
11 | "max_num_lbl_tok": 20,
12 | "grad_accumulation_factor": 16,
13 | "seed": 42,
14 | "lr": 1e-5,
15 | "dropout_rate": 0.1,
16 | "weight_decay": 1e-2,
17 | "pattern_idx": 1,
18 | "eval_train": true
19 | }
20 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/config/Generic.json:
--------------------------------------------------------------------------------
1 | {"pretrained_weight": "microsoft/mdeberta-v3-base", "dataset": "generic", "generic_data_dir": "data/SetFit/amazon_reviews_multi_ja", "pattern": "[TEXT1]これは[LBL]だ", "pattern_idx": 1, "dict_verbalizer": "{\"0\": \"一つ星\", \"1\": \"二つ星\", \"2\": \"三つ星\", \"3\": \"四つ星\", \"4\": \"五つ星\"}", "idx_txt_trim": 1, "max_text_length": 256, "batch_size": 1, "eval_batch_size": 1, "num_batches": 10, "max_num_lbl_tok": 3, "eval_every": 10, "eval_train": true, "warmup_ratio": 0.06, "mask_alpha": 0.105, "grad_accumulation_factor": 16, "seed": 0, "lr": 1e-05, "weight_decay": 0.01}
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/config/MultiRC.json:
--------------------------------------------------------------------------------
1 | {
2 | "pretrained_weight": "albert-xxlarge-v2",
3 | "dataset": "fewglue/MultiRC",
4 | "max_text_length": 512,
5 | "batch_size": 1,
6 | "eval_batch_size": 1,
7 | "num_batches": 1000,
8 | "max_num_lbl_tok": 1,
9 | "mask_alpha": 0.105,
10 | "eval_every": 250,
11 | "warmup_ratio": 0.06,
12 | "grad_accumulation_factor": 16,
13 | "seed": 42,
14 | "lr": 1e-5,
15 | "dropout_rate": 0.1,
16 | "weight_decay": 1e-2,
17 | "pattern_idx": 1,
18 | "eval_train": true
19 | }
20 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/config/RTE.json:
--------------------------------------------------------------------------------
1 | {
2 | "pretrained_weight": "albert-xxlarge-v2",
3 | "dataset": "fewglue/RTE",
4 | "max_text_length": 256,
5 | "batch_size": 1,
6 | "eval_batch_size": 1,
7 | "num_batches": 1000,
8 | "max_num_lbl_tok": 1,
9 | "mask_alpha": 0.105,
10 | "eval_every": 250,
11 | "warmup_ratio": 0.06,
12 | "grad_accumulation_factor": 16,
13 | "seed": 42,
14 | "lr": 1e-5,
15 | "dropout_rate": 0.1,
16 | "weight_decay": 1e-2,
17 | "pattern_idx": 1,
18 | "eval_train": true
19 | }
20 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/config/ReCoRD.json:
--------------------------------------------------------------------------------
1 | {
2 | "pretrained_weight": "albert-xxlarge-v2",
3 | "dataset": "fewglue/ReCoRD",
4 | "max_text_length": 512,
5 | "batch_size": 1,
6 | "eval_batch_size": 1,
7 | "num_batches": 1000,
8 | "eval_every": 250,
9 | "mask_alpha": 0.105,
10 | "warmup_ratio": 0.06,
11 | "max_num_lbl_tok": 20,
12 | "grad_accumulation_factor": 16,
13 | "seed": 42,
14 | "lr": 1e-5,
15 | "dropout_rate": 0.1,
16 | "weight_decay": 1e-2,
17 | "pattern_idx": 1,
18 | "eval_train": true
19 | }
20 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/config/WSC.json:
--------------------------------------------------------------------------------
1 | {
2 | "pretrained_weight": "albert-xxlarge-v2",
3 | "dataset": "fewglue/WSC",
4 | "max_text_length": 256,
5 | "batch_size": 1,
6 | "eval_batch_size": 1,
7 | "num_batches": 1000,
8 | "eval_every": 250,
9 | "mask_alpha": 0.105,
10 | "warmup_ratio": 0.06,
11 | "max_num_lbl_tok": 20,
12 | "grad_accumulation_factor": 16,
13 | "seed": 42,
14 | "lr": 1e-5,
15 | "weight_decay": 1e-2,
16 | "pattern_idx": 1,
17 | "eval_train": true
18 | }
19 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/config/WiC.json:
--------------------------------------------------------------------------------
1 | {
2 | "pretrained_weight": "albert-xxlarge-v2",
3 | "dataset": "fewglue/WiC",
4 | "max_text_length": 256,
5 | "batch_size": 1,
6 | "eval_batch_size": 8,
7 | "num_batches": 1000,
8 | "eval_every": 250,
9 | "warmup_ratio": 0.06,
10 | "grad_accumulation_factor": 16,
11 | "max_num_lbl_tok": 1,
12 | "seed": 42,
13 | "lr": 1e-5,
14 | "dropout_rate": 0.1,
15 | "weight_decay": 1e-2,
16 | "pattern_idx": 1,
17 | "eval_train": true
18 | }
19 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/config/sst-2.json:
--------------------------------------------------------------------------------
1 | {
2 | "pretrained_weight": "distilbert-base-uncased",
3 | "dataset": "generic",
4 | "generic_data_dir": "data/SST-2",
5 | "pattern": "[TEXT1] it is [LBL]",
6 | "pattern_idx": 1,
7 | "dict_verbalizer": {"0": "No", "1": "Yes"},
8 | "idx_txt_trim": 1,
9 | "max_text_length": 64,
10 | "batch_size": 16,
11 | "eval_batch_size": 8,
12 | "num_batches": 1000,
13 | "max_num_lbl_tok": 1,
14 | "eval_every": 50,
15 | "warmup_ratio": 0.06,
16 | "mask_alpha": 0.105,
17 | "grad_accumulation_factor": 1,
18 | "seed": 42,
19 | "lr": 1e-5,
20 | "weight_decay": 1e-2
21 | }
22 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/requirements.txt:
--------------------------------------------------------------------------------
1 | crcmod
2 | pandas==1.1.5
3 | datasets==1.18.2
4 | numpy==1.19
5 | jsonpickle==1.1
6 | scikit-learn==0.23.1
7 | torch===1.5.0
8 | torchvision==0.6.0
9 | transformers==4.15.0
10 | tqdm==4.62.1
11 | sentencepiece==0.1.96
12 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/adapet_test_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 | from src.adapet import adapet
6 | from src.data.Batcher import Batcher
7 | from src.eval.eval_model import test_eval
8 | from src.utils.Config import Config
9 | from src.utils.util import device
10 | from transformers import *
11 |
12 |
13 | os.environ["WANDB_DISABLED"] = "true"
14 |
15 | def test_evaluation(exp_dir):
16 | config_file = os.path.join(exp_dir, "config.json")
17 | os.path.join(exp_dir, "config.json")
18 | config = Config(config_file, mkdir=False)
19 |
20 | tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight)
21 | batcher = Batcher(config, tokenizer, config.dataset)
22 | dataset_reader = batcher.get_dataset_reader()
23 |
24 | model = adapet(config, tokenizer, dataset_reader).to(device)
25 | model.load_state_dict(torch.load(os.path.join(exp_dir, "best_model.pt")))
26 | #model.load_state_dict(torch.load(os.path.join(args.exp_dir, "best_model.pt")))
27 | test_eval(config, model, batcher)
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/data/Batcher.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import math
4 | from torch.utils import data
5 | from src.data.Dataset import Dataset
6 | from src.data.DatasetReader import DatasetReader
7 | from src.utils.util import set_seeds
8 |
9 |
10 | class Batcher(object):
11 | '''
12 | Batcher is responsible for returning batches of data
13 | '''
14 | def __init__(self, config, tokenizer, dataset):
15 | '''
16 | :param config:
17 | :param tokenizer:
18 | :param dataset:
19 | '''
20 | self.config = config
21 | self.dataset_reader = DatasetReader(config, tokenizer, dataset)
22 | set_seeds(self.config.seed)
23 |
24 | self.train_loader = None
25 | self.dev_loader = None
26 | self.test_loader = None
27 | self.eval_train_loader = None
28 |
29 | self.data_len = None
30 |
31 | self.collate_fn = None
32 | if "record" in self.config.dataset:
33 | self.collate_fn = Batcher.my_collate_fn
34 |
35 | def get_dataset_reader(self):
36 | return self.dataset_reader
37 |
38 | @staticmethod
39 | def my_collate_fn(batch):
40 |
41 | dict_batch = {}
42 | dict_batch["input"] = {}
43 | dict_batch["output"] = {}
44 |
45 | for datapoint in batch:
46 | for (k, v) in datapoint["input"].items():
47 | if k in dict_batch["input"]:
48 | dict_batch["input"][k].append(v)
49 | else:
50 | dict_batch["input"][k] = [v]
51 |
52 | for (k, v) in datapoint["output"].items():
53 | if k in dict_batch["output"]:
54 | dict_batch["output"][k].append(v)
55 | else:
56 | dict_batch["output"][k] = [v]
57 |
58 | for (k, list_v) in dict_batch["input"].items():
59 | if isinstance(list_v[0], int):
60 | dict_batch["input"][k] = torch.tensor(list_v)
61 | for (k, list_v) in dict_batch["output"].items():
62 | if isinstance(list_v[0], int):
63 | dict_batch["output"][k] = torch.tensor(list_v)
64 |
65 | return dict_batch
66 |
67 | def _init_train(self):
68 | '''
69 | Initialize loader for train data
70 | '''
71 | train_data = self.dataset_reader.read_dataset("train")
72 | self.train_loader = data.DataLoader(Dataset(train_data), batch_size=self.config.batch_size, shuffle=True, collate_fn=self.my_collate_fn)
73 |
74 | eval_train_data = self.dataset_reader.read_dataset("train", is_eval=True)
75 | self.eval_train_loader = data.DataLoader(Dataset(eval_train_data), batch_size=self.config.eval_batch_size, shuffle=False, collate_fn=self.my_collate_fn)
76 |
77 |
78 | def _init_dev(self):
79 | '''
80 | Initialize loader for dev data
81 | '''
82 | dev_data = self.dataset_reader.read_dataset("dev")
83 | self.dev_loader = data.DataLoader(Dataset(dev_data), batch_size=self.config.eval_batch_size, shuffle=False, collate_fn=self.my_collate_fn)
84 |
85 | def _init_test(self):
86 | '''
87 | Initialize loader for test data
88 | '''
89 | test_data = self.dataset_reader.read_dataset("test")
90 | self.test_loader = data.DataLoader(Dataset(test_data), batch_size=self.config.eval_batch_size, shuffle=False, collate_fn=self.my_collate_fn)
91 |
92 | def get_train_batch(self):
93 | '''
94 | Yield train batches
95 |
96 | :return:
97 | '''
98 | if self.train_loader is None:
99 | self._init_train()
100 |
101 | while True:
102 | for x in self.train_loader:
103 | yield x
104 |
105 | def get_eval_train_batch(self):
106 | '''
107 | Yield non-shuffled train batches
108 |
109 | :return:
110 | '''
111 | if self.eval_train_loader is None:
112 | self._init_train()
113 | for x in self.eval_train_loader:
114 | yield x
115 |
116 | def get_dev_batch(self):
117 | '''
118 | Yield dev batches
119 |
120 | :return:
121 | '''
122 | if self.dev_loader is None:
123 | self._init_dev()
124 |
125 | for x in self.dev_loader:
126 | yield x
127 |
128 |
129 | def get_test_batch(self):
130 | '''
131 | Yield test batches
132 |
133 | :return:
134 | '''
135 | if self.test_loader is None:
136 | self._init_test()
137 |
138 | for x in self.test_loader:
139 | yield x
140 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/data/Dataset.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils import data
3 | import numpy as np
4 |
5 | class Dataset(data.Dataset):
6 | def __init__(self, data):
7 | self.data = data
8 |
9 | def __len__(self):
10 | return len(self.data)
11 |
12 | def __getitem__(self, get_idx):
13 | return self.data[get_idx]
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/data/DatasetReader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | from src.data.BoolQReader import BoolQReader
5 | from src.data.CBReader import CBReader
6 | from src.data.RTEReader import RTEReader
7 | from src.data.MultiRCReader import MultiRCReader
8 | from src.data.WiCReader import WiCReader
9 | from src.data.COPAReader import COPAReader
10 | from src.data.RecordReader import RecordReader
11 | from src.data.WSCReader import WSCReader
12 | from src.data.GenericReader import GenericReader
13 |
14 | class DatasetReader(object):
15 | '''
16 | DatasetReader is responsible for reading dataset
17 | '''
18 | def __init__(self, config, tokenizer, dataset):
19 | '''
20 | :param config:
21 | :param tokenizer:
22 | :param dataset:
23 | '''
24 | self.config = config
25 | self.dataset = dataset
26 |
27 | if self.dataset.lower() == "fewglue/boolq":
28 | self.dataset_reader = BoolQReader(self.config, tokenizer)
29 | elif self.dataset.lower() == "fewglue/cb":
30 | self.dataset_reader = CBReader(self.config, tokenizer)
31 | elif self.dataset.lower() == "fewglue/rte":
32 | self.dataset_reader = RTEReader(self.config, tokenizer)
33 | elif self.dataset.lower() == "fewglue/multirc":
34 | self.dataset_reader = MultiRCReader(self.config, tokenizer)
35 | elif self.dataset.lower() == "fewglue/wic":
36 | self.dataset_reader = WiCReader(self.config, tokenizer)
37 | elif self.dataset.lower() == "fewglue/copa":
38 | self.dataset_reader = COPAReader(self.config, tokenizer)
39 | elif self.dataset.lower() == "fewglue/record":
40 | self.dataset_reader = RecordReader(self.config, tokenizer)
41 | elif self.dataset.lower() == "fewglue/wsc":
42 | self.dataset_reader = WSCReader(self.config, tokenizer)
43 | elif self.dataset.lower() == "generic":
44 | self.dataset_reader = GenericReader(self.config, tokenizer)
45 | else:
46 | raise ValueError("Invalid Dataset name")
47 |
48 | def get_num_lbl_tok(self):
49 | '''
50 | Get number of token in labels for dataset
51 |
52 | :return:
53 | '''
54 | return self.dataset_reader.get_num_lbl_tok()
55 |
56 | def read_dataset(self, split, is_eval=False):
57 | '''
58 | Read dataset
59 |
60 | :param split:
61 | :param is_eval:
62 | :return:
63 | '''
64 | return np.asarray(self.dataset_reader.read_dataset(split, is_eval))
65 |
66 | def prepare_batch(self, batch, type):
67 | '''
68 | Prepare batch of data for model
69 |
70 | :param batch:
71 | :param type: pattern to prepare batch with and which mode to use (ex: PET_MLM_PET1)
72 | :return:
73 | '''
74 | # Prepare for PET MLM objective
75 | if "PET_MLM" in type:
76 | return self.dataset_reader.prepare_pet_mlm_batch(batch, mode=type.replace("PET_MLM_", ""))
77 | # Prepare for evaluation objective
78 | elif "EVAL" in type:
79 | return self.dataset_reader.prepare_eval_pet_batch(batch, mode=type.replace("EVAL_", ""))
80 | # Default is preparing for PET/Decoupled Label objective
81 | else:
82 | return self.dataset_reader.prepare_pet_batch(batch, mode=type)
83 |
84 | def store_test_lbl(self, list_idx, pred_lbl, true_lbl, logits):
85 | '''
86 | Store test outputs for SuperGLUE to submit to leaderboard
87 |
88 | :param list_idx:
89 | :param pred_lbl:
90 | :param true_lbl:
91 | :param logits:
92 | :return:
93 | '''
94 | self.dataset_reader.store_test_lbl(list_idx, pred_lbl, true_lbl, logits)
95 |
96 | def flush_file(self, write_file):
97 | '''
98 | Write out contents of test predictions to file
99 |
100 | :param write_file:
101 | :return:
102 | '''
103 | self.dataset_reader.flush_file(write_file)
104 |
105 | def get_num_lbl(self):
106 | '''
107 | Get number of lbls in dataset
108 |
109 | :return:
110 | '''
111 | return self.dataset_reader.num_lbl
112 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/data/tokenize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import random
4 | from collections import defaultdict
5 |
6 |
7 | def tokenize_pet_mlm_txt(tokenizer, config, txt1, txt2, txt3, txt_trim, mask_idx=None):
8 | '''
9 | Tokenizes the text by trimming the appropriate txt
10 |
11 | :param tokenizer:
12 | param config:
13 | :param txt1:
14 | :param txt2:
15 | :param txt3:
16 | :param mask_txt1:
17 | :param mask_txt2:
18 | :param mask_txt3:
19 | :param txt_trim: idx of text to trim will never contain label
20 | :return mask_idx: list of list of idx of mask token in trunc_input_ids (in case lbl is more than 1 token)
21 | '''
22 |
23 | txt1_input_ids = tokenizer(txt1, add_special_tokens=False)["input_ids"]
24 | txt2_input_ids = tokenizer(txt2, add_special_tokens=False)["input_ids"]
25 | txt3_input_ids = tokenizer(txt3, add_special_tokens=False)["input_ids"]
26 |
27 | # Add 1 to account for CLS rep
28 | tot_length = len(txt1_input_ids) + len(txt2_input_ids) + len(txt3_input_ids) + 1
29 |
30 | # Don't need to trim text
31 | if tot_length <= config.max_text_length:
32 | trunc_input_ids = [tokenizer.pad_token_id] * config.max_text_length
33 | trunc_input_ids[:tot_length] = txt1_input_ids + txt2_input_ids + txt3_input_ids
34 |
35 | # Trim text
36 | else:
37 | num_trim = tot_length - config.max_text_length
38 |
39 | if txt_trim == 0:
40 | new_txt1_input_ids = txt1_input_ids[:-num_trim]
41 | trunc_input_ids = new_txt1_input_ids + txt2_input_ids + txt3_input_ids
42 | elif txt_trim == 1:
43 | new_txt2_input_ids = txt2_input_ids[:-num_trim]
44 | trunc_input_ids = txt1_input_ids + new_txt2_input_ids + txt3_input_ids
45 | elif txt_trim == 2:
46 | new_txt_3_input_ids = txt3_input_ids[:-num_trim]
47 | trunc_input_ids = txt1_input_ids + txt2_input_ids + new_txt_3_input_ids
48 | else:
49 | raise ValueError("Invalid Txt Trim")
50 |
51 | trunc_input_ids = [tokenizer.cls_token_id] + trunc_input_ids
52 |
53 | if mask_idx is None:
54 | sample_length = min(tot_length, config.max_text_length)
55 | upto_ratio_mask = np.random.rand()
56 | num_sample = max(int(upto_ratio_mask * config.mask_alpha * sample_length), 2) - 1
57 | mask_idx = random.sample(range(0, sample_length), k=num_sample)
58 | mask_idx = np.asarray(mask_idx)
59 |
60 | # Copy adds mask idx at random positions
61 | unsup_masked_ids = np.copy(trunc_input_ids)
62 |
63 | unsup_masked_ids[mask_idx] = tokenizer.mask_token_id
64 |
65 | return trunc_input_ids, unsup_masked_ids, mask_idx
66 |
67 | def tokenize_pet_txt(tokenizer, config, txt1, txt2, txt3, mask_txt1, mask_txt2, mask_txt3, txt_trim):
68 | '''
69 | Tokenizes the text by trimming the appropriate txt
70 |
71 | :param txt1:
72 | :param txt2:
73 | :param txt3:
74 | :param mask_txt1:
75 | :param mask_txt2:
76 | :param mask_txt3:
77 | :param txt_trim: text to trim will never contain label
78 | :return trunc_input_ids: list of input ids (each exactly max_config_length)
79 | :return mask_idx: list of list of idx of mask token in trunc_input_ids (in case lbl is more than 1 token)
80 | '''
81 | txt1_input_ids = tokenizer(txt1, add_special_tokens=False)["input_ids"]
82 | txt2_input_ids = tokenizer(txt2, add_special_tokens=False)["input_ids"]
83 | txt3_input_ids = tokenizer(txt3, add_special_tokens=False)["input_ids"]
84 |
85 | mask_txt1_input_ids = tokenizer(mask_txt1, add_special_tokens=False)["input_ids"]
86 | mask_txt2_input_ids = tokenizer(mask_txt2, add_special_tokens=False)["input_ids"]
87 | mask_txt3_input_ids = tokenizer(mask_txt3, add_special_tokens=False)["input_ids"]
88 |
89 | # Add 1 to account for CLS rep
90 | tot_length = len(txt1_input_ids) + len(txt2_input_ids) + len(txt3_input_ids) + 1
91 | tot_mask_length = len(mask_txt1_input_ids) + len(mask_txt2_input_ids) + len(mask_txt3_input_ids) + 1
92 |
93 | # Don't need to trim text
94 | if tot_length <= config.max_text_length:
95 | trunc_input_ids = [tokenizer.pad_token_id] * config.max_text_length
96 | trunc_input_ids[:tot_length] = txt1_input_ids + txt2_input_ids + txt3_input_ids
97 |
98 | trunc_mask_input_ids = [tokenizer.pad_token_id] * config.max_text_length
99 | trunc_mask_input_ids[:tot_mask_length] = mask_txt1_input_ids + mask_txt2_input_ids + mask_txt3_input_ids
100 |
101 | # Trim text
102 | else:
103 | num_trim = tot_length - config.max_text_length
104 |
105 | if txt_trim == 0:
106 | new_txt1_input_ids = txt1_input_ids[:-num_trim]
107 | new_mask_txt1_input_ids = mask_txt1_input_ids[:-num_trim]
108 | trunc_input_ids = new_txt1_input_ids + txt2_input_ids + txt3_input_ids
109 | trunc_mask_input_ids = new_mask_txt1_input_ids + mask_txt2_input_ids + mask_txt3_input_ids
110 | elif txt_trim == 1:
111 | new_txt2_input_ids = txt2_input_ids[:-num_trim]
112 | new_mask_txt2_input_ids = mask_txt2_input_ids[:-num_trim]
113 | trunc_input_ids = txt1_input_ids + new_txt2_input_ids + txt3_input_ids
114 | trunc_mask_input_ids = mask_txt1_input_ids + new_mask_txt2_input_ids + mask_txt3_input_ids
115 | elif txt_trim == 2:
116 | new_txt_3_input_ids = txt3_input_ids[:-num_trim]
117 | new_mask_txt3_input_ids = mask_txt3_input_ids[:-num_trim]
118 | trunc_input_ids = txt1_input_ids + txt2_input_ids + new_txt_3_input_ids
119 | trunc_mask_input_ids = mask_txt1_input_ids + mask_txt2_input_ids + new_mask_txt3_input_ids
120 | else:
121 | raise ValueError("Invalid Txt Trim")
122 |
123 |
124 | trunc_input_ids = [tokenizer.cls_token_id] + trunc_input_ids
125 | trunc_mask_input_ids = [tokenizer.cls_token_id] + trunc_mask_input_ids
126 |
127 | mask_idx = trunc_mask_input_ids.index(tokenizer.mask_token_id)
128 |
129 | return trunc_input_ids, mask_idx
130 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/dev.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import numpy as np
5 | from transformers import *
6 |
7 | from src.data.Batcher import Batcher
8 | from src.adapet import adapet
9 | from src.utils.Config import Config
10 | from src.eval.eval_model import dev_eval
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('-e', "--exp_dir", required=True)
15 | args = parser.parse_args()
16 |
17 | config_file = os.path.join(args.exp_dir, "config.json")
18 | config = Config(config_file, mkdir=False)
19 | config.eval_dev = True
20 |
21 | tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight)
22 | batcher = Batcher(config, tokenizer, config.dataset)
23 | dataset_reader = batcher.get_dataset_reader()
24 |
25 | model = adapet(config, tokenizer, dataset_reader).to(device)
26 | model.load_state_dict(torch.load(os.path.join(args.exp_dir, "best_model.pt")))
27 | dev_acc, dev_logits = dev_eval(config, model, batcher, 0)
28 |
29 | with open(os.path.join(config.exp_dir, "dev_logits.npy"), 'wb') as f:
30 | np.save(f, dev_logits)
31 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/eval/Writer.py:
--------------------------------------------------------------------------------
1 |
2 | class Writer(object):
3 |
4 | def __init__(self, file, dataset_reader):
5 | self.write_file = open(file, 'w+')
6 | self.dataset_reader = dataset_reader
7 |
8 | def add_batch(self, list_idx, list_pred_lbl, list_true_lbl, lbl_logits):
9 | self.dataset_reader.store_test_lbl(list_idx, list_pred_lbl, list_true_lbl, lbl_logits)
10 |
11 | def flush_file(self):
12 | self.dataset_reader.flush_file(self.write_file)
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/eval/eval_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | #added
5 | import os
6 | import time
7 | import numpy as np
8 |
9 | from src.eval.Scorer import Scorer
10 | from src.eval.Writer import Writer
11 |
12 | def eval(config, model, batch_iter, scorer):
13 | '''
14 | Evaluate model
15 |
16 | :param config:
17 | :param model:
18 | :param batch_iter:
19 | :param scorer:
20 | :return:
21 | '''
22 | model.eval()
23 | with torch.no_grad():
24 | for idx, batch in enumerate(batch_iter):
25 | pred_lbl, lbl_logits = model.predict(batch)
26 | list_idx = batch["input"]["idx"] if isinstance(batch["input"]["idx"], list) else batch["input"]["idx"].cpu().numpy().tolist()
27 | list_lbl = batch["output"]["true_lbl"] if "true_lbl" in batch["output"] else batch["output"]["lbl"]
28 |
29 | if config.dataset.lower() == 'fewglue/record':
30 | true_lbl = torch.tensor([1])
31 | pred_lbl = torch.tensor([list_lbl[0][pred_lbl[0].item()]])
32 | scorer.add_batch(list_idx, pred_lbl, true_lbl, lbl_logits.cpu().numpy(), None)
33 | else:
34 | scorer.add_batch(list_idx, pred_lbl, list_lbl, lbl_logits.cpu().numpy(), None)
35 |
36 |
37 |
38 | def dev_eval(config, model, batcher, num_batches, dict_avg_val=None):
39 | '''
40 | Evaluates the accuracy on the dev partition
41 |
42 | :param config:
43 | :param model:
44 | :param batcher: batcher to get batches of data
45 | :param num_batches:
46 | :param dict_avg_val: dictionary storing metrics
47 |
48 | :return: currrent dev score
49 | '''
50 |
51 | dict_eval = {}
52 | dict_eval["num_batches"] = num_batches
53 |
54 | if dict_avg_val is not None:
55 | dict_eval.update(dict_avg_val)
56 |
57 | # Get train Score
58 | if config.eval_train:
59 | train_scorer = Scorer(config, config.dataset)
60 | train_iter = batcher.get_eval_train_batch()
61 | eval(config, model, train_iter, train_scorer)
62 | _, train_scores = train_scorer.get_score("train")
63 | dict_eval.update(train_scores)
64 |
65 | # Get dev Score
66 | if config.eval_dev:
67 | dev_scorer = Scorer(config, config.dataset)
68 | dev_iter = batcher.get_dev_batch()
69 | eval(config, model, dev_iter, dev_scorer)
70 | score_eval, dev_scores = dev_scorer.get_score("dev")
71 | dict_eval.update(dev_scores)
72 | dev_logits = dev_scorer.get_logits()
73 | else:
74 | score_eval = 0
75 | dev_logits = None
76 |
77 | with open(config.dev_score_file, 'a+') as f_out:
78 | f_out.write(json.dumps(dict_eval))
79 | f_out.write('\n')
80 |
81 | return score_eval, dev_logits
82 |
83 | def test_eval(config, model, batcher):
84 | '''
85 | Evaluates the accuracy on the test partition
86 |
87 | :param config:
88 | :param model:
89 | :param batcher:
90 | '''
91 |
92 | model.eval()
93 | dataset_reader = batcher.get_dataset_reader()
94 | test_writer = Writer(os.path.join(config.exp_dir, "test.json"), dataset_reader)
95 |
96 | with torch.no_grad():
97 | #added
98 | pred_labels = []
99 | pred_logits = []
100 | t0 = time.time()
101 | for idx, batch in enumerate(batcher.get_test_batch()):
102 | t1 = time.time()
103 | pred_lbl, lbl_logits = model.predict(batch)
104 |
105 | #lbl_logits = lbl_logits.cpu().numpy()
106 |
107 |
108 | pred_labels.extend(pred_lbl.cpu().numpy().tolist())
109 | pred_logits.extend(lbl_logits.cpu().numpy().tolist())
110 |
111 | list_idx = batch["input"]["idx"] if isinstance(batch["input"]["idx"], list) else batch["input"][
112 | "idx"].cpu().numpy().tolist()
113 | list_lbl = batch["output"]["true_lbl"] if "true_lbl" in batch["output"] else batch["output"]["lbl"]
114 |
115 | if config.dataset.lower() == 'fewglue/record':
116 | list_idx = batch["input"]["qas_idx"]
117 | list_lbl = batch["input"]["candidate_entity"]
118 | test_writer.add_batch(list_idx, pred_lbl, list_lbl, lbl_logits.cpu().numpy())
119 | else:
120 | test_writer.add_batch(list_idx, pred_lbl, list_lbl, lbl_logits.cpu().numpy())
121 |
122 | #added
123 | t2 = time.time()
124 | diff1 = t1-t0
125 | diff2 = t2-t1
126 | diff3 = t2-t0
127 | #json_dict = {'start_loop':diff1, 'inside_loop':diff2, 'once_through':diff3}
128 | #writefile = 'time_difference/'+config.exp_dir
129 | #if not os.path.exists(writefile):
130 | # os.makedirs(writefile)
131 | #print(writefile)
132 | #writefile = writefile+'time.json'
133 | #with open(writefile, "a") as f:
134 | # f.write(json.dumps(json_dict)+ '\n')
135 | t3 = time.time()
136 | print('total inference time: {}'.format(t3-t0))
137 | #altered
138 | #print(pred_logits)
139 | test_writer.flush_file()
140 | return pred_labels, np.array(pred_logits)
141 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/run_pretrained.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import numpy as np
5 | from transformers import *
6 |
7 | from src.data.Batcher import Batcher
8 | from src.utils.Config import Config
9 | from src.utils.util import device, ParseKwargs
10 | from src.adapet import adapet
11 | from src.eval.eval_model import dev_eval
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('-m', "--model_dir", required=True)
16 | parser.add_argument('-c', "--config_file", required=True)
17 | parser.add_argument('-k', '--kwargs', nargs='*', action=ParseKwargs, default={})
18 | args = parser.parse_args()
19 |
20 | config = Config(args.config_file, args.kwargs, mkdir=True)
21 |
22 | tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight)
23 | batcher = Batcher(config, tokenizer, config.dataset)
24 | dataset_reader = batcher.get_dataset_reader()
25 |
26 | model = adapet(config, tokenizer, dataset_reader).to(device)
27 | model.load_state_dict(torch.load(os.path.join(args.model_dir, "best_model.pt")))
28 | dev_eval(config, model, batcher, 0)
29 |
30 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/scripts/example_convert_sst_2_generic.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 |
5 | def read_tsv(filename, max_num_lines):
6 | train_json_filename = filename.replace(".tsv", ".jsonl")
7 |
8 | with open(filename, 'r') as f_in, open(train_json_filename, 'w+') as f_out:
9 | f_in.readline()
10 | for idx, line in enumerate(f_in.readlines()):
11 | if idx < max_num_lines:
12 | tab_split = line.strip('\n').split('\t')
13 | dict_json = {"TEXT1": tab_split[0], "LBL": tab_split[1]}
14 |
15 | f_out.write(json.dumps(dict_json) + '\n')
16 | else:
17 | break
18 |
19 |
20 | if __name__ == "__main__":
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('-f', "--filepath", required=True)
23 | parser.add_argument('-m', "--max_num_lines", type=int, default=32)
24 | args = parser.parse_args()
25 |
26 | read_tsv(args.filepath, args.max_num_lines)
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/test.py:
--------------------------------------------------------------------------------
1 | #import argparse
2 | import os
3 | import torch
4 | import numpy as np
5 | from transformers import *
6 |
7 | from src.data.Batcher import Batcher
8 | from src.utils.Config import Config
9 | #altered
10 | from src.utils.util import device
11 | from src.adapet import adapet
12 | from src.eval.eval_model import test_eval
13 |
14 |
15 | def do_test(exp_dir):
16 | #device = torch.device("cpu")
17 | config_file = os.path.join(exp_dir, "config.json")
18 | config = Config(config_file, mkdir=False)
19 |
20 | tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight)
21 | batcher = Batcher(config, tokenizer, config.dataset)
22 | dataset_reader = batcher.get_dataset_reader()
23 |
24 | model = adapet(config, tokenizer, dataset_reader).to(device)
25 | model.load_state_dict(torch.load(os.path.join(exp_dir, "best_model.pt")))
26 | #altered
27 | pred_labels, pred_logits = test_eval(config, model, batcher)
28 |
29 | return pred_labels, pred_logits
30 |
31 | #if __name__ == "__main__":
32 | #parser = argparse.ArgumentParser()
33 | #parser.add_argument('-e', "--exp_dir", required=True)
34 |
35 | #args = parser.parse_args()
36 |
37 | '''
38 | config_file = os.path.join(args.exp_dir, "config.json")
39 | config = Config(config_file, mkdir=False)
40 |
41 | tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight)
42 | batcher = Batcher(config, tokenizer, config.dataset)
43 | dataset_reader = batcher.get_dataset_reader()
44 |
45 | model = adapet(config, tokenizer, dataset_reader).to(device)
46 | model.load_state_dict(torch.load(os.path.join(args.exp_dir, "best_model.pt")))
47 | test_eval(config, model, batcher)
48 | '''
49 |
50 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import re
4 | import numpy as np
5 | import argparse
6 | import logging
7 | os.environ["WANDB_DISABLED"] = "true"
8 | from transformers import *
9 |
10 | from src.eval.eval_model import dev_eval
11 | from src.adapet import adapet
12 | from torch.optim.lr_scheduler import LambdaLR
13 |
14 | from src.data.Batcher import Batcher
15 | from src.utils.Config import Config
16 | from src.utils.util import get_avg_dict_val_store, update_dict_val_store, ParseKwargs
17 | from src.utils.util import set_global_logging_level, device
18 |
19 |
20 |
21 | set_global_logging_level(logging.ERROR)
22 |
23 | # From HuggingFace
24 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
25 | """
26 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0,
27 | after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
28 |
29 | Args:
30 | optimizer (:class:`~torch.optim.Optimizer`):
31 | The optimizer for which to schedule the learning rate.
32 | num_warmup_steps (:obj:`int`):
33 | The number of steps for the warmup phase.
34 | num_training_steps (:obj:`int`):
35 | The total number of training steps.
36 | last_epoch (:obj:`int`, `optional`, defaults to -1):
37 | The index of the last epoch when resuming training.
38 |
39 | Return:
40 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
41 | """
42 |
43 | def lr_lambda(current_step: int):
44 | if current_step < num_warmup_steps:
45 | return float(current_step) / float(max(1, num_warmup_steps))
46 | return max(
47 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
48 | )
49 |
50 | return LambdaLR(optimizer, lr_lambda, last_epoch)
51 |
52 |
53 | def train(config):
54 | '''
55 | Trains the model
56 |
57 | :param config:
58 | :return:
59 | '''
60 |
61 | tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight)
62 | batcher = Batcher(config, tokenizer, config.dataset)
63 | dataset_reader = batcher.get_dataset_reader()
64 | model = adapet(config, tokenizer, dataset_reader).to(device)
65 |
66 | ### Create Optimizer
67 | # Ignore weight decay for certain parameters
68 | no_decay_param = ['bias', 'LayerNorm.weight']
69 | optimizer_grouped_parameters = [
70 | {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay_param)],
71 | 'weight_decay': config.weight_decay,
72 | 'lr': config.lr},
73 | {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay_param)],
74 | 'weight_decay': 0.0,
75 | 'lr': config.lr},
76 | ]
77 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, eps=1e-8)
78 |
79 | #altered
80 | #best_dev_acc = 0
81 | best_dev_acc = -float('inf')
82 | train_iter = batcher.get_train_batch()
83 | dict_val_store = None
84 |
85 | # Number of batches is assuming grad_accumulation_factor forms one batch
86 | tot_num_batches = config.num_batches * config.grad_accumulation_factor
87 |
88 | # Warmup steps and total steps are based on batches, not epochs
89 | num_warmup_steps = config.num_batches * config.warmup_ratio
90 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, config.num_batches)
91 |
92 | for i in range(tot_num_batches):
93 | # Get true batch_idx
94 | batch_idx = (i // config.grad_accumulation_factor)
95 |
96 | model.train()
97 | sup_batch = next(train_iter)
98 | loss, dict_val_update = model(sup_batch)
99 | loss = loss / config.grad_accumulation_factor
100 | loss.backward()
101 |
102 | if (i+1) % config.grad_accumulation_factor == 0:
103 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm)
104 | optimizer.step()
105 | optimizer.zero_grad()
106 | scheduler.step()
107 |
108 | dict_val_store = update_dict_val_store(dict_val_store, dict_val_update, config.grad_accumulation_factor)
109 | print("Finished %d batches" % batch_idx, end='\r')
110 |
111 | if (batch_idx + 1) % config.eval_every == 0 and i % config.grad_accumulation_factor == 0:
112 | dict_avg_val = get_avg_dict_val_store(dict_val_store, config.eval_every)
113 | dict_val_store = None
114 | dev_acc, dev_logits = dev_eval(config, model, batcher, batch_idx, dict_avg_val)
115 | #altered but not used
116 | if type(dev_acc) == str:
117 | f1s = re.findall(r"[-+]?\d*\.\d+|\d+", dev_acc)
118 | dev_acc = float(f1s[0])
119 |
120 | print("Global Step: %d Acc: %.3f" % (batch_idx, float(dev_acc)) + '\n')
121 |
122 | if dev_acc > best_dev_acc:
123 | best_dev_acc = dev_acc
124 | torch.save(model.state_dict(), os.path.join(config.exp_dir, "best_model.pt"))
125 | with open(os.path.join(config.exp_dir, "dev_logits.npy"), 'wb') as f:
126 | np.save(f, dev_logits)
127 |
128 |
129 | if __name__ == "__main__":
130 | parser = argparse.ArgumentParser()
131 | parser.add_argument('-c', "--config_file", required=True)
132 | parser.add_argument('-k', '--kwargs', nargs='*', action=ParseKwargs, default={})
133 | args = parser.parse_args()
134 |
135 | config = Config(args.config_file, args.kwargs, mkdir=True)
136 | train(config)
137 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/utils/Config.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import ast
4 |
5 | from src.utils.util import make_exp_dir
6 |
7 | class Config(object):
8 | def __init__(self, filename=None, kwargs=None, mkdir=True):
9 | # Dataset parameters
10 | self.dataset = "fewglue/BoolQ"
11 | self.num_lbl = 2
12 | self.max_num_lbl_tok = 10
13 | self.max_num_lbl = 10
14 |
15 | # Model and pattern parameters
16 | self.pretrained_weight = "bert-base-uncased"
17 | self.pattern_idx = "random"
18 |
19 | # Duration of training parameters
20 | self.batch_size = 8
21 | self.eval_batch_size = 64
22 | self.num_batches = 1000
23 | self.eval_every = 1
24 | self.grad_accumulation_factor = 1
25 | self.max_text_length = 64
26 |
27 | self.mask_alpha = 0.5
28 |
29 | self.eval_train = False
30 | self.eval_dev = True
31 |
32 | # Where experiments will be located
33 | self.exp_dir = None
34 | self.seed = 42
35 | self.exp_name = ""
36 |
37 | # Training Hyperparameters
38 | self.lr = 1e-3
39 | self.weight_decay = 0
40 | self.grad_clip_norm = 1
41 | self.warmup_ratio = 0
42 |
43 | # Generic dataset hyperparameters
44 | self.pattern = "[TEXT1] and [TEXT2] "
45 | self.idx_txt_trim = -1 # Indexed from 1
46 | self.dict_verbalizer = {"True": "Yes", "False": "No"}
47 | self.data_dir = "data/fewglue/BoolQ"
48 | #Added
49 | self.task_name = 'SetFit/sst2'
50 |
51 | if filename:
52 | self.__dict__.update(json.load(open(filename)))
53 | if kwargs:
54 | self.update_kwargs(kwargs)
55 |
56 | if filename or kwargs:
57 | self.update_exp_config(mkdir)
58 |
59 | def update_kwargs(self, kwargs):
60 | for (k, v) in kwargs.items():
61 | try:
62 | v = ast.literal_eval(v)
63 | except:
64 | v = v
65 | setattr(self, k, v)
66 |
67 | def update_exp_config(self, mkdir=True):
68 | '''
69 | Updates the config default values based on parameters passed in from config file
70 | '''
71 |
72 |
73 | self.base_dir = os.path.join("exp_out", self.dataset, self.pretrained_weight, self.task_name)
74 | if self.exp_name != "":
75 | self.base_dir = os.path.join(self.base_dir, self.exp_name)
76 |
77 | if mkdir:
78 | self.exp_dir = make_exp_dir(self.base_dir)
79 |
80 | if self.exp_dir is not None:
81 | self.dev_pred_file = os.path.join(self.exp_dir, "dev_pred.txt")
82 | self.dev_score_file = os.path.join(self.exp_dir, "dev_scores.json")
83 | self.test_score_file = os.path.join(self.exp_dir, "test_scores.json")
84 | self.save_config(os.path.join(self.exp_dir, os.path.join("config.json")))
85 |
86 | def to_json(self):
87 | '''
88 | Converts parameter values in config to json
89 | :return: json
90 | '''
91 | #altered -- ensure ascii now == False
92 | return json.dumps(self.__dict__, indent=4, sort_keys=True, ensure_ascii=False)
93 |
94 | def save_config(self, filename):
95 | '''
96 | Saves the config
97 | '''
98 | with open(filename, 'w+') as fout:
99 | fout.write(self.to_json())
100 | fout.write('\n')
101 |
--------------------------------------------------------------------------------
/scripts/adapet/ADAPET/src/utils/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import datetime
3 | import os
4 | import sys
5 | import argparse
6 | import subprocess
7 | from shutil import copytree, ignore_patterns
8 | import random
9 | import numpy as np
10 | import logging
11 | import re
12 | import sys
13 |
14 | global device; device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15 |
16 | def set_global_logging_level(level=logging.ERROR, prefices=[""]):
17 | """
18 | Override logging levels of different modules based on their name as a prefix.
19 | It needs to be invoked after the modules have been loaded so that their loggers have been initialized.
20 |
21 | Args:
22 | - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
23 | - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
24 | Default is `[""]` to match all active loggers.
25 | The match is a case-sensitive `module_name.startswith(prefix)`
26 | """
27 | prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
28 | for name in logging.root.manager.loggerDict:
29 | if re.match(prefix_re, name):
30 | logging.getLogger(name).setLevel(level)
31 |
32 |
33 | def set_seeds(seed):
34 | "set random seeds"
35 | random.seed(seed)
36 | np.random.seed(seed)
37 | torch.manual_seed(seed)
38 | torch.cuda.manual_seed_all(seed)
39 |
40 | def make_dir(dir_name):
41 | '''
42 | Makes a directory if it doesn't exists yet
43 | Args:
44 | dir_name: directory name
45 | '''
46 | if not os.path.exists(dir_name):
47 | os.makedirs(dir_name)
48 |
49 | def make_exp_dir(base_exp_dir):
50 | '''
51 | Makes an experiment directory with timestamp
52 | Args:
53 | base_output_dir_name: base output directory name
54 | Returns:
55 | exp_dir_name: experiment directory name
56 | '''
57 | now = datetime.datetime.now()
58 | ts = "{:04d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}".format(now.year, now.month, now.day, now.hour, now.minute,
59 | now.second)
60 | exp_dir_name = os.path.join(base_exp_dir, ts)
61 | make_dir(exp_dir_name)
62 |
63 | src_file = os.path.join(exp_dir_name, 'src')
64 |
65 | #copytree(os.path.join(os.environ['ADAPET_ROOT'], "src"), src_file, ignore=ignore_patterns('*.pyc', 'tmp*'))
66 |
67 | return exp_dir_name
68 |
69 | def print_mem_usage(loc):
70 | '''
71 | Print memory usage in GB
72 | :return:
73 | '''
74 | print("%s mem usage: %.3f GB, %.3f GB, %.3f GB" % (loc, float(torch.cuda.memory_allocated() / 1e9), float(torch.cuda.memory_reserved() / 1e9), float(torch.cuda.max_memory_allocated() / 1e9)))
75 | sys.stdout.flush()
76 |
77 | class ParseKwargs(argparse.Action):
78 | def __call__(self, parser, namespace, values, option_string=None):
79 | setattr(namespace, self.dest, dict())
80 | for value in values:
81 | key, value = value.split('=')
82 | getattr(namespace, self.dest)[key] = value
83 |
84 |
85 | def update_dict_val_store(dict_val_store, dict_update_val, grad_accumulation_factor):
86 | '''
87 | Update dict_val_store with dict_update_val
88 |
89 | :param dict_val_store:
90 | :param dict_update_val:
91 | :return:
92 | '''
93 | if dict_val_store is None:
94 | dict_val_store = dict_update_val
95 | else:
96 | for k in dict_val_store.keys():
97 | dict_val_store[k] += dict_update_val[k] / grad_accumulation_factor
98 |
99 | return dict_val_store
100 |
101 | def get_avg_dict_val_store(dict_val_store, num_batches=100):
102 | '''
103 | Get average dictionary val
104 |
105 | :param dict_val_store:
106 | :param eval_every:
107 | :return:
108 | '''
109 | dict_avg_val = {}
110 |
111 | for k in dict_val_store.keys():
112 | dict_avg_val[k] = float('%.3f' % (dict_val_store[k].detach().cpu().item() / num_batches))
113 |
114 | return dict_avg_val
115 |
--------------------------------------------------------------------------------
/scripts/create_summary_table.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import tarfile
5 | from collections import defaultdict
6 | from glob import glob
7 | from os import listdir
8 | from os.path import isdir, join, splitext
9 | from typing import List, Tuple
10 |
11 | from numpy import mean, median, std
12 | from scipy.stats import iqr
13 |
14 |
15 | """
16 | To run: python create_summary_table.py --path scripts/{method_name}/{results}/{model_name}
17 | or: python create_summary_table.py --path scripts/{method_name}/{model_name}.tar.gz
18 | Files are outputted to the directory of the results.
19 | """
20 |
21 | TEST_DATASET_TO_METRIC = {
22 | "emotion": "accuracy",
23 | "SentEval-CR": "accuracy",
24 | "sst5": "accuracy",
25 | "ag_news": "accuracy",
26 | "enron_spam": "accuracy",
27 | "amazon_counterfactual_en": "matthews_correlation",
28 | }
29 |
30 |
31 | def extract_results(path: str) -> None:
32 | tar = tarfile.open(path, "r:gz")
33 | unzip_path = splitext(splitext(path)[-2])[-2]
34 | tar.extractall(path=os.path.dirname(unzip_path))
35 | tar.close()
36 | return unzip_path
37 |
38 |
39 | def get_sample_sizes(path: str) -> List[str]:
40 | return sorted(list({int(name.split("-")[-2]) for name in glob(f"{path}/*/train-*-0")}))
41 |
42 |
43 | def get_tfew_sample_sizes(path: str) -> List[str]:
44 | return sorted(list({int(name.split("-")[-2]) for name in glob(f"{path}/train-*-0/seed0")}))
45 |
46 |
47 | def compute_tfew_medians(results_path: str) -> None:
48 | """Given per-split and per-seed T-Few results for multiple dataset,
49 | calculates the median score and interquartile range across all seeds,
50 | and saves them to a `results.json` file in the same path.
51 |
52 | Args:
53 | results_path: path to T-Few results: `/setfit/scripts/tfew/results/t03b_pretrained`
54 | """
55 |
56 | for dataset in listdir(results_path):
57 | dataset_path = join(results_path, dataset)
58 | if isdir(dataset_path):
59 | dataset_metric = TEST_DATASET_TO_METRIC[dataset]
60 | sample_sizes = get_tfew_sample_sizes(dataset_path)
61 |
62 | for sample_size in sample_sizes:
63 | split_dirs = sorted(glob(join(dataset_path, f"train-{sample_size}-*")))
64 | assert split_dirs is not None
65 |
66 | for split_dir in split_dirs:
67 | seed_results_json = sorted(glob(join(split_dir, "seed*/dev_scores.json")))
68 | seed_metrics = []
69 | for seed_result_json in seed_results_json:
70 | with open(seed_result_json) as f:
71 | result_dict = json.loads(f.readlines()[-1])
72 | seed_metrics.append(result_dict[dataset_metric] * 100)
73 |
74 | with open(join(split_dir, "results.json"), "w") as f:
75 | json.dump(
76 | {"score": median(seed_metrics), "measure": dataset_metric, "iqr": iqr(seed_metrics)}, f
77 | )
78 |
79 |
80 | def get_formatted_ds_metrics(path: str, dataset: str, sample_sizes: List[str]) -> Tuple[str, List[str]]:
81 | formatted_row = []
82 | metric_name = ""
83 | exact_metrics, exact_stds = {}, {}
84 |
85 | for sample_size in sample_sizes:
86 | result_jsons = sorted(glob(os.path.join(path, dataset, f"train-{sample_size}-*", "results.json")))
87 | split_metrics = []
88 |
89 | for result_json in result_jsons:
90 | with open(result_json) as f:
91 | result_dict = json.load(f)
92 |
93 | metric_name = result_dict.get("measure", "N/A")
94 | split_metrics.append(result_dict["score"])
95 |
96 | exact_metrics[sample_size] = mean(split_metrics)
97 | exact_stds[sample_size] = std(split_metrics)
98 | formatted_row.extend([f"{exact_metrics[sample_size]:.1f}", f"{exact_stds[sample_size]:.1f}"])
99 |
100 | return metric_name, formatted_row, exact_metrics, exact_stds, sample_sizes
101 |
102 |
103 | def create_summary_table(results_path: str) -> None:
104 | """Given per-split results, creates a summary table of all datasets,
105 | with average metrics and standard deviations.
106 |
107 | Args:
108 | path: path to per-split results: either `scripts/{method_name}/{results}/{model_name}`,
109 | or `final_results/{method_name}/{model_name}.tar.gz`
110 | """
111 |
112 | if results_path.endswith("tar.gz"):
113 | unzipped_path = extract_results(results_path)
114 | else:
115 | unzipped_path = results_path
116 |
117 | if "tfew" in unzipped_path:
118 | print("Computing medians for T-Few...")
119 | compute_tfew_medians(unzipped_path)
120 |
121 | sample_sizes = get_sample_sizes(unzipped_path)
122 | header_row = ["dataset", "measure"]
123 | for sample_size in sample_sizes:
124 | header_row.append(f"{sample_size}_avg")
125 | header_row.append(f"{sample_size}_std")
126 |
127 | csv_lines = [header_row]
128 |
129 | means, stds = defaultdict(list), defaultdict(list)
130 | for dataset in next(os.walk(unzipped_path))[1]:
131 | metric_name, formatted_metrics, exact_metrics, exact_stds, sample_sizes = get_formatted_ds_metrics(
132 | unzipped_path, dataset, sample_sizes
133 | )
134 | dataset_row = [dataset, metric_name, *formatted_metrics]
135 | csv_lines.append(dataset_row)
136 |
137 | # Collect exact metrics for overall average and std calculation
138 | for sample_size in sample_sizes:
139 | means[sample_size].append(exact_metrics[sample_size])
140 | stds[sample_size].append(exact_stds[sample_size])
141 |
142 | # Generate row for overall average
143 | formatted_average_row = []
144 | for sample_size in sample_sizes:
145 | overall_average = mean(means[sample_size])
146 | overall_std = mean(stds[sample_size])
147 | formatted_average_row.extend([f"{overall_average:.1f}", f"{overall_std:.1f}"])
148 | csv_lines.append(["Average", "N/A", *formatted_average_row])
149 |
150 | output_path = os.path.join(unzipped_path, "summary_table.csv")
151 | print("=" * 80)
152 | print("Summary table:\n")
153 | with open(output_path, "w") as f:
154 | for line in csv_lines:
155 | f.write(",".join(line) + "\n")
156 | print(", ".join(line))
157 | print("=" * 80)
158 | print(f"Saved summary table to {output_path}")
159 |
160 |
161 | def main() -> None:
162 | parser = argparse.ArgumentParser()
163 |
164 | parser.add_argument("--path", type=str)
165 | args = parser.parse_args()
166 |
167 | create_summary_table(args.path)
168 |
169 |
170 | if __name__ == "__main__":
171 | main()
172 |
--------------------------------------------------------------------------------
/scripts/perfect/README.md:
--------------------------------------------------------------------------------
1 | # Running PERFECT
2 |
3 | Follow the steps below to run the baselines based on the `PERFECT` paper: [_PERFECT: Prompt-free and Efficient Few-shot Learning with Language Models_](https://arxiv.org/abs/2204.01172).
4 |
5 | ## Setup
6 |
7 | To get started, first create a Python virtual environment, e.g. with `conda`:
8 |
9 | ```
10 | conda create -n baselines-perfect python=3.10 && conda activate baselines-perfect
11 | ```
12 |
13 | Next, clone [our fork](https://github.com/SetFit/perfect) of the [`PERFECT` codebase](https://github.com/facebookresearch/perfect), and install the required dependencies:
14 |
15 | ```
16 | git clone git+https://github.com/SetFit/perfect.git
17 | cd perfect
18 | python setup.py develop
19 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
20 | python -m pip install -r requirements.txt
21 | ```
22 |
23 | Next, download and process the datasets:
24 |
25 | ```
26 | wget https://nlp.cs.princeton.edu/projects/lm-bff/datasets.tar
27 | tar -xvf datasets.tar
28 | mv original/ datasets
29 | python fewshot/process_datasets.py
30 | ```
31 |
32 | ## Usage example
33 |
34 | To train and evaluate `PERFECT` on 8 and 64 examples (per class) across all the SetFit test datasets, run:
35 |
36 | ```
37 | cd fewshot/
38 | bash scripts/run_setfit_baselines.sh
39 | ```
--------------------------------------------------------------------------------
/scripts/plot_summary_comparison.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import string
5 | import sys
6 | from collections import defaultdict
7 | from glob import glob
8 | from pathlib import Path
9 | from typing import List, Tuple
10 |
11 | import matplotlib.pyplot as plt
12 | import pandas as pd
13 |
14 |
15 | """
16 | To run:
17 | python plot_summary_comparison.py --paths scripts/{method_name}/results/{model_name}
18 | Multiple paths can be provided. The produced plots are outputted to scripts/images/v_{id}/{dataset}.png.
19 |
20 | See https://github.com/huggingface/setfit/pull/268#issuecomment-1434549208 for an example of the plots
21 | produced by this script.
22 | """
23 |
24 |
25 | def get_sample_sizes(path: str) -> List[str]:
26 | return sorted(list({int(name.split("-")[-2]) for name in glob(f"{path}/*/train-*-0")}))
27 |
28 |
29 | def get_formatted_ds_metrics(path: str, dataset: str, sample_sizes: List[str]) -> Tuple[str, List[str]]:
30 | split_metrics = defaultdict(list)
31 |
32 | for sample_size in sample_sizes:
33 | result_jsons = sorted(glob(os.path.join(path, dataset, f"train-{sample_size}-*", "results.json")))
34 | for result_json in result_jsons:
35 | with open(result_json) as f:
36 | result_dict = json.load(f)
37 |
38 | metric_name = result_dict.get("measure", "N/A")
39 | split_metrics[sample_size].append(result_dict["score"])
40 |
41 | return metric_name, split_metrics
42 |
43 |
44 | def plot_summary_comparison(paths: List[str]) -> None:
45 | """Given a list of paths to output directories produced by e.g. `scripts/setfit/run_fewshot.py`,
46 | produce and save boxplots that compare the various results.
47 |
48 | The plots are saved to scripts/images/v_{id}/{dataset}.png, i.e. one plot per dataset.
49 |
50 | Args:
51 | paths (List[str]): List of paths to output directories, generally
52 | `scripts/{method_name}/results/{model_name}`
53 | """
54 |
55 | # Parse the result paths
56 | dataset_to_df = defaultdict(pd.DataFrame)
57 | dataset_to_metric = {}
58 | for path_index, path in enumerate(paths):
59 | ds_to_metric, this_dataset_to_df = get_summary_df(path)
60 | for dataset, df in this_dataset_to_df.items():
61 | df["path_index"] = path_index
62 | dataset_to_df[dataset] = pd.concat((dataset_to_df[dataset], df))
63 | dataset_to_metric = dataset_to_metric | ds_to_metric
64 |
65 | # Prepare folder for storing figures
66 | image_dir = Path("scripts") / "images"
67 | image_dir.mkdir(exist_ok=True)
68 | new_version = (
69 | max([int(path.name[2:]) for path in image_dir.glob("v_*/") if path.name[2:].isdigit()], default=0) + 1
70 | )
71 | output_dir = image_dir / f"v_{new_version}"
72 | output_dir.mkdir()
73 |
74 | # Save a copy the executed command in output directory
75 | (output_dir / "command.txt").write_text("python " + " ".join(sys.argv))
76 |
77 | # Create the plots per each dataset
78 | for dataset, df in dataset_to_df.items():
79 | columns = [column for column in df.columns if not column.startswith("path")]
80 | fig, axes = plt.subplots(ncols=len(columns), sharey=True)
81 | for column_index, column in enumerate(columns):
82 | ax = axes[column_index] if len(columns) > 1 else axes
83 |
84 | # Set the y label only for the first column
85 | if column_index == 0:
86 | ax.set_ylabel(dataset_to_metric[dataset])
87 |
88 | # Set positions to 0, 0.25, ..., one position per boxplot
89 | # This places the boxplots closer together
90 | n_boxplots = len(df["path_index"].unique())
91 | allotted_box_width = 0.2
92 | positions = [allotted_box_width * i for i in range(n_boxplots)]
93 | ax.set_xlim(-allotted_box_width * 0.75, allotted_box_width * (n_boxplots - 0.25))
94 |
95 | df[[column, "path_index"]].groupby("path_index", sort=True).boxplot(
96 | subplots=False, ax=ax, column=column, positions=positions
97 | )
98 |
99 | k_shot = column.split("-")[-1]
100 | ax.set_xlabel(f"{k_shot}-shot")
101 | if n_boxplots > 1:
102 | # If there are multiple boxplots, override the labels at the bottom generated by pandas
103 | if n_boxplots <= 26:
104 | ax.set_xticklabels(string.ascii_uppercase[:n_boxplots])
105 | else:
106 | ax.set_xticklabels(range(n_boxplots))
107 | else:
108 | # Otherwise, just remove the xticks
109 | ax.tick_params(labelbottom=False)
110 |
111 | if n_boxplots > 1:
112 | fig.suptitle(
113 | f"Comparison between various baselines on the {dataset}\ndataset under various $K$-shot conditions"
114 | )
115 | else:
116 | fig.suptitle(f"Results on the {dataset} dataset under various $K$-shot conditions")
117 | fig.tight_layout()
118 | plt.savefig(str(output_dir / dataset))
119 |
120 |
121 | def get_summary_df(path: str) -> None:
122 | """Given per-split results, return a mapping from dataset to metrics (e.g. "accuracy") and
123 | a mapping from dataset to pandas DataFrame that stores the results
124 |
125 | Args:
126 | path: path to per-split results: generally `scripts/{method_name}/results/{model_name}`,
127 | """
128 |
129 | sample_sizes = get_sample_sizes(path)
130 | header_row = ["dataset", "measure"]
131 | for sample_size in sample_sizes:
132 | header_row.append(f"{sample_size}_avg")
133 | header_row.append(f"{sample_size}_std")
134 |
135 | dataset_to_metric = {}
136 | dataset_to_df = {}
137 | for dataset in next(os.walk(path))[1]:
138 | metric_name, split_metrics = get_formatted_ds_metrics(path, dataset, sample_sizes)
139 | dataset_df = pd.DataFrame(split_metrics.values(), index=[f"{dataset}-{key}" for key in split_metrics]).T
140 | dataset_to_metric[dataset] = metric_name
141 | dataset_to_df[dataset] = dataset_df
142 | return dataset_to_metric, dataset_to_df
143 |
144 |
145 | def main() -> None:
146 | parser = argparse.ArgumentParser()
147 |
148 | parser.add_argument("--paths", nargs="+", type=str)
149 | args = parser.parse_args()
150 |
151 | if args.paths:
152 | plot_summary_comparison(args.paths)
153 | else:
154 | raise Exception("Please provide at least one path via the `--paths` CLI argument.")
155 |
156 |
157 | if __name__ == "__main__":
158 | main()
159 |
--------------------------------------------------------------------------------
/scripts/setfit/README.md:
--------------------------------------------------------------------------------
1 | # Running SetFit
2 |
3 | ## Setup
4 |
5 | To run the scripts, first create a Python virtual environment, e.g. with `conda`:
6 |
7 | ```
8 | conda create -n baselines-setfit python=3.9 && conda activate baselines-setfit
9 | ```
10 |
11 | Next, install the required dependencies:
12 |
13 | ```
14 | python -m pip install setfit
15 | ```
16 |
17 | ## Usage
18 |
19 | To train and evaluate `SetFit` on 8 examples (per class) on the `sst2` dataset, run:
20 |
21 | ```
22 | python run_fewshot.py --sample_sizes=8 --datasets=sst2
23 | ```
24 |
25 | This will use the default settings used in the paper, including `paraphrase-mpnet-base-v2` as the backbone model. Results will be saved in the `results` directory. To run `SetFit` across all the development datasets used in the paper, run:
26 |
27 | ```
28 | python run_fewshot.py --sample_sizes=8 --is_dev_set=true
29 | ```
30 |
31 | Similarly, you can run `SetFit` over all the test datasets in the paper by running:
32 |
33 | ```
34 | python run_fewshot.py --sample_sizes=8 --is_test_set=true
35 | ```
36 |
37 | ### Exhaustive example
38 |
39 | The following is an example with all argument options and their default values.
40 | Note that you can run on a series of datasets and sample sizes:
41 |
42 | ```
43 | python run_fewshot.py \
44 | --model paraphrase-mpnet-base-v2 \
45 | --datasets sst2 ag_news bbc-news \
46 | --sample_sizes 8 64 \
47 | --num_epochs 1 \
48 | --num_iterations 20 \
49 | --batch_size 16 \
50 | --max_seq_length 256 \
51 | --classifier logistic_regression \
52 | --loss CosineSimilarityLoss \
53 | --exp_name "" \
54 | --add_normalization_layer \
55 | ```
56 |
57 | ### Multilingual experiments
58 |
59 | We provide three different ways to run `SetFit` in multilingual settings:
60 |
61 | * `each`: train on data in target language
62 | * `en`: train on English data only
63 | * `all`: train on data in all languages
64 |
65 | To train `SetFit` in one of these setting, run:
66 |
67 | ```
68 | python run_fewshot_multilingual.py \
69 | --model sentence-transformers/paraphrase-multilingual-mpnet-base-v2 \
70 | --datasets amazon_reviews_multi_de amazon_reviews_multi_es \
71 | --sample_sizes 8 \
72 | --multilinguality=each
73 | ```
74 |
75 | To train `SetFit` on all the multilingual test sets in the paper, run:
76 |
77 | ```
78 | python run_fewshot_multilingual.py \
79 | --model=sentence-transformers/paraphrase-multilingual-mpnet-base-v2 \
80 | --multilinguality=each
81 | ```
82 |
83 | ### Multilabel experiments
84 |
85 | To run `SetFit` on one our our multilingual datasets, run:
86 |
87 | ```
88 | python run_fewshot_multilabel.py \
89 | --sample_sizes=8 64 \
90 | --datasets=go_emotions
91 | ```
92 |
93 | # Zero-shot Text Classification with SetFit
94 | Although `SetFit` was designed for few-shot learning, the method can also be applied in scenarios where no labeled data is available. The main trick is to create synthetic examples that resemble the classification task, and then train a `SetFit` model on them.
95 |
96 | Remarkably, this simple technique typically outperforms the zero-shot pipeline in 🤗 Transformers, and can generate predictions by a factor of 5x (or more) faster!
97 |
98 | To create the synthetic training examples, the label names for the task are required.
99 | The labels can be taken from a `--reference_dataset` or supplied explicitly using `--candidate_labels`.
100 | If both aren't supplied, `--eval_dataset` is used as reference dataset.
101 |
102 | To evaluate zero-shot `SetFit` on the `emotion` dataset, run:
103 |
104 | ```
105 | python run_zeroshot.py --eval_dataset=SetFit/emotion
106 | ```
107 |
108 |
109 | To evaluate on a custom dataset with custom label names:
110 |
111 | ```
112 | python run_zeroshot.py --eval_dataset=[dataset_name] --candidate_labels [label_1 label2 ...]
113 | ```
114 |
--------------------------------------------------------------------------------
/scripts/setfit/distillation_baseline.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from datasets import Dataset
5 | from evaluate import load
6 | from transformers import (
7 | AutoModelForSequenceClassification,
8 | AutoTokenizer,
9 | DataCollatorWithPadding,
10 | Trainer,
11 | TrainingArguments,
12 | )
13 |
14 |
15 | class RegressionTrainer(Trainer):
16 | def compute_loss(self, model, inputs, return_outputs=False):
17 | labels = inputs.pop("labels")
18 | outputs = model(**inputs)
19 | logits = outputs.logits
20 | loss = F.mse_loss(logits, labels)
21 | return (loss, outputs) if return_outputs else loss
22 |
23 |
24 | class BaselineDistillation:
25 | def __init__(self, student_model_name, num_epochs, batch_size) -> None:
26 | self.student_model_name = student_model_name
27 | self.num_epochs = num_epochs
28 | self.batch_size = batch_size
29 | self.tokenizer = AutoTokenizer.from_pretrained(student_model_name)
30 | self.seq_len = 64
31 | self.learning_rate = 6e-5
32 |
33 | def update_metric(self, metric):
34 | self.metric = load(metric)
35 | self.metric_name = metric
36 |
37 | def bl_student_preprocess(self, examples):
38 | label = examples["score"]
39 | examples = self.tokenizer(
40 | examples["text"],
41 | truncation=True,
42 | padding="max_length",
43 | max_length=self.seq_len,
44 | )
45 | # Change this to real number
46 | examples["label"] = [float(i) for i in label]
47 | return examples
48 |
49 | def compute_metrics_for_regression(self, eval_pred):
50 | logits, labels = eval_pred
51 | predictions = np.argmax(logits, axis=-1)
52 | hot_labels = np.argmax(labels, axis=-1)
53 | return self.metric.compute(predictions=predictions, references=hot_labels)
54 |
55 | # ----------------------------------------------------------------#
56 | # ------------------------ Student training ----------------------#
57 | # ----------------------------------------------------------------#
58 | def standard_model_distillation(self, train_raw_student, x_test, y_test, num_classes):
59 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
60 |
61 | value2hot = {}
62 | for i in range(num_classes):
63 | a = [0] * num_classes
64 | a[i] = 1
65 | value2hot.update({i: a})
66 |
67 | test_dict = {"text": x_test, "score": [value2hot[i] for i in y_test]}
68 | raw_test_ds = Dataset.from_dict(test_dict)
69 |
70 | # validation and test sets are the same
71 | ds = {
72 | "train": train_raw_student,
73 | "validation": raw_test_ds,
74 | "test": raw_test_ds,
75 | }
76 | for split in ds:
77 | ds[split] = ds[split].map(self.bl_student_preprocess, remove_columns=["text", "score"])
78 |
79 | training_args = TrainingArguments(
80 | output_dir="baseline_distil_model",
81 | learning_rate=self.learning_rate,
82 | per_device_train_batch_size=self.batch_size,
83 | per_device_eval_batch_size=self.batch_size,
84 | num_train_epochs=self.num_epochs,
85 | eval_strategy="no",
86 | save_strategy="no",
87 | load_best_model_at_end=False,
88 | weight_decay=0.01,
89 | push_to_hub=False,
90 | )
91 |
92 | # define data_collator
93 | data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
94 |
95 | # define student model
96 | student_model = AutoModelForSequenceClassification.from_pretrained(
97 | self.student_model_name, num_labels=num_classes
98 | ).to(device)
99 |
100 | trainer = RegressionTrainer(
101 | student_model,
102 | args=training_args,
103 | train_dataset=ds["train"],
104 | eval_dataset=ds["validation"],
105 | data_collator=data_collator,
106 | tokenizer=self.tokenizer,
107 | compute_metrics=self.compute_metrics_for_regression,
108 | )
109 |
110 | trainer.train()
111 |
112 | trainer.eval_dataset = ds["test"]
113 | # acc = round(trainer.evaluate()["eval_accuracy"], 3)
114 |
115 | score = trainer.evaluate()[f"eval_{self.metric_name}"]
116 | return {self.metric_name: score}
117 |
--------------------------------------------------------------------------------
/scripts/tfew/README.md:
--------------------------------------------------------------------------------
1 | # Running T-Few
2 |
3 | These scripts run the baselines based on the `T-Few` paper: [_Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning_](https://arxiv.org/abs/2205.05638).
4 |
5 | ## Setup
6 |
7 | To run the scripts, first create a Python virtual environment, e.g. with `conda`:
8 |
9 | ```
10 | conda create -n baselines-tfew python=3.10 && conda activate baselines-tfew
11 | ```
12 |
13 | Next, clone our `T-Few` fork, and install the required dependencies:
14 |
15 | ```
16 | cd scripts/tfew
17 | git clone https://github.com/SetFit/t-few.git
18 | python -m pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113
19 | ```
20 |
21 | The steps above only need to be done once. In addition, every time you start a new session, you will need to run:
22 | ```
23 | cd scripts/tfew
24 | . t-few/bin/start.sh
25 | ```
26 | This sets up some required environment variables, including `PYTHONPATH`, `OUTPUT_PATH` (where results will be saved) and `CONFIG_PATH` (where the config `.json` files are stored).
27 | It also sets `CUDA_VISIBLE_DEVICES=0`. To use a different GPU, edit the file `t-few/bin/start.sh`.
28 |
29 | ## Usage example
30 |
31 | To train and evaluate `T-Few` (3B) on 8 examples (per class) on the `sst2` dataset, run:
32 |
33 | ```
34 | python -m t-few.src.pl_train \
35 | -c t03b.json+ia3.json+emotion.json \
36 | -k load_weight="t-few/pretrained_checkpoints/t03b_ia3_finish.pt" \
37 | exp_name=tfew_03b_pretrained/emotion/train-8 \
38 | num_shot=8 \
39 | batch_size=1 \
40 | eval_batch_size=2 \
41 | grad_accum_factor=8 \
42 | ```
43 |
44 | This will fine-tune the 3 billion parameter pretrained model using the (IA)^3 method from the `T-Few` paper, and then run the evaluation. For all our baselines, we use the default settings from the `T-Few` paper. `T-Few` comes in 2 versions: one with a 3 billion (3B) base model, and one with an 11 billion (11B) base model.
45 |
46 | You can run `T-Few` (3B) over all the supported test datasets in the `SetFit` paper by running:
47 |
48 | ```
49 | ./run_tfew_test.sh
50 | ```
51 |
52 | Similarly, to run `T-Few` (11B) over all test datasets, run:
53 |
54 | ```
55 | ./run_tfew_11b.sh
56 | ```
57 |
58 | Results will be saved to the `scripts/tfew/results` directory.
59 | The results are comprised of 10 directories, one for each training split.
60 | Each of these directories contains 5 results, one for each randomly selected training prompt.
61 | To retrieve the median score across all prompts (for each split), run the following on each dataset:
62 |
63 | ```
64 | python scripts/create_summary_table.py --path scripts/tfew/results/{experiment_name}
65 | ```
66 |
67 | The summary table will be saved in `results/{experiment_name}`.
68 |
--------------------------------------------------------------------------------
/scripts/tfew/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.12.0+cu113
2 | datasets==2.0.0
3 | transformers==4.15.0
4 | pytorch-lightning==1.5.8
5 | torchmetrics==0.6.2
6 | psutil==5.9.0
7 | deepspeed==0.5.10
8 | sentencepiece==0.1.96
9 | scipy
10 | ipdb
11 | evaluate==0.1.2
12 | scikit-learn
13 | git+https://github.com/SetFit/promptsource.git
--------------------------------------------------------------------------------
/scripts/tfew/run_tfew_11b.sh:
--------------------------------------------------------------------------------
1 | for dataset in amazon_counterfactual_en emotion enron_spam SentEval-CR sst5
2 | do
3 | for sample_size in 8 64
4 | do
5 | for train_split in 0 1 2 3 4 5 6 7 8 9
6 | do
7 | for seed in 0 1 2 3 4
8 | do
9 |
10 | python -m src.pl_train -c t011b.json+ia3.json+${dataset}.json \
11 | -k load_weight="t-few/pretrained_checkpoints/t011b_ia3_finish.pt" \
12 | exp_name=t011b_pretrained/${dataset}/train-${sample_size}-${train_split}/seed${seed} \
13 | train_split=${train_split} \
14 | few_shot_random_seed=${seed} \
15 | seed=${seed} \
16 | num_shot=$sample_size \
17 | batch_size=1 \
18 | eval_batch_size=2 \
19 | grad_accum_factor=8 \
20 | eval_before_training=0
21 | done
22 | done
23 | done
24 | done
25 |
--------------------------------------------------------------------------------
/scripts/tfew/run_tfew_test.sh:
--------------------------------------------------------------------------------
1 | for dataset in amazon_counterfactual_en emotion enron_spam SentEval-CR sst5
2 | do
3 | for sample_size in 8 64
4 | do
5 | for train_split in 0 1 2 3 4 5 6 7 8 9
6 | do
7 | for seed in 0 1 2 3 4
8 | do
9 | python -m src.pl_train -c t03b.json+ia3.json+${dataset}.json \
10 | -k load_weight="t-few/pretrained_checkpoints/t03b_ia3_finish.pt" \
11 | exp_name=t03b_pretrained/${dataset}/train-${sample_size}-${train_split}/seed${seed} \
12 | train_split=${train_split} \
13 | few_shot_random_seed=${seed} \
14 | seed=${seed} \
15 | num_shot=$sample_size \
16 | batch_size=8 \
17 | eval_batch_size=16 \
18 | grad_accum_factor=1 \
19 | eval_before_training=0
20 | done
21 | done
22 | done
23 | done
24 |
--------------------------------------------------------------------------------
/scripts/transformers/README.md:
--------------------------------------------------------------------------------
1 | # Transformers Baselines
2 |
3 | This folder contains the scripts used to train the 🤗 Transformers baselines quoted in the SetFit paper: [_Efficient Few-Shot Learning Without Prompts_](https://arxiv.org/abs/2209.11055).
4 |
5 | ## Setup
6 |
7 | To run the scripts, first create a Python virtual environment, e.g. with `conda`:
8 |
9 | ```
10 | conda create -n baselines-transformers python=3.9 && conda activate baselines-transformers
11 | ```
12 |
13 | Next, install the required dependencies
14 |
15 | ```
16 | python -m pip install setfit
17 | python -m pip install -r requirements.txt
18 | ```
19 |
20 | ## Usage
21 |
22 | ### Fewshot finetuning
23 |
24 | To finetune a pretrained model on a single dataset under the SetFit organization, run:
25 |
26 | ```
27 | python run_fewshot.py train-single-dataset \
28 | --model-id=distilbert-base-uncased \
29 | --dataset-id=sst2 \
30 | --metric=accuracy \
31 | --learning-rate=2e-5 \
32 | --batch-size=4
33 | ```
34 |
35 | To finetune a pretrained model on all the test datasets used in SetFit, run:
36 |
37 | ```
38 | python run_fewshot.py train-all-datasets --model-ckpt=distilbert-base-uncased --batch-size=4
39 | ```
40 |
41 | ### Full finetuning
42 |
43 | To finetune a pretrained model on a single dataset under the SetFit organization, run:
44 |
45 | ```
46 | python run_full.py train-single-dataset \
47 | --model-id=distilbert-base-uncased \
48 | --dataset-id=sst2 \
49 | --metric=accuracy \
50 | --learning-rate=2e-5 \
51 | --batch-size=24
52 | ```
53 |
54 | To finetune a pretrained model on all the test datasets used in SetFit, run:
55 |
56 | ```
57 | python run_full.py train-all-datasets --model-id=distilbert-base-uncased --batch-size=24
58 | ```
59 |
60 | ### Multilingual finetuning
61 |
62 | We provide three different ways to run SetFit in multilingual settings:
63 |
64 | * `each`: train on data in target language
65 | * `en`: train on English data only
66 | * `all`: train on data in all languages
67 |
68 | To finetune a baseline in one of these setting, run:
69 |
70 | ```
71 | python run_fewshot_multilingual.py train-single-dataset \
72 | --model-id=xlm-roberta-base \
73 | --dataset-id=amazon_reviews_multi_en \
74 | --metric=mae \
75 | --learning-rate=2e-5 \
76 | --batch-size=4 \
77 | --multilinguality=each
78 | ```
79 |
80 | To finetune a baseline on all the multilingual test sets in the paper, run:
81 |
82 | ```
83 | python run_fewshot_multilingual.py train-all-datasets \
84 | --model=xlm-roberta-base \
85 | --learning-rate=2e-5 \
86 | --batch-size=4 \
87 | --multilinguality=each
88 | ```
89 |
90 | ### Inference benchmark
91 |
92 | To run the inference benchmark, run:
93 |
94 | ```
95 | python run_inference.py --model-id=distilbert-base-uncased__sst2__train-16-4 --dataset-id=sst2 --num-samples=100
96 | ```
97 |
--------------------------------------------------------------------------------
/scripts/transformers/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers[sentencepiece,optuna]==4.20.0
2 | torch==1.11
3 | scikit-learn
4 | typer
5 |
--------------------------------------------------------------------------------
/scripts/transformers/run_full.py:
--------------------------------------------------------------------------------
1 | import gc
2 | from pathlib import Path
3 |
4 | import torch
5 | import typer
6 | from datasets import load_dataset
7 | from evaluate import load
8 | from transformers import (
9 | AutoModelForSequenceClassification,
10 | AutoTokenizer,
11 | EarlyStoppingCallback,
12 | Trainer,
13 | TrainingArguments,
14 | )
15 |
16 | from setfit.utils import DEV_DATASET_TO_METRIC, TEST_DATASET_TO_METRIC
17 | from utils import get_label_mappings, save_metrics
18 |
19 |
20 | app = typer.Typer()
21 |
22 |
23 | RESULTS_PATH = Path("results")
24 | RESULTS_PATH.mkdir(parents=True, exist_ok=True)
25 |
26 |
27 | @app.command()
28 | def train_single_dataset(
29 | model_id: str = "distilbert-base-uncased",
30 | dataset_id: str = "sst2",
31 | metric: str = "accuracy",
32 | learning_rate: float = 2e-5,
33 | batch_size: int = 4,
34 | num_train_epochs: int = 20,
35 | push_to_hub: bool = False,
36 | ):
37 | """Fine-tunes a pretrained checkpoint on the fewshot training sets"""
38 | # Load dataset
39 | dataset = load_dataset(f"SetFit/{dataset_id}")
40 | model_name = model_id.split("/")[-1]
41 |
42 | # Create metrics directory
43 | metrics_dir = RESULTS_PATH / Path(f"{model_name}-lr-{learning_rate}/{dataset_id}")
44 | metrics_dir.mkdir(parents=True, exist_ok=True)
45 | # Create split directory
46 | metrics_split_dir = metrics_dir / "train-full"
47 | metrics_split_dir.mkdir(parents=True, exist_ok=True)
48 | metrics_filepath = metrics_split_dir / "results.json"
49 | # Skip previously evaluated model
50 | if metrics_filepath.is_file():
51 | typer.echo("INFO -- model already trained, skipping ...")
52 | return
53 |
54 | # Load tokenizer and preprocess
55 | tokenizer = AutoTokenizer.from_pretrained(model_id)
56 |
57 | def tokenize_dataset(example):
58 | return tokenizer(example["text"], truncation=True, max_length=512)
59 |
60 | tokenized_dataset = dataset.map(tokenize_dataset, batched=True)
61 | # Create training and validation splits
62 | train_eval_dataset = tokenized_dataset["train"].train_test_split(seed=42, test_size=0.2)
63 | # Load model - we use a `model_init()` function here to load a fresh model with each fewshot training run
64 | num_labels, label2id, id2label = get_label_mappings(dataset["train"])
65 |
66 | def model_init():
67 | return AutoModelForSequenceClassification.from_pretrained(
68 | model_id, num_labels=num_labels, id2label=id2label, label2id=label2id
69 | )
70 |
71 | # Define metrics
72 | metric_fn = load(metric)
73 |
74 | def compute_metrics(pred):
75 | labels = pred.label_ids
76 | preds = pred.predictions.argmax(-1)
77 | return metric_fn.compute(predictions=preds, references=labels)
78 |
79 | # Define hyperparameters
80 | training_args = TrainingArguments(
81 | output_dir="checkpoints/full/",
82 | overwrite_output_dir=True,
83 | num_train_epochs=num_train_epochs,
84 | learning_rate=learning_rate,
85 | per_device_train_batch_size=batch_size,
86 | per_device_eval_batch_size=batch_size,
87 | weight_decay=0.001,
88 | eval_strategy="epoch",
89 | logging_steps=100,
90 | metric_for_best_model=metric,
91 | load_best_model_at_end=True,
92 | save_strategy="epoch",
93 | save_total_limit=1,
94 | fp16=True,
95 | report_to="none",
96 | )
97 |
98 | if push_to_hub:
99 | ckpt_name = f"{model_name}-finetuned-{dataset_id}-train-full"
100 | training_args.push_to_hub = True
101 | training_args.hub_strategy = ("end",)
102 | training_args.hub_model_id = f"SetFit/{ckpt_name}"
103 |
104 | callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
105 |
106 | trainer = Trainer(
107 | model_init=model_init,
108 | args=training_args,
109 | compute_metrics=compute_metrics,
110 | train_dataset=train_eval_dataset["train"],
111 | eval_dataset=train_eval_dataset["test"],
112 | tokenizer=tokenizer,
113 | callbacks=callbacks,
114 | )
115 | trainer.train()
116 |
117 | # Compute final metrics on full test set
118 | metrics = trainer.evaluate(tokenized_dataset["test"])
119 | eval_metrics = {}
120 | eval_metrics["score"] = metrics[f"eval_{metric}"] * 100.0
121 | eval_metrics["measure"] = metric
122 |
123 | # Save metrics
124 | save_metrics(eval_metrics, metrics_filepath)
125 |
126 | if push_to_hub:
127 | trainer.push_to_hub("Checkpoint upload", blocking=False)
128 |
129 | # Flush CUDA cache
130 | del trainer
131 | gc.collect()
132 | torch.cuda.empty_cache()
133 |
134 |
135 | @app.command()
136 | def train_all_datasets(
137 | model_id: str = "distilbert-base-uncased",
138 | learning_rate: float = 2e-5,
139 | batch_size: int = 4,
140 | num_train_epochs: int = 20,
141 | push_to_hub: bool = False,
142 | is_dev_set: bool = False,
143 | ):
144 | """Fine-tunes a pretrained checkpoint on all of the SetFit development/test datasets."""
145 | if is_dev_set:
146 | DATASET_TO_METRIC = DEV_DATASET_TO_METRIC
147 | else:
148 | DATASET_TO_METRIC = TEST_DATASET_TO_METRIC
149 |
150 | for dataset_id, metric in DATASET_TO_METRIC.items():
151 | typer.echo(f"🏋️🏋️🏋️ Fine-tuning on dataset {dataset_id} 🏋️🏋️🏋️")
152 | train_single_dataset(
153 | model_id=model_id,
154 | dataset_id=dataset_id,
155 | metric=metric,
156 | learning_rate=learning_rate,
157 | batch_size=batch_size,
158 | num_train_epochs=num_train_epochs,
159 | push_to_hub=push_to_hub,
160 | )
161 | typer.echo("Training complete!")
162 |
163 |
164 | if __name__ == "__main__":
165 | app()
166 |
--------------------------------------------------------------------------------
/scripts/transformers/run_inference.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from time import perf_counter
3 |
4 | import numpy as np
5 | import torch
6 | import typer
7 | from datasets import Dataset, load_dataset
8 | from transformers import pipeline
9 |
10 |
11 | RESULTS_PATH = Path("results")
12 | RESULTS_PATH.mkdir(parents=True, exist_ok=True)
13 |
14 | device_id = 0 if torch.cuda.is_available() else -1
15 |
16 |
17 | def time_pipeline(pipe: pipeline, dataset: Dataset):
18 | latencies = []
19 | # Warm up
20 | for _ in range(10):
21 | _ = pipe("Warming up the pipeline :)")
22 | # Timed run
23 | total_start_time = perf_counter()
24 | for row in dataset:
25 | start_time = perf_counter()
26 | _ = pipe(row["text"])
27 | latency = perf_counter() - start_time
28 | latencies.append(latency)
29 | total_time_ms = (perf_counter() - total_start_time) * 1_000
30 | # Compute run statistics
31 | time_avg_ms = 1_000 * np.mean(latencies)
32 | time_std_ms = 1_000 * np.std(latencies)
33 | time_p95_ms = 1_000 * np.percentile(latencies, 95)
34 | print(
35 | f"P95 latency (ms) - {time_p95_ms}; Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f};", # noqa
36 | time_p95_ms,
37 | f"Total time (ms) - {total_time_ms:.2f}",
38 | )
39 |
40 |
41 | def main(
42 | model_id: str = "distilbert-base-uncased__sst2__train-16-4", dataset_id: str = "sst2", num_samples: int = None
43 | ):
44 | # Load dataset
45 | dataset = load_dataset(f"SetFit/{dataset_id}", split="test")
46 | if num_samples is not None:
47 | dataset = dataset.shuffle(seed=42).select(range(num_samples))
48 | # Load pipeline
49 | pipe = pipeline("text-classification", model=f"SetFit/{model_id}", device=device_id)
50 | # Time it!
51 | time_pipeline(pipe, dataset)
52 |
53 |
54 | if __name__ == "__main__":
55 | typer.run(main)
56 |
--------------------------------------------------------------------------------
/scripts/transformers/run_zeroshot.py:
--------------------------------------------------------------------------------
1 | # Add zeroshot pipeline script here
2 |
--------------------------------------------------------------------------------
/scripts/transformers/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Tuple
3 |
4 | from datasets import Dataset
5 |
6 |
7 | def get_label_mappings(dataset: Dataset) -> Tuple[int, dict, dict]:
8 | """Returns the label mappings of the dataset."""
9 | label_ids = dataset.unique("label")
10 | label_names = dataset.unique("label_text")
11 | label2id = {label: idx for label, idx in zip(label_names, label_ids)}
12 | id2label = {idx: label for label, idx in label2id.items()}
13 | num_labels = len(label_ids)
14 | return num_labels, label2id, id2label
15 |
16 |
17 | def save_metrics(metrics: dict, metrics_filepath):
18 | with open(metrics_filepath, "w") as f:
19 | json.dump(metrics, f)
20 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | multi_line_output = 3
3 | include_trailing_comma = True
4 | force_grid_wrap = 0
5 | use_parentheses = True
6 | ensure_newline_before_comments = True
7 | line_length = 119
8 | lines_after_imports = 2
9 |
10 | [flake8]
11 | ignore = E203, E501, W503
12 | max-line-length = 119
13 | per-file-ignores =
14 | # imported but unused
15 | __init__.py: F401
16 | exclude =
17 | results
18 | scripts/adapet
19 | scripts/tfew
20 |
21 | [tool:pytest]
22 | testpaths = tests
23 | addopts = --cov=setfit --durations=10
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Lint as: python3
2 | from pathlib import Path
3 |
4 | from setuptools import find_packages, setup
5 |
6 |
7 | README_TEXT = (Path(__file__).parent / "README.md").read_text(encoding="utf-8")
8 |
9 | MAINTAINER = "Lewis Tunstall, Tom Aarsen"
10 | MAINTAINER_EMAIL = "lewis@huggingface.co"
11 |
12 | INTEGRATIONS_REQUIRE = ["optuna"]
13 | REQUIRED_PKGS = [
14 | "datasets>=2.15.0",
15 | "sentence-transformers[train]>=3",
16 | "transformers>=4.41.0",
17 | "evaluate>=0.3.0",
18 | "huggingface_hub>=0.24.0",
19 | "scikit-learn",
20 | "packaging",
21 | ]
22 | ABSA_REQUIRE = ["spacy<3.7.6"]
23 | QUALITY_REQUIRE = ["black", "flake8", "isort", "tabulate"]
24 | ONNX_REQUIRE = ["onnxruntime", "onnx!=1.16.2", "skl2onnx"]
25 | OPENVINO_REQUIRE = ["hummingbird-ml", "openvino"]
26 | TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE + OPENVINO_REQUIRE + ABSA_REQUIRE
27 | DOCS_REQUIRE = ["hf-doc-builder>=0.3.0"]
28 | CODECARBON_REQUIRE = ["codecarbon<2.6.0"]
29 | # 2.7.* fails with AttributeError: 'EmissionsTracker' object has no attribute '_cloud'
30 | # 2.6.* has an accidental print statement spamming the terminal
31 | EXTRAS_REQUIRE = {
32 | "optuna": INTEGRATIONS_REQUIRE,
33 | "quality": QUALITY_REQUIRE,
34 | "tests": TESTS_REQUIRE,
35 | "onnx": ONNX_REQUIRE,
36 | "openvino": ONNX_REQUIRE + OPENVINO_REQUIRE,
37 | "docs": DOCS_REQUIRE,
38 | "absa": ABSA_REQUIRE,
39 | "codecarbon": CODECARBON_REQUIRE,
40 | }
41 |
42 |
43 | def combine_requirements(base_keys):
44 | return list(set(k for v in base_keys for k in EXTRAS_REQUIRE[v]))
45 |
46 |
47 | EXTRAS_REQUIRE["dev"] = combine_requirements([k for k in EXTRAS_REQUIRE])
48 | # For the combatibility tests we add pandas<2, as pandas 2.0.0 onwards is incompatible with old datasets versions,
49 | # and we assume few to no users would use old datasets versions with new pandas versions.
50 | # The only alternative is incrementing the minimum version for datasets, which seems unnecessary.
51 | # Beyond that, fsspec is set to <2023.12.0 as that version is incompatible with datasets<=2.15.0
52 | EXTRAS_REQUIRE["compat_tests"] = (
53 | [requirement.replace(">=", "==") for requirement in REQUIRED_PKGS]
54 | + TESTS_REQUIRE
55 | + ["pandas<2", "fsspec<2023.12.0"]
56 | )
57 |
58 | setup(
59 | name="setfit",
60 | version="1.2.0.dev0",
61 | description="Efficient few-shot learning with Sentence Transformers",
62 | long_description=README_TEXT,
63 | long_description_content_type="text/markdown",
64 | maintainer=MAINTAINER,
65 | maintainer_email=MAINTAINER_EMAIL,
66 | url="https://github.com/huggingface/setfit",
67 | download_url="https://github.com/huggingface/setfit/tags",
68 | license="Apache 2.0",
69 | package_dir={"": "src"},
70 | packages=find_packages("src"),
71 | include_package_data=True,
72 | install_requires=REQUIRED_PKGS,
73 | extras_require=EXTRAS_REQUIRE,
74 | classifiers=[
75 | "Development Status :: 5 - Production/Stable",
76 | "Intended Audience :: Developers",
77 | "Intended Audience :: Education",
78 | "Intended Audience :: Science/Research",
79 | "License :: OSI Approved :: Apache Software License",
80 | "Operating System :: OS Independent",
81 | "Programming Language :: Python :: 3",
82 | "Programming Language :: Python :: 3.9",
83 | "Programming Language :: Python :: 3.10",
84 | "Programming Language :: Python :: 3.11",
85 | "Programming Language :: Python :: 3.12",
86 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
87 | ],
88 | keywords="nlp, machine learning, fewshot learning, transformers",
89 | zip_safe=False, # Required for mypy to find the py.typed file
90 | )
91 |
--------------------------------------------------------------------------------
/src/setfit/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.2.0.dev0"
2 |
3 | import importlib
4 | import os
5 | import warnings
6 |
7 | from .data import get_templated_dataset, sample_dataset
8 | from .model_card import SetFitModelCardData
9 | from .modeling import SetFitHead, SetFitModel
10 | from .span import AbsaModel, AbsaTrainer, AspectExtractor, AspectModel, PolarityModel
11 | from .trainer import SetFitTrainer, Trainer
12 | from .trainer_distillation import DistillationSetFitTrainer, DistillationTrainer
13 | from .training_args import TrainingArguments
14 |
15 |
16 | # Ensure that DeprecationWarnings are shown by default, as recommended by
17 | # https://docs.python.org/3/library/warnings.html#overriding-the-default-filter
18 | warnings.filterwarnings("default", category=DeprecationWarning)
19 |
20 | # If codecarbon is installed and the log level is not defined,
21 | # automatically overwrite the default to "error"
22 | if importlib.util.find_spec("codecarbon") and "CODECARBON_LOG_LEVEL" not in os.environ:
23 | os.environ["CODECARBON_LOG_LEVEL"] = "error"
24 |
--------------------------------------------------------------------------------
/src/setfit/exporters/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/src/setfit/exporters/__init__.py
--------------------------------------------------------------------------------
/src/setfit/exporters/openvino.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import openvino.runtime as ov
4 |
5 | from setfit import SetFitModel
6 | from setfit.exporters.onnx import export_onnx
7 |
8 |
9 | def export_to_openvino(
10 | model: SetFitModel,
11 | output_path: str = "model.xml",
12 | ) -> None:
13 | """Export a PyTorch backed SetFit model to OpenVINO Intermediate Representation.
14 |
15 | Args:
16 | model_body (`SentenceTransformer`): The model_body from a SetFit model body. This should be a
17 | SentenceTransformer.
18 | model_head (`torch.nn.Module` or `LogisticRegression`): The SetFit model head. This can be either a
19 | dense layer SetFitHead or a Sklearn estimator.
20 | output_path (`str`): The path where will be stored the generated OpenVINO model. At a minimum it needs to contain
21 | the name of the final file.
22 | ignore_ir_version (`bool`): Whether to ignore the IR version used in sklearn. The version is often missmatched
23 | with the transformer models. Setting this to true coerces the versions to be the same. This might
24 | cause errors but in practice works. If this is set to False you need to ensure that the IR versions
25 | align between the transformer and the sklearn onnx representation.
26 | """
27 |
28 | # Load the model and get all of the parts.
29 | OPENVINO_SUPPORTED_OPSET = 13
30 |
31 | model.model_body.cpu()
32 | onnx_path = output_path.replace(".xml", ".onnx")
33 |
34 | export_onnx(
35 | model.model_body,
36 | model.model_head,
37 | opset=OPENVINO_SUPPORTED_OPSET,
38 | output_path=onnx_path,
39 | ignore_ir_version=True,
40 | use_hummingbird=True,
41 | )
42 |
43 | # Save the final model.
44 | ov_model = ov.Core().read_model(onnx_path)
45 | ov.serialize(ov_model, output_path)
46 |
47 | os.remove(onnx_path)
48 |
--------------------------------------------------------------------------------
/src/setfit/exporters/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def mean_pooling(token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
5 | """Perform attention-aware mean pooling.
6 |
7 | This method takes in embeddings of shape (batch, sequence, embedding_size) and performs average
8 | pooling across the sequence dimension to yield embeddings of size (batch, embedding_size).
9 |
10 | From:
11 | https://github.com/UKPLab/sentence-transformers/blob/0b5ef4be93d2b21de3a918a084b48aab6ba48595/sentence_transformers/model_card_templates.py#L134 # noqa: E501
12 |
13 | Args:
14 | token_embeddings (`torch.Tensor`): The embeddings we wish to pool over of shape
15 | (batch, sequence, embedding_size). This will pool over the sequence to yield
16 | (batch, embedding_size).
17 | attention_mask (`torch.Tensor`): The binary attention mask across the embedings of shape
18 |
19 | Returns:
20 | (`torch.Tensor`) The mean pooled embeddings of size (batch, embedding_size).
21 | """
22 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
23 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
24 |
--------------------------------------------------------------------------------
/src/setfit/integrations.py:
--------------------------------------------------------------------------------
1 | import importlib.util
2 | from typing import TYPE_CHECKING
3 |
4 | from .utils import BestRun
5 |
6 |
7 | if TYPE_CHECKING:
8 | from .trainer import Trainer
9 |
10 |
11 | def is_optuna_available() -> bool:
12 | return importlib.util.find_spec("optuna") is not None
13 |
14 |
15 | def default_hp_search_backend():
16 | if is_optuna_available():
17 | return "optuna"
18 |
19 |
20 | def run_hp_search_optuna(trainer: "Trainer", n_trials: int, direction: str, **kwargs) -> BestRun:
21 | import optuna
22 |
23 | # Heavily inspired by transformers.integrations.run_hp_search_optuna
24 | # https://github.com/huggingface/transformers/blob/cbb8a37929c3860210f95c9ec99b8b84b8cf57a1/src/transformers/integrations.py#L160
25 | def _objective(trial):
26 | trainer.objective = None
27 | trainer.train(trial=trial)
28 | # If there hasn't been any evaluation during the training loop.
29 | if getattr(trainer, "objective", None) is None:
30 | metrics = trainer.evaluate()
31 | trainer.objective = trainer.compute_objective(metrics)
32 | return trainer.objective
33 |
34 | timeout = kwargs.pop("timeout", None)
35 | n_jobs = kwargs.pop("n_jobs", 1)
36 | study = optuna.create_study(direction=direction, **kwargs)
37 | study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
38 | best_trial = study.best_trial
39 | return BestRun(str(best_trial.number), best_trial.value, best_trial.params, study)
40 |
--------------------------------------------------------------------------------
/src/setfit/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class SupConLoss(nn.Module):
6 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
7 |
8 | It also supports the unsupervised contrastive loss in SimCLR.
9 | """
10 |
11 | def __init__(self, model, temperature=0.07, contrast_mode="all", base_temperature=0.07):
12 | super(SupConLoss, self).__init__()
13 | self.model = model
14 | self.temperature = temperature
15 | self.contrast_mode = contrast_mode
16 | self.base_temperature = base_temperature
17 |
18 | def forward(self, sentence_features, labels=None, mask=None):
19 | """Computes loss for model.
20 |
21 | If both `labels` and `mask` are None, it degenerates to SimCLR unsupervised loss:
22 | https://arxiv.org/pdf/2002.05709.pdf
23 |
24 | Args:
25 | features: hidden vector of shape [bsz, n_views, ...].
26 | labels: ground truth of shape [bsz].
27 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
28 | has the same class as sample i. Can be asymmetric.
29 |
30 | Returns:
31 | A loss scalar.
32 | """
33 | features = self.model(sentence_features[0])["sentence_embedding"]
34 |
35 | # Normalize embeddings
36 | features = torch.nn.functional.normalize(features, p=2, dim=1)
37 |
38 | # Add n_views dimension
39 | features = torch.unsqueeze(features, 1)
40 |
41 | device = features.device
42 |
43 | if len(features.shape) < 3:
44 | raise ValueError("`features` needs to be [bsz, n_views, ...]," "at least 3 dimensions are required")
45 | if len(features.shape) > 3:
46 | features = features.view(features.shape[0], features.shape[1], -1)
47 |
48 | batch_size = features.shape[0]
49 | if labels is not None and mask is not None:
50 | raise ValueError("Cannot define both `labels` and `mask`")
51 | elif labels is None and mask is None:
52 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
53 | elif labels is not None:
54 | labels = labels.contiguous().view(-1, 1)
55 | if labels.shape[0] != batch_size:
56 | raise ValueError("Num of labels does not match num of features")
57 | mask = torch.eq(labels, labels.T).float().to(device)
58 | else:
59 | mask = mask.float().to(device)
60 |
61 | contrast_count = features.shape[1]
62 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
63 | if self.contrast_mode == "one":
64 | anchor_feature = features[:, 0]
65 | anchor_count = 1
66 | elif self.contrast_mode == "all":
67 | anchor_feature = contrast_feature
68 | anchor_count = contrast_count
69 | else:
70 | raise ValueError("Unknown mode: {}".format(self.contrast_mode))
71 |
72 | # Compute logits
73 | anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature)
74 | # For numerical stability
75 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
76 | logits = anchor_dot_contrast - logits_max.detach()
77 |
78 | # Tile mask
79 | mask = mask.repeat(anchor_count, contrast_count)
80 | # Mask-out self-contrast cases
81 | logits_mask = torch.scatter(
82 | torch.ones_like(mask),
83 | 1,
84 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
85 | 0,
86 | )
87 | mask = mask * logits_mask
88 |
89 | # Compute log_prob
90 | exp_logits = torch.exp(logits) * logits_mask
91 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
92 |
93 | # Compute mean of log-likelihood over positive
94 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
95 |
96 | # Loss
97 | loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
98 | loss = loss.view(anchor_count, batch_size).mean()
99 |
100 | return loss
101 |
--------------------------------------------------------------------------------
/src/setfit/notebook.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from transformers.utils.notebook import NotebookProgressCallback
4 |
5 |
6 | class SetFitNotebookProgressCallback(NotebookProgressCallback):
7 | """
8 | A variation of NotebookProgressCallback that accepts logs/metrics other than "loss" and "eval_loss".
9 | In particular, it accepts "embedding_loss", "aspect_embedding_loss", and "polarity_embedding_loss"
10 | and the corresponding metrics for the validation set.
11 | """
12 |
13 | def on_log(self, *args, logs=None, **kwargs):
14 | if logs is not None:
15 | logs = {key if key != "embedding_loss" else "loss": value for key, value in logs.items()}
16 | return super().on_log(*args, logs=logs, **kwargs)
17 |
18 | def on_evaluate(self, args, state, control, metrics=None, **kwargs):
19 | if self.training_tracker is not None:
20 | values = {"Training Loss": "No log", "Validation Loss": "No log"}
21 | for log in reversed(state.log_history):
22 | if loss_logs := {
23 | key for key in log if key in ("embedding_loss", "aspect_embedding_loss", "polarity_embedding_loss")
24 | }:
25 | values["Training Loss"] = log[loss_logs.pop()]
26 | break
27 |
28 | if self.first_column == "Epoch":
29 | values["Epoch"] = int(state.epoch)
30 | else:
31 | values["Step"] = state.global_step
32 | metric_key_prefix = "eval"
33 | for k in metrics:
34 | if k.endswith("_loss"):
35 | metric_key_prefix = re.sub(r"\_loss$", "", k)
36 | _ = metrics.pop("total_flos", None)
37 | _ = metrics.pop("epoch", None)
38 | _ = metrics.pop(f"{metric_key_prefix}_runtime", None)
39 | _ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
40 | _ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
41 | _ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None)
42 | for k, v in metrics.items():
43 | splits = k.split("_")
44 | name = " ".join([part.capitalize() for part in splits[1:]])
45 | if name in ("Embedding Loss", "Aspect Embedding Loss", "Polarity Embedding Loss"):
46 | # Single dataset
47 | name = "Validation Loss"
48 | values[name] = v
49 | self.training_tracker.write_line(values)
50 | self.training_tracker.remove_child()
51 | self.prediction_bar = None
52 | # Evaluation takes a long time so we should force the next update.
53 | self._force_next_update = True
54 |
--------------------------------------------------------------------------------
/src/setfit/span/__init__.py:
--------------------------------------------------------------------------------
1 | from .aspect_extractor import AspectExtractor
2 | from .modeling import AbsaModel, AspectModel, PolarityModel
3 | from .trainer import AbsaTrainer
4 |
--------------------------------------------------------------------------------
/src/setfit/span/aspect_extractor.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, List, Tuple
2 |
3 |
4 | if TYPE_CHECKING:
5 | from spacy.tokens import Doc
6 |
7 |
8 | class AspectExtractor:
9 | def __init__(self, spacy_model: str) -> None:
10 | super().__init__()
11 | import spacy
12 |
13 | self.nlp = spacy.load(spacy_model)
14 |
15 | def find_groups(self, aspect_mask: List[bool]):
16 | start = None
17 | for idx, flag in enumerate(aspect_mask):
18 | if flag:
19 | if start is None:
20 | start = idx
21 | else:
22 | if start is not None:
23 | yield slice(start, idx)
24 | start = None
25 | if start is not None:
26 | yield slice(start, idx + 1)
27 |
28 | def __call__(self, texts: List[str]) -> Tuple[List["Doc"], List[slice]]:
29 | aspects_list = []
30 | docs = list(self.nlp.pipe(texts))
31 | for doc in docs:
32 | aspect_mask = [token.pos_ in ("NOUN", "PROPN") for token in doc]
33 | aspects_list.append(list(self.find_groups(aspect_mask)))
34 | return docs, aspects_list
35 |
--------------------------------------------------------------------------------
/src/setfit/utils.py:
--------------------------------------------------------------------------------
1 | import types
2 | from contextlib import contextmanager
3 | from dataclasses import dataclass, field
4 | from time import monotonic_ns
5 | from typing import Any, Dict, List, NamedTuple, Optional, Tuple
6 |
7 | from datasets import Dataset, DatasetDict, load_dataset
8 | from sentence_transformers import losses
9 | from transformers.utils import copy_func
10 |
11 | from .data import create_fewshot_splits, create_fewshot_splits_multilabel
12 | from .losses import SupConLoss
13 |
14 |
15 | SEC_TO_NS_SCALE = 1000000000
16 |
17 |
18 | DEV_DATASET_TO_METRIC = {
19 | "sst2": "accuracy",
20 | "imdb": "accuracy",
21 | "subj": "accuracy",
22 | "bbc-news": "accuracy",
23 | "enron_spam": "accuracy",
24 | "student-question-categories": "accuracy",
25 | "TREC-QC": "accuracy",
26 | "toxic_conversations": "matthews_correlation",
27 | }
28 |
29 | TEST_DATASET_TO_METRIC = {
30 | "emotion": "accuracy",
31 | "SentEval-CR": "accuracy",
32 | "sst5": "accuracy",
33 | "ag_news": "accuracy",
34 | "enron_spam": "accuracy",
35 | "amazon_counterfactual_en": "matthews_correlation",
36 | }
37 |
38 | MULTILINGUAL_DATASET_TO_METRIC = {
39 | f"amazon_reviews_multi_{lang}": "mae" for lang in ["en", "de", "es", "fr", "ja", "zh"]
40 | }
41 |
42 | LOSS_NAME_TO_CLASS = {
43 | "CosineSimilarityLoss": losses.CosineSimilarityLoss,
44 | "ContrastiveLoss": losses.ContrastiveLoss,
45 | "OnlineContrastiveLoss": losses.OnlineContrastiveLoss,
46 | "BatchSemiHardTripletLoss": losses.BatchSemiHardTripletLoss,
47 | "BatchAllTripletLoss": losses.BatchAllTripletLoss,
48 | "BatchHardTripletLoss": losses.BatchHardTripletLoss,
49 | "BatchHardSoftMarginTripletLoss": losses.BatchHardSoftMarginTripletLoss,
50 | "SupConLoss": SupConLoss,
51 | }
52 |
53 |
54 | def default_hp_space_optuna(trial) -> Dict[str, Any]:
55 | from transformers.integrations import is_optuna_available
56 |
57 | assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`"
58 | return {
59 | "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
60 | "num_epochs": trial.suggest_int("num_epochs", 1, 5),
61 | "num_iterations": trial.suggest_categorical("num_iterations", [5, 10, 20]),
62 | "seed": trial.suggest_int("seed", 1, 40),
63 | "batch_size": trial.suggest_categorical("batch_size", [4, 8, 16, 32, 64]),
64 | }
65 |
66 |
67 | def load_data_splits(
68 | dataset: str, sample_sizes: List[int], add_data_augmentation: bool = False
69 | ) -> Tuple[DatasetDict, Dataset]:
70 | """Loads a dataset from the Hugging Face Hub and returns the test split and few-shot training splits."""
71 | print(f"\n\n\n============== {dataset} ============")
72 | # Load one of the SetFit training sets from the Hugging Face Hub
73 | train_split = load_dataset(f"SetFit/{dataset}", split="train")
74 | train_splits = create_fewshot_splits(train_split, sample_sizes, add_data_augmentation, f"SetFit/{dataset}")
75 | test_split = load_dataset(f"SetFit/{dataset}", split="test")
76 | print(f"Test set: {len(test_split)}")
77 | return train_splits, test_split
78 |
79 |
80 | def load_data_splits_multilabel(dataset: str, sample_sizes: List[int]) -> Tuple[DatasetDict, Dataset]:
81 | """Loads a dataset from the Hugging Face Hub and returns the test split and few-shot training splits."""
82 | print(f"\n\n\n============== {dataset} ============")
83 | # Load one of the SetFit training sets from the Hugging Face Hub
84 | train_split = load_dataset(f"SetFit/{dataset}", "multilabel", split="train")
85 | train_splits = create_fewshot_splits_multilabel(train_split, sample_sizes)
86 | test_split = load_dataset(f"SetFit/{dataset}", "multilabel", split="test")
87 | print(f"Test set: {len(test_split)}")
88 | return train_splits, test_split
89 |
90 |
91 | @dataclass
92 | class Benchmark:
93 | """
94 | Performs simple benchmarks of code portions (measures elapsed time).
95 |
96 | Typical usage example:
97 |
98 | bench = Benchmark()
99 | with bench.track("Foo function"):
100 | foo()
101 | with bench.track("Bar function"):
102 | bar()
103 | bench.summary()
104 | """
105 |
106 | out_path: Optional[str] = None
107 | summary_msg: str = field(default_factory=str)
108 |
109 | def print(self, msg: str) -> None:
110 | """
111 | Prints to system out and optionally to specified out_path.
112 | """
113 | print(msg)
114 |
115 | if self.out_path is not None:
116 | with open(self.out_path, "a+") as f:
117 | f.write(msg + "\n")
118 |
119 | @contextmanager
120 | def track(self, step):
121 | """
122 | Computes the elapsed time for given code context.
123 | """
124 | start = monotonic_ns()
125 | yield
126 | ns = monotonic_ns() - start
127 | msg = f"\n{'*' * 70}\n'{step}' took {ns / SEC_TO_NS_SCALE:.3f}s ({ns:,}ns)\n{'*' * 70}\n"
128 | print(msg)
129 | self.summary_msg += msg + "\n"
130 |
131 | def summary(self) -> None:
132 | """
133 | Prints summary of all benchmarks performed.
134 | """
135 | self.print(f"\n{'#' * 30}\nBenchmark Summary:\n{'#' * 30}\n\n{self.summary_msg}")
136 |
137 |
138 | class BestRun(NamedTuple):
139 | """
140 | The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]).
141 |
142 | Parameters:
143 | run_id (`str`):
144 | The id of the best run.
145 | objective (`float`):
146 | The objective that was obtained for this run.
147 | hyperparameters (`Dict[str, Any]`):
148 | The hyperparameters picked to get this run.
149 | backend (`Any`):
150 | The relevant internal object used for optimization. For optuna this is the `study` object.
151 | """
152 |
153 | run_id: str
154 | objective: float
155 | hyperparameters: Dict[str, Any]
156 | backend: Any = None
157 |
158 |
159 | def set_docstring(method, docstring, cls=None):
160 | copied_function = copy_func(method)
161 | copied_function.__doc__ = docstring
162 | return types.MethodType(copied_function, cls or method.__self__)
163 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from datasets import Dataset
3 |
4 | from setfit import AbsaModel, SetFitModel
5 |
6 |
7 | @pytest.fixture()
8 | def model() -> SetFitModel:
9 | return SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
10 |
11 |
12 | @pytest.fixture()
13 | def absa_model() -> AbsaModel:
14 | return AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm")
15 |
16 |
17 | @pytest.fixture()
18 | def trained_absa_model() -> AbsaModel:
19 | return AbsaModel.from_pretrained(
20 | "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
21 | "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
22 | )
23 |
24 |
25 | @pytest.fixture()
26 | def absa_dataset() -> Dataset:
27 | texts = [
28 | "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine.",
29 | "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine.",
30 | "Food is great and inexpensive.",
31 | "Good bagels and good cream cheese.",
32 | "Good bagels and good cream cheese.",
33 | ]
34 | spans = ["food", "ambiance", "Food", "bagels", "cream cheese"]
35 | labels = ["negative", "negative", "positive", "positive", "positive"]
36 | ordinals = [0, 0, 0, 0, 0]
37 | return Dataset.from_dict({"text": texts, "span": spans, "label": labels, "ordinal": ordinals})
38 |
--------------------------------------------------------------------------------
/tests/exporters/test_onnx.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import onnxruntime
5 | import pytest
6 | from transformers import AutoTokenizer
7 |
8 | from setfit import SetFitModel
9 | from setfit.data import get_templated_dataset
10 | from setfit.exporters.onnx import export_onnx
11 | from setfit.trainer import Trainer
12 | from setfit.training_args import TrainingArguments
13 |
14 |
15 | @pytest.mark.parametrize(
16 | "model_path, input_text",
17 | [
18 | ("lewtun/my-awesome-setfit-model", ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]),
19 | (
20 | "lewtun/setfit-ethos-multilabel-example",
21 | ["I'm a really hateful guy!", "I hate this one person in particular!"],
22 | ),
23 | ],
24 | )
25 | def test_export_onnx_sklearn_head(model_path, input_text):
26 | """Test that the exported `ONNX` model returns the same predictions as the original model."""
27 | model = SetFitModel.from_pretrained(model_path)
28 |
29 | # Export the sklearn based model
30 | output_path = "model.onnx"
31 | try:
32 | export_onnx(model.model_body, model.model_head, opset=12, output_path=output_path)
33 |
34 | # Check that the model was saved.
35 | assert output_path in os.listdir(), "Model not saved to output_path"
36 |
37 | # Run inference using the original model.
38 | pytorch_preds = model(input_text)
39 |
40 | # Run inference using the exported onnx model.
41 | tokenizer = AutoTokenizer.from_pretrained(model_path)
42 | inputs = tokenizer(
43 | input_text,
44 | padding=True,
45 | truncation=True,
46 | return_attention_mask=True,
47 | return_token_type_ids=True,
48 | return_tensors="np",
49 | )
50 | # Map inputs to int64 from int32
51 | inputs = {key: value.astype("int64") for key, value in inputs.items()}
52 |
53 | session = onnxruntime.InferenceSession(output_path)
54 |
55 | onnx_preds = session.run(None, dict(inputs))[0]
56 |
57 | # Compare the results and ensure that we get the same predictions.
58 | assert np.array_equal(onnx_preds, pytorch_preds)
59 |
60 | finally:
61 | # Cleanup the model.
62 | os.remove(output_path)
63 |
64 |
65 | @pytest.mark.skip("ONNX exporting of SetFit model with Torch head not yet supported.")
66 | @pytest.mark.parametrize("out_features", [1, 2, 3])
67 | def test_export_onnx_torch_head(out_features):
68 | """Test that the exported `ONNX` model returns the same predictions as the original model."""
69 | dataset = get_templated_dataset(reference_dataset="SetFit/SentEval-CR")
70 | model_path = "sentence-transformers/paraphrase-albert-small-v2"
71 | model = SetFitModel.from_pretrained(
72 | model_path, use_differentiable_head=True, head_params={"out_features": out_features}
73 | )
74 |
75 | args = TrainingArguments(
76 | num_iterations=15,
77 | num_epochs=(1, 15),
78 | batch_size=16,
79 | body_learning_rate=(2e-5, 1e-5),
80 | head_learning_rate=1e-2,
81 | l2_weight=0.0,
82 | end_to_end=True,
83 | )
84 | trainer = Trainer(
85 | model=model,
86 | args=args,
87 | train_dataset=dataset,
88 | eval_dataset=dataset,
89 | column_mapping={"text": "text", "label": "label"},
90 | )
91 | trainer.train()
92 |
93 | # Export the sklearn based model
94 | output_path = "model.onnx"
95 | try:
96 | export_onnx(model.model_body, model.model_head, opset=12, output_path=output_path)
97 |
98 | # Check that the model was saved.
99 | assert output_path in os.listdir(), "Model not saved to output_path"
100 |
101 | # Run inference using the original model.
102 | input_text = ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]
103 | pytorch_preds = model(input_text)
104 |
105 | # Run inference using the exported onnx model.
106 | tokenizer = AutoTokenizer.from_pretrained(model_path)
107 | inputs = tokenizer(
108 | input_text,
109 | padding=True,
110 | truncation=True,
111 | return_attention_mask=True,
112 | return_token_type_ids=True,
113 | return_tensors="np",
114 | )
115 | # Map inputs to int64 from int32
116 | inputs = {key: value.astype("int64") for key, value in inputs.items()}
117 |
118 | session = onnxruntime.InferenceSession(output_path)
119 |
120 | onnx_preds = session.run(None, dict(inputs))[0]
121 |
122 | # Compare the results and ensure that we get the same predictions.
123 | assert np.array_equal(onnx_preds, pytorch_preds)
124 |
125 | finally:
126 | # Cleanup the model.
127 | os.remove(output_path)
128 |
--------------------------------------------------------------------------------
/tests/exporters/test_openvino.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import openvino.runtime as ov
5 | import pytest
6 | from transformers import AutoTokenizer
7 |
8 | from setfit import SetFitModel
9 | from setfit.exporters.openvino import export_to_openvino
10 |
11 |
12 | @pytest.mark.skip(
13 | reason="OpenVINO exporting broke since openvino==2022.3.0, while this version is not supported for Python 3.11 onwards. "
14 | "To allow us to add Python 3.11+ support, we are skipping this test until OpenVINO support is fixed."
15 | )
16 | def test_export_to_openvino():
17 | """Test that the exported `OpenVINO` model returns the same predictions as the original model."""
18 | model_path = "lewtun/my-awesome-setfit-model"
19 | model = SetFitModel.from_pretrained(model_path)
20 |
21 | # Export the sklearn based model
22 | output_path = "model.xml"
23 | export_to_openvino(model, output_path=output_path)
24 |
25 | # Check that the model was saved.
26 | assert output_path in os.listdir(), "Model not saved to output_path"
27 |
28 | # Run inference using the original model.
29 | input_text = ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]
30 | pytorch_preds = model(input_text)
31 |
32 | # Run inference using the exported OpenVINO model.
33 | tokenizer = AutoTokenizer.from_pretrained(model_path)
34 | inputs = tokenizer(
35 | input_text,
36 | padding=True,
37 | truncation=True,
38 | return_attention_mask=True,
39 | return_token_type_ids=True,
40 | return_tensors="np",
41 | )
42 |
43 | inputs_dict = dict(inputs)
44 |
45 | core = ov.Core()
46 | ov_model = core.read_model(output_path)
47 | compiled_model = core.compile_model(ov_model, "CPU")
48 |
49 | ov_preds = compiled_model(inputs_dict)[compiled_model.outputs[0]]
50 |
51 | # Compare the results and ensure that we get the same predictions.
52 | assert np.array_equal(ov_preds, pytorch_preds)
53 |
54 | # Cleanup the model.
55 | os.remove(output_path)
56 | os.remove(output_path.replace(".xml", ".bin"))
57 |
--------------------------------------------------------------------------------
/tests/model_card_pattern.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 |
3 | import re
4 |
5 |
6 | MODEL_CARD_PATTERN = re.compile(
7 | """\
8 | ---
9 | .*
10 | ---
11 |
12 | \# SetFit with sentence\-transformers/paraphrase\-albert\-small\-v2 on SST2
13 |
14 | This is a \[SetFit\]\(https://github\.com/huggingface/setfit\) model trained on the \[SST2\]\(https://huggingface\.co/datasets/sst2\) dataset that can be used for Text Classification\. This SetFit model uses \[sentence\-transformers/paraphrase\-albert\-small\-v2\]\(https://huggingface\.co/sentence\-transformers/paraphrase\-albert\-small\-v2\) as the Sentence Transformer embedding model\. A \[LogisticRegression\]\(https://scikit\-learn\.org/stable/modules/generated/sklearn\.linear_model\.LogisticRegression\.html\) instance is used for classification\.
15 |
16 | The model has been trained using an efficient few\-shot learning technique that involves:
17 |
18 | 1\. Fine\-tuning a \[Sentence Transformer\]\(https://www\.sbert\.net\) with contrastive learning\.
19 | 2\. Training a classification head with features from the fine\-tuned Sentence Transformer\.
20 |
21 | ## Model Details
22 |
23 | ### Model Description
24 | - \*\*Model Type:\*\* SetFit
25 | - \*\*Sentence Transformer body:\*\* \[sentence-transformers/paraphrase-albert-small-v2\]\(https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2\) *
26 | - \*\*Classification head:\*\* a \[LogisticRegression\]\(https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html\) instance *
27 | - \*\*Maximum Sequence Length:\*\* 100 tokens
28 | - \*\*Number of Classes:\*\* 2 classes
29 | - \*\*Training Dataset:\*\* \[SST2\]\(https://huggingface.co/datasets/sst2\)
30 | - \*\*Language:\*\* en
31 | - \*\*License:\*\* apache-2.0
32 |
33 | ### Model Sources
34 |
35 | - \*\*Repository:\*\* \[SetFit on GitHub\]\(https://github.com/huggingface/setfit\)
36 | - \*\*Paper:\*\* \[Efficient Few-Shot Learning Without Prompts\]\(https://arxiv.org/abs/2209.11055\)
37 | - \*\*Blogpost:\*\* \[SetFit: Efficient Few-Shot Learning Without Prompts\]\(https://huggingface.co/blog/setfit\)
38 |
39 | ### Model Labels
40 | \| Label\s+\| Examples\s+\|
41 | \|:-+\|:-+\|
42 | \| positive\s+\| [^\|]+ \|
43 | \| negative\s+\| [^\|]+ \|
44 |
45 | ## Evaluation
46 |
47 | ### Metrics
48 | \| Label \| Accuracy \|
49 | \|:--------\|:---------\|
50 | \| \*\*all\*\* \| [\d\.]+\s+\|
51 |
52 | ## Uses
53 |
54 | ### Direct Use for Inference
55 |
56 | First install the SetFit library:
57 |
58 | ```bash
59 | pip install setfit
60 | ```
61 |
62 | Then you can load this model and run inference.
63 |
64 | ```python
65 | from setfit import SetFitModel
66 |
67 | # Download from the [^H]+ Hub
68 | model = SetFitModel.from_pretrained\("tomaarsen/setfit-paraphrase-albert-small-v2-sst2"\)
69 | # Run inference
70 | preds = model\(".+"\)
71 | ```
72 |
73 |
78 |
79 |
84 |
85 |
90 |
91 |
96 |
97 | ## Training Details
98 |
99 | ### Training Set Metrics
100 | \| Training set \| Min \| Median \| Max \|
101 | \|:-------------\|:----\|:-------\|:----\|
102 | \| Word count \| 3 \| 7.875 \| 18 \|
103 |
104 | \| Label \| Training Sample Count \|
105 | \|:---------\|:----------------------\|
106 | \| negative \| 8 \|
107 | \| positive \| 8 \|
108 |
109 | ### Training Hyperparameters
110 | - batch_size: \(1, 1\)
111 | - num_epochs: \(1, 16\)
112 | - max_steps: 2
113 | - sampling_strategy: oversampling
114 | - body_learning_rate: \(2e-05, 1e-05\)
115 | - head_learning_rate: 0.01
116 | - loss: CosineSimilarityLoss
117 | - distance_metric: cosine_distance
118 | - margin: 0.25
119 | - end_to_end: False
120 | - use_amp: False
121 | - warmup_proportion: 0.1
122 | - l2_weight: 0.01
123 | - seed: 42
124 | - eval_max_steps: -1
125 | - load_best_model_at_end: False
126 |
127 | ### Training Results
128 | \| Epoch \| Step \| Training Loss \| Validation Loss \|
129 | \|:-----:\|:----:\|:-------------:\|:---------------:\|
130 | (\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\|\n)+
131 | ### Environmental Impact
132 | Carbon emissions were measured using \[CodeCarbon\]\(https://github.com/mlco2/codecarbon\)\.
133 | - \*\*Carbon Emitted\*\*: [\d\.]+ kg of CO2
134 | - \*\*Hours Used\*\*: [\d\.]+ hours
135 |
136 | ### Training Hardware
137 | - \*\*On Cloud\*\*: (Yes|No)
138 | - \*\*GPU Model\*\*: [^\n]+
139 | - \*\*CPU Model\*\*: [^\n]+
140 | - \*\*RAM Size\*\*: [\d\.]+ GB
141 |
142 | ### Framework Versions
143 | - Python: [^\n]+
144 | - SetFit: [^\n]+
145 | - Sentence Transformers: [^\n]+
146 | - Transformers: [^\n]+
147 | - PyTorch: [^\n]+
148 | - Datasets: [^\n]+
149 | - Tokenizers: [^\n]+
150 |
151 | ## Citation
152 |
153 | ### BibTeX
154 | ```bibtex
155 | @article{https://doi.org/10.48550/arxiv.2209.11055,
156 | doi = {10.48550/ARXIV.2209.11055},
157 | url = {https://arxiv.org/abs/2209.11055},
158 | author = {Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren},
159 | keywords = {Computation and Language \(cs.CL\), FOS: Computer and information sciences, FOS: Computer and information sciences},
160 | title = {Efficient Few-Shot Learning Without Prompts},
161 | publisher = {arXiv},
162 | year = \{2022\},
163 | copyright = {Creative Commons Attribution 4.0 International}
164 | }
165 | ```
166 |
167 |
172 |
173 |
178 |
179 | """,
184 | flags=re.DOTALL,
185 | )
186 |
--------------------------------------------------------------------------------
/tests/span/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/tests/span/__init__.py
--------------------------------------------------------------------------------
/tests/span/test_model_card.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from datasets import Dataset
4 |
5 | from setfit import AbsaModel, AbsaTrainer, SetFitModelCardData, TrainingArguments
6 |
7 | from .aspect_model_card_pattern import ASPECT_MODEL_CARD_PATTERN
8 | from .polarity_model_card_pattern import POLARITY_MODEL_CARD_PATTERN
9 |
10 |
11 | def test_model_card(absa_dataset: Dataset, tmp_path: Path) -> None:
12 | model = AbsaModel.from_pretrained(
13 | "sentence-transformers/paraphrase-albert-small-v2",
14 | model_card_data=SetFitModelCardData(
15 | model_id="tomaarsen/setfit-absa-paraphrase-albert-small-v2-laptops",
16 | language=["en"],
17 | license="apache-2.0",
18 | ),
19 | )
20 |
21 | args = TrainingArguments(
22 | str(tmp_path),
23 | report_to="codecarbon",
24 | batch_size=1,
25 | eval_steps=1,
26 | logging_steps=1,
27 | max_steps=2,
28 | eval_strategy="steps",
29 | save_strategy="no",
30 | )
31 | trainer = AbsaTrainer(
32 | model=model,
33 | args=args,
34 | train_dataset=absa_dataset,
35 | eval_dataset=absa_dataset,
36 | )
37 | trainer.train()
38 | trainer.evaluate()
39 |
40 | path = tmp_path / "aspect"
41 | model.aspect_model.create_model_card(path, model_name=str(path))
42 | with open(path / "README.md", "r", encoding="utf8") as f:
43 | model_card = f.read()
44 | assert ASPECT_MODEL_CARD_PATTERN.fullmatch(model_card)
45 |
46 | path = tmp_path / "polarity"
47 | model.polarity_model.create_model_card(path, model_name=str(path))
48 | with open(path / "README.md", "r", encoding="utf8") as f:
49 | model_card = f.read()
50 | assert POLARITY_MODEL_CARD_PATTERN.fullmatch(model_card)
51 |
--------------------------------------------------------------------------------
/tests/span/test_trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from datasets import Dataset
4 | from pytest import LogCaptureFixture
5 | from transformers import TrainerCallback
6 |
7 | from setfit import AbsaTrainer
8 | from setfit.logging import get_logger
9 | from setfit.span.modeling import AbsaModel
10 |
11 |
12 | def test_trainer(absa_model: AbsaModel, absa_dataset: Dataset) -> None:
13 | trainer = AbsaTrainer(absa_model, train_dataset=absa_dataset, eval_dataset=absa_dataset)
14 | trainer.train()
15 |
16 | metrics = trainer.evaluate()
17 | assert "aspect" in metrics
18 | assert "polarity" in metrics
19 | assert "accuracy" in metrics["aspect"]
20 | assert "accuracy" in metrics["polarity"]
21 | assert metrics["aspect"]["accuracy"] > 0.0
22 | assert metrics["polarity"]["accuracy"] > 0.0
23 | new_metrics = trainer.evaluate(absa_dataset)
24 | assert metrics == new_metrics
25 |
26 | predict = absa_model.predict("Best pizza outside of Italy and really tasty.")
27 | assert {"span": "pizza", "polarity": "positive"} in predict
28 | predict = absa_model.predict(["Best pizza outside of Italy and really tasty.", "This is another sentence"])
29 | assert isinstance(predict, list) and len(predict) == 2 and isinstance(predict[0], list)
30 | predict = absa_model(["Best pizza outside of Italy and really tasty.", "This is another sentence"])
31 | assert isinstance(predict, list) and len(predict) == 2 and isinstance(predict[0], list)
32 |
33 |
34 | def test_trainer_callbacks(absa_model: AbsaModel) -> None:
35 | trainer = AbsaTrainer(absa_model)
36 | assert len(trainer.aspect_trainer.st_trainer.callback_handler.callbacks) >= 2
37 | num_callbacks = len(trainer.aspect_trainer.st_trainer.callback_handler.callbacks)
38 | callback_names = {
39 | callback.__class__.__name__ for callback in trainer.aspect_trainer.st_trainer.callback_handler.callbacks
40 | }
41 | assert {"DefaultFlowCallback", "ProgressCallback"} <= callback_names
42 |
43 | class TestCallback(TrainerCallback):
44 | pass
45 |
46 | callback = TestCallback()
47 | trainer.add_callback(callback)
48 | assert len(trainer.aspect_trainer.st_trainer.callback_handler.callbacks) == num_callbacks + 1
49 | assert len(trainer.polarity_trainer.st_trainer.callback_handler.callbacks) == num_callbacks + 1
50 | assert trainer.aspect_trainer.st_trainer.callback_handler.callbacks[-1] == callback
51 | assert trainer.polarity_trainer.st_trainer.callback_handler.callbacks[-1] == callback
52 |
53 | assert trainer.pop_callback(callback) == (callback, callback)
54 | trainer.add_callback(callback)
55 | assert trainer.aspect_trainer.st_trainer.callback_handler.callbacks[-1] == callback
56 | assert trainer.polarity_trainer.st_trainer.callback_handler.callbacks[-1] == callback
57 | trainer.remove_callback(callback)
58 | assert callback not in trainer.aspect_trainer.st_trainer.callback_handler.callbacks
59 | assert callback not in trainer.polarity_trainer.st_trainer.callback_handler.callbacks
60 |
61 |
62 | def test_train_ordinal_too_high(absa_model: AbsaModel, caplog: LogCaptureFixture) -> None:
63 | logger = get_logger("setfit")
64 | logger.propagate = True
65 |
66 | absa_dataset = Dataset.from_dict(
67 | {
68 | "text": [
69 | "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine."
70 | ],
71 | "span": ["food"],
72 | "label": ["negative"],
73 | "ordinal": [1],
74 | }
75 | )
76 | with caplog.at_level(logging.INFO):
77 | trainer = AbsaTrainer(absa_model, train_dataset=absa_dataset)
78 | assert len(trainer.aspect_trainer.train_dataset) == 3
79 | assert len(trainer.polarity_trainer.train_dataset) == 0
80 | # These tests are ignored as the caplog is inconsistent:
81 | # assert len(caplog.record_tuples) == 1
82 | # assert caplog.record_tuples[0][2] == (
83 | # "The ordinal of 1 for span 'food' in 'It is about food and ambiance, and imagine how dreadful it will be "
84 | # "it we only had to listen to an idle engine.' is too high. Skipping this sample."
85 | # )
86 | # assert caplog.record_tuples[0][1] == logging.INFO
87 |
88 | logger.propagate = False
89 |
90 |
91 | def test_train_column_mapping(absa_model: AbsaModel, absa_dataset: Dataset) -> None:
92 | absa_dataset = absa_dataset.rename_columns({"text": "sentence", "span": "aspect"})
93 | trainer = AbsaTrainer(
94 | absa_model, train_dataset=absa_dataset, column_mapping={"sentence": "text", "aspect": "span"}
95 | )
96 | trainer.train()
97 |
--------------------------------------------------------------------------------
/tests/test_deprecated_trainer_distillation.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import pytest
4 | from datasets import Dataset
5 | from sentence_transformers.losses import CosineSimilarityLoss
6 |
7 | from setfit import DistillationSetFitTrainer, SetFitTrainer
8 | from setfit.modeling import SetFitModel
9 |
10 |
11 | class DistillationSetFitTrainerTest(TestCase):
12 | def setUp(self):
13 | self.teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
14 | self.student_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")
15 | self.num_iterations = 1
16 |
17 | def test_trainer_works_with_default_columns(self):
18 | dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
19 | # train a teacher model
20 | teacher_trainer = SetFitTrainer(
21 | model=self.teacher_model,
22 | train_dataset=dataset,
23 | eval_dataset=dataset,
24 | loss_class=CosineSimilarityLoss,
25 | metric="accuracy",
26 | )
27 | # Teacher Train and evaluate
28 | teacher_trainer.train()
29 | teacher_model = teacher_trainer.model
30 |
31 | student_trainer = DistillationSetFitTrainer(
32 | teacher_model=teacher_model,
33 | train_dataset=dataset,
34 | student_model=self.student_model,
35 | eval_dataset=dataset,
36 | loss_class=CosineSimilarityLoss,
37 | metric="accuracy",
38 | )
39 |
40 | # Student Train and evaluate
41 | student_trainer.train()
42 | metrics = student_trainer.evaluate()
43 | print("Student results: ", metrics)
44 | self.assertEqual(metrics["accuracy"], 1.0)
45 |
46 | def test_trainer_raises_error_with_missing_label(self):
47 | labeled_dataset = Dataset.from_dict(
48 | {"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]}
49 | )
50 | # train a teacher model
51 | teacher_trainer = SetFitTrainer(
52 | model=self.teacher_model,
53 | train_dataset=labeled_dataset,
54 | eval_dataset=labeled_dataset,
55 | metric="accuracy",
56 | num_iterations=self.num_iterations,
57 | )
58 | # Teacher Train and evaluate
59 | teacher_trainer.train()
60 |
61 | unlabeled_dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]})
62 | student_trainer = DistillationSetFitTrainer(
63 | teacher_model=self.teacher_model,
64 | student_model=self.student_model,
65 | train_dataset=unlabeled_dataset,
66 | eval_dataset=labeled_dataset,
67 | num_iterations=self.num_iterations,
68 | )
69 | student_trainer.train()
70 | metrics = student_trainer.evaluate()
71 | print("Student results: ", metrics)
72 | self.assertEqual(metrics["accuracy"], 1.0)
73 |
74 | def test_trainer_raises_error_with_missing_text(self):
75 | dataset = Dataset.from_dict({"label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
76 | with pytest.raises(ValueError):
77 | DistillationSetFitTrainer(
78 | teacher_model=self.teacher_model,
79 | train_dataset=dataset,
80 | student_model=self.student_model,
81 | eval_dataset=dataset,
82 | num_iterations=self.num_iterations,
83 | )
84 |
85 | def test_column_mapping_with_missing_text(self):
86 | dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]})
87 | with pytest.raises(ValueError):
88 | DistillationSetFitTrainer(
89 | teacher_model=self.teacher_model,
90 | train_dataset=dataset,
91 | student_model=self.student_model,
92 | eval_dataset=dataset,
93 | num_iterations=self.num_iterations,
94 | column_mapping={"label_new": "label"},
95 | )
96 |
97 | def test_column_mapping_multilabel(self):
98 | dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[0, 1], [1, 2], [2, 0]]})
99 |
100 | trainer = DistillationSetFitTrainer(
101 | teacher_model=self.teacher_model,
102 | train_dataset=dataset,
103 | student_model=self.student_model,
104 | eval_dataset=dataset,
105 | num_iterations=self.num_iterations,
106 | column_mapping={"text_new": "text", "label_new": "label"},
107 | )
108 |
109 | trainer._validate_column_mapping(dataset)
110 | formatted_dataset = trainer._apply_column_mapping(dataset, trainer.column_mapping)
111 |
112 | assert formatted_dataset.column_names == ["text", "label"]
113 | assert formatted_dataset[0]["text"] == "a"
114 | assert formatted_dataset[0]["label"] == [0, 1]
115 | assert formatted_dataset[1]["text"] == "b"
116 |
--------------------------------------------------------------------------------
/tests/test_model_card.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import datasets
4 | import pytest
5 | from datasets import Dataset, load_dataset
6 | from packaging.version import Version, parse
7 |
8 | from setfit import SetFitModel, SetFitModelCardData, Trainer, TrainingArguments
9 | from setfit.data import sample_dataset
10 | from setfit.model_card import generate_model_card, is_on_huggingface
11 |
12 | from .model_card_pattern import MODEL_CARD_PATTERN
13 |
14 |
15 | def test_model_card(tmp_path: Path) -> None:
16 | dataset = load_dataset("sst2")
17 | train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
18 | eval_dataset = dataset["validation"].select(range(10))
19 | model = SetFitModel.from_pretrained(
20 | "sentence-transformers/paraphrase-albert-small-v2",
21 | labels=["negative", "positive"],
22 | model_card_data=SetFitModelCardData(
23 | model_id="tomaarsen/setfit-paraphrase-albert-small-v2-sst2",
24 | dataset_id="sst2",
25 | dataset_name="SST2",
26 | language=["en"],
27 | license="apache-2.0",
28 | ),
29 | )
30 |
31 | args = TrainingArguments(
32 | str(tmp_path),
33 | report_to="codecarbon",
34 | batch_size=1,
35 | eval_steps=1,
36 | logging_steps=1,
37 | max_steps=2,
38 | eval_strategy="steps",
39 | save_strategy="no",
40 | )
41 | trainer = Trainer(
42 | model=model,
43 | args=args,
44 | train_dataset=train_dataset,
45 | eval_dataset=eval_dataset,
46 | column_mapping={"sentence": "text"},
47 | )
48 | trainer.train()
49 | trainer.evaluate()
50 | model_card = generate_model_card(trainer.model)
51 | assert MODEL_CARD_PATTERN.fullmatch(model_card)
52 |
53 |
54 | def test_model_card_languages() -> None:
55 | model = SetFitModel.from_pretrained(
56 | "sentence-transformers/paraphrase-albert-small-v2",
57 | model_card_data=SetFitModelCardData(
58 | language=["en", "nl", "de"],
59 | ),
60 | )
61 | model_card = model.generate_model_card()
62 | assert "**Languages:** en, nl, de" in model_card
63 |
64 |
65 | def test_is_on_huggingface_edge_case() -> None:
66 | assert not is_on_huggingface("test_value")
67 | assert not is_on_huggingface("a/test/value")
68 |
69 |
70 | @pytest.mark.skipif(
71 | parse(datasets.__version__) < Version("2.14.0"), reason="Inferring dataset_id only works from datasets >= 2.14.0"
72 | )
73 | @pytest.mark.parametrize("dataset_id", ("SetFit/emotion", "SetFit/sst2"))
74 | def test_infer_dataset_id(dataset_id: str) -> None:
75 | model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
76 | train_dataset = load_dataset(dataset_id, split="train")
77 |
78 | # This triggers inferring the dataset_id from train_dataset
79 | Trainer(model=model, train_dataset=train_dataset)
80 | assert model.model_card_data.dataset_id == dataset_id
81 |
82 |
83 | def test_cant_infer_dataset_id():
84 | model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2")
85 | train_dataset = Dataset.from_dict({"text": ["a", "b", "c", "d"], "label": [0, 1, 1, 0]})
86 |
87 | # This triggers inferring the dataset_id from train_dataset
88 | Trainer(model=model, train_dataset=train_dataset)
89 | assert model.model_card_data.dataset_id is None
90 |
--------------------------------------------------------------------------------
/tests/test_sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from setfit.sampler import ContrastiveDataset
5 |
6 |
7 | @pytest.mark.parametrize(
8 | "sampling_strategy, expected_pos_pairs, expected_neg_pairs",
9 | [("unique", 4, 2), ("undersampling", 2, 2), ("oversampling", 4, 4)],
10 | )
11 | def test_sentence_pairs_generation(sampling_strategy: str, expected_pos_pairs: int, expected_neg_pairs: int):
12 | sentences = np.array(["sent 1", "sent 2", "sent 3"])
13 | labels = np.array(["label 1", "label 1", "label 2"])
14 |
15 | multilabel = False
16 |
17 | data_sampler = ContrastiveDataset(sentences, labels, multilabel, sampling_strategy=sampling_strategy)
18 |
19 | assert data_sampler.len_pos_pairs == expected_pos_pairs
20 | assert data_sampler.len_neg_pairs == expected_neg_pairs
21 |
22 | pairs = [i for i in data_sampler]
23 |
24 | assert len(pairs) == expected_pos_pairs + expected_neg_pairs
25 | assert pairs[0] == {"sentence_1": "sent 1", "sentence_2": "sent 1", "label": 1.0}
26 |
27 |
28 | @pytest.mark.parametrize(
29 | "sampling_strategy, expected_pos_pairs, expected_neg_pairs",
30 | [("unique", 6, 4), ("undersampling", 4, 4), ("oversampling", 6, 6)],
31 | )
32 | def test_sentence_pairs_generation_multilabel(
33 | sampling_strategy: str, expected_pos_pairs: int, expected_neg_pairs: int
34 | ):
35 | sentences = np.array(["sent 1", "sent 2", "sent 3", "sent 4"])
36 | labels = np.array([[1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
37 |
38 | multilabel = True
39 |
40 | data_sampler = ContrastiveDataset(sentences, labels, multilabel, sampling_strategy=sampling_strategy)
41 | assert data_sampler.len_pos_pairs == expected_pos_pairs
42 | assert data_sampler.len_neg_pairs == expected_neg_pairs
43 |
44 | pairs = [i for i in data_sampler]
45 | assert len(pairs) == expected_pos_pairs + expected_neg_pairs
46 |
--------------------------------------------------------------------------------
/tests/test_training_args.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import pytest
4 | from transformers import IntervalStrategy
5 |
6 | from setfit.training_args import TrainingArguments
7 |
8 |
9 | class TestTrainingArguments(TestCase):
10 | def test_raises_error_with_wrong_warmup_proportion(self):
11 | # warmup_proportion must not be > 1.0
12 | with pytest.raises(ValueError):
13 | TrainingArguments(warmup_proportion=1.1)
14 |
15 | # warmup_proportion must not be < 0.0
16 | with pytest.raises(ValueError):
17 | TrainingArguments(warmup_proportion=-0.1)
18 |
19 | def test_batch_sizes(self):
20 | batch_size_A = 12
21 | batch_size_B = 4
22 |
23 | args = TrainingArguments(batch_size=batch_size_A)
24 | self.assertEqual(args.batch_size, (batch_size_A, batch_size_A))
25 | self.assertEqual(args.embedding_batch_size, batch_size_A)
26 | self.assertEqual(args.classifier_batch_size, batch_size_A)
27 |
28 | args = TrainingArguments(batch_size=(batch_size_A, batch_size_B))
29 | self.assertEqual(args.batch_size, (batch_size_A, batch_size_B))
30 | self.assertEqual(args.embedding_batch_size, batch_size_A)
31 | self.assertEqual(args.classifier_batch_size, batch_size_B)
32 |
33 | def test_num_epochs(self):
34 | num_epochs_A = 12
35 | num_epochs_B = 4
36 |
37 | args = TrainingArguments(num_epochs=num_epochs_A)
38 | self.assertEqual(args.num_epochs, (num_epochs_A, num_epochs_A))
39 | self.assertEqual(args.embedding_num_epochs, num_epochs_A)
40 | self.assertEqual(args.classifier_num_epochs, num_epochs_A)
41 |
42 | args = TrainingArguments(num_epochs=(num_epochs_A, num_epochs_B))
43 | self.assertEqual(args.num_epochs, (num_epochs_A, num_epochs_B))
44 | self.assertEqual(args.embedding_num_epochs, num_epochs_A)
45 | self.assertEqual(args.classifier_num_epochs, num_epochs_B)
46 |
47 | def test_learning_rates(self):
48 | learning_rate_A = 1e-2
49 | learning_rate_B = 1e-3
50 |
51 | base = TrainingArguments()
52 |
53 | args = TrainingArguments(body_learning_rate=learning_rate_A)
54 | self.assertEqual(args.body_learning_rate, (learning_rate_A, learning_rate_A))
55 | self.assertEqual(args.body_embedding_learning_rate, learning_rate_A)
56 | self.assertEqual(args.body_classifier_learning_rate, learning_rate_A)
57 | self.assertEqual(args.head_learning_rate, base.head_learning_rate)
58 |
59 | args = TrainingArguments(body_learning_rate=(learning_rate_A, learning_rate_B))
60 | self.assertEqual(args.body_learning_rate, (learning_rate_A, learning_rate_B))
61 | self.assertEqual(args.body_embedding_learning_rate, learning_rate_A)
62 | self.assertEqual(args.body_classifier_learning_rate, learning_rate_B)
63 | self.assertEqual(args.head_learning_rate, base.head_learning_rate)
64 |
65 | def test_report_to(self):
66 | args = TrainingArguments(report_to="none")
67 | self.assertEqual(args.report_to, ["none"])
68 | args = TrainingArguments(report_to=["none"])
69 | self.assertEqual(args.report_to, ["none"])
70 | args = TrainingArguments(report_to="hello")
71 | self.assertEqual(args.report_to, ["hello"])
72 |
73 | def test_eval_steps_without_eval_strat(self):
74 | args = TrainingArguments(eval_steps=5)
75 | self.assertEqual(args.eval_strategy, IntervalStrategy.STEPS)
76 |
77 | def test_eval_strat_steps_without_eval_steps(self):
78 | args = TrainingArguments(eval_strategy="steps")
79 | self.assertEqual(args.eval_steps, args.logging_steps)
80 | with self.assertRaises(ValueError):
81 | TrainingArguments(eval_strategy="steps", logging_steps=0, logging_strategy="no")
82 |
83 | def test_load_best_model(self):
84 | with self.assertRaises(ValueError):
85 | TrainingArguments(load_best_model_at_end=True, eval_strategy="steps", save_strategy="epoch")
86 | with self.assertRaises(ValueError):
87 | TrainingArguments(
88 | load_best_model_at_end=True,
89 | eval_strategy="steps",
90 | save_strategy="steps",
91 | eval_steps=100,
92 | save_steps=50,
93 | )
94 | # No error: save_steps is a round multiple of eval_steps
95 | TrainingArguments(
96 | load_best_model_at_end=True,
97 | eval_strategy="steps",
98 | save_strategy="steps",
99 | eval_steps=50,
100 | save_steps=100,
101 | )
102 |
103 | def test_logging_steps_zero(self):
104 | with self.assertRaises(ValueError):
105 | TrainingArguments(logging_strategy="steps", logging_steps=0)
106 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import tempfile
4 |
5 |
6 | class SafeTemporaryDirectory(tempfile.TemporaryDirectory):
7 | """
8 | The GitHub Actions CI on Windows sometimes raises a NotADirectoryError when cleaning up the temporary directory.
9 | This class is a workaround to avoid the error.
10 |
11 | Unlike tempfile.TemporaryDirectory(ignore_cleanup_errors=True), this also works on Python 3.8 and 3.9.
12 | """
13 |
14 | def __init__(self, *args, **kwargs) -> None:
15 | kwargs["ignore_cleanup_errors"] = True
16 | try:
17 | super().__init__(*args, **kwargs)
18 | except TypeError:
19 | del kwargs["ignore_cleanup_errors"]
20 | super().__init__(*args, **kwargs)
21 |
22 | def __exit__(self, *args, **kwargs):
23 | try:
24 | super().__exit__(*args, **kwargs)
25 | except NotADirectoryError:
26 | pass
27 |
--------------------------------------------------------------------------------
/utils/create_notebook_table.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | GITHUB_PATH_PREFIX = "huggingface/setfit/blob/main/notebooks/"
4 |
5 | CHAPTER_TO_NB = {
6 | "Text Classification (Multiclass)": "text-classification",
7 | "Text Classification (Multilabel)": "text-classification_multilabel",
8 | "Zero-Shot Text Classification": "zero-shot-classification",
9 | "Hyperparameter search": "text-classification_hyperparameter-search",
10 | }
11 |
12 |
13 | def _find_text_in_file(filename, start_prompt, end_prompt):
14 | """
15 | Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
16 | lines.
17 |
18 | Copied from: https://github.com/huggingface/transformers/blob/16f0b7d72c6d4e122957392c342b074aa2c5c519/utils/check_table.py#L30
19 | """
20 | with open(filename, "r", encoding="utf-8", newline="\n") as f:
21 | lines = f.readlines()
22 | # Find the start prompt.
23 | start_index = 0
24 | while not lines[start_index].startswith(start_prompt):
25 | start_index += 1
26 | start_index += 1
27 |
28 | end_index = start_index
29 | while not lines[end_index].startswith(end_prompt):
30 | end_index += 1
31 | end_index -= 1
32 |
33 | while len(lines[start_index]) <= 1:
34 | start_index += 1
35 | while len(lines[end_index]) <= 1:
36 | end_index -= 1
37 | end_index += 1
38 | return "".join(lines[start_index:end_index]), start_index, end_index, lines
39 |
40 |
41 | def create_table():
42 | data = {"Notebook": [], "Colab": [], "Kaggle": [], "Gradient": [], "Studio Lab": []}
43 | for title, nb in CHAPTER_TO_NB.items():
44 | nb_path = f"{GITHUB_PATH_PREFIX}{nb}.ipynb"
45 | data["Notebook"].append(title)
46 | data["Colab"].append(
47 | f"[](https://colab.research.google.com/github/{nb_path})"
48 | )
49 | data["Kaggle"].append(
50 | f"[](https://kaggle.com/kernels/welcome?src=https://github.com/{nb_path})"
51 | )
52 | data["Gradient"].append(
53 | f"[](https://console.paperspace.com/github/{nb_path})"
54 | )
55 | data["Studio Lab"].append(
56 | f"[](https://studiolab.sagemaker.aws/import/github/{nb_path})"
57 | )
58 | return pd.DataFrame(data).to_markdown(index=False) + "\n"
59 |
60 |
61 | def main():
62 | table = create_table()
63 | _, start_index, end_index, lines = _find_text_in_file(
64 | filename="notebooks/README.md",
65 | start_prompt="",
66 | end_prompt="",
67 | )
68 |
69 | with open("notebooks/README.md", "w", encoding="utf-8", newline="\n") as f:
70 | f.writelines(lines[:start_index] + [table] + lines[end_index:])
71 |
72 |
73 | if __name__ == "__main__":
74 | main()
75 |
--------------------------------------------------------------------------------
/utils/release.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import argparse
17 | import os
18 | import re
19 |
20 | import packaging.version
21 |
22 |
23 | REPLACE_PATTERNS = {
24 | "init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'),
25 | "setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'),
26 | }
27 | REPLACE_FILES = {
28 | "init": "src/setfit/__init__.py",
29 | "setup": "setup.py",
30 | }
31 | README_FILE = "README.md"
32 |
33 |
34 | def update_version_in_file(fname, version, pattern):
35 | """Update the version in one file using a specific pattern."""
36 | with open(fname, "r", encoding="utf-8", newline="\n") as f:
37 | code = f.read()
38 | re_pattern, replace = REPLACE_PATTERNS[pattern]
39 | replace = replace.replace("VERSION", version)
40 | code = re_pattern.sub(replace, code)
41 | with open(fname, "w", encoding="utf-8", newline="\n") as f:
42 | f.write(code)
43 |
44 |
45 | def global_version_update(version, patch=False):
46 | """Update the version in all needed files."""
47 | for pattern, fname in REPLACE_FILES.items():
48 | update_version_in_file(fname, version, pattern)
49 |
50 |
51 | def get_version():
52 | """Reads the current version in the __init__."""
53 | with open(REPLACE_FILES["init"], "r") as f:
54 | code = f.read()
55 | default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0]
56 | return packaging.version.parse(default_version)
57 |
58 |
59 | def pre_release_work(patch=False):
60 | """Do all the necessary pre-release steps."""
61 | # First let's get the default version: base version if we are in dev, bump minor otherwise.
62 | default_version = get_version()
63 | if patch and default_version.is_devrelease:
64 | raise ValueError("Can't create a patch version from the dev branch, checkout a released version!")
65 | if default_version.is_devrelease:
66 | default_version = default_version.base_version
67 | elif patch:
68 | default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}"
69 | else:
70 | default_version = f"{default_version.major}.{default_version.minor + 1}.0"
71 |
72 | # Now let's ask nicely if that's the right one.
73 | version = input(f"Which version are you releasing? [{default_version}]")
74 | if len(version) == 0:
75 | version = default_version
76 |
77 | print(f"Updating version to {version}.")
78 | global_version_update(version, patch=patch)
79 |
80 |
81 | def post_release_work():
82 | """Do all the necesarry post-release steps."""
83 | # First let's get the current version
84 | current_version = get_version()
85 | dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0"
86 | current_version = current_version.base_version
87 |
88 | # Check with the user we got that right.
89 | version = input(f"Which version are we developing now? [{dev_version}]")
90 | if len(version) == 0:
91 | version = dev_version
92 |
93 | print(f"Updating version to {version}.")
94 | global_version_update(version)
95 |
96 |
97 | if __name__ == "__main__":
98 | parser = argparse.ArgumentParser()
99 | parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.")
100 | parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.")
101 | args = parser.parse_args()
102 | if not args.post_release:
103 | pre_release_work(patch=args.patch)
104 | elif args.patch:
105 | print("Nothing to do after a patch :-)")
106 | else:
107 | post_release_work()
--------------------------------------------------------------------------------