├── tests ├── __init__.py ├── test_standalone_single_label.py ├── test_standalone.py ├── test_spacy_external_single_label.py ├── test_spacy_internal_single_label.py ├── test_standalone_multi_label.py ├── test_spacy_external_multi_label.py ├── test_spacy_external.py ├── test_spacy_internal_multi_label.py ├── test_spacy_external_zero_shot.py └── test_spacy_internal.py ├── classy_classification ├── examples │ ├── __init__.py │ ├── individual_transformer.py │ ├── spacy_few_shot_external.py │ ├── spacy_internal_embeddings.py │ ├── spacy_zero_shot_external.py │ └── data.py ├── classifiers │ ├── __init__.py │ ├── classy_standalone.py │ ├── classy_spacy.py │ └── classy_skeleton.py └── __init__.py ├── logo.png ├── .vscode └── settings.json ├── setup.cfg ├── CITATION.cff ├── .pre-commit-config.yaml ├── LICENSE ├── .github └── workflows │ ├── python-publish.yml │ └── python-package.yml ├── pyproject.toml ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /classy_classification/examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /classy_classification/classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/davidberenstein1957/classy-classification/HEAD/logo.png -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestEnabled": true, 3 | "python.linting.flake8Enabled": true, 4 | "python.formatting.provider": "black", 5 | "editor.rulers": [119], 6 | "python.linting.enabled": true, 7 | } 8 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 119 3 | max-complexity = 18 4 | docstring-convention=google 5 | exclude = .git,__pycache__,build,dist 6 | select = C,E,F,W,B,B950 7 | ignore = 8 | E203,E266,E501,W503 9 | enable = 10 | W0614 11 | per-file-ignores = 12 | test_*.py: D 13 | -------------------------------------------------------------------------------- /classy_classification/examples/individual_transformer.py: -------------------------------------------------------------------------------- 1 | from classy_classification import ClassyClassifier 2 | 3 | from .data import training_data, validation_data 4 | 5 | classifier = ClassyClassifier(data=training_data) 6 | print(classifier(validation_data[0])) 7 | print(classifier.pipe(validation_data)) 8 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.0.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: David 5 | given-names: Berenstein 6 | title: "Classy Classification - an easy and intuitive approach to few-shot classification using sentence-transformers or spaCy models, or zero-shot classification with Huggingface." 7 | version: 0.6.0 8 | date-released: 2022-12-31 9 | -------------------------------------------------------------------------------- /classy_classification/examples/spacy_few_shot_external.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | 3 | import classy_classification # noqa: F401 4 | 5 | from .data import training_data, validation_data 6 | 7 | nlp = spacy.blank("en") 8 | nlp.add_pipe("classy_classification", config={"data": training_data, "include_sent": True}) 9 | print([sent._.cats for sent in nlp(validation_data[0]).sents]) 10 | print([doc._.cats for doc in nlp.pipe(validation_data)]) 11 | -------------------------------------------------------------------------------- /classy_classification/examples/spacy_internal_embeddings.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | 3 | import classy_classification # noqa: F401 4 | 5 | from .data import training_data, validation_data 6 | 7 | nlp = spacy.load("en_core_web_md") 8 | nlp.add_pipe("classy_classification", config={"data": training_data, "model": "spacy", "include_sent": True}) 9 | print([sent._.cats for sent in nlp(validation_data[0]).sents]) 10 | print([doc._.cats for doc in nlp.pipe(validation_data)]) 11 | -------------------------------------------------------------------------------- /classy_classification/examples/spacy_zero_shot_external.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | 3 | import classy_classification # noqa: F401 4 | 5 | from .data import training_data, validation_data 6 | 7 | nlp = spacy.blank("en") 8 | nlp.add_pipe( 9 | "classy_classification", config={"data": list(training_data.keys()), "cat_type": "zero", "include_sent": True} 10 | ) 11 | print([sent._.cats for sent in nlp(validation_data[0]).sents]) 12 | print([doc._.cats for doc in nlp.pipe(validation_data)]) 13 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.0.1 4 | hooks: 5 | - id: end-of-file-fixer 6 | - repo: https://github.com/psf/black 7 | rev: 22.3.0 8 | hooks: 9 | - id: black 10 | 11 | # Execute flake8 on all changed files (make sure the version is the same as in pyproject) 12 | - repo: https://github.com/pycqa/flake8 13 | rev: 4.0.1 14 | hooks: 15 | - id: flake8 16 | additional_dependencies: 17 | ["flake8-docstrings", "flake8-bugbear", "pep8-naming"] 18 | -------------------------------------------------------------------------------- /tests/test_standalone_single_label.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from classy_classification import ClassyClassifier 4 | from classy_classification.examples.data import training_data_single_class 5 | 6 | 7 | @pytest.fixture 8 | def standalone_single_label(): 9 | classifier = ClassyClassifier(data=training_data_single_class) 10 | return classifier 11 | 12 | 13 | def test_standalone_single_label(standalone_single_label): 14 | _ = standalone_single_label(training_data_single_class["politics"][0]) 15 | _ = standalone_single_label.pipe(training_data_single_class["politics"]) 16 | -------------------------------------------------------------------------------- /tests/test_standalone.py: -------------------------------------------------------------------------------- 1 | from math import isclose 2 | 3 | import pytest 4 | 5 | from classy_classification import ClassyClassifier 6 | from classy_classification.examples.data import training_data, validation_data 7 | 8 | 9 | @pytest.fixture 10 | def standalone(): 11 | classifier = ClassyClassifier(data=training_data) 12 | return classifier 13 | 14 | 15 | def test_standalone(standalone): 16 | pred = standalone(validation_data[0]) 17 | assert isclose(sum(pred.values()), 1) 18 | 19 | preds = standalone.pipe(validation_data) 20 | for pred in preds: 21 | assert isclose(sum(pred.values()), 1) 22 | -------------------------------------------------------------------------------- /tests/test_spacy_external_single_label.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import spacy 3 | 4 | from classy_classification.examples.data import training_data_single_class 5 | 6 | 7 | @pytest.fixture 8 | def spacy_external_single_label(): 9 | nlp = spacy.blank("en") 10 | nlp.add_pipe( 11 | "classy_classification", 12 | config={"data": training_data_single_class}, 13 | ) 14 | return nlp 15 | 16 | 17 | def test_spacy_external_single_label(spacy_external_single_label): 18 | _ = spacy_external_single_label(training_data_single_class["politics"][0]) 19 | _ = spacy_external_single_label.pipe(training_data_single_class["politics"]) 20 | -------------------------------------------------------------------------------- /tests/test_spacy_internal_single_label.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import spacy 3 | 4 | from classy_classification.examples.data import training_data_single_class 5 | 6 | 7 | @pytest.fixture(params=["en_core_web_md", "en_core_web_trf"]) 8 | def spacy_internal_single_label(request): 9 | nlp = spacy.load(request.param) 10 | nlp.add_pipe("classy_classification", config={"data": training_data_single_class}) 11 | return nlp 12 | 13 | 14 | def test_spacy_internal_single_label(spacy_internal_single_label): 15 | _ = spacy_internal_single_label(training_data_single_class["politics"][0]) 16 | _ = spacy_internal_single_label.pipe(training_data_single_class["politics"]) 17 | -------------------------------------------------------------------------------- /tests/test_standalone_multi_label.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from classy_classification import ClassyClassifier 4 | from classy_classification.examples.data import ( 5 | training_data_multi_label, 6 | validation_data, 7 | ) 8 | 9 | 10 | @pytest.fixture 11 | def standalone_multi_label(): 12 | classifier = ClassyClassifier(data=training_data_multi_label, multi_label=True) 13 | return classifier 14 | 15 | 16 | def test_standalone_multi_label(standalone_multi_label): 17 | pred = standalone_multi_label(validation_data[0]) 18 | assert pred 19 | 20 | preds = standalone_multi_label.pipe(validation_data) 21 | for pred in preds: 22 | assert pred 23 | -------------------------------------------------------------------------------- /tests/test_spacy_external_multi_label.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import spacy 3 | 4 | from classy_classification.examples.data import ( 5 | training_data_multi_label, 6 | validation_data, 7 | ) 8 | 9 | 10 | @pytest.fixture 11 | def spacy_external_multi_label(): 12 | nlp = spacy.blank("en") 13 | nlp.add_pipe( 14 | "classy_classification", 15 | config={"data": training_data_multi_label, "include_sent": True, "multi_label": True}, 16 | ) 17 | return nlp 18 | 19 | 20 | def test_spacy_external_multi_label(spacy_external_multi_label): 21 | doc = spacy_external_multi_label(validation_data[0]) 22 | assert doc._.cats 23 | for sent in doc.sents: 24 | assert sent._.cats 25 | 26 | docs = spacy_external_multi_label.pipe(validation_data) 27 | for doc in docs: 28 | assert doc._.cats 29 | for sent in doc.sents: 30 | assert sent._.cats 31 | -------------------------------------------------------------------------------- /tests/test_spacy_external.py: -------------------------------------------------------------------------------- 1 | from math import isclose 2 | 3 | import pytest 4 | import spacy 5 | 6 | from classy_classification.examples.data import training_data, validation_data 7 | 8 | 9 | @pytest.fixture 10 | def spacy_external(): 11 | nlp = spacy.blank("en") 12 | nlp.add_pipe( 13 | "classy_classification", 14 | config={ 15 | "data": training_data, 16 | "include_sent": True, 17 | }, 18 | ) 19 | return nlp 20 | 21 | 22 | def test_spacy_external(spacy_external): 23 | doc = spacy_external(validation_data[0]) 24 | assert isclose(sum(doc._.cats.values()), 1) 25 | for sent in doc.sents: 26 | assert isclose(sum(sent._.cats.values()), 1) 27 | 28 | docs = spacy_external.pipe(validation_data) 29 | for doc in docs: 30 | assert isclose(sum(doc._.cats.values()), 1) 31 | for sent in doc.sents: 32 | assert isclose(sum(sent._.cats.values()), 1) 33 | -------------------------------------------------------------------------------- /tests/test_spacy_internal_multi_label.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import spacy 3 | 4 | from classy_classification.examples.data import ( 5 | training_data_multi_label, 6 | validation_data, 7 | ) 8 | 9 | 10 | @pytest.fixture(params=["en_core_web_md", "en_core_web_trf"]) 11 | def spacy_internal_multi_label(request): 12 | nlp = spacy.load(request.param) 13 | nlp.add_pipe( 14 | "classy_classification", 15 | config={"data": training_data_multi_label, "model": "spacy", "include_sent": True, "multi_label": True}, 16 | ) 17 | return nlp 18 | 19 | 20 | def test_spacy_internal_multi_label(spacy_internal_multi_label): 21 | doc = spacy_internal_multi_label(validation_data[0]) 22 | assert doc._.cats 23 | for sent in doc.sents: 24 | assert sent._.cats 25 | 26 | docs = spacy_internal_multi_label.pipe(validation_data) 27 | for doc in docs: 28 | assert doc._.cats 29 | for sent in doc.sents: 30 | assert sent._.cats 31 | -------------------------------------------------------------------------------- /tests/test_spacy_external_zero_shot.py: -------------------------------------------------------------------------------- 1 | from math import isclose 2 | 3 | import pytest 4 | import spacy 5 | 6 | from classy_classification.examples.data import training_data, validation_data 7 | 8 | 9 | @pytest.fixture 10 | def spacy_external_zer_shot(): 11 | nlp = spacy.blank("en") 12 | nlp.add_pipe( 13 | "classy_classification", config={"data": list(training_data.keys()), "cat_type": "zero", "include_sent": True} 14 | ) 15 | return nlp 16 | 17 | 18 | def test_spacy_external_zer_shot(spacy_external_zer_shot): 19 | doc = spacy_external_zer_shot(validation_data[0]) 20 | assert isclose(sum(doc._.cats.values()), 1, abs_tol=0.05) 21 | for sent in doc.sents: 22 | assert isclose(sum(sent._.cats.values()), 1, abs_tol=0.05) 23 | 24 | docs = spacy_external_zer_shot.pipe(validation_data) 25 | for doc in docs: 26 | assert isclose(sum(doc._.cats.values()), 1, abs_tol=0.05) 27 | for sent in doc.sents: 28 | assert isclose(sum(sent._.cats.values()), 1, abs_tol=0.05) 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 David Berenstein 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [created] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /tests/test_spacy_internal.py: -------------------------------------------------------------------------------- 1 | from math import isclose 2 | 3 | import pytest 4 | import spacy 5 | 6 | from classy_classification.examples.data import training_data, validation_data 7 | 8 | 9 | @pytest.fixture(params=["en_core_web_md", "en_core_web_trf"]) 10 | def spacy_internal(request): 11 | nlp = spacy.load(request.param) 12 | nlp.add_pipe( 13 | "classy_classification", 14 | config={ 15 | "data": training_data, 16 | "model": "spacy", 17 | "include_sent": True, 18 | }, 19 | ) 20 | return nlp 21 | 22 | 23 | def test_spacy_internal(spacy_internal): 24 | doc = spacy_internal(validation_data[0]) 25 | assert isclose(sum(doc._.cats.values()), 1) 26 | for sent in doc.sents: 27 | assert isclose(sum(sent._.cats.values()), 1) 28 | 29 | docs = spacy_internal.pipe(validation_data) 30 | for doc in docs: 31 | assert isclose(sum(doc._.cats.values()), 1) 32 | for sent in doc.sents: 33 | assert isclose(sum(sent._.cats.values()), 1) 34 | 35 | 36 | # @pytest.fixture 37 | # def spacy_internal_trf(): 38 | # nlp = spacy.load("en_core_web_trf") 39 | # nlp.add_pipe( 40 | # "classy_classification", 41 | # config={ 42 | # "data": training_data, 43 | # "model": "spacy", 44 | # "include_sent": True, 45 | # }, 46 | # ) 47 | # return nlp 48 | 49 | 50 | # def test_spacy_internal_trf(spacy_internal_trf): 51 | # doc = spacy_internal(validation_data[0]) 52 | # assert isclose(sum(doc._.cats.values()), 1) 53 | # for sent in doc.sents: 54 | # assert isclose(sum(sent._.cats.values()), 1) 55 | 56 | # docs = spacy_internal.pipe(validation_data) 57 | # for doc in docs: 58 | # assert isclose(sum(doc._.cats.values()), 1) 59 | # for sent in doc.sents: 60 | # assert isclose(sum(sent._.cats.values()), 1) 61 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [main] 9 | pull_request: 10 | branches: [main] 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | python-version: ["3.8", "3.9", "3.10", "3.11"] 19 | 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | sudo pip install -U setuptools 29 | python -m ensurepip --upgrade 30 | python -m pip install --upgrade pip 31 | python -m pip install --upgrade setuptools 32 | python -m pip install flake8 pytest pytest-cov 33 | python -m pip install "poetry<2.0.0" 34 | poetry export -f requirements.txt -o requirements.txt --without-hashes 35 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 36 | python -m spacy download en_core_web_md 37 | python -m spacy download en_core_web_trf 38 | - name: Lint with flake8 39 | run: | 40 | # stop the build if there are Python syntax errors or undefined names 41 | flake8 . --count --max-complexity=18 --enable=W0614 --select=C,E,F,W,B,B950 --ignore=E203,E266,E501,W503 --exclude=.git,__pycache__,build,dist --max-line-length=119 --show-source --statistics 42 | - name: Test with pytest 43 | run: | 44 | pytest --doctest-modules --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html 45 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "classy-classification" 3 | version = "1.0.2" 4 | description = "Have you every struggled with needing a Spacy TextCategorizer but didn't have the time to train one from scratch? Classy Classification is the way to go!" 5 | authors = ["David Berenstein "] 6 | license = "MIT" 7 | readme = "README.md" 8 | homepage = "https://github.com/davidberenstein1957/classy-classification" 9 | repository = "https://github.com/davidberenstein1957/classy-classification" 10 | documentation = "https://github.com/davidberenstein1957/classy-classification" 11 | keywords = ["spacy", "rasa", "few-shot classification", "nlu", "sentence-transformers"] 12 | classifiers = [ 13 | "Development Status :: 4 - Beta", 14 | "Intended Audience :: Developers", 15 | "Intended Audience :: Science/Research", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | "Programming Language :: Python :: 3.8", 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Topic :: Scientific/Engineering", 23 | "Topic :: Software Development" 24 | ] 25 | packages = [{include = "classy_classification"}] 26 | 27 | [tool.poetry.dependencies] 28 | python = ">=3.8,<3.12" 29 | spacy = {extras = ["transformers"], version = "^3.0"} 30 | sentence-transformers = ">2,<4" 31 | scikit-learn = "^1.0" 32 | pandas = ">1" 33 | transformers = {extras = ["torch"], version = ">4.20,<5"} 34 | setuptools = ">65.5.0" 35 | 36 | [tool.poetry.plugins."spacy_factories"] 37 | "spacy" = "classy_classification.__init__:make_text_categorizer" 38 | 39 | [tool.poetry.dev-dependencies] 40 | pytest = "^7.0.1" 41 | flake8 = "^4.0.1" 42 | black = "^22.3.0" 43 | flake8-bugbear = "^22.3.23" 44 | flake8-docstrings = "^1.6.0" 45 | isort = "^5.10.1" 46 | pep8-naming = "^0.12.1" 47 | pre-commit = "^2.17.0" 48 | jupyterlab = "^3.5.2" 49 | ipython = "^8.8.0" 50 | jupyter = "^1.0.0" 51 | ipykernel = "^6.20.1" 52 | 53 | [build-system] 54 | requires = ["poetry-core>=1.0.0"] 55 | build-backend = "poetry.core.masonry.api" 56 | 57 | [tool.pytest.ini_options] 58 | testpaths = "tests" 59 | 60 | [tool.black] 61 | line-length = 119 62 | experimental-string-processing = true 63 | 64 | [tool.isort] 65 | profile = "black" 66 | src_paths = ["classy_classification"] 67 | -------------------------------------------------------------------------------- /classy_classification/examples/data.py: -------------------------------------------------------------------------------- 1 | training_data = { 2 | "politics": [ 3 | "Putin orders troops into pro-Russian regions of eastern Ukraine.", 4 | "The president decided not to go through with his speech.", 5 | "There is much uncertainty surrounding the coming elections.", 6 | "Democrats are engaged in a ‘new politics of evasion’", 7 | ], 8 | "sports": [ 9 | "The soccer team lost.", 10 | "The team won by two against zero.", 11 | "I love all sport.", 12 | "The olympics were amazing.", 13 | "Yesterday, the tennis players wrapped up wimbledon.", 14 | ], 15 | "weather": [ 16 | "It is going to be sunny outside.", 17 | "Heavy rainfall and wind during the afternoon.", 18 | "Clear skies in the morning, but mist in the evenening.", 19 | "It is cold during the winter.", 20 | "There is going to be a storm with heavy rainfall.", 21 | ], 22 | } 23 | training_data_multi_label = { 24 | "politics": [ 25 | "Putin orders troops into pro-Russian regions of eastern Ukraine.", 26 | "The president decided not to go through with his speech.", 27 | "There is much uncertainty surrounding the coming elections.", 28 | "Democrats are engaged in a ‘new politics of evasion’", 29 | "The soccer team lost.", 30 | "The team won by two against zero.", 31 | "I love all sport.", 32 | ], 33 | "sports": [ 34 | "Putin orders troops into pro-Russian regions of eastern Ukraine.", 35 | "The president decided not to go through with his speech.", 36 | "There is much uncertainty surrounding the coming elections.", 37 | "Democrats are engaged in a ‘new politics of evasion’", 38 | "The soccer team lost.", 39 | "The team won by two against zero.", 40 | "I love all sport.", 41 | "The olympics were amazing.", 42 | "Yesterday, the tennis players wrapped up wimbledon.", 43 | ], 44 | "weather": [ 45 | "It is going to be sunny outside.", 46 | "Heavy rainfall and wind during the afternoon.", 47 | "Clear skies in the morning, but mist in the evenening.", 48 | "It is cold during the winter.", 49 | "There is going to be a storm with heavy rainfall.", 50 | ], 51 | } 52 | training_data_single_class = {"politics": training_data["politics"]} 53 | 54 | validation_data = [ 55 | "I am surely talking about politics.", 56 | "Sports is all you need.", 57 | "The weather is amazing and sunny and cloudy.", 58 | ] 59 | -------------------------------------------------------------------------------- /classy_classification/classifiers/classy_standalone.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import List, Union 3 | 4 | from .classy_spacy import ClassyExternal, ClassySkeletonFewShot 5 | 6 | 7 | class ClassyStandalone(ClassyExternal): 8 | def __call__(self, text: str) -> dict: 9 | """predict the class for an input text 10 | 11 | Args: 12 | text (str): an input text 13 | 14 | Returns: 15 | dict: a key-class proba-value dict 16 | """ 17 | embeddings = self.get_embeddings([text]) 18 | 19 | return self.get_prediction(embeddings)[0] 20 | 21 | def pipe(self, text: List[str]) -> List[dict]: 22 | """retrieve predicitons for multiple texts 23 | 24 | Args: 25 | text (List[str]): a list of texts 26 | 27 | Returns: 28 | List[dict]: list of key-class proba-value dict 29 | """ 30 | embeddings = self.get_embeddings(text) 31 | 32 | return self.get_prediction(embeddings) 33 | 34 | 35 | class ClassySentenceTransformer(ClassyStandalone, ClassySkeletonFewShot): 36 | def __init__( 37 | self, 38 | data: dict, 39 | model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", 40 | device: str = "cpu", 41 | multi_label: bool = False, 42 | config: Union[dict, None] = None, 43 | verbose: bool = False, 44 | ) -> None: 45 | """initialize a classy skeleton for classification using a SVC config and some input training data. 46 | 47 | Args: 48 | data (dict): training data. example 49 | { 50 | "class_1": ["example"], 51 | "class 2": ["example"] 52 | }, 53 | device (str): device "cuda"/"cpu", 54 | config (dict, optional): a SVC config. 55 | example 56 | { 57 | "C": [1, 2, 5, 10, 20, 100], 58 | "kernel": ["linear"], 59 | "max_cross_validation_folds": 5 60 | }. 61 | """ 62 | self.multi_label = multi_label 63 | if isinstance(data, dict): 64 | self.data = collections.OrderedDict(sorted(data.items())) 65 | elif isinstance(data, list): # in case of zero-shot classification 66 | self.data = data 67 | self.model = model 68 | self.device = device 69 | self.verbose = verbose 70 | self.set_embedding_model() 71 | self.set_training_data() 72 | self.set_config(config) 73 | self.set_classification_model() 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Downloaded models 132 | *.model 133 | *.model.* 134 | # OS 135 | **/*.DS_Store 136 | /models 137 | *.onnx 138 | test.py 139 | -------------------------------------------------------------------------------- /classy_classification/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union 3 | 4 | from spacy.language import Language 5 | 6 | from .classifiers.classy_spacy import ( 7 | ClassySpacyExternalFewShot, 8 | ClassySpacyExternalZeroShot, 9 | ClassySpacyInternalFewShot, 10 | ) 11 | from .classifiers.classy_standalone import ClassySentenceTransformer as ClassyClassifier 12 | 13 | __all__ = [ 14 | "ClassyClassifier", 15 | "ClassySpacyExternalFewShot", 16 | "ClassySpacyExternalZeroShot", 17 | "ClassySpacyInternalFewShot", 18 | ] 19 | 20 | logging.captureWarnings(True) 21 | 22 | 23 | @Language.factory( 24 | "classy_classification", 25 | default_config={ 26 | "data": None, 27 | "model": None, 28 | "device": "cpu", 29 | "config": None, 30 | "cat_type": "few", 31 | "multi_label": False, 32 | "include_doc": True, 33 | "include_sent": False, 34 | "verbose": False, 35 | }, 36 | ) 37 | def make_text_categorizer( 38 | nlp: Language, 39 | name: str, 40 | data: Union[dict, list], 41 | device: str = "cpu", 42 | config: dict = None, 43 | model: str = None, 44 | cat_type: str = "few", 45 | multi_label: bool = False, 46 | include_doc: bool = True, 47 | include_sent: bool = False, 48 | verbose: bool = False, 49 | ): 50 | if model == "spacy" and cat_type == "zero": 51 | raise NotImplementedError("Cannot use spacy internal embeddings with zero-shot classification") 52 | elif model == "spacy" and cat_type == "few": 53 | return ClassySpacyInternalFewShot( 54 | nlp=nlp, 55 | name=name, 56 | data=data, 57 | config=config, 58 | include_doc=include_doc, 59 | include_sent=include_sent, 60 | multi_label=multi_label, 61 | verbose=verbose, 62 | ) 63 | 64 | elif model != "spacy" and cat_type == "zero": 65 | return ClassySpacyExternalZeroShot( 66 | nlp=nlp, 67 | name=name, 68 | data=data, 69 | device=device, 70 | model=model, 71 | include_doc=include_doc, 72 | include_sent=include_sent, 73 | multi_label=multi_label, 74 | verbose=verbose, 75 | ) 76 | elif model != "spacy" and cat_type == "few": 77 | return ClassySpacyExternalFewShot( 78 | nlp=nlp, 79 | name=name, 80 | data=data, 81 | device=device, 82 | model=model, 83 | config=config, 84 | include_doc=include_doc, 85 | include_sent=include_sent, 86 | multi_label=multi_label, 87 | verbose=verbose, 88 | ) 89 | else: 90 | raise NotImplementedError( 91 | f"`model` as `{model}` is not valid it takes arguments `spacy` and `transformer`. " 92 | f"`cat_type` as `{cat_type}` is not valid stakes arguments `zero` and `few`." 93 | ) 94 | -------------------------------------------------------------------------------- /classy_classification/classifiers/classy_spacy.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import warnings 3 | from typing import List, Union 4 | 5 | import numpy as np 6 | from spacy import __version__, util 7 | from spacy.tokens import Doc 8 | 9 | from .classy_skeleton import ClassyExternal, ClassySkeleton, ClassySkeletonFewShot 10 | 11 | 12 | class ClassySpacy: 13 | def sentence_pipe(self, doc: Doc): 14 | if doc.has_extension("trf_data"): 15 | disable = [comp[0] for comp in self.nlp.components if comp[0] != "transformer"] 16 | texts = [sent.text for sent in doc.sents] 17 | sent_docs = self.nlp.pipe(texts, disable=disable) 18 | else: 19 | sent_docs = [sent.as_doc() for sent in doc.sents] 20 | inferred_sent_docs = self.pipe(iter(sent_docs), include_sent=False) 21 | for sent_doc, sent in zip(inferred_sent_docs, doc.sents): 22 | sent._.cats = sent_doc._.cats 23 | 24 | def __call__(self, doc: Doc): 25 | """ 26 | It takes a doc, gets the embeddings from the doc, reshapes the embeddings, gets the prediction from the embeddings, 27 | and then sets the prediction results for the doc 28 | 29 | :param doc: Doc 30 | :type doc: Doc 31 | :return: The doc object with the predicted categories and the predicted categories for each sentence. 32 | """ 33 | if self.include_doc: 34 | embeddings = self.get_embeddings([doc]) 35 | embeddings = embeddings.reshape(1, -1) 36 | doc._.cats = self.get_prediction(embeddings)[0] 37 | 38 | if self.include_sent: 39 | self.sentence_pipe(doc) 40 | 41 | return doc 42 | 43 | def pipe(self, stream, batch_size=128, include_sent=None): 44 | """ 45 | predict the class for a spacy Doc stream 46 | 47 | Args: 48 | stream (Doc): a spacy doc 49 | 50 | Returns: 51 | Doc: spacy doc with ._.cats key-class proba-value dict 52 | """ 53 | if include_sent is None: 54 | include_sent = self.include_sent 55 | for docs in util.minibatch(stream, size=batch_size): 56 | embeddings = self.get_embeddings(docs) 57 | pred_results = [] * len(embeddings) 58 | if self.include_doc: 59 | pred_results = self.get_prediction(embeddings) 60 | 61 | for doc, pred_result in zip(docs, pred_results): 62 | if self.include_doc: 63 | doc._.cats = pred_result 64 | if include_sent: 65 | self.sentence_pipe(doc) 66 | 67 | yield doc 68 | 69 | 70 | class ClassySpacyInternal(ClassySpacy): 71 | def get_embeddings(self, docs: Union[List[Doc], List[str]]) -> List[float]: 72 | """Retrieve embeddings from text. 73 | Overwrites function from the classySkeleton that is used to get embeddings for training data to fetch internal 74 | spaCy embeddings. 75 | 76 | Args: 77 | text (List[str]): a list of texts 78 | 79 | Returns: 80 | List[float]: a list of embeddings 81 | """ 82 | if not ((len(self.nlp.vocab.vectors)) or ("transformer" in self.nlp.component_names)): 83 | raise NotImplementedError( 84 | "internal spacy embeddings need to be derived from md/lg/trf spacy models not from sm models." 85 | ) 86 | 87 | if isinstance(docs, list): 88 | if isinstance(docs[0], str): 89 | docs = self.nlp.pipe(docs, disable=["tagger", "parser", "attribute_ruler", "lemmatizer", "ner"]) 90 | elif isinstance(docs[0], Doc): 91 | pass 92 | else: 93 | raise ValueError("This should be a List") 94 | 95 | embeddings = [] 96 | for doc in docs: 97 | if doc.has_vector: 98 | embeddings.append(doc.vector) 99 | elif doc.has_extension("trf_data"): 100 | # check if version is larger than 3.7.0 101 | major, minor, patch = map(int, __version__.split(".")) 102 | is_greater_than_3_7 = (major > 3) or (major == 3 and minor >= 7) 103 | if is_greater_than_3_7: 104 | embeddings.append(doc._.trf_data.all_outputs[0].data[-1]) 105 | else: 106 | embeddings.append(doc._.trf_data.model_output.pooler_output[0]) 107 | else: 108 | warnings.warn( 109 | f"None of the words in the text `{str(doc)}` have vectors. Returning zeros.", stacklevel=1 110 | ) 111 | embeddings.append(np.zeros(self.nlp.vocab.vectors_length)) 112 | return np.array(embeddings) 113 | 114 | 115 | class ClassySpacyInternalFewShot(ClassySpacyInternal, ClassySkeletonFewShot): 116 | def __init__(self, *args, **kwargs): 117 | ClassySkeletonFewShot.__init__(self, *args, **kwargs) 118 | 119 | 120 | class ClassySpacyExternalFewShot(ClassySpacy, ClassyExternal, ClassySkeletonFewShot): 121 | def __init__( 122 | self, 123 | model: str = None, 124 | device: str = "cpu", 125 | *args, 126 | **kwargs, 127 | ): 128 | if model is None: 129 | model = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" 130 | self.model = model 131 | self.device = device 132 | self.set_embedding_model() 133 | ClassySkeletonFewShot.__init__(self, *args, **kwargs) 134 | 135 | 136 | class ClassySpacyExternalZeroShot(ClassySpacy, ClassySkeleton): 137 | def __init__( 138 | self, 139 | model: str = None, 140 | device: str = "cpu", 141 | multi_label: bool = False, 142 | *args, 143 | **kwargs, 144 | ): 145 | if model is None: 146 | model = "typeform/distilbert-base-uncased-mnli" 147 | self.model = model 148 | self.device = device 149 | self.multi_label = multi_label 150 | ClassySkeleton.__init__(self, *args, **kwargs) 151 | 152 | def set_classification_model(self, model: str = None, device: str = None): 153 | """set the embedding model based on a sentencetransformer model or path 154 | 155 | Args: 156 | model (str, optional): the model name. Defaults to self.model, if no model is provided. 157 | """ 158 | if model: # update if overwritten 159 | self.model = model 160 | if device: 161 | self.device = device 162 | 163 | try: 164 | from optimum.pipelines import pipeline 165 | 166 | if self.device in ["gpu", "cuda", 0]: 167 | self.device = 0 168 | else: 169 | self.device = -1 170 | 171 | self.pipeline = pipeline( 172 | "zero-shot-classification", model=model, device=self.device, top_k=None, accelerator="ort" 173 | ) 174 | except Exception: 175 | from transformers import pipeline 176 | 177 | if self.device in ["gpu", "cuda", 0]: 178 | self.device = 0 179 | else: 180 | self.device = -1 181 | 182 | self.pipeline = pipeline("zero-shot-classification", model=self.model, device=self.device, top_k=None) 183 | 184 | def set_config(self, _: dict = None): 185 | """Zero-shot models don't require a config""" 186 | pass 187 | 188 | def set_training_data(self, _: dict = None): 189 | """Zero-shot models don't require training data""" 190 | pass 191 | 192 | def set_embedding_model(self, _: dict = None): 193 | """Zero-shot models don't require embeddings models""" 194 | pass 195 | 196 | def get_embeddings(self, docs: Union[List[Doc], List[str]]): 197 | """Zero-shot models don't require embeddings""" 198 | pass 199 | 200 | def format_prediction(self, prediction): 201 | """ 202 | It takes a prediction dictionary and returns a list of dictionaries, where each dictionary has a single key-value 203 | pair 204 | 205 | :param prediction: The prediction returned by the model 206 | :return: A list of dictionaries. 207 | """ 208 | if importlib.util.find_spec("fast-sentence-transformers") is None: 209 | return {pred[0]: pred[1] for pred in zip(prediction.get("labels"), prediction.get("scores"))} 210 | else: 211 | return {self.data[pred[0]]: pred[1] for pred in prediction} 212 | 213 | def set_pred_results_for_doc(self, doc: Doc): 214 | """ 215 | It takes a spaCy Doc object, runs it through the pipeline, and then adds the predictions to the Doc object 216 | 217 | :param doc: Doc 218 | :type doc: Doc 219 | :return: A list of dictionaries. 220 | """ 221 | pred_results = self.pipeline([sent.text for sent in list(doc.sents)], self.data) 222 | pred_results = [self.format_prediction(pred) for pred in pred_results] 223 | for sent, pred in zip(doc.sents, pred_results): 224 | sent._.cats = pred 225 | return doc 226 | 227 | def __call__(self, doc: Doc) -> Doc: 228 | """ 229 | predict the class for a spacy Doc 230 | 231 | Args: 232 | doc (Doc): a spacy doc 233 | 234 | Returns: 235 | Doc: spacy doc with ._.cats key-class proba-value dict 236 | """ 237 | if self.include_doc: 238 | pred_result = self.pipeline(doc.text, self.data, multi_label=self.multi_label) 239 | doc._.cats = self.format_prediction(pred_result) 240 | if self.include_sent: 241 | self.sentence_pipe(doc) 242 | 243 | return doc 244 | 245 | def pipe(self, stream, batch_size=128, include_sent=None): 246 | """ 247 | predict the class for a spacy Doc stream 248 | 249 | Args: 250 | stream (Doc): a spacy doc 251 | 252 | Returns: 253 | Doc: spacy doc with ._.cats key-class proba-value dict 254 | """ 255 | if include_sent is None: 256 | include_sent = self.include_sent 257 | for docs in util.minibatch(stream, size=batch_size): 258 | predictions = [doc.text for doc in docs] 259 | if self.include_doc: 260 | predictions = self.pipeline(predictions, self.data, multi_label=self.multi_label) 261 | predictions = [self.format_prediction(pred) for pred in predictions] 262 | for doc, pred_result in zip(docs, predictions): 263 | if self.include_doc: 264 | doc._.cats = pred_result 265 | if include_sent: 266 | self.sentence_pipe(doc) 267 | 268 | yield doc 269 | -------------------------------------------------------------------------------- /classy_classification/classifiers/classy_skeleton.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import List, Union 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from sentence_transformers import SentenceTransformer 7 | from sklearn import preprocessing 8 | from sklearn.model_selection import GridSearchCV 9 | from sklearn.multiclass import OneVsRestClassifier 10 | from sklearn.svm import SVC, OneClassSVM 11 | from spacy.language import Language 12 | from spacy.tokens import Doc, Span 13 | 14 | 15 | class ClassySkeleton: 16 | def __init__( 17 | self, 18 | nlp: Language, 19 | name: str, 20 | data: dict, 21 | include_doc: bool = True, 22 | include_sent: bool = False, 23 | multi_label: bool = False, 24 | config: Union[dict, None] = None, 25 | verbose: bool = True, 26 | ) -> None: 27 | """initialize a classy skeleton for classification using a SVC config and some input training data. 28 | 29 | Args: 30 | data (dict): training data. example 31 | { 32 | "class_1": ["example"], 33 | "class 2": ["example"] 34 | }, 35 | device (str): device "cuda"/"cpu", 36 | config (_type_, optional): a SVC config. 37 | example 38 | { 39 | "C": [1, 2, 5, 10, 20, 100], 40 | "kernel": ["linear"], 41 | "max_cross_validation_folds": 5 42 | }. 43 | """ 44 | 45 | self.multi_label = multi_label 46 | 47 | if isinstance(data, dict): 48 | self.data = collections.OrderedDict(sorted(data.items())) 49 | elif isinstance(data, list): # in case of zero-shot classification 50 | self.data = data 51 | 52 | self.name = name 53 | self.nlp = nlp 54 | self.verbose = verbose 55 | self.include_doc = include_doc 56 | self.include_sent = include_sent 57 | if include_sent: 58 | Span.set_extension("cats", default=None, force=True) 59 | if "sentencizer" not in nlp.pipe_names: 60 | nlp.add_pipe("sentencizer") 61 | if include_doc: 62 | Doc.set_extension("cats", default=None, force=True) 63 | self.set_training_data() 64 | self.set_config(config) 65 | self.set_classification_model() 66 | 67 | def set_training_data(self): 68 | """Overwritten by super class""" 69 | raise NotImplementedError("Needs to be overwritten by superclass") 70 | 71 | def set_config(self): 72 | """Overwritten by super class""" 73 | raise NotImplementedError("Needs to be overwritten by superclass") 74 | 75 | def set_classification_model(self): 76 | """Overwritten by super class""" 77 | raise NotImplementedError("Needs to be overwritten by superclass") 78 | 79 | def get_embeddings(self): 80 | """Overwritten by super class""" 81 | raise NotImplementedError("Needs to be overwritten by superclass") 82 | 83 | def __call__(self): 84 | """Overwritten by super class""" 85 | raise NotImplementedError("Needs to be overwritten by superclass") 86 | 87 | def pipe(self): 88 | """Overwritten by super class""" 89 | raise NotImplementedError("Needs to be overwritten by superclass") 90 | 91 | def get_prediction(self, embeddings: List[List]) -> List[dict]: 92 | """get the predicitons for a list om embeddings 93 | 94 | Args: 95 | embeddings (List[List]): a list of text embeddings. 96 | 97 | Returns: 98 | List[dict]: list of key-class proba-value dict 99 | """ 100 | if len(self.label_list) > 1: 101 | pred_result = self.clf.predict_proba(embeddings) 102 | pred_result = self.proba_to_dict(pred_result) 103 | else: 104 | pred_result = self.clf.predict(embeddings) 105 | label = self.label_list[0] 106 | pred_result = [ 107 | {label: 1, f"not_{label}": 0} if pred == 1 else {label: 0, f"not_{label}": 1} for pred in pred_result 108 | ] 109 | return pred_result 110 | 111 | 112 | class ClassySkeletonFewShot(ClassySkeleton): 113 | def set_config(self, config: Union[dict, None] = None): 114 | """ 115 | > This function sets the config attribute of the class to the config parameter if the config parameter is not None, 116 | otherwise it sets the config attribute to a default value 117 | 118 | :param config: A dictionary of parameters to be used in the SVM 119 | :type config: Union[dict, None] 120 | """ 121 | 122 | if config is None: 123 | if len(self.label_list) > 1: 124 | config = { 125 | "C": [1, 2, 5, 10, 20, 50, 100], 126 | "kernel": ["linear", "rbf", "poly", "sigmoid"], 127 | "max_cross_validation_folds": 5, 128 | "seed": None, 129 | } 130 | else: 131 | config = { 132 | "nu": 0.1, 133 | "kernel": "rbf", 134 | } 135 | 136 | self.config = config 137 | 138 | def set_classification_model(self, config: dict = None): 139 | """Set and fit the SVC model. 140 | 141 | Args: 142 | config (dict, optional): A config containing keys for SVC kernels, C, max_cross_validation_folds. 143 | Defaults to None if self.config needs to be used. 144 | """ 145 | if config: # update if overwritten 146 | self.config = config 147 | 148 | if len(self.label_list) > 1: 149 | self.svm = SVC( 150 | probability=True, 151 | class_weight="balanced", 152 | verbose=self.verbose, 153 | random_state=self.config.get("seed"), 154 | ) 155 | 156 | # NOTE: consifer using multi_target_strategy "one-vs-one", "one-vs-rest", "output-code" 157 | if self.multi_label: 158 | self.svm = OneVsRestClassifier(self.svm) 159 | param_addition = "estimator__" 160 | cv_splits = None 161 | else: 162 | param_addition = "" 163 | folds = self.config["max_cross_validation_folds"] 164 | cv_splits = max(2, min(folds, np.min(np.bincount(self.y)) // 5)) 165 | 166 | tuned_parameters = [ 167 | { 168 | f"{param_addition}{key}": value 169 | for key, value in self.config.items() 170 | if key not in ["random_state", "max_cross_validation_folds", "seed"] 171 | } 172 | ] 173 | 174 | self.clf = GridSearchCV( 175 | self.svm, 176 | param_grid=tuned_parameters, 177 | cv=cv_splits, 178 | scoring="f1_weighted", 179 | verbose=self.verbose, 180 | ) 181 | self.clf.fit(self.X, self.y) 182 | elif len(self.label_list) == 1: 183 | if self.multi_label: 184 | raise ValueError("Cannot apply one class classification with multiple-labels.") 185 | self.clf = OneClassSVM(verbose=self.verbose, **self.config) 186 | self.clf.fit(self.X) 187 | else: 188 | raise ValueError("Provide input data with Dict[key, List].") 189 | 190 | def proba_to_dict(self, pred_results: List[List]) -> List[dict]: 191 | """converts probability prediciton to a formatted key-class proba-value list 192 | 193 | Args: 194 | pred_results (_List[List]): a list of prediction probabilities. 195 | 196 | Returns: 197 | List[dict]: list of key-class proba-value dict 198 | """ 199 | 200 | pred_dict = [] 201 | for pred in pred_results: 202 | pred_dict.append({label: value for label, value in zip(self.label_list, pred)}) 203 | 204 | return pred_dict 205 | 206 | def set_training_data(self, data: dict = None): 207 | """_summary_ 208 | 209 | Args: 210 | data (dict, optional): a dict containing category keys and lists ov example values. 211 | Defaults to None if self.data needs to be used. 212 | """ 213 | if data: # update if overwritten 214 | self.data = data 215 | 216 | labels = [] 217 | X = [] 218 | self.label_list = list(self.data.keys()) 219 | for key, value in self.data.items(): 220 | labels += len(value) * [key] 221 | X += value 222 | 223 | if self.multi_label: 224 | df = pd.DataFrame(data={"X": X, "labels": labels}) 225 | groups = df.groupby("X").agg(list).to_records().tolist() 226 | X = [group[0] for group in groups] 227 | labels = [group[1] for group in groups] 228 | self.le = preprocessing.MultiLabelBinarizer() 229 | else: 230 | self.le = preprocessing.LabelEncoder() 231 | 232 | self.y = self.le.fit_transform(labels) 233 | 234 | self.X = self.get_embeddings(X) 235 | 236 | if data: # update if overwritten 237 | self.set_classification_model() 238 | 239 | 240 | class ClassyExternal: 241 | def get_embeddings(self, docs: Union[List[Doc], List[str]]) -> List[List[float]]: 242 | """retrieve embeddings from the SentenceTransformer model for a text or list of texts 243 | 244 | Args: 245 | X (List[str]): input texts 246 | 247 | Returns: 248 | List[List[float]]: output embeddings 249 | """ 250 | # inputs = self.tokenizer(X, padding=True, truncation=True, max_length=512, return_tensors="pt") 251 | # ort_inputs = {k: v.cpu().numpy() for k, v in inputs.items()} 252 | 253 | # return self.session.run(None, ort_inputs)[0] 254 | docs = list(docs) 255 | if isinstance(docs, list): 256 | if isinstance(docs[0], str): 257 | pass 258 | elif isinstance(docs[0], Doc): 259 | docs = [doc.text for doc in docs] 260 | else: 261 | raise ValueError("This should be a List") 262 | 263 | return self.encoder.encode(docs, show_progress_bar=self.verbose) 264 | 265 | def set_embedding_model(self, model: str = None, device: str = "cpu"): 266 | """set the embedding model based on a sentencetransformer model or path 267 | 268 | Args: 269 | model (str, optional): the model name. Defaults to self.model, if no model is provided. 270 | """ 271 | if model: # update if overwritten 272 | self.model = model 273 | if device: 274 | self.device = device 275 | 276 | if self.device in ["gpu", "cuda", 0]: 277 | self.device = None # If None, checks if a GPU can be used. 278 | else: 279 | self.device = "cpu" 280 | self.encoder = SentenceTransformer(self.model, device=self.device) 281 | 282 | if model: # update if overwritten 283 | self.set_training_data() 284 | self.set_classification_model() 285 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Classy Classification 2 | Have you ever struggled with needing a [Spacy TextCategorizer](https://spacy.io/api/textcategorizer) but didn't have the time to train one from scratch? Classy Classification is the way to go! For few-shot classification using [sentence-transformers](https://github.com/UKPLab/sentence-transformers) or [spaCy models](https://spacy.io/usage/models), provide a dictionary with labels and examples, or just provide a list of labels for zero shot-classification with [Hugginface zero-shot classifiers](https://huggingface.co/models?pipeline_tag=zero-shot-classification). 3 | 4 | [![Current Release Version](https://img.shields.io/github/release/pandora-intelligence/classy-classification.svg?style=flat-square&logo=github)](https://github.com/pandora-intelligence/classy-classification/releases) 5 | [![pypi Version](https://img.shields.io/pypi/v/classy-classification.svg?style=flat-square&logo=pypi&logoColor=white)](https://pypi.org/project/classy-classification/) 6 | [![PyPi downloads](https://static.pepy.tech/personalized-badge/classy-classification?period=total&units=international_system&left_color=grey&right_color=orange&left_text=pip%20downloads)](https://pypi.org/project/classy-classification/) 7 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg?style=flat-square)](https://github.com/ambv/black) 8 | 9 | # Install 10 | ``` pip install classy-classification``` 11 | 12 | ## SetFit support 13 | 14 | I got a lot of requests for SetFit support, but I decided to create a [separate package](https://github.com/davidberenstein1957/spacy-setfit) for this. Feel free to check it out. ❤️ 15 | 16 | # Quickstart 17 | ## SpaCy embeddings 18 | ```python 19 | import spacy 20 | # or import standalone 21 | # from classy_classification import ClassyClassifier 22 | 23 | data = { 24 | "furniture": ["This text is about chairs.", 25 | "Couches, benches and televisions.", 26 | "I really need to get a new sofa."], 27 | "kitchen": ["There also exist things like fridges.", 28 | "I hope to be getting a new stove today.", 29 | "Do you also have some ovens."] 30 | } 31 | 32 | nlp = spacy.load("en_core_web_trf") 33 | nlp.add_pipe( 34 | "classy_classification", 35 | config={ 36 | "data": data, 37 | "model": "spacy" 38 | } 39 | ) 40 | 41 | print(nlp("I am looking for kitchen appliances.")._.cats) 42 | 43 | # Output: 44 | # 45 | # [{"furniture" : 0.21}, {"kitchen": 0.79}] 46 | ``` 47 | ### Sentence level classification 48 | ```python 49 | import spacy 50 | 51 | data = { 52 | "furniture": ["This text is about chairs.", 53 | "Couches, benches and televisions.", 54 | "I really need to get a new sofa."], 55 | "kitchen": ["There also exist things like fridges.", 56 | "I hope to be getting a new stove today.", 57 | "Do you also have some ovens."] 58 | } 59 | 60 | nlp.add_pipe( 61 | "classy_classification", 62 | config={ 63 | "data": data, 64 | "model": "spacy", 65 | "include_sent": True 66 | } 67 | ) 68 | 69 | print(nlp("I am looking for kitchen appliances. And I love doing so.").sents[0]._.cats) 70 | 71 | # Output: 72 | # 73 | # [[{"furniture" : 0.21}, {"kitchen": 0.79}] 74 | ``` 75 | 76 | ### Define random seed and verbosity 77 | 78 | ```python 79 | 80 | nlp.add_pipe( 81 | "classy_classification", 82 | config={ 83 | "data": data, 84 | "verbose": True, 85 | "config": {"seed": 42} 86 | } 87 | ) 88 | ``` 89 | 90 | ### Multi-label classification 91 | 92 | Sometimes multiple labels are necessary to fully describe the contents of a text. In that case, we want to make use of the **multi-label** implementation, here the sum of label scores is not limited to 1. Just pass the same training data to multiple keys. 93 | 94 | ```python 95 | import spacy 96 | 97 | data = { 98 | "furniture": ["This text is about chairs.", 99 | "Couches, benches and televisions.", 100 | "I really need to get a new sofa.", 101 | "We have a new dinner table.", 102 | "There also exist things like fridges.", 103 | "I hope to be getting a new stove today.", 104 | "Do you also have some ovens.", 105 | "We have a new dinner table."], 106 | "kitchen": ["There also exist things like fridges.", 107 | "I hope to be getting a new stove today.", 108 | "Do you also have some ovens.", 109 | "We have a new dinner table.", 110 | "There also exist things like fridges.", 111 | "I hope to be getting a new stove today.", 112 | "Do you also have some ovens.", 113 | "We have a new dinner table."] 114 | } 115 | 116 | nlp = spacy.load("en_core_web_md") 117 | nlp.add_pipe( 118 | "classy_classification", 119 | config={ 120 | "data": data, 121 | "model": "spacy", 122 | "multi_label": True, 123 | } 124 | ) 125 | 126 | print(nlp("I am looking for furniture and kitchen equipment.")._.cats) 127 | 128 | # Output: 129 | # 130 | # [{"furniture": 0.92}, {"kitchen": 0.91}] 131 | ``` 132 | 133 | ### Outlier detection 134 | 135 | Sometimes it is worth to be able to do outlier detection or binary classification. This can either be approached using 136 | a binary training dataset, however, I have also implemented support for a `OneClassSVM` for [outlier detection using a single label](https://scikit-learn.org/stable/modules/generated/sklearn.svm.OneClassSVM.html). Not that this method does not return probabilities, but that the data is formatted like label-score value pair to ensure uniformity. 137 | 138 | Approach 1: 139 | 140 | ```python 141 | import spacy 142 | 143 | data_binary = { 144 | "inlier": ["This text is about chairs.", 145 | "Couches, benches and televisions.", 146 | "I really need to get a new sofa."], 147 | "outlier": ["Text about kitchen equipment", 148 | "This text is about politics", 149 | "Comments about AI and stuff."] 150 | } 151 | 152 | nlp = spacy.load("en_core_web_md") 153 | nlp.add_pipe( 154 | "classy_classification", 155 | config={ 156 | "data": data_binary, 157 | } 158 | ) 159 | 160 | print(nlp("This text is a random text")._.cats) 161 | 162 | # Output: 163 | # 164 | # [{'inlier': 0.2926672385488411, 'outlier': 0.707332761451159}] 165 | ``` 166 | 167 | Approach 2: 168 | 169 | ```python 170 | import spacy 171 | 172 | data_singular = { 173 | "furniture": ["This text is about chairs.", 174 | "Couches, benches and televisions.", 175 | "I really need to get a new sofa.", 176 | "We have a new dinner table."] 177 | } 178 | nlp = spacy.load("en_core_web_md") 179 | nlp.add_pipe( 180 | "classy_classification", 181 | config={ 182 | "data": data_singular, 183 | } 184 | ) 185 | 186 | print(nlp("This text is a random text")._.cats) 187 | 188 | # Output: 189 | # 190 | # [{'furniture': 0, 'not_furniture': 1}] 191 | ``` 192 | 193 | ## Sentence-transfomer embeddings 194 | 195 | ```python 196 | import spacy 197 | 198 | data = { 199 | "furniture": ["This text is about chairs.", 200 | "Couches, benches and televisions.", 201 | "I really need to get a new sofa."], 202 | "kitchen": ["There also exist things like fridges.", 203 | "I hope to be getting a new stove today.", 204 | "Do you also have some ovens."] 205 | } 206 | 207 | nlp = spacy.blank("en") 208 | nlp.add_pipe( 209 | "classy_classification", 210 | config={ 211 | "data": data, 212 | "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", 213 | "device": "gpu" 214 | } 215 | ) 216 | 217 | print(nlp("I am looking for kitchen appliances.")._.cats) 218 | 219 | # Output: 220 | # 221 | # [{"furniture": 0.21}, {"kitchen": 0.79}] 222 | ``` 223 | 224 | ## Hugginface zero-shot classifiers 225 | 226 | ```python 227 | import spacy 228 | 229 | data = ["furniture", "kitchen"] 230 | 231 | nlp = spacy.blank("en") 232 | nlp.add_pipe( 233 | "classy_classification", 234 | config={ 235 | "data": data, 236 | "model": "typeform/distilbert-base-uncased-mnli", 237 | "cat_type": "zero", 238 | "device": "gpu" 239 | } 240 | ) 241 | 242 | print(nlp("I am looking for kitchen appliances.")._.cats) 243 | 244 | # Output: 245 | # 246 | # [{"furniture": 0.21}, {"kitchen": 0.79}] 247 | ``` 248 | 249 | # Credits 250 | 251 | ## Inspiration Drawn From 252 | 253 | [Huggingface](https://huggingface.co/) does offer some nice models for few/zero-shot classification, but these are not tailored to multi-lingual approaches. Rasa NLU has [a nice approach](https://rasa.com/blog/rasa-nlu-in-depth-part-1-intent-classification/) for this, but its too embedded in their codebase for easy usage outside of Rasa/chatbots. Additionally, it made sense to integrate [sentence-transformers](https://github.com/UKPLab/sentence-transformers) and [Hugginface zero-shot](https://huggingface.co/models?pipeline_tag=zero-shot-classification), instead of default [word embeddings](https://arxiv.org/abs/1301.3781). Finally, I decided to integrate with Spacy, since training a custom [Spacy TextCategorizer](https://spacy.io/api/textcategorizer) seems like a lot of hassle if you want something quick and dirty. 254 | 255 | - [Scikit-learn](https://github.com/scikit-learn/scikit-learn) 256 | - [Rasa NLU](https://github.com/RasaHQ/rasa) 257 | - [Sentence Transformers](https://github.com/UKPLab/sentence-transformers) 258 | - [Spacy](https://github.com/explosion/spaCy) 259 | 260 | ## Or buy me a coffee 261 | 262 | [!["Buy Me A Coffee"](https://www.buymeacoffee.com/assets/img/custom_images/orange_img.png)](https://www.buymeacoffee.com/98kf2552674) 263 | 264 | # Standalone usage without spaCy 265 | 266 | ```python 267 | 268 | from classy_classification import ClassyClassifier 269 | 270 | data = { 271 | "furniture": ["This text is about chairs.", 272 | "Couches, benches and televisions.", 273 | "I really need to get a new sofa."], 274 | "kitchen": ["There also exist things like fridges.", 275 | "I hope to be getting a new stove today.", 276 | "Do you also have some ovens."] 277 | } 278 | 279 | classifier = ClassyClassifier(data=data) 280 | classifier("I am looking for kitchen appliances.") 281 | classifier.pipe(["I am looking for kitchen appliances."]) 282 | 283 | # overwrite training data 284 | classifier.set_training_data(data=data) 285 | classifier("I am looking for kitchen appliances.") 286 | 287 | # overwrite [embedding model](https://www.sbert.net/docs/pretrained_models.html) 288 | classifier.set_embedding_model(model="paraphrase-MiniLM-L3-v2") 289 | classifier("I am looking for kitchen appliances.") 290 | 291 | # overwrite SVC config 292 | classifier.set_classification_model( 293 | config={ 294 | "C": [1, 2, 5, 10, 20, 100], 295 | "kernel": ["linear"], 296 | "max_cross_validation_folds": 5 297 | } 298 | ) 299 | classifier("I am looking for kitchen appliances.") 300 | ``` 301 | 302 | ## Save and load models 303 | 304 | ```python 305 | data = { 306 | "furniture": ["This text is about chairs.", 307 | "Couches, benches and televisions.", 308 | "I really need to get a new sofa."], 309 | "kitchen": ["There also exist things like fridges.", 310 | "I hope to be getting a new stove today.", 311 | "Do you also have some ovens."] 312 | } 313 | classifier = classyClassifier(data=data) 314 | 315 | with open("./classifier.pkl", "wb") as f: 316 | pickle.dump(classifier, f) 317 | 318 | f = open("./classifier.pkl", "rb") 319 | classifier = pickle.load(f) 320 | classifier("I am looking for kitchen appliances.") 321 | ``` 322 | --------------------------------------------------------------------------------