├── .coveragerc ├── .github └── workflows │ ├── build_documentation.yml │ ├── build_pr_documentation.yml │ ├── delete_doc_comment_trigger.yml │ ├── quality.yml │ ├── tests.yml │ └── upload_pr_documentation.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── RELEASE.md ├── Zero_cost_Zero_time_Zero_shot_Financial_Sentiment_Analysis.ipynb ├── assets └── setfit.png ├── docs ├── README.md └── source │ ├── _config.py │ └── en │ ├── _toctree.yml │ ├── conceptual_guides │ ├── sampling_strategies.mdx │ └── setfit.mdx │ ├── how_to │ ├── absa.mdx │ ├── batch_sizes.mdx │ ├── callbacks.mdx │ ├── classification_heads.mdx │ ├── hyperparameter_optimization.mdx │ ├── knowledge_distillation.mdx │ ├── model_cards.mdx │ ├── multilabel.mdx │ ├── overview.mdx │ ├── v1.0.0_migration_guide.mdx │ └── zero_shot.mdx │ ├── index.mdx │ ├── installation.mdx │ ├── quickstart.mdx │ ├── reference │ ├── main.mdx │ ├── trainer.mdx │ └── utility.mdx │ └── tutorials │ ├── onnx.mdx │ ├── overview.mdx │ └── zero_shot.mdx ├── final_results ├── adapet │ ├── .gitkeep │ ├── albert-xxlarge-v2.tar.gz │ ├── xlm-roberta-base_all__lang_prompt.tar.gz │ ├── xlm-roberta-base_each__lang_prompt.tar.gz │ └── xlm-roberta-base_en__eng_prompt.tar.gz ├── distillation │ └── distillation_results.tar.gz ├── perfect.tar.gz ├── setfit │ ├── all-MiniLM-L6-v2.tar.gz │ ├── all-roberta-large-v1.tar.gz │ ├── paraphrase-mpnet-base-v2-epochs_20-with-augmentation.tar.gz │ ├── paraphrase-mpnet-base-v2-epochs_20.tar.gz │ ├── paraphrase-mpnet-base-v2-epochs_5.tar.gz │ ├── paraphrase-mpnet-base-v2-linear-probe.tar.gz │ └── paraphrase-multilingual-mpnet-base-v2.tar.gz ├── tfew │ ├── t011b_pretrained-v2.tar.gz │ └── t03b_pretrained-v2.tar.gz └── transformers │ ├── distilbert-base-uncased.tar.gz │ ├── roberta-large.tar.gz │ └── xlm-roberta-base.tar.gz ├── notebooks ├── .gitkeep ├── README.md ├── multilabel_HoC.ipynb ├── onnx_model_export.ipynb ├── openvino_inference.ipynb ├── setfit-absa-fiqa.ipynb ├── setfit-onnx-optimum.ipynb ├── setfit-optimum-intel.ipynb ├── text-classification.ipynb ├── text-classification_hyperparameter-search.ipynb ├── text-classification_multilabel.ipynb ├── zero-shot-classification.ipynb └── zero_cost_zero_time_zero_shot_financial_sentiment_analysis.ipynb ├── scripts ├── adapet │ ├── .gitkeep │ └── ADAPET │ │ ├── .gitignore │ │ ├── README.md │ │ ├── bin │ │ ├── dev.sh │ │ ├── init.sh │ │ ├── setup.sh │ │ ├── test.sh │ │ └── train.sh │ │ ├── cli.py │ │ ├── config │ │ ├── BoolQ.json │ │ ├── CB.json │ │ ├── COPA.json │ │ ├── Generic.json │ │ ├── MultiRC.json │ │ ├── RTE.json │ │ ├── ReCoRD.json │ │ ├── WSC.json │ │ ├── WiC.json │ │ └── sst-2.json │ │ ├── requirements.txt │ │ ├── setfit_adapet.py │ │ ├── src │ │ ├── adapet.py │ │ ├── adapet_test_eval.py │ │ ├── data │ │ │ ├── Batcher.py │ │ │ ├── BoolQReader.py │ │ │ ├── CBReader.py │ │ │ ├── COPAReader.py │ │ │ ├── Dataset.py │ │ │ ├── DatasetReader.py │ │ │ ├── GenericReader.py │ │ │ ├── MultiRCReader.py │ │ │ ├── RTEReader.py │ │ │ ├── RecordReader.py │ │ │ ├── WSCReader.py │ │ │ ├── WiCReader.py │ │ │ └── tokenize.py │ │ ├── dev.py │ │ ├── eval │ │ │ ├── Scorer.py │ │ │ ├── Writer.py │ │ │ └── eval_model.py │ │ ├── run_pretrained.py │ │ ├── scripts │ │ │ └── example_convert_sst_2_generic.py │ │ ├── test.py │ │ ├── train.py │ │ └── utils │ │ │ ├── Config.py │ │ │ └── util.py │ │ └── utilcode.py ├── create_summary_table.py ├── perfect │ └── README.md ├── plot_summary_comparison.py ├── setfit │ ├── README.md │ ├── distillation_baseline.py │ ├── run_fewshot.py │ ├── run_fewshot_distillation.py │ ├── run_fewshot_multilabel.py │ ├── run_fewshot_multilingual.py │ └── run_zeroshot.py ├── tfew │ ├── README.md │ ├── requirements.txt │ ├── run_tfew_11b.sh │ └── run_tfew_test.sh └── transformers │ ├── README.md │ ├── requirements.txt │ ├── run_fewshot.py │ ├── run_fewshot_multilingual.py │ ├── run_full.py │ ├── run_full_multilingual.py │ ├── run_inference.py │ ├── run_zeroshot.py │ └── utils.py ├── setup.cfg ├── setup.py ├── src └── setfit │ ├── __init__.py │ ├── data.py │ ├── exporters │ ├── __init__.py │ ├── onnx.py │ ├── openvino.py │ └── utils.py │ ├── integrations.py │ ├── logging.py │ ├── losses.py │ ├── model_card.py │ ├── model_card_template.md │ ├── modeling.py │ ├── notebook.py │ ├── sampler.py │ ├── span │ ├── __init__.py │ ├── aspect_extractor.py │ ├── modeling.py │ └── trainer.py │ ├── trainer.py │ ├── trainer_distillation.py │ ├── training_args.py │ └── utils.py ├── tests ├── __init__.py ├── conftest.py ├── exporters │ ├── test_onnx.py │ └── test_openvino.py ├── model_card_pattern.py ├── span │ ├── __init__.py │ ├── aspect_model_card_pattern.py │ ├── polarity_model_card_pattern.py │ ├── test_model_card.py │ ├── test_modeling.py │ └── test_trainer.py ├── test_data.py ├── test_deprecated_trainer.py ├── test_deprecated_trainer_distillation.py ├── test_model_card.py ├── test_modeling.py ├── test_sampler.py ├── test_trainer.py ├── test_trainer_distillation.py ├── test_training_args.py └── utils.py └── utils ├── create_notebook_table.py └── release.py /.coveragerc: -------------------------------------------------------------------------------- 1 | # Configuration file to control (pytest) coverage 2 | [run] 3 | # Run branch coverage, too 4 | branch = True 5 | 6 | [paths] 7 | source = 8 | src/setfit 9 | 10 | [report] 11 | # Regexes for lines to exclude from consideration 12 | exclude_lines = 13 | # Have to re-enable the standard pragma 14 | pragma: no cover 15 | 16 | # Don't complain about missing debug-only code: 17 | def __repr__ 18 | if self\.debug 19 | 20 | # Don't complain if tests don't hit defensive assertion code: 21 | raise AssertionError 22 | raise NotImplementedError 23 | 24 | # Don't complain if non-runnable code isn't run: 25 | if 0: 26 | if __name__ == .__main__.: 27 | 28 | # Don't complain about abstract methods, they aren't run: 29 | @(abc\.)?abstractmethod 30 | 31 | # Ignore TYPE_CHECKING code 32 | if TYPE_CHECKING: 33 | 34 | [html] 35 | directory = coverage_report_html 36 | title = SetFit coverage report -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - doc-builder* 8 | - v*-release 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.sha }} 15 | package: setfit 16 | notebook_folder: setfit_doc 17 | languages: en 18 | secrets: 19 | token: ${{ secrets.HUGGINGFACE_PUSH }} 20 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} -------------------------------------------------------------------------------- /.github/workflows/build_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | 6 | concurrency: 7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.event.pull_request.head.sha }} 15 | pr_number: ${{ github.event.number }} 16 | package: setfit 17 | languages: en -------------------------------------------------------------------------------- /.github/workflows/delete_doc_comment_trigger.yml: -------------------------------------------------------------------------------- 1 | name: Delete doc comment trigger 2 | 3 | on: 4 | pull_request: 5 | types: [ closed ] 6 | 7 | 8 | jobs: 9 | delete: 10 | uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main 11 | with: 12 | pr_number: ${{ github.event.number }} -------------------------------------------------------------------------------- /.github/workflows/quality.yml: -------------------------------------------------------------------------------- 1 | name: Quality 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - v*-release 8 | - v*-pre 9 | pull_request: 10 | branches: 11 | - main 12 | - v*-pre 13 | workflow_dispatch: 14 | 15 | jobs: 16 | 17 | check_code_quality: 18 | name: Check code quality 19 | runs-on: ubuntu-latest 20 | steps: 21 | - name: Checkout code 22 | uses: actions/checkout@v2 23 | - name: Setup Python environment 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: 3.9 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install ".[quality]" 31 | - name: Code quality 32 | run: | 33 | make quality 34 | 35 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Unit tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - v*-release 8 | - v*-pre 9 | pull_request: 10 | branches: 11 | - main 12 | - v*-pre 13 | workflow_dispatch: 14 | 15 | env: 16 | TRANSFORMERS_IS_CI: 1 17 | 18 | jobs: 19 | 20 | test_sampling: 21 | name: Run unit tests 22 | strategy: 23 | matrix: 24 | python-version: ['3.9', '3.10', '3.11', '3.12'] 25 | os: [ubuntu-latest, windows-latest] 26 | requirements: ['.[tests]', '.[compat_tests]'] 27 | fail-fast: false 28 | runs-on: ${{ matrix.os }} 29 | steps: 30 | - name: Checkout code 31 | uses: actions/checkout@v3 32 | 33 | - name: Setup Python environment 34 | uses: actions/setup-python@v4 35 | with: 36 | python-version: ${{ matrix.python-version }} 37 | 38 | # Taken from https://github.com/actions/cache?tab=readme-ov-file#creating-a-cache-key 39 | # Use date to invalidate cache every week 40 | - name: Get date 41 | id: get-date 42 | run: | 43 | echo "date=$(/bin/date -u "+%G%V")" >> $GITHUB_OUTPUT 44 | shell: bash 45 | 46 | - name: Try to load cached dependencies 47 | uses: actions/cache@v3 48 | id: restore-cache 49 | with: 50 | path: ${{ env.pythonLocation }} 51 | key: python-dependencies-${{ matrix.os }}-${{ steps.get-date.outputs.date }}-${{ matrix.python-version }}-${{ matrix.requirements }}-${{ hashFiles('setup.py') }}-${{ env.pythonLocation }} 52 | 53 | - name: Install external dependencies on cache miss 54 | run: | 55 | python -m pip install --no-cache-dir --upgrade pip 56 | python -m pip install --no-cache-dir ${{ matrix.requirements }} 57 | python -m pip install '.[codecarbon]' 58 | python -m spacy download en_core_web_lg 59 | python -m spacy download en_core_web_sm 60 | if: steps.restore-cache.outputs.cache-hit != 'true' 61 | 62 | - name: Install the checked-out setfit 63 | run: python -m pip install . 64 | 65 | - name: Restore HF models from cache 66 | uses: actions/cache/restore@v3 67 | with: 68 | path: | 69 | ~/.cache/huggingface/hub 70 | ~/.cache/torch 71 | key: hf-models-${{ matrix.os }}-${{ env.NEW_HF_CACHE_HASH }} 72 | restore-keys: | 73 | hf-models-${{ matrix.os }}- 74 | 75 | - name: Run unit tests 76 | shell: bash 77 | run: | 78 | echo "OLD_HF_CACHE_HASH=$(find ~/.cache/huggingface/hub ~/.cache/torch -type f -exec sha256sum {} + | LC_ALL=C sort | sha256sum | cut -d ' ' -f 1)" >> $GITHUB_ENV 79 | pytest -v tests/ 80 | echo "NEW_HF_CACHE_HASH=$(find ~/.cache/huggingface/hub ~/.cache/torch -type f -exec sha256sum {} + | LC_ALL=C sort | sha256sum | cut -d ' ' -f 1)" >> $GITHUB_ENV 81 | 82 | - name: Save new HF models to cache 83 | uses: actions/cache/save@v3 84 | with: 85 | path: | 86 | ~/.cache/huggingface/hub 87 | ~/.cache/torch 88 | key: hf-models-${{ matrix.os }}-${{ env.NEW_HF_CACHE_HASH }} 89 | # Only save cache if the hash has changed 90 | if: env.NEW_HF_CACHE_HASH != env.OLD_HF_CACHE_HASH 91 | -------------------------------------------------------------------------------- /.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | build: 11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 12 | with: 13 | package_name: setfit 14 | secrets: 15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # pycharm 2 | .idea 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | *.pyc 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | launch.json 135 | 136 | # VS Code 137 | .history 138 | .vscode/launch.json 139 | 140 | # Temporary results files 141 | scripts/**/results.json 142 | scripts/**/checkpoints 143 | scripts/**/train_script.py 144 | scripts/**/summary_table.csv 145 | scripts/tfew/t-few 146 | scripts/tfew/results 147 | scripts/tfew/run_tmux.sh 148 | 149 | # macOS 150 | .DS_Store 151 | .vscode/settings.json 152 | 153 | # Common SetFit Trainer logging folders 154 | wandb 155 | runs/ 156 | notebooks/perf_metrics.json 157 | notebooks/nc_workspace/ 158 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include src/setfit/model_card_template.md -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: notebooks 2 | 3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) 4 | export PYTHONPATH = src 5 | 6 | check_dirs := scripts tests src setup.py 7 | 8 | style: 9 | python -m black --line-length 119 --exclude="scripts/adapet|scripts/tfew" --target-version py39 $(check_dirs) 10 | python -m isort --skip scripts/adapet --skip scripts/tfew $(check_dirs) 11 | 12 | quality: 13 | python -m black --check --line-length 119 --exclude="scripts/adapet|scripts/tfew" --target-version py39 $(check_dirs) 14 | python -m isort --check-only --skip scripts/adapet --skip scripts/tfew $(check_dirs) 15 | python -m flake8 --max-line-length 119 $(check_dirs) 16 | 17 | test: 18 | python -m pytest -sv tests/ 19 | 20 | coverage: 21 | python -m pytest --cov=src --cov-report=term-missing -sv tests/ 22 | 23 | notebooks: 24 | python utils/create_notebook_table.py 25 | 26 | # Release stuff 27 | 28 | pre-release: 29 | python utils/release.py 30 | 31 | pre-patch: 32 | python utils/release.py --patch 33 | 34 | post-release: 35 | python utils/release.py --post_release 36 | 37 | post-patch: 38 | python utils/release.py --post_release --patch 39 | 40 | wheels: 41 | python setup.py bdist_wheel && python setup.py sdist 42 | 43 | wheels_clean: 44 | rm -rf build && rm -rf dist 45 | 46 | pypi_upload: 47 | python -m pip install twine 48 | twine upload dist/* -r pypi 49 | 50 | pypi_test_upload: 51 | python -m pip install twine 52 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ 53 | 54 | pypi_test_install: 55 | python -m pip install evaluate==0.3.0 datasets==2.3.2 sentence_transformers==2.2.2 56 | python -m pip install -i https://testpypi.python.org/pypi setfit 57 | python -c "from setfit import *" 58 | echo "🚀 Successfully installed setfit from test.pypi.org" 59 | -------------------------------------------------------------------------------- /assets/setfit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/assets/setfit.png -------------------------------------------------------------------------------- /docs/source/_config.py: -------------------------------------------------------------------------------- 1 | # docstyle-ignore 2 | INSTALL_CONTENT = """ 3 | # SetFit installation 4 | ! pip install setfit 5 | # To install from source instead of the last release, comment the command above and uncomment the following one. 6 | # ! pip install git+https://github.com/huggingface/setfit.git 7 | """ 8 | 9 | notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}] -------------------------------------------------------------------------------- /docs/source/en/_toctree.yml: -------------------------------------------------------------------------------- 1 | - sections: 2 | - local: index 3 | title: SetFit 4 | - local: quickstart 5 | title: Quickstart 6 | - local: installation 7 | title: Installation 8 | title: Get started 9 | 10 | - sections: 11 | - local: tutorials/overview 12 | title: Overview 13 | - local: tutorials/zero_shot 14 | title: Zero-shot Text Classification 15 | - local: tutorials/onnx 16 | title: Efficiently run SetFit with ONNX 17 | title: Tutorials 18 | 19 | - sections: 20 | - local: how_to/overview 21 | title: Overview 22 | - local: how_to/callbacks 23 | title: Callbacks 24 | - local: how_to/model_cards 25 | title: Model Cards 26 | - local: how_to/classification_heads 27 | title: Classification Heads 28 | - local: how_to/multilabel 29 | title: Multilabel Text Classification 30 | - local: how_to/zero_shot 31 | title: Zero-shot Text Classification 32 | - local: how_to/hyperparameter_optimization 33 | title: Hyperparameter Optimization 34 | - local: how_to/knowledge_distillation 35 | title: Knowledge Distillation 36 | - local: how_to/batch_sizes 37 | title: Batch Sizes for Inference 38 | - local: how_to/absa 39 | title: Aspect Based Sentiment Analysis 40 | - local: how_to/v1.0.0_migration_guide 41 | title: v1.0.0 Migration Guide 42 | title: How-to Guides 43 | 44 | - sections: 45 | - local: conceptual_guides/setfit 46 | title: SetFit 47 | - local: conceptual_guides/sampling_strategies 48 | title: Sampling Strategies 49 | title: Conceptual Guides 50 | 51 | - sections: 52 | - local: reference/main 53 | title: Main classes 54 | - local: reference/trainer 55 | title: Trainer classes 56 | - local: reference/utility 57 | title: Utility 58 | title: Reference -------------------------------------------------------------------------------- /docs/source/en/conceptual_guides/sampling_strategies.mdx: -------------------------------------------------------------------------------- 1 | 2 | # SetFit Sampling Strategies 3 | 4 | SetFit supports various contrastive pair sampling strategies in [`TrainingArguments`]. In this conceptual guide, we will learn about the following four sampling strategies: 5 | 6 | 1. `"oversampling"` (the default) 7 | 2. `"undersampling"` 8 | 3. `"unique"` 9 | 4. `"num_iterations"` 10 | 11 | Consider first reading the [SetFit conceptual guide](../setfit) for a background on contrastive learning and positive & negative pairs. 12 | 13 | ## Running example 14 | 15 | Throughout this conceptual guide, we will use to the following example scenario: 16 | 17 | * 3 classes: "happy", "content", and "sad". 18 | * 20 total samples: 8 "happy", 4 "content", and 8 "sad" samples. 19 | 20 | Considering that a sentence pair of `(X, Y)` and `(Y, X)` result in the same embedding distance/loss, we only want to consider one of those two cases. Furthermore, we don't want pairs where both sentences are the same, e.g. no `(X, X)`. 21 | 22 | The resulting positive and negative pairs can be visualized in a table like below. The `+` and `-` represent positive and negative pairs, respectively. Furthermore, `h-n` represents the n-th "happy" sentence, `c-n` the n-th "content" sentence, and `s-n` the n-th "sad" sentence. Note that the area below the diagonal is not used as `(X, Y)` and `(Y, X)` result in the same embedding distances, and that the diagonal is not used as we are not interested in pairs where both sentences are identical. 23 | 24 | | |h-1|h-2|h-3|h-4|h-5|h-6|h-7|h-8|c-1|c-2|c-3|c-4|s-1|s-2|s-3|s-4|s-5|s-6|s-7|s-8| 25 | |-------|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| 26 | |**h-1**| | + | + | + | + | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - | 27 | |**h-2**| | | + | + | + | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - | 28 | |**h-3**| | | | + | + | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - | 29 | |**h-4**| | | | | + | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - | 30 | |**h-5**| | | | | | + | + | + | - | - | - | - | - | - | - | - | - | - | - | - | 31 | |**h-6**| | | | | | | + | + | - | - | - | - | - | - | - | - | - | - | - | - | 32 | |**h-7**| | | | | | | | + | - | - | - | - | - | - | - | - | - | - | - | - | 33 | |**h-8**| | | | | | | | | - | - | - | - | - | - | - | - | - | - | - | - | 34 | |**c-1**| | | | | | | | | | + | + | + | - | - | - | - | - | - | - | - | 35 | |**c-2**| | | | | | | | | | | + | + | - | - | - | - | - | - | - | - | 36 | |**c-3**| | | | | | | | | | | | + | - | - | - | - | - | - | - | - | 37 | |**c-4**| | | | | | | | | | | | | - | - | - | - | - | - | - | - | 38 | |**s-1**| | | | | | | | | | | | | | + | + | + | + | + | + | + | 39 | |**s-2**| | | | | | | | | | | | | | | + | + | + | + | + | + | 40 | |**s-3**| | | | | | | | | | | | | | | | + | + | + | + | + | 41 | |**s-4**| | | | | | | | | | | | | | | | | + | + | + | + | 42 | |**s-5**| | | | | | | | | | | | | | | | | | + | + | + | 43 | |**s-6**| | | | | | | | | | | | | | | | | | | + | + | 44 | |**s-7**| | | | | | | | | | | | | | | | | | | | + | 45 | |**s-8**| | | | | | | | | | | | | | | | | | | | | 46 | 47 | As shown in the prior table, we have 28 positive pairs for "happy", 6 positive pairs for "content", and another 28 positive pairs for "sad". In total, this is 62 positive pairs. Also, we have 32 negative pairs between "happy" and "content", 64 negative pairs between "happy" and "sad", and 32 negative pairs between "content" and "sad". In total, this is 128 negative pairs. 48 | 49 | ## Oversampling 50 | 51 | By default, SetFit applies the oversampling strategy for its contrastive pairs. This strategy samples an equal amount of positive and negative training pairs, oversampling the minority pair type to match that of the majority pair type. As the number of negative pairs is generally larger than the number of positive pairs, this usually involves oversampling the positive pairs. 52 | 53 | In our running example, this would involve oversampling the 62 positive pairs up to 128, resulting in one epoch of 128 + 128 = 256 pairs. In summary: 54 | 55 | * ✅ An equal amount of positive and negative pairs are sampled. 56 | * ✅ Every possible pair is used. 57 | * ❌ There is some data duplication. 58 | 59 | ## Undersampling 60 | 61 | Like oversampling, this strategy samples an equal amount of positive and negative training pairs. However, it undersamples the majority pair type to match that of the minority pair type. This usually involves undersampling the negative pairs to match the positive pairs. 62 | 63 | In our running example, this would involve undersampling the 128 negative pairs down to 62, resulting in one epoch of 62 + 62 = 124 pairs. In summary: 64 | 65 | * ✅ An equal amount of positive and negative pairs are sampled. 66 | * ❌ **Not** every possible pair is used. 67 | * ✅ There is **no** data duplication. 68 | 69 | ## Unique 70 | 71 | Thirdly, the unique strategy does not sample an equal amount of positive and negative training pairs. Instead, it simply samples all possible pairs exactly once. No form of oversampling or undersampling is used here. 72 | 73 | In our running example, this would involve sampling all negative and positive pairs, resulting in one epoch of 62 + 128 = 190 pairs. In summary: 74 | 75 | * ❌ **Not** an equal amount of positive and negative pairs are sampled. 76 | * ✅ Every possible pair is used. 77 | * ✅ There is **no** data duplication. 78 | 79 | ## `num_iterations` 80 | 81 | Lastly, SetFit can still be used with a deprecated sampling strategy involving the `num_iterations` training argument. Unlike the other sampling strategies, this strategy does not involve the number of possible pairs. Instead, it samples `num_iterations` positive pairs and `num_iterations` negative pairs for each training sample. 82 | 83 | In our running example, if we assume `num_iterations=20`, then we would sample 20 positive pairs and 20 negative pairs per training sample. Because there's 20 samples, this involves (20 + 20) * 20 = 800 pairs. Because there are only 190 unique pairs, this certainly involves some data duplication. However, it does not guarantee that every possible pair is used. In summary: 84 | 85 | * ✅ **Not** an equal amount of positive and negative pairs are sampled. 86 | * ❌ Not necessarily every possible pair is used. 87 | * ❌ There is some data duplication. -------------------------------------------------------------------------------- /docs/source/en/conceptual_guides/setfit.mdx: -------------------------------------------------------------------------------- 1 | 2 | # Sentence Transformers Finetuning (SetFit) 3 | 4 | SetFit is a model framework to efficiently train text classification models with surprisingly little training data. For example, with only 8 labeled examples per class on the Customer Reviews (CR) sentiment dataset, SetFit is competitive with fine-tuning RoBERTa Large on the full training set of 3k examples. Furthermore, SetFit is fast to train and run inference with, and can easily support multilingual tasks. 5 | 6 | Every SetFit model consists of two parts: a **sentence transformer** embedding model (the body) and a **classifier** (the head). These two parts are trained in two separate phases: the **embedding finetuning phase** and the **classifier training phase**. This conceptual guide will elaborate on the intuition between these phases, and why SetFit works so well. 7 | 8 | ## Embedding finetuning phase 9 | 10 | The first phase has one primary goal: finetune a sentence transformer embedding model to produce useful embeddings for *our* classification task. The [Hugging Face Hub](https://huggingface.co/models?library=sentence-transformers) already has thousands of sentence transformer available, many of which have been trained to very accurately group the embeddings of texts with similar semantic meaning. 11 | 12 | However, models that are good at Semantic Textual Similarity (STS) are not necessarily immediately good at *our* classification task. For example, according to an embedding model, the sentence of 1) `"He biked to work."` will be much more similar to 2) `"He drove his car to work."` than to 3) `"Peter decided to take the bicycle to the beach party!"`. But if our classification task involves classifying texts into transportation modes, then we want our embedding model to place sentences 1 and 3 closely together, and 2 further away. 13 | 14 | To do so, we can finetune the chosen sentence transformer embedding model. The goal here is to nudge the model to use its pretrained knowledge in a different way that better aligns with our classification task, rather than making it completely forget what it has learned. 15 | 16 | For finetuning, SetFit uses **contrastive learning**. This training approach involves creating **positive and negative pairs** of sentences. A sentence pair will be positive if both of the sentences are of the same class, and negative otherwise. For example, in the case of binary "positive"-"negative" sentiment analysis, `("The movie was awesome", "I loved it")` is a positive pair, and `("The movie was awesome", "It was quite disappointing")` is a negative pair. 17 | 18 | During training, the embedding model receives these pairs, and will convert the sentences to embeddings. If the pair is positive, then it will pull on the model weights such that the text embeddings will be more similar, and vice versa for a negative pair. Through this approach, sentences with the same label will be embedded more similarly, and sentences with different labels less similarly. 19 | 20 | Conveniently, this contrastive learning works with pairs rather than individual samples, and we can create plenty of unique pairs from just a few samples. For example, given 8 positive sentences and 8 negative sentences, we can create 28 positive pairs and 64 negative pairs for 92 unique training pairs. This grows exponentially to the number of sentences and classes, and that is why SetFit can train with just a few examples and still correctly finetune the sentence transformer embedding model. However, we should still be wary of overfitting. 21 | 22 | ## Classifier training phase 23 | 24 | Once the sentence transformer embedding model has been finetuned for our task at hand, we can start training the classifier. This phase has one primary goal: create a good mapping from the sentence transformer embeddings to the classes. 25 | 26 | Unlike with the first phase, training the classifier is done from scratch and using the labeled samples directly, rather than using pairs. By default, the classifier is a simple **logistic regression** classifier from scikit-learn. First, all training sentences are fed through the now-finetuned sentence transformer embedding model, and then the sentence embeddings and labels are used to fit the logistic regression classifier. The result is a strong and efficient classifier. 27 | 28 | Using these two parts, SetFit models are efficient, performant and easy to train, even on CPU-only devices. 29 | -------------------------------------------------------------------------------- /docs/source/en/how_to/batch_sizes.mdx: -------------------------------------------------------------------------------- 1 | 2 | # Batch sizes for Inference 3 | In this how-to guide we will explore the effects of increasing the batch sizes in [`SetFitModel.predict`]. 4 | 5 | ## What are they? 6 | When processing on GPUs, often times not all data fits on the GPU its VRAM at once. As a result, the data gets split up into **batches** of some often pre-determined batch size. This is done both during training and during inference. In both scenarios, increasing the batch size often has notable consequences to processing efficiency and VRAM memory usage, as transferring data to and from the GPU can be relatively slow. 7 | 8 | For inference, it is often recommended to set the batch size high to get notably quicker processing speeds. 9 | 10 | ## In SetFit 11 | The batch size for inference in SetFit is set to 32, but it can be affected by passing a `batch_size` argument to [`SetFitModel.predict`]. For example, on a RTX 3090 with a SetFit model based on the [paraphrase-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-mpnet-base-v2) Sentence Transformer, the following throughputs are reached: 12 | 13 | ![setfit_speed_per_batch_size](https://github.com/huggingface/setfit/assets/37621491/c01d391b-aeba-4a4b-83f8-b09970a0d6e6) 14 | 15 | 16 | 17 | Each sentence consists of 11 words in this experiment. 18 | 19 | 20 | 21 | The default batch size of 32 does not result in the highest possible throughput on this hardware. Consider experimenting with the batch size to reach your highest possible throughput. -------------------------------------------------------------------------------- /docs/source/en/how_to/callbacks.mdx: -------------------------------------------------------------------------------- 1 | 2 | # Callbacks 3 | SetFit models can be influenced by callbacks, for example for logging or early stopping. 4 | 5 | This guide will show you what they are and how they can be used. 6 | 7 | ## Callbacks in SetFit 8 | 9 | Callbacks are objects that customize the behaviour of the training loop in the SetFit [`Trainer`] that can inspect the training loop state (for progress reporting, logging, inspecting embeddings during training) and take decisions (e.g. early stopping). 10 | 11 | In particular, the [`Trainer`] uses a [`TrainerControl`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.TrainerControl) that can be influenced by callbacks to stop training, save models, evaluate, or log, and a [`TrainerState`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.TrainerState) which tracks some training loop metrics during training, such as the number of training steps so far. 12 | 13 | SetFit relies on the Callbacks implemented in `transformers`, as described in the `transformers` documentation [here](https://huggingface.co/docs/transformers/main_classes/callback). 14 | 15 | ## Default Callbacks 16 | 17 | SetFit uses the `TrainingArguments.report_to` argument to specify which of the built-in callbacks should be enabled. This argument defaults to `"all"`, meaning that all third-party callbacks from `transformers` that are also installed will be enabled. For example the [`TensorBoardCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.integrations.TensorBoardCallback) or the [`WandbCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.integrations.WandbCallback). 18 | 19 | Beyond that, the [`PrinterCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.PrinterCallback) or [`ProgressCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.ProgressCallback) is always enabled to show the training progress, and [`DefaultFlowCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.DefaultFlowCallback) is also always enabled to properly update the `TrainerControl`. 20 | 21 | ## Using Callbacks 22 | 23 | As mentioned, you can use `TrainingArguments.report_to` to specify exactly which callbacks you would like to enable. For example: 24 | 25 | ```py 26 | from setfit import TrainingArguments 27 | 28 | args = TrainingArguments( 29 | ..., 30 | report_to="wandb", 31 | ..., 32 | ) 33 | # or 34 | args = TrainingArguments( 35 | ..., 36 | report_to=["wandb", "tensorboard"], 37 | ..., 38 | ) 39 | ``` 40 | You can also use [`Trainer.add_callback`], [`Trainer.pop_callback`] and [`Trainer.remove_callback`] to influence the trainer callbacks, and you can specify callbacks via the [`Trainer`] init, e.g.: 41 | 42 | ```py 43 | from setfit import Trainer 44 | 45 | ... 46 | 47 | trainer = Trainer( 48 | model, 49 | args=args, 50 | train_dataset=train_dataset, 51 | eval_dataset=eval_dataset, 52 | callbacks=[EarlyStoppingCallback(early_stopping_patience=5)], 53 | ) 54 | trainer.train() 55 | ``` 56 | 57 | ## Custom Callbacks 58 | 59 | SetFit supports custom callbacks in the same way that `transformers` does: by subclassing [`TrainerCallback`](https://huggingface.co/docs/transformers/main_classes/callback#transformers.TrainerCallback). This class implements a lot of `on_...` methods that can be overridden. For example, the following script shows a custom callback that saves plots of the tSNE of the training and evaluation embeddings during training. 60 | 61 | ```py 62 | import os 63 | import matplotlib.pyplot as plt 64 | from sklearn.manifold import TSNE 65 | 66 | class EmbeddingPlotCallback(TrainerCallback): 67 | """Simple embedding plotting callback that plots the tSNE of the training and evaluation datasets throughout training.""" 68 | def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 69 | os.makedirs("logs", exist_ok=True) 70 | 71 | def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: SetFitModel, **kwargs): 72 | train_embeddings = model.encode(train_dataset["text"]) 73 | eval_embeddings = model.encode(eval_dataset["text"]) 74 | 75 | fig, (train_ax, eval_ax) = plt.subplots(ncols=2) 76 | 77 | train_X = TSNE(n_components=2).fit_transform(train_embeddings) 78 | train_ax.scatter(*train_X.T, c=train_dataset["label"], label=train_dataset["label"]) 79 | train_ax.set_title("Training embeddings") 80 | 81 | eval_X = TSNE(n_components=2).fit_transform(eval_embeddings) 82 | eval_ax.scatter(*eval_X.T, c=eval_dataset["label"], label=eval_dataset["label"]) 83 | eval_ax.set_title("Evaluation embeddings") 84 | 85 | fig.suptitle(f"tSNE of training and evaluation embeddings at step {state.global_step} of {state.max_steps}.") 86 | fig.savefig(f"logs/step_{state.global_step}.png") 87 | ``` 88 | 89 | with 90 | 91 | ```py 92 | trainer = Trainer( 93 | model=model, 94 | args=args, 95 | train_dataset=train_dataset, 96 | eval_dataset=eval_dataset, 97 | callbacks=[EmbeddingPlotCallback()] 98 | ) 99 | trainer.train() 100 | ``` 101 | 102 | The `on_evaluate` from `EmbeddingPlotCallback` will be triggered on every single evaluation call. In the case of this example, it resulted in the following figures being plotted: 103 | 104 | | Step 20 | Step 40 | 105 | |-------------|-------------| 106 | | ![step_20](https://github.com/huggingface/setfit/assets/37621491/7200d00a-fd48-4038-bcbe-f2d5f1280162) | ![step_40](https://github.com/huggingface/setfit/assets/37621491/be12e3c4-867c-452d-89a0-0677f035516d) | 107 | | **Step 60** | **Step 80** | 108 | | ![step_60](https://github.com/huggingface/setfit/assets/37621491/3a384aa2-51ce-40d7-b02c-a2c986f3aeb4) | ![step_80](https://github.com/huggingface/setfit/assets/37621491/b5aa9835-40cb-4327-9f31-b3ababeca769) | -------------------------------------------------------------------------------- /docs/source/en/how_to/multilabel.mdx: -------------------------------------------------------------------------------- 1 | 2 | # Multilabel Text Classification 3 | 4 | SetFit supports multilabel classification, allowing multiple labels to be assigned to each instance. 5 | 6 | 7 | 8 | Unless each instance must be assigned multiple outputs, you frequently do not need to specify a multi target strategy. 9 | 10 | 11 | 12 | This guide will show you how to train and use multilabel SetFit models. 13 | 14 | ## Multilabel strategies 15 | 16 | SetFit will initialise a multilabel classification head from `sklearn` - the following options are available for `multi_target_strategy`: 17 | 18 | * `"one-vs-rest"`: uses a [`OneVsRestClassifier`](https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html) head. 19 | * `"multi-output"`: uses a [`MultiOutputClassifier`](https://scikit-learn.org/stable/modules/generated/sklearn.multioutput.MultiOutputClassifier.html) head. 20 | * `"classifier-chain"`: uses a [`ClassifierChain`](https://scikit-learn.org/stable/modules/generated/sklearn.multioutput.ClassifierChain.html) head. 21 | 22 | See the [scikit-learn documentation for multiclass and multioutput classification](https://scikit-learn.org/stable/modules/multiclass.html#multiclass-classification) for more details. 23 | 24 | ## Initializing SetFit models with multilabel strategies 25 | 26 | Using the default LogisticRegression head, we can apply multi target strategies like so: 27 | 28 | ```py 29 | from setfit import SetFitModel 30 | 31 | model = SetFitModel.from_pretrained( 32 | model_id, # e.g. "BAAI/bge-small-en-v1.5" 33 | multi_target_strategy="multi-output", 34 | ) 35 | ``` 36 | 37 | With a differentiable head it looks like so: 38 | 39 | ```py 40 | from setfit import SetFitModel 41 | 42 | model = SetFitModel.from_pretrained( 43 | model_id, # e.g. "BAAI/bge-small-en-v1.5" 44 | multi_target_strategy="one-vs-rest" 45 | use_differentiable_head=True, 46 | head_params={"out_features": num_classes}, 47 | ) 48 | ``` -------------------------------------------------------------------------------- /docs/source/en/how_to/overview.mdx: -------------------------------------------------------------------------------- 1 | 2 | # Overview 3 | 4 | Welcome to the SetFit How-to Guides! The how-to guides offer a more comprehensive overview of all the tools 🤗 SetFit offers and how to use them. 5 | These guides are designed to be concise and code-heavy, written in "show, don't tell" style. For example, using these guides you may learn how to perform hyperparameter optimization, knowledge distillation, apply callbacks, etc. 6 | 7 | Most how-to guides end with an "end to end" script showing all code from the guide for easy adaptation into your own code. 8 | 9 | For simpler documentation explaining SetFit functionality from start to finish, consider visiting the [Tutorials](../tutorials/overview) section or the [quickstart](../quickstart). 10 | -------------------------------------------------------------------------------- /docs/source/en/how_to/zero_shot.mdx: -------------------------------------------------------------------------------- 1 | 2 | # Zero-shot Text Classification 3 | 4 | [[open-in-colab]] 5 | 6 | Your class names are likely already good descriptors of the text that you're looking to classify. With 🤗 SetFit, you can use these class names with strong pretrained Sentence Transformer models to get a strong baseline model without any training samples. 7 | 8 | This guide will show you how to perform zero-shot text classification. 9 | 10 | ## Testing dataset 11 | 12 | We'll use the [dair-ai/emotion](https://huggingface.co/datasets/dair-ai/emotion) dataset to test the performance of our zero-shot model. 13 | 14 | ```py 15 | from datasets import load_dataset 16 | 17 | test_dataset = load_dataset("dair-ai/emotion", "split", split="test") 18 | ``` 19 | 20 | This dataset stores the class names within the dataset `Features`, so we'll extract the classes like so: 21 | ```py 22 | classes = test_dataset.features["label"].names 23 | # => ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'] 24 | ``` 25 | Otherwise, we could manually set the list of classes. 26 | 27 | ## Synthetic dataset 28 | 29 | Then, we can use [`get_templated_dataset`] to synthetically generate a dummy dataset given these class names. 30 | 31 | ```py 32 | from setfit import get_templated_dataset 33 | 34 | train_dataset = get_templated_dataset() 35 | ``` 36 | ```py 37 | print(train_dataset) 38 | # => Dataset({ 39 | # features: ['text', 'label'], 40 | # num_rows: 48 41 | # }) 42 | print(train_dataset[0]) 43 | # {'text': 'This sentence is sadness', 'label': 0} 44 | ``` 45 | 46 | ## Training 47 | 48 | We can use this dataset to train a SetFit model just like normal: 49 | 50 | ```py 51 | from setfit import SetFitModel, Trainer, TrainingArguments 52 | 53 | model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5") 54 | 55 | args = TrainingArguments( 56 | batch_size=32, 57 | num_epochs=1, 58 | ) 59 | 60 | trainer = Trainer( 61 | model=model, 62 | args=args, 63 | train_dataset=train_dataset, 64 | eval_dataset=test_dataset, 65 | ) 66 | trainer.train() 67 | ``` 68 | ``` 69 | ***** Running training ***** 70 | Num examples = 60 71 | Num epochs = 1 72 | Total optimization steps = 60 73 | Total train batch size = 32 74 | {'embedding_loss': 0.2628, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.02} 75 | {'embedding_loss': 0.0222, 'learning_rate': 3.7037037037037037e-06, 'epoch': 0.83} 76 | {'train_runtime': 15.4717, 'train_samples_per_second': 124.098, 'train_steps_per_second': 3.878, 'epoch': 1.0} 77 | 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00, 6.35it/s] 78 | ``` 79 | 80 | Once trained, we can evaluate the model: 81 | 82 | ```py 83 | metrics = trainer.evaluate() 84 | print(metrics) 85 | ``` 86 | ``` 87 | ***** Running evaluation ***** 88 | {'accuracy': 0.591} 89 | ``` 90 | 91 | And run predictions: 92 | 93 | ```py 94 | preds = model.predict([ 95 | "i am just feeling cranky and blue", 96 | "i feel incredibly lucky just to be able to talk to her", 97 | "you're pissing me off right now", 98 | "i definitely have thalassophobia, don't get me near water like that", 99 | "i did not see that coming at all", 100 | ]) 101 | print([classes[idx] for idx in preds]) 102 | ``` 103 | ```py 104 | ['sadness', 'joy', 'anger', 'fear', 'surprise'] 105 | ``` 106 | 107 | These predictions all look right! 108 | 109 | ## Baseline 110 | 111 | To show that the zero-shot performance of SetFit works well, we'll compare it against a zero-shot classification model from `transformers`. 112 | 113 | ```py 114 | from transformers import pipeline 115 | from datasets import load_dataset 116 | import evaluate 117 | 118 | # Prepare the testing dataset 119 | test_dataset = load_dataset("dair-ai/emotion", "split", split="test") 120 | classes = test_dataset.features["label"].names 121 | 122 | # Set up the zero-shot classification pipeline from transformers 123 | # Uses 'facebook/bart-large-mnli' by default 124 | pipe = pipeline("zero-shot-classification", device=0) 125 | zeroshot_preds = pipe(test_dataset["text"], batch_size=16, candidate_labels=classes) 126 | preds = [classes.index(pred["labels"][0]) for pred in zeroshot_preds] 127 | 128 | # Compute the accuracy 129 | metric = evaluate.load("accuracy") 130 | transformers_accuracy = metric.compute(predictions=preds, references=test_dataset["label"]) 131 | print(transformers_accuracy) 132 | ``` 133 | ```py 134 | {'accuracy': 0.3765} 135 | ``` 136 | 137 | With its 59.1% accuracy, the 0-shot SetFit heavily outperforms the recommended zero-shot model by `transformers`. 138 | 139 | ## Prediction latency 140 | 141 | Beyond getting higher accuracies, SetFit is much faster too. Let's compute the latency of SetFit with `BAAI/bge-small-en-v1.5` versus the latency of `transformers` with `facebook/bart-large-mnli`. Both tests were performed on a GPU. 142 | 143 | ```py 144 | import time 145 | 146 | start_t = time.time() 147 | pipe(test_dataset["text"], batch_size=32, candidate_labels=classes) 148 | delta_t = time.time() - start_t 149 | print(f"`transformers` with `facebook/bart-large-mnli` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence") 150 | ``` 151 | ``` 152 | `transformers` with `facebook/bart-large-mnli` latency: 31.1765ms per sentence 153 | ``` 154 | 155 | ```py 156 | import time 157 | 158 | start_t = time.time() 159 | model.predict(test_dataset["text"]) 160 | delta_t = time.time() - start_t 161 | print(f"SetFit with `BAAI/bge-small-en-v1.5` latency: {delta_t / len(test_dataset['text']) * 1000:.4f}ms per sentence") 162 | ``` 163 | ``` 164 | SetFit with `BAAI/bge-small-en-v1.5` latency: 0.4600ms per sentence 165 | ``` 166 | 167 | So, SetFit with `BAAI/bge-small-en-v1.5` is 67x faster than `transformers` with `facebook/bart-large-mnli`, alongside being more accurate: 168 | 169 | ![zero_shot_transformers_vs_setfit](https://github.com/huggingface/setfit/assets/37621491/33f574d9-c51b-4e02-8d98-6e04e18427ef) 170 | -------------------------------------------------------------------------------- /docs/source/en/index.mdx: -------------------------------------------------------------------------------- 1 | # SetFit 2 | 3 | 8 | 9 | 🤗 SetFit is an efficient and prompt-free framework for few-shot fine-tuning of [Sentence Transformers](https://sbert.net/). It achieves high accuracy with little labeled data - for instance, with only 8 labeled examples per class on the Customer Reviews sentiment dataset, 🤗 SetFit is competitive with fine-tuning RoBERTa Large on the full training set of 3k examples! 10 | 11 | Compared to other few-shot learning methods, SetFit has several unique features: 12 | 13 | * 🗣 **No prompts or verbalizers:** Current techniques for few-shot fine-tuning require handcrafted prompts or verbalizers to convert examples into a format suitable for the underlying language model. SetFit dispenses with prompts altogether by generating rich embeddings directly from text examples. 14 | * 🏎 **Fast to train:** SetFit doesn't require large-scale models like T0, Llama or GPT-4 to achieve high accuracy. As a result, it is typically an order of magnitude (or more) faster to train and run inference with. 15 | * 🌎 **Multilingual support**: SetFit can be used with any [Sentence Transformer](https://huggingface.co/models?library=sentence-transformers&sort=downloads) on the Hub, which means you can classify text in multiple languages by simply fine-tuning a multilingual checkpoint. 16 | 17 |
18 |
19 |
Tutorials
21 |

Learn the basics and become familiar with loading pretrained Sentence Transformers and fine-tuning them on data. Start here if you are using 🤗 SetFit for the first time!

22 |
23 |
How-to guides
25 |

Practical guides to help you achieve a specific goal. Take a look at these guides to learn how to use 🤗 SetFit to solve real-world problems.

26 |
27 |
Conceptual guides
29 |

High-level explanations for building a better understanding about important topics such as few-shot and contrastive learning.

30 |
31 |
Reference
33 |

Technical descriptions of how 🤗 SetFit classes and methods work.

34 |
35 |
36 |
37 | -------------------------------------------------------------------------------- /docs/source/en/installation.mdx: -------------------------------------------------------------------------------- 1 | 2 | # Installation 3 | 4 | Before you start, you'll need to setup your environment and install the appropriate packages. 🤗 SetFit is tested on **Python 3.9+**. 5 | 6 | ## pip 7 | 8 | The most straightforward way to install 🤗 SetFit is with pip: 9 | 10 | ```bash 11 | pip install setfit 12 | ``` 13 | 14 | If you have a CUDA-capable graphics card, then it is recommended to [install `torch` with CUDA support](https://pytorch.org/get-started/locally/) to train and performing inference much more quickly: 15 | 16 | ```bash 17 | pip install torch --index-url https://download.pytorch.org/whl/cu118 18 | ``` 19 | 20 | ## Installing from source 21 | 22 | Building 🤗 SetFit from source lets you make changes to the code base. To install from the source, clone the repository and install 🤗 SetFit in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) with the following commands: 23 | 24 | ```bash 25 | git clone https://github.com/huggingface/setfit.git 26 | cd setfit 27 | pip install -e . 28 | ``` 29 | 30 | If you just want the bleeding-edge version without making any changes of your own, then install from source by running: 31 | 32 | ```bash 33 | pip install git+https://github.com/huggingface/setfit.git 34 | ``` 35 | 36 | ## Conda 37 | 38 | If conda is your package management system of choice, then you can install 🤗 SetFit like so: 39 | 40 | ```bash 41 | conda install -c conda-forge setfit 42 | ``` -------------------------------------------------------------------------------- /docs/source/en/reference/main.mdx: -------------------------------------------------------------------------------- 1 | 2 | # Main Classes 3 | 4 | ## SetFitModel 5 | 6 | [[autodoc]] SetFitModel 7 | - all 8 | - from_pretrained 9 | - save_pretrained 10 | - push_to_hub 11 | - __call__ 12 | - label2id 13 | - id2label 14 | 15 | ## SetFitHead 16 | 17 | [[autodoc]] SetFitHead 18 | 19 | ## SetFitModelCardData 20 | 21 | [[autodoc]] SetFitModelCardData 22 | - to_dict 23 | - to_yaml 24 | 25 | ## AbsaModel 26 | 27 | [[autodoc]] AbsaModel 28 | - __call__ 29 | - device 30 | - from_pretrained 31 | - predict 32 | - push_to_hub 33 | - to 34 | - save_pretrained 35 | 36 | ### AspectModel 37 | 38 | [[autodoc]] AspectModel 39 | - __call__ 40 | - device 41 | - from_pretrained 42 | - predict 43 | - push_to_hub 44 | - save_pretrained 45 | - to 46 | 47 | ### PolarityModel 48 | 49 | [[autodoc]] PolarityModel 50 | - __call__ 51 | - device 52 | - from_pretrained 53 | - predict 54 | - push_to_hub 55 | - save_pretrained 56 | - to 57 | -------------------------------------------------------------------------------- /docs/source/en/reference/trainer.mdx: -------------------------------------------------------------------------------- 1 | 2 | # Trainer Classes 3 | 4 | ## TrainingArguments 5 | 6 | [[autodoc]] TrainingArguments 7 | - to_dict 8 | - from_dict 9 | - copy 10 | - update 11 | 12 | ## Trainer 13 | 14 | [[autodoc]] Trainer 15 | - add_callback 16 | - apply_hyperparameters 17 | - evaluate 18 | - hyperparameter_search 19 | - pop_callback 20 | - push_to_hub 21 | - remove_callback 22 | - train 23 | - train_classifier 24 | - train_embeddings 25 | 26 | ## DistillationTrainer 27 | 28 | [[autodoc]] DistillationTrainer 29 | - add_callback 30 | - apply_hyperparameters 31 | - evaluate 32 | - hyperparameter_search 33 | - pop_callback 34 | - push_to_hub 35 | - remove_callback 36 | - train 37 | - train_classifier 38 | - train_embeddings 39 | 40 | ## AbsaTrainer 41 | 42 | [[autodoc]] AbsaTrainer 43 | - add_callback 44 | - evaluate 45 | - pop_callback 46 | - push_to_hub 47 | - remove_callback 48 | - train 49 | - train_aspect 50 | - train_polarity 51 | -------------------------------------------------------------------------------- /docs/source/en/reference/utility.mdx: -------------------------------------------------------------------------------- 1 | 2 | # Utility Functions 3 | 4 | [[autodoc]] get_templated_dataset 5 | 6 | [[autodoc]] sample_dataset -------------------------------------------------------------------------------- /docs/source/en/tutorials/overview.mdx: -------------------------------------------------------------------------------- 1 | 2 | # Overview 3 | 4 | Welcome to the SetFit tutorials! These tutorials are designed to walk you through particular applications. For example, we'll delve into topics such as zero-shot text classification, where you'll learn how to use SetFit without any predefined labels or examples during training. See also the [SetFit Notebooks](https://github.com/huggingface/setfit/tree/main/notebooks) for more applications, such as hyperparameter searching and ONNX, though some might be outdated. 5 | 6 | For more concise guides on how to configure SetFit or use it for specific forms of text classification, see the [How-to Guides](../how_to/overview) section. 7 | 8 | If you have any questions about SetFit, feel free to open an [issue](https://github.com/huggingface/setfit/issues). 9 | -------------------------------------------------------------------------------- /final_results/adapet/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/adapet/.gitkeep -------------------------------------------------------------------------------- /final_results/adapet/albert-xxlarge-v2.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/adapet/albert-xxlarge-v2.tar.gz -------------------------------------------------------------------------------- /final_results/adapet/xlm-roberta-base_all__lang_prompt.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/adapet/xlm-roberta-base_all__lang_prompt.tar.gz -------------------------------------------------------------------------------- /final_results/adapet/xlm-roberta-base_each__lang_prompt.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/adapet/xlm-roberta-base_each__lang_prompt.tar.gz -------------------------------------------------------------------------------- /final_results/adapet/xlm-roberta-base_en__eng_prompt.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/adapet/xlm-roberta-base_en__eng_prompt.tar.gz -------------------------------------------------------------------------------- /final_results/distillation/distillation_results.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/distillation/distillation_results.tar.gz -------------------------------------------------------------------------------- /final_results/perfect.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/perfect.tar.gz -------------------------------------------------------------------------------- /final_results/setfit/all-MiniLM-L6-v2.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/all-MiniLM-L6-v2.tar.gz -------------------------------------------------------------------------------- /final_results/setfit/all-roberta-large-v1.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/all-roberta-large-v1.tar.gz -------------------------------------------------------------------------------- /final_results/setfit/paraphrase-mpnet-base-v2-epochs_20-with-augmentation.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/paraphrase-mpnet-base-v2-epochs_20-with-augmentation.tar.gz -------------------------------------------------------------------------------- /final_results/setfit/paraphrase-mpnet-base-v2-epochs_20.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/paraphrase-mpnet-base-v2-epochs_20.tar.gz -------------------------------------------------------------------------------- /final_results/setfit/paraphrase-mpnet-base-v2-epochs_5.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/paraphrase-mpnet-base-v2-epochs_5.tar.gz -------------------------------------------------------------------------------- /final_results/setfit/paraphrase-mpnet-base-v2-linear-probe.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/paraphrase-mpnet-base-v2-linear-probe.tar.gz -------------------------------------------------------------------------------- /final_results/setfit/paraphrase-multilingual-mpnet-base-v2.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/setfit/paraphrase-multilingual-mpnet-base-v2.tar.gz -------------------------------------------------------------------------------- /final_results/tfew/t011b_pretrained-v2.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/tfew/t011b_pretrained-v2.tar.gz -------------------------------------------------------------------------------- /final_results/tfew/t03b_pretrained-v2.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/tfew/t03b_pretrained-v2.tar.gz -------------------------------------------------------------------------------- /final_results/transformers/distilbert-base-uncased.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/transformers/distilbert-base-uncased.tar.gz -------------------------------------------------------------------------------- /final_results/transformers/roberta-large.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/transformers/roberta-large.tar.gz -------------------------------------------------------------------------------- /final_results/transformers/xlm-roberta-base.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/final_results/transformers/xlm-roberta-base.tar.gz -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/notebooks/.gitkeep -------------------------------------------------------------------------------- /scripts/adapet/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/scripts/adapet/.gitkeep -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/.gitignore: -------------------------------------------------------------------------------- 1 | exp_out 2 | exp_out/ 3 | cur_exp_dir 4 | env/ 5 | results/ 6 | wandb/ 7 | lib/ 8 | output/ 9 | adapet_models 10 | adapet_models/ 11 | data 12 | data/ 13 | pretrained_models 14 | pretrained_models/ 15 | .idea/ 16 | eche_ 17 | runs/ 18 | *.pyc 19 | .installed.cfg 20 | develop-eggs 21 | dist 22 | downloads 23 | eggs 24 | parts 25 | src/*.egg-info 26 | lib 27 | lib64 28 | !src/data 29 | env.yaml 30 | pet.yml 31 | results/ 32 | results 33 | venv 34 | venv/ 35 | output_slurm 36 | output_slurm/ 37 | slurm_adapetsetfit.sh 38 | seed_ouput 39 | seed_ouput/ -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/README.md: -------------------------------------------------------------------------------- 1 | # ADAPET for SetFit # 2 | This is a fork of the original ADAPET, which can be found [here](https://github.com/rrmenon10/ADAPET). 3 | 4 | Our results were created in Python 3.6.8 with a 40 GB NVIDIA A100 Tensor Core GPU 5 | 6 | To setup, please clone the repo and follow the instructions below. 7 | ```` 8 | python3.6 -m venv venv 9 | source mvenv/bin/activate 10 | pip install -r requirements.txt 11 | pip install torch==1.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html 12 | ```` 13 | 14 | 15 | 16 | To run ADAPET on SetFit datasets, specify the dataset and PLM as argument to python. Additionally, if you wish to 17 | run on a multilingual dataset, then you can choose to prompt/verbalize in the English or the language in question. 18 | You can also remove the prompt. 19 | 20 | For example, if you wish to run ADAPET on `sst2` with a prompt and with `albert-xxlarge-v2` as the PLM, simply run the following 21 | ``` 22 | python setfit_adapet.py --pretrained_weight="albert-xxlarge-v2"\ 23 | --english=True\ 24 | --prompt=True\ 25 | --task_name='SetFit/sst2'\ 26 | ``` 27 | 28 | If you wish to run ADAPET on amazon_reviews_multi_ja, prompt in Japanese, with mdeberta-base 29 | ``` 30 | python setfit_adapet.py --pretrained_weight="microsoft/mdeberta-v3-base"\ 31 | --english=False\ 32 | --prompt=True\ 33 | --task_name='SetFit/amazon_reviews_multi_ja'\ 34 | ``` 35 | 36 | This will run ADAPET and evaluate it on the test set for the 8 and 64 splits. 37 | 38 | In the multilingual case, ADAPET runs in the "each" scenario as described in the paper by default. You can change this by adding a multilingual argument, such as 39 | 40 | ``` 41 | python setfit_adapet.py --pretrained_weight="microsoft/mdeberta-v3-base"\ 42 | --english=False\ 43 | --prompt=True\ 44 | --task_name='SetFit/amazon_reviews_multi_ja'\ 45 | --multilingual='all'\ 46 | ``` 47 | 48 | 49 | Note that with default hyperparameters this will take a very long time. If you change to distilbert-base-uncase or another smaller model, ADAPET will be much faster. 50 | 51 | Once ADAPET is done, the results will be written as follows 52 | ``` 53 | seed_output / model_name / dataset_name / train-{num_samples}-{split_dx} / results.json 54 | ``` 55 |   56 | For non-English datasets, the plm may have different endings. The ending describes the prompting situation: 57 | ```` 58 | {plm}__eng_prompt == prompt and verbalize in english 59 | {plm}__lang_prompt == prompt and verbalize in the language in question. 60 | {plm}__lang_no-prompt == take the prompt away and verbalize in the language in question 61 | {plm}__eng_no-prompt == take the prompt away and verbalize english 62 | ```` 63 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/bin/dev.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -exu 4 | 5 | exp_dir=$1 6 | 7 | python -m src.dev -e $exp_dir 8 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/bin/init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -d "data/superglue/" ] ; then 4 | mkdir -p data 5 | cd data 6 | mkdir -p superglue 7 | cd superglue 8 | wget "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/combined.zip" 9 | unzip combined.zip 10 | cd ../.. 11 | fi 12 | 13 | if [ ! -d "data/fewglue/" ] ; then 14 | mkdir -p data 15 | cd data 16 | git clone https://github.com/timoschick/fewglue.git 17 | cd fewglue 18 | rm -rf .git 19 | rm README.md 20 | mv FewGLUE/* . 21 | rm -r FewGLUE 22 | cd ../.. 23 | fi 24 | 25 | if [ ! -d "env" ] ; then 26 | python -m venv env 27 | source env/bin/activate 28 | pip install --upgrade pip 29 | pip install -r requirements.txt 30 | fi -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/bin/setup.sh: -------------------------------------------------------------------------------- 1 | source env/bin/activate 2 | export ADAPET_ROOT=`pwd` 3 | export PYTHONPATH=$ADAPET_ROOT:$PYTHONPATH 4 | export PYTHON_EXEC=python 5 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/bin/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -exu 4 | 5 | exp_dir=$1 6 | 7 | python -m src.test -e $exp_dir 8 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/bin/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -exu 4 | 5 | config_file=$1 6 | 7 | python -m src.train -c $config_file 8 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/cli.py: -------------------------------------------------------------------------------- 1 | #import argparse 2 | import os 3 | import json 4 | 5 | from src.utils.Config import Config 6 | from src.train import train 7 | 8 | 9 | 10 | def call_adapet(updated_args): 11 | ''' 12 | parser = argparse.ArgumentParser() 13 | 14 | # Arguments for running any datasets 15 | parser.add_argument('-d', "--data_dir", default=configs['generic_data_dir'], 16 | help="Data directory containing train/val/test jsonl files") 17 | parser.add_argument('-p', "--pattern", default=configs['pattern'], 18 | help="Pattern to be used for this dataset") 19 | parser.add_argument('-v', "--dict_verbalizer", type=json.loads, default=configs['dict_verbalizer'], 20 | help="Dictionary mapping label name (in dataset) to the verbalizer to use, e.g. '{\"0\": \"Yes\", \"1\": \"No\"}'") 21 | 22 | # Model and training hyperparams 23 | parser.add_argument('-w', '--pretrained_weight', type=str, default=configs['pretrained_weight'], 24 | help='Pretrained model weights from huggingface') 25 | parser.add_argument('-bs', '--batch_size', type=int, default=1, help='batch size during training') 26 | parser.add_argument('--eval_batch_size', type=int, default=configs['batch_size'], 27 | help='batch size during evaluation') 28 | parser.add_argument('--grad_accumulation_factor', type=int, default=16, help='number of gradient accumulation steps') 29 | parser.add_argument('--num_batches', type=int, default=configs['num_batches'], 30 | help='number of batches for experiment; 1 batch = grad_accumulation_factor x batch_size') 31 | 32 | parser.add_argument('--eval_every', type=int, default=configs['eval_every'], 33 | help='number of training batches per evaluation') 34 | parser.add_argument('--max_text_length', type=int, default=256, help='maximum total input sequence length after tokenization for ADAPET') 35 | parser.add_argument('--lr', type=float, default=1e-5, help='learning rate for the model') 36 | parser.add_argument('--weight_decay', type=float, default=1e-2, help='weight decay for the optmizer') 37 | parser.add_argument('--grad_clip_norm', type=float, default=1, help='gradient clipping norm') 38 | parser.add_argument('--warmup_ratio', type=float, default=0.06, help='linear warmup over warmup_steps for num_batches') 39 | 40 | # ADAPET hyperparameters 41 | parser.add_argument('--pattern_idx', default=1, help="Pattern index among all patterns available; For SuperGLUE, can use numbers >1 depending on dataset. For a new dataset, please set this to 1.") 42 | parser.add_argument('--mask_alpha', type=float, default=0.105, help='masking ratio for the label conditioning loss') 43 | parser.add_argument('--idx_txt_trim', type=int, default=1, help="TXT_ID of the text that can be trimmed (usually the longer text). Eg. if TXT1 needs to be trimmed, set this to 1.") 44 | parser.add_argument('--max_num_lbl_tok', type=int, default=configs['max_num_lbl_tok'], help="The maximum number of tokens per label for the verbalizer. It will raise an error if the tokenizer tokenizes a label into more than 'max_num_lbl_tok' tokens.") 45 | 46 | # Replicating SuperGLUE results 47 | parser.add_argument('-c', '--config', type=str, default=None, help='Use this for replicating SuperGLUE results.') 48 | ''' 49 | 50 | 51 | #args = final_parser.parse_args() 52 | args = updated_args 53 | 54 | # If even one of these three arguments are provided, we need all three as input 55 | if args.data_dir or args.pattern or args.dict_verbalizer: 56 | assert args.data_dir and args.pattern and args.dict_verbalizer, 'Please enter all of data_dir, pattern, dict_verbalizer!' 57 | 58 | if args.config is not None: 59 | use_config = args.config 60 | else: 61 | assert args.data_dir or args.pattern or args.dict_verbalizer, 'Please enter all of data_dir, pattern, dict_verbalizer if not providing config!' 62 | use_config = os.path.join("config", "Generic.json") 63 | 64 | update_config = vars(args) 65 | config = Config(use_config, update_config, mkdir=True) 66 | train(config) 67 | print(config.exp_dir) 68 | return config.exp_dir 69 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/config/BoolQ.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_weight": "albert-xxlarge-v2", 3 | "dataset": "fewglue/BoolQ", 4 | "max_text_length": 256, 5 | "batch_size": 1, 6 | "eval_batch_size": 1, 7 | "num_batches": 1000, 8 | "max_num_lbl_tok": 1, 9 | "eval_every": 250, 10 | "warmup_ratio": 0.06, 11 | "mask_alpha": 0.105, 12 | "grad_accumulation_factor": 16, 13 | "seed": 42, 14 | "lr": 1e-5, 15 | "weight_decay": 1e-2, 16 | "pattern_idx": 1, 17 | "eval_train": true 18 | } 19 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/config/CB.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_weight": "albert-xxlarge-v2", 3 | "dataset": "fewglue/CB", 4 | "max_text_length": 256, 5 | "batch_size": 1, 6 | "eval_batch_size": 1, 7 | "num_batches": 1000, 8 | "max_num_lbl_tok": 1, 9 | "mask_alpha": 0.105, 10 | "eval_every": 250, 11 | "warmup_ratio": 0.06, 12 | "grad_accumulation_factor": 16, 13 | "seed": 42, 14 | "lr": 1e-5, 15 | "dropout_rate": 0.1, 16 | "weight_decay": 1e-2, 17 | "pattern_idx": 1, 18 | "eval_train": true 19 | } 20 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/config/COPA.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_weight": "albert-xxlarge-v2", 3 | "dataset": "fewglue/COPA", 4 | "max_text_length": 256, 5 | "batch_size": 1, 6 | "eval_batch_size": 1, 7 | "num_batches": 1000, 8 | "eval_every": 250, 9 | "mask_alpha": 0.105, 10 | "warmup_ratio": 0.06, 11 | "max_num_lbl_tok": 20, 12 | "grad_accumulation_factor": 16, 13 | "seed": 42, 14 | "lr": 1e-5, 15 | "dropout_rate": 0.1, 16 | "weight_decay": 1e-2, 17 | "pattern_idx": 1, 18 | "eval_train": true 19 | } 20 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/config/Generic.json: -------------------------------------------------------------------------------- 1 | {"pretrained_weight": "microsoft/mdeberta-v3-base", "dataset": "generic", "generic_data_dir": "data/SetFit/amazon_reviews_multi_ja", "pattern": "[TEXT1]これは[LBL]だ", "pattern_idx": 1, "dict_verbalizer": "{\"0\": \"一つ星\", \"1\": \"二つ星\", \"2\": \"三つ星\", \"3\": \"四つ星\", \"4\": \"五つ星\"}", "idx_txt_trim": 1, "max_text_length": 256, "batch_size": 1, "eval_batch_size": 1, "num_batches": 10, "max_num_lbl_tok": 3, "eval_every": 10, "eval_train": true, "warmup_ratio": 0.06, "mask_alpha": 0.105, "grad_accumulation_factor": 16, "seed": 0, "lr": 1e-05, "weight_decay": 0.01} -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/config/MultiRC.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_weight": "albert-xxlarge-v2", 3 | "dataset": "fewglue/MultiRC", 4 | "max_text_length": 512, 5 | "batch_size": 1, 6 | "eval_batch_size": 1, 7 | "num_batches": 1000, 8 | "max_num_lbl_tok": 1, 9 | "mask_alpha": 0.105, 10 | "eval_every": 250, 11 | "warmup_ratio": 0.06, 12 | "grad_accumulation_factor": 16, 13 | "seed": 42, 14 | "lr": 1e-5, 15 | "dropout_rate": 0.1, 16 | "weight_decay": 1e-2, 17 | "pattern_idx": 1, 18 | "eval_train": true 19 | } 20 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/config/RTE.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_weight": "albert-xxlarge-v2", 3 | "dataset": "fewglue/RTE", 4 | "max_text_length": 256, 5 | "batch_size": 1, 6 | "eval_batch_size": 1, 7 | "num_batches": 1000, 8 | "max_num_lbl_tok": 1, 9 | "mask_alpha": 0.105, 10 | "eval_every": 250, 11 | "warmup_ratio": 0.06, 12 | "grad_accumulation_factor": 16, 13 | "seed": 42, 14 | "lr": 1e-5, 15 | "dropout_rate": 0.1, 16 | "weight_decay": 1e-2, 17 | "pattern_idx": 1, 18 | "eval_train": true 19 | } 20 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/config/ReCoRD.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_weight": "albert-xxlarge-v2", 3 | "dataset": "fewglue/ReCoRD", 4 | "max_text_length": 512, 5 | "batch_size": 1, 6 | "eval_batch_size": 1, 7 | "num_batches": 1000, 8 | "eval_every": 250, 9 | "mask_alpha": 0.105, 10 | "warmup_ratio": 0.06, 11 | "max_num_lbl_tok": 20, 12 | "grad_accumulation_factor": 16, 13 | "seed": 42, 14 | "lr": 1e-5, 15 | "dropout_rate": 0.1, 16 | "weight_decay": 1e-2, 17 | "pattern_idx": 1, 18 | "eval_train": true 19 | } 20 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/config/WSC.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_weight": "albert-xxlarge-v2", 3 | "dataset": "fewglue/WSC", 4 | "max_text_length": 256, 5 | "batch_size": 1, 6 | "eval_batch_size": 1, 7 | "num_batches": 1000, 8 | "eval_every": 250, 9 | "mask_alpha": 0.105, 10 | "warmup_ratio": 0.06, 11 | "max_num_lbl_tok": 20, 12 | "grad_accumulation_factor": 16, 13 | "seed": 42, 14 | "lr": 1e-5, 15 | "weight_decay": 1e-2, 16 | "pattern_idx": 1, 17 | "eval_train": true 18 | } 19 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/config/WiC.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_weight": "albert-xxlarge-v2", 3 | "dataset": "fewglue/WiC", 4 | "max_text_length": 256, 5 | "batch_size": 1, 6 | "eval_batch_size": 8, 7 | "num_batches": 1000, 8 | "eval_every": 250, 9 | "warmup_ratio": 0.06, 10 | "grad_accumulation_factor": 16, 11 | "max_num_lbl_tok": 1, 12 | "seed": 42, 13 | "lr": 1e-5, 14 | "dropout_rate": 0.1, 15 | "weight_decay": 1e-2, 16 | "pattern_idx": 1, 17 | "eval_train": true 18 | } 19 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/config/sst-2.json: -------------------------------------------------------------------------------- 1 | { 2 | "pretrained_weight": "distilbert-base-uncased", 3 | "dataset": "generic", 4 | "generic_data_dir": "data/SST-2", 5 | "pattern": "[TEXT1] it is [LBL]", 6 | "pattern_idx": 1, 7 | "dict_verbalizer": {"0": "No", "1": "Yes"}, 8 | "idx_txt_trim": 1, 9 | "max_text_length": 64, 10 | "batch_size": 16, 11 | "eval_batch_size": 8, 12 | "num_batches": 1000, 13 | "max_num_lbl_tok": 1, 14 | "eval_every": 50, 15 | "warmup_ratio": 0.06, 16 | "mask_alpha": 0.105, 17 | "grad_accumulation_factor": 1, 18 | "seed": 42, 19 | "lr": 1e-5, 20 | "weight_decay": 1e-2 21 | } 22 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/requirements.txt: -------------------------------------------------------------------------------- 1 | crcmod 2 | pandas==1.1.5 3 | datasets==1.18.2 4 | numpy==1.19 5 | jsonpickle==1.1 6 | scikit-learn==0.23.1 7 | torch===1.5.0 8 | torchvision==0.6.0 9 | transformers==4.15.0 10 | tqdm==4.62.1 11 | sentencepiece==0.1.96 12 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/adapet_test_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from src.adapet import adapet 6 | from src.data.Batcher import Batcher 7 | from src.eval.eval_model import test_eval 8 | from src.utils.Config import Config 9 | from src.utils.util import device 10 | from transformers import * 11 | 12 | 13 | os.environ["WANDB_DISABLED"] = "true" 14 | 15 | def test_evaluation(exp_dir): 16 | config_file = os.path.join(exp_dir, "config.json") 17 | os.path.join(exp_dir, "config.json") 18 | config = Config(config_file, mkdir=False) 19 | 20 | tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight) 21 | batcher = Batcher(config, tokenizer, config.dataset) 22 | dataset_reader = batcher.get_dataset_reader() 23 | 24 | model = adapet(config, tokenizer, dataset_reader).to(device) 25 | model.load_state_dict(torch.load(os.path.join(exp_dir, "best_model.pt"))) 26 | #model.load_state_dict(torch.load(os.path.join(args.exp_dir, "best_model.pt"))) 27 | test_eval(config, model, batcher) -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/data/Batcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | from torch.utils import data 5 | from src.data.Dataset import Dataset 6 | from src.data.DatasetReader import DatasetReader 7 | from src.utils.util import set_seeds 8 | 9 | 10 | class Batcher(object): 11 | ''' 12 | Batcher is responsible for returning batches of data 13 | ''' 14 | def __init__(self, config, tokenizer, dataset): 15 | ''' 16 | :param config: 17 | :param tokenizer: 18 | :param dataset: 19 | ''' 20 | self.config = config 21 | self.dataset_reader = DatasetReader(config, tokenizer, dataset) 22 | set_seeds(self.config.seed) 23 | 24 | self.train_loader = None 25 | self.dev_loader = None 26 | self.test_loader = None 27 | self.eval_train_loader = None 28 | 29 | self.data_len = None 30 | 31 | self.collate_fn = None 32 | if "record" in self.config.dataset: 33 | self.collate_fn = Batcher.my_collate_fn 34 | 35 | def get_dataset_reader(self): 36 | return self.dataset_reader 37 | 38 | @staticmethod 39 | def my_collate_fn(batch): 40 | 41 | dict_batch = {} 42 | dict_batch["input"] = {} 43 | dict_batch["output"] = {} 44 | 45 | for datapoint in batch: 46 | for (k, v) in datapoint["input"].items(): 47 | if k in dict_batch["input"]: 48 | dict_batch["input"][k].append(v) 49 | else: 50 | dict_batch["input"][k] = [v] 51 | 52 | for (k, v) in datapoint["output"].items(): 53 | if k in dict_batch["output"]: 54 | dict_batch["output"][k].append(v) 55 | else: 56 | dict_batch["output"][k] = [v] 57 | 58 | for (k, list_v) in dict_batch["input"].items(): 59 | if isinstance(list_v[0], int): 60 | dict_batch["input"][k] = torch.tensor(list_v) 61 | for (k, list_v) in dict_batch["output"].items(): 62 | if isinstance(list_v[0], int): 63 | dict_batch["output"][k] = torch.tensor(list_v) 64 | 65 | return dict_batch 66 | 67 | def _init_train(self): 68 | ''' 69 | Initialize loader for train data 70 | ''' 71 | train_data = self.dataset_reader.read_dataset("train") 72 | self.train_loader = data.DataLoader(Dataset(train_data), batch_size=self.config.batch_size, shuffle=True, collate_fn=self.my_collate_fn) 73 | 74 | eval_train_data = self.dataset_reader.read_dataset("train", is_eval=True) 75 | self.eval_train_loader = data.DataLoader(Dataset(eval_train_data), batch_size=self.config.eval_batch_size, shuffle=False, collate_fn=self.my_collate_fn) 76 | 77 | 78 | def _init_dev(self): 79 | ''' 80 | Initialize loader for dev data 81 | ''' 82 | dev_data = self.dataset_reader.read_dataset("dev") 83 | self.dev_loader = data.DataLoader(Dataset(dev_data), batch_size=self.config.eval_batch_size, shuffle=False, collate_fn=self.my_collate_fn) 84 | 85 | def _init_test(self): 86 | ''' 87 | Initialize loader for test data 88 | ''' 89 | test_data = self.dataset_reader.read_dataset("test") 90 | self.test_loader = data.DataLoader(Dataset(test_data), batch_size=self.config.eval_batch_size, shuffle=False, collate_fn=self.my_collate_fn) 91 | 92 | def get_train_batch(self): 93 | ''' 94 | Yield train batches 95 | 96 | :return: 97 | ''' 98 | if self.train_loader is None: 99 | self._init_train() 100 | 101 | while True: 102 | for x in self.train_loader: 103 | yield x 104 | 105 | def get_eval_train_batch(self): 106 | ''' 107 | Yield non-shuffled train batches 108 | 109 | :return: 110 | ''' 111 | if self.eval_train_loader is None: 112 | self._init_train() 113 | for x in self.eval_train_loader: 114 | yield x 115 | 116 | def get_dev_batch(self): 117 | ''' 118 | Yield dev batches 119 | 120 | :return: 121 | ''' 122 | if self.dev_loader is None: 123 | self._init_dev() 124 | 125 | for x in self.dev_loader: 126 | yield x 127 | 128 | 129 | def get_test_batch(self): 130 | ''' 131 | Yield test batches 132 | 133 | :return: 134 | ''' 135 | if self.test_loader is None: 136 | self._init_test() 137 | 138 | for x in self.test_loader: 139 | yield x 140 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/data/Dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils import data 3 | import numpy as np 4 | 5 | class Dataset(data.Dataset): 6 | def __init__(self, data): 7 | self.data = data 8 | 9 | def __len__(self): 10 | return len(self.data) 11 | 12 | def __getitem__(self, get_idx): 13 | return self.data[get_idx] -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/data/DatasetReader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | from src.data.BoolQReader import BoolQReader 5 | from src.data.CBReader import CBReader 6 | from src.data.RTEReader import RTEReader 7 | from src.data.MultiRCReader import MultiRCReader 8 | from src.data.WiCReader import WiCReader 9 | from src.data.COPAReader import COPAReader 10 | from src.data.RecordReader import RecordReader 11 | from src.data.WSCReader import WSCReader 12 | from src.data.GenericReader import GenericReader 13 | 14 | class DatasetReader(object): 15 | ''' 16 | DatasetReader is responsible for reading dataset 17 | ''' 18 | def __init__(self, config, tokenizer, dataset): 19 | ''' 20 | :param config: 21 | :param tokenizer: 22 | :param dataset: 23 | ''' 24 | self.config = config 25 | self.dataset = dataset 26 | 27 | if self.dataset.lower() == "fewglue/boolq": 28 | self.dataset_reader = BoolQReader(self.config, tokenizer) 29 | elif self.dataset.lower() == "fewglue/cb": 30 | self.dataset_reader = CBReader(self.config, tokenizer) 31 | elif self.dataset.lower() == "fewglue/rte": 32 | self.dataset_reader = RTEReader(self.config, tokenizer) 33 | elif self.dataset.lower() == "fewglue/multirc": 34 | self.dataset_reader = MultiRCReader(self.config, tokenizer) 35 | elif self.dataset.lower() == "fewglue/wic": 36 | self.dataset_reader = WiCReader(self.config, tokenizer) 37 | elif self.dataset.lower() == "fewglue/copa": 38 | self.dataset_reader = COPAReader(self.config, tokenizer) 39 | elif self.dataset.lower() == "fewglue/record": 40 | self.dataset_reader = RecordReader(self.config, tokenizer) 41 | elif self.dataset.lower() == "fewglue/wsc": 42 | self.dataset_reader = WSCReader(self.config, tokenizer) 43 | elif self.dataset.lower() == "generic": 44 | self.dataset_reader = GenericReader(self.config, tokenizer) 45 | else: 46 | raise ValueError("Invalid Dataset name") 47 | 48 | def get_num_lbl_tok(self): 49 | ''' 50 | Get number of token in labels for dataset 51 | 52 | :return: 53 | ''' 54 | return self.dataset_reader.get_num_lbl_tok() 55 | 56 | def read_dataset(self, split, is_eval=False): 57 | ''' 58 | Read dataset 59 | 60 | :param split: 61 | :param is_eval: 62 | :return: 63 | ''' 64 | return np.asarray(self.dataset_reader.read_dataset(split, is_eval)) 65 | 66 | def prepare_batch(self, batch, type): 67 | ''' 68 | Prepare batch of data for model 69 | 70 | :param batch: 71 | :param type: pattern to prepare batch with and which mode to use (ex: PET_MLM_PET1) 72 | :return: 73 | ''' 74 | # Prepare for PET MLM objective 75 | if "PET_MLM" in type: 76 | return self.dataset_reader.prepare_pet_mlm_batch(batch, mode=type.replace("PET_MLM_", "")) 77 | # Prepare for evaluation objective 78 | elif "EVAL" in type: 79 | return self.dataset_reader.prepare_eval_pet_batch(batch, mode=type.replace("EVAL_", "")) 80 | # Default is preparing for PET/Decoupled Label objective 81 | else: 82 | return self.dataset_reader.prepare_pet_batch(batch, mode=type) 83 | 84 | def store_test_lbl(self, list_idx, pred_lbl, true_lbl, logits): 85 | ''' 86 | Store test outputs for SuperGLUE to submit to leaderboard 87 | 88 | :param list_idx: 89 | :param pred_lbl: 90 | :param true_lbl: 91 | :param logits: 92 | :return: 93 | ''' 94 | self.dataset_reader.store_test_lbl(list_idx, pred_lbl, true_lbl, logits) 95 | 96 | def flush_file(self, write_file): 97 | ''' 98 | Write out contents of test predictions to file 99 | 100 | :param write_file: 101 | :return: 102 | ''' 103 | self.dataset_reader.flush_file(write_file) 104 | 105 | def get_num_lbl(self): 106 | ''' 107 | Get number of lbls in dataset 108 | 109 | :return: 110 | ''' 111 | return self.dataset_reader.num_lbl 112 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/data/tokenize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import random 4 | from collections import defaultdict 5 | 6 | 7 | def tokenize_pet_mlm_txt(tokenizer, config, txt1, txt2, txt3, txt_trim, mask_idx=None): 8 | ''' 9 | Tokenizes the text by trimming the appropriate txt 10 | 11 | :param tokenizer: 12 | param config: 13 | :param txt1: 14 | :param txt2: 15 | :param txt3: 16 | :param mask_txt1: 17 | :param mask_txt2: 18 | :param mask_txt3: 19 | :param txt_trim: idx of text to trim will never contain label 20 | :return mask_idx: list of list of idx of mask token in trunc_input_ids (in case lbl is more than 1 token) 21 | ''' 22 | 23 | txt1_input_ids = tokenizer(txt1, add_special_tokens=False)["input_ids"] 24 | txt2_input_ids = tokenizer(txt2, add_special_tokens=False)["input_ids"] 25 | txt3_input_ids = tokenizer(txt3, add_special_tokens=False)["input_ids"] 26 | 27 | # Add 1 to account for CLS rep 28 | tot_length = len(txt1_input_ids) + len(txt2_input_ids) + len(txt3_input_ids) + 1 29 | 30 | # Don't need to trim text 31 | if tot_length <= config.max_text_length: 32 | trunc_input_ids = [tokenizer.pad_token_id] * config.max_text_length 33 | trunc_input_ids[:tot_length] = txt1_input_ids + txt2_input_ids + txt3_input_ids 34 | 35 | # Trim text 36 | else: 37 | num_trim = tot_length - config.max_text_length 38 | 39 | if txt_trim == 0: 40 | new_txt1_input_ids = txt1_input_ids[:-num_trim] 41 | trunc_input_ids = new_txt1_input_ids + txt2_input_ids + txt3_input_ids 42 | elif txt_trim == 1: 43 | new_txt2_input_ids = txt2_input_ids[:-num_trim] 44 | trunc_input_ids = txt1_input_ids + new_txt2_input_ids + txt3_input_ids 45 | elif txt_trim == 2: 46 | new_txt_3_input_ids = txt3_input_ids[:-num_trim] 47 | trunc_input_ids = txt1_input_ids + txt2_input_ids + new_txt_3_input_ids 48 | else: 49 | raise ValueError("Invalid Txt Trim") 50 | 51 | trunc_input_ids = [tokenizer.cls_token_id] + trunc_input_ids 52 | 53 | if mask_idx is None: 54 | sample_length = min(tot_length, config.max_text_length) 55 | upto_ratio_mask = np.random.rand() 56 | num_sample = max(int(upto_ratio_mask * config.mask_alpha * sample_length), 2) - 1 57 | mask_idx = random.sample(range(0, sample_length), k=num_sample) 58 | mask_idx = np.asarray(mask_idx) 59 | 60 | # Copy adds mask idx at random positions 61 | unsup_masked_ids = np.copy(trunc_input_ids) 62 | 63 | unsup_masked_ids[mask_idx] = tokenizer.mask_token_id 64 | 65 | return trunc_input_ids, unsup_masked_ids, mask_idx 66 | 67 | def tokenize_pet_txt(tokenizer, config, txt1, txt2, txt3, mask_txt1, mask_txt2, mask_txt3, txt_trim): 68 | ''' 69 | Tokenizes the text by trimming the appropriate txt 70 | 71 | :param txt1: 72 | :param txt2: 73 | :param txt3: 74 | :param mask_txt1: 75 | :param mask_txt2: 76 | :param mask_txt3: 77 | :param txt_trim: text to trim will never contain label 78 | :return trunc_input_ids: list of input ids (each exactly max_config_length) 79 | :return mask_idx: list of list of idx of mask token in trunc_input_ids (in case lbl is more than 1 token) 80 | ''' 81 | txt1_input_ids = tokenizer(txt1, add_special_tokens=False)["input_ids"] 82 | txt2_input_ids = tokenizer(txt2, add_special_tokens=False)["input_ids"] 83 | txt3_input_ids = tokenizer(txt3, add_special_tokens=False)["input_ids"] 84 | 85 | mask_txt1_input_ids = tokenizer(mask_txt1, add_special_tokens=False)["input_ids"] 86 | mask_txt2_input_ids = tokenizer(mask_txt2, add_special_tokens=False)["input_ids"] 87 | mask_txt3_input_ids = tokenizer(mask_txt3, add_special_tokens=False)["input_ids"] 88 | 89 | # Add 1 to account for CLS rep 90 | tot_length = len(txt1_input_ids) + len(txt2_input_ids) + len(txt3_input_ids) + 1 91 | tot_mask_length = len(mask_txt1_input_ids) + len(mask_txt2_input_ids) + len(mask_txt3_input_ids) + 1 92 | 93 | # Don't need to trim text 94 | if tot_length <= config.max_text_length: 95 | trunc_input_ids = [tokenizer.pad_token_id] * config.max_text_length 96 | trunc_input_ids[:tot_length] = txt1_input_ids + txt2_input_ids + txt3_input_ids 97 | 98 | trunc_mask_input_ids = [tokenizer.pad_token_id] * config.max_text_length 99 | trunc_mask_input_ids[:tot_mask_length] = mask_txt1_input_ids + mask_txt2_input_ids + mask_txt3_input_ids 100 | 101 | # Trim text 102 | else: 103 | num_trim = tot_length - config.max_text_length 104 | 105 | if txt_trim == 0: 106 | new_txt1_input_ids = txt1_input_ids[:-num_trim] 107 | new_mask_txt1_input_ids = mask_txt1_input_ids[:-num_trim] 108 | trunc_input_ids = new_txt1_input_ids + txt2_input_ids + txt3_input_ids 109 | trunc_mask_input_ids = new_mask_txt1_input_ids + mask_txt2_input_ids + mask_txt3_input_ids 110 | elif txt_trim == 1: 111 | new_txt2_input_ids = txt2_input_ids[:-num_trim] 112 | new_mask_txt2_input_ids = mask_txt2_input_ids[:-num_trim] 113 | trunc_input_ids = txt1_input_ids + new_txt2_input_ids + txt3_input_ids 114 | trunc_mask_input_ids = mask_txt1_input_ids + new_mask_txt2_input_ids + mask_txt3_input_ids 115 | elif txt_trim == 2: 116 | new_txt_3_input_ids = txt3_input_ids[:-num_trim] 117 | new_mask_txt3_input_ids = mask_txt3_input_ids[:-num_trim] 118 | trunc_input_ids = txt1_input_ids + txt2_input_ids + new_txt_3_input_ids 119 | trunc_mask_input_ids = mask_txt1_input_ids + mask_txt2_input_ids + new_mask_txt3_input_ids 120 | else: 121 | raise ValueError("Invalid Txt Trim") 122 | 123 | 124 | trunc_input_ids = [tokenizer.cls_token_id] + trunc_input_ids 125 | trunc_mask_input_ids = [tokenizer.cls_token_id] + trunc_mask_input_ids 126 | 127 | mask_idx = trunc_mask_input_ids.index(tokenizer.mask_token_id) 128 | 129 | return trunc_input_ids, mask_idx 130 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/dev.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import numpy as np 5 | from transformers import * 6 | 7 | from src.data.Batcher import Batcher 8 | from src.adapet import adapet 9 | from src.utils.Config import Config 10 | from src.eval.eval_model import dev_eval 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-e', "--exp_dir", required=True) 15 | args = parser.parse_args() 16 | 17 | config_file = os.path.join(args.exp_dir, "config.json") 18 | config = Config(config_file, mkdir=False) 19 | config.eval_dev = True 20 | 21 | tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight) 22 | batcher = Batcher(config, tokenizer, config.dataset) 23 | dataset_reader = batcher.get_dataset_reader() 24 | 25 | model = adapet(config, tokenizer, dataset_reader).to(device) 26 | model.load_state_dict(torch.load(os.path.join(args.exp_dir, "best_model.pt"))) 27 | dev_acc, dev_logits = dev_eval(config, model, batcher, 0) 28 | 29 | with open(os.path.join(config.exp_dir, "dev_logits.npy"), 'wb') as f: 30 | np.save(f, dev_logits) 31 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/eval/Writer.py: -------------------------------------------------------------------------------- 1 | 2 | class Writer(object): 3 | 4 | def __init__(self, file, dataset_reader): 5 | self.write_file = open(file, 'w+') 6 | self.dataset_reader = dataset_reader 7 | 8 | def add_batch(self, list_idx, list_pred_lbl, list_true_lbl, lbl_logits): 9 | self.dataset_reader.store_test_lbl(list_idx, list_pred_lbl, list_true_lbl, lbl_logits) 10 | 11 | def flush_file(self): 12 | self.dataset_reader.flush_file(self.write_file) -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/eval/eval_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | #added 5 | import os 6 | import time 7 | import numpy as np 8 | 9 | from src.eval.Scorer import Scorer 10 | from src.eval.Writer import Writer 11 | 12 | def eval(config, model, batch_iter, scorer): 13 | ''' 14 | Evaluate model 15 | 16 | :param config: 17 | :param model: 18 | :param batch_iter: 19 | :param scorer: 20 | :return: 21 | ''' 22 | model.eval() 23 | with torch.no_grad(): 24 | for idx, batch in enumerate(batch_iter): 25 | pred_lbl, lbl_logits = model.predict(batch) 26 | list_idx = batch["input"]["idx"] if isinstance(batch["input"]["idx"], list) else batch["input"]["idx"].cpu().numpy().tolist() 27 | list_lbl = batch["output"]["true_lbl"] if "true_lbl" in batch["output"] else batch["output"]["lbl"] 28 | 29 | if config.dataset.lower() == 'fewglue/record': 30 | true_lbl = torch.tensor([1]) 31 | pred_lbl = torch.tensor([list_lbl[0][pred_lbl[0].item()]]) 32 | scorer.add_batch(list_idx, pred_lbl, true_lbl, lbl_logits.cpu().numpy(), None) 33 | else: 34 | scorer.add_batch(list_idx, pred_lbl, list_lbl, lbl_logits.cpu().numpy(), None) 35 | 36 | 37 | 38 | def dev_eval(config, model, batcher, num_batches, dict_avg_val=None): 39 | ''' 40 | Evaluates the accuracy on the dev partition 41 | 42 | :param config: 43 | :param model: 44 | :param batcher: batcher to get batches of data 45 | :param num_batches: 46 | :param dict_avg_val: dictionary storing metrics 47 | 48 | :return: currrent dev score 49 | ''' 50 | 51 | dict_eval = {} 52 | dict_eval["num_batches"] = num_batches 53 | 54 | if dict_avg_val is not None: 55 | dict_eval.update(dict_avg_val) 56 | 57 | # Get train Score 58 | if config.eval_train: 59 | train_scorer = Scorer(config, config.dataset) 60 | train_iter = batcher.get_eval_train_batch() 61 | eval(config, model, train_iter, train_scorer) 62 | _, train_scores = train_scorer.get_score("train") 63 | dict_eval.update(train_scores) 64 | 65 | # Get dev Score 66 | if config.eval_dev: 67 | dev_scorer = Scorer(config, config.dataset) 68 | dev_iter = batcher.get_dev_batch() 69 | eval(config, model, dev_iter, dev_scorer) 70 | score_eval, dev_scores = dev_scorer.get_score("dev") 71 | dict_eval.update(dev_scores) 72 | dev_logits = dev_scorer.get_logits() 73 | else: 74 | score_eval = 0 75 | dev_logits = None 76 | 77 | with open(config.dev_score_file, 'a+') as f_out: 78 | f_out.write(json.dumps(dict_eval)) 79 | f_out.write('\n') 80 | 81 | return score_eval, dev_logits 82 | 83 | def test_eval(config, model, batcher): 84 | ''' 85 | Evaluates the accuracy on the test partition 86 | 87 | :param config: 88 | :param model: 89 | :param batcher: 90 | ''' 91 | 92 | model.eval() 93 | dataset_reader = batcher.get_dataset_reader() 94 | test_writer = Writer(os.path.join(config.exp_dir, "test.json"), dataset_reader) 95 | 96 | with torch.no_grad(): 97 | #added 98 | pred_labels = [] 99 | pred_logits = [] 100 | t0 = time.time() 101 | for idx, batch in enumerate(batcher.get_test_batch()): 102 | t1 = time.time() 103 | pred_lbl, lbl_logits = model.predict(batch) 104 | 105 | #lbl_logits = lbl_logits.cpu().numpy() 106 | 107 | 108 | pred_labels.extend(pred_lbl.cpu().numpy().tolist()) 109 | pred_logits.extend(lbl_logits.cpu().numpy().tolist()) 110 | 111 | list_idx = batch["input"]["idx"] if isinstance(batch["input"]["idx"], list) else batch["input"][ 112 | "idx"].cpu().numpy().tolist() 113 | list_lbl = batch["output"]["true_lbl"] if "true_lbl" in batch["output"] else batch["output"]["lbl"] 114 | 115 | if config.dataset.lower() == 'fewglue/record': 116 | list_idx = batch["input"]["qas_idx"] 117 | list_lbl = batch["input"]["candidate_entity"] 118 | test_writer.add_batch(list_idx, pred_lbl, list_lbl, lbl_logits.cpu().numpy()) 119 | else: 120 | test_writer.add_batch(list_idx, pred_lbl, list_lbl, lbl_logits.cpu().numpy()) 121 | 122 | #added 123 | t2 = time.time() 124 | diff1 = t1-t0 125 | diff2 = t2-t1 126 | diff3 = t2-t0 127 | #json_dict = {'start_loop':diff1, 'inside_loop':diff2, 'once_through':diff3} 128 | #writefile = 'time_difference/'+config.exp_dir 129 | #if not os.path.exists(writefile): 130 | # os.makedirs(writefile) 131 | #print(writefile) 132 | #writefile = writefile+'time.json' 133 | #with open(writefile, "a") as f: 134 | # f.write(json.dumps(json_dict)+ '\n') 135 | t3 = time.time() 136 | print('total inference time: {}'.format(t3-t0)) 137 | #altered 138 | #print(pred_logits) 139 | test_writer.flush_file() 140 | return pred_labels, np.array(pred_logits) 141 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/run_pretrained.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import numpy as np 5 | from transformers import * 6 | 7 | from src.data.Batcher import Batcher 8 | from src.utils.Config import Config 9 | from src.utils.util import device, ParseKwargs 10 | from src.adapet import adapet 11 | from src.eval.eval_model import dev_eval 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-m', "--model_dir", required=True) 16 | parser.add_argument('-c', "--config_file", required=True) 17 | parser.add_argument('-k', '--kwargs', nargs='*', action=ParseKwargs, default={}) 18 | args = parser.parse_args() 19 | 20 | config = Config(args.config_file, args.kwargs, mkdir=True) 21 | 22 | tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight) 23 | batcher = Batcher(config, tokenizer, config.dataset) 24 | dataset_reader = batcher.get_dataset_reader() 25 | 26 | model = adapet(config, tokenizer, dataset_reader).to(device) 27 | model.load_state_dict(torch.load(os.path.join(args.model_dir, "best_model.pt"))) 28 | dev_eval(config, model, batcher, 0) 29 | 30 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/scripts/example_convert_sst_2_generic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | 5 | def read_tsv(filename, max_num_lines): 6 | train_json_filename = filename.replace(".tsv", ".jsonl") 7 | 8 | with open(filename, 'r') as f_in, open(train_json_filename, 'w+') as f_out: 9 | f_in.readline() 10 | for idx, line in enumerate(f_in.readlines()): 11 | if idx < max_num_lines: 12 | tab_split = line.strip('\n').split('\t') 13 | dict_json = {"TEXT1": tab_split[0], "LBL": tab_split[1]} 14 | 15 | f_out.write(json.dumps(dict_json) + '\n') 16 | else: 17 | break 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('-f', "--filepath", required=True) 23 | parser.add_argument('-m', "--max_num_lines", type=int, default=32) 24 | args = parser.parse_args() 25 | 26 | read_tsv(args.filepath, args.max_num_lines) -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/test.py: -------------------------------------------------------------------------------- 1 | #import argparse 2 | import os 3 | import torch 4 | import numpy as np 5 | from transformers import * 6 | 7 | from src.data.Batcher import Batcher 8 | from src.utils.Config import Config 9 | #altered 10 | from src.utils.util import device 11 | from src.adapet import adapet 12 | from src.eval.eval_model import test_eval 13 | 14 | 15 | def do_test(exp_dir): 16 | #device = torch.device("cpu") 17 | config_file = os.path.join(exp_dir, "config.json") 18 | config = Config(config_file, mkdir=False) 19 | 20 | tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight) 21 | batcher = Batcher(config, tokenizer, config.dataset) 22 | dataset_reader = batcher.get_dataset_reader() 23 | 24 | model = adapet(config, tokenizer, dataset_reader).to(device) 25 | model.load_state_dict(torch.load(os.path.join(exp_dir, "best_model.pt"))) 26 | #altered 27 | pred_labels, pred_logits = test_eval(config, model, batcher) 28 | 29 | return pred_labels, pred_logits 30 | 31 | #if __name__ == "__main__": 32 | #parser = argparse.ArgumentParser() 33 | #parser.add_argument('-e', "--exp_dir", required=True) 34 | 35 | #args = parser.parse_args() 36 | 37 | ''' 38 | config_file = os.path.join(args.exp_dir, "config.json") 39 | config = Config(config_file, mkdir=False) 40 | 41 | tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight) 42 | batcher = Batcher(config, tokenizer, config.dataset) 43 | dataset_reader = batcher.get_dataset_reader() 44 | 45 | model = adapet(config, tokenizer, dataset_reader).to(device) 46 | model.load_state_dict(torch.load(os.path.join(args.exp_dir, "best_model.pt"))) 47 | test_eval(config, model, batcher) 48 | ''' 49 | 50 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import re 4 | import numpy as np 5 | import argparse 6 | import logging 7 | os.environ["WANDB_DISABLED"] = "true" 8 | from transformers import * 9 | 10 | from src.eval.eval_model import dev_eval 11 | from src.adapet import adapet 12 | from torch.optim.lr_scheduler import LambdaLR 13 | 14 | from src.data.Batcher import Batcher 15 | from src.utils.Config import Config 16 | from src.utils.util import get_avg_dict_val_store, update_dict_val_store, ParseKwargs 17 | from src.utils.util import set_global_logging_level, device 18 | 19 | 20 | 21 | set_global_logging_level(logging.ERROR) 22 | 23 | # From HuggingFace 24 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 25 | """ 26 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, 27 | after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 28 | 29 | Args: 30 | optimizer (:class:`~torch.optim.Optimizer`): 31 | The optimizer for which to schedule the learning rate. 32 | num_warmup_steps (:obj:`int`): 33 | The number of steps for the warmup phase. 34 | num_training_steps (:obj:`int`): 35 | The total number of training steps. 36 | last_epoch (:obj:`int`, `optional`, defaults to -1): 37 | The index of the last epoch when resuming training. 38 | 39 | Return: 40 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 41 | """ 42 | 43 | def lr_lambda(current_step: int): 44 | if current_step < num_warmup_steps: 45 | return float(current_step) / float(max(1, num_warmup_steps)) 46 | return max( 47 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 48 | ) 49 | 50 | return LambdaLR(optimizer, lr_lambda, last_epoch) 51 | 52 | 53 | def train(config): 54 | ''' 55 | Trains the model 56 | 57 | :param config: 58 | :return: 59 | ''' 60 | 61 | tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight) 62 | batcher = Batcher(config, tokenizer, config.dataset) 63 | dataset_reader = batcher.get_dataset_reader() 64 | model = adapet(config, tokenizer, dataset_reader).to(device) 65 | 66 | ### Create Optimizer 67 | # Ignore weight decay for certain parameters 68 | no_decay_param = ['bias', 'LayerNorm.weight'] 69 | optimizer_grouped_parameters = [ 70 | {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay_param)], 71 | 'weight_decay': config.weight_decay, 72 | 'lr': config.lr}, 73 | {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay_param)], 74 | 'weight_decay': 0.0, 75 | 'lr': config.lr}, 76 | ] 77 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, eps=1e-8) 78 | 79 | #altered 80 | #best_dev_acc = 0 81 | best_dev_acc = -float('inf') 82 | train_iter = batcher.get_train_batch() 83 | dict_val_store = None 84 | 85 | # Number of batches is assuming grad_accumulation_factor forms one batch 86 | tot_num_batches = config.num_batches * config.grad_accumulation_factor 87 | 88 | # Warmup steps and total steps are based on batches, not epochs 89 | num_warmup_steps = config.num_batches * config.warmup_ratio 90 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, config.num_batches) 91 | 92 | for i in range(tot_num_batches): 93 | # Get true batch_idx 94 | batch_idx = (i // config.grad_accumulation_factor) 95 | 96 | model.train() 97 | sup_batch = next(train_iter) 98 | loss, dict_val_update = model(sup_batch) 99 | loss = loss / config.grad_accumulation_factor 100 | loss.backward() 101 | 102 | if (i+1) % config.grad_accumulation_factor == 0: 103 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm) 104 | optimizer.step() 105 | optimizer.zero_grad() 106 | scheduler.step() 107 | 108 | dict_val_store = update_dict_val_store(dict_val_store, dict_val_update, config.grad_accumulation_factor) 109 | print("Finished %d batches" % batch_idx, end='\r') 110 | 111 | if (batch_idx + 1) % config.eval_every == 0 and i % config.grad_accumulation_factor == 0: 112 | dict_avg_val = get_avg_dict_val_store(dict_val_store, config.eval_every) 113 | dict_val_store = None 114 | dev_acc, dev_logits = dev_eval(config, model, batcher, batch_idx, dict_avg_val) 115 | #altered but not used 116 | if type(dev_acc) == str: 117 | f1s = re.findall(r"[-+]?\d*\.\d+|\d+", dev_acc) 118 | dev_acc = float(f1s[0]) 119 | 120 | print("Global Step: %d Acc: %.3f" % (batch_idx, float(dev_acc)) + '\n') 121 | 122 | if dev_acc > best_dev_acc: 123 | best_dev_acc = dev_acc 124 | torch.save(model.state_dict(), os.path.join(config.exp_dir, "best_model.pt")) 125 | with open(os.path.join(config.exp_dir, "dev_logits.npy"), 'wb') as f: 126 | np.save(f, dev_logits) 127 | 128 | 129 | if __name__ == "__main__": 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('-c', "--config_file", required=True) 132 | parser.add_argument('-k', '--kwargs', nargs='*', action=ParseKwargs, default={}) 133 | args = parser.parse_args() 134 | 135 | config = Config(args.config_file, args.kwargs, mkdir=True) 136 | train(config) 137 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/utils/Config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import ast 4 | 5 | from src.utils.util import make_exp_dir 6 | 7 | class Config(object): 8 | def __init__(self, filename=None, kwargs=None, mkdir=True): 9 | # Dataset parameters 10 | self.dataset = "fewglue/BoolQ" 11 | self.num_lbl = 2 12 | self.max_num_lbl_tok = 10 13 | self.max_num_lbl = 10 14 | 15 | # Model and pattern parameters 16 | self.pretrained_weight = "bert-base-uncased" 17 | self.pattern_idx = "random" 18 | 19 | # Duration of training parameters 20 | self.batch_size = 8 21 | self.eval_batch_size = 64 22 | self.num_batches = 1000 23 | self.eval_every = 1 24 | self.grad_accumulation_factor = 1 25 | self.max_text_length = 64 26 | 27 | self.mask_alpha = 0.5 28 | 29 | self.eval_train = False 30 | self.eval_dev = True 31 | 32 | # Where experiments will be located 33 | self.exp_dir = None 34 | self.seed = 42 35 | self.exp_name = "" 36 | 37 | # Training Hyperparameters 38 | self.lr = 1e-3 39 | self.weight_decay = 0 40 | self.grad_clip_norm = 1 41 | self.warmup_ratio = 0 42 | 43 | # Generic dataset hyperparameters 44 | self.pattern = "[TEXT1] and [TEXT2] " 45 | self.idx_txt_trim = -1 # Indexed from 1 46 | self.dict_verbalizer = {"True": "Yes", "False": "No"} 47 | self.data_dir = "data/fewglue/BoolQ" 48 | #Added 49 | self.task_name = 'SetFit/sst2' 50 | 51 | if filename: 52 | self.__dict__.update(json.load(open(filename))) 53 | if kwargs: 54 | self.update_kwargs(kwargs) 55 | 56 | if filename or kwargs: 57 | self.update_exp_config(mkdir) 58 | 59 | def update_kwargs(self, kwargs): 60 | for (k, v) in kwargs.items(): 61 | try: 62 | v = ast.literal_eval(v) 63 | except: 64 | v = v 65 | setattr(self, k, v) 66 | 67 | def update_exp_config(self, mkdir=True): 68 | ''' 69 | Updates the config default values based on parameters passed in from config file 70 | ''' 71 | 72 | 73 | self.base_dir = os.path.join("exp_out", self.dataset, self.pretrained_weight, self.task_name) 74 | if self.exp_name != "": 75 | self.base_dir = os.path.join(self.base_dir, self.exp_name) 76 | 77 | if mkdir: 78 | self.exp_dir = make_exp_dir(self.base_dir) 79 | 80 | if self.exp_dir is not None: 81 | self.dev_pred_file = os.path.join(self.exp_dir, "dev_pred.txt") 82 | self.dev_score_file = os.path.join(self.exp_dir, "dev_scores.json") 83 | self.test_score_file = os.path.join(self.exp_dir, "test_scores.json") 84 | self.save_config(os.path.join(self.exp_dir, os.path.join("config.json"))) 85 | 86 | def to_json(self): 87 | ''' 88 | Converts parameter values in config to json 89 | :return: json 90 | ''' 91 | #altered -- ensure ascii now == False 92 | return json.dumps(self.__dict__, indent=4, sort_keys=True, ensure_ascii=False) 93 | 94 | def save_config(self, filename): 95 | ''' 96 | Saves the config 97 | ''' 98 | with open(filename, 'w+') as fout: 99 | fout.write(self.to_json()) 100 | fout.write('\n') 101 | -------------------------------------------------------------------------------- /scripts/adapet/ADAPET/src/utils/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | import os 4 | import sys 5 | import argparse 6 | import subprocess 7 | from shutil import copytree, ignore_patterns 8 | import random 9 | import numpy as np 10 | import logging 11 | import re 12 | import sys 13 | 14 | global device; device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | def set_global_logging_level(level=logging.ERROR, prefices=[""]): 17 | """ 18 | Override logging levels of different modules based on their name as a prefix. 19 | It needs to be invoked after the modules have been loaded so that their loggers have been initialized. 20 | 21 | Args: 22 | - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR 23 | - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional. 24 | Default is `[""]` to match all active loggers. 25 | The match is a case-sensitive `module_name.startswith(prefix)` 26 | """ 27 | prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })') 28 | for name in logging.root.manager.loggerDict: 29 | if re.match(prefix_re, name): 30 | logging.getLogger(name).setLevel(level) 31 | 32 | 33 | def set_seeds(seed): 34 | "set random seeds" 35 | random.seed(seed) 36 | np.random.seed(seed) 37 | torch.manual_seed(seed) 38 | torch.cuda.manual_seed_all(seed) 39 | 40 | def make_dir(dir_name): 41 | ''' 42 | Makes a directory if it doesn't exists yet 43 | Args: 44 | dir_name: directory name 45 | ''' 46 | if not os.path.exists(dir_name): 47 | os.makedirs(dir_name) 48 | 49 | def make_exp_dir(base_exp_dir): 50 | ''' 51 | Makes an experiment directory with timestamp 52 | Args: 53 | base_output_dir_name: base output directory name 54 | Returns: 55 | exp_dir_name: experiment directory name 56 | ''' 57 | now = datetime.datetime.now() 58 | ts = "{:04d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}".format(now.year, now.month, now.day, now.hour, now.minute, 59 | now.second) 60 | exp_dir_name = os.path.join(base_exp_dir, ts) 61 | make_dir(exp_dir_name) 62 | 63 | src_file = os.path.join(exp_dir_name, 'src') 64 | 65 | #copytree(os.path.join(os.environ['ADAPET_ROOT'], "src"), src_file, ignore=ignore_patterns('*.pyc', 'tmp*')) 66 | 67 | return exp_dir_name 68 | 69 | def print_mem_usage(loc): 70 | ''' 71 | Print memory usage in GB 72 | :return: 73 | ''' 74 | print("%s mem usage: %.3f GB, %.3f GB, %.3f GB" % (loc, float(torch.cuda.memory_allocated() / 1e9), float(torch.cuda.memory_reserved() / 1e9), float(torch.cuda.max_memory_allocated() / 1e9))) 75 | sys.stdout.flush() 76 | 77 | class ParseKwargs(argparse.Action): 78 | def __call__(self, parser, namespace, values, option_string=None): 79 | setattr(namespace, self.dest, dict()) 80 | for value in values: 81 | key, value = value.split('=') 82 | getattr(namespace, self.dest)[key] = value 83 | 84 | 85 | def update_dict_val_store(dict_val_store, dict_update_val, grad_accumulation_factor): 86 | ''' 87 | Update dict_val_store with dict_update_val 88 | 89 | :param dict_val_store: 90 | :param dict_update_val: 91 | :return: 92 | ''' 93 | if dict_val_store is None: 94 | dict_val_store = dict_update_val 95 | else: 96 | for k in dict_val_store.keys(): 97 | dict_val_store[k] += dict_update_val[k] / grad_accumulation_factor 98 | 99 | return dict_val_store 100 | 101 | def get_avg_dict_val_store(dict_val_store, num_batches=100): 102 | ''' 103 | Get average dictionary val 104 | 105 | :param dict_val_store: 106 | :param eval_every: 107 | :return: 108 | ''' 109 | dict_avg_val = {} 110 | 111 | for k in dict_val_store.keys(): 112 | dict_avg_val[k] = float('%.3f' % (dict_val_store[k].detach().cpu().item() / num_batches)) 113 | 114 | return dict_avg_val 115 | -------------------------------------------------------------------------------- /scripts/create_summary_table.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import tarfile 5 | from collections import defaultdict 6 | from glob import glob 7 | from os import listdir 8 | from os.path import isdir, join, splitext 9 | from typing import List, Tuple 10 | 11 | from numpy import mean, median, std 12 | from scipy.stats import iqr 13 | 14 | 15 | """ 16 | To run: python create_summary_table.py --path scripts/{method_name}/{results}/{model_name} 17 | or: python create_summary_table.py --path scripts/{method_name}/{model_name}.tar.gz 18 | Files are outputted to the directory of the results. 19 | """ 20 | 21 | TEST_DATASET_TO_METRIC = { 22 | "emotion": "accuracy", 23 | "SentEval-CR": "accuracy", 24 | "sst5": "accuracy", 25 | "ag_news": "accuracy", 26 | "enron_spam": "accuracy", 27 | "amazon_counterfactual_en": "matthews_correlation", 28 | } 29 | 30 | 31 | def extract_results(path: str) -> None: 32 | tar = tarfile.open(path, "r:gz") 33 | unzip_path = splitext(splitext(path)[-2])[-2] 34 | tar.extractall(path=os.path.dirname(unzip_path)) 35 | tar.close() 36 | return unzip_path 37 | 38 | 39 | def get_sample_sizes(path: str) -> List[str]: 40 | return sorted(list({int(name.split("-")[-2]) for name in glob(f"{path}/*/train-*-0")})) 41 | 42 | 43 | def get_tfew_sample_sizes(path: str) -> List[str]: 44 | return sorted(list({int(name.split("-")[-2]) for name in glob(f"{path}/train-*-0/seed0")})) 45 | 46 | 47 | def compute_tfew_medians(results_path: str) -> None: 48 | """Given per-split and per-seed T-Few results for multiple dataset, 49 | calculates the median score and interquartile range across all seeds, 50 | and saves them to a `results.json` file in the same path. 51 | 52 | Args: 53 | results_path: path to T-Few results: `/setfit/scripts/tfew/results/t03b_pretrained` 54 | """ 55 | 56 | for dataset in listdir(results_path): 57 | dataset_path = join(results_path, dataset) 58 | if isdir(dataset_path): 59 | dataset_metric = TEST_DATASET_TO_METRIC[dataset] 60 | sample_sizes = get_tfew_sample_sizes(dataset_path) 61 | 62 | for sample_size in sample_sizes: 63 | split_dirs = sorted(glob(join(dataset_path, f"train-{sample_size}-*"))) 64 | assert split_dirs is not None 65 | 66 | for split_dir in split_dirs: 67 | seed_results_json = sorted(glob(join(split_dir, "seed*/dev_scores.json"))) 68 | seed_metrics = [] 69 | for seed_result_json in seed_results_json: 70 | with open(seed_result_json) as f: 71 | result_dict = json.loads(f.readlines()[-1]) 72 | seed_metrics.append(result_dict[dataset_metric] * 100) 73 | 74 | with open(join(split_dir, "results.json"), "w") as f: 75 | json.dump( 76 | {"score": median(seed_metrics), "measure": dataset_metric, "iqr": iqr(seed_metrics)}, f 77 | ) 78 | 79 | 80 | def get_formatted_ds_metrics(path: str, dataset: str, sample_sizes: List[str]) -> Tuple[str, List[str]]: 81 | formatted_row = [] 82 | metric_name = "" 83 | exact_metrics, exact_stds = {}, {} 84 | 85 | for sample_size in sample_sizes: 86 | result_jsons = sorted(glob(os.path.join(path, dataset, f"train-{sample_size}-*", "results.json"))) 87 | split_metrics = [] 88 | 89 | for result_json in result_jsons: 90 | with open(result_json) as f: 91 | result_dict = json.load(f) 92 | 93 | metric_name = result_dict.get("measure", "N/A") 94 | split_metrics.append(result_dict["score"]) 95 | 96 | exact_metrics[sample_size] = mean(split_metrics) 97 | exact_stds[sample_size] = std(split_metrics) 98 | formatted_row.extend([f"{exact_metrics[sample_size]:.1f}", f"{exact_stds[sample_size]:.1f}"]) 99 | 100 | return metric_name, formatted_row, exact_metrics, exact_stds, sample_sizes 101 | 102 | 103 | def create_summary_table(results_path: str) -> None: 104 | """Given per-split results, creates a summary table of all datasets, 105 | with average metrics and standard deviations. 106 | 107 | Args: 108 | path: path to per-split results: either `scripts/{method_name}/{results}/{model_name}`, 109 | or `final_results/{method_name}/{model_name}.tar.gz` 110 | """ 111 | 112 | if results_path.endswith("tar.gz"): 113 | unzipped_path = extract_results(results_path) 114 | else: 115 | unzipped_path = results_path 116 | 117 | if "tfew" in unzipped_path: 118 | print("Computing medians for T-Few...") 119 | compute_tfew_medians(unzipped_path) 120 | 121 | sample_sizes = get_sample_sizes(unzipped_path) 122 | header_row = ["dataset", "measure"] 123 | for sample_size in sample_sizes: 124 | header_row.append(f"{sample_size}_avg") 125 | header_row.append(f"{sample_size}_std") 126 | 127 | csv_lines = [header_row] 128 | 129 | means, stds = defaultdict(list), defaultdict(list) 130 | for dataset in next(os.walk(unzipped_path))[1]: 131 | metric_name, formatted_metrics, exact_metrics, exact_stds, sample_sizes = get_formatted_ds_metrics( 132 | unzipped_path, dataset, sample_sizes 133 | ) 134 | dataset_row = [dataset, metric_name, *formatted_metrics] 135 | csv_lines.append(dataset_row) 136 | 137 | # Collect exact metrics for overall average and std calculation 138 | for sample_size in sample_sizes: 139 | means[sample_size].append(exact_metrics[sample_size]) 140 | stds[sample_size].append(exact_stds[sample_size]) 141 | 142 | # Generate row for overall average 143 | formatted_average_row = [] 144 | for sample_size in sample_sizes: 145 | overall_average = mean(means[sample_size]) 146 | overall_std = mean(stds[sample_size]) 147 | formatted_average_row.extend([f"{overall_average:.1f}", f"{overall_std:.1f}"]) 148 | csv_lines.append(["Average", "N/A", *formatted_average_row]) 149 | 150 | output_path = os.path.join(unzipped_path, "summary_table.csv") 151 | print("=" * 80) 152 | print("Summary table:\n") 153 | with open(output_path, "w") as f: 154 | for line in csv_lines: 155 | f.write(",".join(line) + "\n") 156 | print(", ".join(line)) 157 | print("=" * 80) 158 | print(f"Saved summary table to {output_path}") 159 | 160 | 161 | def main() -> None: 162 | parser = argparse.ArgumentParser() 163 | 164 | parser.add_argument("--path", type=str) 165 | args = parser.parse_args() 166 | 167 | create_summary_table(args.path) 168 | 169 | 170 | if __name__ == "__main__": 171 | main() 172 | -------------------------------------------------------------------------------- /scripts/perfect/README.md: -------------------------------------------------------------------------------- 1 | # Running PERFECT 2 | 3 | Follow the steps below to run the baselines based on the `PERFECT` paper: [_PERFECT: Prompt-free and Efficient Few-shot Learning with Language Models_](https://arxiv.org/abs/2204.01172). 4 | 5 | ## Setup 6 | 7 | To get started, first create a Python virtual environment, e.g. with `conda`: 8 | 9 | ``` 10 | conda create -n baselines-perfect python=3.10 && conda activate baselines-perfect 11 | ``` 12 | 13 | Next, clone [our fork](https://github.com/SetFit/perfect) of the [`PERFECT` codebase](https://github.com/facebookresearch/perfect), and install the required dependencies: 14 | 15 | ``` 16 | git clone git+https://github.com/SetFit/perfect.git 17 | cd perfect 18 | python setup.py develop 19 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 20 | python -m pip install -r requirements.txt 21 | ``` 22 | 23 | Next, download and process the datasets: 24 | 25 | ``` 26 | wget https://nlp.cs.princeton.edu/projects/lm-bff/datasets.tar 27 | tar -xvf datasets.tar 28 | mv original/ datasets 29 | python fewshot/process_datasets.py 30 | ``` 31 | 32 | ## Usage example 33 | 34 | To train and evaluate `PERFECT` on 8 and 64 examples (per class) across all the SetFit test datasets, run: 35 | 36 | ``` 37 | cd fewshot/ 38 | bash scripts/run_setfit_baselines.sh 39 | ``` -------------------------------------------------------------------------------- /scripts/plot_summary_comparison.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import string 5 | import sys 6 | from collections import defaultdict 7 | from glob import glob 8 | from pathlib import Path 9 | from typing import List, Tuple 10 | 11 | import matplotlib.pyplot as plt 12 | import pandas as pd 13 | 14 | 15 | """ 16 | To run: 17 | python plot_summary_comparison.py --paths scripts/{method_name}/results/{model_name} 18 | Multiple paths can be provided. The produced plots are outputted to scripts/images/v_{id}/{dataset}.png. 19 | 20 | See https://github.com/huggingface/setfit/pull/268#issuecomment-1434549208 for an example of the plots 21 | produced by this script. 22 | """ 23 | 24 | 25 | def get_sample_sizes(path: str) -> List[str]: 26 | return sorted(list({int(name.split("-")[-2]) for name in glob(f"{path}/*/train-*-0")})) 27 | 28 | 29 | def get_formatted_ds_metrics(path: str, dataset: str, sample_sizes: List[str]) -> Tuple[str, List[str]]: 30 | split_metrics = defaultdict(list) 31 | 32 | for sample_size in sample_sizes: 33 | result_jsons = sorted(glob(os.path.join(path, dataset, f"train-{sample_size}-*", "results.json"))) 34 | for result_json in result_jsons: 35 | with open(result_json) as f: 36 | result_dict = json.load(f) 37 | 38 | metric_name = result_dict.get("measure", "N/A") 39 | split_metrics[sample_size].append(result_dict["score"]) 40 | 41 | return metric_name, split_metrics 42 | 43 | 44 | def plot_summary_comparison(paths: List[str]) -> None: 45 | """Given a list of paths to output directories produced by e.g. `scripts/setfit/run_fewshot.py`, 46 | produce and save boxplots that compare the various results. 47 | 48 | The plots are saved to scripts/images/v_{id}/{dataset}.png, i.e. one plot per dataset. 49 | 50 | Args: 51 | paths (List[str]): List of paths to output directories, generally 52 | `scripts/{method_name}/results/{model_name}` 53 | """ 54 | 55 | # Parse the result paths 56 | dataset_to_df = defaultdict(pd.DataFrame) 57 | dataset_to_metric = {} 58 | for path_index, path in enumerate(paths): 59 | ds_to_metric, this_dataset_to_df = get_summary_df(path) 60 | for dataset, df in this_dataset_to_df.items(): 61 | df["path_index"] = path_index 62 | dataset_to_df[dataset] = pd.concat((dataset_to_df[dataset], df)) 63 | dataset_to_metric = dataset_to_metric | ds_to_metric 64 | 65 | # Prepare folder for storing figures 66 | image_dir = Path("scripts") / "images" 67 | image_dir.mkdir(exist_ok=True) 68 | new_version = ( 69 | max([int(path.name[2:]) for path in image_dir.glob("v_*/") if path.name[2:].isdigit()], default=0) + 1 70 | ) 71 | output_dir = image_dir / f"v_{new_version}" 72 | output_dir.mkdir() 73 | 74 | # Save a copy the executed command in output directory 75 | (output_dir / "command.txt").write_text("python " + " ".join(sys.argv)) 76 | 77 | # Create the plots per each dataset 78 | for dataset, df in dataset_to_df.items(): 79 | columns = [column for column in df.columns if not column.startswith("path")] 80 | fig, axes = plt.subplots(ncols=len(columns), sharey=True) 81 | for column_index, column in enumerate(columns): 82 | ax = axes[column_index] if len(columns) > 1 else axes 83 | 84 | # Set the y label only for the first column 85 | if column_index == 0: 86 | ax.set_ylabel(dataset_to_metric[dataset]) 87 | 88 | # Set positions to 0, 0.25, ..., one position per boxplot 89 | # This places the boxplots closer together 90 | n_boxplots = len(df["path_index"].unique()) 91 | allotted_box_width = 0.2 92 | positions = [allotted_box_width * i for i in range(n_boxplots)] 93 | ax.set_xlim(-allotted_box_width * 0.75, allotted_box_width * (n_boxplots - 0.25)) 94 | 95 | df[[column, "path_index"]].groupby("path_index", sort=True).boxplot( 96 | subplots=False, ax=ax, column=column, positions=positions 97 | ) 98 | 99 | k_shot = column.split("-")[-1] 100 | ax.set_xlabel(f"{k_shot}-shot") 101 | if n_boxplots > 1: 102 | # If there are multiple boxplots, override the labels at the bottom generated by pandas 103 | if n_boxplots <= 26: 104 | ax.set_xticklabels(string.ascii_uppercase[:n_boxplots]) 105 | else: 106 | ax.set_xticklabels(range(n_boxplots)) 107 | else: 108 | # Otherwise, just remove the xticks 109 | ax.tick_params(labelbottom=False) 110 | 111 | if n_boxplots > 1: 112 | fig.suptitle( 113 | f"Comparison between various baselines on the {dataset}\ndataset under various $K$-shot conditions" 114 | ) 115 | else: 116 | fig.suptitle(f"Results on the {dataset} dataset under various $K$-shot conditions") 117 | fig.tight_layout() 118 | plt.savefig(str(output_dir / dataset)) 119 | 120 | 121 | def get_summary_df(path: str) -> None: 122 | """Given per-split results, return a mapping from dataset to metrics (e.g. "accuracy") and 123 | a mapping from dataset to pandas DataFrame that stores the results 124 | 125 | Args: 126 | path: path to per-split results: generally `scripts/{method_name}/results/{model_name}`, 127 | """ 128 | 129 | sample_sizes = get_sample_sizes(path) 130 | header_row = ["dataset", "measure"] 131 | for sample_size in sample_sizes: 132 | header_row.append(f"{sample_size}_avg") 133 | header_row.append(f"{sample_size}_std") 134 | 135 | dataset_to_metric = {} 136 | dataset_to_df = {} 137 | for dataset in next(os.walk(path))[1]: 138 | metric_name, split_metrics = get_formatted_ds_metrics(path, dataset, sample_sizes) 139 | dataset_df = pd.DataFrame(split_metrics.values(), index=[f"{dataset}-{key}" for key in split_metrics]).T 140 | dataset_to_metric[dataset] = metric_name 141 | dataset_to_df[dataset] = dataset_df 142 | return dataset_to_metric, dataset_to_df 143 | 144 | 145 | def main() -> None: 146 | parser = argparse.ArgumentParser() 147 | 148 | parser.add_argument("--paths", nargs="+", type=str) 149 | args = parser.parse_args() 150 | 151 | if args.paths: 152 | plot_summary_comparison(args.paths) 153 | else: 154 | raise Exception("Please provide at least one path via the `--paths` CLI argument.") 155 | 156 | 157 | if __name__ == "__main__": 158 | main() 159 | -------------------------------------------------------------------------------- /scripts/setfit/README.md: -------------------------------------------------------------------------------- 1 | # Running SetFit 2 | 3 | ## Setup 4 | 5 | To run the scripts, first create a Python virtual environment, e.g. with `conda`: 6 | 7 | ``` 8 | conda create -n baselines-setfit python=3.9 && conda activate baselines-setfit 9 | ``` 10 | 11 | Next, install the required dependencies: 12 | 13 | ``` 14 | python -m pip install setfit 15 | ``` 16 | 17 | ## Usage 18 | 19 | To train and evaluate `SetFit` on 8 examples (per class) on the `sst2` dataset, run: 20 | 21 | ``` 22 | python run_fewshot.py --sample_sizes=8 --datasets=sst2 23 | ``` 24 | 25 | This will use the default settings used in the paper, including `paraphrase-mpnet-base-v2` as the backbone model. Results will be saved in the `results` directory. To run `SetFit` across all the development datasets used in the paper, run: 26 | 27 | ``` 28 | python run_fewshot.py --sample_sizes=8 --is_dev_set=true 29 | ``` 30 | 31 | Similarly, you can run `SetFit` over all the test datasets in the paper by running: 32 | 33 | ``` 34 | python run_fewshot.py --sample_sizes=8 --is_test_set=true 35 | ``` 36 | 37 | ### Exhaustive example 38 | 39 | The following is an example with all argument options and their default values. 40 | Note that you can run on a series of datasets and sample sizes: 41 | 42 | ``` 43 | python run_fewshot.py \ 44 | --model paraphrase-mpnet-base-v2 \ 45 | --datasets sst2 ag_news bbc-news \ 46 | --sample_sizes 8 64 \ 47 | --num_epochs 1 \ 48 | --num_iterations 20 \ 49 | --batch_size 16 \ 50 | --max_seq_length 256 \ 51 | --classifier logistic_regression \ 52 | --loss CosineSimilarityLoss \ 53 | --exp_name "" \ 54 | --add_normalization_layer \ 55 | ``` 56 | 57 | ### Multilingual experiments 58 | 59 | We provide three different ways to run `SetFit` in multilingual settings: 60 | 61 | * `each`: train on data in target language 62 | * `en`: train on English data only 63 | * `all`: train on data in all languages 64 | 65 | To train `SetFit` in one of these setting, run: 66 | 67 | ``` 68 | python run_fewshot_multilingual.py \ 69 | --model sentence-transformers/paraphrase-multilingual-mpnet-base-v2 \ 70 | --datasets amazon_reviews_multi_de amazon_reviews_multi_es \ 71 | --sample_sizes 8 \ 72 | --multilinguality=each 73 | ``` 74 | 75 | To train `SetFit` on all the multilingual test sets in the paper, run: 76 | 77 | ``` 78 | python run_fewshot_multilingual.py \ 79 | --model=sentence-transformers/paraphrase-multilingual-mpnet-base-v2 \ 80 | --multilinguality=each 81 | ``` 82 | 83 | ### Multilabel experiments 84 | 85 | To run `SetFit` on one our our multilingual datasets, run: 86 | 87 | ``` 88 | python run_fewshot_multilabel.py \ 89 | --sample_sizes=8 64 \ 90 | --datasets=go_emotions 91 | ``` 92 | 93 | # Zero-shot Text Classification with SetFit 94 | Although `SetFit` was designed for few-shot learning, the method can also be applied in scenarios where no labeled data is available. The main trick is to create synthetic examples that resemble the classification task, and then train a `SetFit` model on them. 95 | 96 | Remarkably, this simple technique typically outperforms the zero-shot pipeline in 🤗 Transformers, and can generate predictions by a factor of 5x (or more) faster! 97 | 98 | To create the synthetic training examples, the label names for the task are required. 99 | The labels can be taken from a `--reference_dataset` or supplied explicitly using `--candidate_labels`. 100 | If both aren't supplied, `--eval_dataset` is used as reference dataset. 101 | 102 | To evaluate zero-shot `SetFit` on the `emotion` dataset, run: 103 | 104 | ``` 105 | python run_zeroshot.py --eval_dataset=SetFit/emotion 106 | ``` 107 | 108 | 109 | To evaluate on a custom dataset with custom label names: 110 | 111 | ``` 112 | python run_zeroshot.py --eval_dataset=[dataset_name] --candidate_labels [label_1 label2 ...] 113 | ``` 114 | -------------------------------------------------------------------------------- /scripts/setfit/distillation_baseline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from datasets import Dataset 5 | from evaluate import load 6 | from transformers import ( 7 | AutoModelForSequenceClassification, 8 | AutoTokenizer, 9 | DataCollatorWithPadding, 10 | Trainer, 11 | TrainingArguments, 12 | ) 13 | 14 | 15 | class RegressionTrainer(Trainer): 16 | def compute_loss(self, model, inputs, return_outputs=False): 17 | labels = inputs.pop("labels") 18 | outputs = model(**inputs) 19 | logits = outputs.logits 20 | loss = F.mse_loss(logits, labels) 21 | return (loss, outputs) if return_outputs else loss 22 | 23 | 24 | class BaselineDistillation: 25 | def __init__(self, student_model_name, num_epochs, batch_size) -> None: 26 | self.student_model_name = student_model_name 27 | self.num_epochs = num_epochs 28 | self.batch_size = batch_size 29 | self.tokenizer = AutoTokenizer.from_pretrained(student_model_name) 30 | self.seq_len = 64 31 | self.learning_rate = 6e-5 32 | 33 | def update_metric(self, metric): 34 | self.metric = load(metric) 35 | self.metric_name = metric 36 | 37 | def bl_student_preprocess(self, examples): 38 | label = examples["score"] 39 | examples = self.tokenizer( 40 | examples["text"], 41 | truncation=True, 42 | padding="max_length", 43 | max_length=self.seq_len, 44 | ) 45 | # Change this to real number 46 | examples["label"] = [float(i) for i in label] 47 | return examples 48 | 49 | def compute_metrics_for_regression(self, eval_pred): 50 | logits, labels = eval_pred 51 | predictions = np.argmax(logits, axis=-1) 52 | hot_labels = np.argmax(labels, axis=-1) 53 | return self.metric.compute(predictions=predictions, references=hot_labels) 54 | 55 | # ----------------------------------------------------------------# 56 | # ------------------------ Student training ----------------------# 57 | # ----------------------------------------------------------------# 58 | def standard_model_distillation(self, train_raw_student, x_test, y_test, num_classes): 59 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 60 | 61 | value2hot = {} 62 | for i in range(num_classes): 63 | a = [0] * num_classes 64 | a[i] = 1 65 | value2hot.update({i: a}) 66 | 67 | test_dict = {"text": x_test, "score": [value2hot[i] for i in y_test]} 68 | raw_test_ds = Dataset.from_dict(test_dict) 69 | 70 | # validation and test sets are the same 71 | ds = { 72 | "train": train_raw_student, 73 | "validation": raw_test_ds, 74 | "test": raw_test_ds, 75 | } 76 | for split in ds: 77 | ds[split] = ds[split].map(self.bl_student_preprocess, remove_columns=["text", "score"]) 78 | 79 | training_args = TrainingArguments( 80 | output_dir="baseline_distil_model", 81 | learning_rate=self.learning_rate, 82 | per_device_train_batch_size=self.batch_size, 83 | per_device_eval_batch_size=self.batch_size, 84 | num_train_epochs=self.num_epochs, 85 | eval_strategy="no", 86 | save_strategy="no", 87 | load_best_model_at_end=False, 88 | weight_decay=0.01, 89 | push_to_hub=False, 90 | ) 91 | 92 | # define data_collator 93 | data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer) 94 | 95 | # define student model 96 | student_model = AutoModelForSequenceClassification.from_pretrained( 97 | self.student_model_name, num_labels=num_classes 98 | ).to(device) 99 | 100 | trainer = RegressionTrainer( 101 | student_model, 102 | args=training_args, 103 | train_dataset=ds["train"], 104 | eval_dataset=ds["validation"], 105 | data_collator=data_collator, 106 | tokenizer=self.tokenizer, 107 | compute_metrics=self.compute_metrics_for_regression, 108 | ) 109 | 110 | trainer.train() 111 | 112 | trainer.eval_dataset = ds["test"] 113 | # acc = round(trainer.evaluate()["eval_accuracy"], 3) 114 | 115 | score = trainer.evaluate()[f"eval_{self.metric_name}"] 116 | return {self.metric_name: score} 117 | -------------------------------------------------------------------------------- /scripts/tfew/README.md: -------------------------------------------------------------------------------- 1 | # Running T-Few 2 | 3 | These scripts run the baselines based on the `T-Few` paper: [_Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning_](https://arxiv.org/abs/2205.05638). 4 | 5 | ## Setup 6 | 7 | To run the scripts, first create a Python virtual environment, e.g. with `conda`: 8 | 9 | ``` 10 | conda create -n baselines-tfew python=3.10 && conda activate baselines-tfew 11 | ``` 12 | 13 | Next, clone our `T-Few` fork, and install the required dependencies: 14 | 15 | ``` 16 | cd scripts/tfew 17 | git clone https://github.com/SetFit/t-few.git 18 | python -m pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 19 | ``` 20 | 21 | The steps above only need to be done once. In addition, every time you start a new session, you will need to run: 22 | ``` 23 | cd scripts/tfew 24 | . t-few/bin/start.sh 25 | ``` 26 | This sets up some required environment variables, including `PYTHONPATH`, `OUTPUT_PATH` (where results will be saved) and `CONFIG_PATH` (where the config `.json` files are stored). 27 | It also sets `CUDA_VISIBLE_DEVICES=0`. To use a different GPU, edit the file `t-few/bin/start.sh`. 28 | 29 | ## Usage example 30 | 31 | To train and evaluate `T-Few` (3B) on 8 examples (per class) on the `sst2` dataset, run: 32 | 33 | ``` 34 | python -m t-few.src.pl_train \ 35 | -c t03b.json+ia3.json+emotion.json \ 36 | -k load_weight="t-few/pretrained_checkpoints/t03b_ia3_finish.pt" \ 37 | exp_name=tfew_03b_pretrained/emotion/train-8 \ 38 | num_shot=8 \ 39 | batch_size=1 \ 40 | eval_batch_size=2 \ 41 | grad_accum_factor=8 \ 42 | ``` 43 | 44 | This will fine-tune the 3 billion parameter pretrained model using the (IA)^3 method from the `T-Few` paper, and then run the evaluation. For all our baselines, we use the default settings from the `T-Few` paper. `T-Few` comes in 2 versions: one with a 3 billion (3B) base model, and one with an 11 billion (11B) base model. 45 | 46 | You can run `T-Few` (3B) over all the supported test datasets in the `SetFit` paper by running: 47 | 48 | ``` 49 | ./run_tfew_test.sh 50 | ``` 51 | 52 | Similarly, to run `T-Few` (11B) over all test datasets, run: 53 | 54 | ``` 55 | ./run_tfew_11b.sh 56 | ``` 57 | 58 | Results will be saved to the `scripts/tfew/results` directory. 59 | The results are comprised of 10 directories, one for each training split. 60 | Each of these directories contains 5 results, one for each randomly selected training prompt. 61 | To retrieve the median score across all prompts (for each split), run the following on each dataset: 62 | 63 | ``` 64 | python scripts/create_summary_table.py --path scripts/tfew/results/{experiment_name} 65 | ``` 66 | 67 | The summary table will be saved in `results/{experiment_name}`. 68 | -------------------------------------------------------------------------------- /scripts/tfew/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.12.0+cu113 2 | datasets==2.0.0 3 | transformers==4.15.0 4 | pytorch-lightning==1.5.8 5 | torchmetrics==0.6.2 6 | psutil==5.9.0 7 | deepspeed==0.5.10 8 | sentencepiece==0.1.96 9 | scipy 10 | ipdb 11 | evaluate==0.1.2 12 | scikit-learn 13 | git+https://github.com/SetFit/promptsource.git -------------------------------------------------------------------------------- /scripts/tfew/run_tfew_11b.sh: -------------------------------------------------------------------------------- 1 | for dataset in amazon_counterfactual_en emotion enron_spam SentEval-CR sst5 2 | do 3 | for sample_size in 8 64 4 | do 5 | for train_split in 0 1 2 3 4 5 6 7 8 9 6 | do 7 | for seed in 0 1 2 3 4 8 | do 9 | 10 | python -m src.pl_train -c t011b.json+ia3.json+${dataset}.json \ 11 | -k load_weight="t-few/pretrained_checkpoints/t011b_ia3_finish.pt" \ 12 | exp_name=t011b_pretrained/${dataset}/train-${sample_size}-${train_split}/seed${seed} \ 13 | train_split=${train_split} \ 14 | few_shot_random_seed=${seed} \ 15 | seed=${seed} \ 16 | num_shot=$sample_size \ 17 | batch_size=1 \ 18 | eval_batch_size=2 \ 19 | grad_accum_factor=8 \ 20 | eval_before_training=0 21 | done 22 | done 23 | done 24 | done 25 | -------------------------------------------------------------------------------- /scripts/tfew/run_tfew_test.sh: -------------------------------------------------------------------------------- 1 | for dataset in amazon_counterfactual_en emotion enron_spam SentEval-CR sst5 2 | do 3 | for sample_size in 8 64 4 | do 5 | for train_split in 0 1 2 3 4 5 6 7 8 9 6 | do 7 | for seed in 0 1 2 3 4 8 | do 9 | python -m src.pl_train -c t03b.json+ia3.json+${dataset}.json \ 10 | -k load_weight="t-few/pretrained_checkpoints/t03b_ia3_finish.pt" \ 11 | exp_name=t03b_pretrained/${dataset}/train-${sample_size}-${train_split}/seed${seed} \ 12 | train_split=${train_split} \ 13 | few_shot_random_seed=${seed} \ 14 | seed=${seed} \ 15 | num_shot=$sample_size \ 16 | batch_size=8 \ 17 | eval_batch_size=16 \ 18 | grad_accum_factor=1 \ 19 | eval_before_training=0 20 | done 21 | done 22 | done 23 | done 24 | -------------------------------------------------------------------------------- /scripts/transformers/README.md: -------------------------------------------------------------------------------- 1 | # Transformers Baselines 2 | 3 | This folder contains the scripts used to train the 🤗 Transformers baselines quoted in the SetFit paper: [_Efficient Few-Shot Learning Without Prompts_](https://arxiv.org/abs/2209.11055). 4 | 5 | ## Setup 6 | 7 | To run the scripts, first create a Python virtual environment, e.g. with `conda`: 8 | 9 | ``` 10 | conda create -n baselines-transformers python=3.9 && conda activate baselines-transformers 11 | ``` 12 | 13 | Next, install the required dependencies 14 | 15 | ``` 16 | python -m pip install setfit 17 | python -m pip install -r requirements.txt 18 | ``` 19 | 20 | ## Usage 21 | 22 | ### Fewshot finetuning 23 | 24 | To finetune a pretrained model on a single dataset under the SetFit organization, run: 25 | 26 | ``` 27 | python run_fewshot.py train-single-dataset \ 28 | --model-id=distilbert-base-uncased \ 29 | --dataset-id=sst2 \ 30 | --metric=accuracy \ 31 | --learning-rate=2e-5 \ 32 | --batch-size=4 33 | ``` 34 | 35 | To finetune a pretrained model on all the test datasets used in SetFit, run: 36 | 37 | ``` 38 | python run_fewshot.py train-all-datasets --model-ckpt=distilbert-base-uncased --batch-size=4 39 | ``` 40 | 41 | ### Full finetuning 42 | 43 | To finetune a pretrained model on a single dataset under the SetFit organization, run: 44 | 45 | ``` 46 | python run_full.py train-single-dataset \ 47 | --model-id=distilbert-base-uncased \ 48 | --dataset-id=sst2 \ 49 | --metric=accuracy \ 50 | --learning-rate=2e-5 \ 51 | --batch-size=24 52 | ``` 53 | 54 | To finetune a pretrained model on all the test datasets used in SetFit, run: 55 | 56 | ``` 57 | python run_full.py train-all-datasets --model-id=distilbert-base-uncased --batch-size=24 58 | ``` 59 | 60 | ### Multilingual finetuning 61 | 62 | We provide three different ways to run SetFit in multilingual settings: 63 | 64 | * `each`: train on data in target language 65 | * `en`: train on English data only 66 | * `all`: train on data in all languages 67 | 68 | To finetune a baseline in one of these setting, run: 69 | 70 | ``` 71 | python run_fewshot_multilingual.py train-single-dataset \ 72 | --model-id=xlm-roberta-base \ 73 | --dataset-id=amazon_reviews_multi_en \ 74 | --metric=mae \ 75 | --learning-rate=2e-5 \ 76 | --batch-size=4 \ 77 | --multilinguality=each 78 | ``` 79 | 80 | To finetune a baseline on all the multilingual test sets in the paper, run: 81 | 82 | ``` 83 | python run_fewshot_multilingual.py train-all-datasets \ 84 | --model=xlm-roberta-base \ 85 | --learning-rate=2e-5 \ 86 | --batch-size=4 \ 87 | --multilinguality=each 88 | ``` 89 | 90 | ### Inference benchmark 91 | 92 | To run the inference benchmark, run: 93 | 94 | ``` 95 | python run_inference.py --model-id=distilbert-base-uncased__sst2__train-16-4 --dataset-id=sst2 --num-samples=100 96 | ``` 97 | -------------------------------------------------------------------------------- /scripts/transformers/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers[sentencepiece,optuna]==4.20.0 2 | torch==1.11 3 | scikit-learn 4 | typer 5 | -------------------------------------------------------------------------------- /scripts/transformers/run_full.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from pathlib import Path 3 | 4 | import torch 5 | import typer 6 | from datasets import load_dataset 7 | from evaluate import load 8 | from transformers import ( 9 | AutoModelForSequenceClassification, 10 | AutoTokenizer, 11 | EarlyStoppingCallback, 12 | Trainer, 13 | TrainingArguments, 14 | ) 15 | 16 | from setfit.utils import DEV_DATASET_TO_METRIC, TEST_DATASET_TO_METRIC 17 | from utils import get_label_mappings, save_metrics 18 | 19 | 20 | app = typer.Typer() 21 | 22 | 23 | RESULTS_PATH = Path("results") 24 | RESULTS_PATH.mkdir(parents=True, exist_ok=True) 25 | 26 | 27 | @app.command() 28 | def train_single_dataset( 29 | model_id: str = "distilbert-base-uncased", 30 | dataset_id: str = "sst2", 31 | metric: str = "accuracy", 32 | learning_rate: float = 2e-5, 33 | batch_size: int = 4, 34 | num_train_epochs: int = 20, 35 | push_to_hub: bool = False, 36 | ): 37 | """Fine-tunes a pretrained checkpoint on the fewshot training sets""" 38 | # Load dataset 39 | dataset = load_dataset(f"SetFit/{dataset_id}") 40 | model_name = model_id.split("/")[-1] 41 | 42 | # Create metrics directory 43 | metrics_dir = RESULTS_PATH / Path(f"{model_name}-lr-{learning_rate}/{dataset_id}") 44 | metrics_dir.mkdir(parents=True, exist_ok=True) 45 | # Create split directory 46 | metrics_split_dir = metrics_dir / "train-full" 47 | metrics_split_dir.mkdir(parents=True, exist_ok=True) 48 | metrics_filepath = metrics_split_dir / "results.json" 49 | # Skip previously evaluated model 50 | if metrics_filepath.is_file(): 51 | typer.echo("INFO -- model already trained, skipping ...") 52 | return 53 | 54 | # Load tokenizer and preprocess 55 | tokenizer = AutoTokenizer.from_pretrained(model_id) 56 | 57 | def tokenize_dataset(example): 58 | return tokenizer(example["text"], truncation=True, max_length=512) 59 | 60 | tokenized_dataset = dataset.map(tokenize_dataset, batched=True) 61 | # Create training and validation splits 62 | train_eval_dataset = tokenized_dataset["train"].train_test_split(seed=42, test_size=0.2) 63 | # Load model - we use a `model_init()` function here to load a fresh model with each fewshot training run 64 | num_labels, label2id, id2label = get_label_mappings(dataset["train"]) 65 | 66 | def model_init(): 67 | return AutoModelForSequenceClassification.from_pretrained( 68 | model_id, num_labels=num_labels, id2label=id2label, label2id=label2id 69 | ) 70 | 71 | # Define metrics 72 | metric_fn = load(metric) 73 | 74 | def compute_metrics(pred): 75 | labels = pred.label_ids 76 | preds = pred.predictions.argmax(-1) 77 | return metric_fn.compute(predictions=preds, references=labels) 78 | 79 | # Define hyperparameters 80 | training_args = TrainingArguments( 81 | output_dir="checkpoints/full/", 82 | overwrite_output_dir=True, 83 | num_train_epochs=num_train_epochs, 84 | learning_rate=learning_rate, 85 | per_device_train_batch_size=batch_size, 86 | per_device_eval_batch_size=batch_size, 87 | weight_decay=0.001, 88 | eval_strategy="epoch", 89 | logging_steps=100, 90 | metric_for_best_model=metric, 91 | load_best_model_at_end=True, 92 | save_strategy="epoch", 93 | save_total_limit=1, 94 | fp16=True, 95 | report_to="none", 96 | ) 97 | 98 | if push_to_hub: 99 | ckpt_name = f"{model_name}-finetuned-{dataset_id}-train-full" 100 | training_args.push_to_hub = True 101 | training_args.hub_strategy = ("end",) 102 | training_args.hub_model_id = f"SetFit/{ckpt_name}" 103 | 104 | callbacks = [EarlyStoppingCallback(early_stopping_patience=3)] 105 | 106 | trainer = Trainer( 107 | model_init=model_init, 108 | args=training_args, 109 | compute_metrics=compute_metrics, 110 | train_dataset=train_eval_dataset["train"], 111 | eval_dataset=train_eval_dataset["test"], 112 | tokenizer=tokenizer, 113 | callbacks=callbacks, 114 | ) 115 | trainer.train() 116 | 117 | # Compute final metrics on full test set 118 | metrics = trainer.evaluate(tokenized_dataset["test"]) 119 | eval_metrics = {} 120 | eval_metrics["score"] = metrics[f"eval_{metric}"] * 100.0 121 | eval_metrics["measure"] = metric 122 | 123 | # Save metrics 124 | save_metrics(eval_metrics, metrics_filepath) 125 | 126 | if push_to_hub: 127 | trainer.push_to_hub("Checkpoint upload", blocking=False) 128 | 129 | # Flush CUDA cache 130 | del trainer 131 | gc.collect() 132 | torch.cuda.empty_cache() 133 | 134 | 135 | @app.command() 136 | def train_all_datasets( 137 | model_id: str = "distilbert-base-uncased", 138 | learning_rate: float = 2e-5, 139 | batch_size: int = 4, 140 | num_train_epochs: int = 20, 141 | push_to_hub: bool = False, 142 | is_dev_set: bool = False, 143 | ): 144 | """Fine-tunes a pretrained checkpoint on all of the SetFit development/test datasets.""" 145 | if is_dev_set: 146 | DATASET_TO_METRIC = DEV_DATASET_TO_METRIC 147 | else: 148 | DATASET_TO_METRIC = TEST_DATASET_TO_METRIC 149 | 150 | for dataset_id, metric in DATASET_TO_METRIC.items(): 151 | typer.echo(f"🏋️🏋️🏋️ Fine-tuning on dataset {dataset_id} 🏋️🏋️🏋️") 152 | train_single_dataset( 153 | model_id=model_id, 154 | dataset_id=dataset_id, 155 | metric=metric, 156 | learning_rate=learning_rate, 157 | batch_size=batch_size, 158 | num_train_epochs=num_train_epochs, 159 | push_to_hub=push_to_hub, 160 | ) 161 | typer.echo("Training complete!") 162 | 163 | 164 | if __name__ == "__main__": 165 | app() 166 | -------------------------------------------------------------------------------- /scripts/transformers/run_inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from time import perf_counter 3 | 4 | import numpy as np 5 | import torch 6 | import typer 7 | from datasets import Dataset, load_dataset 8 | from transformers import pipeline 9 | 10 | 11 | RESULTS_PATH = Path("results") 12 | RESULTS_PATH.mkdir(parents=True, exist_ok=True) 13 | 14 | device_id = 0 if torch.cuda.is_available() else -1 15 | 16 | 17 | def time_pipeline(pipe: pipeline, dataset: Dataset): 18 | latencies = [] 19 | # Warm up 20 | for _ in range(10): 21 | _ = pipe("Warming up the pipeline :)") 22 | # Timed run 23 | total_start_time = perf_counter() 24 | for row in dataset: 25 | start_time = perf_counter() 26 | _ = pipe(row["text"]) 27 | latency = perf_counter() - start_time 28 | latencies.append(latency) 29 | total_time_ms = (perf_counter() - total_start_time) * 1_000 30 | # Compute run statistics 31 | time_avg_ms = 1_000 * np.mean(latencies) 32 | time_std_ms = 1_000 * np.std(latencies) 33 | time_p95_ms = 1_000 * np.percentile(latencies, 95) 34 | print( 35 | f"P95 latency (ms) - {time_p95_ms}; Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f};", # noqa 36 | time_p95_ms, 37 | f"Total time (ms) - {total_time_ms:.2f}", 38 | ) 39 | 40 | 41 | def main( 42 | model_id: str = "distilbert-base-uncased__sst2__train-16-4", dataset_id: str = "sst2", num_samples: int = None 43 | ): 44 | # Load dataset 45 | dataset = load_dataset(f"SetFit/{dataset_id}", split="test") 46 | if num_samples is not None: 47 | dataset = dataset.shuffle(seed=42).select(range(num_samples)) 48 | # Load pipeline 49 | pipe = pipeline("text-classification", model=f"SetFit/{model_id}", device=device_id) 50 | # Time it! 51 | time_pipeline(pipe, dataset) 52 | 53 | 54 | if __name__ == "__main__": 55 | typer.run(main) 56 | -------------------------------------------------------------------------------- /scripts/transformers/run_zeroshot.py: -------------------------------------------------------------------------------- 1 | # Add zeroshot pipeline script here 2 | -------------------------------------------------------------------------------- /scripts/transformers/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Tuple 3 | 4 | from datasets import Dataset 5 | 6 | 7 | def get_label_mappings(dataset: Dataset) -> Tuple[int, dict, dict]: 8 | """Returns the label mappings of the dataset.""" 9 | label_ids = dataset.unique("label") 10 | label_names = dataset.unique("label_text") 11 | label2id = {label: idx for label, idx in zip(label_names, label_ids)} 12 | id2label = {idx: label for label, idx in label2id.items()} 13 | num_labels = len(label_ids) 14 | return num_labels, label2id, id2label 15 | 16 | 17 | def save_metrics(metrics: dict, metrics_filepath): 18 | with open(metrics_filepath, "w") as f: 19 | json.dump(metrics, f) 20 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | multi_line_output = 3 3 | include_trailing_comma = True 4 | force_grid_wrap = 0 5 | use_parentheses = True 6 | ensure_newline_before_comments = True 7 | line_length = 119 8 | lines_after_imports = 2 9 | 10 | [flake8] 11 | ignore = E203, E501, W503 12 | max-line-length = 119 13 | per-file-ignores = 14 | # imported but unused 15 | __init__.py: F401 16 | exclude = 17 | results 18 | scripts/adapet 19 | scripts/tfew 20 | 21 | [tool:pytest] 22 | testpaths = tests 23 | addopts = --cov=setfit --durations=10 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | from pathlib import Path 3 | 4 | from setuptools import find_packages, setup 5 | 6 | 7 | README_TEXT = (Path(__file__).parent / "README.md").read_text(encoding="utf-8") 8 | 9 | MAINTAINER = "Lewis Tunstall, Tom Aarsen" 10 | MAINTAINER_EMAIL = "lewis@huggingface.co" 11 | 12 | INTEGRATIONS_REQUIRE = ["optuna"] 13 | REQUIRED_PKGS = [ 14 | "datasets>=2.15.0", 15 | "sentence-transformers[train]>=3", 16 | "transformers>=4.41.0", 17 | "evaluate>=0.3.0", 18 | "huggingface_hub>=0.24.0", 19 | "scikit-learn", 20 | "packaging", 21 | ] 22 | ABSA_REQUIRE = ["spacy<3.7.6"] 23 | QUALITY_REQUIRE = ["black", "flake8", "isort", "tabulate"] 24 | ONNX_REQUIRE = ["onnxruntime", "onnx!=1.16.2", "skl2onnx"] 25 | OPENVINO_REQUIRE = ["hummingbird-ml", "openvino"] 26 | TESTS_REQUIRE = ["pytest", "pytest-cov"] + ONNX_REQUIRE + OPENVINO_REQUIRE + ABSA_REQUIRE 27 | DOCS_REQUIRE = ["hf-doc-builder>=0.3.0"] 28 | CODECARBON_REQUIRE = ["codecarbon<2.6.0"] 29 | # 2.7.* fails with AttributeError: 'EmissionsTracker' object has no attribute '_cloud' 30 | # 2.6.* has an accidental print statement spamming the terminal 31 | EXTRAS_REQUIRE = { 32 | "optuna": INTEGRATIONS_REQUIRE, 33 | "quality": QUALITY_REQUIRE, 34 | "tests": TESTS_REQUIRE, 35 | "onnx": ONNX_REQUIRE, 36 | "openvino": ONNX_REQUIRE + OPENVINO_REQUIRE, 37 | "docs": DOCS_REQUIRE, 38 | "absa": ABSA_REQUIRE, 39 | "codecarbon": CODECARBON_REQUIRE, 40 | } 41 | 42 | 43 | def combine_requirements(base_keys): 44 | return list(set(k for v in base_keys for k in EXTRAS_REQUIRE[v])) 45 | 46 | 47 | EXTRAS_REQUIRE["dev"] = combine_requirements([k for k in EXTRAS_REQUIRE]) 48 | # For the combatibility tests we add pandas<2, as pandas 2.0.0 onwards is incompatible with old datasets versions, 49 | # and we assume few to no users would use old datasets versions with new pandas versions. 50 | # The only alternative is incrementing the minimum version for datasets, which seems unnecessary. 51 | # Beyond that, fsspec is set to <2023.12.0 as that version is incompatible with datasets<=2.15.0 52 | EXTRAS_REQUIRE["compat_tests"] = ( 53 | [requirement.replace(">=", "==") for requirement in REQUIRED_PKGS] 54 | + TESTS_REQUIRE 55 | + ["pandas<2", "fsspec<2023.12.0"] 56 | ) 57 | 58 | setup( 59 | name="setfit", 60 | version="1.2.0.dev0", 61 | description="Efficient few-shot learning with Sentence Transformers", 62 | long_description=README_TEXT, 63 | long_description_content_type="text/markdown", 64 | maintainer=MAINTAINER, 65 | maintainer_email=MAINTAINER_EMAIL, 66 | url="https://github.com/huggingface/setfit", 67 | download_url="https://github.com/huggingface/setfit/tags", 68 | license="Apache 2.0", 69 | package_dir={"": "src"}, 70 | packages=find_packages("src"), 71 | include_package_data=True, 72 | install_requires=REQUIRED_PKGS, 73 | extras_require=EXTRAS_REQUIRE, 74 | classifiers=[ 75 | "Development Status :: 5 - Production/Stable", 76 | "Intended Audience :: Developers", 77 | "Intended Audience :: Education", 78 | "Intended Audience :: Science/Research", 79 | "License :: OSI Approved :: Apache Software License", 80 | "Operating System :: OS Independent", 81 | "Programming Language :: Python :: 3", 82 | "Programming Language :: Python :: 3.9", 83 | "Programming Language :: Python :: 3.10", 84 | "Programming Language :: Python :: 3.11", 85 | "Programming Language :: Python :: 3.12", 86 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 87 | ], 88 | keywords="nlp, machine learning, fewshot learning, transformers", 89 | zip_safe=False, # Required for mypy to find the py.typed file 90 | ) 91 | -------------------------------------------------------------------------------- /src/setfit/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.2.0.dev0" 2 | 3 | import importlib 4 | import os 5 | import warnings 6 | 7 | from .data import get_templated_dataset, sample_dataset 8 | from .model_card import SetFitModelCardData 9 | from .modeling import SetFitHead, SetFitModel 10 | from .span import AbsaModel, AbsaTrainer, AspectExtractor, AspectModel, PolarityModel 11 | from .trainer import SetFitTrainer, Trainer 12 | from .trainer_distillation import DistillationSetFitTrainer, DistillationTrainer 13 | from .training_args import TrainingArguments 14 | 15 | 16 | # Ensure that DeprecationWarnings are shown by default, as recommended by 17 | # https://docs.python.org/3/library/warnings.html#overriding-the-default-filter 18 | warnings.filterwarnings("default", category=DeprecationWarning) 19 | 20 | # If codecarbon is installed and the log level is not defined, 21 | # automatically overwrite the default to "error" 22 | if importlib.util.find_spec("codecarbon") and "CODECARBON_LOG_LEVEL" not in os.environ: 23 | os.environ["CODECARBON_LOG_LEVEL"] = "error" 24 | -------------------------------------------------------------------------------- /src/setfit/exporters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/src/setfit/exporters/__init__.py -------------------------------------------------------------------------------- /src/setfit/exporters/openvino.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import openvino.runtime as ov 4 | 5 | from setfit import SetFitModel 6 | from setfit.exporters.onnx import export_onnx 7 | 8 | 9 | def export_to_openvino( 10 | model: SetFitModel, 11 | output_path: str = "model.xml", 12 | ) -> None: 13 | """Export a PyTorch backed SetFit model to OpenVINO Intermediate Representation. 14 | 15 | Args: 16 | model_body (`SentenceTransformer`): The model_body from a SetFit model body. This should be a 17 | SentenceTransformer. 18 | model_head (`torch.nn.Module` or `LogisticRegression`): The SetFit model head. This can be either a 19 | dense layer SetFitHead or a Sklearn estimator. 20 | output_path (`str`): The path where will be stored the generated OpenVINO model. At a minimum it needs to contain 21 | the name of the final file. 22 | ignore_ir_version (`bool`): Whether to ignore the IR version used in sklearn. The version is often missmatched 23 | with the transformer models. Setting this to true coerces the versions to be the same. This might 24 | cause errors but in practice works. If this is set to False you need to ensure that the IR versions 25 | align between the transformer and the sklearn onnx representation. 26 | """ 27 | 28 | # Load the model and get all of the parts. 29 | OPENVINO_SUPPORTED_OPSET = 13 30 | 31 | model.model_body.cpu() 32 | onnx_path = output_path.replace(".xml", ".onnx") 33 | 34 | export_onnx( 35 | model.model_body, 36 | model.model_head, 37 | opset=OPENVINO_SUPPORTED_OPSET, 38 | output_path=onnx_path, 39 | ignore_ir_version=True, 40 | use_hummingbird=True, 41 | ) 42 | 43 | # Save the final model. 44 | ov_model = ov.Core().read_model(onnx_path) 45 | ov.serialize(ov_model, output_path) 46 | 47 | os.remove(onnx_path) 48 | -------------------------------------------------------------------------------- /src/setfit/exporters/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mean_pooling(token_embeddings: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: 5 | """Perform attention-aware mean pooling. 6 | 7 | This method takes in embeddings of shape (batch, sequence, embedding_size) and performs average 8 | pooling across the sequence dimension to yield embeddings of size (batch, embedding_size). 9 | 10 | From: 11 | https://github.com/UKPLab/sentence-transformers/blob/0b5ef4be93d2b21de3a918a084b48aab6ba48595/sentence_transformers/model_card_templates.py#L134 # noqa: E501 12 | 13 | Args: 14 | token_embeddings (`torch.Tensor`): The embeddings we wish to pool over of shape 15 | (batch, sequence, embedding_size). This will pool over the sequence to yield 16 | (batch, embedding_size). 17 | attention_mask (`torch.Tensor`): The binary attention mask across the embedings of shape 18 | 19 | Returns: 20 | (`torch.Tensor`) The mean pooled embeddings of size (batch, embedding_size). 21 | """ 22 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 23 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 24 | -------------------------------------------------------------------------------- /src/setfit/integrations.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | from typing import TYPE_CHECKING 3 | 4 | from .utils import BestRun 5 | 6 | 7 | if TYPE_CHECKING: 8 | from .trainer import Trainer 9 | 10 | 11 | def is_optuna_available() -> bool: 12 | return importlib.util.find_spec("optuna") is not None 13 | 14 | 15 | def default_hp_search_backend(): 16 | if is_optuna_available(): 17 | return "optuna" 18 | 19 | 20 | def run_hp_search_optuna(trainer: "Trainer", n_trials: int, direction: str, **kwargs) -> BestRun: 21 | import optuna 22 | 23 | # Heavily inspired by transformers.integrations.run_hp_search_optuna 24 | # https://github.com/huggingface/transformers/blob/cbb8a37929c3860210f95c9ec99b8b84b8cf57a1/src/transformers/integrations.py#L160 25 | def _objective(trial): 26 | trainer.objective = None 27 | trainer.train(trial=trial) 28 | # If there hasn't been any evaluation during the training loop. 29 | if getattr(trainer, "objective", None) is None: 30 | metrics = trainer.evaluate() 31 | trainer.objective = trainer.compute_objective(metrics) 32 | return trainer.objective 33 | 34 | timeout = kwargs.pop("timeout", None) 35 | n_jobs = kwargs.pop("n_jobs", 1) 36 | study = optuna.create_study(direction=direction, **kwargs) 37 | study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs) 38 | best_trial = study.best_trial 39 | return BestRun(str(best_trial.number), best_trial.value, best_trial.params, study) 40 | -------------------------------------------------------------------------------- /src/setfit/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class SupConLoss(nn.Module): 6 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 7 | 8 | It also supports the unsupervised contrastive loss in SimCLR. 9 | """ 10 | 11 | def __init__(self, model, temperature=0.07, contrast_mode="all", base_temperature=0.07): 12 | super(SupConLoss, self).__init__() 13 | self.model = model 14 | self.temperature = temperature 15 | self.contrast_mode = contrast_mode 16 | self.base_temperature = base_temperature 17 | 18 | def forward(self, sentence_features, labels=None, mask=None): 19 | """Computes loss for model. 20 | 21 | If both `labels` and `mask` are None, it degenerates to SimCLR unsupervised loss: 22 | https://arxiv.org/pdf/2002.05709.pdf 23 | 24 | Args: 25 | features: hidden vector of shape [bsz, n_views, ...]. 26 | labels: ground truth of shape [bsz]. 27 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 28 | has the same class as sample i. Can be asymmetric. 29 | 30 | Returns: 31 | A loss scalar. 32 | """ 33 | features = self.model(sentence_features[0])["sentence_embedding"] 34 | 35 | # Normalize embeddings 36 | features = torch.nn.functional.normalize(features, p=2, dim=1) 37 | 38 | # Add n_views dimension 39 | features = torch.unsqueeze(features, 1) 40 | 41 | device = features.device 42 | 43 | if len(features.shape) < 3: 44 | raise ValueError("`features` needs to be [bsz, n_views, ...]," "at least 3 dimensions are required") 45 | if len(features.shape) > 3: 46 | features = features.view(features.shape[0], features.shape[1], -1) 47 | 48 | batch_size = features.shape[0] 49 | if labels is not None and mask is not None: 50 | raise ValueError("Cannot define both `labels` and `mask`") 51 | elif labels is None and mask is None: 52 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 53 | elif labels is not None: 54 | labels = labels.contiguous().view(-1, 1) 55 | if labels.shape[0] != batch_size: 56 | raise ValueError("Num of labels does not match num of features") 57 | mask = torch.eq(labels, labels.T).float().to(device) 58 | else: 59 | mask = mask.float().to(device) 60 | 61 | contrast_count = features.shape[1] 62 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 63 | if self.contrast_mode == "one": 64 | anchor_feature = features[:, 0] 65 | anchor_count = 1 66 | elif self.contrast_mode == "all": 67 | anchor_feature = contrast_feature 68 | anchor_count = contrast_count 69 | else: 70 | raise ValueError("Unknown mode: {}".format(self.contrast_mode)) 71 | 72 | # Compute logits 73 | anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature) 74 | # For numerical stability 75 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 76 | logits = anchor_dot_contrast - logits_max.detach() 77 | 78 | # Tile mask 79 | mask = mask.repeat(anchor_count, contrast_count) 80 | # Mask-out self-contrast cases 81 | logits_mask = torch.scatter( 82 | torch.ones_like(mask), 83 | 1, 84 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 85 | 0, 86 | ) 87 | mask = mask * logits_mask 88 | 89 | # Compute log_prob 90 | exp_logits = torch.exp(logits) * logits_mask 91 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 92 | 93 | # Compute mean of log-likelihood over positive 94 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 95 | 96 | # Loss 97 | loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos 98 | loss = loss.view(anchor_count, batch_size).mean() 99 | 100 | return loss 101 | -------------------------------------------------------------------------------- /src/setfit/notebook.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from transformers.utils.notebook import NotebookProgressCallback 4 | 5 | 6 | class SetFitNotebookProgressCallback(NotebookProgressCallback): 7 | """ 8 | A variation of NotebookProgressCallback that accepts logs/metrics other than "loss" and "eval_loss". 9 | In particular, it accepts "embedding_loss", "aspect_embedding_loss", and "polarity_embedding_loss" 10 | and the corresponding metrics for the validation set. 11 | """ 12 | 13 | def on_log(self, *args, logs=None, **kwargs): 14 | if logs is not None: 15 | logs = {key if key != "embedding_loss" else "loss": value for key, value in logs.items()} 16 | return super().on_log(*args, logs=logs, **kwargs) 17 | 18 | def on_evaluate(self, args, state, control, metrics=None, **kwargs): 19 | if self.training_tracker is not None: 20 | values = {"Training Loss": "No log", "Validation Loss": "No log"} 21 | for log in reversed(state.log_history): 22 | if loss_logs := { 23 | key for key in log if key in ("embedding_loss", "aspect_embedding_loss", "polarity_embedding_loss") 24 | }: 25 | values["Training Loss"] = log[loss_logs.pop()] 26 | break 27 | 28 | if self.first_column == "Epoch": 29 | values["Epoch"] = int(state.epoch) 30 | else: 31 | values["Step"] = state.global_step 32 | metric_key_prefix = "eval" 33 | for k in metrics: 34 | if k.endswith("_loss"): 35 | metric_key_prefix = re.sub(r"\_loss$", "", k) 36 | _ = metrics.pop("total_flos", None) 37 | _ = metrics.pop("epoch", None) 38 | _ = metrics.pop(f"{metric_key_prefix}_runtime", None) 39 | _ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None) 40 | _ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None) 41 | _ = metrics.pop(f"{metric_key_prefix}_jit_compilation_time", None) 42 | for k, v in metrics.items(): 43 | splits = k.split("_") 44 | name = " ".join([part.capitalize() for part in splits[1:]]) 45 | if name in ("Embedding Loss", "Aspect Embedding Loss", "Polarity Embedding Loss"): 46 | # Single dataset 47 | name = "Validation Loss" 48 | values[name] = v 49 | self.training_tracker.write_line(values) 50 | self.training_tracker.remove_child() 51 | self.prediction_bar = None 52 | # Evaluation takes a long time so we should force the next update. 53 | self._force_next_update = True 54 | -------------------------------------------------------------------------------- /src/setfit/span/__init__.py: -------------------------------------------------------------------------------- 1 | from .aspect_extractor import AspectExtractor 2 | from .modeling import AbsaModel, AspectModel, PolarityModel 3 | from .trainer import AbsaTrainer 4 | -------------------------------------------------------------------------------- /src/setfit/span/aspect_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, List, Tuple 2 | 3 | 4 | if TYPE_CHECKING: 5 | from spacy.tokens import Doc 6 | 7 | 8 | class AspectExtractor: 9 | def __init__(self, spacy_model: str) -> None: 10 | super().__init__() 11 | import spacy 12 | 13 | self.nlp = spacy.load(spacy_model) 14 | 15 | def find_groups(self, aspect_mask: List[bool]): 16 | start = None 17 | for idx, flag in enumerate(aspect_mask): 18 | if flag: 19 | if start is None: 20 | start = idx 21 | else: 22 | if start is not None: 23 | yield slice(start, idx) 24 | start = None 25 | if start is not None: 26 | yield slice(start, idx + 1) 27 | 28 | def __call__(self, texts: List[str]) -> Tuple[List["Doc"], List[slice]]: 29 | aspects_list = [] 30 | docs = list(self.nlp.pipe(texts)) 31 | for doc in docs: 32 | aspect_mask = [token.pos_ in ("NOUN", "PROPN") for token in doc] 33 | aspects_list.append(list(self.find_groups(aspect_mask))) 34 | return docs, aspects_list 35 | -------------------------------------------------------------------------------- /src/setfit/utils.py: -------------------------------------------------------------------------------- 1 | import types 2 | from contextlib import contextmanager 3 | from dataclasses import dataclass, field 4 | from time import monotonic_ns 5 | from typing import Any, Dict, List, NamedTuple, Optional, Tuple 6 | 7 | from datasets import Dataset, DatasetDict, load_dataset 8 | from sentence_transformers import losses 9 | from transformers.utils import copy_func 10 | 11 | from .data import create_fewshot_splits, create_fewshot_splits_multilabel 12 | from .losses import SupConLoss 13 | 14 | 15 | SEC_TO_NS_SCALE = 1000000000 16 | 17 | 18 | DEV_DATASET_TO_METRIC = { 19 | "sst2": "accuracy", 20 | "imdb": "accuracy", 21 | "subj": "accuracy", 22 | "bbc-news": "accuracy", 23 | "enron_spam": "accuracy", 24 | "student-question-categories": "accuracy", 25 | "TREC-QC": "accuracy", 26 | "toxic_conversations": "matthews_correlation", 27 | } 28 | 29 | TEST_DATASET_TO_METRIC = { 30 | "emotion": "accuracy", 31 | "SentEval-CR": "accuracy", 32 | "sst5": "accuracy", 33 | "ag_news": "accuracy", 34 | "enron_spam": "accuracy", 35 | "amazon_counterfactual_en": "matthews_correlation", 36 | } 37 | 38 | MULTILINGUAL_DATASET_TO_METRIC = { 39 | f"amazon_reviews_multi_{lang}": "mae" for lang in ["en", "de", "es", "fr", "ja", "zh"] 40 | } 41 | 42 | LOSS_NAME_TO_CLASS = { 43 | "CosineSimilarityLoss": losses.CosineSimilarityLoss, 44 | "ContrastiveLoss": losses.ContrastiveLoss, 45 | "OnlineContrastiveLoss": losses.OnlineContrastiveLoss, 46 | "BatchSemiHardTripletLoss": losses.BatchSemiHardTripletLoss, 47 | "BatchAllTripletLoss": losses.BatchAllTripletLoss, 48 | "BatchHardTripletLoss": losses.BatchHardTripletLoss, 49 | "BatchHardSoftMarginTripletLoss": losses.BatchHardSoftMarginTripletLoss, 50 | "SupConLoss": SupConLoss, 51 | } 52 | 53 | 54 | def default_hp_space_optuna(trial) -> Dict[str, Any]: 55 | from transformers.integrations import is_optuna_available 56 | 57 | assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`" 58 | return { 59 | "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True), 60 | "num_epochs": trial.suggest_int("num_epochs", 1, 5), 61 | "num_iterations": trial.suggest_categorical("num_iterations", [5, 10, 20]), 62 | "seed": trial.suggest_int("seed", 1, 40), 63 | "batch_size": trial.suggest_categorical("batch_size", [4, 8, 16, 32, 64]), 64 | } 65 | 66 | 67 | def load_data_splits( 68 | dataset: str, sample_sizes: List[int], add_data_augmentation: bool = False 69 | ) -> Tuple[DatasetDict, Dataset]: 70 | """Loads a dataset from the Hugging Face Hub and returns the test split and few-shot training splits.""" 71 | print(f"\n\n\n============== {dataset} ============") 72 | # Load one of the SetFit training sets from the Hugging Face Hub 73 | train_split = load_dataset(f"SetFit/{dataset}", split="train") 74 | train_splits = create_fewshot_splits(train_split, sample_sizes, add_data_augmentation, f"SetFit/{dataset}") 75 | test_split = load_dataset(f"SetFit/{dataset}", split="test") 76 | print(f"Test set: {len(test_split)}") 77 | return train_splits, test_split 78 | 79 | 80 | def load_data_splits_multilabel(dataset: str, sample_sizes: List[int]) -> Tuple[DatasetDict, Dataset]: 81 | """Loads a dataset from the Hugging Face Hub and returns the test split and few-shot training splits.""" 82 | print(f"\n\n\n============== {dataset} ============") 83 | # Load one of the SetFit training sets from the Hugging Face Hub 84 | train_split = load_dataset(f"SetFit/{dataset}", "multilabel", split="train") 85 | train_splits = create_fewshot_splits_multilabel(train_split, sample_sizes) 86 | test_split = load_dataset(f"SetFit/{dataset}", "multilabel", split="test") 87 | print(f"Test set: {len(test_split)}") 88 | return train_splits, test_split 89 | 90 | 91 | @dataclass 92 | class Benchmark: 93 | """ 94 | Performs simple benchmarks of code portions (measures elapsed time). 95 | 96 | Typical usage example: 97 | 98 | bench = Benchmark() 99 | with bench.track("Foo function"): 100 | foo() 101 | with bench.track("Bar function"): 102 | bar() 103 | bench.summary() 104 | """ 105 | 106 | out_path: Optional[str] = None 107 | summary_msg: str = field(default_factory=str) 108 | 109 | def print(self, msg: str) -> None: 110 | """ 111 | Prints to system out and optionally to specified out_path. 112 | """ 113 | print(msg) 114 | 115 | if self.out_path is not None: 116 | with open(self.out_path, "a+") as f: 117 | f.write(msg + "\n") 118 | 119 | @contextmanager 120 | def track(self, step): 121 | """ 122 | Computes the elapsed time for given code context. 123 | """ 124 | start = monotonic_ns() 125 | yield 126 | ns = monotonic_ns() - start 127 | msg = f"\n{'*' * 70}\n'{step}' took {ns / SEC_TO_NS_SCALE:.3f}s ({ns:,}ns)\n{'*' * 70}\n" 128 | print(msg) 129 | self.summary_msg += msg + "\n" 130 | 131 | def summary(self) -> None: 132 | """ 133 | Prints summary of all benchmarks performed. 134 | """ 135 | self.print(f"\n{'#' * 30}\nBenchmark Summary:\n{'#' * 30}\n\n{self.summary_msg}") 136 | 137 | 138 | class BestRun(NamedTuple): 139 | """ 140 | The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]). 141 | 142 | Parameters: 143 | run_id (`str`): 144 | The id of the best run. 145 | objective (`float`): 146 | The objective that was obtained for this run. 147 | hyperparameters (`Dict[str, Any]`): 148 | The hyperparameters picked to get this run. 149 | backend (`Any`): 150 | The relevant internal object used for optimization. For optuna this is the `study` object. 151 | """ 152 | 153 | run_id: str 154 | objective: float 155 | hyperparameters: Dict[str, Any] 156 | backend: Any = None 157 | 158 | 159 | def set_docstring(method, docstring, cls=None): 160 | copied_function = copy_func(method) 161 | copied_function.__doc__ = docstring 162 | return types.MethodType(copied_function, cls or method.__self__) 163 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from datasets import Dataset 3 | 4 | from setfit import AbsaModel, SetFitModel 5 | 6 | 7 | @pytest.fixture() 8 | def model() -> SetFitModel: 9 | return SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") 10 | 11 | 12 | @pytest.fixture() 13 | def absa_model() -> AbsaModel: 14 | return AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm") 15 | 16 | 17 | @pytest.fixture() 18 | def trained_absa_model() -> AbsaModel: 19 | return AbsaModel.from_pretrained( 20 | "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect", 21 | "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity", 22 | ) 23 | 24 | 25 | @pytest.fixture() 26 | def absa_dataset() -> Dataset: 27 | texts = [ 28 | "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine.", 29 | "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine.", 30 | "Food is great and inexpensive.", 31 | "Good bagels and good cream cheese.", 32 | "Good bagels and good cream cheese.", 33 | ] 34 | spans = ["food", "ambiance", "Food", "bagels", "cream cheese"] 35 | labels = ["negative", "negative", "positive", "positive", "positive"] 36 | ordinals = [0, 0, 0, 0, 0] 37 | return Dataset.from_dict({"text": texts, "span": spans, "label": labels, "ordinal": ordinals}) 38 | -------------------------------------------------------------------------------- /tests/exporters/test_onnx.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import onnxruntime 5 | import pytest 6 | from transformers import AutoTokenizer 7 | 8 | from setfit import SetFitModel 9 | from setfit.data import get_templated_dataset 10 | from setfit.exporters.onnx import export_onnx 11 | from setfit.trainer import Trainer 12 | from setfit.training_args import TrainingArguments 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "model_path, input_text", 17 | [ 18 | ("lewtun/my-awesome-setfit-model", ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]), 19 | ( 20 | "lewtun/setfit-ethos-multilabel-example", 21 | ["I'm a really hateful guy!", "I hate this one person in particular!"], 22 | ), 23 | ], 24 | ) 25 | def test_export_onnx_sklearn_head(model_path, input_text): 26 | """Test that the exported `ONNX` model returns the same predictions as the original model.""" 27 | model = SetFitModel.from_pretrained(model_path) 28 | 29 | # Export the sklearn based model 30 | output_path = "model.onnx" 31 | try: 32 | export_onnx(model.model_body, model.model_head, opset=12, output_path=output_path) 33 | 34 | # Check that the model was saved. 35 | assert output_path in os.listdir(), "Model not saved to output_path" 36 | 37 | # Run inference using the original model. 38 | pytorch_preds = model(input_text) 39 | 40 | # Run inference using the exported onnx model. 41 | tokenizer = AutoTokenizer.from_pretrained(model_path) 42 | inputs = tokenizer( 43 | input_text, 44 | padding=True, 45 | truncation=True, 46 | return_attention_mask=True, 47 | return_token_type_ids=True, 48 | return_tensors="np", 49 | ) 50 | # Map inputs to int64 from int32 51 | inputs = {key: value.astype("int64") for key, value in inputs.items()} 52 | 53 | session = onnxruntime.InferenceSession(output_path) 54 | 55 | onnx_preds = session.run(None, dict(inputs))[0] 56 | 57 | # Compare the results and ensure that we get the same predictions. 58 | assert np.array_equal(onnx_preds, pytorch_preds) 59 | 60 | finally: 61 | # Cleanup the model. 62 | os.remove(output_path) 63 | 64 | 65 | @pytest.mark.skip("ONNX exporting of SetFit model with Torch head not yet supported.") 66 | @pytest.mark.parametrize("out_features", [1, 2, 3]) 67 | def test_export_onnx_torch_head(out_features): 68 | """Test that the exported `ONNX` model returns the same predictions as the original model.""" 69 | dataset = get_templated_dataset(reference_dataset="SetFit/SentEval-CR") 70 | model_path = "sentence-transformers/paraphrase-albert-small-v2" 71 | model = SetFitModel.from_pretrained( 72 | model_path, use_differentiable_head=True, head_params={"out_features": out_features} 73 | ) 74 | 75 | args = TrainingArguments( 76 | num_iterations=15, 77 | num_epochs=(1, 15), 78 | batch_size=16, 79 | body_learning_rate=(2e-5, 1e-5), 80 | head_learning_rate=1e-2, 81 | l2_weight=0.0, 82 | end_to_end=True, 83 | ) 84 | trainer = Trainer( 85 | model=model, 86 | args=args, 87 | train_dataset=dataset, 88 | eval_dataset=dataset, 89 | column_mapping={"text": "text", "label": "label"}, 90 | ) 91 | trainer.train() 92 | 93 | # Export the sklearn based model 94 | output_path = "model.onnx" 95 | try: 96 | export_onnx(model.model_body, model.model_head, opset=12, output_path=output_path) 97 | 98 | # Check that the model was saved. 99 | assert output_path in os.listdir(), "Model not saved to output_path" 100 | 101 | # Run inference using the original model. 102 | input_text = ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"] 103 | pytorch_preds = model(input_text) 104 | 105 | # Run inference using the exported onnx model. 106 | tokenizer = AutoTokenizer.from_pretrained(model_path) 107 | inputs = tokenizer( 108 | input_text, 109 | padding=True, 110 | truncation=True, 111 | return_attention_mask=True, 112 | return_token_type_ids=True, 113 | return_tensors="np", 114 | ) 115 | # Map inputs to int64 from int32 116 | inputs = {key: value.astype("int64") for key, value in inputs.items()} 117 | 118 | session = onnxruntime.InferenceSession(output_path) 119 | 120 | onnx_preds = session.run(None, dict(inputs))[0] 121 | 122 | # Compare the results and ensure that we get the same predictions. 123 | assert np.array_equal(onnx_preds, pytorch_preds) 124 | 125 | finally: 126 | # Cleanup the model. 127 | os.remove(output_path) 128 | -------------------------------------------------------------------------------- /tests/exporters/test_openvino.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import openvino.runtime as ov 5 | import pytest 6 | from transformers import AutoTokenizer 7 | 8 | from setfit import SetFitModel 9 | from setfit.exporters.openvino import export_to_openvino 10 | 11 | 12 | @pytest.mark.skip( 13 | reason="OpenVINO exporting broke since openvino==2022.3.0, while this version is not supported for Python 3.11 onwards. " 14 | "To allow us to add Python 3.11+ support, we are skipping this test until OpenVINO support is fixed." 15 | ) 16 | def test_export_to_openvino(): 17 | """Test that the exported `OpenVINO` model returns the same predictions as the original model.""" 18 | model_path = "lewtun/my-awesome-setfit-model" 19 | model = SetFitModel.from_pretrained(model_path) 20 | 21 | # Export the sklearn based model 22 | output_path = "model.xml" 23 | export_to_openvino(model, output_path=output_path) 24 | 25 | # Check that the model was saved. 26 | assert output_path in os.listdir(), "Model not saved to output_path" 27 | 28 | # Run inference using the original model. 29 | input_text = ["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"] 30 | pytorch_preds = model(input_text) 31 | 32 | # Run inference using the exported OpenVINO model. 33 | tokenizer = AutoTokenizer.from_pretrained(model_path) 34 | inputs = tokenizer( 35 | input_text, 36 | padding=True, 37 | truncation=True, 38 | return_attention_mask=True, 39 | return_token_type_ids=True, 40 | return_tensors="np", 41 | ) 42 | 43 | inputs_dict = dict(inputs) 44 | 45 | core = ov.Core() 46 | ov_model = core.read_model(output_path) 47 | compiled_model = core.compile_model(ov_model, "CPU") 48 | 49 | ov_preds = compiled_model(inputs_dict)[compiled_model.outputs[0]] 50 | 51 | # Compare the results and ensure that we get the same predictions. 52 | assert np.array_equal(ov_preds, pytorch_preds) 53 | 54 | # Cleanup the model. 55 | os.remove(output_path) 56 | os.remove(output_path.replace(".xml", ".bin")) 57 | -------------------------------------------------------------------------------- /tests/model_card_pattern.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | import re 4 | 5 | 6 | MODEL_CARD_PATTERN = re.compile( 7 | """\ 8 | --- 9 | .* 10 | --- 11 | 12 | \# SetFit with sentence\-transformers/paraphrase\-albert\-small\-v2 on SST2 13 | 14 | This is a \[SetFit\]\(https://github\.com/huggingface/setfit\) model trained on the \[SST2\]\(https://huggingface\.co/datasets/sst2\) dataset that can be used for Text Classification\. This SetFit model uses \[sentence\-transformers/paraphrase\-albert\-small\-v2\]\(https://huggingface\.co/sentence\-transformers/paraphrase\-albert\-small\-v2\) as the Sentence Transformer embedding model\. A \[LogisticRegression\]\(https://scikit\-learn\.org/stable/modules/generated/sklearn\.linear_model\.LogisticRegression\.html\) instance is used for classification\. 15 | 16 | The model has been trained using an efficient few\-shot learning technique that involves: 17 | 18 | 1\. Fine\-tuning a \[Sentence Transformer\]\(https://www\.sbert\.net\) with contrastive learning\. 19 | 2\. Training a classification head with features from the fine\-tuned Sentence Transformer\. 20 | 21 | ## Model Details 22 | 23 | ### Model Description 24 | - \*\*Model Type:\*\* SetFit 25 | - \*\*Sentence Transformer body:\*\* \[sentence-transformers/paraphrase-albert-small-v2\]\(https://huggingface.co/sentence-transformers/paraphrase-albert-small-v2\) * 26 | - \*\*Classification head:\*\* a \[LogisticRegression\]\(https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html\) instance * 27 | - \*\*Maximum Sequence Length:\*\* 100 tokens 28 | - \*\*Number of Classes:\*\* 2 classes 29 | - \*\*Training Dataset:\*\* \[SST2\]\(https://huggingface.co/datasets/sst2\) 30 | - \*\*Language:\*\* en 31 | - \*\*License:\*\* apache-2.0 32 | 33 | ### Model Sources 34 | 35 | - \*\*Repository:\*\* \[SetFit on GitHub\]\(https://github.com/huggingface/setfit\) 36 | - \*\*Paper:\*\* \[Efficient Few-Shot Learning Without Prompts\]\(https://arxiv.org/abs/2209.11055\) 37 | - \*\*Blogpost:\*\* \[SetFit: Efficient Few-Shot Learning Without Prompts\]\(https://huggingface.co/blog/setfit\) 38 | 39 | ### Model Labels 40 | \| Label\s+\| Examples\s+\| 41 | \|:-+\|:-+\| 42 | \| positive\s+\| [^\|]+ \| 43 | \| negative\s+\| [^\|]+ \| 44 | 45 | ## Evaluation 46 | 47 | ### Metrics 48 | \| Label \| Accuracy \| 49 | \|:--------\|:---------\| 50 | \| \*\*all\*\* \| [\d\.]+\s+\| 51 | 52 | ## Uses 53 | 54 | ### Direct Use for Inference 55 | 56 | First install the SetFit library: 57 | 58 | ```bash 59 | pip install setfit 60 | ``` 61 | 62 | Then you can load this model and run inference. 63 | 64 | ```python 65 | from setfit import SetFitModel 66 | 67 | # Download from the [^H]+ Hub 68 | model = SetFitModel.from_pretrained\("tomaarsen/setfit-paraphrase-albert-small-v2-sst2"\) 69 | # Run inference 70 | preds = model\(".+"\) 71 | ``` 72 | 73 | 78 | 79 | 84 | 85 | 90 | 91 | 96 | 97 | ## Training Details 98 | 99 | ### Training Set Metrics 100 | \| Training set \| Min \| Median \| Max \| 101 | \|:-------------\|:----\|:-------\|:----\| 102 | \| Word count \| 3 \| 7.875 \| 18 \| 103 | 104 | \| Label \| Training Sample Count \| 105 | \|:---------\|:----------------------\| 106 | \| negative \| 8 \| 107 | \| positive \| 8 \| 108 | 109 | ### Training Hyperparameters 110 | - batch_size: \(1, 1\) 111 | - num_epochs: \(1, 16\) 112 | - max_steps: 2 113 | - sampling_strategy: oversampling 114 | - body_learning_rate: \(2e-05, 1e-05\) 115 | - head_learning_rate: 0.01 116 | - loss: CosineSimilarityLoss 117 | - distance_metric: cosine_distance 118 | - margin: 0.25 119 | - end_to_end: False 120 | - use_amp: False 121 | - warmup_proportion: 0.1 122 | - l2_weight: 0.01 123 | - seed: 42 124 | - eval_max_steps: -1 125 | - load_best_model_at_end: False 126 | 127 | ### Training Results 128 | \| Epoch \| Step \| Training Loss \| Validation Loss \| 129 | \|:-----:\|:----:\|:-------------:\|:---------------:\| 130 | (\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\| [\d\.]+ +\|\n)+ 131 | ### Environmental Impact 132 | Carbon emissions were measured using \[CodeCarbon\]\(https://github.com/mlco2/codecarbon\)\. 133 | - \*\*Carbon Emitted\*\*: [\d\.]+ kg of CO2 134 | - \*\*Hours Used\*\*: [\d\.]+ hours 135 | 136 | ### Training Hardware 137 | - \*\*On Cloud\*\*: (Yes|No) 138 | - \*\*GPU Model\*\*: [^\n]+ 139 | - \*\*CPU Model\*\*: [^\n]+ 140 | - \*\*RAM Size\*\*: [\d\.]+ GB 141 | 142 | ### Framework Versions 143 | - Python: [^\n]+ 144 | - SetFit: [^\n]+ 145 | - Sentence Transformers: [^\n]+ 146 | - Transformers: [^\n]+ 147 | - PyTorch: [^\n]+ 148 | - Datasets: [^\n]+ 149 | - Tokenizers: [^\n]+ 150 | 151 | ## Citation 152 | 153 | ### BibTeX 154 | ```bibtex 155 | @article{https://doi.org/10.48550/arxiv.2209.11055, 156 | doi = {10.48550/ARXIV.2209.11055}, 157 | url = {https://arxiv.org/abs/2209.11055}, 158 | author = {Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren}, 159 | keywords = {Computation and Language \(cs.CL\), FOS: Computer and information sciences, FOS: Computer and information sciences}, 160 | title = {Efficient Few-Shot Learning Without Prompts}, 161 | publisher = {arXiv}, 162 | year = \{2022\}, 163 | copyright = {Creative Commons Attribution 4.0 International} 164 | } 165 | ``` 166 | 167 | 172 | 173 | 178 | 179 | """, 184 | flags=re.DOTALL, 185 | ) 186 | -------------------------------------------------------------------------------- /tests/span/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/setfit/6be4d6b34b3d5e4fdae6a8cfc120af4ba11cc160/tests/span/__init__.py -------------------------------------------------------------------------------- /tests/span/test_model_card.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from datasets import Dataset 4 | 5 | from setfit import AbsaModel, AbsaTrainer, SetFitModelCardData, TrainingArguments 6 | 7 | from .aspect_model_card_pattern import ASPECT_MODEL_CARD_PATTERN 8 | from .polarity_model_card_pattern import POLARITY_MODEL_CARD_PATTERN 9 | 10 | 11 | def test_model_card(absa_dataset: Dataset, tmp_path: Path) -> None: 12 | model = AbsaModel.from_pretrained( 13 | "sentence-transformers/paraphrase-albert-small-v2", 14 | model_card_data=SetFitModelCardData( 15 | model_id="tomaarsen/setfit-absa-paraphrase-albert-small-v2-laptops", 16 | language=["en"], 17 | license="apache-2.0", 18 | ), 19 | ) 20 | 21 | args = TrainingArguments( 22 | str(tmp_path), 23 | report_to="codecarbon", 24 | batch_size=1, 25 | eval_steps=1, 26 | logging_steps=1, 27 | max_steps=2, 28 | eval_strategy="steps", 29 | save_strategy="no", 30 | ) 31 | trainer = AbsaTrainer( 32 | model=model, 33 | args=args, 34 | train_dataset=absa_dataset, 35 | eval_dataset=absa_dataset, 36 | ) 37 | trainer.train() 38 | trainer.evaluate() 39 | 40 | path = tmp_path / "aspect" 41 | model.aspect_model.create_model_card(path, model_name=str(path)) 42 | with open(path / "README.md", "r", encoding="utf8") as f: 43 | model_card = f.read() 44 | assert ASPECT_MODEL_CARD_PATTERN.fullmatch(model_card) 45 | 46 | path = tmp_path / "polarity" 47 | model.polarity_model.create_model_card(path, model_name=str(path)) 48 | with open(path / "README.md", "r", encoding="utf8") as f: 49 | model_card = f.read() 50 | assert POLARITY_MODEL_CARD_PATTERN.fullmatch(model_card) 51 | -------------------------------------------------------------------------------- /tests/span/test_trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from datasets import Dataset 4 | from pytest import LogCaptureFixture 5 | from transformers import TrainerCallback 6 | 7 | from setfit import AbsaTrainer 8 | from setfit.logging import get_logger 9 | from setfit.span.modeling import AbsaModel 10 | 11 | 12 | def test_trainer(absa_model: AbsaModel, absa_dataset: Dataset) -> None: 13 | trainer = AbsaTrainer(absa_model, train_dataset=absa_dataset, eval_dataset=absa_dataset) 14 | trainer.train() 15 | 16 | metrics = trainer.evaluate() 17 | assert "aspect" in metrics 18 | assert "polarity" in metrics 19 | assert "accuracy" in metrics["aspect"] 20 | assert "accuracy" in metrics["polarity"] 21 | assert metrics["aspect"]["accuracy"] > 0.0 22 | assert metrics["polarity"]["accuracy"] > 0.0 23 | new_metrics = trainer.evaluate(absa_dataset) 24 | assert metrics == new_metrics 25 | 26 | predict = absa_model.predict("Best pizza outside of Italy and really tasty.") 27 | assert {"span": "pizza", "polarity": "positive"} in predict 28 | predict = absa_model.predict(["Best pizza outside of Italy and really tasty.", "This is another sentence"]) 29 | assert isinstance(predict, list) and len(predict) == 2 and isinstance(predict[0], list) 30 | predict = absa_model(["Best pizza outside of Italy and really tasty.", "This is another sentence"]) 31 | assert isinstance(predict, list) and len(predict) == 2 and isinstance(predict[0], list) 32 | 33 | 34 | def test_trainer_callbacks(absa_model: AbsaModel) -> None: 35 | trainer = AbsaTrainer(absa_model) 36 | assert len(trainer.aspect_trainer.st_trainer.callback_handler.callbacks) >= 2 37 | num_callbacks = len(trainer.aspect_trainer.st_trainer.callback_handler.callbacks) 38 | callback_names = { 39 | callback.__class__.__name__ for callback in trainer.aspect_trainer.st_trainer.callback_handler.callbacks 40 | } 41 | assert {"DefaultFlowCallback", "ProgressCallback"} <= callback_names 42 | 43 | class TestCallback(TrainerCallback): 44 | pass 45 | 46 | callback = TestCallback() 47 | trainer.add_callback(callback) 48 | assert len(trainer.aspect_trainer.st_trainer.callback_handler.callbacks) == num_callbacks + 1 49 | assert len(trainer.polarity_trainer.st_trainer.callback_handler.callbacks) == num_callbacks + 1 50 | assert trainer.aspect_trainer.st_trainer.callback_handler.callbacks[-1] == callback 51 | assert trainer.polarity_trainer.st_trainer.callback_handler.callbacks[-1] == callback 52 | 53 | assert trainer.pop_callback(callback) == (callback, callback) 54 | trainer.add_callback(callback) 55 | assert trainer.aspect_trainer.st_trainer.callback_handler.callbacks[-1] == callback 56 | assert trainer.polarity_trainer.st_trainer.callback_handler.callbacks[-1] == callback 57 | trainer.remove_callback(callback) 58 | assert callback not in trainer.aspect_trainer.st_trainer.callback_handler.callbacks 59 | assert callback not in trainer.polarity_trainer.st_trainer.callback_handler.callbacks 60 | 61 | 62 | def test_train_ordinal_too_high(absa_model: AbsaModel, caplog: LogCaptureFixture) -> None: 63 | logger = get_logger("setfit") 64 | logger.propagate = True 65 | 66 | absa_dataset = Dataset.from_dict( 67 | { 68 | "text": [ 69 | "It is about food and ambiance, and imagine how dreadful it will be it we only had to listen to an idle engine." 70 | ], 71 | "span": ["food"], 72 | "label": ["negative"], 73 | "ordinal": [1], 74 | } 75 | ) 76 | with caplog.at_level(logging.INFO): 77 | trainer = AbsaTrainer(absa_model, train_dataset=absa_dataset) 78 | assert len(trainer.aspect_trainer.train_dataset) == 3 79 | assert len(trainer.polarity_trainer.train_dataset) == 0 80 | # These tests are ignored as the caplog is inconsistent: 81 | # assert len(caplog.record_tuples) == 1 82 | # assert caplog.record_tuples[0][2] == ( 83 | # "The ordinal of 1 for span 'food' in 'It is about food and ambiance, and imagine how dreadful it will be " 84 | # "it we only had to listen to an idle engine.' is too high. Skipping this sample." 85 | # ) 86 | # assert caplog.record_tuples[0][1] == logging.INFO 87 | 88 | logger.propagate = False 89 | 90 | 91 | def test_train_column_mapping(absa_model: AbsaModel, absa_dataset: Dataset) -> None: 92 | absa_dataset = absa_dataset.rename_columns({"text": "sentence", "span": "aspect"}) 93 | trainer = AbsaTrainer( 94 | absa_model, train_dataset=absa_dataset, column_mapping={"sentence": "text", "aspect": "span"} 95 | ) 96 | trainer.train() 97 | -------------------------------------------------------------------------------- /tests/test_deprecated_trainer_distillation.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import pytest 4 | from datasets import Dataset 5 | from sentence_transformers.losses import CosineSimilarityLoss 6 | 7 | from setfit import DistillationSetFitTrainer, SetFitTrainer 8 | from setfit.modeling import SetFitModel 9 | 10 | 11 | class DistillationSetFitTrainerTest(TestCase): 12 | def setUp(self): 13 | self.teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") 14 | self.student_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2") 15 | self.num_iterations = 1 16 | 17 | def test_trainer_works_with_default_columns(self): 18 | dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]}) 19 | # train a teacher model 20 | teacher_trainer = SetFitTrainer( 21 | model=self.teacher_model, 22 | train_dataset=dataset, 23 | eval_dataset=dataset, 24 | loss_class=CosineSimilarityLoss, 25 | metric="accuracy", 26 | ) 27 | # Teacher Train and evaluate 28 | teacher_trainer.train() 29 | teacher_model = teacher_trainer.model 30 | 31 | student_trainer = DistillationSetFitTrainer( 32 | teacher_model=teacher_model, 33 | train_dataset=dataset, 34 | student_model=self.student_model, 35 | eval_dataset=dataset, 36 | loss_class=CosineSimilarityLoss, 37 | metric="accuracy", 38 | ) 39 | 40 | # Student Train and evaluate 41 | student_trainer.train() 42 | metrics = student_trainer.evaluate() 43 | print("Student results: ", metrics) 44 | self.assertEqual(metrics["accuracy"], 1.0) 45 | 46 | def test_trainer_raises_error_with_missing_label(self): 47 | labeled_dataset = Dataset.from_dict( 48 | {"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]} 49 | ) 50 | # train a teacher model 51 | teacher_trainer = SetFitTrainer( 52 | model=self.teacher_model, 53 | train_dataset=labeled_dataset, 54 | eval_dataset=labeled_dataset, 55 | metric="accuracy", 56 | num_iterations=self.num_iterations, 57 | ) 58 | # Teacher Train and evaluate 59 | teacher_trainer.train() 60 | 61 | unlabeled_dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]}) 62 | student_trainer = DistillationSetFitTrainer( 63 | teacher_model=self.teacher_model, 64 | student_model=self.student_model, 65 | train_dataset=unlabeled_dataset, 66 | eval_dataset=labeled_dataset, 67 | num_iterations=self.num_iterations, 68 | ) 69 | student_trainer.train() 70 | metrics = student_trainer.evaluate() 71 | print("Student results: ", metrics) 72 | self.assertEqual(metrics["accuracy"], 1.0) 73 | 74 | def test_trainer_raises_error_with_missing_text(self): 75 | dataset = Dataset.from_dict({"label": [0, 1, 2], "extra_column": ["d", "e", "f"]}) 76 | with pytest.raises(ValueError): 77 | DistillationSetFitTrainer( 78 | teacher_model=self.teacher_model, 79 | train_dataset=dataset, 80 | student_model=self.student_model, 81 | eval_dataset=dataset, 82 | num_iterations=self.num_iterations, 83 | ) 84 | 85 | def test_column_mapping_with_missing_text(self): 86 | dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]}) 87 | with pytest.raises(ValueError): 88 | DistillationSetFitTrainer( 89 | teacher_model=self.teacher_model, 90 | train_dataset=dataset, 91 | student_model=self.student_model, 92 | eval_dataset=dataset, 93 | num_iterations=self.num_iterations, 94 | column_mapping={"label_new": "label"}, 95 | ) 96 | 97 | def test_column_mapping_multilabel(self): 98 | dataset = Dataset.from_dict({"text_new": ["a", "b", "c"], "label_new": [[0, 1], [1, 2], [2, 0]]}) 99 | 100 | trainer = DistillationSetFitTrainer( 101 | teacher_model=self.teacher_model, 102 | train_dataset=dataset, 103 | student_model=self.student_model, 104 | eval_dataset=dataset, 105 | num_iterations=self.num_iterations, 106 | column_mapping={"text_new": "text", "label_new": "label"}, 107 | ) 108 | 109 | trainer._validate_column_mapping(dataset) 110 | formatted_dataset = trainer._apply_column_mapping(dataset, trainer.column_mapping) 111 | 112 | assert formatted_dataset.column_names == ["text", "label"] 113 | assert formatted_dataset[0]["text"] == "a" 114 | assert formatted_dataset[0]["label"] == [0, 1] 115 | assert formatted_dataset[1]["text"] == "b" 116 | -------------------------------------------------------------------------------- /tests/test_model_card.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import datasets 4 | import pytest 5 | from datasets import Dataset, load_dataset 6 | from packaging.version import Version, parse 7 | 8 | from setfit import SetFitModel, SetFitModelCardData, Trainer, TrainingArguments 9 | from setfit.data import sample_dataset 10 | from setfit.model_card import generate_model_card, is_on_huggingface 11 | 12 | from .model_card_pattern import MODEL_CARD_PATTERN 13 | 14 | 15 | def test_model_card(tmp_path: Path) -> None: 16 | dataset = load_dataset("sst2") 17 | train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8) 18 | eval_dataset = dataset["validation"].select(range(10)) 19 | model = SetFitModel.from_pretrained( 20 | "sentence-transformers/paraphrase-albert-small-v2", 21 | labels=["negative", "positive"], 22 | model_card_data=SetFitModelCardData( 23 | model_id="tomaarsen/setfit-paraphrase-albert-small-v2-sst2", 24 | dataset_id="sst2", 25 | dataset_name="SST2", 26 | language=["en"], 27 | license="apache-2.0", 28 | ), 29 | ) 30 | 31 | args = TrainingArguments( 32 | str(tmp_path), 33 | report_to="codecarbon", 34 | batch_size=1, 35 | eval_steps=1, 36 | logging_steps=1, 37 | max_steps=2, 38 | eval_strategy="steps", 39 | save_strategy="no", 40 | ) 41 | trainer = Trainer( 42 | model=model, 43 | args=args, 44 | train_dataset=train_dataset, 45 | eval_dataset=eval_dataset, 46 | column_mapping={"sentence": "text"}, 47 | ) 48 | trainer.train() 49 | trainer.evaluate() 50 | model_card = generate_model_card(trainer.model) 51 | assert MODEL_CARD_PATTERN.fullmatch(model_card) 52 | 53 | 54 | def test_model_card_languages() -> None: 55 | model = SetFitModel.from_pretrained( 56 | "sentence-transformers/paraphrase-albert-small-v2", 57 | model_card_data=SetFitModelCardData( 58 | language=["en", "nl", "de"], 59 | ), 60 | ) 61 | model_card = model.generate_model_card() 62 | assert "**Languages:** en, nl, de" in model_card 63 | 64 | 65 | def test_is_on_huggingface_edge_case() -> None: 66 | assert not is_on_huggingface("test_value") 67 | assert not is_on_huggingface("a/test/value") 68 | 69 | 70 | @pytest.mark.skipif( 71 | parse(datasets.__version__) < Version("2.14.0"), reason="Inferring dataset_id only works from datasets >= 2.14.0" 72 | ) 73 | @pytest.mark.parametrize("dataset_id", ("SetFit/emotion", "SetFit/sst2")) 74 | def test_infer_dataset_id(dataset_id: str) -> None: 75 | model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") 76 | train_dataset = load_dataset(dataset_id, split="train") 77 | 78 | # This triggers inferring the dataset_id from train_dataset 79 | Trainer(model=model, train_dataset=train_dataset) 80 | assert model.model_card_data.dataset_id == dataset_id 81 | 82 | 83 | def test_cant_infer_dataset_id(): 84 | model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2") 85 | train_dataset = Dataset.from_dict({"text": ["a", "b", "c", "d"], "label": [0, 1, 1, 0]}) 86 | 87 | # This triggers inferring the dataset_id from train_dataset 88 | Trainer(model=model, train_dataset=train_dataset) 89 | assert model.model_card_data.dataset_id is None 90 | -------------------------------------------------------------------------------- /tests/test_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from setfit.sampler import ContrastiveDataset 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "sampling_strategy, expected_pos_pairs, expected_neg_pairs", 9 | [("unique", 4, 2), ("undersampling", 2, 2), ("oversampling", 4, 4)], 10 | ) 11 | def test_sentence_pairs_generation(sampling_strategy: str, expected_pos_pairs: int, expected_neg_pairs: int): 12 | sentences = np.array(["sent 1", "sent 2", "sent 3"]) 13 | labels = np.array(["label 1", "label 1", "label 2"]) 14 | 15 | multilabel = False 16 | 17 | data_sampler = ContrastiveDataset(sentences, labels, multilabel, sampling_strategy=sampling_strategy) 18 | 19 | assert data_sampler.len_pos_pairs == expected_pos_pairs 20 | assert data_sampler.len_neg_pairs == expected_neg_pairs 21 | 22 | pairs = [i for i in data_sampler] 23 | 24 | assert len(pairs) == expected_pos_pairs + expected_neg_pairs 25 | assert pairs[0] == {"sentence_1": "sent 1", "sentence_2": "sent 1", "label": 1.0} 26 | 27 | 28 | @pytest.mark.parametrize( 29 | "sampling_strategy, expected_pos_pairs, expected_neg_pairs", 30 | [("unique", 6, 4), ("undersampling", 4, 4), ("oversampling", 6, 6)], 31 | ) 32 | def test_sentence_pairs_generation_multilabel( 33 | sampling_strategy: str, expected_pos_pairs: int, expected_neg_pairs: int 34 | ): 35 | sentences = np.array(["sent 1", "sent 2", "sent 3", "sent 4"]) 36 | labels = np.array([[1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) 37 | 38 | multilabel = True 39 | 40 | data_sampler = ContrastiveDataset(sentences, labels, multilabel, sampling_strategy=sampling_strategy) 41 | assert data_sampler.len_pos_pairs == expected_pos_pairs 42 | assert data_sampler.len_neg_pairs == expected_neg_pairs 43 | 44 | pairs = [i for i in data_sampler] 45 | assert len(pairs) == expected_pos_pairs + expected_neg_pairs 46 | -------------------------------------------------------------------------------- /tests/test_training_args.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import pytest 4 | from transformers import IntervalStrategy 5 | 6 | from setfit.training_args import TrainingArguments 7 | 8 | 9 | class TestTrainingArguments(TestCase): 10 | def test_raises_error_with_wrong_warmup_proportion(self): 11 | # warmup_proportion must not be > 1.0 12 | with pytest.raises(ValueError): 13 | TrainingArguments(warmup_proportion=1.1) 14 | 15 | # warmup_proportion must not be < 0.0 16 | with pytest.raises(ValueError): 17 | TrainingArguments(warmup_proportion=-0.1) 18 | 19 | def test_batch_sizes(self): 20 | batch_size_A = 12 21 | batch_size_B = 4 22 | 23 | args = TrainingArguments(batch_size=batch_size_A) 24 | self.assertEqual(args.batch_size, (batch_size_A, batch_size_A)) 25 | self.assertEqual(args.embedding_batch_size, batch_size_A) 26 | self.assertEqual(args.classifier_batch_size, batch_size_A) 27 | 28 | args = TrainingArguments(batch_size=(batch_size_A, batch_size_B)) 29 | self.assertEqual(args.batch_size, (batch_size_A, batch_size_B)) 30 | self.assertEqual(args.embedding_batch_size, batch_size_A) 31 | self.assertEqual(args.classifier_batch_size, batch_size_B) 32 | 33 | def test_num_epochs(self): 34 | num_epochs_A = 12 35 | num_epochs_B = 4 36 | 37 | args = TrainingArguments(num_epochs=num_epochs_A) 38 | self.assertEqual(args.num_epochs, (num_epochs_A, num_epochs_A)) 39 | self.assertEqual(args.embedding_num_epochs, num_epochs_A) 40 | self.assertEqual(args.classifier_num_epochs, num_epochs_A) 41 | 42 | args = TrainingArguments(num_epochs=(num_epochs_A, num_epochs_B)) 43 | self.assertEqual(args.num_epochs, (num_epochs_A, num_epochs_B)) 44 | self.assertEqual(args.embedding_num_epochs, num_epochs_A) 45 | self.assertEqual(args.classifier_num_epochs, num_epochs_B) 46 | 47 | def test_learning_rates(self): 48 | learning_rate_A = 1e-2 49 | learning_rate_B = 1e-3 50 | 51 | base = TrainingArguments() 52 | 53 | args = TrainingArguments(body_learning_rate=learning_rate_A) 54 | self.assertEqual(args.body_learning_rate, (learning_rate_A, learning_rate_A)) 55 | self.assertEqual(args.body_embedding_learning_rate, learning_rate_A) 56 | self.assertEqual(args.body_classifier_learning_rate, learning_rate_A) 57 | self.assertEqual(args.head_learning_rate, base.head_learning_rate) 58 | 59 | args = TrainingArguments(body_learning_rate=(learning_rate_A, learning_rate_B)) 60 | self.assertEqual(args.body_learning_rate, (learning_rate_A, learning_rate_B)) 61 | self.assertEqual(args.body_embedding_learning_rate, learning_rate_A) 62 | self.assertEqual(args.body_classifier_learning_rate, learning_rate_B) 63 | self.assertEqual(args.head_learning_rate, base.head_learning_rate) 64 | 65 | def test_report_to(self): 66 | args = TrainingArguments(report_to="none") 67 | self.assertEqual(args.report_to, ["none"]) 68 | args = TrainingArguments(report_to=["none"]) 69 | self.assertEqual(args.report_to, ["none"]) 70 | args = TrainingArguments(report_to="hello") 71 | self.assertEqual(args.report_to, ["hello"]) 72 | 73 | def test_eval_steps_without_eval_strat(self): 74 | args = TrainingArguments(eval_steps=5) 75 | self.assertEqual(args.eval_strategy, IntervalStrategy.STEPS) 76 | 77 | def test_eval_strat_steps_without_eval_steps(self): 78 | args = TrainingArguments(eval_strategy="steps") 79 | self.assertEqual(args.eval_steps, args.logging_steps) 80 | with self.assertRaises(ValueError): 81 | TrainingArguments(eval_strategy="steps", logging_steps=0, logging_strategy="no") 82 | 83 | def test_load_best_model(self): 84 | with self.assertRaises(ValueError): 85 | TrainingArguments(load_best_model_at_end=True, eval_strategy="steps", save_strategy="epoch") 86 | with self.assertRaises(ValueError): 87 | TrainingArguments( 88 | load_best_model_at_end=True, 89 | eval_strategy="steps", 90 | save_strategy="steps", 91 | eval_steps=100, 92 | save_steps=50, 93 | ) 94 | # No error: save_steps is a round multiple of eval_steps 95 | TrainingArguments( 96 | load_best_model_at_end=True, 97 | eval_strategy="steps", 98 | save_strategy="steps", 99 | eval_steps=50, 100 | save_steps=100, 101 | ) 102 | 103 | def test_logging_steps_zero(self): 104 | with self.assertRaises(ValueError): 105 | TrainingArguments(logging_strategy="steps", logging_steps=0) 106 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import tempfile 4 | 5 | 6 | class SafeTemporaryDirectory(tempfile.TemporaryDirectory): 7 | """ 8 | The GitHub Actions CI on Windows sometimes raises a NotADirectoryError when cleaning up the temporary directory. 9 | This class is a workaround to avoid the error. 10 | 11 | Unlike tempfile.TemporaryDirectory(ignore_cleanup_errors=True), this also works on Python 3.8 and 3.9. 12 | """ 13 | 14 | def __init__(self, *args, **kwargs) -> None: 15 | kwargs["ignore_cleanup_errors"] = True 16 | try: 17 | super().__init__(*args, **kwargs) 18 | except TypeError: 19 | del kwargs["ignore_cleanup_errors"] 20 | super().__init__(*args, **kwargs) 21 | 22 | def __exit__(self, *args, **kwargs): 23 | try: 24 | super().__exit__(*args, **kwargs) 25 | except NotADirectoryError: 26 | pass 27 | -------------------------------------------------------------------------------- /utils/create_notebook_table.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | GITHUB_PATH_PREFIX = "huggingface/setfit/blob/main/notebooks/" 4 | 5 | CHAPTER_TO_NB = { 6 | "Text Classification (Multiclass)": "text-classification", 7 | "Text Classification (Multilabel)": "text-classification_multilabel", 8 | "Zero-Shot Text Classification": "zero-shot-classification", 9 | "Hyperparameter search": "text-classification_hyperparameter-search", 10 | } 11 | 12 | 13 | def _find_text_in_file(filename, start_prompt, end_prompt): 14 | """ 15 | Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty 16 | lines. 17 | 18 | Copied from: https://github.com/huggingface/transformers/blob/16f0b7d72c6d4e122957392c342b074aa2c5c519/utils/check_table.py#L30 19 | """ 20 | with open(filename, "r", encoding="utf-8", newline="\n") as f: 21 | lines = f.readlines() 22 | # Find the start prompt. 23 | start_index = 0 24 | while not lines[start_index].startswith(start_prompt): 25 | start_index += 1 26 | start_index += 1 27 | 28 | end_index = start_index 29 | while not lines[end_index].startswith(end_prompt): 30 | end_index += 1 31 | end_index -= 1 32 | 33 | while len(lines[start_index]) <= 1: 34 | start_index += 1 35 | while len(lines[end_index]) <= 1: 36 | end_index -= 1 37 | end_index += 1 38 | return "".join(lines[start_index:end_index]), start_index, end_index, lines 39 | 40 | 41 | def create_table(): 42 | data = {"Notebook": [], "Colab": [], "Kaggle": [], "Gradient": [], "Studio Lab": []} 43 | for title, nb in CHAPTER_TO_NB.items(): 44 | nb_path = f"{GITHUB_PATH_PREFIX}{nb}.ipynb" 45 | data["Notebook"].append(title) 46 | data["Colab"].append( 47 | f"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/{nb_path})" 48 | ) 49 | data["Kaggle"].append( 50 | f"[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/{nb_path})" 51 | ) 52 | data["Gradient"].append( 53 | f"[![Gradient](https://assets.paperspace.io/img/gradient-badge.svg)](https://console.paperspace.com/github/{nb_path})" 54 | ) 55 | data["Studio Lab"].append( 56 | f"[![Open In SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/{nb_path})" 57 | ) 58 | return pd.DataFrame(data).to_markdown(index=False) + "\n" 59 | 60 | 61 | def main(): 62 | table = create_table() 63 | _, start_index, end_index, lines = _find_text_in_file( 64 | filename="notebooks/README.md", 65 | start_prompt="", 66 | end_prompt="", 67 | ) 68 | 69 | with open("notebooks/README.md", "w", encoding="utf-8", newline="\n") as f: 70 | f.writelines(lines[:start_index] + [table] + lines[end_index:]) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /utils/release.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import os 18 | import re 19 | 20 | import packaging.version 21 | 22 | 23 | REPLACE_PATTERNS = { 24 | "init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'), 25 | "setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'), 26 | } 27 | REPLACE_FILES = { 28 | "init": "src/setfit/__init__.py", 29 | "setup": "setup.py", 30 | } 31 | README_FILE = "README.md" 32 | 33 | 34 | def update_version_in_file(fname, version, pattern): 35 | """Update the version in one file using a specific pattern.""" 36 | with open(fname, "r", encoding="utf-8", newline="\n") as f: 37 | code = f.read() 38 | re_pattern, replace = REPLACE_PATTERNS[pattern] 39 | replace = replace.replace("VERSION", version) 40 | code = re_pattern.sub(replace, code) 41 | with open(fname, "w", encoding="utf-8", newline="\n") as f: 42 | f.write(code) 43 | 44 | 45 | def global_version_update(version, patch=False): 46 | """Update the version in all needed files.""" 47 | for pattern, fname in REPLACE_FILES.items(): 48 | update_version_in_file(fname, version, pattern) 49 | 50 | 51 | def get_version(): 52 | """Reads the current version in the __init__.""" 53 | with open(REPLACE_FILES["init"], "r") as f: 54 | code = f.read() 55 | default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0] 56 | return packaging.version.parse(default_version) 57 | 58 | 59 | def pre_release_work(patch=False): 60 | """Do all the necessary pre-release steps.""" 61 | # First let's get the default version: base version if we are in dev, bump minor otherwise. 62 | default_version = get_version() 63 | if patch and default_version.is_devrelease: 64 | raise ValueError("Can't create a patch version from the dev branch, checkout a released version!") 65 | if default_version.is_devrelease: 66 | default_version = default_version.base_version 67 | elif patch: 68 | default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}" 69 | else: 70 | default_version = f"{default_version.major}.{default_version.minor + 1}.0" 71 | 72 | # Now let's ask nicely if that's the right one. 73 | version = input(f"Which version are you releasing? [{default_version}]") 74 | if len(version) == 0: 75 | version = default_version 76 | 77 | print(f"Updating version to {version}.") 78 | global_version_update(version, patch=patch) 79 | 80 | 81 | def post_release_work(): 82 | """Do all the necesarry post-release steps.""" 83 | # First let's get the current version 84 | current_version = get_version() 85 | dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0" 86 | current_version = current_version.base_version 87 | 88 | # Check with the user we got that right. 89 | version = input(f"Which version are we developing now? [{dev_version}]") 90 | if len(version) == 0: 91 | version = dev_version 92 | 93 | print(f"Updating version to {version}.") 94 | global_version_update(version) 95 | 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.") 100 | parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.") 101 | args = parser.parse_args() 102 | if not args.post_release: 103 | pre_release_work(patch=args.patch) 104 | elif args.patch: 105 | print("Nothing to do after a patch :-)") 106 | else: 107 | post_release_work() --------------------------------------------------------------------------------