├── benchmarks ├── __init__.py ├── data.py ├── results │ ├── train_benchmark_results.json │ └── train_test_benchmark_results.json └── run_benchmarks.py ├── semhash ├── __init__.py ├── version.py ├── records.py ├── utils.py ├── index.py ├── datamodels.py └── semhash.py ├── assets └── images │ ├── semhash_logo.png │ └── semhash_logo_v2.png ├── tests ├── data │ └── test_model │ │ ├── model.safetensors │ │ ├── config.json │ │ ├── modules.json │ │ └── README.md ├── conftest.py ├── test_datamodels.py └── test_semhash.py ├── Makefile ├── CITATION.cff ├── LICENSE ├── .pre-commit-config.yaml ├── .github └── workflows │ └── ci.yaml ├── pyproject.toml ├── .gitignore └── README.md /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /semhash/__init__.py: -------------------------------------------------------------------------------- 1 | from semhash.semhash import SemHash 2 | 3 | __all__ = ["SemHash"] 4 | -------------------------------------------------------------------------------- /assets/images/semhash_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/semhash/HEAD/assets/images/semhash_logo.png -------------------------------------------------------------------------------- /semhash/version.py: -------------------------------------------------------------------------------- 1 | __version_triple__ = (0, 3, 3) 2 | __version__ = ".".join(map(str, __version_triple__)) 3 | -------------------------------------------------------------------------------- /assets/images/semhash_logo_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/semhash/HEAD/assets/images/semhash_logo_v2.png -------------------------------------------------------------------------------- /tests/data/test_model/model.safetensors: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinishLab/semhash/HEAD/tests/data/test_model/model.safetensors -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | clean: 2 | 3 | 4 | venv: 5 | uv venv 6 | 7 | install: venv 8 | uv sync --all-extras 9 | uv run pre-commit install 10 | 11 | install-no-pre-commit: 12 | uv pip install ".[dev]" 13 | 14 | fix: 15 | uv run pre-commit run --all-files 16 | 17 | test: 18 | uv run pytest --cov=semhash --cov-report=term-missing 19 | -------------------------------------------------------------------------------- /tests/data/test_model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "model2vec", 3 | "architectures": [ 4 | "StaticModel" 5 | ], 6 | "tokenizer_name": "baai/bge-base-en-v1.5", 7 | "apply_pca": 128, 8 | "apply_zipf": true, 9 | "hidden_dim": 128, 10 | "seq_length": 1000000, 11 | "normalize": true 12 | } 13 | -------------------------------------------------------------------------------- /tests/data/test_model/modules.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "idx": 0, 4 | "name": "0", 5 | "path": ".", 6 | "type": "sentence_transformers.models.StaticEmbedding" 7 | }, 8 | { 9 | "idx": 1, 10 | "name": "1", 11 | "path": "1_Normalize", 12 | "type": "sentence_transformers.models.Normalize" 13 | } 14 | ] 15 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use SemHash in your research, please cite it as below." 3 | title: "SemHash: Fast Semantic Text Deduplication & Filtering" 4 | authors: 5 | - family-names: "van Dongen" 6 | given-names: "Thomas" 7 | - family-names: "Tulkens" 8 | given-names: "Stephan" 9 | doi: 10.5281/zenodo.17265942 10 | license: MIT 11 | url: "https://github.com/MinishLab/semhash" 12 | repository-code: "https://github.com/MinishLab/semhash" 13 | date-released: "2025-01-05" 14 | 15 | preferred-citation: 16 | type: software 17 | title: "SemHash: Fast Semantic Text Deduplication & Filtering" 18 | authors: 19 | - family-names: "van Dongen" 20 | given-names: "Thomas" 21 | - family-names: "Tulkens" 22 | given-names: "Stephan" 23 | year: 2025 24 | publisher: Zenodo 25 | doi: 10.5281/zenodo.17265942 26 | url: "https://github.com/MinishLab/semhash" 27 | license: MIT 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 The Minish Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.4.0 6 | hooks: 7 | - id: check-ast 8 | description: Simply check whether files parse as valid python. 9 | - id: trailing-whitespace 10 | description: Trims trailing whitespace 11 | - id: end-of-file-fixer 12 | description: Makes sure files end in a newline and only a newline. 13 | - id: check-added-large-files 14 | args: ['--maxkb=5000'] 15 | description: Prevent giant files from being committed. 16 | - id: check-case-conflict 17 | description: Check for files with names that would conflict on case-insensitive filesystems like MacOS/Windows. 18 | - repo: https://github.com/jsh9/pydoclint 19 | rev: 0.5.3 20 | hooks: 21 | - id: pydoclint 22 | - repo: https://github.com/astral-sh/ruff-pre-commit 23 | rev: v0.4.10 24 | hooks: 25 | - id: ruff 26 | args: [ --fix ] 27 | - id: ruff-format 28 | - repo: local 29 | hooks: 30 | - id: mypy 31 | name: mypy 32 | entry: mypy 33 | language: python 34 | types: [python] 35 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from model2vec import StaticModel 3 | 4 | 5 | @pytest.fixture 6 | def model() -> StaticModel: 7 | """Load a model for testing.""" 8 | return StaticModel.from_pretrained("tests/data/test_model") 9 | 10 | 11 | @pytest.fixture(params=[True, False], ids=["use_ann=True", "use_ann=False"]) 12 | def use_ann(request: pytest.FixtureRequest) -> bool: 13 | """Whether to use approximate nearest neighbors or not.""" 14 | return request.param 15 | 16 | 17 | @pytest.fixture 18 | def train_texts() -> list[str]: 19 | """A list of train texts for testing outlier and representative filtering.""" 20 | return [ 21 | "apple", 22 | "banana", 23 | "cherry", 24 | "strawberry", 25 | "blueberry", 26 | "raspberry", 27 | "blackberry", 28 | "peach", 29 | "plum", 30 | "grape", 31 | "mango", 32 | "papaya", 33 | "pineapple", 34 | "watermelon", 35 | "orange", 36 | "lemon", 37 | "lime", 38 | "tangerine", 39 | "car", # Outlier 40 | "bicycle", # Outlier 41 | ] 42 | 43 | 44 | @pytest.fixture 45 | def test_texts() -> list[str]: 46 | """A list of test texts for testing outlier and representative filtering.""" 47 | return [ 48 | "apple", 49 | "banana", 50 | "kiwi", 51 | "fig", 52 | "apricot", 53 | "grapefruit", 54 | "pomegranate", 55 | "motorcycle", # Outlier 56 | "plane", # Outlier 57 | ] 58 | -------------------------------------------------------------------------------- /semhash/records.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from semhash.datamodels import DeduplicationResult, DuplicateRecord 4 | 5 | 6 | def dict_to_string(record: dict[str, str], columns: Sequence[str]) -> str: 7 | r""" 8 | Turn a record into a single string. 9 | 10 | Uses self.columns to determine the order of the text segments. 11 | Each text is cleaned by replacing '\t' with ' '. The texts are then joined by '\t'. 12 | 13 | :param record: A record to unpack. 14 | :param columns: Columns to unpack. 15 | :return: A single string representation of the record. 16 | """ 17 | return "\t".join(record.get(c, "").replace("\t", " ") for c in columns) 18 | 19 | 20 | def map_deduplication_result_to_strings(result: DeduplicationResult, columns: Sequence[str]) -> DeduplicationResult: 21 | """Convert the record and duplicates in each DuplicateRecord back to strings if self.was_string is True.""" 22 | deduplicated_str = [dict_to_string(r, columns) for r in result.selected] 23 | mapped = [] 24 | for dup_rec in result.duplicates: 25 | record_as_str = dict_to_string(dup_rec.record, columns) 26 | duplicates_as_str = [(dict_to_string(r, columns), score) for r, score in dup_rec.duplicates] 27 | mapped.append( 28 | DuplicateRecord( 29 | record=record_as_str, 30 | duplicates=duplicates_as_str, 31 | exact=dup_rec.exact, 32 | ) 33 | ) 34 | return DeduplicationResult(selected=deduplicated_str, filtered=mapped, threshold=result.threshold) 35 | 36 | 37 | def add_scores_to_records(records: list[dict[str, str]]) -> list[tuple[dict[str, str], float]]: 38 | """Add scores to records and return a DeduplicationResult.""" 39 | return [(record, 1.0) for record in records] 40 | -------------------------------------------------------------------------------- /semhash/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Protocol, Sequence, Union 2 | 3 | import numpy as np 4 | from frozendict import frozendict 5 | 6 | 7 | class Encoder(Protocol): 8 | """An encoder protocol for SemHash.""" 9 | 10 | def encode( 11 | self, 12 | sentences: Union[list[str], str, Sequence[str]], 13 | **kwargs: Any, 14 | ) -> np.ndarray: 15 | """ 16 | Encode a list of sentences into embeddings. 17 | 18 | :param sentences: A list of sentences to encode. 19 | :param **kwargs: Additional keyword arguments. 20 | :return: The embeddings of the sentences. 21 | """ 22 | ... # pragma: no cover 23 | 24 | 25 | def to_frozendict(record: dict[str, str], columns: set[str]) -> frozendict[str, str]: 26 | """Convert a record to a frozendict.""" 27 | return frozendict({k: record.get(k, "") for k in columns}) 28 | 29 | 30 | def compute_candidate_limit( 31 | total: int, 32 | selection_size: int, 33 | fraction: float = 0.1, 34 | min_candidates: int = 100, 35 | max_candidates: int = 1000, 36 | ) -> int: 37 | """ 38 | Compute the 'auto' candidate limit based on the total number of records. 39 | 40 | :param total: Total number of records. 41 | :param selection_size: Number of representatives to select. 42 | :param fraction: Fraction of total records to consider as candidates. 43 | :param min_candidates: Minimum number of candidates. 44 | :param max_candidates: Maximum number of candidates. 45 | :return: Computed candidate limit. 46 | """ 47 | # 1) fraction of total 48 | limit = int(total * fraction) 49 | # 2) ensure enough to pick selection_size 50 | limit = max(limit, selection_size) 51 | # 3) enforce lower bound 52 | limit = max(limit, min_candidates) 53 | # 4) enforce upper bound (and never exceed the dataset) 54 | limit = min(limit, max_candidates, total) 55 | return limit 56 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: Run tests and upload coverage 2 | 3 | on: 4 | push 5 | 6 | jobs: 7 | test: 8 | name: Run tests with pytest 9 | runs-on: ${{ matrix.os }} 10 | strategy: 11 | matrix: 12 | os: ["ubuntu-latest", "windows-latest"] 13 | python-version: ["3.9", "3.10", "3.11", "3.12"] 14 | exclude: 15 | - os: windows-latest 16 | python-version: "3.9" 17 | - os: windows-latest 18 | python-version: "3.11" 19 | - os: windows-latest 20 | python-version: "3.12" 21 | fail-fast: false 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | 26 | - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }} 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | allow-prereleases: true 31 | 32 | # Step for Windows: Create and activate a virtual environment 33 | - name: Create and activate a virtual environment (Windows) 34 | if: ${{ runner.os == 'Windows' }} 35 | run: | 36 | irm https://astral.sh/uv/install.ps1 | iex 37 | $env:Path = "C:\Users\runneradmin\.local\bin;$env:Path" 38 | uv venv .venv 39 | "VIRTUAL_ENV=.venv" | Out-File -FilePath $env:GITHUB_ENV -Append 40 | "$PWD/.venv/Scripts" | Out-File -FilePath $env:GITHUB_PATH -Append 41 | 42 | # Step for Unix: Create and activate a virtual environment 43 | - name: Create and activate a virtual environment (Unix) 44 | if: ${{ runner.os != 'Windows' }} 45 | run: | 46 | curl -LsSf https://astral.sh/uv/install.sh | sh 47 | uv venv .venv 48 | echo "VIRTUAL_ENV=.venv" >> $GITHUB_ENV 49 | echo "$PWD/.venv/bin" >> $GITHUB_PATH 50 | 51 | # Install dependencies using uv pip 52 | - name: Install dependencies 53 | run: make install-no-pre-commit 54 | 55 | # Run tests with coverage 56 | - name: Run tests under coverage 57 | run: | 58 | coverage run -m pytest 59 | coverage report 60 | 61 | # Upload results to Codecov 62 | - name: Upload results to Codecov 63 | uses: codecov/codecov-action@v4 64 | with: 65 | token: ${{ secrets.CODECOV_TOKEN }} 66 | -------------------------------------------------------------------------------- /benchmarks/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class DatasetRecord: 6 | """Dataset record.""" 7 | 8 | name: str 9 | text_name: str | None = None 10 | label_name: str | None = None 11 | sub_directory: str = "" 12 | columns: list[str] | None = None 13 | split_one: str = "train" 14 | split_two: str = "test" 15 | 16 | 17 | DATASET_DICT: dict[str, DatasetRecord] = { 18 | "bbc": DatasetRecord(name="SetFit/bbc-news", text_name="text", label_name="label_text"), 19 | "senteval_cr": DatasetRecord(name="SetFit/SentEval-CR", text_name="text", label_name="label_text"), 20 | "tweet_sentiment_extraction": DatasetRecord( 21 | name="SetFit/tweet_sentiment_extraction", text_name="text", label_name="label_text" 22 | ), 23 | "emotion": DatasetRecord(name="SetFit/emotion", text_name="text", label_name="label_text"), 24 | "amazon_counterfactual": DatasetRecord( 25 | name="SetFit/amazon_counterfactual_en", text_name="text", label_name="label_text" 26 | ), 27 | "ag_news": DatasetRecord(name="SetFit/ag_news", text_name="text", label_name="label_text"), 28 | "enron_spam": DatasetRecord(name="SetFit/enron_spam", text_name="text", label_name="label_text"), 29 | "subj": DatasetRecord(name="SetFit/subj", text_name="text", label_name="label_text"), 30 | "sst5": DatasetRecord(name="SetFit/sst5", text_name="text", label_name="label_text"), 31 | "20_newgroups": DatasetRecord(name="SetFit/20_newsgroups", text_name="text", label_name="label_text"), 32 | "hatespeech_offensive": DatasetRecord(name="SetFit/hate_speech_offensive", text_name="text", label_name="label"), 33 | "ade": DatasetRecord(name="SetFit/ade_corpus_v2_classification", text_name="text", label_name="label"), 34 | "imdb": DatasetRecord(name="SetFit/imdb", text_name="text", label_name="label"), 35 | "massive_scenario": DatasetRecord( 36 | name="SetFit/amazon_massive_scenario_en-US", text_name="text", label_name="label" 37 | ), 38 | "student": DatasetRecord(name="SetFit/student-question-categories", text_name="text", label_name="label"), 39 | "squad_v2": DatasetRecord(name="squad_v2", columns=["question", "context"], split_two="validation"), 40 | "wikitext": DatasetRecord( 41 | name="Salesforce/wikitext", text_name="text", label_name="text", sub_directory="wikitext-103-raw-v1" 42 | ), 43 | } 44 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "semhash" 3 | description = "Fast Semantic Text Deduplication & Filtering" 4 | authors = [{name = "Thomas van Dongen", email = "thomas123@live.nl"}, { name = "Stéphan Tulkens", email = "stephantul@gmail.com"}] 5 | readme = { file = "README.md", content-type = "text/markdown" } 6 | dynamic = ["version"] 7 | license = { file = "LICENSE" } 8 | requires-python = ">=3.9" 9 | 10 | classifiers = [ 11 | "Development Status :: 4 - Beta", 12 | "Intended Audience :: Developers", 13 | "Intended Audience :: Science/Research", 14 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 15 | "Topic :: Software Development :: Libraries", 16 | "License :: OSI Approved :: MIT License", 17 | "Programming Language :: Python :: 3 :: Only", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Natural Language :: English", 23 | ] 24 | 25 | dependencies = [ 26 | "model2vec>=0.3.4", 27 | "vicinity[usearch]>=0.4.3", 28 | "frozendict", 29 | ] 30 | 31 | [build-system] 32 | requires = ["setuptools>=64", "setuptools_scm>=8"] 33 | build-backend = "setuptools.build_meta" 34 | 35 | [project.optional-dependencies] 36 | dev = [ 37 | "black", 38 | "ipython", 39 | "mypy", 40 | "pre-commit", 41 | "pytest", 42 | "pytest-coverage", 43 | "ruff", 44 | ] 45 | 46 | [project.urls] 47 | "Homepage" = "https://github.com/MinishLab" 48 | "Bug Reports" = "https://github.com/MinishLab/semhash/issues" 49 | "Source" = "https://github.com/MinishLab/semhash" 50 | 51 | [tool.ruff] 52 | exclude = [".venv/"] 53 | line-length = 120 54 | target-version = "py310" 55 | 56 | [tool.ruff.lint] 57 | select = [ 58 | # Annotations: Enforce type annotations 59 | "ANN", 60 | # Complexity: Enforce a maximum cyclomatic complexity 61 | "C90", 62 | # Pydocstyle: Enforce docstrings 63 | "D", 64 | # Isort: Enforce import order 65 | "I", 66 | # Numpy: Enforce numpy style 67 | "NPY", 68 | # Print: Forbid print statements 69 | "T20", 70 | ] 71 | ignore = [ 72 | # Allow self and cls to be untyped, and allow Any type 73 | "ANN101", "ANN102", "ANN401", 74 | # Pydocstyle ignores 75 | "D100", "D101", "D104", "D203", "D212", "D401", 76 | # Allow use of f-strings in logging 77 | "G004" 78 | ] 79 | 80 | [tool.pydoclint] 81 | style = "sphinx" 82 | exclude = "test_" 83 | allow-init-docstring = true 84 | arg-type-hints-in-docstring = false 85 | check-return-types = false 86 | require-return-section-when-returning-nothing = false 87 | 88 | [tool.mypy] 89 | python_version = "3.10" 90 | warn_unused_configs = true 91 | ignore_missing_imports = true 92 | 93 | [tool.setuptools] 94 | packages = ["semhash"] 95 | license-files = [] 96 | 97 | [tool.setuptools_scm] 98 | # can be empty if no extra settings are needed, presence enables setuptools_scm 99 | 100 | [tool.setuptools.dynamic] 101 | version = {attr = "semhash.version.__version__"} 102 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | local/* 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /semhash/index.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | import numpy as np 6 | from vicinity import Backend 7 | from vicinity.backends import AbstractBackend, get_backend_class 8 | from vicinity.datatypes import SingleQueryResult 9 | 10 | DocScore = tuple[dict[str, str], float] 11 | DocScores = list[DocScore] 12 | DictItem = list[dict[str, str]] 13 | 14 | 15 | class Index: 16 | def __init__(self, vectors: np.ndarray, items: list[DictItem], backend: AbstractBackend) -> None: 17 | """ 18 | An index that maps vectors to items. 19 | 20 | This index has an efficient backend for querying, but also explicitly stores the vectors in memory. 21 | 22 | :param vectors: The vectors of the items. 23 | :param items: The items in the index. This is a list of lists. Each sublist contains one or more dictionaries 24 | that represent records. These records are exact duplicates of each other. 25 | :param backend: The backend to use for querying. 26 | """ 27 | self.items = items 28 | self.backend = backend 29 | self.vectors = vectors 30 | 31 | @classmethod 32 | def from_vectors_and_items( 33 | cls, vectors: np.ndarray, items: list[DictItem], backend_type: Backend | str, **kwargs: Any 34 | ) -> Index: 35 | """ 36 | Load the index from vectors and items. 37 | 38 | :param vectors: The vectors of the items. 39 | :param items: The items in the index. 40 | :param backend_type: The type of backend to use. 41 | :param **kwargs: Additional arguments to pass to the backend. 42 | :return: The index. 43 | """ 44 | backend_class = get_backend_class(backend_type) 45 | arguments = backend_class.argument_class(**kwargs) 46 | backend = backend_class.from_vectors(vectors, **arguments.dict()) 47 | 48 | return cls(vectors, items, backend) 49 | 50 | def query_threshold(self, vectors: np.ndarray, threshold: float) -> list[DocScores]: 51 | """ 52 | Query the index with a threshold. 53 | 54 | :param vectors: The vectors to query. 55 | :param threshold: The similarity threshold. 56 | :return: The query results. 57 | """ 58 | out: list[DocScores] = [] 59 | for result in self.backend.threshold(vectors, threshold=1 - threshold, max_k=100): 60 | intermediate = [] 61 | for index, distance in zip(*result): 62 | # Every item in the index contains one or more records. 63 | # These are all exact duplicates, so they get the same score. 64 | for record in self.items[index]: 65 | # The score is the cosine similarity. 66 | # The backend returns distances, so we need to convert. 67 | intermediate.append((record, 1 - distance)) 68 | out.append(intermediate) 69 | 70 | return out 71 | 72 | def query_top_k(self, vectors: np.ndarray, k: int, vectors_are_in_index: bool) -> list[SingleQueryResult]: 73 | """ 74 | Query the index with a top-k. 75 | 76 | :param vectors: The vectors to query. 77 | :param k: Maximum number of top-k records to keep. 78 | :param vectors_are_in_index: Whether the vectors are in the index. If this is set to True, we retrieve k + 1 79 | records, and do not consider the first one, as it is the query vector itself. 80 | :return: The query results. Each result is a tuple where the first element is the list of neighbor records, 81 | and the second element is a NumPy array of cosine similarity scores. 82 | """ 83 | results = [] 84 | offset = int(vectors_are_in_index) 85 | for x, y in self.backend.query(vectors=vectors, k=k + offset): 86 | # Convert returned distances to cosine similarities. 87 | similarities = 1 - y[offset:] 88 | results.append((x[offset:], similarities)) 89 | return results 90 | -------------------------------------------------------------------------------- /tests/data/test_model/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | library_name: model2vec 3 | license: mit 4 | model_name: potion-base-4m-int8 5 | tags: 6 | - embeddings 7 | - static-embeddings 8 | - sentence-transformers 9 | --- 10 | 11 | # potion-base-4m-int8 Model Card 12 | 13 | This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of a Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical. Model2Vec models are the smallest, fastest, and most performant static embedders available. The distilled models are up to 50 times smaller and 500 times faster than traditional Sentence Transformers. 14 | 15 | 16 | ## Installation 17 | 18 | Install model2vec using pip: 19 | ``` 20 | pip install model2vec 21 | ``` 22 | 23 | ## Usage 24 | 25 | ### Using Model2Vec 26 | 27 | The [Model2Vec library](https://github.com/MinishLab/model2vec) is the fastest and most lightweight way to run Model2Vec models. 28 | 29 | Load this model using the `from_pretrained` method: 30 | ```python 31 | from model2vec import StaticModel 32 | 33 | # Load a pretrained Model2Vec model 34 | model = StaticModel.from_pretrained("potion-base-4m-int8") 35 | 36 | # Compute text embeddings 37 | embeddings = model.encode(["Example sentence"]) 38 | ``` 39 | 40 | ### Using Sentence Transformers 41 | 42 | You can also use the [Sentence Transformers library](https://github.com/UKPLab/sentence-transformers) to load and use the model: 43 | 44 | ```python 45 | from sentence_transformers import SentenceTransformer 46 | 47 | # Load a pretrained Sentence Transformer model 48 | model = SentenceTransformer("potion-base-4m-int8") 49 | 50 | # Compute text embeddings 51 | embeddings = model.encode(["Example sentence"]) 52 | ``` 53 | 54 | ### Distilling a Model2Vec model 55 | 56 | You can distill a Model2Vec model from a Sentence Transformer model using the `distill` method. First, install the `distill` extra with `pip install model2vec[distill]`. Then, run the following code: 57 | 58 | ```python 59 | from model2vec.distill import distill 60 | 61 | # Distill a Sentence Transformer model, in this case the BAAI/bge-base-en-v1.5 model 62 | m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=256) 63 | 64 | # Save the model 65 | m2v_model.save_pretrained("m2v_model") 66 | ``` 67 | 68 | ## How it works 69 | 70 | Model2vec creates a small, fast, and powerful model that outperforms other static embedding models by a large margin on all tasks we could find, while being much faster to create than traditional static embedding models such as GloVe. Best of all, you don't need any data to distill a model using Model2Vec. 71 | 72 | It works by passing a vocabulary through a sentence transformer model, then reducing the dimensionality of the resulting embeddings using PCA, and finally weighting the embeddings using [SIF weighting](https://openreview.net/pdf?id=SyK00v5xx). During inference, we simply take the mean of all token embeddings occurring in a sentence. 73 | 74 | ## Additional Resources 75 | 76 | - [Model2Vec Repo](https://github.com/MinishLab/model2vec) 77 | - [Model2Vec Base Models](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e) 78 | - [Model2Vec Results](https://github.com/MinishLab/model2vec/tree/main/results) 79 | - [Model2Vec Tutorials](https://github.com/MinishLab/model2vec/tree/main/tutorials) 80 | - [Website](https://minishlab.github.io/) 81 | 82 | 83 | ## Library Authors 84 | 85 | Model2Vec was developed by the [Minish Lab](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled). 86 | 87 | ## Citation 88 | 89 | Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) if you use this model in your work. 90 | ``` 91 | @article{minishlab2024model2vec, 92 | author = {Tulkens, Stephan and {van Dongen}, Thomas}, 93 | title = {Model2Vec: Fast State-of-the-Art Static Embeddings}, 94 | year = {2024}, 95 | url = {https://github.com/MinishLab/model2vec} 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /benchmarks/results/train_benchmark_results.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "dataset": "bbc", 4 | "original_train_size": 1225, 5 | "deduplicated_train_size": 1144, 6 | "percent_removed": 6.612244897959185, 7 | "build_time_seconds": 0.5598082079086453, 8 | "deduplication_time_seconds": 0.008702374994754791, 9 | "time_seconds": 0.5685105829034001 10 | }, 11 | { 12 | "dataset": "senteval_cr", 13 | "original_train_size": 3012, 14 | "deduplicated_train_size": 2990, 15 | "percent_removed": 0.7304116865869847, 16 | "build_time_seconds": 0.10847400000784546, 17 | "deduplication_time_seconds": 0.027519959025084972, 18 | "time_seconds": 0.13599395903293043 19 | }, 20 | { 21 | "dataset": "tweet_sentiment_extraction", 22 | "original_train_size": 27481, 23 | "deduplicated_train_size": 26695, 24 | "percent_removed": 2.860157927295226, 25 | "build_time_seconds": 1.3568968329345807, 26 | "deduplication_time_seconds": 0.41194633406121284, 27 | "time_seconds": 1.7688431669957936 28 | }, 29 | { 30 | "dataset": "emotion", 31 | "original_train_size": 16000, 32 | "deduplicated_train_size": 15695, 33 | "percent_removed": 1.9062499999999982, 34 | "build_time_seconds": 0.5511152499821037, 35 | "deduplication_time_seconds": 0.21407662506680936, 36 | "time_seconds": 0.7651918750489131 37 | }, 38 | { 39 | "dataset": "amazon_counterfactual", 40 | "original_train_size": 5000, 41 | "deduplicated_train_size": 4992, 42 | "percent_removed": 0.16000000000000458, 43 | "build_time_seconds": 0.2848535830853507, 44 | "deduplication_time_seconds": 0.048574666026979685, 45 | "time_seconds": 0.3334282491123304 46 | }, 47 | { 48 | "dataset": "ag_news", 49 | "original_train_size": 120000, 50 | "deduplicated_train_size": 106921, 51 | "percent_removed": 10.899166666666671, 52 | "build_time_seconds": 3.0319770000642166, 53 | "deduplication_time_seconds": 2.171258582966402, 54 | "time_seconds": 5.203235583030619 55 | }, 56 | { 57 | "dataset": "enron_spam", 58 | "original_train_size": 31716, 59 | "deduplicated_train_size": 20540, 60 | "percent_removed": 35.23773489721276, 61 | "build_time_seconds": 1.3818323339801282, 62 | "deduplication_time_seconds": 0.6438171250047162, 63 | "time_seconds": 2.0256494589848444 64 | }, 65 | { 66 | "dataset": "subj", 67 | "original_train_size": 8000, 68 | "deduplicated_train_size": 7990, 69 | "percent_removed": 0.12499999999999734, 70 | "build_time_seconds": 0.5059439589967951, 71 | "deduplication_time_seconds": 0.12505983305163682, 72 | "time_seconds": 0.6310037920484319 73 | }, 74 | { 75 | "dataset": "sst5", 76 | "original_train_size": 8544, 77 | "deduplicated_train_size": 8526, 78 | "percent_removed": 0.2106741573033699, 79 | "build_time_seconds": 0.4805819580797106, 80 | "deduplication_time_seconds": 0.10166720801498741, 81 | "time_seconds": 0.582249166094698 82 | }, 83 | { 84 | "dataset": "20_newgroups", 85 | "original_train_size": 11314, 86 | "deduplicated_train_size": 10684, 87 | "percent_removed": 5.568322432384654, 88 | "build_time_seconds": 0.610724583035335, 89 | "deduplication_time_seconds": 0.11600329191423953, 90 | "time_seconds": 0.7267278749495745 91 | }, 92 | { 93 | "dataset": "hatespeech_offensive", 94 | "original_train_size": 22783, 95 | "deduplicated_train_size": 22090, 96 | "percent_removed": 3.0417416494754823, 97 | "build_time_seconds": 0.6471997499465942, 98 | "deduplication_time_seconds": 0.2704670410603285, 99 | "time_seconds": 0.9176667910069227 100 | }, 101 | { 102 | "dataset": "ade", 103 | "original_train_size": 17637, 104 | "deduplicated_train_size": 15718, 105 | "percent_removed": 10.880535238419231, 106 | "build_time_seconds": 0.5221591669833288, 107 | "deduplication_time_seconds": 0.20764074998442084, 108 | "time_seconds": 0.7297999169677496 109 | }, 110 | { 111 | "dataset": "imdb", 112 | "original_train_size": 25000, 113 | "deduplicated_train_size": 24830, 114 | "percent_removed": 0.6800000000000028, 115 | "build_time_seconds": 1.460668999934569, 116 | "deduplication_time_seconds": 0.29758112493436784, 117 | "time_seconds": 1.7582501248689368 118 | }, 119 | { 120 | "dataset": "massive_scenario", 121 | "original_train_size": 11514, 122 | "deduplicated_train_size": 9366, 123 | "percent_removed": 18.655549765502865, 124 | "build_time_seconds": 0.35503324994351715, 125 | "deduplication_time_seconds": 0.11619104200508446, 126 | "time_seconds": 0.4712242919486016 127 | }, 128 | { 129 | "dataset": "student", 130 | "original_train_size": 117519, 131 | "deduplicated_train_size": 63856, 132 | "percent_removed": 45.66325445247151, 133 | "build_time_seconds": 2.9044899590080604, 134 | "deduplication_time_seconds": 5.895973875070922, 135 | "time_seconds": 8.800463834078982 136 | }, 137 | { 138 | "dataset": "squad_v2", 139 | "original_train_size": 130319, 140 | "deduplicated_train_size": 109698, 141 | "percent_removed": 15.823479308466148, 142 | "build_time_seconds": 6.078755749971606, 143 | "deduplication_time_seconds": 2.7270843340083957, 144 | "time_seconds": 8.805840083980002 145 | }, 146 | { 147 | "dataset": "wikitext", 148 | "original_train_size": 1801350, 149 | "deduplicated_train_size": 884645, 150 | "percent_removed": 50.88988813945097, 151 | "build_time_seconds": 39.38258587510791, 152 | "deduplication_time_seconds": 44.1503732081037, 153 | "time_seconds": 83.53295908321161 154 | } 155 | ] 156 | -------------------------------------------------------------------------------- /benchmarks/results/train_test_benchmark_results.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "dataset": "bbc", 4 | "train_size": 1225, 5 | "test_size": 1000, 6 | "deduplicated_test_size": 870, 7 | "percent_removed": 13.0, 8 | "build_time_seconds": 0.5598082079086453, 9 | "deduplication_time_seconds": 0.1528247919632122, 10 | "time_seconds": 0.7126329998718575 11 | }, 12 | { 13 | "dataset": "senteval_cr", 14 | "train_size": 3012, 15 | "test_size": 753, 16 | "deduplicated_test_size": 750, 17 | "percent_removed": 0.3984063745019917, 18 | "build_time_seconds": 0.10847400000784546, 19 | "deduplication_time_seconds": 0.019297500024549663, 20 | "time_seconds": 0.12777150003239512 21 | }, 22 | { 23 | "dataset": "tweet_sentiment_extraction", 24 | "train_size": 27481, 25 | "test_size": 3534, 26 | "deduplicated_test_size": 3412, 27 | "percent_removed": 3.452178834182229, 28 | "build_time_seconds": 1.3568968329345807, 29 | "deduplication_time_seconds": 0.17268049996346235, 30 | "time_seconds": 1.529577332898043 31 | }, 32 | { 33 | "dataset": "emotion", 34 | "train_size": 16000, 35 | "test_size": 2000, 36 | "deduplicated_test_size": 1926, 37 | "percent_removed": 3.7000000000000033, 38 | "build_time_seconds": 0.5511152499821037, 39 | "deduplication_time_seconds": 0.10135454102419317, 40 | "time_seconds": 0.6524697910062969 41 | }, 42 | { 43 | "dataset": "amazon_counterfactual", 44 | "train_size": 5000, 45 | "test_size": 5000, 46 | "deduplicated_test_size": 4990, 47 | "percent_removed": 0.20000000000000018, 48 | "build_time_seconds": 0.2848535830853507, 49 | "deduplication_time_seconds": 0.22846354101784527, 50 | "time_seconds": 0.513317124103196 51 | }, 52 | { 53 | "dataset": "ag_news", 54 | "train_size": 120000, 55 | "test_size": 7600, 56 | "deduplicated_test_size": 6198, 57 | "percent_removed": 18.447368421052634, 58 | "build_time_seconds": 3.0319770000642166, 59 | "deduplication_time_seconds": 0.7034984159981832, 60 | "time_seconds": 3.7354754160623997 61 | }, 62 | { 63 | "dataset": "enron_spam", 64 | "train_size": 31716, 65 | "test_size": 2000, 66 | "deduplicated_test_size": 1060, 67 | "percent_removed": 47.0, 68 | "build_time_seconds": 1.3818323339801282, 69 | "deduplication_time_seconds": 0.553584959008731, 70 | "time_seconds": 1.9354172929888591 71 | }, 72 | { 73 | "dataset": "subj", 74 | "train_size": 8000, 75 | "test_size": 2000, 76 | "deduplicated_test_size": 1999, 77 | "percent_removed": 0.04999999999999449, 78 | "build_time_seconds": 0.5059439589967951, 79 | "deduplication_time_seconds": 0.11624520795885473, 80 | "time_seconds": 0.6221891669556499 81 | }, 82 | { 83 | "dataset": "sst5", 84 | "train_size": 8544, 85 | "test_size": 2210, 86 | "deduplicated_test_size": 2205, 87 | "percent_removed": 0.2262443438914019, 88 | "build_time_seconds": 0.4805819580797106, 89 | "deduplication_time_seconds": 0.11375170899555087, 90 | "time_seconds": 0.5943336670752615 91 | }, 92 | { 93 | "dataset": "20_newgroups", 94 | "train_size": 11314, 95 | "test_size": 7532, 96 | "deduplicated_test_size": 7098, 97 | "percent_removed": 5.762081784386619, 98 | "build_time_seconds": 0.610724583035335, 99 | "deduplication_time_seconds": 1.6346445409581065, 100 | "time_seconds": 2.2453691239934415 101 | }, 102 | { 103 | "dataset": "hatespeech_offensive", 104 | "train_size": 22783, 105 | "test_size": 2000, 106 | "deduplicated_test_size": 1925, 107 | "percent_removed": 3.749999999999998, 108 | "build_time_seconds": 0.6471997499465942, 109 | "deduplication_time_seconds": 0.12372829194646329, 110 | "time_seconds": 0.7709280418930575 111 | }, 112 | { 113 | "dataset": "ade", 114 | "train_size": 17637, 115 | "test_size": 5879, 116 | "deduplicated_test_size": 4952, 117 | "percent_removed": 15.76798775301922, 118 | "build_time_seconds": 0.5221591669833288, 119 | "deduplication_time_seconds": 0.28758599993307143, 120 | "time_seconds": 0.8097451669164002 121 | }, 122 | { 123 | "dataset": "imdb", 124 | "train_size": 25000, 125 | "test_size": 25000, 126 | "deduplicated_test_size": 24795, 127 | "percent_removed": 0.8199999999999985, 128 | "build_time_seconds": 1.460668999934569, 129 | "deduplication_time_seconds": 1.3489695829339325, 130 | "time_seconds": 2.8096385828685015 131 | }, 132 | { 133 | "dataset": "massive_scenario", 134 | "train_size": 11514, 135 | "test_size": 2974, 136 | "deduplicated_test_size": 2190, 137 | "percent_removed": 26.36180228648285, 138 | "build_time_seconds": 0.35503324994351715, 139 | "deduplication_time_seconds": 0.10878237499855459, 140 | "time_seconds": 0.46381562494207174 141 | }, 142 | { 143 | "dataset": "student", 144 | "train_size": 117519, 145 | "test_size": 5000, 146 | "deduplicated_test_size": 2393, 147 | "percent_removed": 52.14, 148 | "build_time_seconds": 2.9044899590080604, 149 | "deduplication_time_seconds": 0.8721794589655474, 150 | "time_seconds": 3.776669417973608 151 | }, 152 | { 153 | "dataset": "squad_v2", 154 | "train_size": 130319, 155 | "test_size": 11873, 156 | "deduplicated_test_size": 11863, 157 | "percent_removed": 0.08422471153036737, 158 | "build_time_seconds": 6.078755749971606, 159 | "deduplication_time_seconds": 1.0497459589969367, 160 | "time_seconds": 7.1285017089685425 161 | }, 162 | { 163 | "dataset": "wikitext", 164 | "train_size": 1801350, 165 | "test_size": 4358, 166 | "deduplicated_test_size": 2139, 167 | "percent_removed": 50.91785222579165, 168 | "build_time_seconds": 39.38258587510791, 169 | "deduplication_time_seconds": 0.9325925830053166, 170 | "time_seconds": 40.31517845811322 171 | } 172 | ] 173 | -------------------------------------------------------------------------------- /benchmarks/run_benchmarks.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from time import perf_counter 4 | 5 | from datasets import load_dataset 6 | from model2vec import StaticModel 7 | 8 | from benchmarks.data import DATASET_DICT 9 | from semhash import SemHash 10 | 11 | # Set up logging 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def main() -> None: # noqa: C901 16 | """Run the benchmarks.""" 17 | # Prepare lists to hold benchmark results 18 | train_dedup_results = [] 19 | train_test_dedup_results = [] 20 | # Load the model and initialize SemHash 21 | 22 | model = StaticModel.from_pretrained("minishlab/potion-base-8m") 23 | 24 | for dataset_name, record in DATASET_DICT.items(): 25 | logger.info(f"Loading dataset: {dataset_name} from {record.name}") 26 | 27 | # Load train and test splits 28 | if record.sub_directory: 29 | train_ds = load_dataset(record.name, record.sub_directory, split=record.split_one) 30 | test_ds = load_dataset(record.name, record.sub_directory, split=record.split_two) 31 | else: 32 | train_ds = load_dataset(record.name, split=record.split_one) 33 | test_ds = load_dataset(record.name, split=record.split_two) 34 | 35 | # If the dataset has columns, use them 36 | if record.columns: 37 | columns = record.columns 38 | train_records = [dict(row) for row in train_ds] 39 | test_records = [dict(row) for row in test_ds] 40 | # Else, use the text_name 41 | else: 42 | train_records = train_ds[record.text_name] 43 | test_records = test_ds[record.text_name] 44 | columns = None 45 | 46 | # Build the SemHash instance 47 | build_start = perf_counter() 48 | semhash = SemHash.from_records(model=model, use_ann=True, records=train_records, columns=columns) 49 | build_end = perf_counter() 50 | build_time = build_end - build_start 51 | # Time how long it takes to deduplicate the train set 52 | train_only_start = perf_counter() 53 | deduplicated_train = semhash.self_deduplicate() 54 | train_only_end = perf_counter() 55 | 56 | train_only_dedup_time = train_only_end - train_only_start 57 | original_train_size = len(train_records) 58 | dedup_train_size = len(deduplicated_train.selected) 59 | 60 | percent_removed_train = deduplicated_train.duplicate_ratio * 100 61 | train_dedup_results.append( 62 | { 63 | "dataset": dataset_name, 64 | "original_train_size": original_train_size, 65 | "deduplicated_train_size": dedup_train_size, 66 | "percent_removed": percent_removed_train, 67 | "build_time_seconds": build_time, 68 | "deduplication_time_seconds": train_only_dedup_time, 69 | "time_seconds": train_only_dedup_time + build_time, 70 | } 71 | ) 72 | 73 | logger.info( 74 | f"[TRAIN DEDUPLICATION] Dataset: {dataset_name}\n" 75 | f" - Original Train Size: {original_train_size}\n" 76 | f" - Deduplicated Train Size: {dedup_train_size}\n" 77 | f" - % Removed: {percent_removed_train:.2f}\n" 78 | f" - Deduplication Time (seconds): {train_only_dedup_time:.2f}\n" 79 | f" - Build Time (seconds): {build_time:.2f}\n" 80 | f" - Total Time (seconds): {train_only_dedup_time + build_time:.2f}\n" 81 | ) 82 | 83 | # Time how long it takes to deduplicate the test set 84 | train_test_start = perf_counter() 85 | deduplicated_test = semhash.deduplicate( 86 | records=test_records, 87 | ) 88 | train_test_end = perf_counter() 89 | train_test_dedup_time = train_test_end - train_test_start 90 | original_test_size = len(test_records) 91 | deduped_test_size = len(deduplicated_test.selected) 92 | percent_removed_test = deduplicated_test.duplicate_ratio * 100 93 | 94 | train_test_dedup_results.append( 95 | { 96 | "dataset": dataset_name, 97 | "train_size": original_train_size, 98 | "test_size": original_test_size, 99 | "deduplicated_test_size": deduped_test_size, 100 | "percent_removed": percent_removed_test, 101 | "build_time_seconds": build_time, 102 | "deduplication_time_seconds": train_test_dedup_time, 103 | "time_seconds": train_test_dedup_time + build_time, 104 | } 105 | ) 106 | 107 | logger.info( 108 | f"[TRAIN/TEST DEDUPLICATION] Dataset: {dataset_name}\n" 109 | f" - Train Size: {original_train_size}\n" 110 | f" - Test Size: {original_test_size}\n" 111 | f" - Deduplicated Test Size: {deduped_test_size}\n" 112 | f" - % Removed: {percent_removed_test:.2f}\n" 113 | f" - Deduplication Time (seconds): {train_test_dedup_time:.2f}\n" 114 | f" - Build Time (seconds): {build_time:.2f}\n" 115 | f" - Total Time (seconds): {train_test_dedup_time + build_time:.2f}\n" 116 | ) 117 | 118 | # Write the results to JSON files 119 | with open("benchmarks/results/train_benchmark_results.json", "w", encoding="utf-8") as f: 120 | json.dump(train_dedup_results, f, ensure_ascii=False, indent=2) 121 | 122 | with open("benchmarks/results/train_test_benchmark_results.json", "w", encoding="utf-8") as f: 123 | json.dump(train_test_dedup_results, f, ensure_ascii=False, indent=2) 124 | 125 | # Print the train table 126 | print("### Train Deduplication Benchmark\n") # noqa T201 127 | print( # noqa T201 128 | f"| {'Dataset':<20} | {'Original Train Size':>20} | {'Deduplicated Train Size':>24} | {'% Removed':>10} | {'Deduplication Time (s)':>24} |" 129 | ) # noqa T201 130 | print("|" + "-" * 22 + "|" + "-" * 22 + "|" + "-" * 26 + "|" + "-" * 12 + "|" + "-" * 26 + "|") # noqa T201 131 | for r in train_dedup_results: 132 | print( # noqa T201 133 | f"| {r['dataset']:<20} " 134 | f"| {r['original_train_size']:>20} " 135 | f"| {r['deduplicated_train_size']:>24} " 136 | f"| {r['percent_removed']:>10.2f} " 137 | f"| {r['time_seconds']:>24.2f} |" 138 | ) 139 | 140 | print("\n") # noqa T201 141 | 142 | # Print the train/test table 143 | print("### Train/Test Deduplication Benchmark\n") # noqa T201 144 | print( # noqa T201 145 | f"| {'Dataset':<20} | {'Train Size':>12} | {'Test Size':>12} | {'Deduplicated Test Size':>24} | {'% Removed':>10} | {'Deduplication Time (s)':>24} |" 146 | ) # noqa T201 147 | print("|" + "-" * 22 + "|" + "-" * 14 + "|" + "-" * 14 + "|" + "-" * 26 + "|" + "-" * 12 + "|" + "-" * 26 + "|") # noqa T201 148 | for r in train_test_dedup_results: 149 | print( # noqa T201 150 | f"| {r['dataset']:<20} " 151 | f"| {r['train_size']:>12} " 152 | f"| {r['test_size']:>12} " 153 | f"| {r['deduplicated_test_size']:>24} " 154 | f"| {r['percent_removed']:>10.2f} " 155 | f"| {r['time_seconds']:>24.2f} |" 156 | ) 157 | 158 | 159 | if __name__ == "__main__": 160 | logging.basicConfig(level=logging.INFO) 161 | main() 162 | -------------------------------------------------------------------------------- /tests/test_datamodels.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import semhash 4 | import semhash.version 5 | from semhash.datamodels import DeduplicationResult, DuplicateRecord, SelectedWithDuplicates 6 | 7 | 8 | def test_deduplication_scoring() -> None: 9 | """Test the deduplication scoring.""" 10 | d = DeduplicationResult( 11 | ["a", "b", "c"], 12 | [DuplicateRecord("a", False, [("b", 0.9)]), DuplicateRecord("b", False, [("c", 0.8)])], 13 | 0.8, 14 | ) 15 | assert d.duplicate_ratio == 0.4 16 | 17 | 18 | def test_deduplication_scoring_exact() -> None: 19 | """Test the deduplication scoring.""" 20 | d = DeduplicationResult( 21 | ["a", "b", "c"], 22 | [DuplicateRecord("a", True, [("b", 0.9)]), DuplicateRecord("b", False, [("c", 0.8)])], 23 | 0.8, 24 | ) 25 | assert d.exact_duplicate_ratio == 0.2 26 | 27 | 28 | def test_deduplication_scoring_exact_empty() -> None: 29 | """Test the deduplication scoring.""" 30 | d = DeduplicationResult([], [], 0.8, columns=["text"]) 31 | assert d.exact_duplicate_ratio == 0.0 32 | 33 | 34 | def test_deduplication_scoring_empty() -> None: 35 | """Test the deduplication scoring.""" 36 | d = DeduplicationResult([], [], 0.8, columns=["text"]) 37 | assert d.duplicate_ratio == 0.0 38 | 39 | 40 | def test_rethreshold() -> None: 41 | """Test rethresholding the duplicates.""" 42 | d = DuplicateRecord("a", False, [("b", 0.9), ("c", 0.8)]) 43 | d._rethreshold(0.85) 44 | assert d.duplicates == [("b", 0.9)] 45 | 46 | 47 | def test_rethreshold_empty() -> None: 48 | """Test rethresholding the duplicates.""" 49 | d = DuplicateRecord("a", False, []) 50 | d._rethreshold(0.85) 51 | assert d.duplicates == [] 52 | 53 | 54 | def test_get_least_similar_from_duplicates() -> None: 55 | """Test getting the least similar duplicates.""" 56 | d = DeduplicationResult( 57 | ["a", "b", "c"], 58 | [DuplicateRecord("a", False, [("b", 0.9), ("c", 0.7)]), DuplicateRecord("b", False, [("c", 0.8)])], 59 | 0.8, 60 | ) 61 | result = d.get_least_similar_from_duplicates(1) 62 | assert result == [("a", "c", 0.7)] 63 | 64 | 65 | def test_get_least_similar_from_duplicates_empty() -> None: 66 | """Test getting the least similar duplicates.""" 67 | d = DeduplicationResult([], [], 0.8, columns=["text"]) 68 | assert d.get_least_similar_from_duplicates(1) == [] 69 | 70 | 71 | def test_rethreshold_deduplication_result() -> None: 72 | """Test rethresholding the duplicates.""" 73 | d = DeduplicationResult( 74 | ["a", "b", "c"], 75 | [ 76 | DuplicateRecord("d", False, [("x", 0.9), ("y", 0.8)]), 77 | DuplicateRecord("e", False, [("z", 0.8)]), 78 | ], 79 | 0.8, 80 | ) 81 | d.rethreshold(0.85) 82 | assert d.filtered == [DuplicateRecord("d", False, [("x", 0.9)])] 83 | assert d.selected == ["a", "b", "c", "e"] 84 | 85 | 86 | def test_rethreshold_exception() -> None: 87 | """Test rethresholding throws an exception.""" 88 | d = DeduplicationResult( 89 | ["a", "b", "c"], 90 | [ 91 | DuplicateRecord("d", False, [("x", 0.9), ("y", 0.8)]), 92 | DuplicateRecord("e", False, [("z", 0.8)]), 93 | ], 94 | 0.7, 95 | ) 96 | with pytest.raises(ValueError): 97 | d.rethreshold(0.6) 98 | 99 | 100 | def test_deprecation_deduplicated_duplicates() -> None: 101 | """Test deprecation warnings for deduplicated and duplicates fields.""" 102 | if semhash.version.__version__ < "0.4.0": 103 | with pytest.warns(DeprecationWarning): 104 | d = DeduplicationResult( 105 | deduplicated=["a", "b", "c"], 106 | duplicates=[ 107 | DuplicateRecord("d", False, [("x", 0.9), ("y", 0.8)]), 108 | DuplicateRecord("e", False, [("z", 0.8)]), 109 | ], 110 | threshold=0.8, 111 | ) 112 | else: 113 | raise ValueError("deprecate `deduplicated` and `duplicates` fields in `DeduplicationResult`") 114 | assert d.selected == ["a", "b", "c"] 115 | assert d.filtered == [ 116 | DuplicateRecord("d", False, [("x", 0.9), ("y", 0.8)]), 117 | DuplicateRecord("e", False, [("z", 0.8)]), 118 | ] 119 | 120 | 121 | def test_selected_with_duplicates_strings() -> None: 122 | """Test selected_with_duplicates for strings.""" 123 | d = DeduplicationResult( 124 | selected=["original"], 125 | filtered=[ 126 | DuplicateRecord("duplicate_1", False, [("original", 0.9)]), 127 | DuplicateRecord("duplicate_2", False, [("original", 0.8)]), 128 | ], 129 | threshold=0.8, 130 | ) 131 | 132 | expected = [ 133 | SelectedWithDuplicates( 134 | record="original", 135 | duplicates=[("duplicate_1", 0.9), ("duplicate_2", 0.8)], 136 | ) 137 | ] 138 | assert d.selected_with_duplicates == expected 139 | 140 | 141 | def test_selected_with_duplicates_dicts() -> None: 142 | """Test selected_with_duplicates for dicts.""" 143 | selected = {"id": 0, "text": "hello"} 144 | d = DeduplicationResult( 145 | selected=[selected], 146 | filtered=[ 147 | DuplicateRecord({"id": 1, "text": "hello"}, True, [(selected, 1.0)]), 148 | DuplicateRecord({"id": 2, "text": "helllo"}, False, [(selected, 0.1)]), 149 | ], 150 | threshold=0.8, 151 | columns=["text"], 152 | ) 153 | 154 | items = d.selected_with_duplicates 155 | assert len(items) == 1 156 | kept = items[0].record 157 | dups = items[0].duplicates 158 | assert kept == selected 159 | assert {r["id"] for r, _ in dups} == {1, 2} 160 | 161 | 162 | def test_selected_with_duplicates_multi_column() -> None: 163 | """Test selected_with_duplicates for multi-columns.""" 164 | selected = {"text": "hello", "text2": "world"} 165 | d = DeduplicationResult( 166 | selected=[selected], 167 | filtered=[ 168 | DuplicateRecord({"text": "hello", "text2": "world"}, True, [(selected, 1.0)]), 169 | DuplicateRecord({"text": "helllo", "text2": "world"}, False, [(selected, 0.1)]), 170 | ], 171 | threshold=0.8, 172 | columns=["text", "text2"], 173 | ) 174 | 175 | items = d.selected_with_duplicates 176 | assert len(items) == 1 177 | kept = items[0].record 178 | assert kept == selected 179 | 180 | 181 | def test_selected_with_duplicates_unhashable_values() -> None: 182 | """Test selected_with_duplicates with unhashable values in records.""" 183 | selected = {"text": "hello", "a": [1, 2, 3]} # list -> unhashable value 184 | filtered = {"text": "hello", "a": [1, 2, 3], "flag": True} 185 | 186 | d = DeduplicationResult( 187 | selected=[selected], 188 | filtered=[DuplicateRecord(filtered, exact=False, duplicates=[(selected, 1.0)])], 189 | threshold=0.8, 190 | columns=["text"], 191 | ) 192 | 193 | items = d.selected_with_duplicates 194 | assert items == [SelectedWithDuplicates(record=selected, duplicates=[(filtered, 1.0)])] 195 | 196 | 197 | def test_selected_with_duplicates_removes_internal_duplicates() -> None: 198 | """Test that selected_with_duplicates removes internal duplicates that have the same hash.""" 199 | selected = {"id": 0, "text": "hello"} 200 | filtered = {"id": 1, "text": "hello"} 201 | 202 | d = DeduplicationResult( 203 | selected=[selected], 204 | filtered=[ 205 | DuplicateRecord(filtered, exact=False, duplicates=[(selected, 0.95)]), 206 | DuplicateRecord(filtered, exact=False, duplicates=[(selected, 0.90)]), 207 | ], 208 | threshold=0.8, 209 | columns=["text"], 210 | ) 211 | 212 | items = d.selected_with_duplicates 213 | assert len(items) == 1 214 | 215 | selected_record = items[0].record 216 | duplicate_list = items[0].duplicates 217 | # Should keep the kept record unchanged 218 | assert selected_record == selected 219 | # The duplicate row must appear only once 220 | assert len(duplicate_list) == 1 221 | assert duplicate_list[0][0] == filtered 222 | -------------------------------------------------------------------------------- /semhash/datamodels.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import warnings 5 | from collections import defaultdict 6 | from dataclasses import dataclass, field 7 | from typing import Any, Generic, Hashable, Sequence, TypeVar 8 | 9 | from frozendict import frozendict 10 | from typing_extensions import TypeAlias 11 | 12 | from semhash.utils import to_frozendict 13 | 14 | Record = TypeVar("Record", str, dict[str, Any]) 15 | DuplicateList: TypeAlias = list[tuple[Record, float]] 16 | 17 | 18 | @dataclass 19 | class DuplicateRecord(Generic[Record]): 20 | """ 21 | A single record with its duplicates. 22 | 23 | Attributes 24 | ---------- 25 | record: The original record being deduplicated. 26 | exact: Whether the record was identified as an exact match. 27 | duplicates: List of tuples consisting of duplicate records and their associated scores. 28 | 29 | """ 30 | 31 | record: Record 32 | exact: bool 33 | duplicates: DuplicateList = field(default_factory=list) 34 | 35 | def _rethreshold(self, threshold: float) -> None: 36 | """Rethreshold the duplicates.""" 37 | self.duplicates = [(d, score) for d, score in self.duplicates if score >= threshold] 38 | 39 | 40 | @dataclass 41 | class SelectedWithDuplicates(Generic[Record]): 42 | """ 43 | A record that has been selected along with its duplicates. 44 | 45 | Attributes 46 | ---------- 47 | record: The original record being selected. 48 | duplicates: List of tuples consisting of duplicate records and their associated scores. 49 | 50 | """ 51 | 52 | record: Record 53 | duplicates: DuplicateList = field(default_factory=list) 54 | 55 | 56 | @dataclass 57 | class DeduplicationResult(Generic[Record]): 58 | """ 59 | Deduplication result. 60 | 61 | Attributes 62 | ---------- 63 | selected: List of deduplicated records after removing duplicates. 64 | filtered: List of DuplicateRecord objects containing details about duplicates of an original record. 65 | threshold: The similarity threshold used for deduplication. 66 | columns: Columns used for deduplication. 67 | deduplicated: Deprecated, use selected instead. 68 | duplicates: Deprecated, use filtered instead. 69 | 70 | """ 71 | 72 | selected: list[Record] = field(default_factory=list) 73 | filtered: list[DuplicateRecord] = field(default_factory=list) 74 | threshold: float = field(default=0.9) 75 | columns: Sequence[str] | None = field(default=None) 76 | deduplicated: list[Record] = field(default_factory=list) # Deprecated 77 | duplicates: list[DuplicateRecord] = field(default_factory=list) # Deprecated 78 | 79 | def __post_init__(self) -> None: 80 | """Initialize deprecated fields and warn about deprecation.""" 81 | if self.deduplicated or self.duplicates: 82 | warnings.warn( 83 | "'deduplicated' and 'duplicates' fields are deprecated and will be removed in a future release. Use 'selected' and 'filtered' instead.", 84 | DeprecationWarning, 85 | stacklevel=2, 86 | ) 87 | 88 | if not self.selected and self.deduplicated: 89 | self.selected = self.deduplicated 90 | if not self.filtered and self.duplicates: 91 | self.filtered = self.duplicates 92 | if not self.deduplicated: 93 | self.deduplicated = self.selected 94 | if not self.duplicates: 95 | self.duplicates = self.filtered 96 | 97 | @property 98 | def duplicate_ratio(self) -> float: 99 | """Return the percentage of records dropped.""" 100 | if denom := len(self.selected) + len(self.filtered): 101 | return 1.0 - len(self.selected) / denom 102 | return 0.0 103 | 104 | @property 105 | def exact_duplicate_ratio(self) -> float: 106 | """Return the percentage of records dropped due to an exact match.""" 107 | if denom := len(self.selected) + len(self.filtered): 108 | return len([dup for dup in self.filtered if dup.exact]) / denom 109 | return 0.0 110 | 111 | def get_least_similar_from_duplicates(self, n: int = 1) -> list[tuple[Record, Record, float]]: 112 | """ 113 | Return the N least similar duplicate pairs. 114 | 115 | :param n: The number of least similar pairs to return. 116 | :return: A list of tuples consisting of (original_record, duplicate_record, score). 117 | """ 118 | all_pairs = [(dup.record, d, score) for dup in self.filtered for d, score in dup.duplicates] 119 | sorted_pairs = sorted(all_pairs, key=lambda x: x[2]) # Sort by score 120 | return sorted_pairs[:n] 121 | 122 | def rethreshold(self, threshold: float) -> None: 123 | """Rethreshold the duplicates.""" 124 | if self.threshold > threshold: 125 | raise ValueError("Threshold is smaller than the given value.") 126 | for dup in self.filtered: 127 | dup._rethreshold(threshold) 128 | if not dup.duplicates: 129 | self.filtered.remove(dup) 130 | self.selected.append(dup.record) 131 | self.threshold = threshold 132 | 133 | @property 134 | def selected_with_duplicates(self) -> list[SelectedWithDuplicates[Record]]: 135 | """ 136 | For every kept record, return the duplicates that were removed along with their similarity scores. 137 | 138 | :return: A list of tuples where each tuple contains a kept record 139 | and a list of its duplicates with their similarity scores. 140 | """ 141 | 142 | def _to_hashable(record: Record) -> frozendict[str, str] | str: 143 | """Convert a record to a hashable representation.""" 144 | if isinstance(record, dict) and self.columns is not None: 145 | # Convert dict to frozendict for immutability and hashability 146 | return to_frozendict(record, set(self.columns)) 147 | return str(record) 148 | 149 | # Build a mapping from original-record to [(duplicate, score), …] 150 | buckets: defaultdict[Hashable, DuplicateList] = defaultdict(list) 151 | for duplicate_record in self.filtered: 152 | for original_record, score in duplicate_record.duplicates: 153 | buckets[_to_hashable(original_record)].append((duplicate_record.record, float(score))) 154 | 155 | result: list[SelectedWithDuplicates[Record]] = [] 156 | for selected in self.selected: 157 | # Get the list of duplicates for the selected record 158 | raw_list = buckets.get(_to_hashable(selected), []) 159 | # Ensure we don't have duplicates in the list 160 | # Use full-record canonical JSON for dicts so that unhashable values are handled correctly 161 | deduped = { 162 | ( 163 | json.dumps(rec, sort_keys=True, separators=(",", ":"), ensure_ascii=False) 164 | if isinstance(rec, dict) 165 | else rec 166 | ): (rec, score) 167 | for rec, score in raw_list 168 | } 169 | result.append(SelectedWithDuplicates(record=selected, duplicates=list(deduped.values()))) 170 | 171 | return result 172 | 173 | 174 | @dataclass 175 | class FilterResult(Generic[Record]): 176 | """ 177 | Result of filtering operations. 178 | 179 | Attributes 180 | ---------- 181 | selected: List of records that passed the filter criteria. 182 | filtered: List of records that were filtered out. 183 | scores_selected: List of scores for the selected records. 184 | scores_filtered: List of scores for the filtered records. 185 | 186 | """ 187 | 188 | selected: list[Record] 189 | filtered: list[Record] 190 | scores_selected: list[float] = field(default_factory=list) 191 | scores_filtered: list[float] = field(default_factory=list) 192 | 193 | @property 194 | def filter_ratio(self) -> float: 195 | """Return the percentage of records filtered out.""" 196 | if denom := len(self.selected) + len(self.filtered): 197 | return len(self.filtered) / denom 198 | return 0.0 199 | 200 | @property 201 | def selected_ratio(self) -> float: 202 | """Return the percentage of records selected.""" 203 | return 1 - self.filter_ratio 204 | -------------------------------------------------------------------------------- /tests/test_semhash.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from semhash import SemHash 5 | from semhash.datamodels import FilterResult 6 | from semhash.utils import Encoder 7 | 8 | 9 | def test_single_dataset_deduplication(use_ann: bool, model: Encoder) -> None: 10 | """Test single dataset deduplication.""" 11 | # No duplicates 12 | texts = [ 13 | "It's dangerous to go alone!", 14 | "The master sword can seal the darkness.", 15 | "Ganondorf has invaded Hyrule!", 16 | ] 17 | semhash = SemHash.from_records(records=texts, use_ann=use_ann, model=model) 18 | deduplicated_texts = semhash.self_deduplicate().selected 19 | 20 | assert deduplicated_texts == texts 21 | 22 | # With duplicates 23 | texts = [ 24 | "It's dangerous to go alone!", 25 | "It's dangerous to go alone!", # Exact duplicate 26 | "It's not safe to go alone!", # Semantically similar 27 | ] 28 | semhash = SemHash.from_records(records=texts, use_ann=use_ann, model=model) 29 | deduplicated_texts = semhash.self_deduplicate(0.7).selected 30 | assert deduplicated_texts == ["It's dangerous to go alone!"] 31 | 32 | 33 | def test_multi_dataset_deduplication(use_ann: bool, model: Encoder) -> None: 34 | """Test deduplication across two datasets.""" 35 | # No duplicates 36 | texts1 = [ 37 | "It's dangerous to go alone!", 38 | "It's a secret to everybody.", 39 | "Ganondorf has invaded Hyrule!", 40 | ] 41 | texts2 = [ 42 | "Link is the hero of time.", 43 | "Zelda is the princess of Hyrule.", 44 | "Ganon is the king of thieves.", 45 | ] 46 | semhash = SemHash.from_records(texts1, columns=None, use_ann=use_ann, model=model) 47 | deduplicated_texts = semhash.deduplicate(texts2).selected 48 | assert deduplicated_texts == texts2 49 | 50 | # With duplicates 51 | texts2 = [ 52 | "It's dangerous to go alone!", # Exact duplicate 53 | "It's risky to go alone!", # Semantically similar 54 | "Ganondorf has attacked Hyrule!", # Semantically similar 55 | ] 56 | deduplicated_texts = semhash.deduplicate(texts2, threshold=0.7).selected 57 | assert deduplicated_texts == [] 58 | 59 | 60 | def test_single_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) -> None: 61 | """Test single dataset deduplication with multi-column records.""" 62 | records = [ 63 | {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, 64 | {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, # Exact duplicate 65 | { 66 | "question": "Who is the protagonist?", 67 | "context": "In this story, Link is the hero", 68 | "answer": "Link", 69 | }, # Semantically similar 70 | {"question": "Who is the princess?", "context": "The princess is Zelda", "answer": "Zelda"}, 71 | ] 72 | semhash = SemHash.from_records( 73 | records, 74 | columns=["question", "context", "answer"], 75 | use_ann=use_ann, 76 | model=model, 77 | ) 78 | deduplicated = semhash.self_deduplicate(threshold=0.7) 79 | 80 | assert deduplicated.selected == [ 81 | {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, 82 | {"question": "Who is the princess?", "context": "The princess is Zelda", "answer": "Zelda"}, 83 | ] 84 | 85 | 86 | def test_multi_dataset_deduplication_multicolumn(use_ann: bool, model: Encoder) -> None: 87 | """Test multi dataset deduplication with multi-column records.""" 88 | train_records = [ 89 | {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, 90 | {"question": "Who is the princess?", "context": "The princess is Zelda", "answer": "Zelda"}, 91 | ] 92 | test_records = [ 93 | {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, # Exact duplicate 94 | { 95 | "question": "Who is the princess?", 96 | "context": "Zelda is the princess", 97 | "answer": "Zelda", 98 | }, # Semantically similar 99 | {"question": "What is the villain's name?", "context": "The villain is Ganon", "answer": "Ganon"}, 100 | ] 101 | semhash = SemHash.from_records( 102 | train_records, 103 | columns=["question", "context", "answer"], 104 | use_ann=use_ann, 105 | model=model, 106 | ) 107 | deduplicated = semhash.deduplicate(test_records).selected 108 | assert deduplicated == [ 109 | {"question": "What is the villain's name?", "context": "The villain is Ganon", "answer": "Ganon"} 110 | ] 111 | 112 | 113 | def test_from_records_without_columns(use_ann: bool, model: Encoder) -> None: 114 | """Test fitting without specifying columns.""" 115 | records = [ 116 | {"question": "What is the hero's name?", "context": "The hero is Link", "answer": "Link"}, 117 | {"question": "Who is the princess?", "context": "The princess is Zelda", "answer": "Zelda"}, 118 | ] 119 | with pytest.raises(ValueError): 120 | SemHash.from_records(records, columns=None, use_ann=use_ann, model=model) 121 | 122 | 123 | def test_deduplicate_with_only_exact_duplicates(use_ann: bool, model: Encoder) -> None: 124 | """Test deduplicating with only exact duplicates.""" 125 | texts1 = [ 126 | "It's dangerous to go alone!", 127 | "It's dangerous to go alone!", 128 | "It's dangerous to go alone!", 129 | ] 130 | texts2 = [ 131 | "It's dangerous to go alone!", 132 | "It's dangerous to go alone!", 133 | "It's dangerous to go alone!", 134 | ] 135 | semhash = SemHash.from_records(texts1, use_ann=use_ann, model=model) 136 | deduplicated = semhash.self_deduplicate() 137 | assert deduplicated.selected == ["It's dangerous to go alone!"] 138 | 139 | deduplicated = semhash.deduplicate(texts2) 140 | assert deduplicated.selected == [] 141 | 142 | 143 | def test_self_find_representative(use_ann: bool, model: Encoder, train_texts: list[str]) -> None: 144 | """Test the self_find_representative method.""" 145 | semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model) 146 | result = semhash.self_find_representative( 147 | candidate_limit=5, 148 | selection_size=3, 149 | lambda_param=0.5, 150 | ) 151 | assert len(result.selected) == 3, "Expected 3 representatives" 152 | selected = {r["text"] for r in result.selected} 153 | assert selected == { 154 | "blueberry", 155 | "pineapple", 156 | "grape", 157 | }, "Expected representatives to be blueberry, pineapple, and grape" 158 | 159 | 160 | def test_find_representative(use_ann: bool, model: Encoder, train_texts: list[str], test_texts: list[str]) -> None: 161 | """Test the find_representative method.""" 162 | semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model) 163 | result = semhash.find_representative(records=test_texts, candidate_limit=5, selection_size=3, lambda_param=0.5) 164 | assert len(result.selected) == 3, "Expected 3 representatives" 165 | selected = {r["text"] for r in result.selected} 166 | assert selected == {"grapefruit", "banana", "apple"}, "Expected representatives to be grapefruit, banana, and apple" 167 | 168 | 169 | def test_filter_outliers(use_ann: bool, model: Encoder, train_texts: list[str], test_texts: list[str]) -> None: 170 | """Test the filter_outliers method.""" 171 | semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model) 172 | result = semhash.filter_outliers(records=test_texts, outlier_percentage=0.2) 173 | assert len(result.filtered) == 2, "Expected 2 outliers" 174 | assert len(result.selected) == len(test_texts) - 2 175 | filtered = {r["text"] for r in result.filtered} 176 | assert filtered == {"motorcycle", "plane"}, "Expected outliers to be motorcycle and plane" 177 | 178 | 179 | def test_self_filter_outliers(use_ann: bool, model: Encoder, train_texts: list[str]) -> None: 180 | """Test the self_filter_outliers method.""" 181 | semhash = SemHash.from_records(records=train_texts, use_ann=use_ann, model=model) 182 | result = semhash.self_filter_outliers(outlier_percentage=0.1) 183 | assert len(result.filtered) == 2, "Expected 2 outliers" 184 | assert len(result.selected) == len(train_texts) - 2 185 | filtered = {r["text"] for r in result.filtered} 186 | assert filtered == {"car", "bicycle"}, "Expected outliers to be car and bicycle" 187 | 188 | 189 | def test__mmr(monkeypatch: pytest.MonkeyPatch) -> None: 190 | """Test the _mmr method.""" 191 | # Create a dummy SemHash instance 192 | semhash = SemHash(index=None, model=None, columns=["text"], was_string=True) # type: ignore 193 | # Prepare a fake ranking with three records 194 | records = ["a", "b", "c"] 195 | scores = [3.0, 2.0, 1.0] 196 | ranking = FilterResult(selected=records, filtered=[], scores_selected=scores, scores_filtered=[]) 197 | # Create dummy embeddings for the records 198 | embeddings = np.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]]) 199 | # Monkeypatch featurize to return the dummy embeddings 200 | monkeypatch.setattr(semhash, "_featurize", lambda records, columns, model: embeddings) 201 | 202 | # Test lambda=1.0: pure relevance, should pick top 2 by score 203 | result_rel = semhash._mmr(ranking, candidate_limit=3, selection_size=2, lambda_param=1.0) 204 | assert result_rel.selected == ["a", "b"] 205 | 206 | # Test lambda=0.0: pure diversity, should first pick 'a', then pick most dissimilar: 'c' 207 | result_div = semhash._mmr(ranking, candidate_limit=3, selection_size=2, lambda_param=0.0) 208 | assert result_div.selected == ["a", "c"] 209 | 210 | 211 | def test_mmr_invalid_lambda_raises() -> None: 212 | """Test that invalid lambda values raise ValueError.""" 213 | semhash = SemHash(index=None, model=None, columns=["text"], was_string=True) # type: ignore 214 | dummy = FilterResult(selected=["x"], filtered=[], scores_selected=[0.5], scores_filtered=[]) 215 | with pytest.raises(ValueError): 216 | semhash._mmr(dummy, candidate_limit=1, selection_size=1, lambda_param=-0.1) 217 | with pytest.raises(ValueError): 218 | semhash._mmr(dummy, candidate_limit=1, selection_size=1, lambda_param=1.1) 219 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |

4 | SemHash logo
5 | Fast Semantic Text Deduplication & Filtering 6 |

7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 |

15 | Package version 16 | Supported Python versions 17 | 18 | Downloads 19 | 20 | 21 | Codecov 22 | 23 | 24 | Join Discord 25 | 26 | 27 | License - MIT 28 | 29 |

30 | 31 | 32 | 33 | [Quickstart](#quickstart) • 34 | [Main Features](#main-features) • 35 | [Usage](#usage) • 36 | [Benchmarks](#benchmarks) 37 | 38 |
39 | 40 | 41 | SemHash is a lightweight and flexible tool for deduplicating datasets, filtering outliers, and finding representative samples using semantic similarity. It combines fast embedding generation from [Model2Vec](https://github.com/MinishLab/model2vec) with efficient ANN-based similarity search through [Vicinity](https://github.com/MinishLab/vicinity). 42 | 43 | SemHash supports both single-dataset deduplication & filtering (e.g., cleaning up a train set by removing duplicates and outliers) and multi-dataset deduplication & filtering (e.g., ensuring no overlap between a test set and a train set). It works with simple datasets, such as text lists, and more complex ones, like multi-column QA datasets. Additionally, it includes functions to inspect deduplication results, making it easier to understand and refine your data cleaning process. 44 | 45 | ## Quickstart 46 | 47 | Install the package with: 48 | ```bash 49 | pip install semhash 50 | ``` 51 | 52 | Deduplicate a single dataset, filter outliers, and find representative samples with the following code (note: the examples assume you have `datasets` installed, which you can install with `pip install datasets`): 53 | 54 | ```python 55 | from datasets import load_dataset 56 | from semhash import SemHash 57 | 58 | # Load a dataset to deduplicate 59 | texts = load_dataset("ag_news", split="train")["text"] 60 | 61 | # Initialize a SemHash instance 62 | semhash = SemHash.from_records(records=texts) 63 | 64 | # Deduplicate the texts 65 | deduplicated_texts = semhash.self_deduplicate().selected 66 | 67 | # Filter outliers 68 | filtered_texts = semhash.self_filter_outliers().selected 69 | 70 | # Find representative texts 71 | representative_texts = semhash.self_find_representative().selected 72 | ``` 73 | 74 | Or, deduplicate across two datasets, filter outliers, and find representative samples with the following code (e.g., eliminating train/test leakage): 75 | 76 | ```python 77 | from datasets import load_dataset 78 | from semhash import SemHash 79 | 80 | # Load two datasets to deduplicate 81 | train_texts = load_dataset("ag_news", split="train")["text"] 82 | test_texts = load_dataset("ag_news", split="test")["text"] 83 | 84 | # Initialize a SemHash instance with the training data 85 | semhash = SemHash.from_records(records=train_texts) 86 | 87 | # Deduplicate the test data against the training data, optionally with a specific threshold 88 | deduplicated_test_texts = semhash.deduplicate(records=test_texts, threshold=0.9).selected 89 | 90 | # Filter outliers from the test data against the training data, 91 | # optionally with a specific percentage 92 | filtered_test_texts = semhash.filter_outliers(records=test_texts, outlier_percentage=0.1).selected 93 | 94 | # Find representative texts in the test data against the training data, 95 | # optionally with a specific selection size 96 | representative_test_texts = semhash.find_representative( 97 | records=test_texts, selection_size=10).selected 98 | 99 | 100 | ``` 101 | 102 | Or, deduplicate multi-column dataset, filter outliers, and find representative samples with the following code (e.g., deduplicating a QA dataset): 103 | 104 | ```python 105 | from datasets import load_dataset 106 | from semhash import SemHash 107 | 108 | # Load the dataset 109 | dataset = load_dataset("squad_v2", split="train") 110 | 111 | # Convert the dataset to a list of dictionaries 112 | records = [dict(row) for row in dataset] 113 | 114 | # Initialize SemHash with the columns to deduplicate 115 | semhash = SemHash.from_records(records=records, columns=["question", "context"]) 116 | 117 | # Deduplicate the records 118 | deduplicated_records = semhash.self_deduplicate().selected 119 | 120 | # Filter outliers from the records 121 | filtered_texts = semhash.self_filter_outliers().selected 122 | 123 | # Find representative texts in the records 124 | representative_texts = semhash.self_find_representative().selected 125 | ``` 126 | 127 | The `deduplicate` and `self_deduplicate` functions return a [DeduplicationResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#L30). This object stores the deduplicated corpus, a set of duplicate object (along with the objects that caused duplication), and several useful functions to further inspect the deduplication result. Examples of how these functions can be used can be found in the [usage](#usage) section. 128 | 129 | The `filter_outliers`, `self_filter_outliers`, `find_representative`, and `self_find_representative` functions return a [FilterResult](https://github.com/MinishLab/semhash/blob/main/semhash/datamodels.py#106). This object stores the found outliers/representative samples. 130 | 131 | For both the `DeduplicationResult` and `FilterResult` objects, you can easily view the filtered records with the `selected` attribute (e.g. to view outliers: `outliers = semhash.self_filter_outliers().filtered`) 132 | ## Main Features 133 | 134 | - **Fast**: SemHash uses [model2vec](https://github.com/MinishLab/model2vec) to embed texts and [vicinity](https://github.com/MinishLab/vicinity) to perform similarity search, making it extremely fast. 135 | - **Scalable**: SemHash can deduplicate & filter large datasets with millions of records thanks to the ANN backends in Vicinity. 136 | - **Flexible**: SemHash can be used to deduplicate & filter a single dataset or across two datasets, and can also be used to deduplicate & filter multi-column datasets (such as QA datasets). 137 | - **Lightweight**: SemHash is a lightweight package with minimal dependencies, making it easy to install and use. 138 | - **Explainable**: Easily inspect the duplicates and what caused them with the `DeduplicationResult` object. You can also view the lowest similarity duplicates to find the right threshold for deduplication for your dataset. 139 | 140 | ## Usage 141 | 142 | The following examples show the various ways you can use SemHash to deduplicate datasets, filter outliers, and find representative samples. These examples assume you have the `datasets` library installed, which you can install with `pip install datasets`. 143 | 144 |
145 | Deduplicate, filter outliers, and find representative samples on a single dataset 146 |
147 | 148 | The following code snippet shows how to deduplicate a single dataset, filter outliers, and find representative samples using SemHash (in this example, the train split of the [AG News dataset](https://huggingface.co/datasets/fancyzhx/ag_news)): 149 | 150 | ```python 151 | from datasets import load_dataset 152 | from semhash import SemHash 153 | 154 | # Load a dataset to deduplicate 155 | texts = load_dataset("ag_news", split="train")["text"] 156 | 157 | # Initialize a SemHash instance 158 | semhash = SemHash.from_records(records=texts) 159 | 160 | # Deduplicate the texts 161 | deduplicated_texts = semhash.self_deduplicate().selected 162 | 163 | # Filter outliers 164 | filtered_texts = semhash.self_filter_outliers().selected 165 | 166 | # Find representative texts 167 | representative_texts = semhash.self_find_representative().selected 168 | ``` 169 |
170 | 171 |
172 | Deduplicate, filter outliers, and find representative samples across two datasets 173 |
174 | 175 | The following code snippet shows how to deduplicate across two datasets, filter outliers, and find representative samples using SemHash (in this example, the train/test split of the [AG News dataset](https://huggingface.co/datasets/fancyzhx/ag_news)): 176 | 177 | ```python 178 | from datasets import load_dataset 179 | from semhash import SemHash 180 | 181 | # Initialize a SemHash instance 182 | semhash = SemHash() 183 | 184 | # Load two datasets to deduplicate 185 | train_texts = load_dataset("ag_news", split="train")["text"] 186 | test_texts = load_dataset("ag_news", split="test")["text"] 187 | 188 | # Initialize a SemHash instance 189 | semhash = SemHash.from_records(records=train_texts) 190 | 191 | # Deduplicate the test data against the training data 192 | deduplicated_test_texts = semhash.deduplicate(records=test_texts).selected 193 | 194 | # Filter outliers from the test data 195 | filtered_test_texts = semhash.filter_outliers(records=test_texts).selected 196 | 197 | # Find representative texts in the test data 198 | representative_test_texts = semhash.find_representative(records=test_texts).selected 199 | ``` 200 | 201 |
202 | 203 |
204 | Deduplicate, filter outliers, and find representative samples on multi-column datasets 205 |
206 | 207 | The following code snippet shows how to deduplicate multi-column datasets, filter outliers, and find representative samples using SemHash (in this example, the train split of the QA dataset [SQuAD 2.0](https://huggingface.co/datasets/rajpurkar/squad_v2), which consists of questions, contexts, and answers): 208 | 209 | ```python 210 | from datasets import load_dataset 211 | from semhash import SemHash 212 | 213 | # Load the dataset 214 | dataset = load_dataset("squad_v2", split="train") 215 | 216 | # Convert the dataset to a list of dictionaries 217 | records = [dict(row) for row in dataset] 218 | 219 | # Initialize SemHash with the columns to deduplicate 220 | semhash = SemHash.from_records(records=records, columns=["question", "context"]) 221 | 222 | # Deduplicate the records 223 | deduplicated_records = semhash.self_deduplicate().selected 224 | 225 | # Filter outliers from the records 226 | filtered_records = semhash.self_filter_outliers().selected 227 | 228 | # Find representative samples in the records 229 | representative_records = semhash.self_find_representative().selected 230 | ``` 231 | 232 |
233 | 234 |
235 | DeduplicationResult functionality 236 |
237 | 238 | The `DeduplicationResult` object returned by the `deduplicate` and `self_deduplicate` functions contains several useful functions to inspect the deduplication result. The following code snippet shows how to use these functions: 239 | 240 | ```python 241 | from datasets import load_dataset 242 | from semhash import SemHash 243 | 244 | # Load a dataset to deduplicate 245 | texts = load_dataset("ag_news", split="train")["text"] 246 | 247 | # Initialize a SemHash instance 248 | semhash = SemHash.from_records(records=texts) 249 | 250 | # Deduplicate the texts 251 | deduplication_result = semhash.self_deduplicate() 252 | 253 | # Check the deduplicated texts 254 | deduplication_result.selected 255 | # Check the duplicates 256 | deduplication_result.filtered 257 | # See what percentage of the texts were duplicates 258 | deduplication_result.duplicate_ratio 259 | # See what percentage of the texts were exact duplicates 260 | deduplication_result.exact_duplicate_ratio 261 | 262 | # Get the least similar text from the duplicates. This is useful for finding the right threshold for deduplication. 263 | least_similar = deduplication_result.get_least_similar_from_duplicates() 264 | 265 | # Rethreshold the duplicates. This allows you to instantly rethreshold the duplicates with a new threshold without having to re-deduplicate the texts. 266 | deduplication_result.rethreshold(0.95) 267 | 268 | # View selected records along with their duplicates. 269 | # This is the opposite of the `filtered` attribute, which shows for every duplicate the record that caused it. 270 | deduplication_result.selected_with_duplicates 271 | ``` 272 | 273 |
274 | 275 |
276 | Using custom encoders 277 |
278 | 279 | The following code snippet shows how to use a custom encoder with SemHash: 280 | 281 | ```python 282 | from datasets import load_dataset 283 | from model2vec import StaticModel 284 | from semhash import SemHash 285 | 286 | # Load a dataset to deduplicate 287 | texts = load_dataset("ag_news", split="train")["text"] 288 | 289 | # Load an embedding model (in this example, a multilingual model) 290 | model = StaticModel.from_pretrained("minishlab/M2V_multilingual_output") 291 | 292 | # Initialize a SemHash with the model and custom encoder 293 | semhash = SemHash.from_records(records=texts, model=model) 294 | 295 | # Deduplicate the texts 296 | deduplicated_texts = semhash.self_deduplicate() 297 | ``` 298 | 299 | Any encoder can be used that adheres to our [encoder protocol](https://github.com/MinishLab/semhash/blob/main/semhash/utils.py). For example, any [sentence-transformers](https://github.com/UKPLab/sentence-transformers) model can be used as an encoder: 300 | 301 | ```python 302 | from datasets import load_dataset 303 | from semhash import SemHash 304 | from sentence_transformers import SentenceTransformer 305 | 306 | # Load a dataset to deduplicate 307 | texts = load_dataset("ag_news", split="train")["text"] 308 | 309 | # Load a sentence-transformers model 310 | model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") 311 | 312 | # Initialize a SemHash with the model and custom encoder 313 | semhash = SemHash.from_records(records=texts, model=model) 314 | 315 | # Deduplicate the texts 316 | deduplicated_texts = semhash.self_deduplicate() 317 | ``` 318 | 319 |
320 | 321 | 322 | 323 | 324 |
325 | Using custom ANN backends 326 |
327 | 328 | The following code snippet shows how to use a custom ANN backend and custom args with SemHash: 329 | 330 | ```python 331 | from datasets import load_dataset 332 | from semhash import SemHash 333 | from vicinity import Backend 334 | 335 | # Load a dataset to deduplicate 336 | texts = load_dataset("ag_news", split="train")["text"] 337 | 338 | # Initialize a SemHash with the model and custom ann backend and custom args 339 | semhash = SemHash.from_records(records=texts, ann_backend=Backend.FAISS, nlist=50) 340 | 341 | # Deduplicate the texts 342 | deduplicated_texts = semhash.self_deduplicate() 343 | ``` 344 | 345 | For the full list of supported ANN backends and args, see the [Vicinity docs](https://github.com/MinishLab/vicinity/tree/main?tab=readme-ov-file#supported-backends). 346 | 347 |
348 | 349 | 350 |
351 | Using Pandas DataFrames 352 |
353 | 354 | You can easily use Pandas DataFrames with SemHash. The following code snippet shows how to deduplicate a Pandas DataFrame: 355 | 356 | ```python 357 | import pandas as pd 358 | from datasets import load_dataset 359 | from semhash import SemHash 360 | 361 | # Load a dataset as a pandas dataframe 362 | dataframe = load_dataset("ag_news", split="train").to_pandas() 363 | 364 | # Convert the dataframe to a list of dictionaries 365 | dataframe = dataframe.to_dict(orient="records") 366 | 367 | # Initialize a SemHash instance with the columns to deduplicate 368 | semhash = SemHash.from_records(records=dataframe, columns=["text"]) 369 | 370 | # Deduplicate the texts 371 | deduplicated_records = semhash.self_deduplicate().selected 372 | 373 | # Convert the deduplicated records back to a pandas dataframe 374 | deduplicated_dataframe = pd.DataFrame(deduplicated_records) 375 | ``` 376 | 377 |
378 | 379 | NOTE: By default, we use the ANN (approximate-nearest neighbors) backend for deduplication. We recommend keeping this since the recall for smaller datasets is ~100%, and it's needed for larger datasets (>1M samples) since these will take too long to deduplicate without ANN. If you want to use the flat/exact-matching backend, you can set `use_ann=False` in the SemHash constructor: 380 | 381 | ```python 382 | semhash = SemHash.from_records(records=texts, use_ann=False) 383 | ``` 384 | 385 | 386 | 387 | ## Benchmarks 388 | 389 | We've benchmarked SemHash on a variety of datasets to measure the deduplication performance and speed. The benchmarks were run with the following setup: 390 | - The benchmarks were all run on CPU 391 | - The benchmarks were all run with `use_ann=True` 392 | - The used encoder is the default encoder ([potion-base-8M](https://huggingface.co/minishlab/potion-base-8M)). 393 | - The timings include the encoding time, index building time, and deduplication time. 394 | ### Train Deduplication Benchmark 395 | 396 | | Dataset | Original Train Size | Deduplicated Train Size | % Removed | Deduplication Time (s) | 397 | |----------------------|----------------------|--------------------------|------------|--------------------------| 398 | | bbc | 1225 | 1144 | 6.61 | 0.57 | 399 | | senteval_cr | 3012 | 2990 | 0.73 | 0.14 | 400 | | tweet_sentiment_extraction | 27481 | 26695 | 2.86 | 1.77 | 401 | | emotion | 16000 | 15695 | 1.91 | 0.77 | 402 | | amazon_counterfactual | 5000 | 4992 | 0.16 | 0.33 | 403 | | ag_news | 120000 | 106921 | 10.90 | 5.20 | 404 | | enron_spam | 31716 | 20540 | 35.24 | 2.03 | 405 | | subj | 8000 | 7990 | 0.12 | 0.63 | 406 | | sst5 | 8544 | 8526 | 0.21 | 0.58 | 407 | | 20_newgroups | 11314 | 10684 | 5.57 | 0.73 | 408 | | hatespeech_offensive | 22783 | 22090 | 3.04 | 0.92 | 409 | | ade | 17637 | 15718 | 10.88 | 0.73 | 410 | | imdb | 25000 | 24830 | 0.68 | 1.76 | 411 | | massive_scenario | 11514 | 9366 | 18.66 | 0.47 | 412 | | student | 117519 | 63856 | 45.66 | 8.80 | 413 | | squad_v2 | 130319 | 109698 | 15.82 | 8.81 | 414 | | wikitext | 1801350 | 884645 | 50.89 | 83.53 | 415 | 416 | 417 | ### Train/Test Deduplication Benchmark 418 | 419 | | Dataset | Train Size | Test Size | Deduplicated Test Size | % Removed | Deduplication Time (s) | 420 | |----------------------|--------------|--------------|--------------------------|------------|--------------------------| 421 | | bbc | 1225 | 1000 | 870 | 13.00 | 0.71 | 422 | | senteval_cr | 3012 | 753 | 750 | 0.40 | 0.13 | 423 | | tweet_sentiment_extraction | 27481 | 3534 | 3412 | 3.45 | 1.53 | 424 | | emotion | 16000 | 2000 | 1926 | 3.70 | 0.65 | 425 | | amazon_counterfactual | 5000 | 5000 | 4990 | 0.20 | 0.51 | 426 | | ag_news | 120000 | 7600 | 6198 | 18.45 | 3.74 | 427 | | enron_spam | 31716 | 2000 | 1060 | 47.00 | 1.94 | 428 | | subj | 8000 | 2000 | 1999 | 0.05 | 0.62 | 429 | | sst5 | 8544 | 2210 | 2205 | 0.23 | 0.59 | 430 | | 20_newgroups | 11314 | 7532 | 7098 | 5.76 | 2.25 | 431 | | hatespeech_offensive | 22783 | 2000 | 1925 | 3.75 | 0.77 | 432 | | ade | 17637 | 5879 | 4952 | 15.77 | 0.81 | 433 | | imdb | 25000 | 25000 | 24795 | 0.82 | 2.81 | 434 | | massive_scenario | 11514 | 2974 | 2190 | 26.36 | 0.46 | 435 | | student | 117519 | 5000 | 2393 | 52.14 | 3.78 | 436 | | squad_v2 | 130319 | 11873 | 11863 | 0.08 | 7.13 | 437 | | wikitext | 1801350 | 4358 | 2139 | 50.92 | 40.32 | 438 | 439 | 440 | As can be seen, SemHash is extremely fast, and scales to large datasets with millions of records. There are some notable examples of train/test leakage, such as `enron_spam` and `student`, where the test dataset contains a significant amount of semantic overlap with the training dataset. 441 | 442 | ### Reproducing the Benchmarks 443 | 444 | To run the benchmarks yourself, you can use the following command (assuming you have the `datasets` library installed): 445 | 446 | ```bash 447 | python -m benchmarks.run_benchmarks 448 | ``` 449 | Optionally, the datasets can be updated in the [data.py](https://github.com/MinishLab/semhash/blob/main/benchmarks/data.py) file. 450 | 451 | ## License 452 | 453 | MIT 454 | 455 | ## Citing 456 | 457 | If you use SemHash in your research, please cite the following: 458 | ```bibtex 459 | @software{minishlab2025semhash, 460 | author = {{van Dongen}, Thomas and Stephan Tulkens}, 461 | title = {SemHash: Fast Semantic Text Deduplication \& Filtering}, 462 | year = {2025}, 463 | publisher = {Zenodo}, 464 | doi = {10.5281/zenodo.17265942}, 465 | url = {https://github.com/MinishLab/semhash}, 466 | license = {MIT} 467 | } 468 | ``` 469 | -------------------------------------------------------------------------------- /semhash/semhash.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import defaultdict 4 | from math import ceil 5 | from typing import Any, Generic, Literal, Sequence 6 | 7 | import numpy as np 8 | from frozendict import frozendict 9 | from model2vec import StaticModel 10 | from vicinity import Backend 11 | 12 | from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult, Record 13 | from semhash.index import Index 14 | from semhash.records import add_scores_to_records, map_deduplication_result_to_strings 15 | from semhash.utils import Encoder, compute_candidate_limit, to_frozendict 16 | 17 | 18 | class SemHash(Generic[Record]): 19 | def __init__(self, index: Index, model: Encoder, columns: Sequence[str], was_string: bool) -> None: 20 | """ 21 | Initialize SemHash. 22 | 23 | :param index: An index. 24 | :param model: A model to use for featurization. 25 | :param columns: Columns of the records. 26 | :param was_string: Whether the records were strings. Used for mapping back to strings. 27 | """ 28 | self.index = index 29 | self.model = model 30 | self.columns = columns 31 | self._was_string = was_string 32 | self._ranking_cache: FilterResult | None = None 33 | 34 | @staticmethod 35 | def _featurize( 36 | records: Sequence[dict[str, str]], 37 | columns: Sequence[str], 38 | model: Encoder, 39 | ) -> np.ndarray: 40 | """ 41 | Featurize a list of records using the model. 42 | 43 | :param records: A list of records. 44 | :param columns: Columns to featurize. 45 | :param model: An Encoder model. 46 | :return: The embeddings of the records. 47 | """ 48 | # Extract the embeddings for each column across all records 49 | embeddings_per_col = [] 50 | for col in columns: 51 | col_texts = [r[col] for r in records] 52 | col_emb = model.encode(col_texts) 53 | embeddings_per_col.append(np.asarray(col_emb)) 54 | 55 | return np.concatenate(embeddings_per_col, axis=1) 56 | 57 | @classmethod 58 | def _remove_exact_duplicates( 59 | cls, 60 | records: Sequence[dict[str, str]], 61 | columns: Sequence[str], 62 | reference_records: list[list[dict[str, str]]] | None = None, 63 | ) -> tuple[list[dict[str, str]], list[tuple[dict[str, str], list[dict[str, str]]]]]: 64 | """ 65 | Remove exact duplicates based on the unpacked string representation of each record. 66 | 67 | If reference_records is None, the function will only check for duplicates within the records list. 68 | 69 | :param records: A list of records to check for exact duplicates. 70 | :param columns: Columns to unpack. 71 | :param reference_records: A list of records to compare against. These are already unpacked 72 | :return: A list of deduplicated records and a list of duplicates. 73 | """ 74 | deduplicated = [] 75 | duplicates = [] 76 | 77 | column_set = set(columns) 78 | # Build a seen set from reference_records if provided 79 | seen: defaultdict[frozendict[str, str], list[dict[str, str]]] = defaultdict(list) 80 | if reference_records is not None: 81 | for record_set in reference_records: 82 | key = to_frozendict(record_set[0], column_set) 83 | seen[key] = list(record_set) 84 | in_one_set = reference_records is None 85 | 86 | for record in records: 87 | frozen_record = frozendict({k: v for k, v in record.items() if k in column_set}) 88 | if duplicated_records := seen.get(frozen_record): 89 | duplicates.append((record, duplicated_records)) 90 | else: 91 | deduplicated.append(record) 92 | # Only add current documents to seen if no reference set is used 93 | if in_one_set: 94 | seen[frozen_record].append(record) 95 | 96 | return deduplicated, duplicates 97 | 98 | @classmethod 99 | def from_records( 100 | cls, 101 | records: Sequence[Record], 102 | columns: Sequence[str] | None = None, 103 | use_ann: bool = True, 104 | model: Encoder | None = None, 105 | ann_backend: Backend | str = Backend.USEARCH, 106 | **kwargs: Any, 107 | ) -> SemHash: 108 | """ 109 | Initialize a SemHash instance from records. 110 | 111 | This removes exact duplicates, featurizes the records, and fits a vicinity index. 112 | 113 | :param records: A list of records (strings or dictionaries). 114 | :param columns: Columns to featurize if records are dictionaries. 115 | :param use_ann: Whether to use approximate nearest neighbors (True) or basic search (False). Default is True. 116 | :param model: (Optional) An Encoder model. If None, the default model is used (minishlab/potion-base-8M). 117 | :param ann_backend: (Optional) The ANN backend to use if use_ann is True. Defaults to Backend.USEARCH. 118 | :param **kwargs: Any additional keyword arguments to pass to the Vicinity index. 119 | :return: A SemHash instance with a fitted vicinity index. 120 | :raises ValueError: If columns are not provided for dictionary records. 121 | """ 122 | if columns is None and isinstance(records[0], dict): 123 | raise ValueError("Columns must be specified when passing dictionaries.") 124 | 125 | if isinstance(records[0], str): 126 | # If records are strings, convert to dictionaries with a single column 127 | columns = ["text"] 128 | dict_records: list[dict[str, str]] = [{"text": record} for record in records] 129 | was_string = True 130 | else: 131 | dict_records = list(records) 132 | was_string = False 133 | 134 | # If no model is provided, load the default model 135 | if model is None: 136 | model = StaticModel.from_pretrained("minishlab/potion-base-8M") 137 | 138 | # Remove exact duplicates 139 | deduplicated_records, duplicates = cls._remove_exact_duplicates(dict_records, columns) 140 | 141 | col_set = set(columns) 142 | duplicate_map = defaultdict(list) 143 | for x, _ in duplicates: 144 | frozen_record = to_frozendict(x, col_set) 145 | duplicate_map[frozen_record].append(x) 146 | 147 | items: list[list[dict[str, str]]] = [] 148 | for record in deduplicated_records: 149 | i = [record] 150 | frozen_record = to_frozendict(record, set(columns)) 151 | i.extend(duplicate_map[frozen_record]) 152 | items.append(i) 153 | 154 | # Create embeddings and unpack records 155 | embeddings = cls._featurize(deduplicated_records, columns, model) 156 | 157 | # Build the Vicinity index 158 | backend = ann_backend if use_ann else Backend.BASIC 159 | index = Index.from_vectors_and_items( 160 | vectors=embeddings, 161 | items=items, 162 | backend_type=backend, 163 | **kwargs, 164 | ) 165 | 166 | return cls(index=index, columns=columns, model=model, was_string=was_string) 167 | 168 | def deduplicate( 169 | self, 170 | records: Sequence[Record], 171 | threshold: float = 0.9, 172 | ) -> DeduplicationResult: 173 | """ 174 | Perform deduplication against the fitted index. 175 | 176 | This method assumes you have already fit on a reference dataset (e.g., a train set) with from_records. 177 | It will remove any items from 'records' that are similar above a certain threshold 178 | to any item in the fitted dataset. 179 | 180 | :param records: A new set of records (e.g., test set) to deduplicate against the fitted dataset. 181 | :param threshold: Similarity threshold for deduplication. 182 | :return: A deduplicated list of records. 183 | """ 184 | dict_records = self._validate_if_strings(records) 185 | 186 | # Remove exact duplicates before embedding 187 | dict_records, exact_duplicates = self._remove_exact_duplicates( 188 | records=dict_records, columns=self.columns, reference_records=self.index.items 189 | ) 190 | duplicate_records = [] 191 | for record, duplicates in exact_duplicates: 192 | duplicated_with_score = add_scores_to_records(duplicates) 193 | duplicate_record = DuplicateRecord(record=record, duplicates=duplicated_with_score, exact=True) 194 | duplicate_records.append(duplicate_record) 195 | 196 | # If no records are left after removing exact duplicates, return early 197 | if not dict_records: 198 | return DeduplicationResult( 199 | deduplicated=[], duplicates=duplicate_records, threshold=threshold, columns=self.columns 200 | ) 201 | 202 | # Compute embeddings for the new records 203 | embeddings = self._featurize(records=dict_records, columns=self.columns, model=self.model) 204 | # Query the fitted index 205 | results = self.index.query_threshold(embeddings, threshold=threshold) 206 | 207 | deduplicated_records = [] 208 | for record, similar_items in zip(dict_records, results): 209 | if not similar_items: 210 | # No duplicates found, keep this record 211 | deduplicated_records.append(record) 212 | else: 213 | duplicate_records.append( 214 | DuplicateRecord( 215 | record=record, 216 | duplicates=[(item, score) for item, score in similar_items], 217 | exact=False, 218 | ) 219 | ) 220 | 221 | result = DeduplicationResult( 222 | deduplicated=deduplicated_records, duplicates=duplicate_records, threshold=threshold, columns=self.columns 223 | ) 224 | 225 | if self._was_string: 226 | # Convert records back to strings if the records were originally strings 227 | return map_deduplication_result_to_strings(result, columns=self.columns) 228 | 229 | return result 230 | 231 | def self_deduplicate( 232 | self, 233 | threshold: float = 0.9, 234 | ) -> DeduplicationResult: 235 | """ 236 | Deduplicate within the same dataset. This can be used to remove duplicates from a single dataset. 237 | 238 | :param threshold: Similarity threshold for deduplication. 239 | :return: A deduplicated list of records. 240 | """ 241 | # Query the fitted index 242 | results = self.index.query_threshold(self.index.vectors, threshold=threshold) 243 | column_set = set(self.columns) 244 | 245 | duplicate_records = [] 246 | 247 | deduplicated_records = [] 248 | seen_items: set[frozendict[str, str]] = set() 249 | for item, similar_items in zip(self.index.items, results): 250 | # Items is a list of items which are exact duplicates of each other 251 | # So if the an item has more than one record, it is an exact duplicate 252 | # Crucially, we should count each instance separately. 253 | record, *duplicates = item 254 | # We need to compare all duplicates to all _items_. 255 | # The first item in a list of duplicate is not duplicated, because otherwise 256 | # we would remove the whole cluster. But it is a duplicate for the other items. 257 | 258 | # Iterate from index 1. 259 | for index, curr_record in enumerate(duplicates, 1): 260 | # The use of indexing is intentional here, we want to check if the object is the same 261 | # not if they have the same values. If we did != or is we would probably ignore lots 262 | # of items. 263 | items_to_keep = item[:index] + item[index + 1 :] 264 | items_with_score = add_scores_to_records(items_to_keep) 265 | duplicate_records.append(DuplicateRecord(record=curr_record, duplicates=items_with_score, exact=True)) 266 | 267 | # If we don't see any similar_items, we know the record is not a duplicate. 268 | # in rare cases, the item itself might not be a duplicate of itself. 269 | if not similar_items: 270 | deduplicated_records.append(record) 271 | continue 272 | items, _ = zip(*similar_items) 273 | frozen_items = [to_frozendict(item, column_set) for item in items] 274 | # similar_items includes 'record' itself 275 | # If we've seen any of these items before, this is a duplicate cluster. 276 | if any(item in seen_items for item in frozen_items): 277 | duplicate_records.append( 278 | DuplicateRecord( 279 | record=record, 280 | duplicates=[(item, score) for item, score in similar_items if item != record], 281 | exact=False, 282 | ) 283 | ) 284 | continue 285 | # This is the first time we see this cluster of similar items 286 | deduplicated_records.append(record) 287 | # Mark all items in this cluster as seen 288 | seen_items.update(frozen_items) 289 | 290 | result = DeduplicationResult( 291 | deduplicated=deduplicated_records, duplicates=duplicate_records, threshold=threshold, columns=self.columns 292 | ) 293 | 294 | if self._was_string: 295 | # Convert records back to strings if the records were originally strings 296 | return map_deduplication_result_to_strings(result, columns=self.columns) 297 | 298 | return result 299 | 300 | def _validate_if_strings(self, records: Sequence[dict[str, str] | str]) -> Sequence[dict[str, str]]: 301 | """ 302 | Validate if the records are strings. 303 | 304 | If the records are strings, they are converted to dictionaries with a single column. 305 | 306 | :param records: The records to validate. 307 | :return: The records as a list of dictionaries. 308 | :raises ValueError: If the records are strings but were not originally strings. 309 | :raises ValueError: If the records are not all strings or dictionaries. 310 | """ 311 | if isinstance(records[0], str): 312 | if not self._was_string: 313 | raise ValueError("Records were not originally strings, but you passed strings.") 314 | dict_records = [{"text": record} for record in records if isinstance(record, str)] 315 | else: 316 | dict_records = [record for record in records if isinstance(record, dict)] 317 | if len(dict_records) != len(records): 318 | raise ValueError("Records must be either strings or dictionaries.") 319 | return dict_records 320 | 321 | def find_representative( 322 | self, 323 | records: Sequence[Record], 324 | selection_size: int = 10, 325 | candidate_limit: int | Literal["auto"] = "auto", 326 | lambda_param: float = 0.5, 327 | ) -> FilterResult: 328 | """ 329 | Find representative samples from a given set of records against the fitted index. 330 | 331 | First, the records are ranked using average similarity. 332 | Then, the top candidates are re-ranked using Maximal Marginal Relevance (MMR) 333 | to select a diverse set of representatives. 334 | 335 | :param records: The records to rank and select representatives from. 336 | :param selection_size: Number of representatives to select. 337 | :param candidate_limit: Number of top candidates to consider for MMR reranking. 338 | Defaults to "auto", which calculates the limit based on the total number of records. 339 | :param lambda_param: Trade-off parameter between relevance (1.0) and diversity (0.0). Must be between 0 and 1. 340 | :return: A FilterResult with the diversified candidates. 341 | """ 342 | ranking = self._rank_by_average_similarity(records) 343 | if candidate_limit == "auto": 344 | candidate_limit = compute_candidate_limit(total=len(ranking.selected), selection_size=selection_size) 345 | return self._mmr(ranking, candidate_limit, selection_size, lambda_param) 346 | 347 | def self_find_representative( 348 | self, 349 | selection_size: int = 10, 350 | candidate_limit: int | Literal["auto"] = "auto", 351 | lambda_param: float = 0.5, 352 | ) -> FilterResult: 353 | """ 354 | Find representative samples from the fitted dataset. 355 | 356 | First, the rank the records are ranked using average similarity. 357 | Then, the top candidates are re-ranked using Maximal Marginal Relevance (MMR) 358 | to select a diverse set of representatives. 359 | 360 | :param selection_size: Number of representatives to select. 361 | :param candidate_limit: Number of top candidates to consider for MMR reranking. 362 | Defaults to "auto", which calculates the limit based on the total number of records. 363 | :param lambda_param: Trade-off parameter between relevance (1.0) and diversity (0.0). Must be between 0 and 1. 364 | :return: A FilterResult with the diversified representatives. 365 | """ 366 | ranking = self._self_rank_by_average_similarity() 367 | if candidate_limit == "auto": 368 | candidate_limit = compute_candidate_limit(total=len(ranking.selected), selection_size=selection_size) 369 | return self._mmr(ranking, candidate_limit, selection_size, lambda_param) 370 | 371 | def filter_outliers( 372 | self, 373 | records: Sequence[Record], 374 | outlier_percentage: float = 0.1, 375 | ) -> FilterResult: 376 | """ 377 | Filter outliers in a given set of records against the fitted dataset. 378 | 379 | This method ranks the records by their average similarity and filters the bottom 380 | outlier_percentage of records as outliers. 381 | 382 | :param records: A sequence of records to find outliers in. 383 | :param outlier_percentage: The percentage (between 0 and 1) of records to consider outliers. 384 | :return: A FilterResult where 'selected' contains the inliers and 'filtered' contains the outliers. 385 | :raises ValueError: If outlier_percentage is not between 0 and 1. 386 | """ 387 | if outlier_percentage < 0.0 or outlier_percentage > 1.0: 388 | raise ValueError("outlier_percentage must be between 0 and 1") 389 | ranking = self._rank_by_average_similarity(records) 390 | outlier_count = ceil(len(ranking.selected) * outlier_percentage) 391 | if outlier_count == 0: 392 | # If the outlier count is 0, return an empty selection 393 | return FilterResult( 394 | selected=[], filtered=ranking.selected, scores_selected=[], scores_filtered=ranking.scores_selected 395 | ) 396 | 397 | outlier_records = ranking.selected[-outlier_count:] 398 | outlier_scores = ranking.scores_selected[-outlier_count:] 399 | inlier_records = ranking.selected[:-outlier_count] 400 | inlier_scores = ranking.scores_selected[:-outlier_count] 401 | 402 | return FilterResult( 403 | selected=inlier_records, 404 | filtered=outlier_records, 405 | scores_selected=inlier_scores, 406 | scores_filtered=outlier_scores, 407 | ) 408 | 409 | def self_filter_outliers( 410 | self, 411 | outlier_percentage: float = 0.1, 412 | ) -> FilterResult: 413 | """ 414 | Filter outliers in the fitted dataset. 415 | 416 | The method ranks the records stored in the index and filters the bottom outlier_percentage 417 | of records as outliers. 418 | 419 | :param outlier_percentage: The percentage (between 0 and 1) of records to consider as outliers. 420 | :return: A FilterResult where 'selected' contains the inliers and 'filtered' contains the outliers. 421 | :raises ValueError: If outlier_percentage is not between 0 and 1. 422 | """ 423 | if outlier_percentage < 0.0 or outlier_percentage > 1.0: 424 | raise ValueError("outlier_percentage must be between 0 and 1") 425 | ranking = self._self_rank_by_average_similarity() 426 | outlier_count = ceil(len(ranking.selected) * outlier_percentage) 427 | if outlier_count == 0: 428 | # If the outlier count is 0, return an empty selection 429 | return FilterResult( 430 | selected=[], filtered=ranking.selected, scores_selected=[], scores_filtered=ranking.scores_selected 431 | ) 432 | 433 | outlier_records = ranking.selected[-outlier_count:] 434 | outlier_scores = ranking.scores_selected[-outlier_count:] 435 | inlier_records = ranking.selected[:-outlier_count] 436 | inlier_scores = ranking.scores_selected[:-outlier_count] 437 | 438 | return FilterResult( 439 | selected=inlier_records, 440 | filtered=outlier_records, 441 | scores_selected=inlier_scores, 442 | scores_filtered=outlier_scores, 443 | ) 444 | 445 | def _rank_by_average_similarity( 446 | self, 447 | records: Sequence[Record], 448 | ) -> FilterResult: 449 | """ 450 | Rank a given set of records based on the average cosine similarity of the neighbors in the fitted index. 451 | 452 | :param records: A sequence of records. 453 | :return: A FilterResult containing the ranking (records sorted and their average similarity scores). 454 | """ 455 | dict_records = self._validate_if_strings(records) 456 | embeddings = self._featurize(records=dict_records, columns=self.columns, model=self.model) 457 | results = self.index.query_top_k(embeddings, k=100, vectors_are_in_index=False) 458 | 459 | # Compute the average similarity for each record. 460 | sorted_scores = sorted( 461 | ((record, np.mean(sims)) for record, (_, sims) in zip(dict_records, results)), 462 | key=lambda x: x[1], 463 | reverse=True, 464 | ) 465 | selected, scores_selected = zip(*sorted_scores) 466 | 467 | return FilterResult( 468 | selected=list(selected), 469 | filtered=[], 470 | scores_selected=list(scores_selected), 471 | scores_filtered=[], 472 | ) 473 | 474 | def _self_rank_by_average_similarity( 475 | self, 476 | ) -> FilterResult: 477 | """ 478 | Rank the records stored in the fitted index based on the average cosine similarity of the neighbors. 479 | 480 | :return: A FilterResult containing the ranking. 481 | """ 482 | if self._ranking_cache is not None: 483 | return self._ranking_cache 484 | 485 | dict_records = [record[0] for record in self.index.items] 486 | results = self.index.query_top_k(self.index.vectors, k=100, vectors_are_in_index=True) 487 | 488 | # Compute the average similarity for each record. 489 | sorted_scores = sorted( 490 | ((record, np.mean(sims)) for record, (_, sims) in zip(dict_records, results)), 491 | key=lambda x: x[1], 492 | reverse=True, 493 | ) 494 | selected, scores_selected = zip(*sorted_scores) 495 | 496 | ranking = FilterResult( 497 | selected=list(selected), 498 | filtered=[], 499 | scores_selected=list(scores_selected), 500 | scores_filtered=[], 501 | ) 502 | self._ranking_cache = ranking 503 | return ranking 504 | 505 | def _mmr( 506 | self, 507 | ranked_results: FilterResult, 508 | candidate_limit: int, 509 | selection_size: int, 510 | lambda_param: float, 511 | ) -> FilterResult: 512 | """ 513 | Perform Maximal Marginal Relevance (MMR) re-ranking on the top candidates from a FilterResult. 514 | 515 | This function first slices the ranking, then computes embeddings for the candidates, 516 | normalizes them, and finally performs MMR re-ranking to obtain a diverse subset of representatives. 517 | 518 | :param ranked_results: A FilterResult containing the ranking of records. 519 | :param candidate_limit: Number of top candidates to consider from the ranking. 520 | :param selection_size: Number of final candidates to select using MMR. 521 | :param lambda_param: Weight balancing relevance and diversity (between 0 and 1). 522 | :return: A FilterResult containing the selected (diversified) representatives along with 523 | their MMR scores in `scores_selected`. The remaining candidates are placed in filtered. 524 | :raises ValueError: If lambda_param is not between 0 and 1. 525 | """ 526 | if not (0.0 <= lambda_param <= 1.0): 527 | raise ValueError("lambda_param must be between 0 and 1") 528 | 529 | # Slice the top candidates from the ranking. 530 | candidate_records = ranked_results.selected[:candidate_limit] 531 | candidate_relevance = ranked_results.scores_selected[:candidate_limit] 532 | 533 | # Compute embeddings for candidate records. 534 | embeddings = self._featurize(records=candidate_records, columns=self.columns, model=self.model) 535 | 536 | # Normalize embeddings for cosine similarity. 537 | normalized_embeddings = embeddings / (np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-16) 538 | 539 | # Package candidates as tuples: (record, baseline_relevance, normalized_embedding) 540 | candidates = [ 541 | (record, relevance, normalized_embedding) 542 | for record, relevance, normalized_embedding in zip( 543 | candidate_records, candidate_relevance, normalized_embeddings 544 | ) 545 | ] 546 | 547 | # If no candidates, return empty result. 548 | if not candidates: 549 | return FilterResult(selected=[], filtered=[], scores_selected=[], scores_filtered=[]) 550 | 551 | # Initialize selected set with the most relevant candidate. 552 | first_record, first_relevance, first_embedding = candidates[0] 553 | selected_records = [first_record] 554 | selected_scores = [first_relevance] 555 | selected_embeddings = [first_embedding] 556 | remaining_candidates = candidates[1:] 557 | 558 | # Iteratively select candidates using the MMR criterion. 559 | while remaining_candidates and len(selected_records) < selection_size: 560 | # Build arrays for the remaining candidates. 561 | embeddings_remaining = np.vstack([emb for (_, _, emb) in remaining_candidates]) 562 | relevances_remaining = np.array([rel for (_, rel, _) in remaining_candidates]) 563 | 564 | # Build array of embeddings for the already selected set. 565 | embeddings_selected = np.vstack(selected_embeddings) 566 | 567 | # Compute cosine similarities between selected and remaining candidates. 568 | similarity_matrix = embeddings_remaining.dot(embeddings_selected.T) 569 | max_similarity = similarity_matrix.max(axis=1) 570 | 571 | # Compute MMR scores for all remaining candidates. 572 | mmr_scores = lambda_param * relevances_remaining - (1.0 - lambda_param) * max_similarity 573 | 574 | # Choose the candidate with the highest MMR score. 575 | best_index = int(np.argmax(mmr_scores)) 576 | record, _, embedding = remaining_candidates.pop(best_index) 577 | 578 | selected_records.append(record) 579 | selected_scores.append(mmr_scores[best_index]) 580 | selected_embeddings.append(embedding) 581 | 582 | # Whatever is left in remaining_candidates is filtered out. 583 | filtered_records = [rec for (rec, _, _) in remaining_candidates] 584 | filtered_scores = [rel for (_, rel, _) in remaining_candidates] 585 | 586 | return FilterResult( 587 | selected=selected_records, 588 | filtered=filtered_records, 589 | scores_selected=selected_scores, 590 | scores_filtered=filtered_scores, 591 | ) 592 | --------------------------------------------------------------------------------