├── benchmarks ├── __init__.py ├── data.py ├── results │ ├── train_benchmark_results.json │ └── train_test_benchmark_results.json └── run_benchmarks.py ├── semhash ├── __init__.py ├── version.py ├── records.py ├── utils.py ├── index.py ├── datamodels.py └── semhash.py ├── assets └── images │ ├── semhash_logo.png │ └── semhash_logo_v2.png ├── tests ├── data │ └── test_model │ │ ├── model.safetensors │ │ ├── config.json │ │ ├── modules.json │ │ └── README.md ├── conftest.py ├── test_datamodels.py └── test_semhash.py ├── Makefile ├── CITATION.cff ├── LICENSE ├── .pre-commit-config.yaml ├── .github └── workflows │ └── ci.yaml ├── pyproject.toml ├── .gitignore └── README.md /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /semhash/__init__.py: -------------------------------------------------------------------------------- 1 | from semhash.semhash import SemHash 2 | 3 | __all__ = ["SemHash"] 4 | -------------------------------------------------------------------------------- /assets/images/semhash_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/semhash/HEAD/assets/images/semhash_logo.png -------------------------------------------------------------------------------- /semhash/version.py: -------------------------------------------------------------------------------- 1 | __version_triple__ = (0, 3, 3) 2 | __version__ = ".".join(map(str, __version_triple__)) 3 | -------------------------------------------------------------------------------- /assets/images/semhash_logo_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/semhash/HEAD/assets/images/semhash_logo_v2.png -------------------------------------------------------------------------------- /tests/data/test_model/model.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/semhash/HEAD/tests/data/test_model/model.safetensors -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | clean: 2 | 3 | 4 | venv: 5 | uv venv 6 | 7 | install: venv 8 | uv sync --all-extras 9 | uv run pre-commit install 10 | 11 | install-no-pre-commit: 12 | uv pip install ".[dev]" 13 | 14 | fix: 15 | uv run pre-commit run --all-files 16 | 17 | test: 18 | uv run pytest --cov=semhash --cov-report=term-missing 19 | -------------------------------------------------------------------------------- /tests/data/test_model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "model2vec", 3 | "architectures": [ 4 | "StaticModel" 5 | ], 6 | "tokenizer_name": "baai/bge-base-en-v1.5", 7 | "apply_pca": 128, 8 | "apply_zipf": true, 9 | "hidden_dim": 128, 10 | "seq_length": 1000000, 11 | "normalize": true 12 | } 13 | -------------------------------------------------------------------------------- /tests/data/test_model/modules.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "idx": 0, 4 | "name": "0", 5 | "path": ".", 6 | "type": "sentence_transformers.models.StaticEmbedding" 7 | }, 8 | { 9 | "idx": 1, 10 | "name": "1", 11 | "path": "1_Normalize", 12 | "type": "sentence_transformers.models.Normalize" 13 | } 14 | ] 15 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use SemHash in your research, please cite it as below." 3 | title: "SemHash: Fast Semantic Text Deduplication & Filtering" 4 | authors: 5 | - family-names: "van Dongen" 6 | given-names: "Thomas" 7 | - family-names: "Tulkens" 8 | given-names: "Stephan" 9 | doi: 10.5281/zenodo.17265942 10 | license: MIT 11 | url: "https://github.com/MinishLab/semhash" 12 | repository-code: "https://github.com/MinishLab/semhash" 13 | date-released: "2025-01-05" 14 | 15 | preferred-citation: 16 | type: software 17 | title: "SemHash: Fast Semantic Text Deduplication & Filtering" 18 | authors: 19 | - family-names: "van Dongen" 20 | given-names: "Thomas" 21 | - family-names: "Tulkens" 22 | given-names: "Stephan" 23 | year: 2025 24 | publisher: Zenodo 25 | doi: 10.5281/zenodo.17265942 26 | url: "https://github.com/MinishLab/semhash" 27 | license: MIT 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 The Minish Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.4.0 6 | hooks: 7 | - id: check-ast 8 | description: Simply check whether files parse as valid python. 9 | - id: trailing-whitespace 10 | description: Trims trailing whitespace 11 | - id: end-of-file-fixer 12 | description: Makes sure files end in a newline and only a newline. 13 | - id: check-added-large-files 14 | args: ['--maxkb=5000'] 15 | description: Prevent giant files from being committed. 16 | - id: check-case-conflict 17 | description: Check for files with names that would conflict on case-insensitive filesystems like MacOS/Windows. 18 | - repo: https://github.com/jsh9/pydoclint 19 | rev: 0.5.3 20 | hooks: 21 | - id: pydoclint 22 | - repo: https://github.com/astral-sh/ruff-pre-commit 23 | rev: v0.4.10 24 | hooks: 25 | - id: ruff 26 | args: [ --fix ] 27 | - id: ruff-format 28 | - repo: local 29 | hooks: 30 | - id: mypy 31 | name: mypy 32 | entry: mypy 33 | language: python 34 | types: [python] 35 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from model2vec import StaticModel 3 | 4 | 5 | @pytest.fixture 6 | def model() -> StaticModel: 7 | """Load a model for testing.""" 8 | return StaticModel.from_pretrained("tests/data/test_model") 9 | 10 | 11 | @pytest.fixture(params=[True, False], ids=["use_ann=True", "use_ann=False"]) 12 | def use_ann(request: pytest.FixtureRequest) -> bool: 13 | """Whether to use approximate nearest neighbors or not.""" 14 | return request.param 15 | 16 | 17 | @pytest.fixture 18 | def train_texts() -> list[str]: 19 | """A list of train texts for testing outlier and representative filtering.""" 20 | return [ 21 | "apple", 22 | "banana", 23 | "cherry", 24 | "strawberry", 25 | "blueberry", 26 | "raspberry", 27 | "blackberry", 28 | "peach", 29 | "plum", 30 | "grape", 31 | "mango", 32 | "papaya", 33 | "pineapple", 34 | "watermelon", 35 | "orange", 36 | "lemon", 37 | "lime", 38 | "tangerine", 39 | "car", # Outlier 40 | "bicycle", # Outlier 41 | ] 42 | 43 | 44 | @pytest.fixture 45 | def test_texts() -> list[str]: 46 | """A list of test texts for testing outlier and representative filtering.""" 47 | return [ 48 | "apple", 49 | "banana", 50 | "kiwi", 51 | "fig", 52 | "apricot", 53 | "grapefruit", 54 | "pomegranate", 55 | "motorcycle", # Outlier 56 | "plane", # Outlier 57 | ] 58 | -------------------------------------------------------------------------------- /semhash/records.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from semhash.datamodels import DeduplicationResult, DuplicateRecord 4 | 5 | 6 | def dict_to_string(record: dict[str, str], columns: Sequence[str]) -> str: 7 | r""" 8 | Turn a record into a single string. 9 | 10 | Uses self.columns to determine the order of the text segments. 11 | Each text is cleaned by replacing '\t' with ' '. The texts are then joined by '\t'. 12 | 13 | :param record: A record to unpack. 14 | :param columns: Columns to unpack. 15 | :return: A single string representation of the record. 16 | """ 17 | return "\t".join(record.get(c, "").replace("\t", " ") for c in columns) 18 | 19 | 20 | def map_deduplication_result_to_strings(result: DeduplicationResult, columns: Sequence[str]) -> DeduplicationResult: 21 | """Convert the record and duplicates in each DuplicateRecord back to strings if self.was_string is True.""" 22 | deduplicated_str = [dict_to_string(r, columns) for r in result.selected] 23 | mapped = [] 24 | for dup_rec in result.duplicates: 25 | record_as_str = dict_to_string(dup_rec.record, columns) 26 | duplicates_as_str = [(dict_to_string(r, columns), score) for r, score in dup_rec.duplicates] 27 | mapped.append( 28 | DuplicateRecord( 29 | record=record_as_str, 30 | duplicates=duplicates_as_str, 31 | exact=dup_rec.exact, 32 | ) 33 | ) 34 | return DeduplicationResult(selected=deduplicated_str, filtered=mapped, threshold=result.threshold) 35 | 36 | 37 | def add_scores_to_records(records: list[dict[str, str]]) -> list[tuple[dict[str, str], float]]: 38 | """Add scores to records and return a DeduplicationResult.""" 39 | return [(record, 1.0) for record in records] 40 | -------------------------------------------------------------------------------- /semhash/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Protocol, Sequence, Union 2 | 3 | import numpy as np 4 | from frozendict import frozendict 5 | 6 | 7 | class Encoder(Protocol): 8 | """An encoder protocol for SemHash.""" 9 | 10 | def encode( 11 | self, 12 | sentences: Union[list[str], str, Sequence[str]], 13 | **kwargs: Any, 14 | ) -> np.ndarray: 15 | """ 16 | Encode a list of sentences into embeddings. 17 | 18 | :param sentences: A list of sentences to encode. 19 | :param **kwargs: Additional keyword arguments. 20 | :return: The embeddings of the sentences. 21 | """ 22 | ... # pragma: no cover 23 | 24 | 25 | def to_frozendict(record: dict[str, str], columns: set[str]) -> frozendict[str, str]: 26 | """Convert a record to a frozendict.""" 27 | return frozendict({k: record.get(k, "") for k in columns}) 28 | 29 | 30 | def compute_candidate_limit( 31 | total: int, 32 | selection_size: int, 33 | fraction: float = 0.1, 34 | min_candidates: int = 100, 35 | max_candidates: int = 1000, 36 | ) -> int: 37 | """ 38 | Compute the 'auto' candidate limit based on the total number of records. 39 | 40 | :param total: Total number of records. 41 | :param selection_size: Number of representatives to select. 42 | :param fraction: Fraction of total records to consider as candidates. 43 | :param min_candidates: Minimum number of candidates. 44 | :param max_candidates: Maximum number of candidates. 45 | :return: Computed candidate limit. 46 | """ 47 | # 1) fraction of total 48 | limit = int(total * fraction) 49 | # 2) ensure enough to pick selection_size 50 | limit = max(limit, selection_size) 51 | # 3) enforce lower bound 52 | limit = max(limit, min_candidates) 53 | # 4) enforce upper bound (and never exceed the dataset) 54 | limit = min(limit, max_candidates, total) 55 | return limit 56 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: Run tests and upload coverage 2 | 3 | on: 4 | push 5 | 6 | jobs: 7 | test: 8 | name: Run tests with pytest 9 | runs-on: ${{ matrix.os }} 10 | strategy: 11 | matrix: 12 | os: ["ubuntu-latest", "windows-latest"] 13 | python-version: ["3.9", "3.10", "3.11", "3.12"] 14 | exclude: 15 | - os: windows-latest 16 | python-version: "3.9" 17 | - os: windows-latest 18 | python-version: "3.11" 19 | - os: windows-latest 20 | python-version: "3.12" 21 | fail-fast: false 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | 26 | - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }} 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | allow-prereleases: true 31 | 32 | # Step for Windows: Create and activate a virtual environment 33 | - name: Create and activate a virtual environment (Windows) 34 | if: ${{ runner.os == 'Windows' }} 35 | run: | 36 | irm https://astral.sh/uv/install.ps1 | iex 37 | $env:Path = "C:\Users\runneradmin\.local\bin;$env:Path" 38 | uv venv .venv 39 | "VIRTUAL_ENV=.venv" | Out-File -FilePath $env:GITHUB_ENV -Append 40 | "$PWD/.venv/Scripts" | Out-File -FilePath $env:GITHUB_PATH -Append 41 | 42 | # Step for Unix: Create and activate a virtual environment 43 | - name: Create and activate a virtual environment (Unix) 44 | if: ${{ runner.os != 'Windows' }} 45 | run: | 46 | curl -LsSf https://astral.sh/uv/install.sh | sh 47 | uv venv .venv 48 | echo "VIRTUAL_ENV=.venv" >> $GITHUB_ENV 49 | echo "$PWD/.venv/bin" >> $GITHUB_PATH 50 | 51 | # Install dependencies using uv pip 52 | - name: Install dependencies 53 | run: make install-no-pre-commit 54 | 55 | # Run tests with coverage 56 | - name: Run tests under coverage 57 | run: | 58 | coverage run -m pytest 59 | coverage report 60 | 61 | # Upload results to Codecov 62 | - name: Upload results to Codecov 63 | uses: codecov/codecov-action@v4 64 | with: 65 | token: ${{ secrets.CODECOV_TOKEN }} 66 | -------------------------------------------------------------------------------- /benchmarks/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class DatasetRecord: 6 | """Dataset record.""" 7 | 8 | name: str 9 | text_name: str | None = None 10 | label_name: str | None = None 11 | sub_directory: str = "" 12 | columns: list[str] | None = None 13 | split_one: str = "train" 14 | split_two: str = "test" 15 | 16 | 17 | DATASET_DICT: dict[str, DatasetRecord] = { 18 | "bbc": DatasetRecord(name="SetFit/bbc-news", text_name="text", label_name="label_text"), 19 | "senteval_cr": DatasetRecord(name="SetFit/SentEval-CR", text_name="text", label_name="label_text"), 20 | "tweet_sentiment_extraction": DatasetRecord( 21 | name="SetFit/tweet_sentiment_extraction", text_name="text", label_name="label_text" 22 | ), 23 | "emotion": DatasetRecord(name="SetFit/emotion", text_name="text", label_name="label_text"), 24 | "amazon_counterfactual": DatasetRecord( 25 | name="SetFit/amazon_counterfactual_en", text_name="text", label_name="label_text" 26 | ), 27 | "ag_news": DatasetRecord(name="SetFit/ag_news", text_name="text", label_name="label_text"), 28 | "enron_spam": DatasetRecord(name="SetFit/enron_spam", text_name="text", label_name="label_text"), 29 | "subj": DatasetRecord(name="SetFit/subj", text_name="text", label_name="label_text"), 30 | "sst5": DatasetRecord(name="SetFit/sst5", text_name="text", label_name="label_text"), 31 | "20_newgroups": DatasetRecord(name="SetFit/20_newsgroups", text_name="text", label_name="label_text"), 32 | "hatespeech_offensive": DatasetRecord(name="SetFit/hate_speech_offensive", text_name="text", label_name="label"), 33 | "ade": DatasetRecord(name="SetFit/ade_corpus_v2_classification", text_name="text", label_name="label"), 34 | "imdb": DatasetRecord(name="SetFit/imdb", text_name="text", label_name="label"), 35 | "massive_scenario": DatasetRecord( 36 | name="SetFit/amazon_massive_scenario_en-US", text_name="text", label_name="label" 37 | ), 38 | "student": DatasetRecord(name="SetFit/student-question-categories", text_name="text", label_name="label"), 39 | "squad_v2": DatasetRecord(name="squad_v2", columns=["question", "context"], split_two="validation"), 40 | "wikitext": DatasetRecord( 41 | name="Salesforce/wikitext", text_name="text", label_name="text", sub_directory="wikitext-103-raw-v1" 42 | ), 43 | } 44 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "semhash" 3 | description = "Fast Semantic Text Deduplication & Filtering" 4 | authors = [{name = "Thomas van Dongen", email = "thomas123@live.nl"}, { name = "Stéphan Tulkens", email = "stephantul@gmail.com"}] 5 | readme = { file = "README.md", content-type = "text/markdown" } 6 | dynamic = ["version"] 7 | license = { file = "LICENSE" } 8 | requires-python = ">=3.9" 9 | 10 | classifiers = [ 11 | "Development Status :: 4 - Beta", 12 | "Intended Audience :: Developers", 13 | "Intended Audience :: Science/Research", 14 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 15 | "Topic :: Software Development :: Libraries", 16 | "License :: OSI Approved :: MIT License", 17 | "Programming Language :: Python :: 3 :: Only", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Natural Language :: English", 23 | ] 24 | 25 | dependencies = [ 26 | "model2vec>=0.3.4", 27 | "vicinity[usearch]>=0.4.3", 28 | "frozendict", 29 | ] 30 | 31 | [build-system] 32 | requires = ["setuptools>=64", "setuptools_scm>=8"] 33 | build-backend = "setuptools.build_meta" 34 | 35 | [project.optional-dependencies] 36 | dev = [ 37 | "black", 38 | "ipython", 39 | "mypy", 40 | "pre-commit", 41 | "pytest", 42 | "pytest-coverage", 43 | "ruff", 44 | ] 45 | 46 | [project.urls] 47 | "Homepage" = "https://github.com/MinishLab" 48 | "Bug Reports" = "https://github.com/MinishLab/semhash/issues" 49 | "Source" = "https://github.com/MinishLab/semhash" 50 | 51 | [tool.ruff] 52 | exclude = [".venv/"] 53 | line-length = 120 54 | target-version = "py310" 55 | 56 | [tool.ruff.lint] 57 | select = [ 58 | # Annotations: Enforce type annotations 59 | "ANN", 60 | # Complexity: Enforce a maximum cyclomatic complexity 61 | "C90", 62 | # Pydocstyle: Enforce docstrings 63 | "D", 64 | # Isort: Enforce import order 65 | "I", 66 | # Numpy: Enforce numpy style 67 | "NPY", 68 | # Print: Forbid print statements 69 | "T20", 70 | ] 71 | ignore = [ 72 | # Allow self and cls to be untyped, and allow Any type 73 | "ANN101", "ANN102", "ANN401", 74 | # Pydocstyle ignores 75 | "D100", "D101", "D104", "D203", "D212", "D401", 76 | # Allow use of f-strings in logging 77 | "G004" 78 | ] 79 | 80 | [tool.pydoclint] 81 | style = "sphinx" 82 | exclude = "test_" 83 | allow-init-docstring = true 84 | arg-type-hints-in-docstring = false 85 | check-return-types = false 86 | require-return-section-when-returning-nothing = false 87 | 88 | [tool.mypy] 89 | python_version = "3.10" 90 | warn_unused_configs = true 91 | ignore_missing_imports = true 92 | 93 | [tool.setuptools] 94 | packages = ["semhash"] 95 | license-files = [] 96 | 97 | [tool.setuptools_scm] 98 | # can be empty if no extra settings are needed, presence enables setuptools_scm 99 | 100 | [tool.setuptools.dynamic] 101 | version = {attr = "semhash.version.__version__"} 102 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | local/* 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /semhash/index.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | import numpy as np 6 | from vicinity import Backend 7 | from vicinity.backends import AbstractBackend, get_backend_class 8 | from vicinity.datatypes import SingleQueryResult 9 | 10 | DocScore = tuple[dict[str, str], float] 11 | DocScores = list[DocScore] 12 | DictItem = list[dict[str, str]] 13 | 14 | 15 | class Index: 16 | def __init__(self, vectors: np.ndarray, items: list[DictItem], backend: AbstractBackend) -> None: 17 | """ 18 | An index that maps vectors to items. 19 | 20 | This index has an efficient backend for querying, but also explicitly stores the vectors in memory. 21 | 22 | :param vectors: The vectors of the items. 23 | :param items: The items in the index. This is a list of lists. Each sublist contains one or more dictionaries 24 | that represent records. These records are exact duplicates of each other. 25 | :param backend: The backend to use for querying. 26 | """ 27 | self.items = items 28 | self.backend = backend 29 | self.vectors = vectors 30 | 31 | @classmethod 32 | def from_vectors_and_items( 33 | cls, vectors: np.ndarray, items: list[DictItem], backend_type: Backend | str, **kwargs: Any 34 | ) -> Index: 35 | """ 36 | Load the index from vectors and items. 37 | 38 | :param vectors: The vectors of the items. 39 | :param items: The items in the index. 40 | :param backend_type: The type of backend to use. 41 | :param **kwargs: Additional arguments to pass to the backend. 42 | :return: The index. 43 | """ 44 | backend_class = get_backend_class(backend_type) 45 | arguments = backend_class.argument_class(**kwargs) 46 | backend = backend_class.from_vectors(vectors, **arguments.dict()) 47 | 48 | return cls(vectors, items, backend) 49 | 50 | def query_threshold(self, vectors: np.ndarray, threshold: float) -> list[DocScores]: 51 | """ 52 | Query the index with a threshold. 53 | 54 | :param vectors: The vectors to query. 55 | :param threshold: The similarity threshold. 56 | :return: The query results. 57 | """ 58 | out: list[DocScores] = [] 59 | for result in self.backend.threshold(vectors, threshold=1 - threshold, max_k=100): 60 | intermediate = [] 61 | for index, distance in zip(*result): 62 | # Every item in the index contains one or more records. 63 | # These are all exact duplicates, so they get the same score. 64 | for record in self.items[index]: 65 | # The score is the cosine similarity. 66 | # The backend returns distances, so we need to convert. 67 | intermediate.append((record, 1 - distance)) 68 | out.append(intermediate) 69 | 70 | return out 71 | 72 | def query_top_k(self, vectors: np.ndarray, k: int, vectors_are_in_index: bool) -> list[SingleQueryResult]: 73 | """ 74 | Query the index with a top-k. 75 | 76 | :param vectors: The vectors to query. 77 | :param k: Maximum number of top-k records to keep. 78 | :param vectors_are_in_index: Whether the vectors are in the index. If this is set to True, we retrieve k + 1 79 | records, and do not consider the first one, as it is the query vector itself. 80 | :return: The query results. Each result is a tuple where the first element is the list of neighbor records, 81 | and the second element is a NumPy array of cosine similarity scores. 82 | """ 83 | results = [] 84 | offset = int(vectors_are_in_index) 85 | for x, y in self.backend.query(vectors=vectors, k=k + offset): 86 | # Convert returned distances to cosine similarities. 87 | similarities = 1 - y[offset:] 88 | results.append((x[offset:], similarities)) 89 | return results 90 | -------------------------------------------------------------------------------- /tests/data/test_model/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | library_name: model2vec 3 | license: mit 4 | model_name: potion-base-4m-int8 5 | tags: 6 | - embeddings 7 | - static-embeddings 8 | - sentence-transformers 9 | --- 10 | 11 | # potion-base-4m-int8 Model Card 12 | 13 | This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of a Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical. Model2Vec models are the smallest, fastest, and most performant static embedders available. The distilled models are up to 50 times smaller and 500 times faster than traditional Sentence Transformers. 14 | 15 | 16 | ## Installation 17 | 18 | Install model2vec using pip: 19 | ``` 20 | pip install model2vec 21 | ``` 22 | 23 | ## Usage 24 | 25 | ### Using Model2Vec 26 | 27 | The [Model2Vec library](https://github.com/MinishLab/model2vec) is the fastest and most lightweight way to run Model2Vec models. 28 | 29 | Load this model using the `from_pretrained` method: 30 | ```python 31 | from model2vec import StaticModel 32 | 33 | # Load a pretrained Model2Vec model 34 | model = StaticModel.from_pretrained("potion-base-4m-int8") 35 | 36 | # Compute text embeddings 37 | embeddings = model.encode(["Example sentence"]) 38 | ``` 39 | 40 | ### Using Sentence Transformers 41 | 42 | You can also use the [Sentence Transformers library](https://github.com/UKPLab/sentence-transformers) to load and use the model: 43 | 44 | ```python 45 | from sentence_transformers import SentenceTransformer 46 | 47 | # Load a pretrained Sentence Transformer model 48 | model = SentenceTransformer("potion-base-4m-int8") 49 | 50 | # Compute text embeddings 51 | embeddings = model.encode(["Example sentence"]) 52 | ``` 53 | 54 | ### Distilling a Model2Vec model 55 | 56 | You can distill a Model2Vec model from a Sentence Transformer model using the `distill` method. First, install the `distill` extra with `pip install model2vec[distill]`. Then, run the following code: 57 | 58 | ```python 59 | from model2vec.distill import distill 60 | 61 | # Distill a Sentence Transformer model, in this case the BAAI/bge-base-en-v1.5 model 62 | m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=256) 63 | 64 | # Save the model 65 | m2v_model.save_pretrained("m2v_model") 66 | ``` 67 | 68 | ## How it works 69 | 70 | Model2vec creates a small, fast, and powerful model that outperforms other static embedding models by a large margin on all tasks we could find, while being much faster to create than traditional static embedding models such as GloVe. Best of all, you don't need any data to distill a model using Model2Vec. 71 | 72 | It works by passing a vocabulary through a sentence transformer model, then reducing the dimensionality of the resulting embeddings using PCA, and finally weighting the embeddings using [SIF weighting](https://openreview.net/pdf?id=SyK00v5xx). During inference, we simply take the mean of all token embeddings occurring in a sentence. 73 | 74 | ## Additional Resources 75 | 76 | - [Model2Vec Repo](https://github.com/MinishLab/model2vec) 77 | - [Model2Vec Base Models](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e) 78 | - [Model2Vec Results](https://github.com/MinishLab/model2vec/tree/main/results) 79 | - [Model2Vec Tutorials](https://github.com/MinishLab/model2vec/tree/main/tutorials) 80 | - [Website](https://minishlab.github.io/) 81 | 82 | 83 | ## Library Authors 84 | 85 | Model2Vec was developed by the [Minish Lab](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled). 86 | 87 | ## Citation 88 | 89 | Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) if you use this model in your work. 90 | ``` 91 | @article{minishlab2024model2vec, 92 | author = {Tulkens, Stephan and {van Dongen}, Thomas}, 93 | title = {Model2Vec: Fast State-of-the-Art Static Embeddings}, 94 | year = {2024}, 95 | url = {https://github.com/MinishLab/model2vec} 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /benchmarks/results/train_benchmark_results.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "dataset": "bbc", 4 | "original_train_size": 1225, 5 | "deduplicated_train_size": 1144, 6 | "percent_removed": 6.612244897959185, 7 | "build_time_seconds": 0.5598082079086453, 8 | "deduplication_time_seconds": 0.008702374994754791, 9 | "time_seconds": 0.5685105829034001 10 | }, 11 | { 12 | "dataset": "senteval_cr", 13 | "original_train_size": 3012, 14 | "deduplicated_train_size": 2990, 15 | "percent_removed": 0.7304116865869847, 16 | "build_time_seconds": 0.10847400000784546, 17 | "deduplication_time_seconds": 0.027519959025084972, 18 | "time_seconds": 0.13599395903293043 19 | }, 20 | { 21 | "dataset": "tweet_sentiment_extraction", 22 | "original_train_size": 27481, 23 | "deduplicated_train_size": 26695, 24 | "percent_removed": 2.860157927295226, 25 | "build_time_seconds": 1.3568968329345807, 26 | "deduplication_time_seconds": 0.41194633406121284, 27 | "time_seconds": 1.7688431669957936 28 | }, 29 | { 30 | "dataset": "emotion", 31 | "original_train_size": 16000, 32 | "deduplicated_train_size": 15695, 33 | "percent_removed": 1.9062499999999982, 34 | "build_time_seconds": 0.5511152499821037, 35 | "deduplication_time_seconds": 0.21407662506680936, 36 | "time_seconds": 0.7651918750489131 37 | }, 38 | { 39 | "dataset": "amazon_counterfactual", 40 | "original_train_size": 5000, 41 | "deduplicated_train_size": 4992, 42 | "percent_removed": 0.16000000000000458, 43 | "build_time_seconds": 0.2848535830853507, 44 | "deduplication_time_seconds": 0.048574666026979685, 45 | "time_seconds": 0.3334282491123304 46 | }, 47 | { 48 | "dataset": "ag_news", 49 | "original_train_size": 120000, 50 | "deduplicated_train_size": 106921, 51 | "percent_removed": 10.899166666666671, 52 | "build_time_seconds": 3.0319770000642166, 53 | "deduplication_time_seconds": 2.171258582966402, 54 | "time_seconds": 5.203235583030619 55 | }, 56 | { 57 | "dataset": "enron_spam", 58 | "original_train_size": 31716, 59 | "deduplicated_train_size": 20540, 60 | "percent_removed": 35.23773489721276, 61 | "build_time_seconds": 1.3818323339801282, 62 | "deduplication_time_seconds": 0.6438171250047162, 63 | "time_seconds": 2.0256494589848444 64 | }, 65 | { 66 | "dataset": "subj", 67 | "original_train_size": 8000, 68 | "deduplicated_train_size": 7990, 69 | "percent_removed": 0.12499999999999734, 70 | "build_time_seconds": 0.5059439589967951, 71 | "deduplication_time_seconds": 0.12505983305163682, 72 | "time_seconds": 0.6310037920484319 73 | }, 74 | { 75 | "dataset": "sst5", 76 | "original_train_size": 8544, 77 | "deduplicated_train_size": 8526, 78 | "percent_removed": 0.2106741573033699, 79 | "build_time_seconds": 0.4805819580797106, 80 | "deduplication_time_seconds": 0.10166720801498741, 81 | "time_seconds": 0.582249166094698 82 | }, 83 | { 84 | "dataset": "20_newgroups", 85 | "original_train_size": 11314, 86 | "deduplicated_train_size": 10684, 87 | "percent_removed": 5.568322432384654, 88 | "build_time_seconds": 0.610724583035335, 89 | "deduplication_time_seconds": 0.11600329191423953, 90 | "time_seconds": 0.7267278749495745 91 | }, 92 | { 93 | "dataset": "hatespeech_offensive", 94 | "original_train_size": 22783, 95 | "deduplicated_train_size": 22090, 96 | "percent_removed": 3.0417416494754823, 97 | "build_time_seconds": 0.6471997499465942, 98 | "deduplication_time_seconds": 0.2704670410603285, 99 | "time_seconds": 0.9176667910069227 100 | }, 101 | { 102 | "dataset": "ade", 103 | "original_train_size": 17637, 104 | "deduplicated_train_size": 15718, 105 | "percent_removed": 10.880535238419231, 106 | "build_time_seconds": 0.5221591669833288, 107 | "deduplication_time_seconds": 0.20764074998442084, 108 | "time_seconds": 0.7297999169677496 109 | }, 110 | { 111 | "dataset": "imdb", 112 | "original_train_size": 25000, 113 | "deduplicated_train_size": 24830, 114 | "percent_removed": 0.6800000000000028, 115 | "build_time_seconds": 1.460668999934569, 116 | "deduplication_time_seconds": 0.29758112493436784, 117 | "time_seconds": 1.7582501248689368 118 | }, 119 | { 120 | "dataset": "massive_scenario", 121 | "original_train_size": 11514, 122 | "deduplicated_train_size": 9366, 123 | "percent_removed": 18.655549765502865, 124 | "build_time_seconds": 0.35503324994351715, 125 | "deduplication_time_seconds": 0.11619104200508446, 126 | "time_seconds": 0.4712242919486016 127 | }, 128 | { 129 | "dataset": "student", 130 | "original_train_size": 117519, 131 | "deduplicated_train_size": 63856, 132 | "percent_removed": 45.66325445247151, 133 | "build_time_seconds": 2.9044899590080604, 134 | "deduplication_time_seconds": 5.895973875070922, 135 | "time_seconds": 8.800463834078982 136 | }, 137 | { 138 | "dataset": "squad_v2", 139 | "original_train_size": 130319, 140 | "deduplicated_train_size": 109698, 141 | "percent_removed": 15.823479308466148, 142 | "build_time_seconds": 6.078755749971606, 143 | "deduplication_time_seconds": 2.7270843340083957, 144 | "time_seconds": 8.805840083980002 145 | }, 146 | { 147 | "dataset": "wikitext", 148 | "original_train_size": 1801350, 149 | "deduplicated_train_size": 884645, 150 | "percent_removed": 50.88988813945097, 151 | "build_time_seconds": 39.38258587510791, 152 | "deduplication_time_seconds": 44.1503732081037, 153 | "time_seconds": 83.53295908321161 154 | } 155 | ] 156 | -------------------------------------------------------------------------------- /benchmarks/results/train_test_benchmark_results.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "dataset": "bbc", 4 | "train_size": 1225, 5 | "test_size": 1000, 6 | "deduplicated_test_size": 870, 7 | "percent_removed": 13.0, 8 | "build_time_seconds": 0.5598082079086453, 9 | "deduplication_time_seconds": 0.1528247919632122, 10 | "time_seconds": 0.7126329998718575 11 | }, 12 | { 13 | "dataset": "senteval_cr", 14 | "train_size": 3012, 15 | "test_size": 753, 16 | "deduplicated_test_size": 750, 17 | "percent_removed": 0.3984063745019917, 18 | "build_time_seconds": 0.10847400000784546, 19 | "deduplication_time_seconds": 0.019297500024549663, 20 | "time_seconds": 0.12777150003239512 21 | }, 22 | { 23 | "dataset": "tweet_sentiment_extraction", 24 | "train_size": 27481, 25 | "test_size": 3534, 26 | "deduplicated_test_size": 3412, 27 | "percent_removed": 3.452178834182229, 28 | "build_time_seconds": 1.3568968329345807, 29 | "deduplication_time_seconds": 0.17268049996346235, 30 | "time_seconds": 1.529577332898043 31 | }, 32 | { 33 | "dataset": "emotion", 34 | "train_size": 16000, 35 | "test_size": 2000, 36 | "deduplicated_test_size": 1926, 37 | "percent_removed": 3.7000000000000033, 38 | "build_time_seconds": 0.5511152499821037, 39 | "deduplication_time_seconds": 0.10135454102419317, 40 | "time_seconds": 0.6524697910062969 41 | }, 42 | { 43 | "dataset": "amazon_counterfactual", 44 | "train_size": 5000, 45 | "test_size": 5000, 46 | "deduplicated_test_size": 4990, 47 | "percent_removed": 0.20000000000000018, 48 | "build_time_seconds": 0.2848535830853507, 49 | "deduplication_time_seconds": 0.22846354101784527, 50 | "time_seconds": 0.513317124103196 51 | }, 52 | { 53 | "dataset": "ag_news", 54 | "train_size": 120000, 55 | "test_size": 7600, 56 | "deduplicated_test_size": 6198, 57 | "percent_removed": 18.447368421052634, 58 | "build_time_seconds": 3.0319770000642166, 59 | "deduplication_time_seconds": 0.7034984159981832, 60 | "time_seconds": 3.7354754160623997 61 | }, 62 | { 63 | "dataset": "enron_spam", 64 | "train_size": 31716, 65 | "test_size": 2000, 66 | "deduplicated_test_size": 1060, 67 | "percent_removed": 47.0, 68 | "build_time_seconds": 1.3818323339801282, 69 | "deduplication_time_seconds": 0.553584959008731, 70 | "time_seconds": 1.9354172929888591 71 | }, 72 | { 73 | "dataset": "subj", 74 | "train_size": 8000, 75 | "test_size": 2000, 76 | "deduplicated_test_size": 1999, 77 | "percent_removed": 0.04999999999999449, 78 | "build_time_seconds": 0.5059439589967951, 79 | "deduplication_time_seconds": 0.11624520795885473, 80 | "time_seconds": 0.6221891669556499 81 | }, 82 | { 83 | "dataset": "sst5", 84 | "train_size": 8544, 85 | "test_size": 2210, 86 | "deduplicated_test_size": 2205, 87 | "percent_removed": 0.2262443438914019, 88 | "build_time_seconds": 0.4805819580797106, 89 | "deduplication_time_seconds": 0.11375170899555087, 90 | "time_seconds": 0.5943336670752615 91 | }, 92 | { 93 | "dataset": "20_newgroups", 94 | "train_size": 11314, 95 | "test_size": 7532, 96 | "deduplicated_test_size": 7098, 97 | "percent_removed": 5.762081784386619, 98 | "build_time_seconds": 0.610724583035335, 99 | "deduplication_time_seconds": 1.6346445409581065, 100 | "time_seconds": 2.2453691239934415 101 | }, 102 | { 103 | "dataset": "hatespeech_offensive", 104 | "train_size": 22783, 105 | "test_size": 2000, 106 | "deduplicated_test_size": 1925, 107 | "percent_removed": 3.749999999999998, 108 | "build_time_seconds": 0.6471997499465942, 109 | "deduplication_time_seconds": 0.12372829194646329, 110 | "time_seconds": 0.7709280418930575 111 | }, 112 | { 113 | "dataset": "ade", 114 | "train_size": 17637, 115 | "test_size": 5879, 116 | "deduplicated_test_size": 4952, 117 | "percent_removed": 15.76798775301922, 118 | "build_time_seconds": 0.5221591669833288, 119 | "deduplication_time_seconds": 0.28758599993307143, 120 | "time_seconds": 0.8097451669164002 121 | }, 122 | { 123 | "dataset": "imdb", 124 | "train_size": 25000, 125 | "test_size": 25000, 126 | "deduplicated_test_size": 24795, 127 | "percent_removed": 0.8199999999999985, 128 | "build_time_seconds": 1.460668999934569, 129 | "deduplication_time_seconds": 1.3489695829339325, 130 | "time_seconds": 2.8096385828685015 131 | }, 132 | { 133 | "dataset": "massive_scenario", 134 | "train_size": 11514, 135 | "test_size": 2974, 136 | "deduplicated_test_size": 2190, 137 | "percent_removed": 26.36180228648285, 138 | "build_time_seconds": 0.35503324994351715, 139 | "deduplication_time_seconds": 0.10878237499855459, 140 | "time_seconds": 0.46381562494207174 141 | }, 142 | { 143 | "dataset": "student", 144 | "train_size": 117519, 145 | "test_size": 5000, 146 | "deduplicated_test_size": 2393, 147 | "percent_removed": 52.14, 148 | "build_time_seconds": 2.9044899590080604, 149 | "deduplication_time_seconds": 0.8721794589655474, 150 | "time_seconds": 3.776669417973608 151 | }, 152 | { 153 | "dataset": "squad_v2", 154 | "train_size": 130319, 155 | "test_size": 11873, 156 | "deduplicated_test_size": 11863, 157 | "percent_removed": 0.08422471153036737, 158 | "build_time_seconds": 6.078755749971606, 159 | "deduplication_time_seconds": 1.0497459589969367, 160 | "time_seconds": 7.1285017089685425 161 | }, 162 | { 163 | "dataset": "wikitext", 164 | "train_size": 1801350, 165 | "test_size": 4358, 166 | "deduplicated_test_size": 2139, 167 | "percent_removed": 50.91785222579165, 168 | "build_time_seconds": 39.38258587510791, 169 | "deduplication_time_seconds": 0.9325925830053166, 170 | "time_seconds": 40.31517845811322 171 | } 172 | ] 173 | -------------------------------------------------------------------------------- /benchmarks/run_benchmarks.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from time import perf_counter 4 | 5 | from datasets import load_dataset 6 | from model2vec import StaticModel 7 | 8 | from benchmarks.data import DATASET_DICT 9 | from semhash import SemHash 10 | 11 | # Set up logging 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def main() -> None: # noqa: C901 16 | """Run the benchmarks.""" 17 | # Prepare lists to hold benchmark results 18 | train_dedup_results = [] 19 | train_test_dedup_results = [] 20 | # Load the model and initialize SemHash 21 | 22 | model = StaticModel.from_pretrained("minishlab/potion-base-8m") 23 | 24 | for dataset_name, record in DATASET_DICT.items(): 25 | logger.info(f"Loading dataset: {dataset_name} from {record.name}") 26 | 27 | # Load train and test splits 28 | if record.sub_directory: 29 | train_ds = load_dataset(record.name, record.sub_directory, split=record.split_one) 30 | test_ds = load_dataset(record.name, record.sub_directory, split=record.split_two) 31 | else: 32 | train_ds = load_dataset(record.name, split=record.split_one) 33 | test_ds = load_dataset(record.name, split=record.split_two) 34 | 35 | # If the dataset has columns, use them 36 | if record.columns: 37 | columns = record.columns 38 | train_records = [dict(row) for row in train_ds] 39 | test_records = [dict(row) for row in test_ds] 40 | # Else, use the text_name 41 | else: 42 | train_records = train_ds[record.text_name] 43 | test_records = test_ds[record.text_name] 44 | columns = None 45 | 46 | # Build the SemHash instance 47 | build_start = perf_counter() 48 | semhash = SemHash.from_records(model=model, use_ann=True, records=train_records, columns=columns) 49 | build_end = perf_counter() 50 | build_time = build_end - build_start 51 | # Time how long it takes to deduplicate the train set 52 | train_only_start = perf_counter() 53 | deduplicated_train = semhash.self_deduplicate() 54 | train_only_end = perf_counter() 55 | 56 | train_only_dedup_time = train_only_end - train_only_start 57 | original_train_size = len(train_records) 58 | dedup_train_size = len(deduplicated_train.selected) 59 | 60 | percent_removed_train = deduplicated_train.duplicate_ratio * 100 61 | train_dedup_results.append( 62 | { 63 | "dataset": dataset_name, 64 | "original_train_size": original_train_size, 65 | "deduplicated_train_size": dedup_train_size, 66 | "percent_removed": percent_removed_train, 67 | "build_time_seconds": build_time, 68 | "deduplication_time_seconds": train_only_dedup_time, 69 | "time_seconds": train_only_dedup_time + build_time, 70 | } 71 | ) 72 | 73 | logger.info( 74 | f"[TRAIN DEDUPLICATION] Dataset: {dataset_name}\n" 75 | f" - Original Train Size: {original_train_size}\n" 76 | f" - Deduplicated Train Size: {dedup_train_size}\n" 77 | f" - % Removed: {percent_removed_train:.2f}\n" 78 | f" - Deduplication Time (seconds): {train_only_dedup_time:.2f}\n" 79 | f" - Build Time (seconds): {build_time:.2f}\n" 80 | f" - Total Time (seconds): {train_only_dedup_time + build_time:.2f}\n" 81 | ) 82 | 83 | # Time how long it takes to deduplicate the test set 84 | train_test_start = perf_counter() 85 | deduplicated_test = semhash.deduplicate( 86 | records=test_records, 87 | ) 88 | train_test_end = perf_counter() 89 | train_test_dedup_time = train_test_end - train_test_start 90 | original_test_size = len(test_records) 91 | deduped_test_size = len(deduplicated_test.selected) 92 | percent_removed_test = deduplicated_test.duplicate_ratio * 100 93 | 94 | train_test_dedup_results.append( 95 | { 96 | "dataset": dataset_name, 97 | "train_size": original_train_size, 98 | "test_size": original_test_size, 99 | "deduplicated_test_size": deduped_test_size, 100 | "percent_removed": percent_removed_test, 101 | "build_time_seconds": build_time, 102 | "deduplication_time_seconds": train_test_dedup_time, 103 | "time_seconds": train_test_dedup_time + build_time, 104 | } 105 | ) 106 | 107 | logger.info( 108 | f"[TRAIN/TEST DEDUPLICATION] Dataset: {dataset_name}\n" 109 | f" - Train Size: {original_train_size}\n" 110 | f" - Test Size: {original_test_size}\n" 111 | f" - Deduplicated Test Size: {deduped_test_size}\n" 112 | f" - % Removed: {percent_removed_test:.2f}\n" 113 | f" - Deduplication Time (seconds): {train_test_dedup_time:.2f}\n" 114 | f" - Build Time (seconds): {build_time:.2f}\n" 115 | f" - Total Time (seconds): {train_test_dedup_time + build_time:.2f}\n" 116 | ) 117 | 118 | # Write the results to JSON files 119 | with open("benchmarks/results/train_benchmark_results.json", "w", encoding="utf-8") as f: 120 | json.dump(train_dedup_results, f, ensure_ascii=False, indent=2) 121 | 122 | with open("benchmarks/results/train_test_benchmark_results.json", "w", encoding="utf-8") as f: 123 | json.dump(train_test_dedup_results, f, ensure_ascii=False, indent=2) 124 | 125 | # Print the train table 126 | print("### Train Deduplication Benchmark\n") # noqa T201 127 | print( # noqa T201 128 | f"| {'Dataset':<20} | {'Original Train Size':>20} | {'Deduplicated Train Size':>24} | {'% Removed':>10} | {'Deduplication Time (s)':>24} |" 129 | ) # noqa T201 130 | print("|" + "-" * 22 + "|" + "-" * 22 + "|" + "-" * 26 + "|" + "-" * 12 + "|" + "-" * 26 + "|") # noqa T201 131 | for r in train_dedup_results: 132 | print( # noqa T201 133 | f"| {r['dataset']:<20} " 134 | f"| {r['original_train_size']:>20} " 135 | f"| {r['deduplicated_train_size']:>24} " 136 | f"| {r['percent_removed']:>10.2f} " 137 | f"| {r['time_seconds']:>24.2f} |" 138 | ) 139 | 140 | print("\n") # noqa T201 141 | 142 | # Print the train/test table 143 | print("### Train/Test Deduplication Benchmark\n") # noqa T201 144 | print( # noqa T201 145 | f"| {'Dataset':<20} | {'Train Size':>12} | {'Test Size':>12} | {'Deduplicated Test Size':>24} | {'% Removed':>10} | {'Deduplication Time (s)':>24} |" 146 | ) # noqa T201 147 | print("|" + "-" * 22 + "|" + "-" * 14 + "|" + "-" * 14 + "|" + "-" * 26 + "|" + "-" * 12 + "|" + "-" * 26 + "|") # noqa T201 148 | for r in train_test_dedup_results: 149 | print( # noqa T201 150 | f"| {r['dataset']:<20} " 151 | f"| {r['train_size']:>12} " 152 | f"| {r['test_size']:>12} " 153 | f"| {r['deduplicated_test_size']:>24} " 154 | f"| {r['percent_removed']:>10.2f} " 155 | f"| {r['time_seconds']:>24.2f} |" 156 | ) 157 | 158 | 159 | if __name__ == "__main__": 160 | logging.basicConfig(level=logging.INFO) 161 | main() 162 | -------------------------------------------------------------------------------- /tests/test_datamodels.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import semhash 4 | import semhash.version 5 | from semhash.datamodels import DeduplicationResult, DuplicateRecord, SelectedWithDuplicates 6 | 7 | 8 | def test_deduplication_scoring() -> None: 9 | """Test the deduplication scoring.""" 10 | d = DeduplicationResult( 11 | ["a", "b", "c"], 12 | [DuplicateRecord("a", False, [("b", 0.9)]), DuplicateRecord("b", False, [("c", 0.8)])], 13 | 0.8, 14 | ) 15 | assert d.duplicate_ratio == 0.4 16 | 17 | 18 | def test_deduplication_scoring_exact() -> None: 19 | """Test the deduplication scoring.""" 20 | d = DeduplicationResult( 21 | ["a", "b", "c"], 22 | [DuplicateRecord("a", True, [("b", 0.9)]), DuplicateRecord("b", False, [("c", 0.8)])], 23 | 0.8, 24 | ) 25 | assert d.exact_duplicate_ratio == 0.2 26 | 27 | 28 | def test_deduplication_scoring_exact_empty() -> None: 29 | """Test the deduplication scoring.""" 30 | d = DeduplicationResult([], [], 0.8, columns=["text"]) 31 | assert d.exact_duplicate_ratio == 0.0 32 | 33 | 34 | def test_deduplication_scoring_empty() -> None: 35 | """Test the deduplication scoring.""" 36 | d = DeduplicationResult([], [], 0.8, columns=["text"]) 37 | assert d.duplicate_ratio == 0.0 38 | 39 | 40 | def test_rethreshold() -> None: 41 | """Test rethresholding the duplicates.""" 42 | d = DuplicateRecord("a", False, [("b", 0.9), ("c", 0.8)]) 43 | d._rethreshold(0.85) 44 | assert d.duplicates == [("b", 0.9)] 45 | 46 | 47 | def test_rethreshold_empty() -> None: 48 | """Test rethresholding the duplicates.""" 49 | d = DuplicateRecord("a", False, []) 50 | d._rethreshold(0.85) 51 | assert d.duplicates == [] 52 | 53 | 54 | def test_get_least_similar_from_duplicates() -> None: 55 | """Test getting the least similar duplicates.""" 56 | d = DeduplicationResult( 57 | ["a", "b", "c"], 58 | [DuplicateRecord("a", False, [("b", 0.9), ("c", 0.7)]), DuplicateRecord("b", False, [("c", 0.8)])], 59 | 0.8, 60 | ) 61 | result = d.get_least_similar_from_duplicates(1) 62 | assert result == [("a", "c", 0.7)] 63 | 64 | 65 | def test_get_least_similar_from_duplicates_empty() -> None: 66 | """Test getting the least similar duplicates.""" 67 | d = DeduplicationResult([], [], 0.8, columns=["text"]) 68 | assert d.get_least_similar_from_duplicates(1) == [] 69 | 70 | 71 | def test_rethreshold_deduplication_result() -> None: 72 | """Test rethresholding the duplicates.""" 73 | d = DeduplicationResult( 74 | ["a", "b", "c"], 75 | [ 76 | DuplicateRecord("d", False, [("x", 0.9), ("y", 0.8)]), 77 | DuplicateRecord("e", False, [("z", 0.8)]), 78 | ], 79 | 0.8, 80 | ) 81 | d.rethreshold(0.85) 82 | assert d.filtered == [DuplicateRecord("d", False, [("x", 0.9)])] 83 | assert d.selected == ["a", "b", "c", "e"] 84 | 85 | 86 | def test_rethreshold_exception() -> None: 87 | """Test rethresholding throws an exception.""" 88 | d = DeduplicationResult( 89 | ["a", "b", "c"], 90 | [ 91 | DuplicateRecord("d", False, [("x", 0.9), ("y", 0.8)]), 92 | DuplicateRecord("e", False, [("z", 0.8)]), 93 | ], 94 | 0.7, 95 | ) 96 | with pytest.raises(ValueError): 97 | d.rethreshold(0.6) 98 | 99 | 100 | def test_deprecation_deduplicated_duplicates() -> None: 101 | """Test deprecation warnings for deduplicated and duplicates fields.""" 102 | if semhash.version.__version__ < "0.4.0": 103 | with pytest.warns(DeprecationWarning): 104 | d = DeduplicationResult( 105 | deduplicated=["a", "b", "c"], 106 | duplicates=[ 107 | DuplicateRecord("d", False, [("x", 0.9), ("y", 0.8)]), 108 | DuplicateRecord("e", False, [("z", 0.8)]), 109 | ], 110 | threshold=0.8, 111 | ) 112 | else: 113 | raise ValueError("deprecate `deduplicated` and `duplicates` fields in `DeduplicationResult`") 114 | assert d.selected == ["a", "b", "c"] 115 | assert d.filtered == [ 116 | DuplicateRecord("d", False, [("x", 0.9), ("y", 0.8)]), 117 | DuplicateRecord("e", False, [("z", 0.8)]), 118 | ] 119 | 120 | 121 | def test_selected_with_duplicates_strings() -> None: 122 | """Test selected_with_duplicates for strings.""" 123 | d = DeduplicationResult( 124 | selected=["original"], 125 | filtered=[ 126 | DuplicateRecord("duplicate_1", False, [("original", 0.9)]), 127 | DuplicateRecord("duplicate_2", False, [("original", 0.8)]), 128 | ], 129 | threshold=0.8, 130 | ) 131 | 132 | expected = [ 133 | SelectedWithDuplicates( 134 | record="original", 135 | duplicates=[("duplicate_1", 0.9), ("duplicate_2", 0.8)], 136 | ) 137 | ] 138 | assert d.selected_with_duplicates == expected 139 | 140 | 141 | def test_selected_with_duplicates_dicts() -> None: 142 | """Test selected_with_duplicates for dicts.""" 143 | selected = {"id": 0, "text": "hello"} 144 | d = DeduplicationResult( 145 | selected=[selected], 146 | filtered=[ 147 | DuplicateRecord({"id": 1, "text": "hello"}, True, [(selected, 1.0)]), 148 | DuplicateRecord({"id": 2, "text": "helllo"}, False, [(selected, 0.1)]), 149 | ], 150 | threshold=0.8, 151 | columns=["text"], 152 | ) 153 | 154 | items = d.selected_with_duplicates 155 | assert len(items) == 1 156 | kept = items[0].record 157 | dups = items[0].duplicates 158 | assert kept == selected 159 | assert {r["id"] for r, _ in dups} == {1, 2} 160 | 161 | 162 | def test_selected_with_duplicates_multi_column() -> None: 163 | """Test selected_with_duplicates for multi-columns.""" 164 | selected = {"text": "hello", "text2": "world"} 165 | d = DeduplicationResult( 166 | selected=[selected], 167 | filtered=[ 168 | DuplicateRecord({"text": "hello", "text2": "world"}, True, [(selected, 1.0)]), 169 | DuplicateRecord({"text": "helllo", "text2": "world"}, False, [(selected, 0.1)]), 170 | ], 171 | threshold=0.8, 172 | columns=["text", "text2"], 173 | ) 174 | 175 | items = d.selected_with_duplicates 176 | assert len(items) == 1 177 | kept = items[0].record 178 | assert kept == selected 179 | 180 | 181 | def test_selected_with_duplicates_unhashable_values() -> None: 182 | """Test selected_with_duplicates with unhashable values in records.""" 183 | selected = {"text": "hello", "a": [1, 2, 3]} # list -> unhashable value 184 | filtered = {"text": "hello", "a": [1, 2, 3], "flag": True} 185 | 186 | d = DeduplicationResult( 187 | selected=[selected], 188 | filtered=[DuplicateRecord(filtered, exact=False, duplicates=[(selected, 1.0)])], 189 | threshold=0.8, 190 | columns=["text"], 191 | ) 192 | 193 | items = d.selected_with_duplicates 194 | assert items == [SelectedWithDuplicates(record=selected, duplicates=[(filtered, 1.0)])] 195 | 196 | 197 | def test_selected_with_duplicates_removes_internal_duplicates() -> None: 198 | """Test that selected_with_duplicates removes internal duplicates that have the same hash.""" 199 | selected = {"id": 0, "text": "hello"} 200 | filtered = {"id": 1, "text": "hello"} 201 | 202 | d = DeduplicationResult( 203 | selected=[selected], 204 | filtered=[ 205 | DuplicateRecord(filtered, exact=False, duplicates=[(selected, 0.95)]), 206 | DuplicateRecord(filtered, exact=False, duplicates=[(selected, 0.90)]), 207 | ], 208 | threshold=0.8, 209 | columns=["text"], 210 | ) 211 | 212 | items = d.selected_with_duplicates 213 | assert len(items) == 1 214 | 215 | selected_record = items[0].record 216 | duplicate_list = items[0].duplicates 217 | # Should keep the kept record unchanged 218 | assert selected_record == selected 219 | # The duplicate row must appear only once 220 | assert len(duplicate_list) == 1 221 | assert duplicate_list[0][0] == filtered 222 | -------------------------------------------------------------------------------- /semhash/datamodels.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import warnings 5 | from collections import defaultdict 6 | from dataclasses import dataclass, field 7 | from typing import Any, Generic, Hashable, Sequence, TypeVar 8 | 9 | from frozendict import frozendict 10 | from typing_extensions import TypeAlias 11 | 12 | from semhash.utils import to_frozendict 13 | 14 | Record = TypeVar("Record", str, dict[str, Any]) 15 | DuplicateList: TypeAlias = list[tuple[Record, float]] 16 | 17 | 18 | @dataclass 19 | class DuplicateRecord(Generic[Record]): 20 | """ 21 | A single record with its duplicates. 22 | 23 | Attributes 24 | ---------- 25 | record: The original record being deduplicated. 26 | exact: Whether the record was identified as an exact match. 27 | duplicates: List of tuples consisting of duplicate records and their associated scores. 28 | 29 | """ 30 | 31 | record: Record 32 | exact: bool 33 | duplicates: DuplicateList = field(default_factory=list) 34 | 35 | def _rethreshold(self, threshold: float) -> None: 36 | """Rethreshold the duplicates.""" 37 | self.duplicates = [(d, score) for d, score in self.duplicates if score >= threshold] 38 | 39 | 40 | @dataclass 41 | class SelectedWithDuplicates(Generic[Record]): 42 | """ 43 | A record that has been selected along with its duplicates. 44 | 45 | Attributes 46 | ---------- 47 | record: The original record being selected. 48 | duplicates: List of tuples consisting of duplicate records and their associated scores. 49 | 50 | """ 51 | 52 | record: Record 53 | duplicates: DuplicateList = field(default_factory=list) 54 | 55 | 56 | @dataclass 57 | class DeduplicationResult(Generic[Record]): 58 | """ 59 | Deduplication result. 60 | 61 | Attributes 62 | ---------- 63 | selected: List of deduplicated records after removing duplicates. 64 | filtered: List of DuplicateRecord objects containing details about duplicates of an original record. 65 | threshold: The similarity threshold used for deduplication. 66 | columns: Columns used for deduplication. 67 | deduplicated: Deprecated, use selected instead. 68 | duplicates: Deprecated, use filtered instead. 69 | 70 | """ 71 | 72 | selected: list[Record] = field(default_factory=list) 73 | filtered: list[DuplicateRecord] = field(default_factory=list) 74 | threshold: float = field(default=0.9) 75 | columns: Sequence[str] | None = field(default=None) 76 | deduplicated: list[Record] = field(default_factory=list) # Deprecated 77 | duplicates: list[DuplicateRecord] = field(default_factory=list) # Deprecated 78 | 79 | def __post_init__(self) -> None: 80 | """Initialize deprecated fields and warn about deprecation.""" 81 | if self.deduplicated or self.duplicates: 82 | warnings.warn( 83 | "'deduplicated' and 'duplicates' fields are deprecated and will be removed in a future release. Use 'selected' and 'filtered' instead.", 84 | DeprecationWarning, 85 | stacklevel=2, 86 | ) 87 | 88 | if not self.selected and self.deduplicated: 89 | self.selected = self.deduplicated 90 | if not self.filtered and self.duplicates: 91 | self.filtered = self.duplicates 92 | if not self.deduplicated: 93 | self.deduplicated = self.selected 94 | if not self.duplicates: 95 | self.duplicates = self.filtered 96 | 97 | @property 98 | def duplicate_ratio(self) -> float: 99 | """Return the percentage of records dropped.""" 100 | if denom := len(self.selected) + len(self.filtered): 101 | return 1.0 - len(self.selected) / denom 102 | return 0.0 103 | 104 | @property 105 | def exact_duplicate_ratio(self) -> float: 106 | """Return the percentage of records dropped due to an exact match.""" 107 | if denom := len(self.selected) + len(self.filtered): 108 | return len([dup for dup in self.filtered if dup.exact]) / denom 109 | return 0.0 110 | 111 | def get_least_similar_from_duplicates(self, n: int = 1) -> list[tuple[Record, Record, float]]: 112 | """ 113 | Return the N least similar duplicate pairs. 114 | 115 | :param n: The number of least similar pairs to return. 116 | :return: A list of tuples consisting of (original_record, duplicate_record, score). 117 | """ 118 | all_pairs = [(dup.record, d, score) for dup in self.filtered for d, score in dup.duplicates] 119 | sorted_pairs = sorted(all_pairs, key=lambda x: x[2]) # Sort by score 120 | return sorted_pairs[:n] 121 | 122 | def rethreshold(self, threshold: float) -> None: 123 | """Rethreshold the duplicates.""" 124 | if self.threshold > threshold: 125 | raise ValueError("Threshold is smaller than the given value.") 126 | for dup in self.filtered: 127 | dup._rethreshold(threshold) 128 | if not dup.duplicates: 129 | self.filtered.remove(dup) 130 | self.selected.append(dup.record) 131 | self.threshold = threshold 132 | 133 | @property 134 | def selected_with_duplicates(self) -> list[SelectedWithDuplicates[Record]]: 135 | """ 136 | For every kept record, return the duplicates that were removed along with their similarity scores. 137 | 138 | :return: A list of tuples where each tuple contains a kept record 139 | and a list of its duplicates with their similarity scores. 140 | """ 141 | 142 | def _to_hashable(record: Record) -> frozendict[str, str] | str: 143 | """Convert a record to a hashable representation.""" 144 | if isinstance(record, dict) and self.columns is not None: 145 | # Convert dict to frozendict for immutability and hashability 146 | return to_frozendict(record, set(self.columns)) 147 | return str(record) 148 | 149 | # Build a mapping from original-record to [(duplicate, score), …] 150 | buckets: defaultdict[Hashable, DuplicateList] = defaultdict(list) 151 | for duplicate_record in self.filtered: 152 | for original_record, score in duplicate_record.duplicates: 153 | buckets[_to_hashable(original_record)].append((duplicate_record.record, float(score))) 154 | 155 | result: list[SelectedWithDuplicates[Record]] = [] 156 | for selected in self.selected: 157 | # Get the list of duplicates for the selected record 158 | raw_list = buckets.get(_to_hashable(selected), []) 159 | # Ensure we don't have duplicates in the list 160 | # Use full-record canonical JSON for dicts so that unhashable values are handled correctly 161 | deduped = { 162 | ( 163 | json.dumps(rec, sort_keys=True, separators=(",", ":"), ensure_ascii=False) 164 | if isinstance(rec, dict) 165 | else rec 166 | ): (rec, score) 167 | for rec, score in raw_list 168 | } 169 | result.append(SelectedWithDuplicates(record=selected, duplicates=list(deduped.values()))) 170 | 171 | return result 172 | 173 | 174 | @dataclass 175 | class FilterResult(Generic[Record]): 176 | """ 177 | Result of filtering operations. 178 | 179 | Attributes 180 | ---------- 181 | selected: List of records that passed the filter criteria. 182 | filtered: List of records that were filtered out. 183 | scores_selected: List of scores for the selected records. 184 | scores_filtered: List of scores for the filtered records. 185 | 186 | """ 187 | 188 | selected: list[Record] 189 | filtered: list[Record] 190 | scores_selected: list[float] = field(default_factory=list) 191 | scores_filtered: list[float] = field(default_factory=list) 192 | 193 | @property 194 | def filter_ratio(self) -> float: 195 | """Return the percentage of records filtered out.""" 196 | if denom := len(self.selected) + len(self.filtered): 197 | return len(self.filtered) / denom 198 | return 0.0 199 | 200 | @property 201 | def selected_ratio(self) -> float: 202 | """Return the percentage of records selected.""" 203 | return 1 - self.filter_ratio 204 | -------------------------------------------------------------------------------- /tests/test_semhash.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from semhash import SemHash 5 | from semhash.datamodels import FilterResult 6 | from semhash.utils import Encoder 7 | 8 | 9 | def test_single_dataset_deduplication(use_ann: bool, model: Encoder) -> None: 10 | """Test single dataset deduplication.""" 11 | # No duplicates 12 | texts = [ 13 | "It's dangerous to go alone!", 14 | "The master sword can seal the darkness.", 15 | "Ganondorf has invaded Hyrule!", 16 | ] 17 | semhash = SemHash.from_records(records=texts, use_ann=use_ann, model=model) 18 | deduplicated_texts = semhash.self_deduplicate().selected 19 | 20 | assert deduplicated_texts == texts 21 | 22 | # With duplicates 23 | texts = [ 24 | "It's dangerous to go alone!", 25 | "It's dangerous to go alone!", # Exact duplicate 26 | "It's not safe to go alone!", # Semantically similar 27 | ] 28 | semhash = SemHash.from_records(records=texts, use_ann=use_ann, model=model) 29 | deduplicated_texts = semhash.self_deduplicate(0.7).selected 30 | assert deduplicated_texts == ["It's dangerous to go alone!"] 31 | 32 | 33 | def test_multi_dataset_deduplication(use_ann: bool, model: Encoder) -> None: 34 | """Test deduplication across two datasets.""" 35 | # No duplicates 36 | texts1 = [ 37 | "It's dangerous to go alone!", 38 | "It's a secret to everybody.", 39 | "Ganondorf has invaded Hyrule!", 40 | ] 41 | texts2 = [ 42 | "Link is the hero of time.", 43 | "Zelda is the princess of Hyrule.", 44 | "Ganon is the king of thieves.", 45 | ] 46 | semhash = SemHash.from_records(texts1, columns=None, use_ann=use_ann, model=model) 47 | deduplicated_texts = semhash.deduplicate(texts2).selected 48 | assert deduplicated_texts == texts2 49 | 50 | # With duplicates 51 | texts2 = [ 52 | "It's dangerous to go alone!", # Exact duplicate 53 | "It's risky to go alone!", # Semantically similar 54 | "Ganondorf has attacked Hyrule!", # Semantically similar 55 | ] 56 | deduplicated_texts = semhash.deduplicate(texts2, threshold=0.7).selected 57 | assert deduplicated_texts == [] 58 | 59 | 60 | def test_single_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) -> None: 61 | """Test single dataset deduplication with multi-column records.""" 62 | records = [ 63 | {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, 64 | {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, # Exact duplicate 65 | { 66 | "question": "Who is the protagonist?", 67 | "context": "In this story, Link is the hero", 68 | "answer": "Link", 69 | }, # Semantically similar 70 | {"question": "Who is the princess?", "context": "The princess is Zelda", "answer": "Zelda"}, 71 | ] 72 | semhash = SemHash.from_records( 73 | records, 74 | columns=["question", "context", "answer"], 75 | use_ann=use_ann, 76 | model=model, 77 | ) 78 | deduplicated = semhash.self_deduplicate(threshold=0.7) 79 | 80 | assert deduplicated.selected == [ 81 | {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, 82 | {"question": "Who is the princess?", "context": "The princess is Zelda", "answer": "Zelda"}, 83 | ] 84 | 85 | 86 | def test_multi_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) -> None: 87 | """Test multi dataset deduplication with multi-column records.""" 88 | train_records = [ 89 | {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, 90 | {"question": "Who is the princess?", "context": "The princess is Zelda", "answer": "Zelda"}, 91 | ] 92 | test_records = [ 93 | {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, # Exact duplicate 94 | { 95 | "question": "Who is the princess?", 96 | "context": "Zelda is the princess", 97 | "answer": "Zelda", 98 | }, # Semantically similar 99 | {"question": "What is the villain's name?", "context": "The villain is Ganon", "answer": "Ganon"}, 100 | ] 101 | semhash = SemHash.from_records( 102 | train_records, 103 | columns=["question", "context", "answer"], 104 | use_ann=use_ann, 105 | model=model, 106 | ) 107 | deduplicated = semhash.deduplicate(test_records).selected 108 | assert deduplicated == [ 109 | {"question": "What is the villain's name?", "context": "The villain is Ganon", "answer": "Ganon"} 110 | ] 111 | 112 | 113 | def test_from_records_without_columns(use_ann: bool, model: Encoder) -> None: 114 | """Test fitting without specifying columns.""" 115 | records = [ 116 | {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, 117 | {"question": "Who is the princess?", "context": "The princess is Zelda", "answer": "Zelda"}, 118 | ] 119 | with pytest.raises(ValueError): 120 | SemHash.from_records(records, columns=None, use_ann=use_ann, model=model) 121 | 122 | 123 | def test_deduplicate_with_only_exact_duplicates(use_ann: bool, model: Encoder) -> None: 124 | """Test deduplicating with only exact duplicates.""" 125 | texts1 = [ 126 | "It's dangerous to go alone!", 127 | "It's dangerous to go alone!", 128 | "It's dangerous to go alone!", 129 | ] 130 | texts2 = [ 131 | "It's dangerous to go alone!", 132 | "It's dangerous to go alone!", 133 | "It's dangerous to go alone!", 134 | ] 135 | semhash = SemHash.from_records(texts1, use_ann=use_ann, model=model) 136 | deduplicated = semhash.self_deduplicate() 137 | assert deduplicated.selected == ["It's dangerous to go alone!"] 138 | 139 | deduplicated = semhash.deduplicate(texts2) 140 | assert deduplicated.selected == [] 141 | 142 | 143 | def test_self_find_representative(use_ann: bool, model: Encoder, train_texts: list[str]) -> None: 144 | """Test the self_find_representative method.""" 145 | semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model) 146 | result = semhash.self_find_representative( 147 | candidate_limit=5, 148 | selection_size=3, 149 | lambda_param=0.5, 150 | ) 151 | assert len(result.selected) == 3, "Expected 3 representatives" 152 | selected = {r["text"] for r in result.selected} 153 | assert selected == { 154 | "blueberry", 155 | "pineapple", 156 | "grape", 157 | }, "Expected representatives to be blueberry, pineapple, and grape" 158 | 159 | 160 | def test_find_representative(use_ann: bool, model: Encoder, train_texts: list[str], test_texts: list[str]) -> None: 161 | """Test the find_representative method.""" 162 | semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model) 163 | result = semhash.find_representative(records=test_texts, candidate_limit=5, selection_size=3, lambda_param=0.5) 164 | assert len(result.selected) == 3, "Expected 3 representatives" 165 | selected = {r["text"] for r in result.selected} 166 | assert selected == {"grapefruit", "banana", "apple"}, "Expected representatives to be grapefruit, banana, and apple" 167 | 168 | 169 | def test_filter_outliers(use_ann: bool, model: Encoder, train_texts: list[str], test_texts: list[str]) -> None: 170 | """Test the filter_outliers method.""" 171 | semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model) 172 | result = semhash.filter_outliers(records=test_texts, outlier_percentage=0.2) 173 | assert len(result.filtered) == 2, "Expected 2 outliers" 174 | assert len(result.selected) == len(test_texts) - 2 175 | filtered = {r["text"] for r in result.filtered} 176 | assert filtered == {"motorcycle", "plane"}, "Expected outliers to be motorcycle and plane" 177 | 178 | 179 | def test_self_filter_outliers(use_ann: bool, model: Encoder, train_texts: list[str]) -> None: 180 | """Test the self_filter_outliers method.""" 181 | semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model) 182 | result = semhash.self_filter_outliers(outlier_percentage=0.1) 183 | assert len(result.filtered) == 2, "Expected 2 outliers" 184 | assert len(result.selected) == len(train_texts) - 2 185 | filtered = {r["text"] for r in result.filtered} 186 | assert filtered == {"car", "bicycle"}, "Expected outliers to be car and bicycle" 187 | 188 | 189 | def test__mmr(monkeypatch: pytest.MonkeyPatch) -> None: 190 | """Test the _mmr method.""" 191 | # Create a dummy SemHash instance 192 | semhash = SemHash(index=None, model=None, columns=["text"], was_string=True) # type: ignore 193 | # Prepare a fake ranking with three records 194 | records = ["a", "b", "c"] 195 | scores = [3.0, 2.0, 1.0] 196 | ranking = FilterResult(selected=records, filtered=[], scores_selected=scores, scores_filtered=[]) 197 | # Create dummy embeddings for the records 198 | embeddings = np.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]]) 199 | # Monkeypatch featurize to return the dummy embeddings 200 | monkeypatch.setattr(semhash, "_featurize", lambda records, columns, model: embeddings) 201 | 202 | # Test lambda=1.0: pure relevance, should pick top 2 by score 203 | result_rel = semhash._mmr(ranking, candidate_limit=3, selection_size=2, lambda_param=1.0) 204 | assert result_rel.selected == ["a", "b"] 205 | 206 | # Test lambda=0.0: pure diversity, should first pick 'a', then pick most dissimilar: 'c' 207 | result_div = semhash._mmr(ranking, candidate_limit=3, selection_size=2, lambda_param=0.0) 208 | assert result_div.selected == ["a", "c"] 209 | 210 | 211 | def test_mmr_invalid_lambda_raises() -> None: 212 | """Test that invalid lambda values raise ValueError.""" 213 | semhash = SemHash(index=None, model=None, columns=["text"], was_string=True) # type: ignore 214 | dummy = FilterResult(selected=["x"], filtered=[], scores_selected=[0.5], scores_filtered=[]) 215 | with pytest.raises(ValueError): 216 | semhash._mmr(dummy, candidate_limit=1, selection_size=1, lambda_param=-0.1) 217 | with pytest.raises(ValueError): 218 | semhash._mmr(dummy, candidate_limit=1, selection_size=1, lambda_param=1.1) 219 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |
