├── tests
├── __init__.py
├── test_default.py
├── data
│ └── thiscatdoesnotexist.jpeg
├── test_base.py
├── test_docs.py
└── test_vision.py
├── docs
├── API
│ ├── grab.md
│ ├── vision
│ │ ├── timm.md
│ │ ├── imageload.md
│ │ └── colorhist.md
│ └── text
│ │ ├── sentence-enc.md
│ │ └── sense2vec.md
├── images
│ ├── icon.png
│ ├── timm.png
│ ├── sense2vec.png
│ ├── imageloader.png
│ ├── colorhistogram.png
│ ├── columngrabber.png
│ └── sentence-encoder.png
└── index.md
├── .flake8
├── embetter
├── __init__.py
├── base.py
├── vision
│ ├── __init__.py
│ ├── _colorhist.py
│ ├── _torchvis.py
│ └── _loader.py
├── text
│ ├── __init__.py
│ ├── _s2v.py
│ └── _sbert.py
├── error.py
└── grab.py
├── mkdocs.yml
├── .github
└── workflows
│ └── unittest.yml
├── Makefile
├── LICENCE
├── setup.py
├── .gitignore
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/API/grab.md:
--------------------------------------------------------------------------------
1 | # ColumnGrabber
2 |
3 | ::: embetter.grab.ColumnGrabber
--------------------------------------------------------------------------------
/docs/API/vision/timm.md:
--------------------------------------------------------------------------------
1 | # TimmEncoder
2 |
3 | ::: embetter.vision.TimmEncoder
--------------------------------------------------------------------------------
/docs/API/vision/imageload.md:
--------------------------------------------------------------------------------
1 | # ImageLoader
2 |
3 | ::: embetter.vision.ImageLoader
4 |
--------------------------------------------------------------------------------
/docs/images/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/icon.png
--------------------------------------------------------------------------------
/docs/images/timm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/timm.png
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 160
3 | ignore = E501, C901
4 | extend-ignore = E203, W503
--------------------------------------------------------------------------------
/tests/test_default.py:
--------------------------------------------------------------------------------
1 | def test_true():
2 | """kind of a no-op test"""
3 | assert True
4 |
--------------------------------------------------------------------------------
/docs/images/sense2vec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/sense2vec.png
--------------------------------------------------------------------------------
/docs/API/vision/colorhist.md:
--------------------------------------------------------------------------------
1 | # ColorHistogramEncoder
2 |
3 | ::: embetter.vision.ColorHistogramEncoder
4 |
--------------------------------------------------------------------------------
/docs/images/imageloader.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/imageloader.png
--------------------------------------------------------------------------------
/docs/images/colorhistogram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/colorhistogram.png
--------------------------------------------------------------------------------
/docs/images/columngrabber.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/columngrabber.png
--------------------------------------------------------------------------------
/docs/images/sentence-encoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/sentence-encoder.png
--------------------------------------------------------------------------------
/tests/data/thiscatdoesnotexist.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lmcinnes/embetter/main/tests/data/thiscatdoesnotexist.jpeg
--------------------------------------------------------------------------------
/docs/API/text/sentence-enc.md:
--------------------------------------------------------------------------------
1 | # SentenceEncoder
2 |
3 | ## `embetter.text.SentenceEncoder`
4 |
5 | ::: embetter.text.SentenceEncoder
--------------------------------------------------------------------------------
/docs/API/text/sense2vec.md:
--------------------------------------------------------------------------------
1 | # Sense2VecEncoder
2 |
3 | ## `embetter.text.Sense2VecEncoder`
4 |
5 | ::: embetter.text.Sense2VecEncoder
6 |
--------------------------------------------------------------------------------
/embetter/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from importlib import metadata
3 | except ImportError: # for Python<3.8
4 | import importlib_metadata as metadata
5 |
6 |
7 | __title__ = __name__
8 | __version__ = metadata.version(__title__)
9 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Embetter Docs
2 | repo_url: https://github.com/koaning/embetter
3 | plugins:
4 | - mkdocstrings:
5 | custom_templates: templates
6 | theme:
7 | name: material
8 | logo: images/icon.png
9 | palette:
10 | primary: white
11 | markdown_extensions:
12 | - pymdownx.highlight:
13 | use_pygments: true
14 | - pymdownx.superfences
15 |
16 |
--------------------------------------------------------------------------------
/tests/test_base.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from embetter.grab import ColumnGrabber
3 |
4 |
5 | def test_grab_column():
6 | """Ensure that we can grab a text column."""
7 | data = [{"text": "hi", "foo": 1}, {"text": "yes", "foo": 2}]
8 | dataframe = pd.DataFrame(data)
9 | out = ColumnGrabber("text").fit_transform(dataframe)
10 | assert out == ["hi", "yes"]
11 |
--------------------------------------------------------------------------------
/embetter/base.py:
--------------------------------------------------------------------------------
1 | from sklearn.base import BaseEstimator, TransformerMixin
2 |
3 |
4 | class EmbetterBase(BaseEstimator, TransformerMixin):
5 | """Base class for feature transformers in this library"""
6 |
7 | def fit(self, X, y=None):
8 | """No-op."""
9 | return self
10 |
11 | def partial_fit(self, X, y=None):
12 | """No-op."""
13 | return self
14 |
--------------------------------------------------------------------------------
/embetter/vision/__init__.py:
--------------------------------------------------------------------------------
1 | from embetter.error import NotInstalled
2 | from embetter.vision._loader import ImageLoader
3 | from embetter.vision._colorhist import ColorHistogramEncoder
4 |
5 | try:
6 | from embetter.vision._torchvis import TimmEncoder
7 | except ModuleNotFoundError:
8 | TimmEncoder = NotInstalled("TimmEncoder", "vision")
9 |
10 |
11 | __all__ = ["ImageLoader", "ColorHistogramEncoder", "TimmEncoder"]
12 |
--------------------------------------------------------------------------------
/embetter/text/__init__.py:
--------------------------------------------------------------------------------
1 | from embetter.error import NotInstalled
2 |
3 | try:
4 | from embetter.text._sbert import SentenceEncoder
5 | except ModuleNotFoundError:
6 | SentenceEncoder = NotInstalled("SentenceEncoder", "sbert")
7 |
8 | try:
9 | from embetter.text._s2v import Sense2VecEncoder
10 | except ModuleNotFoundError:
11 | Sense2VecEncoder = NotInstalled("Sense2VecEncoder", "sbert")
12 |
13 |
14 | __all__ = ["SentenceEncoder", "Sense2VecEncoder"]
15 |
--------------------------------------------------------------------------------
/tests/test_docs.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from mktestdocs import check_md_file, check_docstring
3 | from embetter.vision import ColorHistogramEncoder, TimmEncoder, ImageLoader
4 | from embetter.text import Sense2VecEncoder, SentenceEncoder
5 | from embetter.grab import ColumnGrabber
6 |
7 |
8 | def test_readme():
9 | """Readme needs to be accurate"""
10 | check_md_file(fpath="README.md")
11 |
12 |
13 | objects = [
14 | ColumnGrabber,
15 | SentenceEncoder,
16 | Sense2VecEncoder,
17 | ColorHistogramEncoder,
18 | TimmEncoder,
19 | ImageLoader,
20 | ]
21 |
22 |
23 | @pytest.mark.parametrize("func", objects, ids=lambda d: d.__name__)
24 | def test_docstring(func):
25 | """Check the docstrings of the components"""
26 | check_docstring(obj=func)
27 |
--------------------------------------------------------------------------------
/embetter/text/_s2v.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sense2vec import Sense2Vec
3 |
4 | from embetter.base import BaseEstimator
5 |
6 |
7 | class Sense2VecEncoder(BaseEstimator):
8 | """
9 | Create a [Sense2Vec encoder](https://github.com/explosion/sense2vec), meant to
10 | help when encoding phrases as opposed to sentences.
11 |
12 | 
13 |
14 | Arguments:
15 | path: path to downloaded model
16 | """
17 |
18 | def __init__(self, path):
19 | self.s2v = Sense2Vec().from_disk(path)
20 |
21 | def transform(self, X, y=None):
22 | """Transforms the phrase text into a numeric representation."""
23 | return np.array([self.s2v[x] for x in X])
24 |
--------------------------------------------------------------------------------
/embetter/error.py:
--------------------------------------------------------------------------------
1 | class NotInstalled:
2 | """
3 | This object is used for optional dependencies. If a backend is not installed we
4 | replace the transformer/language with this object. This allows us to give a friendly
5 | message to the user that they need to install extra dependencies as well as a link
6 | to our documentation page.
7 | """
8 |
9 | def __init__(self, tool, dep):
10 | self.tool = tool
11 | self.dep = dep
12 |
13 | msg = f"In order to use {self.tool} you'll need to install via;\n\n"
14 | msg += f"pip install embetter[{self.dep}]\n\n"
15 | self.msg = msg
16 |
17 | def __getattr__(self, *args, **kwargs):
18 | raise ModuleNotFoundError(self.msg)
19 |
20 | def __call__(self, *args, **kwargs):
21 | raise ModuleNotFoundError(self.msg)
22 |
--------------------------------------------------------------------------------
/.github/workflows/unittest.yml:
--------------------------------------------------------------------------------
1 | name: Code Checks
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | build:
13 | runs-on: ubuntu-latest
14 | strategy:
15 | matrix:
16 | python-version: [3.7, 3.8, 3.9, "3.10"]
17 |
18 | steps:
19 | - uses: actions/checkout@v2
20 | - name: Set up Python ${{ matrix.python-version }}
21 | uses: actions/setup-python@v1
22 | with:
23 | python-version: ${{ matrix.python-version }}
24 | - name: Install Base Dependencies
25 | run: python -m pip install -e .
26 | - name: Install Testing Dependencies
27 | run: make install
28 | - name: Interrogate
29 | run: make interrogate
30 | - name: Black
31 | run: black -t py37 --check embetter tests setup.py
32 | - name: Flake8
33 | run: make flake
34 | - name: Unittest
35 | run: make test
36 |
--------------------------------------------------------------------------------
/tests/test_vision.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from embetter.vision import ImageLoader, ColorHistogramEncoder, TimmEncoder
3 |
4 |
5 | @pytest.mark.parametrize("n_buckets", [5, 10, 25, 128])
6 | def test_color_hist_resize(n_buckets):
7 | """Make sure we can resize and it fits"""
8 | X = ImageLoader().fit_transform(["tests/data/thiscatdoesnotexist.jpeg"])
9 | shape_out = ColorHistogramEncoder(n_buckets=n_buckets).fit_transform(X).shape
10 | shape_exp = (1, n_buckets * 3)
11 | assert shape_exp == shape_out
12 |
13 |
14 | @pytest.mark.parametrize("encode_predictions,size", [(True, 1000), (False, 1280)])
15 | def test_basic_timm(encode_predictions, size):
16 | """Super basic check for torch image model."""
17 | model = TimmEncoder("mobilenetv2_120d", encode_predictions=encode_predictions)
18 | X = ImageLoader().fit_transform(["tests/data/thiscatdoesnotexist.jpeg"])
19 | out = model.fit_transform(X)
20 | assert out.shape == (1, size)
21 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: docs
2 |
3 | black:
4 | black embetter tests setup.py
5 |
6 | flake:
7 | flake8 embetter tests setup.py
8 |
9 | test:
10 | pytest
11 |
12 | install:
13 | python -m pip install -e ".[dev]"
14 | pre-commit install
15 |
16 | interrogate:
17 | interrogate -vv --ignore-nested-functions --ignore-semiprivate --ignore-private --ignore-magic --ignore-module --ignore-init-method --fail-under 100 tests
18 | interrogate -vv --ignore-nested-functions --ignore-semiprivate --ignore-private --ignore-magic --ignore-module --ignore-init-method --fail-under 100 embetter
19 |
20 | pypi:
21 | python setup.py sdist
22 | python setup.py bdist_wheel --universal
23 | twine upload dist/*
24 |
25 | clean:
26 | rm -rf **/.ipynb_checkpoints **/.pytest_cache **/__pycache__ **/**/__pycache__ .ipynb_checkpoints .pytest_cache
27 |
28 | check: clean black flake interrogate test clean
29 |
30 | docs:
31 | cp README.md docs/index.md
32 | python -m mkdocs serve
33 |
34 | deploy-docs:
35 | cp README.md docs/index.md
36 | python -m mkdocs gh-deploy
37 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Vincent D. Warmerdam
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/embetter/grab.py:
--------------------------------------------------------------------------------
1 | from embetter.base import EmbetterBase
2 |
3 |
4 | class ColumnGrabber(EmbetterBase):
5 | """
6 | Component that can grab a pandas column as a list.
7 |
8 | 
9 |
10 | This can be useful when dealing with text encoders as these
11 | sometimes cannot deal with pandas columns.
12 |
13 | Arguments:
14 | colname: the column name to grab from a dataframe
15 |
16 | **Usage**
17 |
18 | In essense, the `ColumnGrabber` really just selects a single column.
19 |
20 | ```python
21 | import pandas as pd
22 | from embetter.grab import ColumnGrabber
23 |
24 | # Let's say we start we start with a csv file with filepaths
25 | data = {"filepaths": ["tests/data/thiscatdoesnotexist.jpeg"]}
26 | df = pd.DataFrame(data)
27 |
28 | # You can use the component in stand-alone fashion
29 | ColumnGrabber("filepaths").fit_transform(df)
30 | ```
31 |
32 | But the most common way to use the `ColumnGrabber` is part of a pipeline.
33 |
34 | ```python
35 | import pandas as pd
36 | from sklearn.pipeline import make_pipeline
37 |
38 | from embetter.grab import ColumnGrabber
39 | from embetter.vision import ImageLoader, ColorHistogramEncoder
40 |
41 | # Let's say we start we start with a csv file with filepaths
42 | data = {"filepaths": ["tests/data/thiscatdoesnotexist.jpeg"]}
43 | df = pd.DataFrame(data)
44 |
45 | # You can use the component in stand-alone fashion
46 | ColumnGrabber("filepaths").fit_transform(df)
47 |
48 | # But let's build a pipeline that grabs the column, turns it
49 | # into an image and embeds it.
50 | pipe = make_pipeline(
51 | ColumnGrabber("filepaths"),
52 | ImageLoader(),
53 | ColorHistogramEncoder()
54 | )
55 |
56 | pipe.fit_transform(df)
57 | ```
58 | """
59 |
60 | def __init__(self, colname: str) -> None:
61 | self.colname = colname
62 |
63 | def transform(self, X, y=None):
64 | """
65 | Takes a column from pandas and returns it as a list.
66 | """
67 | return [x for x in X[self.colname]]
68 |
--------------------------------------------------------------------------------
/embetter/vision/_colorhist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from embetter.base import EmbetterBase
3 |
4 |
5 | class ColorHistogramEncoder(EmbetterBase):
6 | """
7 | Encoder that generates an embedding based on the color histogram of the image.
8 |
9 | 
10 |
11 | Arguments:
12 | n_buckets: number of buckets per color
13 |
14 | **Usage**:
15 |
16 | ```python
17 | import pandas as pd
18 | from sklearn.pipeline import make_pipeline
19 |
20 | from embetter.grab import ColumnGrabber
21 | from embetter.vision import ImageLoader, ColorHistogramEncoder
22 |
23 | # Let's say we start we start with a csv file with filepaths
24 | data = {"filepaths": ["tests/data/thiscatdoesnotexist.jpeg"]}
25 | df = pd.DataFrame(data)
26 |
27 | # Let's build a pipeline that grabs the column, turns it
28 | # into an image and embeds it.
29 | pipe = make_pipeline(
30 | ColumnGrabber("filepaths"),
31 | ImageLoader(),
32 | ColorHistogramEncoder()
33 | )
34 |
35 | # This pipeline can now encode each image in the dataframe
36 | pipe.fit_transform(df)
37 | ```
38 | """
39 |
40 | def __init__(self, n_buckets=256):
41 | self.n_buckets = n_buckets
42 |
43 | def transform(self, X, y=None):
44 | """
45 | Takes a sequence of `PIL.Image` and returns a numpy array representing
46 | a color histogram for each.
47 | """
48 | output = np.zeros((len(X), self.n_buckets * 3))
49 | for i, x in enumerate(X):
50 | arr = np.array(x)
51 | output[i, :] = np.concatenate(
52 | [
53 | np.histogram(
54 | arr[:, :, 0].flatten(),
55 | bins=np.linspace(0, 255, self.n_buckets + 1),
56 | )[0],
57 | np.histogram(
58 | arr[:, :, 1].flatten(),
59 | bins=np.linspace(0, 255, self.n_buckets + 1),
60 | )[0],
61 | np.histogram(
62 | arr[:, :, 2].flatten(),
63 | bins=np.linspace(0, 255, self.n_buckets + 1),
64 | )[0],
65 | ]
66 | )
67 | return output
68 |
--------------------------------------------------------------------------------
/embetter/vision/_torchvis.py:
--------------------------------------------------------------------------------
1 | import timm
2 | from timm.data.transforms_factory import create_transform
3 | from timm.data import resolve_data_config
4 |
5 | import numpy as np
6 | from embetter.base import EmbetterBase
7 |
8 |
9 | class TimmEncoder(EmbetterBase):
10 | """
11 | Use a pretrained vision model from TorchVision to generate embeddings. Embeddings
12 | are provider via the lovely `timm` library.
13 |
14 | 
15 |
16 | You can find a list of available models [here](https://rwightman.github.io/pytorch-image-models/models/).
17 |
18 | Arguments:
19 | name: name of the model to use
20 | encode_predictions: output the predictions instead of the pooled embedding layer before
21 |
22 | **Usage**:
23 |
24 | ```python
25 | import pandas as pd
26 | from sklearn.pipeline import make_pipeline
27 |
28 | from embetter.grab import ColumnGrabber
29 | from embetter.vision import ImageLoader, TimmEncoder
30 |
31 | # Let's say we start we start with a csv file with filepaths
32 | data = {"filepaths": ["tests/data/thiscatdoesnotexist.jpeg"]}
33 | df = pd.DataFrame(data)
34 |
35 | # Let's build a pipeline that grabs the column, turns it
36 | # into an image and embeds it.
37 | pipe = make_pipeline(
38 | ColumnGrabber("filepaths"),
39 | ImageLoader(),
40 | TimmEncoder(name="mobilenetv3_large_100")
41 | )
42 |
43 | # This pipeline can now encode each image in the dataframe
44 | pipe.fit_transform(df)
45 | ```
46 | """
47 |
48 | def __init__(self, name="mobilenetv3_large_100", encode_predictions=False):
49 | self.name = name
50 | self.encode_predictions = encode_predictions
51 | self.model = timm.create_model(name, pretrained=True, num_classes=0)
52 | if self.encode_predictions:
53 | self.model = timm.create_model(name, pretrained=True)
54 | self.config = resolve_data_config({}, model=self.model)
55 | self.transform_img = create_transform(**self.config)
56 |
57 | def transform(self, X, y=None):
58 | """
59 | Transforms grabbed images into numeric representations.
60 | """
61 | batch = [self.transform_img(x).unsqueeze(0) for x in X]
62 | return np.array([self.model(x).squeeze(0).detach().numpy() for x in batch])
63 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from setuptools import setup, find_packages
3 |
4 |
5 | base_packages = ["scikit-learn>=1.0.0", "pandas>=1.0.0"]
6 |
7 | sentence_encoder_pkgs = ["sentence-transformers>=2.2.2"]
8 | sense2vec_pkgs = ["sense2vec==2.0.0"]
9 | text_packages = sentence_encoder_pkgs + sense2vec_pkgs
10 |
11 | vision_packages = ["timm>=0.6.7"]
12 |
13 | docs_packages = [
14 | "mkdocs==1.1",
15 | "mkdocs-material==4.6.3",
16 | "mkdocstrings==0.8.0",
17 | "mktestdocs==0.1.2",
18 | ]
19 |
20 | test_packages = [
21 | "interrogate>=1.5.0",
22 | "flake8>=3.6.0",
23 | "pytest>=4.0.2",
24 | "black>=19.3b0",
25 | "pre-commit>=2.2.0",
26 | "mktestdocs==0.1.2",
27 | ]
28 |
29 | all_packages = base_packages + text_packages + vision_packages
30 | dev_packages = all_packages + docs_packages + test_packages
31 |
32 |
33 | setup(
34 | name="embetter",
35 | version="0.2.0",
36 | author="Vincent D. Warmerdam",
37 | packages=find_packages(exclude=["notebooks", "docs"]),
38 | description="Just a bunch of useful embeddings to get started quickly.",
39 | long_description=pathlib.Path("README.md").read_text(),
40 | long_description_content_type="text/markdown",
41 | license_files = ("LICENSE"),
42 | url="https://koaning.github.io/embetter/",
43 | project_urls={
44 | "Documentation": "https://koaning.github.io/embetter/",
45 | "Source Code": "https://github.com/koaning/embetter/",
46 | "Issue Tracker": "https://github.com/koaning/embetter/issues",
47 | },
48 | install_requires=base_packages,
49 | extras_require={
50 | "sense2vec": sense2vec_pkgs + base_packages,
51 | "sentence-tfm": sentence_encoder_pkgs + base_packages,
52 | "text": text_packages + base_packages,
53 | "vision": vision_packages + base_packages,
54 | "all": all_packages,
55 | "dev": dev_packages,
56 | },
57 | classifiers=[
58 | "Intended Audience :: Science/Research",
59 | "Programming Language :: Python :: 3",
60 | "Programming Language :: Python :: 3.7",
61 | "Programming Language :: Python :: 3.8",
62 | "Programming Language :: Python :: 3.9",
63 | "Programming Language :: Python :: 3.10",
64 | "License :: OSI Approved :: MIT License",
65 | "Topic :: Scientific/Engineering",
66 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
67 | ],
68 | )
69 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | *.ipynb
131 | .vscode
132 |
--------------------------------------------------------------------------------
/embetter/vision/_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from embetter.base import EmbetterBase
4 |
5 |
6 | class ImageLoader(EmbetterBase):
7 | """
8 | Component that can turn filepaths into a list of PIL.Image objects.
9 |
10 | 
11 |
12 | Arguments:
13 | convert: Color [conversion setting](https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.convert) from the Python image library.
14 | out: What kind of image output format to expect.
15 |
16 | **Usage**
17 |
18 | You can use the `ImageLoader` in standalone fashion.
19 |
20 | ```python
21 | from embetter.vision import ImageLoader
22 |
23 | filepath = "tests/data/thiscatdoesnotexist.jpeg"
24 | ImageLoader(convert="RGB").fit_transform([filepath])
25 | ```
26 |
27 | But it's more common to see it part of a pipeline.
28 |
29 | ```python
30 | import pandas as pd
31 | from sklearn.pipeline import make_pipeline
32 |
33 | from embetter.grab import ColumnGrabber
34 | from embetter.vision import ImageLoader, ColorHistogramEncoder
35 |
36 | # Let's say we start we start with a csv file with filepaths
37 | data = {"filepaths": ["tests/data/thiscatdoesnotexist.jpeg"]}
38 | df = pd.DataFrame(data)
39 |
40 | # Let's build a pipeline that grabs the column, turns it
41 | # into an image and embeds it.
42 | pipe = make_pipeline(
43 | ColumnGrabber("filepaths"),
44 | ImageLoader(),
45 | ColorHistogramEncoder()
46 | )
47 |
48 | pipe.fit_transform(df)
49 | ```
50 |
51 | """
52 |
53 | def __init__(self, convert: str = "RGB", out: str = "pil") -> None:
54 | self.convert = convert
55 | self.out = out
56 |
57 | def fit(self, X, y=None):
58 | """
59 | Not actual "fitting" happens in this method, but it does check the input arguments
60 | per sklearn convention.
61 | """
62 | if self.out not in ["pil", "numpy"]:
63 | raise ValueError(
64 | f"Output format parameter out={self.out} must be either pil/numpy."
65 | )
66 | return self
67 |
68 | def transform(self, X, y=None):
69 | """
70 | Turn a file path into numpy array containing pixel values.
71 | """
72 | if self.out == "pil":
73 | return [Image.open(x).convert(self.convert) for x in X]
74 | if self.out == "numpy":
75 | return np.array([np.array(Image.open(x).convert(self.convert)) for x in X])
76 |
--------------------------------------------------------------------------------
/embetter/text/_sbert.py:
--------------------------------------------------------------------------------
1 | from sentence_transformers import SentenceTransformer as SBERT
2 | from embetter.base import EmbetterBase
3 |
4 |
5 | class SentenceEncoder(EmbetterBase):
6 | """
7 | Encoder that can numerically encode sentences.
8 |
9 | 
10 |
11 | Arguments:
12 | name: name of model, see available options
13 |
14 | The following model names should be supported:
15 |
16 | - `all-mpnet-base-v2`
17 | - `multi-qa-mpnet-base-dot-v1`
18 | - `all-distilroberta-v1`
19 | - `all-MiniLM-L12-v2`
20 | - `multi-qa-distilbert-cos-v1`
21 | - `all-MiniLM-L6-v2`
22 | - `multi-qa-MiniLM-L6-cos-v1`
23 | - `paraphrase-multilingual-mpnet-base-v2`
24 | - `paraphrase-albert-small-v2`
25 | - `paraphrase-multilingual-MiniLM-L12-v2`
26 | - `paraphrase-MiniLM-L3-v2`
27 | - `distiluse-base-multilingual-cased-v1`
28 | - `distiluse-base-multilingual-cased-v2`
29 |
30 | You can find the more options, and information, on the [sentence-transformers docs page](https://www.sbert.net/docs/pretrained_models.html#model-overview).
31 |
32 | **Usage**:
33 |
34 | ```python
35 | import pandas as pd
36 | from sklearn.pipeline import make_pipeline
37 | from sklearn.linear_model import LogisticRegression
38 |
39 | from embetter.grab import ColumnGrabber
40 | from embetter.text import SentenceEncoder
41 |
42 | # Let's suppose this is the input dataframe
43 | dataf = pd.DataFrame({
44 | "text": ["positive sentiment", "super negative"],
45 | "label_col": ["pos", "neg"]
46 | })
47 |
48 | # This pipeline grabs the `text` column from a dataframe
49 | # which then get fed into Sentence-Transformers' all-MiniLM-L6-v2.
50 | text_emb_pipeline = make_pipeline(
51 | ColumnGrabber("text"),
52 | SentenceEncoder('all-MiniLM-L6-v2')
53 | )
54 | X = text_emb_pipeline.fit_transform(dataf, dataf['label_col'])
55 |
56 | # This pipeline can also be trained to make predictions, using
57 | # the embedded features.
58 | text_clf_pipeline = make_pipeline(
59 | text_emb_pipeline,
60 | LogisticRegression()
61 | )
62 |
63 | # Prediction example
64 | text_clf_pipeline.fit(dataf, dataf['label_col']).predict(dataf)
65 | ```
66 | """
67 |
68 | def __init__(self, name="all-MiniLM-L6-v2"):
69 | self.name = name
70 | self.tfm = SBERT(name)
71 |
72 | def transform(self, X, y=None):
73 | """Transforms the text into a numeric representation."""
74 | return self.tfm.encode(X)
75 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # embetter
4 |
5 | > "Just a bunch of embeddings to get started quickly."
6 |
7 |
8 |
9 | Embetter implements scikit-learn compatible embeddings that should help get you started quickly.
10 |
11 | ## Install
12 |
13 | You can only install from Github, for now.
14 |
15 | ```
16 | python -m pip install "embetter @ git+https://github.com/koaning/embetter.git"
17 | ```
18 |
19 | ## API Design
20 |
21 | This is what's being implemented now.
22 |
23 | ```python
24 | # Helpers to grab text or image from pandas column.
25 | from embetter.grab import ColumnGrabber
26 |
27 | # Representations/Helpers for computer vision
28 | from embetter.vision import ImageLoader, TimmEncoder, ColorHistogramEncoder
29 |
30 | # Representations for text
31 | from embetter.text import SentenceEncoder, Sense2VecEncoder
32 | ```
33 |
34 |
35 | ## Text Example
36 |
37 | ```python
38 | import pandas as pd
39 | from sklearn.pipeline import make_pipeline
40 | from sklearn.linear_model import LogisticRegression
41 |
42 | from embetter.grab import ColumnGrabber
43 | from embetter.text import SentenceEncoder
44 |
45 | # This pipeline grabs the `text` column from a dataframe
46 | # which then get fed into Sentence-Transformers' all-MiniLM-L6-v2.
47 | text_emb_pipeline = make_pipeline(
48 | ColumnGrabber("text"),
49 | SentenceEncoder('all-MiniLM-L6-v2')
50 | )
51 |
52 | # This pipeline can also be trained to make predictions, using
53 | # the embedded features.
54 | text_clf_pipeline = make_pipeline(
55 | text_emb_pipeline,
56 | LogisticRegression()
57 | )
58 |
59 | dataf = pd.DataFrame({
60 | "text": ["positive sentiment", "super negative"],
61 | "label_col": ["pos", "neg"]
62 | })
63 | X = text_emb_pipeline.fit_transform(dataf, dataf['label_col'])
64 | text_clf_pipeline.fit(dataf, dataf['label_col']).predict(dataf)
65 | ```
66 |
67 | ## Image Example
68 |
69 | The goal of the API is to allow pipelines like this:
70 |
71 | ```python
72 | import pandas as pd
73 | from sklearn.pipeline import make_pipeline
74 | from sklearn.linear_model import LogisticRegression
75 |
76 | from embetter.grab import ColumnGrabber
77 | from embetter.vision import ImageLoader, TimmEncoder
78 |
79 | # This pipeline grabs the `img_path` column from a dataframe
80 | # then it grabs the image paths and turns them into `PIL.Image` objects
81 | # which then get fed into MobileNetv2 via TorchImageModels (timm).
82 | image_emb_pipeline = make_pipeline(
83 | ColumnGrabber("img_path"),
84 | ImageLoader(convert="RGB"),
85 | TimmEncoder("mobilenetv2_120d")
86 | )
87 |
88 | dataf = pd.DataFrame({
89 | "img_path": ["tests/data/thiscatdoesnotexist.jpeg"]
90 | })
91 | image_emb_pipeline.fit_transform(dataf)
92 | ```
93 |
94 | ## Batched Learning
95 |
96 | All of the encoding tools you've seen here are also compatible
97 | with the [`partial_fit` mechanic](https://scikit-learn.org/0.15/modules/scaling_strategies.html#incremental-learning)
98 | in scikit-learn. That means
99 | you can leverage [scikit-partial](https://github.com/koaning/scikit-partial)
100 | to build pipelines that can handle out-of-core datasets.
101 |
102 | ## Available Components
103 |
104 | The goal of the library is remain small but to offer a few general tools
105 | that might help with bulk labelling in particular, but general scikit-learn
106 | pipelines as well.
107 |
108 | | class | link | What it does |
109 | |:-------------------------:|------|-------------------------------------------------------------------------------------------------------|
110 | | `ColumnGrabber` | docs |  |
111 | | `SentenceEncoder` | docs |  |
112 | | `Sense2VecEncoder` | docs |  |
113 | | `ImageLoader` | docs |  |
114 | | `ColorHistogramEncoder` | docs |  |
115 | | `TimmEncoder` | docs |  |
116 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # embetter
4 |
5 | > "Just a bunch of useful embeddings to get started quickly."
6 |
7 |
8 |
9 | Embetter implements scikit-learn compatible embeddings for computer vision and text. It should make it very easy to quickly build proof of concepts using scikit-learn pipelines and, in particular, should help with [bulk labelling](https://www.youtube.com/watch?v=gDk7_f3ovIk). It's a also meant to play nice with [bulk](https://github.com/koaning/bulk) and [scikit-partial](https://github.com/koaning/scikit-partial).
10 |
11 | ## Install
12 |
13 | You can only install from Github, for now.
14 |
15 | ```
16 | python -m pip install embetter
17 | ```
18 |
19 | Many of the embeddings are optional depending on your use-case, so if you
20 | want to nit-pick to download only the tools that you need:
21 |
22 | ```
23 | python -m pip install "embetter[text]"
24 | python -m pip install "embetter[sense2vec]"
25 | python -m pip install "embetter[sentence-tfm]"
26 | python -m pip install "embetter[vision]"
27 | python -m pip install "embetter[all]"
28 | ```
29 |
30 | ## API Design
31 |
32 | This is what's being implemented now.
33 |
34 | ```python
35 | # Helpers to grab text or image from pandas column.
36 | from embetter.grab import ColumnGrabber
37 |
38 | # Representations/Helpers for computer vision
39 | from embetter.vision import ImageLoader, TimmEncoder, ColorHistogramEncoder
40 |
41 | # Representations for text
42 | from embetter.text import SentenceEncoder, Sense2VecEncoder
43 | ```
44 |
45 | All of these components are scikit-learn compatible, which means that you
46 | can apply them as you would normally in a scikit-learn pipeline. Just be aware
47 | that these components are stateless. They won't require training as these
48 | are all pretrained tools.
49 |
50 | ## Text Example
51 |
52 | ```python
53 | import pandas as pd
54 | from sklearn.pipeline import make_pipeline
55 | from sklearn.linear_model import LogisticRegression
56 |
57 | from embetter.grab import ColumnGrabber
58 | from embetter.text import SentenceEncoder
59 |
60 | # This pipeline grabs the `text` column from a dataframe
61 | # which then get fed into Sentence-Transformers' all-MiniLM-L6-v2.
62 | text_emb_pipeline = make_pipeline(
63 | ColumnGrabber("text"),
64 | SentenceEncoder('all-MiniLM-L6-v2')
65 | )
66 |
67 | # This pipeline can also be trained to make predictions, using
68 | # the embedded features.
69 | text_clf_pipeline = make_pipeline(
70 | text_emb_pipeline,
71 | LogisticRegression()
72 | )
73 |
74 | dataf = pd.DataFrame({
75 | "text": ["positive sentiment", "super negative"],
76 | "label_col": ["pos", "neg"]
77 | })
78 | X = text_emb_pipeline.fit_transform(dataf, dataf['label_col'])
79 | text_clf_pipeline.fit(dataf, dataf['label_col']).predict(dataf)
80 | ```
81 |
82 | ## Image Example
83 |
84 | The goal of the API is to allow pipelines like this:
85 |
86 | ```python
87 | import pandas as pd
88 | from sklearn.pipeline import make_pipeline
89 | from sklearn.linear_model import LogisticRegression
90 |
91 | from embetter.grab import ColumnGrabber
92 | from embetter.vision import ImageLoader, TimmEncoder
93 |
94 | # This pipeline grabs the `img_path` column from a dataframe
95 | # then it grabs the image paths and turns them into `PIL.Image` objects
96 | # which then get fed into MobileNetv2 via TorchImageModels (timm).
97 | image_emb_pipeline = make_pipeline(
98 | ColumnGrabber("img_path"),
99 | ImageLoader(convert="RGB"),
100 | TimmEncoder("mobilenetv2_120d")
101 | )
102 |
103 | dataf = pd.DataFrame({
104 | "img_path": ["tests/data/thiscatdoesnotexist.jpeg"]
105 | })
106 | image_emb_pipeline.fit_transform(dataf)
107 | ```
108 |
109 | ## Batched Learning
110 |
111 | All of the encoding tools you've seen here are also compatible
112 | with the [`partial_fit` mechanic](https://scikit-learn.org/0.15/modules/scaling_strategies.html#incremental-learning)
113 | in scikit-learn. That means
114 | you can leverage [scikit-partial](https://github.com/koaning/scikit-partial)
115 | to build pipelines that can handle out-of-core datasets.
116 |
117 | ## Available Components
118 |
119 | The goal of the library is remain small but to offer a few general tools
120 | that might help with bulk labelling in particular, but general scikit-learn
121 | pipelines as well.
122 |
123 | | class | link | What it does |
124 | |:-------------------------:|------------------------------------------------------|-------------------------------------------------------------------------------------------------------|
125 | | `ColumnGrabber` | [docs](https://koaning.github.io/embetter/API/grab/) |  |
126 | | `SentenceEncoder` | [docs](https://koaning.github.io/embetter/API/text/sentence-enc/) |  |
127 | | `Sense2VecEncoder` | [docs](https://koaning.github.io/embetter/API/text/sense2vec/) |  |
128 | | `ImageLoader` | [docs](https://koaning.github.io/embetter/API/vision/imageload/) |  |
129 | | `ColorHistogramEncoder` | [docs](https://koaning.github.io/embetter/API/vision/colorhist/) |  |
130 | | `TimmEncoder` | [docs](https://koaning.github.io/embetter/API/vision/timm/) |  |
131 |
--------------------------------------------------------------------------------