├── tests ├── __init__.py ├── test_default.py ├── data │ └── thiscatdoesnotexist.jpeg ├── test_base.py ├── test_docs.py └── test_vision.py ├── docs ├── API │ ├── grab.md │ ├── vision │ │ ├── timm.md │ │ ├── imageload.md │ │ └── colorhist.md │ └── text │ │ ├── sentence-enc.md │ │ └── sense2vec.md ├── images │ ├── icon.png │ ├── timm.png │ ├── sense2vec.png │ ├── imageloader.png │ ├── colorhistogram.png │ ├── columngrabber.png │ └── sentence-encoder.png └── index.md ├── .flake8 ├── embetter ├── __init__.py ├── base.py ├── vision │ ├── __init__.py │ ├── _colorhist.py │ ├── _torchvis.py │ └── _loader.py ├── text │ ├── __init__.py │ ├── _s2v.py │ └── _sbert.py ├── error.py └── grab.py ├── mkdocs.yml ├── .github └── workflows │ └── unittest.yml ├── Makefile ├── LICENCE ├── setup.py ├── .gitignore └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/API/grab.md: -------------------------------------------------------------------------------- 1 | # ColumnGrabber 2 | 3 | ::: embetter.grab.ColumnGrabber -------------------------------------------------------------------------------- /docs/API/vision/timm.md: -------------------------------------------------------------------------------- 1 | # TimmEncoder 2 | 3 | ::: embetter.vision.TimmEncoder -------------------------------------------------------------------------------- /docs/API/vision/imageload.md: -------------------------------------------------------------------------------- 1 | # ImageLoader 2 | 3 | ::: embetter.vision.ImageLoader 4 | -------------------------------------------------------------------------------- /docs/images/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/icon.png -------------------------------------------------------------------------------- /docs/images/timm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/timm.png -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 160 3 | ignore = E501, C901 4 | extend-ignore = E203, W503 -------------------------------------------------------------------------------- /tests/test_default.py: -------------------------------------------------------------------------------- 1 | def test_true(): 2 | """kind of a no-op test""" 3 | assert True 4 | -------------------------------------------------------------------------------- /docs/images/sense2vec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/sense2vec.png -------------------------------------------------------------------------------- /docs/API/vision/colorhist.md: -------------------------------------------------------------------------------- 1 | # ColorHistogramEncoder 2 | 3 | ::: embetter.vision.ColorHistogramEncoder 4 | -------------------------------------------------------------------------------- /docs/images/imageloader.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/imageloader.png -------------------------------------------------------------------------------- /docs/images/colorhistogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/colorhistogram.png -------------------------------------------------------------------------------- /docs/images/columngrabber.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/columngrabber.png -------------------------------------------------------------------------------- /docs/images/sentence-encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/embetter/main/docs/images/sentence-encoder.png -------------------------------------------------------------------------------- /tests/data/thiscatdoesnotexist.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/embetter/main/tests/data/thiscatdoesnotexist.jpeg -------------------------------------------------------------------------------- /docs/API/text/sentence-enc.md: -------------------------------------------------------------------------------- 1 | # SentenceEncoder 2 | 3 | ## `embetter.text.SentenceEncoder` 4 | 5 | ::: embetter.text.SentenceEncoder -------------------------------------------------------------------------------- /docs/API/text/sense2vec.md: -------------------------------------------------------------------------------- 1 | # Sense2VecEncoder 2 | 3 | ## `embetter.text.Sense2VecEncoder` 4 | 5 | ::: embetter.text.Sense2VecEncoder 6 | -------------------------------------------------------------------------------- /embetter/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from importlib import metadata 3 | except ImportError: # for Python<3.8 4 | import importlib_metadata as metadata 5 | 6 | 7 | __title__ = __name__ 8 | __version__ = metadata.version(__title__) 9 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Embetter Docs 2 | repo_url: https://github.com/koaning/embetter 3 | plugins: 4 | - mkdocstrings: 5 | custom_templates: templates 6 | theme: 7 | name: material 8 | logo: images/icon.png 9 | palette: 10 | primary: white 11 | markdown_extensions: 12 | - pymdownx.highlight: 13 | use_pygments: true 14 | - pymdownx.superfences 15 | 16 | -------------------------------------------------------------------------------- /tests/test_base.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from embetter.grab import ColumnGrabber 3 | 4 | 5 | def test_grab_column(): 6 | """Ensure that we can grab a text column.""" 7 | data = [{"text": "hi", "foo": 1}, {"text": "yes", "foo": 2}] 8 | dataframe = pd.DataFrame(data) 9 | out = ColumnGrabber("text").fit_transform(dataframe) 10 | assert out == ["hi", "yes"] 11 | -------------------------------------------------------------------------------- /embetter/base.py: -------------------------------------------------------------------------------- 1 | from sklearn.base import BaseEstimator, TransformerMixin 2 | 3 | 4 | class EmbetterBase(BaseEstimator, TransformerMixin): 5 | """Base class for feature transformers in this library""" 6 | 7 | def fit(self, X, y=None): 8 | """No-op.""" 9 | return self 10 | 11 | def partial_fit(self, X, y=None): 12 | """No-op.""" 13 | return self 14 | -------------------------------------------------------------------------------- /embetter/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from embetter.error import NotInstalled 2 | from embetter.vision._loader import ImageLoader 3 | from embetter.vision._colorhist import ColorHistogramEncoder 4 | 5 | try: 6 | from embetter.vision._torchvis import TimmEncoder 7 | except ModuleNotFoundError: 8 | TimmEncoder = NotInstalled("TimmEncoder", "vision") 9 | 10 | 11 | __all__ = ["ImageLoader", "ColorHistogramEncoder", "TimmEncoder"] 12 | -------------------------------------------------------------------------------- /embetter/text/__init__.py: -------------------------------------------------------------------------------- 1 | from embetter.error import NotInstalled 2 | 3 | try: 4 | from embetter.text._sbert import SentenceEncoder 5 | except ModuleNotFoundError: 6 | SentenceEncoder = NotInstalled("SentenceEncoder", "sbert") 7 | 8 | try: 9 | from embetter.text._s2v import Sense2VecEncoder 10 | except ModuleNotFoundError: 11 | Sense2VecEncoder = NotInstalled("Sense2VecEncoder", "sbert") 12 | 13 | 14 | __all__ = ["SentenceEncoder", "Sense2VecEncoder"] 15 | -------------------------------------------------------------------------------- /tests/test_docs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mktestdocs import check_md_file, check_docstring 3 | from embetter.vision import ColorHistogramEncoder, TimmEncoder, ImageLoader 4 | from embetter.text import Sense2VecEncoder, SentenceEncoder 5 | from embetter.grab import ColumnGrabber 6 | 7 | 8 | def test_readme(): 9 | """Readme needs to be accurate""" 10 | check_md_file(fpath="README.md") 11 | 12 | 13 | objects = [ 14 | ColumnGrabber, 15 | SentenceEncoder, 16 | Sense2VecEncoder, 17 | ColorHistogramEncoder, 18 | TimmEncoder, 19 | ImageLoader, 20 | ] 21 | 22 | 23 | @pytest.mark.parametrize("func", objects, ids=lambda d: d.__name__) 24 | def test_docstring(func): 25 | """Check the docstrings of the components""" 26 | check_docstring(obj=func) 27 | -------------------------------------------------------------------------------- /embetter/text/_s2v.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sense2vec import Sense2Vec 3 | 4 | from embetter.base import BaseEstimator 5 | 6 | 7 | class Sense2VecEncoder(BaseEstimator): 8 | """ 9 | Create a [Sense2Vec encoder](https://github.com/explosion/sense2vec), meant to 10 | help when encoding phrases as opposed to sentences. 11 | 12 | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/sense2vec.png) 13 | 14 | Arguments: 15 | path: path to downloaded model 16 | """ 17 | 18 | def __init__(self, path): 19 | self.s2v = Sense2Vec().from_disk(path) 20 | 21 | def transform(self, X, y=None): 22 | """Transforms the phrase text into a numeric representation.""" 23 | return np.array([self.s2v[x] for x in X]) 24 | -------------------------------------------------------------------------------- /embetter/error.py: -------------------------------------------------------------------------------- 1 | class NotInstalled: 2 | """ 3 | This object is used for optional dependencies. If a backend is not installed we 4 | replace the transformer/language with this object. This allows us to give a friendly 5 | message to the user that they need to install extra dependencies as well as a link 6 | to our documentation page. 7 | """ 8 | 9 | def __init__(self, tool, dep): 10 | self.tool = tool 11 | self.dep = dep 12 | 13 | msg = f"In order to use {self.tool} you'll need to install via;\n\n" 14 | msg += f"pip install embetter[{self.dep}]\n\n" 15 | self.msg = msg 16 | 17 | def __getattr__(self, *args, **kwargs): 18 | raise ModuleNotFoundError(self.msg) 19 | 20 | def __call__(self, *args, **kwargs): 21 | raise ModuleNotFoundError(self.msg) 22 | -------------------------------------------------------------------------------- /.github/workflows/unittest.yml: -------------------------------------------------------------------------------- 1 | name: Code Checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: [3.7, 3.8, 3.9, "3.10"] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v1 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install Base Dependencies 25 | run: python -m pip install -e . 26 | - name: Install Testing Dependencies 27 | run: make install 28 | - name: Interrogate 29 | run: make interrogate 30 | - name: Black 31 | run: black -t py37 --check embetter tests setup.py 32 | - name: Flake8 33 | run: make flake 34 | - name: Unittest 35 | run: make test 36 | -------------------------------------------------------------------------------- /tests/test_vision.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from embetter.vision import ImageLoader, ColorHistogramEncoder, TimmEncoder 3 | 4 | 5 | @pytest.mark.parametrize("n_buckets", [5, 10, 25, 128]) 6 | def test_color_hist_resize(n_buckets): 7 | """Make sure we can resize and it fits""" 8 | X = ImageLoader().fit_transform(["tests/data/thiscatdoesnotexist.jpeg"]) 9 | shape_out = ColorHistogramEncoder(n_buckets=n_buckets).fit_transform(X).shape 10 | shape_exp = (1, n_buckets * 3) 11 | assert shape_exp == shape_out 12 | 13 | 14 | @pytest.mark.parametrize("encode_predictions,size", [(True, 1000), (False, 1280)]) 15 | def test_basic_timm(encode_predictions, size): 16 | """Super basic check for torch image model.""" 17 | model = TimmEncoder("mobilenetv2_120d", encode_predictions=encode_predictions) 18 | X = ImageLoader().fit_transform(["tests/data/thiscatdoesnotexist.jpeg"]) 19 | out = model.fit_transform(X) 20 | assert out.shape == (1, size) 21 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: docs 2 | 3 | black: 4 | black embetter tests setup.py 5 | 6 | flake: 7 | flake8 embetter tests setup.py 8 | 9 | test: 10 | pytest 11 | 12 | install: 13 | python -m pip install -e ".[dev]" 14 | pre-commit install 15 | 16 | interrogate: 17 | interrogate -vv --ignore-nested-functions --ignore-semiprivate --ignore-private --ignore-magic --ignore-module --ignore-init-method --fail-under 100 tests 18 | interrogate -vv --ignore-nested-functions --ignore-semiprivate --ignore-private --ignore-magic --ignore-module --ignore-init-method --fail-under 100 embetter 19 | 20 | pypi: 21 | python setup.py sdist 22 | python setup.py bdist_wheel --universal 23 | twine upload dist/* 24 | 25 | clean: 26 | rm -rf **/.ipynb_checkpoints **/.pytest_cache **/__pycache__ **/**/__pycache__ .ipynb_checkpoints .pytest_cache 27 | 28 | check: clean black flake interrogate test clean 29 | 30 | docs: 31 | cp README.md docs/index.md 32 | python -m mkdocs serve 33 | 34 | deploy-docs: 35 | cp README.md docs/index.md 36 | python -m mkdocs gh-deploy 37 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Vincent D. Warmerdam 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /embetter/grab.py: -------------------------------------------------------------------------------- 1 | from embetter.base import EmbetterBase 2 | 3 | 4 | class ColumnGrabber(EmbetterBase): 5 | """ 6 | Component that can grab a pandas column as a list. 7 | 8 | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/columngrabber.png) 9 | 10 | This can be useful when dealing with text encoders as these 11 | sometimes cannot deal with pandas columns. 12 | 13 | Arguments: 14 | colname: the column name to grab from a dataframe 15 | 16 | **Usage** 17 | 18 | In essense, the `ColumnGrabber` really just selects a single column. 19 | 20 | ```python 21 | import pandas as pd 22 | from embetter.grab import ColumnGrabber 23 | 24 | # Let's say we start we start with a csv file with filepaths 25 | data = {"filepaths": ["tests/data/thiscatdoesnotexist.jpeg"]} 26 | df = pd.DataFrame(data) 27 | 28 | # You can use the component in stand-alone fashion 29 | ColumnGrabber("filepaths").fit_transform(df) 30 | ``` 31 | 32 | But the most common way to use the `ColumnGrabber` is part of a pipeline. 33 | 34 | ```python 35 | import pandas as pd 36 | from sklearn.pipeline import make_pipeline 37 | 38 | from embetter.grab import ColumnGrabber 39 | from embetter.vision import ImageLoader, ColorHistogramEncoder 40 | 41 | # Let's say we start we start with a csv file with filepaths 42 | data = {"filepaths": ["tests/data/thiscatdoesnotexist.jpeg"]} 43 | df = pd.DataFrame(data) 44 | 45 | # You can use the component in stand-alone fashion 46 | ColumnGrabber("filepaths").fit_transform(df) 47 | 48 | # But let's build a pipeline that grabs the column, turns it 49 | # into an image and embeds it. 50 | pipe = make_pipeline( 51 | ColumnGrabber("filepaths"), 52 | ImageLoader(), 53 | ColorHistogramEncoder() 54 | ) 55 | 56 | pipe.fit_transform(df) 57 | ``` 58 | """ 59 | 60 | def __init__(self, colname: str) -> None: 61 | self.colname = colname 62 | 63 | def transform(self, X, y=None): 64 | """ 65 | Takes a column from pandas and returns it as a list. 66 | """ 67 | return [x for x in X[self.colname]] 68 | -------------------------------------------------------------------------------- /embetter/vision/_colorhist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from embetter.base import EmbetterBase 3 | 4 | 5 | class ColorHistogramEncoder(EmbetterBase): 6 | """ 7 | Encoder that generates an embedding based on the color histogram of the image. 8 | 9 | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/colorhistogram.png) 10 | 11 | Arguments: 12 | n_buckets: number of buckets per color 13 | 14 | **Usage**: 15 | 16 | ```python 17 | import pandas as pd 18 | from sklearn.pipeline import make_pipeline 19 | 20 | from embetter.grab import ColumnGrabber 21 | from embetter.vision import ImageLoader, ColorHistogramEncoder 22 | 23 | # Let's say we start we start with a csv file with filepaths 24 | data = {"filepaths": ["tests/data/thiscatdoesnotexist.jpeg"]} 25 | df = pd.DataFrame(data) 26 | 27 | # Let's build a pipeline that grabs the column, turns it 28 | # into an image and embeds it. 29 | pipe = make_pipeline( 30 | ColumnGrabber("filepaths"), 31 | ImageLoader(), 32 | ColorHistogramEncoder() 33 | ) 34 | 35 | # This pipeline can now encode each image in the dataframe 36 | pipe.fit_transform(df) 37 | ``` 38 | """ 39 | 40 | def __init__(self, n_buckets=256): 41 | self.n_buckets = n_buckets 42 | 43 | def transform(self, X, y=None): 44 | """ 45 | Takes a sequence of `PIL.Image` and returns a numpy array representing 46 | a color histogram for each. 47 | """ 48 | output = np.zeros((len(X), self.n_buckets * 3)) 49 | for i, x in enumerate(X): 50 | arr = np.array(x) 51 | output[i, :] = np.concatenate( 52 | [ 53 | np.histogram( 54 | arr[:, :, 0].flatten(), 55 | bins=np.linspace(0, 255, self.n_buckets + 1), 56 | )[0], 57 | np.histogram( 58 | arr[:, :, 1].flatten(), 59 | bins=np.linspace(0, 255, self.n_buckets + 1), 60 | )[0], 61 | np.histogram( 62 | arr[:, :, 2].flatten(), 63 | bins=np.linspace(0, 255, self.n_buckets + 1), 64 | )[0], 65 | ] 66 | ) 67 | return output 68 | -------------------------------------------------------------------------------- /embetter/vision/_torchvis.py: -------------------------------------------------------------------------------- 1 | import timm 2 | from timm.data.transforms_factory import create_transform 3 | from timm.data import resolve_data_config 4 | 5 | import numpy as np 6 | from embetter.base import EmbetterBase 7 | 8 | 9 | class TimmEncoder(EmbetterBase): 10 | """ 11 | Use a pretrained vision model from TorchVision to generate embeddings. Embeddings 12 | are provider via the lovely `timm` library. 13 | 14 | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/timm.png) 15 | 16 | You can find a list of available models [here](https://rwightman.github.io/pytorch-image-models/models/). 17 | 18 | Arguments: 19 | name: name of the model to use 20 | encode_predictions: output the predictions instead of the pooled embedding layer before 21 | 22 | **Usage**: 23 | 24 | ```python 25 | import pandas as pd 26 | from sklearn.pipeline import make_pipeline 27 | 28 | from embetter.grab import ColumnGrabber 29 | from embetter.vision import ImageLoader, TimmEncoder 30 | 31 | # Let's say we start we start with a csv file with filepaths 32 | data = {"filepaths": ["tests/data/thiscatdoesnotexist.jpeg"]} 33 | df = pd.DataFrame(data) 34 | 35 | # Let's build a pipeline that grabs the column, turns it 36 | # into an image and embeds it. 37 | pipe = make_pipeline( 38 | ColumnGrabber("filepaths"), 39 | ImageLoader(), 40 | TimmEncoder(name="mobilenetv3_large_100") 41 | ) 42 | 43 | # This pipeline can now encode each image in the dataframe 44 | pipe.fit_transform(df) 45 | ``` 46 | """ 47 | 48 | def __init__(self, name="mobilenetv3_large_100", encode_predictions=False): 49 | self.name = name 50 | self.encode_predictions = encode_predictions 51 | self.model = timm.create_model(name, pretrained=True, num_classes=0) 52 | if self.encode_predictions: 53 | self.model = timm.create_model(name, pretrained=True) 54 | self.config = resolve_data_config({}, model=self.model) 55 | self.transform_img = create_transform(**self.config) 56 | 57 | def transform(self, X, y=None): 58 | """ 59 | Transforms grabbed images into numeric representations. 60 | """ 61 | batch = [self.transform_img(x).unsqueeze(0) for x in X] 62 | return np.array([self.model(x).squeeze(0).detach().numpy() for x in batch]) 63 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from setuptools import setup, find_packages 3 | 4 | 5 | base_packages = ["scikit-learn>=1.0.0", "pandas>=1.0.0"] 6 | 7 | sentence_encoder_pkgs = ["sentence-transformers>=2.2.2"] 8 | sense2vec_pkgs = ["sense2vec==2.0.0"] 9 | text_packages = sentence_encoder_pkgs + sense2vec_pkgs 10 | 11 | vision_packages = ["timm>=0.6.7"] 12 | 13 | docs_packages = [ 14 | "mkdocs==1.1", 15 | "mkdocs-material==4.6.3", 16 | "mkdocstrings==0.8.0", 17 | "mktestdocs==0.1.2", 18 | ] 19 | 20 | test_packages = [ 21 | "interrogate>=1.5.0", 22 | "flake8>=3.6.0", 23 | "pytest>=4.0.2", 24 | "black>=19.3b0", 25 | "pre-commit>=2.2.0", 26 | "mktestdocs==0.1.2", 27 | ] 28 | 29 | all_packages = base_packages + text_packages + vision_packages 30 | dev_packages = all_packages + docs_packages + test_packages 31 | 32 | 33 | setup( 34 | name="embetter", 35 | version="0.2.0", 36 | author="Vincent D. Warmerdam", 37 | packages=find_packages(exclude=["notebooks", "docs"]), 38 | description="Just a bunch of useful embeddings to get started quickly.", 39 | long_description=pathlib.Path("README.md").read_text(), 40 | long_description_content_type="text/markdown", 41 | license_files = ("LICENSE"), 42 | url="https://koaning.github.io/embetter/", 43 | project_urls={ 44 | "Documentation": "https://koaning.github.io/embetter/", 45 | "Source Code": "https://github.com/koaning/embetter/", 46 | "Issue Tracker": "https://github.com/koaning/embetter/issues", 47 | }, 48 | install_requires=base_packages, 49 | extras_require={ 50 | "sense2vec": sense2vec_pkgs + base_packages, 51 | "sentence-tfm": sentence_encoder_pkgs + base_packages, 52 | "text": text_packages + base_packages, 53 | "vision": vision_packages + base_packages, 54 | "all": all_packages, 55 | "dev": dev_packages, 56 | }, 57 | classifiers=[ 58 | "Intended Audience :: Science/Research", 59 | "Programming Language :: Python :: 3", 60 | "Programming Language :: Python :: 3.7", 61 | "Programming Language :: Python :: 3.8", 62 | "Programming Language :: Python :: 3.9", 63 | "Programming Language :: Python :: 3.10", 64 | "License :: OSI Approved :: MIT License", 65 | "Topic :: Scientific/Engineering", 66 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 67 | ], 68 | ) 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | *.ipynb 131 | .vscode 132 | -------------------------------------------------------------------------------- /embetter/vision/_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from embetter.base import EmbetterBase 4 | 5 | 6 | class ImageLoader(EmbetterBase): 7 | """ 8 | Component that can turn filepaths into a list of PIL.Image objects. 9 | 10 | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/imageloader.png) 11 | 12 | Arguments: 13 | convert: Color [conversion setting](https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.convert) from the Python image library. 14 | out: What kind of image output format to expect. 15 | 16 | **Usage** 17 | 18 | You can use the `ImageLoader` in standalone fashion. 19 | 20 | ```python 21 | from embetter.vision import ImageLoader 22 | 23 | filepath = "tests/data/thiscatdoesnotexist.jpeg" 24 | ImageLoader(convert="RGB").fit_transform([filepath]) 25 | ``` 26 | 27 | But it's more common to see it part of a pipeline. 28 | 29 | ```python 30 | import pandas as pd 31 | from sklearn.pipeline import make_pipeline 32 | 33 | from embetter.grab import ColumnGrabber 34 | from embetter.vision import ImageLoader, ColorHistogramEncoder 35 | 36 | # Let's say we start we start with a csv file with filepaths 37 | data = {"filepaths": ["tests/data/thiscatdoesnotexist.jpeg"]} 38 | df = pd.DataFrame(data) 39 | 40 | # Let's build a pipeline that grabs the column, turns it 41 | # into an image and embeds it. 42 | pipe = make_pipeline( 43 | ColumnGrabber("filepaths"), 44 | ImageLoader(), 45 | ColorHistogramEncoder() 46 | ) 47 | 48 | pipe.fit_transform(df) 49 | ``` 50 | 51 | """ 52 | 53 | def __init__(self, convert: str = "RGB", out: str = "pil") -> None: 54 | self.convert = convert 55 | self.out = out 56 | 57 | def fit(self, X, y=None): 58 | """ 59 | Not actual "fitting" happens in this method, but it does check the input arguments 60 | per sklearn convention. 61 | """ 62 | if self.out not in ["pil", "numpy"]: 63 | raise ValueError( 64 | f"Output format parameter out={self.out} must be either pil/numpy." 65 | ) 66 | return self 67 | 68 | def transform(self, X, y=None): 69 | """ 70 | Turn a file path into numpy array containing pixel values. 71 | """ 72 | if self.out == "pil": 73 | return [Image.open(x).convert(self.convert) for x in X] 74 | if self.out == "numpy": 75 | return np.array([np.array(Image.open(x).convert(self.convert)) for x in X]) 76 | -------------------------------------------------------------------------------- /embetter/text/_sbert.py: -------------------------------------------------------------------------------- 1 | from sentence_transformers import SentenceTransformer as SBERT 2 | from embetter.base import EmbetterBase 3 | 4 | 5 | class SentenceEncoder(EmbetterBase): 6 | """ 7 | Encoder that can numerically encode sentences. 8 | 9 | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/sentence-encoder.png) 10 | 11 | Arguments: 12 | name: name of model, see available options 13 | 14 | The following model names should be supported: 15 | 16 | - `all-mpnet-base-v2` 17 | - `multi-qa-mpnet-base-dot-v1` 18 | - `all-distilroberta-v1` 19 | - `all-MiniLM-L12-v2` 20 | - `multi-qa-distilbert-cos-v1` 21 | - `all-MiniLM-L6-v2` 22 | - `multi-qa-MiniLM-L6-cos-v1` 23 | - `paraphrase-multilingual-mpnet-base-v2` 24 | - `paraphrase-albert-small-v2` 25 | - `paraphrase-multilingual-MiniLM-L12-v2` 26 | - `paraphrase-MiniLM-L3-v2` 27 | - `distiluse-base-multilingual-cased-v1` 28 | - `distiluse-base-multilingual-cased-v2` 29 | 30 | You can find the more options, and information, on the [sentence-transformers docs page](https://www.sbert.net/docs/pretrained_models.html#model-overview). 31 | 32 | **Usage**: 33 | 34 | ```python 35 | import pandas as pd 36 | from sklearn.pipeline import make_pipeline 37 | from sklearn.linear_model import LogisticRegression 38 | 39 | from embetter.grab import ColumnGrabber 40 | from embetter.text import SentenceEncoder 41 | 42 | # Let's suppose this is the input dataframe 43 | dataf = pd.DataFrame({ 44 | "text": ["positive sentiment", "super negative"], 45 | "label_col": ["pos", "neg"] 46 | }) 47 | 48 | # This pipeline grabs the `text` column from a dataframe 49 | # which then get fed into Sentence-Transformers' all-MiniLM-L6-v2. 50 | text_emb_pipeline = make_pipeline( 51 | ColumnGrabber("text"), 52 | SentenceEncoder('all-MiniLM-L6-v2') 53 | ) 54 | X = text_emb_pipeline.fit_transform(dataf, dataf['label_col']) 55 | 56 | # This pipeline can also be trained to make predictions, using 57 | # the embedded features. 58 | text_clf_pipeline = make_pipeline( 59 | text_emb_pipeline, 60 | LogisticRegression() 61 | ) 62 | 63 | # Prediction example 64 | text_clf_pipeline.fit(dataf, dataf['label_col']).predict(dataf) 65 | ``` 66 | """ 67 | 68 | def __init__(self, name="all-MiniLM-L6-v2"): 69 | self.name = name 70 | self.tfm = SBERT(name) 71 | 72 | def transform(self, X, y=None): 73 | """Transforms the text into a numeric representation.""" 74 | return self.tfm.encode(X) 75 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # embetter 4 | 5 | > "Just a bunch of embeddings to get started quickly." 6 | 7 |
8 | 9 | Embetter implements scikit-learn compatible embeddings that should help get you started quickly. 10 | 11 | ## Install 12 | 13 | You can only install from Github, for now. 14 | 15 | ``` 16 | python -m pip install "embetter @ git+https://github.com/koaning/embetter.git" 17 | ``` 18 | 19 | ## API Design 20 | 21 | This is what's being implemented now. 22 | 23 | ```python 24 | # Helpers to grab text or image from pandas column. 25 | from embetter.grab import ColumnGrabber 26 | 27 | # Representations/Helpers for computer vision 28 | from embetter.vision import ImageLoader, TimmEncoder, ColorHistogramEncoder 29 | 30 | # Representations for text 31 | from embetter.text import SentenceEncoder, Sense2VecEncoder 32 | ``` 33 | 34 | 35 | ## Text Example 36 | 37 | ```python 38 | import pandas as pd 39 | from sklearn.pipeline import make_pipeline 40 | from sklearn.linear_model import LogisticRegression 41 | 42 | from embetter.grab import ColumnGrabber 43 | from embetter.text import SentenceEncoder 44 | 45 | # This pipeline grabs the `text` column from a dataframe 46 | # which then get fed into Sentence-Transformers' all-MiniLM-L6-v2. 47 | text_emb_pipeline = make_pipeline( 48 | ColumnGrabber("text"), 49 | SentenceEncoder('all-MiniLM-L6-v2') 50 | ) 51 | 52 | # This pipeline can also be trained to make predictions, using 53 | # the embedded features. 54 | text_clf_pipeline = make_pipeline( 55 | text_emb_pipeline, 56 | LogisticRegression() 57 | ) 58 | 59 | dataf = pd.DataFrame({ 60 | "text": ["positive sentiment", "super negative"], 61 | "label_col": ["pos", "neg"] 62 | }) 63 | X = text_emb_pipeline.fit_transform(dataf, dataf['label_col']) 64 | text_clf_pipeline.fit(dataf, dataf['label_col']).predict(dataf) 65 | ``` 66 | 67 | ## Image Example 68 | 69 | The goal of the API is to allow pipelines like this: 70 | 71 | ```python 72 | import pandas as pd 73 | from sklearn.pipeline import make_pipeline 74 | from sklearn.linear_model import LogisticRegression 75 | 76 | from embetter.grab import ColumnGrabber 77 | from embetter.vision import ImageLoader, TimmEncoder 78 | 79 | # This pipeline grabs the `img_path` column from a dataframe 80 | # then it grabs the image paths and turns them into `PIL.Image` objects 81 | # which then get fed into MobileNetv2 via TorchImageModels (timm). 82 | image_emb_pipeline = make_pipeline( 83 | ColumnGrabber("img_path"), 84 | ImageLoader(convert="RGB"), 85 | TimmEncoder("mobilenetv2_120d") 86 | ) 87 | 88 | dataf = pd.DataFrame({ 89 | "img_path": ["tests/data/thiscatdoesnotexist.jpeg"] 90 | }) 91 | image_emb_pipeline.fit_transform(dataf) 92 | ``` 93 | 94 | ## Batched Learning 95 | 96 | All of the encoding tools you've seen here are also compatible 97 | with the [`partial_fit` mechanic](https://scikit-learn.org/0.15/modules/scaling_strategies.html#incremental-learning) 98 | in scikit-learn. That means 99 | you can leverage [scikit-partial](https://github.com/koaning/scikit-partial) 100 | to build pipelines that can handle out-of-core datasets. 101 | 102 | ## Available Components 103 | 104 | The goal of the library is remain small but to offer a few general tools 105 | that might help with bulk labelling in particular, but general scikit-learn 106 | pipelines as well. 107 | 108 | | class | link | What it does | 109 | |:-------------------------:|------|-------------------------------------------------------------------------------------------------------| 110 | | `ColumnGrabber` | docs | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/columngrabber.png) | 111 | | `SentenceEncoder` | docs | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/sentence-encoder.png) | 112 | | `Sense2VecEncoder` | docs | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/sense2vec.png) | 113 | | `ImageLoader` | docs | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/imageloader.png) | 114 | | `ColorHistogramEncoder` | docs | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/colorhistogram.png) | 115 | | `TimmEncoder` | docs | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/timm.png) | 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # embetter 4 | 5 | > "Just a bunch of useful embeddings to get started quickly." 6 | 7 |
8 | 9 | Embetter implements scikit-learn compatible embeddings for computer vision and text. It should make it very easy to quickly build proof of concepts using scikit-learn pipelines and, in particular, should help with [bulk labelling](https://www.youtube.com/watch?v=gDk7_f3ovIk). It's a also meant to play nice with [bulk](https://github.com/koaning/bulk) and [scikit-partial](https://github.com/koaning/scikit-partial). 10 | 11 | ## Install 12 | 13 | You can only install from Github, for now. 14 | 15 | ``` 16 | python -m pip install embetter 17 | ``` 18 | 19 | Many of the embeddings are optional depending on your use-case, so if you 20 | want to nit-pick to download only the tools that you need: 21 | 22 | ``` 23 | python -m pip install "embetter[text]" 24 | python -m pip install "embetter[sense2vec]" 25 | python -m pip install "embetter[sentence-tfm]" 26 | python -m pip install "embetter[vision]" 27 | python -m pip install "embetter[all]" 28 | ``` 29 | 30 | ## API Design 31 | 32 | This is what's being implemented now. 33 | 34 | ```python 35 | # Helpers to grab text or image from pandas column. 36 | from embetter.grab import ColumnGrabber 37 | 38 | # Representations/Helpers for computer vision 39 | from embetter.vision import ImageLoader, TimmEncoder, ColorHistogramEncoder 40 | 41 | # Representations for text 42 | from embetter.text import SentenceEncoder, Sense2VecEncoder 43 | ``` 44 | 45 | All of these components are scikit-learn compatible, which means that you 46 | can apply them as you would normally in a scikit-learn pipeline. Just be aware 47 | that these components are stateless. They won't require training as these 48 | are all pretrained tools. 49 | 50 | ## Text Example 51 | 52 | ```python 53 | import pandas as pd 54 | from sklearn.pipeline import make_pipeline 55 | from sklearn.linear_model import LogisticRegression 56 | 57 | from embetter.grab import ColumnGrabber 58 | from embetter.text import SentenceEncoder 59 | 60 | # This pipeline grabs the `text` column from a dataframe 61 | # which then get fed into Sentence-Transformers' all-MiniLM-L6-v2. 62 | text_emb_pipeline = make_pipeline( 63 | ColumnGrabber("text"), 64 | SentenceEncoder('all-MiniLM-L6-v2') 65 | ) 66 | 67 | # This pipeline can also be trained to make predictions, using 68 | # the embedded features. 69 | text_clf_pipeline = make_pipeline( 70 | text_emb_pipeline, 71 | LogisticRegression() 72 | ) 73 | 74 | dataf = pd.DataFrame({ 75 | "text": ["positive sentiment", "super negative"], 76 | "label_col": ["pos", "neg"] 77 | }) 78 | X = text_emb_pipeline.fit_transform(dataf, dataf['label_col']) 79 | text_clf_pipeline.fit(dataf, dataf['label_col']).predict(dataf) 80 | ``` 81 | 82 | ## Image Example 83 | 84 | The goal of the API is to allow pipelines like this: 85 | 86 | ```python 87 | import pandas as pd 88 | from sklearn.pipeline import make_pipeline 89 | from sklearn.linear_model import LogisticRegression 90 | 91 | from embetter.grab import ColumnGrabber 92 | from embetter.vision import ImageLoader, TimmEncoder 93 | 94 | # This pipeline grabs the `img_path` column from a dataframe 95 | # then it grabs the image paths and turns them into `PIL.Image` objects 96 | # which then get fed into MobileNetv2 via TorchImageModels (timm). 97 | image_emb_pipeline = make_pipeline( 98 | ColumnGrabber("img_path"), 99 | ImageLoader(convert="RGB"), 100 | TimmEncoder("mobilenetv2_120d") 101 | ) 102 | 103 | dataf = pd.DataFrame({ 104 | "img_path": ["tests/data/thiscatdoesnotexist.jpeg"] 105 | }) 106 | image_emb_pipeline.fit_transform(dataf) 107 | ``` 108 | 109 | ## Batched Learning 110 | 111 | All of the encoding tools you've seen here are also compatible 112 | with the [`partial_fit` mechanic](https://scikit-learn.org/0.15/modules/scaling_strategies.html#incremental-learning) 113 | in scikit-learn. That means 114 | you can leverage [scikit-partial](https://github.com/koaning/scikit-partial) 115 | to build pipelines that can handle out-of-core datasets. 116 | 117 | ## Available Components 118 | 119 | The goal of the library is remain small but to offer a few general tools 120 | that might help with bulk labelling in particular, but general scikit-learn 121 | pipelines as well. 122 | 123 | | class | link | What it does | 124 | |:-------------------------:|------------------------------------------------------|-------------------------------------------------------------------------------------------------------| 125 | | `ColumnGrabber` | [docs](https://koaning.github.io/embetter/API/grab/) | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/columngrabber.png) | 126 | | `SentenceEncoder` | [docs](https://koaning.github.io/embetter/API/text/sentence-enc/) | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/sentence-encoder.png) | 127 | | `Sense2VecEncoder` | [docs](https://koaning.github.io/embetter/API/text/sense2vec/) | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/sense2vec.png) | 128 | | `ImageLoader` | [docs](https://koaning.github.io/embetter/API/vision/imageload/) | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/imageloader.png) | 129 | | `ColorHistogramEncoder` | [docs](https://koaning.github.io/embetter/API/vision/colorhist/) | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/colorhistogram.png) | 130 | | `TimmEncoder` | [docs](https://koaning.github.io/embetter/API/vision/timm/) | ![](https://raw.githubusercontent.com/koaning/embetter/main/docs/images/timm.png) | 131 | --------------------------------------------------------------------------------