├── perception ├── py.typed ├── __init__.py ├── testing │ ├── videos │ │ ├── rgb.m4v │ │ ├── v1.m4v │ │ ├── v2.m4v │ │ ├── v2s.mov │ │ ├── expected_tmk.json.gz │ │ └── README.md │ ├── images │ │ ├── image1.jpg │ │ ├── image2.jpg │ │ ├── image3.jpg │ │ ├── image4.jpg │ │ ├── image5.jpg │ │ ├── image6.jpg │ │ ├── image7.jpg │ │ ├── image8.jpg │ │ ├── image9.jpg │ │ ├── image10.jpg │ │ └── README.md │ ├── logos │ │ ├── logoipsum.png │ │ └── README.md │ └── __init__.py ├── utils.py ├── hashers │ ├── video │ │ ├── __init__.py │ │ ├── framewise.py │ │ └── tmk.py │ ├── image │ │ ├── __init__.py │ │ ├── dhash.py │ │ ├── pdq.py │ │ ├── average.py │ │ ├── opencv.py │ │ ├── wavelet.py │ │ └── phash.py │ ├── __init__.py │ └── hasher.py ├── benchmarking │ ├── __init__.py │ ├── image_transforms.py │ ├── extensions.pyx │ ├── image.py │ ├── video_transforms.py │ └── video.py ├── approximate_deduplication │ ├── serve.py │ ├── debug.py │ └── __init__.py └── extensions.pyx ├── .dockerignore ├── .gitattributes ├── poetry.toml ├── .git-blame-ignore-revs ├── docs ├── api │ ├── tools.rst │ ├── index.rst │ ├── experimental.rst │ ├── benchmarking.rst │ └── hashers.rst ├── examples │ ├── index.rst │ ├── detecting_csam.rst │ └── deduplication.rst ├── requirements.txt ├── index.rst └── conf.py ├── tests ├── images │ ├── chair.png │ ├── chair3.png │ ├── chair-tall.png │ └── chair-square.png ├── test_tmk.py ├── test_hashers.py ├── test_benchmarking.py ├── test_tools.py └── test_local_descriptor_deduplication.py ├── .github ├── dependabot.yaml └── workflows │ ├── ci.yaml │ ├── codeql-analysis.yml │ └── release.yaml ├── MANIFEST.in ├── .readthedocs.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── Makefile ├── pyproject.toml ├── CHANGELOG.md ├── README.md ├── CODE_OF_CONDUCT.md └── LICENSE /perception/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | notebooks 2 | .venv/ 3 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | perception/_version.py export-subst 2 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | create = true 3 | in-project = true 4 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Format with black 2 | 6c03f96a9335e548685ece233474125fe453c262 -------------------------------------------------------------------------------- /docs/api/tools.rst: -------------------------------------------------------------------------------- 1 | 2 | Tools 3 | ***** 4 | 5 | 6 | .. automodule:: perception.tools 7 | :members: -------------------------------------------------------------------------------- /tests/images/chair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/tests/images/chair.png -------------------------------------------------------------------------------- /tests/images/chair3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/tests/images/chair3.png -------------------------------------------------------------------------------- /perception/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import metadata 2 | 3 | __version__ = metadata.version("perception") 4 | -------------------------------------------------------------------------------- /tests/images/chair-tall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/tests/images/chair-tall.png -------------------------------------------------------------------------------- /tests/images/chair-square.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/tests/images/chair-square.png -------------------------------------------------------------------------------- /perception/testing/videos/rgb.m4v: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/videos/rgb.m4v -------------------------------------------------------------------------------- /perception/testing/videos/v1.m4v: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/videos/v1.m4v -------------------------------------------------------------------------------- /perception/testing/videos/v2.m4v: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/videos/v2.m4v -------------------------------------------------------------------------------- /perception/testing/videos/v2s.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/videos/v2s.mov -------------------------------------------------------------------------------- /perception/utils.py: -------------------------------------------------------------------------------- 1 | def flatten(list_of_lists): 2 | return [item for sublist in list_of_lists for item in sublist] 3 | -------------------------------------------------------------------------------- /perception/testing/images/image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/images/image1.jpg -------------------------------------------------------------------------------- /perception/testing/images/image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/images/image2.jpg -------------------------------------------------------------------------------- /perception/testing/images/image3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/images/image3.jpg -------------------------------------------------------------------------------- /perception/testing/images/image4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/images/image4.jpg -------------------------------------------------------------------------------- /perception/testing/images/image5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/images/image5.jpg -------------------------------------------------------------------------------- /perception/testing/images/image6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/images/image6.jpg -------------------------------------------------------------------------------- /perception/testing/images/image7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/images/image7.jpg -------------------------------------------------------------------------------- /perception/testing/images/image8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/images/image8.jpg -------------------------------------------------------------------------------- /perception/testing/images/image9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/images/image9.jpg -------------------------------------------------------------------------------- /perception/testing/images/image10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/images/image10.jpg -------------------------------------------------------------------------------- /perception/testing/logos/logoipsum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/logos/logoipsum.png -------------------------------------------------------------------------------- /perception/testing/logos/README.md: -------------------------------------------------------------------------------- 1 | # Sample Logos 2 | These logos were obtained from free sources. 3 | 4 | - [LogoIpsum](https://logoipsum.com/) -------------------------------------------------------------------------------- /perception/testing/videos/expected_tmk.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thorn-oss/perception/HEAD/perception/testing/videos/expected_tmk.json.gz -------------------------------------------------------------------------------- /perception/hashers/video/__init__.py: -------------------------------------------------------------------------------- 1 | from .framewise import FramewiseHasher 2 | from .tmk import TMKL1, TMKL2 3 | 4 | __all__ = ["FramewiseHasher", "TMKL1", "TMKL2"] 5 | -------------------------------------------------------------------------------- /docs/api/index.rst: -------------------------------------------------------------------------------- 1 | API 2 | *** 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Contents: 7 | 8 | hashers 9 | benchmarking 10 | tools 11 | experimental -------------------------------------------------------------------------------- /docs/examples/index.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ******** 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Contents: 7 | 8 | deduplication 9 | detecting_csam 10 | benchmarking -------------------------------------------------------------------------------- /.github/dependabot.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | # Check for updates to GitHub Actions every week. 7 | interval: "weekly" 8 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include perception/testing/images/* 2 | include perception/testing/videos/* 3 | include perception/testing/logos/* 4 | include perception/**/*.pyx 5 | include perception/*.pyx 6 | include perception/py.typed 7 | exclude tests/* 8 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-autodoc-typehints==1.6.0 2 | sphinx-autobuild==0.7.1 3 | sphinx==1.8.3 4 | sphinx_rtd_theme==0.4.3 5 | m2r==0.2.1 6 | opencv-contrib-python-headless 7 | tqdm 8 | imgaug 9 | ffmpeg-python 10 | typing-extensions 11 | faiss-cpu 12 | aiohttp 13 | python-json-logger 14 | networkit 15 | -------------------------------------------------------------------------------- /perception/testing/videos/README.md: -------------------------------------------------------------------------------- 1 | Video from https://www.youtube.com/watch?v=84Er4LnWXtI under Creative Commons Attribution License. 2 | 3 | Notes 4 | - v1 is a fairly short, slow moving video 5 | - v2 is a longer but faster-paced video 6 | - v2s is the same as v2 but with a snippet removed in the middle (simulates a scene or cut) -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | # Build documentation in the docs/ directory with Sphinx 4 | sphinx: 5 | configuration: docs/conf.py 6 | 7 | formats: all 8 | 9 | # Installs the package and the docs requirements. 10 | python: 11 | version: 3.9 12 | install: 13 | - requirements: docs/requirements.txt 14 | - method: pip 15 | path: . 16 | system_packages: true 17 | -------------------------------------------------------------------------------- /perception/hashers/image/__init__.py: -------------------------------------------------------------------------------- 1 | from .average import AverageHash 2 | from .dhash import DHash 3 | from .opencv import BlockMean, ColorMoment, MarrHildreth 4 | from .phash import PHash, PHashF, PHashU8 5 | from .wavelet import WaveletHash 6 | 7 | __all__ = [ 8 | "AverageHash", 9 | "PHash", 10 | "WaveletHash", 11 | "MarrHildreth", 12 | "BlockMean", 13 | "ColorMoment", 14 | "DHash", 15 | "PHashF", 16 | "PHashU8", 17 | ] 18 | -------------------------------------------------------------------------------- /docs/api/experimental.rst: -------------------------------------------------------------------------------- 1 | Experimental 2 | ************ 3 | 4 | This module contains experimental functionality that may not be ready for production use. 5 | 6 | Approximate Nearest Neighbors 7 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 8 | 9 | .. automodule:: perception.experimental.ann 10 | :members: 11 | :imported-members: 12 | 13 | Local Descriptor Deduplication 14 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 15 | 16 | .. automodule:: perception.experimental.local_descriptor_deduplication 17 | :members: deduplicate, validate_match, pairs_to_clusters 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # MacOS stuff 2 | .DS_Store 3 | 4 | # Python artifacts 5 | *.egg-info 6 | 7 | # Cache 8 | .mypy_cache 9 | .pytest_cache 10 | __pycache__ 11 | .ipynb_checkpoints 12 | dist 13 | 14 | # Any temporary images or CSV files 15 | notebooks 16 | 17 | # Local environment 18 | .venv 19 | .python-version 20 | 21 | # Coverage file 22 | .coverage 23 | 24 | # Versioneer artifacts 25 | /versioneer.pyc 26 | 27 | # Build artifacts 28 | /build 29 | 30 | # Docs build artifacts 31 | /docs/_build 32 | 33 | # Remove .vscode folder 34 | .vscode 35 | 36 | # Extension artifacts 37 | *.c 38 | *.cpp 39 | *.so 40 | debug-image* 41 | -------------------------------------------------------------------------------- /perception/benchmarking/__init__.py: -------------------------------------------------------------------------------- 1 | from perception.benchmarking import video_transforms 2 | from perception.benchmarking import video 3 | from perception.benchmarking import image 4 | from perception.benchmarking.image import ( 5 | BenchmarkImageDataset, 6 | BenchmarkImageTransforms, 7 | ) 8 | from perception.benchmarking.video import ( 9 | BenchmarkVideoDataset, 10 | BenchmarkVideoTransforms, 11 | ) 12 | from perception.benchmarking.common import BenchmarkHashes 13 | 14 | __all__ = [ 15 | "BenchmarkImageDataset", 16 | "BenchmarkImageTransforms", 17 | "BenchmarkVideoDataset", 18 | "BenchmarkVideoTransforms", 19 | "BenchmarkHashes", 20 | "video_transforms", 21 | "video", 22 | "image", 23 | ] 24 | -------------------------------------------------------------------------------- /perception/hashers/__init__.py: -------------------------------------------------------------------------------- 1 | from .hasher import ImageHasher, VideoHasher 2 | from .image.average import AverageHash 3 | from .image.dhash import DHash 4 | from .image.opencv import BlockMean, ColorMoment, MarrHildreth 5 | from .image.phash import PHash, PHashF, PHashU8 6 | from .image.wavelet import WaveletHash 7 | from .video.framewise import FramewiseHasher 8 | from .video.tmk import TMKL1, TMKL2 9 | 10 | 11 | __all__ = [ 12 | "ImageHasher", 13 | "VideoHasher", 14 | "AverageHash", 15 | "PHash", 16 | "WaveletHash", 17 | "MarrHildreth", 18 | "BlockMean", 19 | "ColorMoment", 20 | "DHash", 21 | "FramewiseHasher", 22 | "TMKL1", 23 | "TMKL2", 24 | "PHashU8", 25 | "PHashF", 26 | ] 27 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.5.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-added-large-files 11 | - repo: https://github.com/psf/black 12 | rev: 24.8.0 13 | hooks: 14 | - id: black 15 | language_version: python3 16 | - repo: https://github.com/astral-sh/ruff-pre-commit 17 | # Ruff version. 18 | rev: v0.3.0 19 | hooks: 20 | # Run the linter. 21 | - id: ruff 22 | args: [ --fix ] 23 | - repo: https://github.com/pre-commit/mirrors-mypy 24 | rev: v1.8.0 25 | hooks: 26 | - id: mypy 27 | -------------------------------------------------------------------------------- /docs/api/benchmarking.rst: -------------------------------------------------------------------------------- 1 | Benchmarking 2 | ************ 3 | 4 | .. autoclass:: perception.benchmarking.BenchmarkImageDataset 5 | :members: 6 | :inherited-members: 7 | 8 | .. autoclass:: perception.benchmarking.BenchmarkImageTransforms 9 | :members: 10 | :inherited-members: 11 | 12 | .. autoclass:: perception.benchmarking.BenchmarkVideoDataset 13 | :members: 14 | :inherited-members: 15 | 16 | .. autoclass:: perception.benchmarking.BenchmarkVideoTransforms 17 | :members: 18 | :inherited-members: 19 | 20 | .. autoclass:: perception.benchmarking.BenchmarkHashes 21 | :members: 22 | :inherited-members: 23 | 24 | Video Transforms 25 | ================ 26 | 27 | Transforming videos can be more complex, so we provide the following 28 | tools for transforming videos. 29 | 30 | .. automodule:: perception.benchmarking.video_transforms 31 | :members: get_simple_transform, get_black_frame_padding_transform, get_slideshow_transform 32 | -------------------------------------------------------------------------------- /perception/hashers/image/dhash.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | from ..hasher import ImageHasher 4 | 5 | 6 | class DHash(ImageHasher): 7 | """A hash based on the differences between adjacent pixels. 8 | Implementation based on that of 9 | `ImageHash `_. 10 | """ 11 | 12 | dtype = "bool" 13 | distance_metric = "hamming" 14 | 15 | def __init__(self, hash_size=8): 16 | assert hash_size > 1, "Hash size must be greater than 1." 17 | self.hash_size = hash_size 18 | self.hash_length = hash_size * hash_size 19 | 20 | def _compute(self, image): 21 | image = cv2.resize( 22 | image, 23 | dsize=(self.hash_size + 1, self.hash_size), 24 | interpolation=cv2.INTER_AREA, 25 | ) 26 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 27 | previous = image[:, :-1] 28 | current = image[:, 1:] 29 | difference = previous > current 30 | return difference.flatten() 31 | -------------------------------------------------------------------------------- /docs/examples/detecting_csam.rst: -------------------------------------------------------------------------------- 1 | Detecting Child Sexual Abuse Material 2 | ************************************* 3 | 4 | Using `perception` and a subscription to Thorn's Safer service, 5 | you can easily check for child sexual abuse material against a database of known bad content 6 | **without** having to send any images to a third party. You do this by sending compact, irreversible 7 | image hashes to get matches with a high degree of precision. We support matching using 8 | 16x16 PHash hashes and md5 hashes. 9 | 10 | See usage example below. Please contact info@getsafer.io to discuss Thorn's Safer service 11 | and subscription options and visit `getsafer.io `_ to learn more. 12 | 13 | .. code-block:: python 14 | 15 | from perception import tools 16 | matcher = tools.SaferMatcher( 17 | api_key='YOUR_API_KEY', 18 | url='MATCHING_SERVICE_URL' 19 | ) 20 | matches = matcher.match(['myfile.jpg']) 21 | 22 | In some cases, you may have a username/password instead of an API key, in which case 23 | you can pass those instead (see API documentation for details). -------------------------------------------------------------------------------- /docs/api/hashers.rst: -------------------------------------------------------------------------------- 1 | Hashers 2 | ******* 3 | 4 | All hashers from the :code:`Hasher` class. 5 | 6 | .. autoclass:: perception.hashers.hasher.Hasher 7 | :members: 8 | 9 | Images 10 | ~~~~~~ 11 | 12 | All image hashers inherit from the :code:`ImageHasher` class. 13 | 14 | .. autoclass:: perception.hashers.hasher.ImageHasher 15 | :members: 16 | 17 | The following image hash functions are included in the package. 18 | 19 | .. automodule:: perception.hashers.image 20 | :members: 21 | :imported-members: 22 | 23 | 24 | Videos 25 | ~~~~~~ 26 | 27 | All video hashers inherit from the :code:`VideoHasher` class. 28 | 29 | .. autoclass:: perception.hashers.hasher.VideoHasher 30 | :members: 31 | 32 | The following video hash functions are included in the package. 33 | 34 | .. automodule:: perception.hashers.video 35 | :members: 36 | :imported-members: 37 | 38 | Tools 39 | ~~~~~ 40 | 41 | These utility functions are only used by the hashers but are documented 42 | here for completeness. 43 | 44 | .. automodule:: perception.hashers.tools 45 | :members: -------------------------------------------------------------------------------- /perception/hashers/image/pdq.py: -------------------------------------------------------------------------------- 1 | import pdqhash 2 | 3 | from ..hasher import ImageHasher 4 | 5 | 6 | class PDQHash(ImageHasher): 7 | """The Facebook PDQ hash. Based on the original implementation located at 8 | the `official repository `_. 9 | """ 10 | 11 | distance_metric = "hamming" 12 | dtype = "bool" 13 | hash_length = 256 14 | 15 | def _compute(self, image): 16 | return pdqhash.compute(image)[0] > 0 17 | 18 | def _compute_with_quality(self, image): 19 | hash_vector, quality = pdqhash.compute(image) 20 | return hash_vector > 0, quality 21 | 22 | def _compute_isometric(self, image): 23 | hash_vectors, _ = pdqhash.compute_dihedral(image) 24 | names = ["r0", "r90", "r180", "r270", "fv", "fh", "r90fv", "r90fh"] 25 | return dict(zip(names, hash_vectors)) 26 | 27 | 28 | class PDQHashF(PDQHash): 29 | dtype = "float32" 30 | distance_metric = "euclidean" 31 | hash_length = 256 32 | 33 | def _compute(self, image): 34 | return pdqhash.compute_float(image)[0] 35 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | TEST_SCOPE?=tests/ 2 | 3 | .PHONY: build build-wheel build-sdist init-project init test lint_check type_check format format_check precommit 4 | 5 | init-project: 6 | poetry install -E benchmarking -E matching -E experimental 7 | 8 | init: init-project 9 | poetry run pre-commit install 10 | 11 | test: 12 | poetry run pytest $(TEST_SCOPE) 13 | 14 | lint_check: 15 | poetry run ruff check perception tests 16 | 17 | type_check: 18 | poetry run mypy perception 19 | 20 | format: 21 | poetry run black . 22 | 23 | format_check: 24 | poetry run black --check . || (echo '\nUnexpected format.' && exit 1) 25 | 26 | precommit: 27 | poetry check 28 | make lint_check 29 | make type_check 30 | make format_check 31 | make test 32 | 33 | build-wheel: 34 | @poetry run pip -q install repairwheel 35 | @poetry self add -q "poetry-dynamic-versioning[plugin]" 36 | @poetry build --format="wheel" --output="dist-tmp" 37 | @poetry run repairwheel -o dist dist-tmp/*.whl 38 | @find dist -name "*.whl" -type f | sed -n "s/\(.*\)\.linux.*\.whl$$/& \1.whl/p" | xargs -r -n 2 mv # Fix wheel name 39 | @rm -rf dist-tmp 40 | 41 | build-sdist: 42 | @poetry self add -q "poetry-dynamic-versioning[plugin]" 43 | @poetry build --format="sdist" --output="dist" 44 | 45 | build: build-wheel build-sdist 46 | -------------------------------------------------------------------------------- /perception/hashers/image/average.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | from .. import tools 4 | from ..hasher import ImageHasher 5 | 6 | 7 | class AverageHash(ImageHasher): 8 | """Computes a simple hash comparing the intensity of each 9 | pixel in a resized version of the image to the mean. 10 | Implementation based on that of 11 | `ImageHash `_.""" 12 | 13 | distance_metric = "hamming" 14 | dtype = "bool" 15 | 16 | def __init__(self, hash_size=8): 17 | assert hash_size >= 2, "Hash size must be greater than or equal to 2." 18 | self.hash_size = hash_size 19 | self.hash_length = hash_size * hash_size 20 | 21 | def _compute(self, image): 22 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 23 | image = cv2.resize( 24 | image, dsize=(self.hash_size, self.hash_size), interpolation=cv2.INTER_AREA 25 | ) 26 | diff = image > image.mean() 27 | return diff.flatten() 28 | 29 | def _compute_isometric_from_hash(self, vector): 30 | return { 31 | transform_name: diff.flatten() 32 | for transform_name, diff in tools.get_isometric_transforms( 33 | vector.reshape(self.hash_size, self.hash_size, 1), require_color=False 34 | ).items() 35 | } 36 | -------------------------------------------------------------------------------- /perception/testing/images/README.md: -------------------------------------------------------------------------------- 1 | # Sample images 2 | These images were obtained from Wikimedia Commons. 3 | 4 | - [Image 1](https://commons.wikimedia.org/wiki/Commons:Picture_of_the_day#/media/File:ADAC-Zentrale,_Munich,_March_2017-05.jpg) 5 | - [Image 2](https://commons.wikimedia.org/wiki/Commons:Picture_of_the_day#/media/File:Two-tailed_pasha_(Charaxes_jasius_jasius)_Greece.jpg) 6 | - [Image 3](https://commons.wikimedia.org/wiki/Main_Page#/media/File:Escolta_presidencial,_Plaza_de_Armas,_Lima,_Per%C3%BA,_2015-07-28,_DD_40.JPG) 7 | - [Image 4](https://commons.wikimedia.org/wiki/Commons:Picture_of_the_day#/media/File:Iglesia_de_Ntra._Sra._de_la_Junquera,_Luesma,_Zaragoza,_Espa%C3%B1a,_2017-01-04,_DD_60.jpg) 8 | - [Image 5](https://commons.wikimedia.org/wiki/Commons:Picture_of_the_day#/media/File:Bahrain_Fort_March_2015.JPG) 9 | - [Image 6](https://commons.wikimedia.org/wiki/Commons:Picture_of_the_day#/media/File:ET_Gondar_asv2018-02_img18_Fasil_Ghebbi.jpg) 10 | - [Image 7](https://commons.wikimedia.org/wiki/Commons:Picture_of_the_day#/media/File:M%C3%BCnster,_Beresa,_Mercedes-Benz_C-Klasse_Cabrio_--_2018_--_1757.jpg) 11 | - [Image 8](https://commons.wikimedia.org/wiki/Commons:Picture_of_the_day#/media/File:Panoramic_sunset_in_Conques_02.jpg) 12 | - [Image 9](https://commons.wikimedia.org/wiki/Commons:Picture_of_the_day#/media/File:Catedral_de_San_Basilio,_Mosc%C3%BA,_Rusia,_2016-10-03,_DD_05-06_HDR.jpg) 13 | - [Image 10](https://commons.wikimedia.org/wiki/Commons:Picture_of_the_day#/media/File:Tupolev_Tu-160_overflying_Moscow_fix.jpg) -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | branches: 5 | - "**" 6 | tags-ignore: 7 | - v* 8 | jobs: 9 | test: 10 | strategy: 11 | matrix: 12 | python-version: ["3.10", "3.11", "3.12"] 13 | os: ["ubuntu-latest", "windows-latest", "macos-latest", "macos-13"] # macOS 13 is the latest version with the old architecture 14 | runs-on: ${{ matrix.os }} 15 | steps: 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Setup Poetry 21 | uses: abatilo/actions-poetry@v3 22 | - name: Setup FFMPEG 23 | uses: FedericoCarboni/setup-ffmpeg@v3 24 | if: ${{ ! startsWith(matrix.os, 'macos') }} 25 | - name: Setup Dependencies with Homebrew 26 | if: startsWith(matrix.os, 'macos') 27 | run: | 28 | brew install llvm ffmpeg 29 | echo "CC=$(brew --prefix)/opt/llvm/bin/clang" >> $GITHUB_ENV 30 | echo "CXX=$(brew --prefix)/opt/llvm/bin/clang++" >> $GITHUB_ENV 31 | - name: checkout 32 | uses: actions/checkout@v4 33 | - uses: actions/cache@v4 34 | name: Cache the venv 35 | with: 36 | path: ./.venv 37 | key: venv-${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('poetry.lock') }} 38 | - name: Setup Project 39 | run: make init-project 40 | - name: Run precommit 41 | run: make precommit 42 | -------------------------------------------------------------------------------- /tests/test_tmk.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | from pathlib import Path 4 | from typing import cast 5 | import platform 6 | 7 | import numpy as np 8 | import pytest 9 | 10 | from perception.hashers.video import tmk 11 | 12 | TEST_FILES = Path("perception") / "testing" / "videos" 13 | 14 | 15 | def test_tmk_parity(): 16 | if platform.machine() == "arm64": 17 | pytest.xfail("TMK is not supported on ARM64") 18 | 19 | hasher = tmk.TMKL2() 20 | with gzip.open(TEST_FILES / "expected_tmk.json.gz", "rt", encoding="utf8") as f: 21 | expected_output = json.load(f) 22 | expected_output = {k: np.array(v) for k, v in expected_output.items()} 23 | 24 | output = [] 25 | 26 | for filepath in [ 27 | "perception/testing/videos/v1.m4v", 28 | "perception/testing/videos/v2.m4v", 29 | ]: 30 | hash_value: np.ndarray = cast( 31 | np.ndarray, hasher.compute(filepath=filepath, hash_format="vector") 32 | ) 33 | output.append(hash_value.reshape((4, 64, -1))) 34 | 35 | # Verify the hashes are the same 36 | for o, t in zip(output, expected_output["hashes"]): 37 | np.testing.assert_allclose(o.reshape(*t.shape), t) 38 | 39 | # Verify the pair-wise scores are the same 40 | offsets = np.arange(-5, 5) 41 | for normalization in ["feat", "feat_freq", "matrix"]: 42 | score = hasher._score_pair( 43 | output[0], output[1], offsets=offsets, normalization=normalization 44 | ) 45 | np.testing.assert_allclose(score, expected_output[normalization]) 46 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | perception 2 | ========== 3 | 4 | :code:`perception` provides flexible, well-documented, and comprehensively tested tooling for perceptual hashing 5 | research, development, and production use. It provides a common wrapper around existing, popular perceptual hashes 6 | (such as those implemented by `ImageHash `_) 7 | along with tools to compare their performance and use them for common tasks. 8 | 9 | Perceptual hashes are used to create compact image "fingerprints" which are invariant to small alterations to 10 | the original image. Typically, the representations are compact enough that they are irreversible, which makes 11 | them useful for deduplication and detecting abusive content while preserving the privacy of content owners. 12 | 13 | Installation 14 | ************ 15 | 16 | You can install :code:`perception` using pip. You must install OpenCV separately (e.g., with :code:`pip install opencv-python`). 17 | 18 | .. code-block:: bash 19 | 20 | # Install from PyPi 21 | pip install perception 22 | 23 | # Install from GitHub 24 | pip install git+https://github.com/thorn-oss/perception.git#egg=perception 25 | 26 | To install with the necessary dependencies for benchmarking, use: 27 | 28 | .. code-block:: bash 29 | 30 | # Install from PyPi 31 | pip install perception[benchmarking] 32 | 33 | # Install from GitHub 34 | pip install opencv-python git+https://github.com/thorn-oss/perception.git#egg=perception[benchmarking] 35 | 36 | Getting Started 37 | *************** 38 | 39 | Please see the examples for code snippets for common use cases. 40 | 41 | 42 | .. toctree:: 43 | :maxdepth: 2 44 | :caption: Contents: 45 | 46 | examples/index 47 | api/index 48 | 49 | -------------------------------------------------------------------------------- /perception/benchmarking/image_transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def apply_watermark(watermark, alpha: float = 1.0, size: float = 1.0): 6 | """Apply a watermark to the bottom right of 7 | images. Based on the work provided at 8 | https://www.pyimagesearch.com/2016/04/25/watermarking-images-with-opencv-and-python/ 9 | 10 | Args: 11 | watermark: The watermark to overlay 12 | alpha: The strength of the overlay 13 | size: The maximum proportion of the image 14 | taken by the watermark. 15 | """ 16 | assert watermark.shape[-1] == 4, "Watermark must have an alpha channel." 17 | 18 | # Why do we have to do this? It's not clear. But the process doesn't work 19 | # without it. 20 | (B, G, R, A) = cv2.split(watermark) 21 | B = cv2.bitwise_and(B, B, mask=A) 22 | G = cv2.bitwise_and(G, G, mask=A) 23 | R = cv2.bitwise_and(R, R, mask=A) 24 | watermark = cv2.merge([B, G, R, A]) 25 | 26 | def transform(image): 27 | # Add alpha channel 28 | (h, w) = image.shape[:2] 29 | wh, ww = watermark.shape[:2] 30 | scale = size * min(h / wh, w / ww) 31 | image = np.dstack([image, np.ones((h, w), dtype="uint8") * 255]) 32 | # Construct an overlay that is the same size as the input. 33 | overlay = np.zeros((h, w, 4), dtype="uint8") 34 | scaled = cv2.resize(watermark, (int(scale * ww), int(scale * wh))) 35 | sh, sw = scaled.shape[:2] 36 | overlay[max(h - sh, 0) :, max(w - sw, 0) : w] = scaled 37 | # Blend the two images together using transparent overlays 38 | output = image.copy() 39 | cv2.addWeighted(overlay, alpha, output, 1.0, 0, output) 40 | return cv2.cvtColor(output, cv2.COLOR_RGBA2RGB) 41 | 42 | return transform 43 | -------------------------------------------------------------------------------- /perception/hashers/image/opencv.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from ..hasher import ImageHasher 5 | 6 | 7 | class OpenCVHasher(ImageHasher): 8 | allow_parallel = False 9 | 10 | def __init__(self): 11 | if not hasattr(cv2, "img_hash"): 12 | raise RuntimeError( 13 | "You do not appear to have opencv-contrib installed. It is required for pure OpenCV hashers." 14 | ) 15 | 16 | 17 | class MarrHildreth(OpenCVHasher): 18 | """A wrapper around OpenCV's Marr-Hildreth hash. 19 | See `paper `_ for details.""" 20 | 21 | dtype = "bool" 22 | distance_metric = "hamming" 23 | hash_length = 576 24 | 25 | def __init__(self): 26 | super().__init__() 27 | self.hasher = cv2.img_hash.MarrHildrethHash.create() 28 | 29 | def _compute(self, image): 30 | return np.unpackbits(self.hasher.compute(image)[0]) 31 | 32 | 33 | class ColorMoment(OpenCVHasher): 34 | """A wrapper around OpenCV's Color Moments hash. 35 | See `paper `_ for details.""" 36 | 37 | dtype = "float32" 38 | distance_metric = "euclidean" 39 | hash_length = 42 40 | 41 | def __init__(self): 42 | super().__init__() 43 | self.hasher = cv2.img_hash.ColorMomentHash.create() 44 | 45 | def _compute(self, image): 46 | return 10000 * self.hasher.compute(image)[0] 47 | 48 | 49 | class BlockMean(OpenCVHasher): 50 | """A wrapper around OpenCV's Block Mean hash. 51 | See `paper `_ for details.""" 52 | 53 | dtype = "bool" 54 | distance_metric = "hamming" 55 | hash_length = 968 56 | 57 | def __init__(self): 58 | super().__init__() 59 | self.hasher = cv2.img_hash.BlockMeanHash.create(1) 60 | 61 | def _compute(self, image): 62 | # https://stackoverflow.com/questions/54762896/why-cv2-norm-hamming-gives-different-value-than-actual-hamming-distance 63 | return np.unpackbits(self.hasher.compute(image)[0]) 64 | -------------------------------------------------------------------------------- /perception/hashers/image/wavelet.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import pywt 4 | 5 | from ..hasher import ImageHasher 6 | 7 | 8 | class WaveletHash(ImageHasher): 9 | """Similar to PHash but using wavelets instead of DCT. 10 | Implementation based on that of 11 | `ImageHash `_. 12 | """ 13 | 14 | distance_metric = "hamming" 15 | dtype = "bool" 16 | 17 | def __init__(self, hash_size=8, image_scale=None, mode="haar"): 18 | assert hash_size & (hash_size - 1) == 0, "Hash size must be a power of 2." 19 | if image_scale is not None: 20 | assert ( 21 | image_scale & (image_scale - 1) == 0 22 | ), "Image scale must be a power of 2." 23 | assert ( 24 | image_scale >= hash_size 25 | ), "Image scale must be greater than or equal to than hash size." 26 | self.hash_size = hash_size 27 | self.image_scale = image_scale 28 | self.mode = mode 29 | self.hash_length = hash_size * hash_size 30 | 31 | def _compute(self, image): 32 | if self.image_scale is None: 33 | image_scale = max(2 ** int(np.log2(min(image.shape[:2]))), self.hash_size) 34 | else: 35 | image_scale = self.image_scale 36 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 37 | image = cv2.resize( 38 | image, dsize=(image_scale, image_scale), interpolation=cv2.INTER_AREA 39 | ) 40 | image = np.float32(image) / 255 41 | 42 | ll_max_level = int(np.log2(image_scale)) 43 | level = int(np.log2(self.hash_size)) 44 | dwt_level = ll_max_level - level 45 | 46 | if self.mode == "haar": 47 | coeffs = pywt.wavedec2(image, "haar", level=ll_max_level) 48 | coeffs = list(coeffs) 49 | coeffs[0] *= 0 50 | image = pywt.waverec2(coeffs, "haar") 51 | 52 | coeffs = pywt.wavedec2(image, self.mode, level=dwt_level) 53 | dwt_low = coeffs[0] 54 | 55 | # Subtract median and compute hash 56 | med = np.median(dwt_low) 57 | diff = dwt_low > med 58 | 59 | return diff.flatten() 60 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "Perception" 3 | version = "0.0.0" 4 | description = "Perception provides flexible, well-documented, and comprehensively tested tooling for perceptual hashing research, development, and production use." 5 | authors = ["Thorn "] 6 | license = "Apache License 2.0" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.10" 11 | Cython = "^3" 12 | numpy = "^1.26" 13 | opencv-contrib-python-headless = "^4.10" 14 | pandas = "*" 15 | pdqhash = "*" 16 | Pillow = "*" 17 | pywavelets = "^1.5.0" 18 | tqdm = "*" 19 | validators = ">=0.22, <1.0" 20 | scipy = "*" 21 | 22 | # Benchmarking Extras 23 | matplotlib = { version = "*", optional = true } 24 | imgaug = { version = "*", optional = true } 25 | tabulate = { version = "*", optional = true } 26 | scikit-learn = { version = "*", optional = true } 27 | ffmpeg-python = { version = "*", optional = true } 28 | 29 | # Matching Extras 30 | aiohttp = { version = "*", optional = true } 31 | python-json-logger = { version = "*", optional = true } 32 | rich = "^13.7.0" 33 | 34 | # Experimental Extras 35 | networkit = { version = "^11", optional = true } 36 | faiss-cpu = { version = "^1.8.0.post1", optional = true } 37 | 38 | [tool.poetry.extras] 39 | benchmarking = [ 40 | "matplotlib", 41 | "scipy", 42 | "imgaug", 43 | "tabulate", 44 | "scikit-learn", 45 | "ffmpeg-python", 46 | ] 47 | matching = ["aiohttp", "python-json-logger"] 48 | experimental = ["networkit", "faiss-cpu"] 49 | 50 | 51 | [tool.poetry.group.dev.dependencies] 52 | black = "^24" 53 | coverage = "*" 54 | ipython = "*" 55 | mypy = "*" 56 | pandas-stubs = "*" 57 | pre-commit = "*" 58 | pytest = "*" 59 | pytest-cov = "*" 60 | ruff = "*" 61 | types-pillow = "*" 62 | types-tqdm = "*" 63 | twine = "*" 64 | 65 | 66 | [tool.poetry.build] 67 | script = "build.py" 68 | generate-setup-file = true 69 | 70 | [tool.mypy] 71 | exclude = ["/tests/"] 72 | check_untyped_defs = true 73 | ignore_missing_imports = true 74 | 75 | [tool.poetry-dynamic-versioning] 76 | enable = true 77 | 78 | [build-system] 79 | requires = [ 80 | "poetry-core", 81 | "poetry-dynamic-versioning", 82 | "numpy", 83 | "Cython", 84 | "setuptools", 85 | "wheel", 86 | ] 87 | build-backend = "poetry_dynamic_versioning.backend" 88 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 6 | 7 | ## [0.4.0] - 2020-10-17 8 | This release switches from using false positive rates in benchmarking to reporting precision, which is more intuitive. 9 | 10 | ### Breaking changes 11 | All references to fpr_threshold now refer to precision_threshold. 12 | 13 | ### Bug fixes 14 | The PDQHash hasher now correctly returns the hash vector instead of the (vector, quality) tuple. 15 | 16 | ## [0.3.0] - 2020-04-27 17 | This release adds significantly more support for video. 18 | 19 | ### Breaking changes 20 | - Previously, `read_video` returned `(frame, index, timestamp)` tuples where `index` reflected the index of the yielded frame (i.e., it always increased by exactly 1). It now reflects the index of the frame in the original video. This means that, if the requested framerate is higher than the encoded video framerate, this index may repeat the same value, indicating that we have repeated the same frame. 21 | 22 | ### Enhancements 23 | - We now include a `SimpleSceneDetection` hasher that can wrap other video hashers using scene detection. 24 | - `compute_metrics` is much faster now for integer-valued hashes that use a euclidean distance metric. 25 | - We now include an unsigned 8-bit integer version of `PHash`, called `PHashU8`. This provides a useful framewise hasher for averaging across frames (e.g., using TMK) while being more compact than `PHashF`. 26 | - We include more thorough support for benchmarking video hashes. 27 | 28 | ### Bug fixes 29 | - When using `hasher.vector_to_string` with hashers that return multiple hashes, the `hash_format` argument was not respected. 30 | - The `compute_threshold_recall` and `show_histograms` functions did not work properly when `grouping=[]`. 31 | 32 | ## [0.2.0] - 2019-12-20 33 | This release adds more support for hashing videos (including TMK L2 and TMK L2). As part of that, it also includes a re-factor to separate `benchmarking.BenchmarkDataset` and `benchmarking.BenchmarkTransforms` into image and video variants. 34 | 35 | ## [0.1.0] - 2019-11-04 36 | Initial release -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | 16 | # -- Project information ----------------------------------------------------- 17 | project = "perception" 18 | copyright = "2019, thorn" 19 | author = "thorn" 20 | 21 | # The short X.Y version 22 | version = "" 23 | # The full version, including alpha/beta/rc tags 24 | release = "" 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = [ 32 | "sphinx.ext.autodoc", 33 | "sphinx.ext.imgmath", 34 | "sphinx.ext.napoleon", 35 | "sphinx_autodoc_typehints", 36 | "m2r", 37 | ] 38 | 39 | # The suffix(es) of source filenames. 40 | # You can specify multiple suffix as a list of string: 41 | # 42 | # source_suffix = ['.rst', '.md'] 43 | source_suffix = ".rst" 44 | 45 | # The master toctree document. 46 | master_doc = "index" 47 | 48 | # The language for content autogenerated by Sphinx. Refer to documentation 49 | # for a list of supported languages. 50 | # 51 | # This is also used if you do content translation via gettext catalogs. 52 | # Usually you set "language" from the command line for these cases. 53 | language = None 54 | 55 | # List of patterns, relative to source directory, that match files and 56 | # directories to ignore when looking for source files. 57 | # This pattern also affects html_static_path and html_extra_path. 58 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 59 | 60 | # The name of the Pygments (syntax highlighting) style to use. 61 | pygments_style = None 62 | 63 | html_theme = "sphinx_rtd_theme" 64 | 65 | html_theme_options = {"navigation_depth": 4, "collapse_navigation": False} 66 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ main ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ main ] 20 | 21 | 22 | jobs: 23 | analyze: 24 | name: Analyze 25 | runs-on: ubuntu-latest 26 | permissions: 27 | actions: read 28 | contents: read 29 | security-events: write 30 | 31 | strategy: 32 | fail-fast: false 33 | matrix: 34 | language: ["python"] 35 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 36 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 37 | 38 | steps: 39 | - name: Checkout repository 40 | uses: actions/checkout@v4 41 | 42 | # Initializes the CodeQL tools for scanning. 43 | - name: Initialize CodeQL 44 | uses: github/codeql-action/init@v3 45 | with: 46 | languages: ${{ matrix.language }} 47 | # If you wish to specify custom queries, you can do so here or in a config file. 48 | # By default, queries listed here will override any specified in a config file. 49 | # Prefix the list here with "+" to use these queries and those in the config file. 50 | 51 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 52 | # queries: security-extended,security-and-quality 53 | 54 | 55 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 56 | # If this step fails, then you should remove it and run the build manually (see below) 57 | - name: Autobuild 58 | uses: github/codeql-action/autobuild@v3 59 | 60 | # ℹ️ Command-line programs to run using the OS shell. 61 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 62 | 63 | # If the Autobuild fails above, remove it and uncomment the following three lines. 64 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 65 | 66 | # - run: | 67 | # echo "Run, Build Application using script" 68 | # ./location_of_script_within_repo/buildscript.sh 69 | 70 | - name: Perform CodeQL Analysis 71 | uses: github/codeql-action/analyze@v3 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # perception ![ci](https://github.com/thorn-oss/perception/workflows/ci/badge.svg) 2 | 3 | `perception` provides flexible, well-documented, and comprehensively tested tooling for perceptual hashing research, development, and production use. See [the documentation](https://perception.thorn.engineering/en/latest/) for details. 4 | 5 | ## Background 6 | 7 | `perception` was initially developed at [Thorn](https://www.thorn.org) as part of our work to eliminate child sexual abuse material from the internet. For more information on the issue, check out [our CEO's TED talk](https://www.thorn.org/blog/time-is-now-eliminate-csam/). 8 | 9 | ## Getting Started 10 | 11 | ### Installation 12 | 13 | `pip install perception` 14 | 15 | ### Hashing 16 | 17 | Hashing with different functions is simple with `perception`. 18 | 19 | ```python 20 | from perception import hashers 21 | 22 | file1, file2 = 'test1.jpg', 'test2.jpg' 23 | hasher = hashers.PHash() 24 | hash1, hash2 = hasher.compute(file1), hasher.compute(file2) 25 | distance = hasher.compute_distance(hash1, hash2) 26 | ``` 27 | 28 | ### Examples 29 | 30 | See below for end-to-end examples for common use cases for perceptual hashes. 31 | 32 | - [Detecting child sexual abuse material](https://perception.thorn.engineering/en/latest/examples/detecting_csam.html) 33 | - [Deduplicating media](https://perception.thorn.engineering/en/latest/examples/deduplication.html) 34 | - [Benchmarking perceptual hashes](https://perception.thorn.engineering/en/latest/examples/benchmarking.html) 35 | 36 | ## Supported Hashing Algorithms 37 | 38 | `perception` currently ships with: 39 | 40 | - pHash (DCT hash) (`perception.hashers.PHash`) 41 | - Facebook's PDQ Hash (`perception.hashers.PDQ`) 42 | - dHash (difference hash) (`perception.hashers.DHash`) 43 | - aHash (average hash) (`perception.hashers.AverageHash`) 44 | - Marr-Hildreth (`perception.hashers.MarrHildreth`) 45 | - Color Moment (`perception.hashers.ColorMoment`) 46 | - Block Mean (`perception.hashers.BlockMean`) 47 | - wHash (wavelet hash) (`perception.hashers.WaveletHash`) 48 | 49 | ## Contributing 50 | 51 | To work on the project, start by doing the following. 52 | 53 | ```bash 54 | # Install local dependencies for 55 | # code completion, etc. 56 | make init 57 | 58 | - To do a (close to) comprehensive check before committing code, you can use `make precommit`. 59 | 60 | To implement new features, please first file an issue proposing your change for discussion. 61 | 62 | To report problems, please file an issue with sample code, expected results, actual results, and a complete traceback. 63 | 64 | ## Alternatives 65 | 66 | There are other packages worth checking out to see if they meet your needs for perceptual hashing. Here are some 67 | examples. 68 | 69 | - [dedupe](https://github.com/dedupeio/dedupe) 70 | - [imagededup](https://idealo.github.io/imagededup/) 71 | - [ImageHash](https://github.com/JohannesBuchner/imagehash) 72 | - [PhotoHash](https://github.com/bunchesofdonald/photohash) 73 | ``` 74 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: release 2 | on: 3 | release: 4 | types: [published] 5 | workflow_dispatch: 6 | 7 | jobs: 8 | build-wheels: 9 | runs-on: ${{ matrix.os }} 10 | strategy: 11 | matrix: 12 | python-version: ["3.10", "3.11", "3.12"] 13 | os: ["ubuntu-latest", "windows-latest", "macos-latest", "macos-13"] # macOS 13 is the latest version with the old architecture 14 | name: Build for ${{ matrix.os }} on Python ${{ matrix.python-version }} 15 | steps: 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Setup Poetry 21 | uses: abatilo/actions-poetry@v3 22 | - name: Setup FFMPEG 23 | uses: FedericoCarboni/setup-ffmpeg@v3 24 | if: ${{ ! startsWith(matrix.os, 'macos') }} 25 | - name: Setup Dependencies with Homebrew 26 | if: startsWith(matrix.os, 'macos') 27 | run: | 28 | brew install llvm ffmpeg 29 | echo "CC=$(brew --prefix)/opt/llvm/bin/clang" >> $GITHUB_ENV 30 | echo "CXX=$(brew --prefix)/opt/llvm/bin/clang++" >> $GITHUB_ENV 31 | - uses: actions/checkout@v4 32 | with: 33 | # Full clone for version calculation 34 | fetch-depth: 0 35 | - name: Build Project 36 | run: make build-wheel 37 | - uses: actions/upload-artifact@v4 38 | with: 39 | name: package-wheels-${{ matrix.os }}-${{ matrix.python-version }} 40 | path: dist/* 41 | 42 | build-sdist: 43 | runs-on: ubuntu-latest 44 | name: Build sdist 45 | steps: 46 | - name: Set up Python 47 | uses: actions/setup-python@v5 48 | with: 49 | python-version: "3.12" 50 | - name: Setup Poetry 51 | uses: abatilo/actions-poetry@v3 52 | - uses: actions/checkout@v4 53 | with: 54 | # Full clone for version calculation 55 | fetch-depth: 0 56 | - name: Build Project 57 | run: make build-sdist 58 | - uses: actions/upload-artifact@v4 59 | with: 60 | name: package-sdist 61 | path: dist/* 62 | 63 | publish: 64 | needs: [build-wheels, build-sdist] 65 | runs-on: ubuntu-latest 66 | if: ${{ github.repository_owner == 'thorn-oss' && github.event_name == 'release' }} 67 | steps: 68 | - uses: actions/checkout@v4 69 | with: 70 | # Full clone for version calculation 71 | fetch-depth: 0 72 | - uses: actions/setup-python@v5 73 | with: 74 | python-version: "3.12" 75 | - name: Setup Poetry 76 | uses: abatilo/actions-poetry@v3 77 | - name: Setup Dynamic Versioning 78 | run: poetry self add "poetry-dynamic-versioning[plugin]" 79 | - name: Download wheels 80 | uses: actions/download-artifact@v4 81 | with: 82 | path: dist 83 | pattern: package-* 84 | merge-multiple: true 85 | - name: Load PyPI Token 86 | uses: 1password/load-secrets-action@v2 87 | with: 88 | # Export loaded secrets as environment variables 89 | export-env: true 90 | env: 91 | OP_SERVICE_ACCOUNT_TOKEN: ${{ secrets.OP_SERVICE_ACCOUNT_TOKEN }} 92 | POETRY_PYPI_TOKEN_PYPI: op://data-science-oss/perception-pypi-api-key/secret/value 93 | - name: Publish package 94 | run: poetry publish -n 95 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at conduct@thorn.org. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /perception/hashers/image/phash.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import scipy.fftpack 4 | 5 | from .. import tools 6 | from ..hasher import ImageHasher 7 | 8 | 9 | class PHash(ImageHasher): 10 | """Also known as the DCT hash, a hash based on discrete cosine transforms of images. 11 | See `complete paper `_ for 12 | details. Implementation based on that of 13 | `ImageHash `_. 14 | 15 | Args: 16 | hash_size: The number of DCT elements to retain (the hash length 17 | will be hash_size * hash_size). 18 | highfreq_factor: The multiple of the hash size to resize the input 19 | image to before computing the DCT. 20 | exclude_first_term: WHether to exclude the first term of the DCT 21 | freq_shift: The number of DCT low frequency elements to skip. 22 | """ 23 | 24 | distance_metric = "hamming" 25 | dtype = "bool" 26 | 27 | def __init__( 28 | self, hash_size=8, highfreq_factor=4, exclude_first_term=False, freq_shift=0 29 | ): 30 | assert hash_size >= 2, "Hash size must be greater than or equal to 2" 31 | assert ( 32 | freq_shift <= highfreq_factor * hash_size - hash_size 33 | ), "Frequency shift is too large for this hash size / highfreq_factor combination." 34 | self.hash_size = hash_size 35 | self.highfreq_factor = highfreq_factor 36 | self.exclude_first_term = exclude_first_term 37 | self.hash_length = hash_size * hash_size 38 | self.freq_shift = freq_shift 39 | if exclude_first_term: 40 | self.hash_length -= 1 41 | 42 | def _compute_dct(self, image): 43 | img_size = self.hash_size * self.highfreq_factor 44 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 45 | image = cv2.resize( 46 | image, dsize=(img_size, img_size), interpolation=cv2.INTER_AREA 47 | ) 48 | dct = scipy.fftpack.dct(scipy.fftpack.dct(image, axis=0), axis=1) 49 | return dct[ 50 | self.freq_shift : self.hash_size + self.freq_shift, 51 | self.freq_shift : self.hash_size + self.freq_shift, 52 | ] 53 | 54 | def _dct_to_hash(self, dct): 55 | dct = dct.flatten() 56 | if self.exclude_first_term: 57 | dct = dct[1:] 58 | return dct > np.median(dct) 59 | 60 | def _compute(self, image): 61 | dct = self._compute_dct(image) 62 | return self._dct_to_hash(dct) 63 | 64 | def _compute_isometric(self, image): 65 | return { 66 | transform_name: self._dct_to_hash(dct) 67 | for transform_name, dct in tools.get_isometric_dct_transforms( 68 | self._compute_dct(image) 69 | ).items() 70 | } 71 | 72 | 73 | class PHashF(PHash): 74 | """A real-valued version of PHash. It 75 | returns the raw 32-bit floats in the DCT. 76 | For a more compact approach, see PHashU8.""" 77 | 78 | dtype = "float32" 79 | distance_metric = "euclidean" 80 | 81 | def _dct_to_hash(self, dct): 82 | dct = dct.flatten() 83 | if self.exclude_first_term: 84 | dct = dct[1:] 85 | if (dct == 0).all(): 86 | return None 87 | return dct 88 | 89 | 90 | class PHashU8(PHash): 91 | """A real-valued version of PHash. It 92 | uses minimum / maximum scaling to convert 93 | DCT values to unsigned 8-bit integers (more 94 | compact than the 32-bit floats used by PHashF at 95 | the cost of precision).""" 96 | 97 | dtype = "uint8" 98 | distance_metric = "euclidean" 99 | 100 | def _dct_to_hash(self, dct): 101 | dct = dct.flatten() 102 | if self.exclude_first_term: 103 | dct = dct[1:] 104 | if (dct == 0).all(): 105 | return None 106 | min_value = dct.min() 107 | max_value = dct.max() 108 | dct = np.uint8(255 * (dct - min_value) / (max_value - min_value)) 109 | return dct 110 | -------------------------------------------------------------------------------- /perception/hashers/video/framewise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .. import tools 4 | from ..hasher import ImageHasher, VideoHasher 5 | 6 | 7 | class FramewiseHasher(VideoHasher): 8 | """A hasher that simply returns frame-wise hashes at some 9 | regular interval with some minimum inter-frame distance threshold.""" 10 | 11 | returns_multiple = True 12 | 13 | def __init__( 14 | self, 15 | frame_hasher: ImageHasher, 16 | interframe_threshold: float, 17 | frames_per_second: int = 15, 18 | quality_threshold: float | None = None, 19 | ): 20 | self.hash_length = frame_hasher.hash_length 21 | self.frames_per_second = frames_per_second 22 | self.frame_hasher = frame_hasher 23 | self.distance_metric = frame_hasher.distance_metric 24 | if self.distance_metric == "hamming" and interframe_threshold > 1: 25 | raise ValueError( 26 | "Hamming distance is always between 0 and 1 but " 27 | f"`interframe_threshold` was set to {interframe_threshold}." 28 | ) 29 | self.dtype = frame_hasher.dtype 30 | self.interframe_threshold = interframe_threshold 31 | self.quality_threshold = quality_threshold 32 | 33 | def process_frame(self, frame, frame_index, frame_timestamp, state=None): 34 | if self.quality_threshold is None: 35 | current = self.frame_hasher.compute(frame, hash_format="vector") 36 | else: 37 | current, quality = self.frame_hasher.compute_with_quality( 38 | frame, hash_format="vector" 39 | ) 40 | if quality < self.quality_threshold: 41 | return state or {"previous": None, "hashes": []} 42 | assert isinstance(current, np.ndarray) # help type checking below 43 | if state is None or state["previous"] is None: 44 | # We keep a separate reference to the previous hash instead of using 45 | # the last entry in the hashes list because `compute_batches` may 46 | # clear the hashes list but we still want to be able to compare 47 | # the final entry. 48 | state = { 49 | "previous": current, 50 | "hashes": [current], 51 | } 52 | else: 53 | if ( 54 | self.frame_hasher.compute_distance(current, state["previous"]) 55 | > self.interframe_threshold 56 | ): 57 | state["hashes"].append(current) 58 | return state 59 | 60 | def compute_batches( 61 | self, filepath: str, batch_size: int, errors="raise", hash_format="base64" 62 | ): 63 | """Compute hashes for a video in batches. 64 | 65 | Args: 66 | filepath: Path to video file 67 | batch_size: The batch size to use for returning hashes 68 | errors: One of "raise", "ignore", or "warn". Passed 69 | to perception.hashers.tools.read_video. 70 | hash_format: The format in which to return hashes 71 | """ 72 | 73 | def format_batch(hashes): 74 | return [ 75 | ( 76 | self.vector_to_string(vector, hash_format=hash_format) 77 | if hash_format != "vector" 78 | else vector 79 | ) 80 | for vector in hashes 81 | ] 82 | 83 | state = None 84 | for frame, frame_index, frame_timestamp in tools.read_video( 85 | filepath=filepath, frames_per_second=self.frames_per_second, errors=errors 86 | ): 87 | state = self.process_frame( 88 | frame=frame, 89 | frame_index=frame_index, 90 | frame_timestamp=frame_timestamp, 91 | state=state, 92 | ) 93 | if state is not None and len(state["hashes"]) > batch_size: 94 | yield format_batch(state["hashes"]) 95 | state["hashes"] = [] 96 | if state is not None and state["hashes"]: 97 | yield format_batch(state["hashes"]) 98 | 99 | def hash_from_final_state(self, state): 100 | if state is None: 101 | return [] 102 | return state["hashes"] 103 | -------------------------------------------------------------------------------- /perception/benchmarking/extensions.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | 3 | import cython 4 | import numpy as np 5 | from cython.parallel import parallel, prange 6 | 7 | cimport numpy as np 8 | from libc.math cimport sqrt 9 | from libc.stdlib cimport abort, free, malloc 10 | 11 | 12 | cdef extern from "limits.h": 13 | int INT_MAX 14 | 15 | ctypedef np.uint8_t uint8 16 | 17 | @cython.boundscheck(False) 18 | @cython.wraparound(False) 19 | def compute_euclidean_metrics(int[:, :] X_noop, int[:, :] X_tran, uint8[:, :] mask): 20 | """Compute the positive / negative distance metrics between two sets of vectors 21 | using euclidean distance. This function obtains the necessary metrics roughly 22 | 10x faster than using scipy.spatial.distance.cdist and numpy functions. 23 | 24 | Args: 25 | X_noop: The vectors for the noop hashes with shape (N, K) 26 | X_tran: The vectors for the transformed instances with shape (M, K) 27 | mask: A (M, N) array indicating whether noop n corresponds to transform m 28 | 29 | Returns: 30 | distances: An M by 2 array with the closest false positive and closest 31 | true positive for each transform. 32 | indexes: An M by 2 array with the index for the closest false positive 33 | noop and the closest true positive noop. 34 | """ 35 | 36 | cdef Py_ssize_t n_noop = X_noop.shape[0] 37 | cdef Py_ssize_t d_noop = X_noop.shape[1] 38 | cdef Py_ssize_t n_tran = X_tran.shape[0] 39 | cdef Py_ssize_t d_tran = X_tran.shape[1] 40 | cdef Py_ssize_t n_mask_tran = mask.shape[0] 41 | cdef Py_ssize_t n_mask_noop = mask.shape[1] 42 | cdef Py_ssize_t i_mask_tran 43 | cdef Py_ssize_t i_mask_noop 44 | cdef int n_pos 45 | 46 | cdef int current_distance 47 | cdef int current_closest_fp 48 | cdef int current_closest_tp 49 | cdef int[:] x 50 | cdef int[:] y 51 | cdef uint8 is_pos 52 | cdef Py_ssize_t i_noop, i_tran, i_d 53 | cdef Py_ssize_t i_closest_fp = 0 54 | cdef Py_ssize_t i_closest_tp = 1 55 | cdef Py_ssize_t i_closest_fp_idx = 0 56 | cdef Py_ssize_t i_closest_tp_idx = 1 57 | cdef int * local_buf 58 | cdef size_t size = 5 59 | cdef float NAN 60 | NAN = float("NaN") 61 | 62 | assert d_noop == d_tran, "Dimensionality of vectors must match." 63 | assert n_mask_tran == n_tran, "Dimension 0 of mask must correspond to n_transforms." 64 | assert n_mask_noop == n_noop, "Dimension 1 of mask must correspond to n_noops." 65 | for i_mask_tran in range(n_mask_tran): 66 | n_pos = 0 67 | for i_mask_noop in range(n_mask_noop): 68 | if mask[i_mask_tran, i_mask_noop] == True: 69 | n_pos += 1 70 | assert n_pos > 0, "All transforms must have at least one positive noop." 71 | assert n_pos < n_mask_noop, "All transforms must have at least one negative noop." 72 | 73 | distances = np.zeros((n_tran, 2), dtype=np.float32) 74 | indexes = np.zeros((n_tran, 2), dtype=np.int32) 75 | 76 | cdef np.float32_t[:, :] distances_view = distances 77 | cdef int[:, :] indexes_view = indexes 78 | 79 | with nogil, parallel(): 80 | local_buf = malloc(sizeof(int) * size) 81 | if local_buf is NULL: 82 | abort() 83 | for i_tran in prange(n_tran): 84 | local_buf[1] = INT_MAX # Smallest false positive distance 85 | local_buf[2] = INT_MAX # Smallest true positive distance 86 | local_buf[3] = 0 # Smallest false positive index 87 | local_buf[4] = 0 # Smallest true positive index 88 | for i_noop in range(n_noop): 89 | local_buf[0] = 0 # Current distance 90 | is_pos = mask[i_tran, i_noop] == True 91 | for i_d in range(d_noop): 92 | local_buf[0] += (X_noop[i_noop, i_d] - X_tran[i_tran, i_d]) ** 2 93 | if is_pos and (local_buf[0] < local_buf[2]): 94 | local_buf[2] = local_buf[0] 95 | local_buf[4] = i_noop 96 | if not is_pos and (local_buf[0] < local_buf[1]): 97 | local_buf[1] = local_buf[0] 98 | local_buf[3] = i_noop 99 | # I do not think that an can ever actually be 100 | # greater than INT_MAX but we'll leave the check in. 101 | if local_buf[1] < INT_MAX: 102 | distances_view[i_tran, i_closest_fp] = sqrt(local_buf[1]) 103 | else: 104 | distances_view[i_tran, i_closest_fp] = NAN 105 | if local_buf[2] < INT_MAX: 106 | distances_view[i_tran, i_closest_tp] = sqrt(local_buf[2]) 107 | else: 108 | distances_view[i_tran, i_closest_tp] = NAN 109 | indexes_view[i_tran, i_closest_fp_idx] = local_buf[3] 110 | indexes_view[i_tran, i_closest_tp_idx] = local_buf[4] 111 | free(local_buf) 112 | return distances, indexes 113 | -------------------------------------------------------------------------------- /tests/test_hashers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import string 3 | 4 | import pytest 5 | 6 | from perception import hashers, testing 7 | from perception.hashers.image.pdq import PDQHash 8 | 9 | TEST_IMAGES = [os.path.join("tests", "images", f"image{n}.jpg") for n in range(1, 11)] 10 | 11 | 12 | # The PDQ hash isometric computation is inexact. See 13 | # https://github.com/faustomorales/pdqhash-python/blob/master/tests/test_compute.py 14 | # for details. 15 | @pytest.mark.parametrize( 16 | "hasher_class,pil_opencv_threshold,transform_threshold,opencv_hasher", 17 | [ 18 | (hashers.AverageHash, 0.1, 0.1, False), 19 | (hashers.WaveletHash, 0.1, 0.1, False), 20 | (hashers.PHash, 0.1, 0.1, False), 21 | (PDQHash, 0.1, 0.15, False), 22 | (hashers.DHash, 0.1, 0.1, False), 23 | (hashers.MarrHildreth, 0.1, 0.1, True), 24 | (hashers.BlockMean, 0.1, 0.1, True), 25 | (hashers.ColorMoment, 10, 0.1, True), 26 | ], 27 | ) 28 | def test_image_hashing_common( 29 | hasher_class, pil_opencv_threshold, transform_threshold, opencv_hasher 30 | ): 31 | testing.test_image_hasher_integrity( 32 | hasher=hasher_class(), 33 | pil_opencv_threshold=pil_opencv_threshold, 34 | transform_threshold=transform_threshold, 35 | opencv_hasher=opencv_hasher, 36 | ) 37 | 38 | 39 | def test_video_hashing_common(): 40 | testing.test_video_hasher_integrity( 41 | hasher=hashers.FramewiseHasher( 42 | frame_hasher=hashers.PHash(hash_size=16), 43 | interframe_threshold=0.1, 44 | frames_per_second=1, 45 | ) 46 | ) 47 | 48 | 49 | def test_video_reading(): 50 | # We should get one red, one green, and one blue frame 51 | for frame, _, timestamp in hashers.tools.read_video( 52 | filepath="perception/testing/videos/rgb.m4v", frames_per_second=0.5 53 | ): 54 | assert timestamp in [0.0, 2.0, 4.0] 55 | channel = int(timestamp / 2) 56 | assert frame[:, :, channel].min() > 220 57 | for other in [0, 1, 2]: 58 | if other == channel: 59 | continue 60 | assert frame[:, :, other].max() < 20 61 | 62 | 63 | def test_common_framerate(): 64 | assert hashers.tools.get_common_framerates( 65 | dict(zip(["a", "b", "c"], [1 / 3, 1 / 2, 1 / 5])) 66 | ) == {1.0: ("a", "b", "c")} 67 | assert hashers.tools.get_common_framerates( 68 | dict(zip(["a", "b", "c"], [1 / 3, 1 / 6, 1 / 9])) 69 | ) == {1 / 3: ("a", "b", "c")} 70 | assert hashers.tools.get_common_framerates( 71 | dict(zip(["a", "b", "c", "d", "e"], [1 / 3, 1 / 2, 1 / 5, 1 / 7, 1 / 11])) 72 | ) == {1.0: ("a", "b", "c", "d", "e")} 73 | assert hashers.tools.get_common_framerates( 74 | dict(zip(string.ascii_lowercase[:6], [10, 5, 3, 1 / 3, 1 / 6, 1 / 9])) 75 | ) == {3.0: ("c", "d", "e", "f"), 10.0: ("a", "b")} 76 | assert hashers.tools.get_common_framerates(dict(zip(["a", "b"], [100, 1]))) == { 77 | 100: ("a", "b") 78 | } 79 | 80 | 81 | def test_synchronized_hashing(): 82 | video_hashers = { 83 | "phashframewise": hashers.FramewiseHasher( 84 | frame_hasher=hashers.PHash(hash_size=16), 85 | frames_per_second=1, 86 | interframe_threshold=0.2, 87 | ), 88 | "tmkl2": hashers.TMKL2(frames_per_second=15), 89 | "tmkl1": hashers.TMKL1(frames_per_second=15), 90 | } 91 | 92 | for filepath in [ 93 | "perception/testing/videos/v1.m4v", 94 | "perception/testing/videos/v2.m4v", 95 | ]: 96 | # Ensure synchronized hashing 97 | hashes1 = { 98 | hasher_name: hasher.compute(filepath) 99 | for hasher_name, hasher in video_hashers.items() 100 | } 101 | hashes2 = hashers.tools.compute_synchronized_video_hashes( 102 | filepath=filepath, hashers=video_hashers 103 | ) 104 | assert hashes1 == hashes2 105 | 106 | 107 | def test_hex_b64_conversion(): 108 | b64_string = ( 109 | """ 110 | CFFRABrAaRKCDQigEBIGwAhNBdIISgVZBxQYAgP4fwYNUR0oBgYCPwwIDSqTAmIH 111 | FRQhCiT/IT9DpHIeIx4cA2hQcBTwISovFkspMxz/MzdnljeCOEs4LnBYNHHBMC4x 112 | EC8mPxLaLkI/dywmNk1lMXoqJyCLSyg7BxwRSgTmIlI/LwsrP04hTCMtBSxaGAFB 113 | """.replace( 114 | "\n", "" 115 | ) 116 | .replace(" ", "") 117 | .strip() 118 | ) 119 | hex_string = ( 120 | """ 121 | 085151001ac06912820d08a0101206c0084d05d2084a05590714180203f87f06 122 | 0d511d280606023f0c080d2a930262071514210a24ff213f43a4721e231e1c03 123 | 68507014f0212a2f164b29331cff333767963782384b382e70583471c1302e31 124 | 102f263f12da2e423f772c26364d65317a2a27208b4b283b071c114a04e62252 125 | 3f2f0b2b3f4e214c232d052c5a180141 126 | """.replace( 127 | "\n", "" 128 | ) 129 | .replace(" ", "") 130 | .strip() 131 | ) 132 | assert ( 133 | hashers.tools.hex_to_b64(hex_string, dtype="uint8", hash_length=144) 134 | == b64_string 135 | ) 136 | assert ( 137 | hashers.tools.b64_to_hex(b64_string, dtype="uint8", hash_length=144) 138 | == hex_string 139 | ) 140 | -------------------------------------------------------------------------------- /perception/approximate_deduplication/serve.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import functools 3 | import json 4 | import logging 5 | import typing 6 | 7 | import aiohttp.web 8 | import numpy as np 9 | from pythonjsonlogger import jsonlogger 10 | 11 | import perception.hashers.tools as pht 12 | 13 | from .index import ApproximateNearestNeighbors 14 | 15 | 16 | def is_similarity_valid(data, index: ApproximateNearestNeighbors): 17 | """Validates input to the similarity endpoint.""" 18 | hash_format = data.get("hash_format", "base64") 19 | expected_string_length = pht.get_string_length( 20 | hash_length=index.hash_length, dtype=index.dtype, hash_format=hash_format 21 | ) 22 | return ( 23 | isinstance(data, dict) 24 | and "queries" in data 25 | and isinstance(data["queries"], list) 26 | and all(isinstance(x.get("hash", None), str) for x in data["queries"]) 27 | and hash_format in ["hex", "base64"] 28 | and all( 29 | len(x.get("hash", None)) == expected_string_length for x in data["queries"] 30 | ) 31 | ) 32 | 33 | 34 | async def similarity(request): 35 | """Responds to a vector similarity query of the form: 36 | 37 | ``` 38 | { 39 | "queries": [{"id": str, "hash": "base64_encoded_hash1"}, ...], 40 | "k": int, 41 | "threshold": float, 42 | "hash_format": "base64" 43 | } 44 | ``` 45 | 46 | with information about similar vectors in the index in the form: 47 | 48 | ``` 49 | { 50 | "queries": [{"id": str, "matches": [{"metadata": {json metadata}, "distance": float},...],...] 51 | } 52 | ``` 53 | """ 54 | try: 55 | request_data = await request.json() 56 | except json.JSONDecodeError: 57 | return aiohttp.web.json_response({"reason": "Malformed JSON"}, status=400) 58 | 59 | index = request.app["index"] 60 | try: 61 | assert is_similarity_valid(request_data, index) 62 | except Exception: 63 | return aiohttp.web.json_response({"reason": "Invalid JSON request"}, status=400) 64 | 65 | async with request.app["query_semaphore"]: 66 | matches = await asyncio.get_event_loop().run_in_executor( 67 | None, 68 | functools.partial( 69 | index.search, 70 | queries=request_data["queries"], 71 | threshold=request_data.get( 72 | "threshold", request.app["default_threshold"] 73 | ), 74 | threshold_func=request.app["default_threshold_func"], 75 | k=request_data.get("k", request.app["default_k"]), 76 | hash_format=request_data.get("hash_format", "base64"), 77 | ), 78 | ) 79 | matches = json.loads(json.dumps({"queries": matches})) 80 | 81 | return aiohttp.web.json_response(matches) 82 | 83 | 84 | def get_logger(name, log_level): 85 | logger = logging.Logger(name=name, level=log_level) 86 | handler = logging.StreamHandler() 87 | handler.setFormatter( 88 | jsonlogger.JsonFormatter( 89 | "%(asctime)s:%(levelname)s:%(name)s:%(message)s%(exc_info)" 90 | ) 91 | ) 92 | logger.addHandler(handler) 93 | return logger 94 | 95 | 96 | async def serve( 97 | index: ApproximateNearestNeighbors, 98 | default_threshold: int | None = None, 99 | default_threshold_func: typing.Callable[[np.ndarray], np.ndarray] | None = None, 100 | default_k: int = 1, 101 | concurrency: int = 2, 102 | log_level=logging.INFO, 103 | host="localhost", 104 | port=8080, 105 | ): 106 | """Serve an index as a web API. This function does not block. 107 | If you wish to use the function in a blocking manner, you can 108 | do something like 109 | 110 | .. code-block:: python 111 | 112 | loop = asyncio.get_event_loop() 113 | loop.run_until_complete(serve(...)) 114 | loop.run_forever() 115 | 116 | You can query the API with something like: 117 | 118 | .. code-block:: bash 119 | 120 | curl --header "Content-Type: application/json" \\ 121 | --request POST \\ 122 | --data '{"queries": [{"hash": "", "id": "bar"}], "threshold": 1200}' \\ 123 | http://localhost:8080/v1/similarity 124 | 125 | Args: 126 | index: The underlying index 127 | default_threshold: The default threshold for matches 128 | default_k: The default number of nearest neighbors to look for 129 | concurrency: The number of concurrent requests served 130 | log_level: The log level to use for the logger 131 | host: The host for the servoce 132 | port: The port for the service 133 | """ 134 | logger = get_logger(name="serve", log_level=log_level) 135 | logger.info("Initializing web service") 136 | app = aiohttp.web.Application() 137 | app.router.add_post("/v1/similarity", similarity, name="similarity") 138 | 139 | # Store globals in the application object 140 | app["default_threshold"] = default_threshold 141 | app["logger"] = logger 142 | app["default_k"] = default_k 143 | app["default_threshold_func"] = default_threshold_func 144 | app["index"] = index 145 | app["query_semaphore"] = asyncio.Semaphore(concurrency) 146 | logger.info("Entering web service listener loop.") 147 | runner = aiohttp.web.AppRunner(app, logger=logger) 148 | await runner.setup() 149 | site = aiohttp.web.TCPSite(runner, host, port) 150 | await site.start() 151 | return site 152 | -------------------------------------------------------------------------------- /docs/examples/deduplication.rst: -------------------------------------------------------------------------------- 1 | Media Deduplication 2 | ******************* 3 | 4 | Perceptual hashes can be used to deduplicate sets of images. Below we provide two examples (one simple, one larger scale). 5 | 6 | **For most use cases, we recommend using PHash with** :code:`hash_size=16` **and 7 | with 0.2 as the distance threshold as in the example below.** You may wish to adjust 8 | this threshold up or down based on your tolerance for false negatives / positives. 9 | 10 | In practice, deduplicating in memory on your machine by the methods below may be impractical. 11 | For larger-scale applications, you may wish to use tools like 12 | `FAISS `_, 13 | `Annoy `_, or databases with 14 | functionality for querying based on distance such as 15 | `MemSQL `_. 16 | 17 | For the supported hashers, below are our recommended thresholds with expected false positive rates of <1%. 18 | 19 | ====================== =========== 20 | hasher threshold 21 | ====================== =========== 22 | ahash (hash_size=16) 0.008 23 | blockmean 0.008 24 | dhash (hash_size=16) 0.07 25 | marrhildreth 0.1 26 | pdq 0.2 27 | phash (hash_size=16) 0.2 28 | wavelet (hash_size=16) 0.02 29 | ====================== =========== 30 | 31 | Simple example 32 | ============== 33 | 34 | In this example, we download a ZIP file containing 18 images. One of the images is duplicated 35 | twice and another image is duplicated once. 36 | 37 | .. code-block:: python 38 | 39 | import os 40 | import glob 41 | import zipfile 42 | import urllib.request 43 | 44 | import tabulate 45 | import pandas as pd 46 | 47 | from perception import tools, hashers 48 | 49 | urllib.request.urlretrieve( 50 | "https://thorn-perception.s3.amazonaws.com/thorn-perceptual-deduplication-example.zip", 51 | "thorn-perceptual-deduplication-example.zip" 52 | ) 53 | 54 | with zipfile.ZipFile('thorn-perceptual-deduplication-example.zip') as f: 55 | f.extractall('.') 56 | 57 | filepaths = glob.glob('thorn-perceptual-deduplication-example/*.jpg') 58 | duplicate_pairs = tools.deduplicate(files=filepaths, hashers=[(hashers.PHash(hash_size=16), 0.2)]) 59 | print(tabulate.tabulate(pd.DataFrame(duplicate_pairs), showindex=False, headers=['file1', 'file2'], tablefmt='rst')) 60 | 61 | # Now we can do whatever we want with the duplicates. We could just delete 62 | # the first entry in each pair or manually verify the pairs to ensure they 63 | # are, in fact duplicates. 64 | 65 | 66 | =============================================== =============================================== 67 | file1 file2 68 | =============================================== =============================================== 69 | thorn-perceptual-deduplication-example/309b.jpg thorn-perceptual-deduplication-example/309.jpg 70 | thorn-perceptual-deduplication-example/309b.jpg thorn-perceptual-deduplication-example/309a.jpg 71 | thorn-perceptual-deduplication-example/309a.jpg thorn-perceptual-deduplication-example/309.jpg 72 | thorn-perceptual-deduplication-example/315a.jpg thorn-perceptual-deduplication-example/315.jpg 73 | =============================================== =============================================== 74 | 75 | Real-world example 76 | ================== 77 | 78 | In the example below, we use the 79 | `Caltech 256 Categories `_ dataset. Like 80 | most other public image datasets, it contains a handful of duplicates in some categories. 81 | 82 | The code below will: 83 | 84 | 1. Download the dataset 85 | 2. Group all the filepaths by category (the dataset is provided in folders) 86 | 3. Within each group, find duplicates using PHash. We will compare not just the 87 | original images, but also the 8 isometric transformations for each image. 88 | 89 | .. code-block:: python 90 | 91 | import os 92 | import tarfile 93 | from glob import glob 94 | import urllib.request 95 | 96 | import tqdm 97 | 98 | from perception import hashers, tools 99 | 100 | urllib.request.urlretrieve( 101 | "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", 102 | "256_ObjectCategories.tar" 103 | ) 104 | with tarfile.open('256_ObjectCategories.tar') as tfile: 105 | tfile.extractall() 106 | 107 | files = glob('256_ObjectCategories/**/*.jpg') 108 | 109 | # To reduce the number of pairwise comparisons, 110 | # we can deduplicate within each image category 111 | # (i.e., we don't need to compare images of 112 | # butterflies with images of chess boards). 113 | filepath_group = [ 114 | ( 115 | filepath, 116 | os.path.normpath(filepath).split(os.sep)[-2] 117 | ) for filepath in files 118 | ] 119 | groups = list(set([group for _, group in filepath_group])) 120 | 121 | # We consider any pair of images with a PHash distance of < 0.2 as 122 | # as a duplicate. 123 | comparison_hashers = [(hashers.PHash(hash_size=16), 0.2)] 124 | 125 | duplicate_pairs = [] 126 | 127 | for current_group in groups: 128 | current_filepaths = [ 129 | filepath for filepath, group in filepath_group if group == current_group 130 | ] 131 | current_duplicate_pairs = tools.deduplicate( 132 | files=current_filepaths, 133 | hashers=comparison_hashers, 134 | isometric=True, 135 | progress=tqdm.tqdm 136 | ) 137 | duplicate_pairs.extend(current_duplicate_pairs) 138 | 139 | # Now we can do whatever we want with the duplicates. We could just delete 140 | # the first entry in each pair or manually verify the pairs to ensure they 141 | # are, in fact duplicates. 142 | 143 | Video deduplication 144 | =================== 145 | 146 | Video deduplication requires more thought depending on your tolerance for false positives and 147 | how important temporal relationships are. Below is one example approach for deduplicating a 148 | group of videos by taking frames from each video that are sufficiently different from each other 149 | (to avoid keeping too many) and then using them all to find 150 | pairs of videos that have matching frames. 151 | 152 | .. code-block:: python 153 | 154 | import urllib.request 155 | import zipfile 156 | 157 | import glob 158 | import tqdm 159 | 160 | import perception.hashers 161 | 162 | # Download some example videos. 163 | urllib.request.urlretrieve( 164 | "https://thorn-perception.s3.amazonaws.com/thorn-perceptual-video-deduplication-example.zip", 165 | "thorn-perceptual-video-deduplication-example.zip" 166 | ) 167 | 168 | with zipfile.ZipFile('thorn-perceptual-video-deduplication-example.zip') as f: 169 | f.extractall('.') 170 | 171 | frame_hasher = hashers.PHash(hash_size=16) 172 | 173 | hasher = perception.hashers.FramewiseHasher(frames_per_second=1, 174 | frame_hasher=frame_hasher, 175 | interframe_threshold=50, 176 | quality_threshold=90) 177 | 178 | # Set a threshold for matching frames within videos and across videos. 179 | filepaths = glob.glob('thorn-perceptual-video-deduplication-example/*.m4v') + \ 180 | glob.glob('thorn-perceptual-video-deduplication-example/*.gif') 181 | 182 | # Returns a list of dicts with a "filepath" and "hash" key. "hash" contains a 183 | # list of hashes. 184 | hashes = hasher.compute_parallel(filepaths=filepaths, progress=tqdm.tqdm) 185 | 186 | 187 | # Flatten the hashes into a list of (filepath, hash) tuples. 188 | hashes_flattened = perception.tools.flatten([ 189 | [(hash_group['filepath'], hash_string) for hash_string in hash_group['hash']] 190 | for hash_group in hashes 191 | ]) 192 | 193 | duplicates = perception.tools.deduplicate_hashes( 194 | hashes=hashes_flattened, 195 | threshold=50, 196 | hasher=hasher 197 | ) -------------------------------------------------------------------------------- /perception/benchmarking/image.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import uuid 4 | import warnings 5 | 6 | import cv2 7 | import imgaug 8 | import pandas as pd 9 | from tqdm import tqdm 10 | 11 | from ..hashers import tools 12 | from ..hashers.hasher import ImageHasher 13 | from ..tools import deduplicate, flatten 14 | from .common import BenchmarkDataset, BenchmarkHashes, BenchmarkTransforms 15 | 16 | log = logging.getLogger(__name__) 17 | 18 | 19 | class BenchmarkImageTransforms(BenchmarkTransforms): 20 | def compute_hashes( 21 | self, hashers: dict[str, ImageHasher], max_workers: int = 5 22 | ) -> BenchmarkHashes: 23 | """Compute hashes for a series of files given some set of hashers. 24 | 25 | Args: 26 | hashers: A dictionary of hashers. 27 | max_workers: Maximum number of workers for parallel hash 28 | computation. 29 | 30 | Returns: 31 | metrics: A BenchmarkHashes object. 32 | """ 33 | hashsets = [] 34 | filepaths = self._df["filepath"] 35 | for hasher_name, hasher in hashers.items(): 36 | hash_dicts = hasher.compute_parallel( 37 | filepaths, 38 | progress=tqdm, 39 | progress_desc=f"Computing hashes for {hasher_name}", 40 | max_workers=max_workers, 41 | ) 42 | if not hasher.returns_multiple: 43 | hashes_df = pd.DataFrame.from_records(hash_dicts) 44 | else: 45 | hash_groups = [ 46 | hash_dict["hash"] if hash_dict["error"] is None else [None] 47 | for hash_dict in hash_dicts 48 | ] 49 | hash_group_sizes = [len(hash_group) for hash_group in hash_groups] 50 | current_hashes = flatten(hash_groups) 51 | current_filepaths = flatten( 52 | [ 53 | [hash_dict["filepath"]] * hash_group_size 54 | for hash_dict, hash_group_size in zip( 55 | hash_dicts, hash_group_sizes 56 | ) 57 | ] 58 | ) 59 | current_errors = flatten( 60 | [ 61 | [hash_dict["error"]] * hash_group_size 62 | for hash_dict, hash_group_size in zip( 63 | hash_dicts, hash_group_sizes 64 | ) 65 | ] 66 | ) 67 | hashes_df = pd.DataFrame( 68 | { 69 | "error": current_errors, 70 | "filepath": current_filepaths, 71 | "hash": current_hashes, 72 | } 73 | ) 74 | hashset = hashes_df.assign( 75 | hasher_name=hasher_name, 76 | hasher_hash_length=hasher.hash_length, 77 | hasher_dtype=hasher.dtype, 78 | hasher_distance_metric=hasher.distance_metric, 79 | ) 80 | hashset = hashset.merge(self._df, on="filepath") 81 | hashsets.append(hashset) 82 | return BenchmarkHashes(pd.concat(hashsets, sort=True)) 83 | 84 | 85 | class BenchmarkImageDataset(BenchmarkDataset): 86 | def deduplicate( 87 | self, hasher: ImageHasher, threshold=0.001, isometric=False 88 | ) -> tuple["BenchmarkImageDataset", set[tuple[str, str]]]: 89 | """Remove duplicate files from dataset. 90 | 91 | Args: 92 | files: A list of file paths 93 | hasher: A hasher to use for finding a duplicate 94 | threshold: The threshold required for a match 95 | isometric: Whether to compute the rotated versions of the images 96 | 97 | Returns: 98 | A list where each entry is a list of files that are 99 | duplicates of each other. We keep only the last entry. 100 | """ 101 | pairs: set[tuple[str, str]] = set() 102 | for _, group in tqdm( 103 | self._df.groupby(["category"]), desc="Deduplicating categories." 104 | ): 105 | pairs = pairs.union( 106 | set( 107 | deduplicate( 108 | files=group["filepath"].tolist(), 109 | hashers=[(hasher, threshold)], 110 | isometric=isometric, 111 | ) 112 | ) 113 | ) 114 | removed = [pair[0] for pair in pairs] 115 | return ( 116 | BenchmarkImageDataset(self._df[~self._df["filepath"].isin(removed)].copy()), 117 | pairs, 118 | ) 119 | 120 | def transform( 121 | self, 122 | transforms: dict[str, imgaug.augmenters.meta.Augmenter], 123 | storage_dir: str, 124 | errors: str = "raise", 125 | ) -> BenchmarkImageTransforms: 126 | """Prepare files to be used as part of benchmarking run. 127 | 128 | Args: 129 | transforms: A dictionary of transformations. The only required 130 | key is `noop` which determines how the original, untransformed 131 | image is saved. For a true copy, simply make the `noop` key 132 | `imgaug.augmenters.Noop()`. 133 | storage_dir: A directory to store all the images along with 134 | their transformed counterparts. 135 | errors: How to handle errors reading files. If "raise", exceptions are 136 | raised. If "warn", the error is printed as a warning. 137 | 138 | Returns: 139 | transforms: A BenchmarkImageTransforms object 140 | """ 141 | assert ( 142 | "noop" in transforms 143 | ), "You must provide a no-op transform such as `lambda img: img`." 144 | 145 | os.makedirs(storage_dir, exist_ok=True) 146 | 147 | files = self._df.copy() 148 | files["guid"] = [uuid.uuid4() for n in range(len(files))] 149 | 150 | def apply_transform(files, transform_name): 151 | transform = transforms[transform_name] 152 | transformed_arr = [] 153 | for _, row in tqdm( 154 | files.iterrows(), 155 | desc=f"Creating files for {transform_name}", 156 | total=len(files), 157 | ): 158 | filepath, guid, category = row[["filepath", "guid", "category"]] 159 | try: 160 | image = tools.read(filepath) 161 | except Exception as exception: 162 | message = f"An error occurred reading {filepath}." 163 | if errors == "raise": 164 | raise exception 165 | warnings.warn(message, UserWarning) 166 | continue 167 | try: 168 | transformed = transform(image=image) 169 | except Exception as e: 170 | raise RuntimeError( 171 | f"An exception occurred while processing {filepath} " 172 | f"with transform {transform_name}." 173 | ) from e 174 | transformed_path = os.path.join( 175 | storage_dir, f"{guid}_{transform_name}.jpg" 176 | ) 177 | cv2.imwrite( 178 | transformed_path, cv2.cvtColor(transformed, cv2.COLOR_RGB2BGR) 179 | ) 180 | transformed_arr.append( 181 | { 182 | "guid": guid, 183 | "transform_name": transform_name, 184 | "input_filepath": filepath, 185 | "filepath": transformed_path, 186 | "category": category, 187 | } 188 | ) 189 | return pd.DataFrame.from_records(transformed_arr) 190 | 191 | results = [apply_transform(files, transform_name="noop")] 192 | 193 | for transform_name in transforms.keys(): 194 | if transform_name == "noop": 195 | continue 196 | results.append(apply_transform(results[0], transform_name=transform_name)) 197 | benchmark_transforms = BenchmarkImageTransforms( 198 | df=pd.concat(results, axis=0, ignore_index=True) 199 | ) 200 | benchmark_transforms.save(storage_dir) 201 | return benchmark_transforms 202 | -------------------------------------------------------------------------------- /perception/benchmarking/video_transforms.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import ffmpeg 5 | 6 | from ..hashers.tools import read_video 7 | 8 | 9 | def probe(filepath): 10 | """Get the output of ffprobe.""" 11 | return ffmpeg.probe(filepath) 12 | 13 | 14 | def sanitize_output_filepath(input_filepath, output_filepath, output_ext=None): 15 | """Get a suitable output filepath with an extension based on 16 | an input filepath. 17 | 18 | Args: 19 | input_filepath: The filepath for the source file. 20 | output_filepath: The filepath for the output file. 21 | output_ext: A new extension to add (e.g., '.gif') 22 | """ 23 | _, input_ext = os.path.splitext(input_filepath) 24 | if not output_filepath.lower().endswith(output_ext or input_ext): 25 | output_filepath += output_ext or input_ext 26 | return output_filepath 27 | 28 | 29 | def get_simple_transform( 30 | width: str | int = -1, 31 | height: str | int = -1, 32 | pad: str | None = None, 33 | codec: str | None = None, 34 | clip_pct: tuple[float, float] | None = None, 35 | clip_s: tuple[float, float] | None = None, 36 | sar=None, 37 | fps=None, 38 | output_ext=None, 39 | ): 40 | """Resize to a specific size and re-encode. 41 | 42 | Args: 43 | width: The target width (-1 to maintain aspect ratio) 44 | height: The target height (-1 to maintain aspect ratio) 45 | pad: An ffmpeg pad argument provided as a string. 46 | codec: The codec for encoding the video. 47 | fps: The new frame rate for the video. 48 | clip_pct: The video start and end in percentages of video duration. 49 | clip_s: The video start and end in seconds (used over clip_pct if both 50 | are provided). 51 | sar: Whether to make all videos have a common sample aspect 52 | ratio (i.e., for all square pixels, set this to '1/1'). 53 | output_ext: The extension to use when re-encoding (used to select 54 | video format). It should include the leading '.'. 55 | """ 56 | 57 | def transform(input_filepath, output_filepath): 58 | output_filepath = sanitize_output_filepath( 59 | input_filepath, output_filepath, output_ext 60 | ) 61 | data = None 62 | if codec is None: 63 | data = data or probe(input_filepath) 64 | output_codec = [s for s in data["streams"] if s["codec_type"] == "video"][ 65 | 0 66 | ]["codec_name"] 67 | else: 68 | output_codec = codec 69 | format_kwargs = {"codec:v": output_codec} 70 | if clip_pct is not None or clip_s is not None: 71 | pct_start, pct_end, pos_start, pos_end = None, None, None, None 72 | if clip_pct is not None: 73 | pct_start, pct_end = clip_pct 74 | if clip_s is not None: 75 | pos_start, pos_end = clip_s 76 | if pct_start is not None: 77 | assert 0 <= pct_start <= 1, "Start position must be between 0 and 1." 78 | if pct_end is not None: 79 | assert 0 <= pct_end <= 1, "End position must be between 0 and 1." 80 | if pct_start is not None and pct_end is not None: 81 | assert pct_start < pct_end, "End must be greater than start." 82 | if (pct_start is not None and pos_start is None) or ( 83 | pct_end is not None and pos_end is None 84 | ): 85 | # We only want to get the duration for the video if we need 86 | # it. 87 | data = data or probe(input_filepath) 88 | duration = float(data["streams"][0]["duration"]) 89 | if pct_start is not None or pos_start is not None: 90 | format_kwargs["ss"] = pos_start or pct_start * duration # type: ignore 91 | if pct_end is not None or pos_end is not None: 92 | format_kwargs["t"] = pos_end or pct_end * duration # type: ignore 93 | stream = ffmpeg.input(input_filepath) 94 | if not (width == -1 and height == -1): 95 | stream = stream.filter("scale", width, height) 96 | if pad is not None: 97 | stream = stream.filter("pad", *pad.split(":")) 98 | if fps is not None: 99 | stream = stream.filter("fps", fps) 100 | if sar is not None: 101 | stream = stream.filter("setsar", sar) 102 | stream = stream.output(output_filepath, **format_kwargs).overwrite_output() 103 | ffmpeg.run(stream) 104 | if os.path.isfile(output_filepath): 105 | return output_filepath 106 | return None 107 | 108 | return transform 109 | 110 | 111 | def get_slideshow_transform( 112 | frame_input_rate, frame_output_rate, max_frames=None, offset=0 113 | ): 114 | """Get a slideshow transform to create slideshows from 115 | videos. 116 | 117 | Args: 118 | frame_input_rate: The rate at which frames will be sampled 119 | from the source video (e.g., a rate of 1 means we collect 120 | one frame per second of the input video). 121 | frame_output_rate: The rate at which the sampled frames are played 122 | in the slideshow (e.g., a rate of 0.5 means each frame will 123 | appear for 2 seconds). 124 | max_frames: The maximum number of frames to write. 125 | offset: The number of seconds to wait before beginning the slide show. 126 | """ 127 | 128 | def transform(input_filepath, output_filepath): 129 | output_filepath = sanitize_output_filepath( 130 | input_filepath, output_filepath, output_ext=".avi" 131 | ) 132 | writer = None 133 | frame_count = 0 134 | try: 135 | for frame, _, timestamp in read_video( 136 | filepath=input_filepath, frames_per_second=frame_input_rate 137 | ): 138 | if timestamp < offset: 139 | continue 140 | if writer is None: 141 | writer = cv2.VideoWriter( 142 | filename=output_filepath, 143 | fourcc=cv2.VideoWriter_fourcc(*"MJPG"), # type: ignore[attr-defined] 144 | fps=frame_output_rate, 145 | frameSize=tuple(frame.shape[:2][::-1]), 146 | isColor=True, 147 | ) 148 | writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) 149 | frame_count += 1 150 | if max_frames is not None and frame_count >= max_frames: 151 | break 152 | finally: 153 | if writer is not None: 154 | writer.release() 155 | if os.path.isfile(output_filepath): 156 | return output_filepath 157 | return None 158 | 159 | return transform 160 | 161 | 162 | def get_black_frame_padding_transform(duration_s=0, duration_pct=0): 163 | """Get a transform that adds black frames at the start and end 164 | of a video. 165 | 166 | Args: 167 | duration_s: The duration of the black frames in seconds. 168 | duration_pct: The duration of the black frames 169 | as a percentage of video duration. If both duration_s 170 | and duration_pct are provided, the maximum value 171 | is used. 172 | """ 173 | 174 | def transform(input_filepath, output_filepath): 175 | output_filepath = sanitize_output_filepath(input_filepath, output_filepath) 176 | stream = next( 177 | stream 178 | for stream in probe(input_filepath)["streams"] 179 | if stream["codec_type"] == "video" 180 | ) 181 | assert stream["sample_aspect_ratio"] == "1:1", "SAR is not 1:1." 182 | width = stream["width"] 183 | height = stream["height"] 184 | duration = max(duration_s, duration_pct * float(stream["duration"])) 185 | ffmpeg.input(input_filepath).output( 186 | output_filepath, 187 | vf=( 188 | "color=c=black:s={width}x{height}:d={duration} [pre] ; " 189 | "color=c=black:s={width}x{height}:d={duration} [post] ; " 190 | "[pre] [in] [post] concat=n=3" 191 | ).format(width=width, height=height, duration=duration), 192 | fps_mode="vfr", 193 | ).overwrite_output().run() 194 | if os.path.isfile(output_filepath): 195 | return output_filepath 196 | return None 197 | 198 | return transform 199 | -------------------------------------------------------------------------------- /perception/hashers/video/tmk.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import warnings 3 | 4 | import numpy as np 5 | import scipy.special 6 | 7 | from ..hasher import ImageHasher, VideoHasher 8 | from ..image.phash import PHashF 9 | 10 | 11 | class TMKL2(VideoHasher): 12 | """The TMK L2 video hashing algorithm.""" 13 | 14 | dtype = "float32" 15 | distance_metric = "custom" 16 | 17 | def __init__( 18 | self, 19 | frame_hasher: ImageHasher | None = None, 20 | frames_per_second: int = 15, 21 | normalization: str = "matrix", 22 | ): 23 | if platform.machine() == "arm64": 24 | warnings.warn("TMK is not supported on ARM64") 25 | 26 | T = np.array([2731, 4391, 9767, 14653]).astype("float32") 27 | m = 32 28 | if frame_hasher is None: 29 | frame_hasher = PHashF(hash_size=16, exclude_first_term=True, freq_shift=1) 30 | self.frames_per_second = frames_per_second 31 | assert frame_hasher.dtype != "bool", "This hasher requires real valued hashes." 32 | 33 | # Beta parameter of the modified Bessel function of the first kind 34 | self.beta = 32 35 | 36 | # Number of Fourier coefficients per period 37 | self.m = m 38 | 39 | # The periods with shape (T, ) 40 | self.T = T # (T) 41 | 42 | # The Fourier coefficients with shape (T, m, 1) 43 | self.ms = 2 * np.pi * np.arange(0, self.m).astype("float32") # (m) 44 | self.ms_normed = (self.ms[np.newaxis,] / self.T.reshape(-1, 1)).reshape( 45 | len(self.T), self.m, 1 46 | ) # (T, m, 1) 47 | 48 | # The weights with shape (T, 2m, 1) 49 | a = np.array( 50 | [ 51 | (scipy.special.iv(0, self.beta) - np.exp(-self.beta)) 52 | / (2 * np.sinh(self.beta)) 53 | ] 54 | + [ 55 | scipy.special.iv(i, self.beta) / np.sinh(self.beta) 56 | for i in range(1, self.m) 57 | ] 58 | ) 59 | a = a.reshape(1, -1).repeat(repeats=len(self.T), axis=0) 60 | a = np.sqrt(a) 61 | self.a = a[..., np.newaxis] 62 | 63 | # The frame-wise hasher 64 | self.frame_hasher = frame_hasher 65 | 66 | self.hash_length = self.T.shape[0] * 2 * self.m * self.frame_hasher.hash_length 67 | 68 | self.normalization = normalization 69 | 70 | def process_frame(self, frame, frame_index, frame_timestamp, state=None): 71 | if state is None: 72 | state = {"features": [], "timestamps": []} 73 | state["features"].append(self.frame_hasher.compute(frame, hash_format="vector")) 74 | state["timestamps"].append(frame_timestamp) 75 | return state 76 | 77 | def hash_from_final_state(self, state): 78 | timestamps = np.array(state["timestamps"]) 79 | features = np.array(state["features"]).reshape( 80 | (1, 1, timestamps.shape[0], self.frame_hasher.hash_length) 81 | ) 82 | x = self.ms_normed * timestamps 83 | yw1 = np.sin(x) * self.a 84 | yw2 = np.cos(x) * self.a 85 | yw = np.concatenate([yw1, yw2], axis=1)[..., np.newaxis] # (T, 2m, t, 1) 86 | y = (yw * features).sum(axis=2) # (T, 2m, d) 87 | return y.flatten() 88 | 89 | def _compute_distance(self, vector1, vector2): 90 | shape = (len(self.T), 2 * self.m, self.frame_hasher.hash_length) 91 | return 1 - self._score_pair( 92 | fv_a=vector1.reshape(shape), 93 | fv_b=vector2.reshape(shape), 94 | offsets=None, 95 | normalization=self.normalization, 96 | ) 97 | 98 | def _score_pair(self, fv_a, fv_b, offsets=None, normalization="matrix"): 99 | eps = 1e-8 100 | 101 | if offsets is None: 102 | offsets = np.array([0]) 103 | 104 | assert normalization in [ 105 | "feat", 106 | "freq", 107 | "feat_freq", 108 | "matrix", 109 | ], "Invalid normalization" 110 | 111 | if "feat" in normalization: 112 | a_xp = np.concatenate([self.a, self.a], axis=1) # (T, 2m, 1) 113 | fv_a_0 = fv_a / a_xp 114 | fv_b_0 = fv_b / a_xp 115 | norm_a = np.sqrt(np.sum(fv_a_0**2, axis=2, keepdims=True) + eps) + eps 116 | norm_b = np.sqrt(np.sum(fv_b_0**2, axis=2, keepdims=True) + eps) + eps 117 | fv_a = fv_a / norm_a 118 | fv_b = fv_b / norm_b 119 | 120 | if "freq" in normalization: 121 | norm_a, norm_b = ( 122 | np.sqrt((fv**2).sum(axis=1, keepdims=True) / self.m + eps) + eps 123 | for fv in [fv_a, fv_b] 124 | ) 125 | fv_a = fv_a / norm_a 126 | fv_b = fv_b / norm_b 127 | 128 | if normalization == "matrix": 129 | norm_a, norm_b = ( 130 | np.sqrt(np.sum(fv**2, axis=(1, 2)) + eps)[..., np.newaxis] + eps 131 | for fv in [fv_a, fv_b] 132 | ) # (T, 1) 133 | 134 | fv_a_sin, fv_b_sin = (fv[:, : self.m] for fv in [fv_a, fv_b]) # (T, m, d) 135 | fv_a_cos, fv_b_cos = (fv[:, self.m :] for fv in [fv_a, fv_b]) # (T, m, d) 136 | ms = self.ms.reshape(-1, 1) # (m, 1) 137 | dot_sin_sin, dot_sin_cos, dot_cos_cos, dot_cos_sin = ( 138 | np.sum(p, axis=2, keepdims=True) 139 | for p in [ 140 | fv_a_sin * fv_b_sin, 141 | fv_a_sin * fv_b_cos, 142 | fv_a_cos * fv_b_cos, 143 | fv_a_cos * fv_b_sin, 144 | ] 145 | ) # (T, m, 1) 146 | delta = ( 147 | ms.reshape(1, -1, 1) * offsets.reshape(1, -1) / self.T.reshape((-1, 1, 1)) 148 | ) 149 | cos_delta = np.cos(delta) # (T, m, delta) 150 | sin_delta = np.sin(delta) # (T, m, delta) 151 | dots = ( 152 | dot_sin_sin * cos_delta 153 | + dot_sin_cos * sin_delta 154 | + dot_cos_cos * cos_delta 155 | - dot_cos_sin * sin_delta 156 | ).sum(axis=1) 157 | if normalization == "matrix": 158 | dots = dots / (norm_a * norm_b) 159 | if normalization == "freq": 160 | dots = dots / self.m # (T, m, delta) 161 | elif normalization in ["feat", "feat_freq"]: 162 | dots = dots / 512 163 | return dots.mean(axis=0) 164 | 165 | 166 | class TMKL1(VideoHasher): 167 | """The TMK L1 video hashing algorithm.""" 168 | 169 | def __init__( 170 | self, 171 | frame_hasher: ImageHasher | None = None, 172 | frames_per_second: int = 15, 173 | dtype="float32", 174 | distance_metric="cosine", 175 | norm=2, 176 | quality_threshold=None, 177 | ): 178 | if frame_hasher is None: 179 | frame_hasher = PHashF(hash_size=16, exclude_first_term=True, freq_shift=1) 180 | self.hash_length = frame_hasher.hash_length 181 | self.frames_per_second = frames_per_second 182 | assert frame_hasher.dtype != "bool", "This hasher requires real valued hashes." 183 | self.frame_hasher = frame_hasher 184 | self.norm = norm 185 | self.dtype = dtype or self.frame_hasher.dtype 186 | self.distance_metric = distance_metric or self.frame_hasher.distance_metric 187 | self.quality_threshold = quality_threshold 188 | 189 | def process_frame(self, frame, frame_index, frame_timestamp, state=None): 190 | if state is None: 191 | state = {"sum": np.zeros(self.frame_hasher.hash_length), "frame_count": 0} 192 | if self.quality_threshold is None: 193 | hash_vector = self.frame_hasher.compute(frame, hash_format="vector") 194 | else: 195 | hash_vector, quality = self.frame_hasher.compute_with_quality( 196 | frame, hash_format="vector" 197 | ) 198 | if quality < self.quality_threshold: 199 | return state 200 | assert isinstance(hash_vector, np.ndarray) # help type checking below 201 | if hash_vector is not None: 202 | state["sum"] += hash_vector.astype(np.float32) 203 | state["frame_count"] += 1 204 | return state 205 | 206 | def hash_from_final_state(self, state): 207 | if state["frame_count"] == 0: 208 | return None 209 | average_vector = state["sum"] / state["frame_count"] 210 | if self.norm is not None: 211 | return ( 212 | average_vector / np.linalg.norm(average_vector, ord=self.norm) 213 | ).astype(self.frame_hasher.dtype) 214 | return average_vector.astype(self.frame_hasher.dtype) 215 | -------------------------------------------------------------------------------- /perception/benchmarking/video.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import os 3 | import typing 4 | import uuid 5 | 6 | import pandas as pd 7 | import tqdm 8 | 9 | from ..hashers import VideoHasher, tools 10 | from ..tools import flatten 11 | from .common import BenchmarkDataset, BenchmarkHashes, BenchmarkTransforms 12 | 13 | 14 | def _process_row(row, hashers, framerates): 15 | error = None 16 | try: 17 | assert not pd.isnull(row["filepath"]), "No filepath provided." 18 | hashes = tools.compute_synchronized_video_hashes( 19 | filepath=row["filepath"], 20 | hashers=hashers, 21 | framerates=framerates, 22 | hash_format="base64", 23 | ) 24 | except Exception as exception: 25 | error = str(exception) 26 | hashes = { 27 | hasher_name: [None] if hasher.returns_multiple else None 28 | for hasher_name, hasher in hashers.items() 29 | } 30 | base_dict = { 31 | "guid": row["guid"], 32 | "filepath": row["filepath"], 33 | "error": error, 34 | "category": row["category"], 35 | "transform_name": row["transform_name"], 36 | "input_filepath": row["input_filepath"], 37 | } 38 | hash_dicts = [] 39 | for hasher_name, hasher in hashers.items(): 40 | base_hash_dict = { 41 | "hasher_name": hasher_name, 42 | "hasher_dtype": hasher.dtype, 43 | "hasher_distance_metric": hasher.distance_metric, 44 | "hasher_hash_length": hasher.hash_length, 45 | } 46 | if not hasher.returns_multiple: 47 | hash_dicts.append( 48 | { 49 | **{ 50 | "hash": hashes[hasher_name], 51 | }, 52 | **base_hash_dict, 53 | } 54 | ) 55 | else: 56 | for hash_value in hashes[hasher_name]: 57 | hash_dicts.append( 58 | { 59 | **{ 60 | "hash": hash_value, 61 | }, 62 | **base_hash_dict, 63 | } 64 | ) 65 | return [{**hash_dict, **base_dict} for hash_dict in hash_dicts] 66 | 67 | 68 | class BenchmarkVideoDataset(BenchmarkDataset): 69 | def transform( 70 | self, 71 | transforms: dict[str, typing.Callable], 72 | storage_dir: str, 73 | errors: str = "raise", 74 | ): 75 | """Prepare files to be used as part of benchmarking run. 76 | 77 | Args: 78 | transforms: A dictionary of transformations. The only required 79 | key is `noop` which determines how the original, untransformed 80 | video is saved. Each transform should be a callable function with 81 | that accepts an `input_filepath` and `output_filepath` argument and 82 | it should return the `output_filepath` (which may have a different 83 | extension appended by the transform function). 84 | storage_dir: A directory to store all the videos along with 85 | their transformed counterparts. 86 | errors: How to handle errors reading files. If "raise", exceptions are 87 | raised. If "warn", the error is printed as a warning. 88 | 89 | Returns: 90 | transforms: A BenchmarkVideoTransforms object 91 | """ 92 | assert "noop" in transforms, "You must provide a no-op transform." 93 | 94 | os.makedirs(storage_dir, exist_ok=True) 95 | 96 | files = self._df.copy() 97 | files["guid"] = [uuid.uuid4() for n in range(len(files))] 98 | 99 | def apply_transform_to_file(input_filepath, guid, transform_name, category): 100 | if input_filepath is None: 101 | # This can happen if the noop transform did not yield 102 | # a file. We don't want to drop the records so we 103 | # keep them. 104 | return { 105 | "guid": guid, 106 | "error": "No source file provided", 107 | "transform_name": transform_name, 108 | "input_filepath": input_filepath, 109 | "filepath": None, 110 | "category": category, 111 | } 112 | try: 113 | output_filepath = transforms[transform_name]( 114 | input_filepath, 115 | output_filepath=os.path.join( 116 | storage_dir, f"{guid}_{transform_name}" 117 | ), 118 | ) 119 | error = None 120 | except Exception as e: 121 | output_filepath = None 122 | error = str(e) 123 | return { 124 | "guid": guid, 125 | "error": error, 126 | "transform_name": transform_name, 127 | "input_filepath": input_filepath, 128 | "filepath": output_filepath, 129 | "category": category, 130 | } 131 | 132 | def apply_transform_to_files(files, transform_name): 133 | return pd.DataFrame.from_records( 134 | [ 135 | apply_transform_to_file( 136 | input_filepath=row["filepath"], 137 | guid=row["guid"], 138 | transform_name=transform_name, 139 | category=row["category"], 140 | ) 141 | for _, row in tqdm.tqdm( 142 | files.iterrows(), 143 | desc=f"Creating files for {transform_name}", 144 | total=len(files), 145 | ) 146 | ] 147 | ) 148 | 149 | results = [apply_transform_to_files(files, transform_name="noop")] 150 | for transform_name in transforms.keys(): 151 | if transform_name == "noop": 152 | continue 153 | results.append( 154 | apply_transform_to_files(results[0], transform_name=transform_name) 155 | ) 156 | benchmark_transforms = BenchmarkVideoTransforms( 157 | df=pd.concat(results, axis=0, ignore_index=True) 158 | ) 159 | benchmark_transforms.save(storage_dir) 160 | return benchmark_transforms 161 | 162 | 163 | class BenchmarkVideoTransforms(BenchmarkTransforms): 164 | expected_columns = [ 165 | "filepath", 166 | "category", 167 | "transform_name", 168 | "input_filepath", 169 | "guid", 170 | "error", 171 | ] 172 | 173 | def compute_hashes( 174 | self, hashers: dict[str, VideoHasher], max_workers: int = 5 175 | ) -> BenchmarkHashes: 176 | """Compute hashes for a series of files given some set of hashers. 177 | 178 | Args: 179 | hashers: A dictionary of hashers. 180 | max_workers: Maximum number of workers for parallel hash 181 | computation. 182 | 183 | Returns: 184 | hashes: A BenchmarkHashes object. 185 | """ 186 | id_rates = { 187 | hasher_name: hasher.frames_per_second 188 | for hasher_name, hasher in hashers.items() 189 | if hasher.frames_per_second is not None 190 | } 191 | if id_rates: 192 | framerates = tools.get_common_framerates( 193 | { 194 | hasher_name: hasher.frames_per_second 195 | for hasher_name, hasher in hashers.items() 196 | if hasher.frames_per_second is not None 197 | } 198 | ) 199 | else: 200 | framerates = {} 201 | 202 | with concurrent.futures.ProcessPoolExecutor( 203 | max_workers=max_workers 204 | ) as executor: 205 | futures = [ 206 | executor.submit( 207 | _process_row, row=row, framerates=framerates, hashers=hashers 208 | ) 209 | for index, row in self._df.iterrows() 210 | ] 211 | return BenchmarkHashes( 212 | pd.DataFrame.from_records( 213 | flatten( 214 | [ 215 | future.result() 216 | for future in tqdm.tqdm( 217 | concurrent.futures.as_completed(futures), 218 | desc="Computing hashes.", 219 | total=len(self._df), 220 | ) 221 | ] 222 | ) 223 | ) 224 | ) 225 | -------------------------------------------------------------------------------- /perception/approximate_deduplication/debug.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | import perception.local_descriptor_deduplication as ldd 8 | 9 | LOGGER = logging.getLogger(__name__) 10 | 11 | # Set a fixed size for drawing, we don't have the real descriptor size. 12 | KEYPOINT_SIZE: int = 8 13 | 14 | 15 | def vizualize_pair( 16 | features_1, 17 | features_2, 18 | ratio: float, 19 | match_metadata=None, 20 | local_path_col: str | None = None, 21 | sanitized: bool = False, 22 | include_all_points=False, 23 | circle_size=KEYPOINT_SIZE, 24 | ): 25 | """Given two rows from a reference df vizualize their overlap. 26 | 27 | Currently recalcs overlap using cv2 default logic. 28 | 29 | Args: 30 | features_1: The row from a reference df for one image. 31 | features_2: The row from a reference df for the other image. 32 | ratio: Value for ratio test, suggest re-using value from matching. 33 | match_metadata: metadata returned from matching, if None will redo brute force matching. 34 | local_path_col: column in df with path to the image. If None will 35 | use the index: features_1.name and features_2.name 36 | sanitized: if True images themselves will not be rendered, only the points. 37 | include_all_points: if True will draw all points, not just matched points. 38 | circle_size: size of the circle to draw around keypoints. 39 | Returns: 40 | An image of the two images concatted together and matching keypoints drawn. 41 | """ 42 | # Set a fixed size for drawing, we don't have the real descriptor size. 43 | if local_path_col is not None: 44 | features_1_path = features_1[local_path_col] 45 | features_2_path = features_2[local_path_col] 46 | else: 47 | features_1_path = features_1.name 48 | features_2_path = features_2.name 49 | 50 | img1 = np.zeros( 51 | (features_1.dimensions[1], features_1.dimensions[0], 1), dtype="uint8" 52 | ) 53 | img2 = np.zeros( 54 | (features_2.dimensions[1], features_2.dimensions[0], 1), dtype="uint8" 55 | ) 56 | 57 | if not sanitized: 58 | try: 59 | img1 = ldd.load_and_preprocess( 60 | features_1_path, max_size=max(features_1.dimensions), grayscale=False 61 | ) 62 | except Exception: 63 | LOGGER.warning("Failed to load image %s", features_1_path) 64 | try: 65 | img2 = ldd.load_and_preprocess( 66 | features_2_path, max_size=max(features_2.dimensions), grayscale=False 67 | ) 68 | except Exception: 69 | LOGGER.warning("Failed to load image %s", features_2_path) 70 | 71 | if match_metadata is not None: 72 | img_matched = viz_match_data( 73 | features_1, 74 | features_2, 75 | img1, 76 | img2, 77 | match_metadata, 78 | include_all_points=include_all_points, 79 | circle_size=circle_size, 80 | ) 81 | else: 82 | LOGGER.warning( 83 | """No match_metadata provided, recalculating match points, 84 | won't match perception match points.""" 85 | ) 86 | img_matched = viz_brute_force(features_1, features_2, img1, img2, ratio=ratio) 87 | 88 | return img_matched 89 | 90 | 91 | def viz_match_data( 92 | features_1, 93 | features_2, 94 | img1, 95 | img2, 96 | match_metadata, 97 | include_all_points=False, 98 | circle_size=KEYPOINT_SIZE, 99 | ): 100 | """Given match data viz matching points. 101 | 102 | Args: 103 | features_1: The row from a reference df for one image. 104 | features_2: The row from a reference df for the other image. 105 | img1: cv2 of first image 106 | img2: cv2 of second image 107 | match_metadata: metadata returned from matching, if None will redo 108 | brute force matching. 109 | include_all_points: if True will draw all points, not just matched points. 110 | circle_size: size of the circle to draw around keypoints. 111 | Returns: 112 | cv2 img with matching keypoints drawn. 113 | """ 114 | # NOTE: could refactor to put matches in to correct format and use: cv2.drawMatchesKnn, 115 | # but python docs on necessary class not clear. 116 | 117 | # Pad img1 or img2 vertically with black pixels to match the height of the other image 118 | if img1.shape[0] > img2.shape[0]: 119 | img2 = np.pad( 120 | img2, 121 | ((0, img1.shape[0] - img2.shape[0]), (0, 0), (0, 0)), 122 | mode="constant", 123 | constant_values=0, 124 | ) 125 | elif img1.shape[0] < img2.shape[0]: 126 | img1 = np.pad( 127 | img1, 128 | ((0, img2.shape[0] - img1.shape[0]), (0, 0), (0, 0)), 129 | mode="constant", 130 | constant_values=0, 131 | ) 132 | # draw two images h concat: 133 | img_matched = np.concatenate((img1, img2), axis=1) 134 | 135 | overlay = img_matched.copy() 136 | 137 | if include_all_points: 138 | # draw all points in kp_1 139 | for k in features_1["keypoints"]: 140 | new_color = ( 141 | random.randint(0, 255), 142 | random.randint(0, 255), 143 | random.randint(0, 255), 144 | ) 145 | # Draw semi transparent circle 146 | cv2.circle(img_matched, (int(k[0]), int(k[1])), circle_size, new_color, 1) 147 | 148 | # draw all points in kp_2 149 | for k in features_2["keypoints"]: 150 | new_color = ( 151 | random.randint(0, 255), 152 | random.randint(0, 255), 153 | random.randint(0, 255), 154 | ) 155 | cv2.circle( 156 | img_matched, 157 | (int(k[0] + features_1.dimensions[0]), int(k[1])), 158 | circle_size, 159 | new_color, 160 | 1, 161 | ) 162 | 163 | # draw lines between matching points 164 | for i in range(len(match_metadata["final_matched_b_pts"])): 165 | new_color = ( 166 | random.randint(0, 255), 167 | random.randint(0, 255), 168 | random.randint(0, 255), 169 | ) 170 | a_pt = ( 171 | int(match_metadata["final_matched_a_pts"][i][0]), 172 | int(match_metadata["final_matched_a_pts"][i][1]), 173 | ) 174 | b_pt = ( 175 | int(match_metadata["final_matched_b_pts"][i][0] + features_1.dimensions[0]), 176 | int(match_metadata["final_matched_b_pts"][i][1]), 177 | ) 178 | cv2.circle(img_matched, a_pt, circle_size, new_color, 1) 179 | cv2.circle(img_matched, b_pt, circle_size, new_color, 1) 180 | cv2.line( 181 | img_matched, 182 | a_pt, 183 | b_pt, 184 | new_color, 185 | 1, 186 | ) 187 | 188 | # Re-overlay original image to add some transparency effect to lines and circles. 189 | alpha = 0.4 # Transparency factor. 190 | # Following line overlays transparent rectangle over the image 191 | img_matched = cv2.addWeighted(overlay, alpha, img_matched, 1 - alpha, 0) 192 | 193 | return img_matched 194 | 195 | 196 | def viz_brute_force(features_1, features_2, img1, img2, ratio: float): 197 | """ 198 | Given two rows from a reference df vizualize their overlap. 199 | 200 | NOTE: It redoes matching using cv2 bruteforce, so will not match the same 201 | as the perception matching code. 202 | 203 | Args: 204 | features_1: The row from a reference df for one image. 205 | features_2: The row from a reference df for the other image. 206 | img1: cv2 of first image 207 | img2: cv2 of second image 208 | ratio: Value for ratio test, suggest re-using value from matching. 209 | 210 | Returns: 211 | An image of the two images concatted together and matching keypoints drawn. 212 | """ 213 | # Convert numpy keypoints to cv2.KeyPoints 214 | kp1_fixed = [] 215 | for k in features_1["keypoints"]: 216 | kp1_fixed.append(cv2.KeyPoint(k[0], k[1], KEYPOINT_SIZE)) 217 | 218 | kp2_fixed = [] 219 | for k in features_2["keypoints"]: 220 | kp2_fixed.append(cv2.KeyPoint(k[0], k[1], KEYPOINT_SIZE)) 221 | brute_force_matcher = cv2.BFMatcher() 222 | kn_matches = brute_force_matcher.knnMatch( 223 | features_1["descriptors"], features_2["descriptors"], k=2 224 | ) 225 | # Apply ratio test 226 | good = [] 227 | for nearest_match, next_nearest_match in kn_matches: 228 | if nearest_match.distance < ratio * next_nearest_match.distance: 229 | good.append([nearest_match]) 230 | img_matched = cv2.drawMatchesKnn( # type: ignore[call-overload] 231 | img1, 232 | kp1_fixed, 233 | img2, 234 | kp2_fixed, 235 | good, 236 | None, 237 | flags=cv2.DrawMatchesFlags_DRAW_RICH_KEYPOINTS, 238 | ) 239 | return img_matched 240 | -------------------------------------------------------------------------------- /perception/testing/__init__.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import math 3 | import typing 4 | from contextlib import ExitStack 5 | from importlib import resources 6 | 7 | import cv2 8 | import numpy as np 9 | import pandas as pd 10 | import pytest 11 | from PIL import Image 12 | 13 | from .. import hashers, tools 14 | 15 | SIZES = {"float32": 32, "uint8": 8, "bool": 1} 16 | 17 | 18 | def get_low_detail_image(): 19 | v = np.arange(0, 50, 1) 20 | v = np.concatenate([v, v[::-1]])[np.newaxis,] 21 | image = np.matmul(v.T, v) 22 | image = (image * 255 / image.max()).astype("uint8") 23 | image = image[..., np.newaxis].repeat(repeats=3, axis=2) 24 | image[:, 50:] = 0 25 | image[50:] = 0 26 | return image 27 | 28 | 29 | LOW_DETAIL_IMAGE = get_low_detail_image() 30 | 31 | file_manager = ExitStack() 32 | atexit.register(file_manager.close) 33 | 34 | DEFAULT_TEST_IMAGES = [ 35 | str( 36 | file_manager.enter_context( 37 | resources.as_file( 38 | resources.files("perception") / "testing" / "images" / f"image{n}.jpg" 39 | ) 40 | ) 41 | ) 42 | for n in range(1, 11) 43 | ] 44 | DEFAULT_TEST_LOGOS = [ 45 | str( 46 | file_manager.enter_context( 47 | resources.as_file( 48 | resources.files("perception") / "testing" / "logos" / "logoipsum.png" 49 | ) 50 | ) 51 | ) 52 | ] 53 | DEFAULT_TEST_VIDEOS = [ 54 | str( 55 | file_manager.enter_context( 56 | resources.as_file( 57 | resources.files("perception") / "testing" / "videos" / f"v{n}.m4v" 58 | ) 59 | ) 60 | ) 61 | for n in range(1, 3) 62 | ] + [ 63 | str( 64 | file_manager.enter_context( 65 | resources.as_file( 66 | resources.files("perception") / "testing" / "videos" / "v2s.mov" 67 | ) 68 | ) 69 | ) 70 | ] 71 | 72 | 73 | @typing.no_type_check 74 | def test_opencv_hasher(hasher: hashers.ImageHasher, image1: str, image2: str): 75 | # For OpenCV hashers we make sure the distance we compute 76 | # is the same as inside OpenCV 77 | f1 = image1 78 | f2 = image2 79 | opencv_distance = hasher.hasher.compare( 80 | hasher.hasher.compute(hashers.tools.read(f1)), 81 | hasher.hasher.compute(hashers.tools.read(f2)), 82 | ) 83 | if hasher.distance_metric == "hamming": 84 | opencv_distance /= hasher.hash_length 85 | np.testing.assert_approx_equal( 86 | opencv_distance, 87 | hasher.compute_distance(hasher.compute(f1), hasher.compute(f2)), 88 | significant=4, 89 | ) 90 | 91 | 92 | def hash_dicts_to_df(hash_dicts, returns_multiple): 93 | assert all( 94 | h["error"] is None for h in hash_dicts 95 | ), "An error was found in the hash dictionaries" 96 | if returns_multiple: 97 | return pd.DataFrame( 98 | { 99 | "filepath": tools.flatten( 100 | [[h["filepath"]] * len(h["hash"]) for h in hash_dicts] 101 | ), 102 | "hash": tools.flatten([h["hash"] for h in hash_dicts]), 103 | } 104 | ).assign(error=None) 105 | return pd.DataFrame.from_records(hash_dicts).assign(error=None) 106 | 107 | 108 | def test_hasher_parallelization(hasher, test_filepaths): 109 | filepaths_10x = test_filepaths * 10 110 | if not hasher.allow_parallel: 111 | with pytest.warns(UserWarning, match="cannot be used in parallel"): 112 | hashes_parallel_dicts = hasher.compute_parallel(filepaths=filepaths_10x) 113 | else: 114 | hashes_parallel_dicts = hasher.compute_parallel(filepaths=filepaths_10x) 115 | hashes_sequential_dicts = [ 116 | {"filepath": filepath, "hash": hasher.compute(filepath), "error": None} 117 | for filepath in filepaths_10x 118 | ] 119 | hashes_parallel = hash_dicts_to_df( 120 | hashes_parallel_dicts, returns_multiple=hasher.returns_multiple 121 | ).sort_values(["filepath", "hash"]) 122 | hashes_sequential = hash_dicts_to_df( 123 | hashes_sequential_dicts, returns_multiple=hasher.returns_multiple 124 | ).sort_values(["filepath", "hash"]) 125 | assert (hashes_sequential.hash.values == hashes_parallel.hash.values).all() 126 | assert (hashes_sequential.filepath.values == hashes_parallel.filepath.values).all() 127 | 128 | 129 | def test_video_hasher_integrity( 130 | hasher: hashers.VideoHasher, test_videos: list[str] = DEFAULT_TEST_VIDEOS 131 | ): 132 | test_hasher_parallelization(hasher, test_videos) 133 | 134 | 135 | def test_image_hasher_integrity( 136 | hasher: hashers.ImageHasher, 137 | pil_opencv_threshold: float, 138 | transform_threshold: float, 139 | test_images: list[str] = DEFAULT_TEST_IMAGES, 140 | opencv_hasher: bool = False, 141 | ): 142 | """Test to ensure a hasher works correctly. 143 | 144 | Args: 145 | hasher: The hasher to test. 146 | test_images: A list of filepaths to images to use for testing. 147 | pil_opencv_threshold: The hash distance permitted for an image 148 | when loaded with OpenCV vs. PIL. 149 | transform_threshold: The permitted error in isometric transform 150 | hashes. 151 | opencv_hasher: Whether the hasher is an OpenCV hasher. Used to 152 | determine whether to check for consistent distances. 153 | """ 154 | assert len(test_images) >= 2, "You must provide at least two test images." 155 | image1 = test_images[0] 156 | image2 = test_images[1] 157 | hash1_1 = str(hasher.compute(image1)) # str() games for mypy, not proud 158 | hash1_2 = str(hasher.compute(Image.open(image1))) 159 | hash1_3 = str(hasher.compute(cv2.cvtColor(cv2.imread(image1), cv2.COLOR_BGR2RGB))) 160 | 161 | hash2_1 = str(hasher.compute(image2)) 162 | 163 | # There is a small distance because PIL and OpenCV read 164 | # JPEG images a little differently (e.g., libjpeg-turbo vs. libjpeg) 165 | assert hasher.compute_distance(hash1_1, hash1_2) < pil_opencv_threshold 166 | assert hasher.compute_distance(hash1_1, hash2_1) > pil_opencv_threshold 167 | assert hasher.compute_distance(hash1_1, hash1_3) == 0 168 | 169 | # Ensure the conversion to and from vectors works for both base64 and hex. 170 | assert hasher.vector_to_string(hasher.string_to_vector(hash2_1)) == hash2_1 171 | assert ( 172 | hasher.vector_to_string( 173 | hasher.string_to_vector( 174 | str( 175 | hasher.vector_to_string( 176 | hasher.string_to_vector(hash2_1), hash_format="hex" 177 | ) 178 | ), 179 | hash_format="hex", 180 | ) 181 | ) 182 | == hash2_1 183 | ) 184 | 185 | # Ensure parallelization works properly. 186 | test_hasher_parallelization(hasher=hasher, test_filepaths=test_images) 187 | 188 | # Ensure the isometric hashes computation work properly 189 | for image in test_images: 190 | transforms = hashers.tools.get_isometric_transforms(image) 191 | hashes_exp = { 192 | key: str(hasher.compute(value)) for key, value in transforms.items() 193 | } 194 | hashes_act = hasher.compute_isometric(transforms["r0"]) 195 | for transform_name in hashes_exp.keys(): 196 | assert ( 197 | hasher.compute_distance( 198 | hashes_exp[transform_name], hashes_act[transform_name] 199 | ) 200 | < transform_threshold 201 | ) 202 | 203 | # Verify that hashes are the correct length. 204 | hash_bits = hasher.hash_length * SIZES[hasher.dtype] 205 | 206 | words_base64 = math.ceil(hash_bits / 6) # Base64 uses 8 bits for every 6 bits 207 | words_base64 += ( 208 | 0 if words_base64 % 4 == 0 else 4 - (words_base64 % 4) 209 | ) # Base64 always uses multiples of four 210 | assert len(hash2_1) == words_base64 211 | 212 | words_hex = 2 * math.ceil(hash_bits / 8) # Hex uses 16 bits for every 8 bits 213 | words_hex += ( 214 | 0 if words_hex % 2 == 0 else 1 215 | ) # Two characters for every one character. 216 | assert ( 217 | len( 218 | str( 219 | hasher.vector_to_string( 220 | hasher.string_to_vector(hash2_1), hash_format="hex" 221 | ) 222 | ) 223 | ) 224 | == words_hex 225 | ) 226 | 227 | # Verify that low quality images yield zero quality 228 | image = np.zeros((100, 100, 3)).astype("uint8") # type: ignore 229 | _, quality = hasher.compute_with_quality(image) 230 | assert quality == 0 231 | 232 | # Verify that high quality images yield high quality 233 | # scores. 234 | assert ( 235 | min(hasher.compute_with_quality(filepath)[1] for filepath in test_images) == 100 236 | ) 237 | 238 | # Verify that medium quality images yield medium quality 239 | _, quality = hasher.compute_with_quality(LOW_DETAIL_IMAGE) 240 | assert 0 < quality < 100 241 | 242 | if opencv_hasher: 243 | test_opencv_hasher(hasher, image1, image2) 244 | -------------------------------------------------------------------------------- /tests/test_benchmarking.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import shutil 4 | import tempfile 5 | 6 | import numpy as np 7 | import pytest 8 | from imgaug import augmenters as iaa 9 | from scipy import spatial 10 | 11 | from perception import benchmarking, hashers, testing 12 | from perception.benchmarking import video_transforms 13 | from perception.benchmarking.image import BenchmarkImageDataset 14 | from perception.benchmarking.video import BenchmarkVideoDataset 15 | 16 | files = testing.DEFAULT_TEST_IMAGES 17 | dataset = BenchmarkImageDataset.from_tuples([(fn, i % 2) for i, fn in enumerate(files)]) 18 | 19 | 20 | def test_deduplicate(): 21 | tempdir = tempfile.TemporaryDirectory() 22 | new_file = os.path.join(tempdir.name, "dup_file.jpg") 23 | shutil.copy(files[0], new_file) 24 | duplicated_files = files + [new_file] 25 | deduplicated, duplicates = BenchmarkImageDataset.from_tuples( 26 | [(fn, i % 2) for i, fn in enumerate(duplicated_files)] 27 | ).deduplicate(hasher=hashers.AverageHash(), threshold=1e-2) 28 | assert len(duplicates) == 1 29 | assert len(deduplicated._df) == len(files) 30 | 31 | 32 | def test_bad_dataset(): 33 | bad_files = files + ["tests/images/nonexistent.jpg"] 34 | bad_dataset = BenchmarkImageDataset.from_tuples( 35 | [(fn, i % 2) for i, fn in enumerate(bad_files)] 36 | ) 37 | transforms = { 38 | "blur0.05": iaa.GaussianBlur(0.05), 39 | "noop": iaa.Resize(size=(256, 256)), 40 | } 41 | with pytest.raises(Exception): 42 | transformed = bad_dataset.transform( 43 | transforms=transforms, storage_dir="/tmp/transforms", errors="raise" 44 | ) 45 | with pytest.warns(UserWarning, match="occurred reading"): 46 | transformed = bad_dataset.transform( 47 | transforms=transforms, storage_dir="/tmp/transforms", errors="warn" 48 | ) 49 | assert len(transformed._df) == len(files) * 2 50 | 51 | 52 | def test_benchmark_dataset(): 53 | assert len(dataset._df) == len(files) 54 | assert len(dataset.filter(category=[0])._df) == len(files) / 2 55 | with pytest.warns(UserWarning, match="Did not find"): 56 | assert len(dataset.filter(category=[3])._df) == 0 57 | 58 | dataset.save("/tmp/dataset.zip") 59 | dataset.save("/tmp/dataset_folder") 60 | o1 = BenchmarkImageDataset.load("/tmp/dataset.zip") 61 | o2 = BenchmarkImageDataset.load("/tmp/dataset_folder") 62 | o3 = BenchmarkImageDataset.load("/tmp/dataset.zip") 63 | 64 | for opened in [o1, o2, o3]: 65 | assert ( 66 | opened._df["filepath"].apply(os.path.basename) 67 | == dataset._df["filepath"].apply(os.path.basename) 68 | ).all() 69 | 70 | 71 | def test_benchmark_transforms(): 72 | transformed = dataset.transform( 73 | transforms={ 74 | "blur0.05": iaa.GaussianBlur(0.05), 75 | "noop": iaa.Resize(size=(256, 256)), 76 | }, 77 | storage_dir="/tmp/transforms", 78 | ) 79 | 80 | assert len(transformed._df) == len(files) * 2 81 | 82 | hashes = transformed.compute_hashes(hashers={"pdna": hashers.PHash()}) 83 | tr = hashes.compute_threshold_recall().reset_index() 84 | 85 | hashes._metrics = None 86 | hashes._df.at[0, "hash"] = None 87 | with pytest.warns(UserWarning, match="invalid / empty hashes"): 88 | hashes.compute_threshold_recall() 89 | 90 | assert (tr[tr["transform_name"] == "noop"]["recall"] == 100.0).all() 91 | 92 | # This is a charting function but we execute it just to make sure 93 | # it runs without error. 94 | hashes.show_histograms() 95 | 96 | 97 | def convert_hash_string_to_vector(hash_string): 98 | buff = base64.decodebytes(hash_string.encode("utf-8")) 99 | return np.frombuffer(buff, dtype=np.uint8) 100 | 101 | 102 | def test_video_benchmark_dataset(): 103 | video_dataset = BenchmarkVideoDataset.from_tuples( 104 | files=[ 105 | ("perception/testing/videos/v1.m4v", "category1"), 106 | ("perception/testing/videos/v2.m4v", "category1"), 107 | ("perception/testing/videos/v1.m4v", "category2"), 108 | ("perception/testing/videos/v2.m4v", "category2"), 109 | ] 110 | ) 111 | transforms = { 112 | "noop": video_transforms.get_simple_transform(width=128, sar="1/1"), 113 | "gif": video_transforms.get_simple_transform(codec="gif", output_ext=".gif"), 114 | "clip1s": video_transforms.get_simple_transform(clip_s=(1, None)), 115 | "blackpad": video_transforms.get_black_frame_padding_transform(duration_s=1), 116 | "slideshow": video_transforms.get_slideshow_transform( 117 | frame_input_rate=1, frame_output_rate=1 118 | ), 119 | } 120 | transformed = video_dataset.transform( 121 | storage_dir=tempfile.TemporaryDirectory().name, transforms=transforms 122 | ) 123 | assert len(transformed._df) == len(transforms) * len(video_dataset._df) 124 | assert transformed._df["filepath"].isnull().sum() == 0 125 | 126 | # We will compute hashes for each of the transformed 127 | # videos and check the results for correctness. 128 | phash_framewise_hasher = hashers.FramewiseHasher( 129 | frame_hasher=hashers.PHash(), interframe_threshold=-1, frames_per_second=2 130 | ) 131 | hashes = transformed.compute_hashes( 132 | hashers={"phashframewise": phash_framewise_hasher} 133 | ) 134 | 135 | guid = hashes._df.guid.iloc[0] 136 | df = hashes._df[hashes._df["guid"] == guid] 137 | clip1s = df[(df.transform_name == "clip1s")] 138 | noop = df[(df.transform_name == "noop")] 139 | blackpad = df[(df.transform_name == "blackpad")] 140 | slideshow = df[(df.transform_name == "slideshow")] 141 | 142 | # We should have dropped two hashes from the beginning 143 | # on the clipped video. 144 | assert len(clip1s) == len(noop) - 2 145 | 146 | # The first hash from the clipped video should be the 147 | # same as the third hash from the original 148 | np.testing.assert_allclose( 149 | convert_hash_string_to_vector(clip1s.hash.iloc[0]), 150 | convert_hash_string_to_vector(noop.hash.iloc[2]), 151 | rtol=0.2, 152 | ) 153 | 154 | # The black padding adds four hashes (two on either side). 155 | assert len(blackpad) == len(noop) + 4 156 | 157 | # A black frame should yield all zeros for PHash 158 | assert phash_framewise_hasher.string_to_vector(blackpad.iloc[0].hash).sum() == 0 159 | 160 | # The slideshow hashes should be the same as the noop 161 | # hashes for every other hash. 162 | # Note: this is a weird test structure now because the original test, which was 163 | # assert (noop.hash.values[::2] == slideshow.hash.values[::2]).all() 164 | # kept failing because of 1 bit difference in 1 hash. This is keeps the same 165 | # spirit, but is more complex with a little leniency. We suspect the difference is 166 | # due to some versioning. So might be worthwhile to try replacing the test with the 167 | # original one occasionally. 168 | noop_hash_vectors = [ 169 | convert_hash_string_to_vector(h) for h in noop.hash.values[::2] 170 | ] 171 | slideshow_hash_vectors = [ 172 | convert_hash_string_to_vector(h) for h in slideshow.hash.values[::2] 173 | ] 174 | total_missed_bits = 0 175 | for noop_vector, slideshow_vector in zip(noop_hash_vectors, slideshow_hash_vectors): 176 | for n in range(0, len(noop_vector)): 177 | if noop_vector[n] != slideshow_vector[n]: 178 | total_missed_bits += 1 179 | assert total_missed_bits <= 2 180 | 181 | # Every second hash in the slideshow should be the same as the 182 | # previous one. 183 | for n in range(0, 10, 2): 184 | assert slideshow.hash.values[n] == slideshow.hash.values[n + 1] 185 | 186 | 187 | def test_euclidean_extension(): 188 | 189 | # This function plainly inplements the process of computing 190 | # the closest positive and negative examples and their indexes. 191 | def compute_euclidean_metrics_py(X_noop, X_transformed, mask): 192 | distance_matrix = spatial.distance.cdist( 193 | XA=X_transformed, XB=X_noop, metric="euclidean" 194 | ) 195 | pos = np.ma.masked_array(distance_matrix, np.logical_not(mask)) 196 | neg = np.ma.masked_array(distance_matrix, mask) 197 | distances = np.concatenate( 198 | [neg.min(axis=1).data[np.newaxis], pos.min(axis=1).data[np.newaxis]], axis=0 199 | ).T 200 | indexes = np.concatenate( 201 | [neg.argmin(axis=1)[np.newaxis], pos.argmin(axis=1)[np.newaxis]] 202 | ).T 203 | return distances, indexes 204 | 205 | X_noop = np.random.uniform(low=0, high=255, size=(5, 144)).astype("int32") 206 | X_trans = np.random.uniform(low=0, high=255, size=(10, 144)).astype("int32") 207 | mask = np.array([True, False] * 5 * 5).reshape(10, 5) 208 | 209 | distances, indexes = benchmarking.common.extensions.compute_euclidean_metrics( 210 | X_noop, X_trans, mask 211 | ) 212 | distances_py, indexes_py = compute_euclidean_metrics_py(X_noop, X_trans, mask) 213 | 214 | assert (indexes_py == indexes).all() 215 | np.testing.assert_allclose(distances, distances_py) 216 | -------------------------------------------------------------------------------- /tests/test_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from perception import hashers, testing, tools 9 | 10 | 11 | def test_deduplicate(): 12 | directory = tempfile.TemporaryDirectory() 13 | original = testing.DEFAULT_TEST_IMAGES[0] 14 | duplicate = os.path.join(directory.name, "image1.jpg") 15 | shutil.copy(original, duplicate) 16 | pairs = tools.deduplicate( 17 | files=[ 18 | testing.DEFAULT_TEST_IMAGES[0], 19 | testing.DEFAULT_TEST_IMAGES[1], 20 | duplicate, 21 | ], 22 | hashers=[(hashers.PHash(hash_size=16), 0.25)], 23 | ) 24 | assert len(pairs) == 1 25 | file1, file2 = pairs[0] 26 | assert ((file1 == duplicate) and (file2 == original)) or ( 27 | (file1 == original) and (file2 == duplicate) 28 | ) 29 | 30 | 31 | def test_deduplicate_u8(): 32 | # This test verifies that extensions.compute_euclidean_pairwise_duplicates 33 | # works properly. 34 | directory = tempfile.TemporaryDirectory() 35 | original = testing.DEFAULT_TEST_IMAGES[0] 36 | duplicate = os.path.join(directory.name, "image1.jpg") 37 | shutil.copy(original, duplicate) 38 | pairs = tools.deduplicate( 39 | files=[ 40 | testing.DEFAULT_TEST_IMAGES[0], 41 | testing.DEFAULT_TEST_IMAGES[1], 42 | duplicate, 43 | ], 44 | hashers=[(hashers.PHashU8(hash_size=16), 10)], 45 | ) 46 | assert len(pairs) == 1 47 | file1, file2 = pairs[0] 48 | assert ((file1 == duplicate) and (file2 == original)) or ( 49 | (file1 == original) and (file2 == duplicate) 50 | ) 51 | 52 | 53 | def test_deduplicate_hashes_multiple(): 54 | # This test verifies that deduplicate_hashes functions properly 55 | # when there is more than one hash for a file. 56 | directory = tempfile.TemporaryDirectory() 57 | original = testing.DEFAULT_TEST_IMAGES[0] 58 | duplicate = os.path.join(directory.name, "image1.jpg") 59 | hasher = hashers.PHashU8(hash_size=16) 60 | shutil.copy(original, duplicate) 61 | hashes = [ 62 | (0, hasher.compute(original)), 63 | (1, hasher.compute(duplicate)), 64 | (1, hasher.compute(duplicate)), 65 | (1, hasher.compute(duplicate)), 66 | (2, hasher.compute(testing.DEFAULT_TEST_IMAGES[1])), 67 | ] 68 | pairs = tools.deduplicate_hashes( 69 | hashes=hashes, # type: ignore[arg-type] 70 | threshold=10, 71 | hash_format="base64", 72 | hash_length=hasher.hash_length, 73 | distance_metric="euclidean", 74 | hash_dtype="uint8", 75 | ) 76 | assert len(pairs) == 1 77 | file1, file2 = pairs[0] 78 | assert ((file1 == 0) and (file2 == 1)) or ((file1 == 1) and (file2 == 0)) 79 | 80 | 81 | def test_compute_euclidean_pairwise_duplicates(): 82 | # The purpose of this test is to verify that the handling of 83 | # deduplication with files that have multiple hashes works 84 | # properly. This is particularly important for video where 85 | # we are likely to have many hashes. 86 | X = np.array( 87 | [ 88 | # File 1 89 | [0, 0, 0], 90 | [1, 1, 1], 91 | [2, 2, 2], 92 | # File 2 93 | [1, 1, 1], 94 | [2, 2, 2], 95 | [3, 3, 3], 96 | # File 3 97 | [3, 3, 3], 98 | [4, 4, 4], 99 | # File 4 100 | [5, 5, 5], 101 | [6, 6, 6], 102 | ] 103 | ) 104 | 105 | # Use grouped files. 106 | counts = np.array([3, 3, 2, 2]) 107 | expected = np.array( 108 | [[2 / 3, 2 / 3], [0, 0], [0, 0], [1 / 3, 1 / 2], [0, 0], [0, 0]] 109 | ) 110 | actual = tools.extensions.compute_euclidean_pairwise_duplicates( 111 | X=X.astype("int32"), 112 | threshold=1, 113 | counts=counts.astype("uint32"), 114 | compute_overlap=True, 115 | ) 116 | assert (expected == actual).all() 117 | 118 | # Use without computing overlap. 119 | expected = np.array([[2, 2], [0, 0], [0, 0], [1, 1], [0, 0], [0, 0]]) 120 | actual = tools.extensions.compute_euclidean_pairwise_duplicates( 121 | X=X.astype("int32"), 122 | threshold=1, 123 | counts=counts.astype("uint32"), 124 | compute_overlap=False, 125 | ) 126 | assert (expected == actual).all() 127 | 128 | # Use ungrouped files. 129 | X = np.array( 130 | [ 131 | # File 1 132 | [0, 0, 0], 133 | [1, 1, 1], 134 | [2, 2, 2], 135 | [1, 1, 1], 136 | ] 137 | ) 138 | expected = np.array([[0, 0], [0, 0], [0, 0], [0, 0], [1, 1], [0, 0]]) 139 | actual = tools.extensions.compute_euclidean_pairwise_duplicates( 140 | X=X.astype("int32"), threshold=1, compute_overlap=True 141 | ) 142 | assert (expected == actual).all() 143 | 144 | 145 | def test_api_is_over_https(): 146 | matcher_https = tools.SaferMatcher(api_key="foo", url="https://www.example.com/") 147 | assert matcher_https 148 | 149 | if "SAFER_MATCHING_SERVICE_DEV_ALLOW_HTTP" in os.environ: 150 | del os.environ["SAFER_MATCHING_SERVICE_DEV_ALLOW_HTTP"] 151 | with pytest.raises(ValueError): 152 | tools.SaferMatcher(api_key="foo", url="http://www.example.com/") 153 | 154 | os.environ["SAFER_MATCHING_SERVICE_DEV_ALLOW_HTTP"] = "1" 155 | matcher_http_with_escape_hatch = tools.SaferMatcher( 156 | api_key="foo", url="http://www.example.com/" 157 | ) 158 | assert matcher_http_with_escape_hatch 159 | 160 | 161 | def test_unletterbox(): 162 | image = hashers.tools.read(testing.DEFAULT_TEST_IMAGES[0]) 163 | padded = np.zeros((image.shape[0] + 100, image.shape[1] + 50, 3), dtype="uint8") 164 | padded[50 : 50 + image.shape[0], 25 : 25 + image.shape[1]] = image 165 | result = hashers.tools.unletterbox(padded) 166 | assert result is not None 167 | (x1, x2), (y1, y2) = result 168 | assert y1 == 50 169 | assert y2 == 50 + image.shape[0] 170 | assert x1 == 25 171 | assert x2 == 25 + image.shape[1] 172 | 173 | 174 | def test_unletterbox_crop(): 175 | image = hashers.tools.read(testing.DEFAULT_TEST_IMAGES[0]) 176 | padded = np.zeros((image.shape[0] + 100, image.shape[1] + 50, 3), dtype="uint8") 177 | padded[50 : 50 + image.shape[0], 25 : 25 + image.shape[1]] = image 178 | cropped_image = hashers.tools.unletterbox_crop(padded) 179 | assert cropped_image is not None 180 | assert image.shape[0] == cropped_image.shape[0] 181 | assert image.shape[1] == cropped_image.shape[1] 182 | 183 | 184 | def test_unletterbox_crop_meaningful_pixels(): 185 | """Test the value of .5 min_fraction_meaningful_pixels in unletterbox().""" 186 | image = hashers.tools.read(testing.DEFAULT_TEST_IMAGES[0]) 187 | h, w, _ = image.shape 188 | 189 | # make tall skinny images with lots of padding around the content 190 | # so its below min_fraction_meaningful_pixels threshold 191 | padding_size = int(5 * h) 192 | 193 | padded = np.r_[ 194 | np.zeros((padding_size, w, 3)), image, np.zeros((padding_size, w, 3)) 195 | ] 196 | assert None is hashers.tools.unletterbox_crop( 197 | padded, min_fraction_meaningful_pixels=0.5 198 | ) 199 | 200 | 201 | def test_unletterbox_color(): 202 | image = hashers.tools.read(testing.DEFAULT_TEST_IMAGES[0]) 203 | padded = np.zeros((image.shape[0] + 100, image.shape[1] + 50, 3), dtype="uint8") 204 | padded[:, :] = (200, 0, 200) 205 | padded[50 : 50 + image.shape[0], 25 : 25 + image.shape[1]] = image 206 | # Should not unletterbox since not black. 207 | results = hashers.tools.unletterbox(padded, only_remove_black=True) 208 | assert results is not None 209 | (x1, x2), (y1, y2) = results 210 | assert y1 == 0 211 | assert y2 == padded.shape[0] 212 | assert x1 == 0 213 | assert x2 == padded.shape[1] 214 | 215 | # Should unletterbox color: 216 | results = hashers.tools.unletterbox(padded, only_remove_black=False) 217 | assert results is not None 218 | (x1, x2), (y1, y2) = results 219 | assert y1 == 50 220 | assert y2 == 50 + image.shape[0] 221 | assert x1 == 25 222 | assert x2 == 25 + image.shape[1] 223 | 224 | 225 | def test_unletterbox_aspect_ratio(): 226 | """Test the value of .1 in unletterbox().""" 227 | image = hashers.tools.read(testing.DEFAULT_TEST_IMAGES[0]) 228 | h, w, z = image.shape 229 | 230 | # make tall skinny images with non-trivial content just below and 231 | # above 10% threshold 232 | base = int(4.5 * h) # 2 * base + h = 100% 233 | h_fail, h_pass = base + 10, base - 10 234 | 235 | padded = np.r_[np.zeros((h_fail, w, 3)), image, np.zeros((h_fail, w, 3))] 236 | assert None is hashers.tools.unletterbox(padded) 237 | 238 | padded = np.r_[np.zeros((h_pass, w, 3)), image, np.zeros((h_pass, w, 3))] 239 | 240 | results = hashers.tools.unletterbox(padded) 241 | assert results is not None 242 | (x1, x2), (y1, y2) = results 243 | 244 | assert y1 == h_pass 245 | assert y2 == h_pass + image.shape[0] 246 | assert x1 == 0 247 | assert x2 == image.shape[1] 248 | 249 | 250 | def test_unletterbox_noblackbars(): 251 | image = hashers.tools.read(testing.DEFAULT_TEST_IMAGES[0]) 252 | 253 | results = hashers.tools.unletterbox(image) 254 | assert results is not None 255 | (x1, x2), (y1, y2) = results 256 | assert x1 == 0 257 | assert y1 == 0 258 | assert x2 == image.shape[1] 259 | assert y2 == image.shape[0] 260 | 261 | 262 | def test_ffmpeg_video(): 263 | """Check that the FFMPEG video parsing code provides substantially similar 264 | results to the OpenCV approach (which also uses FFMPEG under the hood but 265 | also has different frame selection logic).""" 266 | frames_per_second = 2.3 267 | for filepath in testing.DEFAULT_TEST_VIDEOS: 268 | filename = os.path.basename(filepath) 269 | for (frame1, index1, timestamp1), (frame2, index2, timestamp2) in zip( 270 | hashers.tools.read_video_to_generator_ffmpeg( 271 | filepath, frames_per_second=frames_per_second 272 | ), 273 | hashers.tools.read_video_to_generator( 274 | filepath, frames_per_second=frames_per_second 275 | ), 276 | ): 277 | diff = np.abs(frame1.astype("int32") - frame2.astype("int32")).flatten() 278 | assert index1 == index2, f"Index mismatch for {filename}" 279 | np.testing.assert_allclose( 280 | timestamp1, timestamp2 281 | ), f"Timestamp mismatch for {filename}" 282 | assert np.percentile(diff, 75) < 25, f"Frame mismatch for {filename}" 283 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | https://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | Copyright 2019 Thorn 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | https://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. -------------------------------------------------------------------------------- /tests/test_local_descriptor_deduplication.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import cv2 5 | import imgaug 6 | import pandas as pd 7 | import pytest 8 | 9 | 10 | import perception.benchmarking.image as pb 11 | import perception.benchmarking.image_transforms as pbit 12 | import perception.approximate_deduplication as ad 13 | import perception.local_descriptor_deduplication as ldd 14 | import perception.hashers.tools as pht 15 | import perception.testing as pt 16 | from perception.approximate_deduplication.debug import vizualize_pair 17 | 18 | # Params for object level matching. 19 | OBJECT_MATCH_PARAMS = { 20 | "strong_match_threshold": 0.3, # Ideally something close to 95% precision. 21 | "ratio": 0.5, 22 | "coarse_pct_probe": 0.1, 23 | "minimum_coarse_overlap": 0.001, 24 | "coarse_threshold": 100.0, 25 | "minimum_validation_match": 0.04, 26 | "minimum_validation_intersection": 0.04, 27 | "minimum_validation_inliers": 6, 28 | } 29 | 30 | 31 | @pytest.mark.parametrize("hasher", [ldd.SIFT(), ldd.AKAZE()]) 32 | def test_deduplication(hasher): 33 | tdir = tempfile.TemporaryDirectory() 34 | watermark = cv2.cvtColor( 35 | cv2.imread(pt.DEFAULT_TEST_LOGOS[0], cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA 36 | ) 37 | transformed = pb.BenchmarkImageDataset.from_tuples( 38 | files=[(filepath, "test") for filepath in pt.DEFAULT_TEST_IMAGES] 39 | ).transform( 40 | transforms={ 41 | "noop": lambda image: image, 42 | "pad": imgaug.augmenters.Pad(percent=0.1), 43 | "crop": imgaug.augmenters.Crop(percent=0.1), 44 | "watermark": pbit.apply_watermark(watermark, alpha=1, size=0.8), 45 | }, 46 | storage_dir=tdir.name, 47 | ) 48 | df = transformed._df.set_index("filepath") 49 | pairs = ldd.deduplicate( 50 | filepaths_or_reference_df=df.index, max_workers=2, hasher=hasher 51 | ) # Test throws errors if unset. 52 | 53 | clustered = ( 54 | pd.DataFrame( 55 | ad.pairs_to_clusters(ids=df.index, pairs=pairs, strictness="component") 56 | ) 57 | .set_index("id") 58 | .merge(df, left_index=True, right_index=True) 59 | .reset_index() 60 | ) 61 | print("test2") 62 | n_clusters = clustered["cluster"].nunique() 63 | n_transforms = clustered["transform_name"].nunique() 64 | perfect = ( 65 | clustered.groupby("cluster") 66 | .apply( 67 | lambda g: g["guid"].nunique() == 1 68 | and g["transform_name"].nunique() == n_transforms 69 | ) 70 | .sum() 71 | ) 72 | 73 | tainted = clustered.groupby("cluster")["guid"].nunique().gt(1).sum() 74 | pct_perfect = perfect / n_clusters 75 | pct_tainted = tainted / n_clusters 76 | assert pct_tainted == 0 77 | assert pct_perfect > 0.1 78 | 79 | 80 | @pytest.mark.parametrize("hasher", [ldd.SIFT(), ldd.AKAZE()]) 81 | def test_deduplication_across_sets(hasher): 82 | tdir = tempfile.TemporaryDirectory() 83 | watermark = cv2.cvtColor( 84 | cv2.imread(pt.DEFAULT_TEST_LOGOS[0], cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA 85 | ) 86 | transformed = pb.BenchmarkImageDataset.from_tuples( 87 | files=[(filepath, "test") for filepath in pt.DEFAULT_TEST_IMAGES] 88 | ).transform( 89 | transforms={ 90 | "noop": lambda image: image, 91 | "pad": imgaug.augmenters.Pad(percent=0.1), 92 | "crop": imgaug.augmenters.Crop(percent=0.1), 93 | "watermark": pbit.apply_watermark(watermark, alpha=1, size=0.8), 94 | }, 95 | storage_dir=tdir.name, 96 | ) 97 | 98 | df = transformed._df.set_index("filepath") 99 | query_images = list(df[df.transform_name == "noop"].index.values) 100 | images_to_match_to = list(df[~(df.transform_name == "noop")].index.values) 101 | 102 | pairs = ldd.deduplicate( 103 | filepaths_or_reference_df=images_to_match_to, 104 | query_filepaths_or_df=query_images, 105 | max_workers=2, 106 | hasher=hasher, 107 | ) # Test throws errors if unset. 108 | 109 | assert len(pairs) >= 20, "Wrong # of pairs." 110 | only_one_noop = [p for p in pairs if (("noop" in p[0]) != ("noop" in p[1]))] 111 | assert len(only_one_noop) == len( 112 | pairs 113 | ), "All pairs must be between a noop and non-noop file" 114 | 115 | 116 | @pytest.mark.parametrize("hasher", [ldd.SIFT(), ldd.AKAZE()]) 117 | def test_validation_for_overlapping_case(hasher): 118 | tdir = tempfile.TemporaryDirectory() 119 | # Each image will have the center of the other 120 | # pasted in the top left corner. 121 | image1 = pht.read(pt.DEFAULT_TEST_IMAGES[0]) 122 | image2 = pht.read(pt.DEFAULT_TEST_IMAGES[1]) 123 | image1[:100, :100] = image2[100:200, 100:200] 124 | image2[:100, :100] = image1[100:200, 100:200] 125 | fp1 = os.path.join(tdir.name, "test1.jpg") 126 | fp2 = os.path.join(tdir.name, "test2.jpg") 127 | cv2.imwrite(fp1, image1[..., ::-1]) 128 | cv2.imwrite(fp2, image2[..., ::-1]) 129 | descriptor1 = ldd.generate_image_descriptors(fp1, hasher) 130 | descriptor2 = ldd.generate_image_descriptors(fp2, hasher) 131 | assert descriptor1 is not None 132 | assert descriptor2 is not None 133 | 134 | # These images should not match. 135 | assert not hasher.validate_match(descriptor1=descriptor1, descriptor2=descriptor2)[ 136 | 0 137 | ] 138 | 139 | 140 | @pytest.mark.parametrize("hasher", [ldd.SIFT(), ldd.AKAZE()]) 141 | def test_handling_bad_file_case(caplog, hasher): 142 | tdir = tempfile.TemporaryDirectory() 143 | missing_file = os.path.join(tdir.name, "missing-file") 144 | bad_file_handle = tempfile.NamedTemporaryFile() 145 | bad_file = bad_file_handle.name 146 | transformed = pb.BenchmarkImageDataset.from_tuples( 147 | files=[(filepath, "test") for filepath in pt.DEFAULT_TEST_IMAGES] 148 | ).transform( 149 | transforms={ 150 | "noop": lambda image: image, 151 | }, 152 | storage_dir=tdir.name, 153 | ) 154 | df = transformed._df.set_index("filepath") 155 | df.loc[missing_file] = df.iloc[0] 156 | df.loc[bad_file] = df.iloc[0] 157 | pairs = ldd.deduplicate(filepaths_or_reference_df=df.index, hasher=hasher) 158 | clustered = ( 159 | pd.DataFrame( 160 | ad.pairs_to_clusters(ids=df.index, pairs=pairs, strictness="component") 161 | ) 162 | .set_index("id") 163 | .merge(df, left_index=True, right_index=True) 164 | .reset_index() 165 | ) 166 | 167 | assert bad_file not in clustered.index 168 | assert missing_file not in clustered.index 169 | 170 | bad_file_error = next( 171 | record for record in caplog.records if bad_file in record.message 172 | ) 173 | assert bad_file_error 174 | assert bad_file_error.levelname == "ERROR" 175 | 176 | missing_file_warning = next( 177 | record for record in caplog.records if missing_file in record.message 178 | ) 179 | assert missing_file_warning 180 | assert missing_file_warning.levelname == "WARNING" 181 | 182 | 183 | def test_handling_hasher_mismatch(): 184 | tdir = tempfile.TemporaryDirectory() 185 | transformed = pb.BenchmarkImageDataset.from_tuples( 186 | files=[(filepath, "test") for filepath in pt.DEFAULT_TEST_IMAGES] 187 | ).transform( 188 | transforms={ 189 | "noop": lambda image: image, 190 | }, 191 | storage_dir=tdir.name, 192 | ) 193 | df = transformed._df.set_index("filepath") 194 | reference_df = ldd.build_reference_df(filepaths=df.index, hasher=ldd.SIFT()) 195 | query_df = ldd.build_reference_df(filepaths=df.index, hasher=ldd.AKAZE()) 196 | with pytest.raises(AssertionError): 197 | ldd.deduplicate(reference_df, query_df) 198 | 199 | 200 | def test_viz_pair(): 201 | object_sift = ldd.SIFT( 202 | max_features=256, 203 | ratio=OBJECT_MATCH_PARAMS["ratio"], 204 | threshold=OBJECT_MATCH_PARAMS["coarse_threshold"], 205 | overlap=OBJECT_MATCH_PARAMS["minimum_coarse_overlap"], 206 | validation_match=OBJECT_MATCH_PARAMS["minimum_validation_match"], 207 | validation_inliers=OBJECT_MATCH_PARAMS["minimum_validation_inliers"], 208 | validation_intersection=OBJECT_MATCH_PARAMS["minimum_validation_intersection"], 209 | ) 210 | filepaths = [ 211 | "tests/images/chair.png", 212 | "tests/images/chair3.png", 213 | "tests/images/chair-square.png", 214 | "tests/images/chair-tall.png", 215 | ] 216 | reference_df = ldd.build_reference_df( 217 | filepaths=filepaths, 218 | hasher=object_sift, 219 | min_features=10, 220 | max_size=1000, 221 | show_progress=False, 222 | ) 223 | pairs = ldd.deduplicate( 224 | filepaths_or_reference_df=reference_df, 225 | hasher=object_sift, 226 | max_size=1000, 227 | min_features=10, 228 | verbose=True, 229 | ) 230 | row = pairs[0] 231 | viz_img = vizualize_pair( 232 | reference_df.loc[row[0]], 233 | reference_df.loc[row[1]], 234 | 0.5, 235 | match_metadata=row[2], 236 | sanitized=False, 237 | ) 238 | viz_img = cv2.cvtColor(viz_img, cv2.COLOR_RGB2BGR) 239 | cv2.imwrite("tests/images/debug-image.png", viz_img) 240 | 241 | 242 | def test_viz_pair_symmetry(): 243 | # This test catches a regression where if the smaller image was the query one LDD would swap 244 | # points during distance calculation, but not unswap points before returning them. 245 | object_sift = ldd.SIFT( 246 | max_features=256, 247 | ratio=OBJECT_MATCH_PARAMS["ratio"], 248 | threshold=OBJECT_MATCH_PARAMS["coarse_threshold"], 249 | overlap=OBJECT_MATCH_PARAMS["minimum_coarse_overlap"], 250 | validation_match=OBJECT_MATCH_PARAMS["minimum_validation_match"], 251 | validation_inliers=OBJECT_MATCH_PARAMS["minimum_validation_inliers"], 252 | validation_intersection=OBJECT_MATCH_PARAMS["minimum_validation_intersection"], 253 | ) 254 | filepaths = [ 255 | "tests/images/chair.png", 256 | "tests/images/chair3.png", 257 | ] 258 | reference_df = ldd.build_reference_df( 259 | filepaths=filepaths, 260 | hasher=object_sift, 261 | min_features=10, 262 | max_size=1000, 263 | show_progress=False, 264 | ) 265 | pairs = ldd.deduplicate( 266 | filepaths_or_reference_df=filepaths[:1], 267 | query_filepaths_or_df=filepaths[1:], 268 | hasher=object_sift, 269 | max_size=1000, 270 | min_features=10, 271 | verbose=True, 272 | ) 273 | row = pairs[0] 274 | viz_img = vizualize_pair( 275 | reference_df.loc[row[0]], 276 | reference_df.loc[row[1]], 277 | 0.5, 278 | match_metadata=row[2], 279 | sanitized=False, 280 | ) 281 | viz_img = cv2.cvtColor(viz_img, cv2.COLOR_RGB2BGR) 282 | cv2.imwrite("tests/images/debug-image-symmetry-1.png", viz_img) 283 | 284 | # Swap order of ref and query files. 285 | pairs = ldd.deduplicate( 286 | filepaths_or_reference_df=filepaths[1:], 287 | query_filepaths_or_df=filepaths[:1], 288 | hasher=object_sift, 289 | max_size=1000, 290 | min_features=10, 291 | verbose=True, 292 | ) 293 | row = pairs[0] 294 | viz_img = vizualize_pair( 295 | reference_df.loc[row[0]], 296 | reference_df.loc[row[1]], 297 | 0.5, 298 | match_metadata=row[2], 299 | sanitized=False, 300 | ) 301 | viz_img = cv2.cvtColor(viz_img, cv2.COLOR_RGB2BGR) 302 | cv2.imwrite("tests/images/debug-image-symmetry-2.png", viz_img) 303 | -------------------------------------------------------------------------------- /perception/approximate_deduplication/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os.path as op 4 | import typing 5 | 6 | import faiss 7 | import networkit as nk 8 | import numpy as np 9 | import tqdm 10 | import typing_extensions 11 | 12 | LOGGER = logging.getLogger(__name__) 13 | DEFAULT_PCT_PROBE = 0 14 | 15 | 16 | # For faiss training on datasets larger than 50,000 vectors, we take a random sub-sample. 17 | TRAIN_LARGE_SIZE: int = 50_000 18 | 19 | 20 | class ClusterAssignment(typing_extensions.TypedDict): 21 | cluster: int 22 | id: typing.Any 23 | 24 | 25 | def build_index( 26 | X: np.ndarray, 27 | pct_probe: float = DEFAULT_PCT_PROBE, 28 | approximate: bool = True, 29 | use_gpu: bool = True, 30 | ): 31 | """Buid a FAISS index from a reference dataframe. 32 | 33 | Args: 34 | X: The vectors to add to the index. 35 | pct_probe: The minimum fraction of nearest lists to search. If 36 | the product of pct_probe and the number of lists is less 37 | than 1, one list will be searched. 38 | approximate: Whether to build an approximate or exact index. 39 | 40 | Returns: 41 | An (index, lookup) tuple where the lookup returns the filepath 42 | for a given entry in the index. 43 | """ 44 | if X is None: 45 | return None 46 | X = X.astype("float32") 47 | d = X.shape[1] 48 | if approximate: 49 | ntotal = X.shape[0] 50 | nlist = int(max(min(4 * np.sqrt(ntotal), ntotal / 39), 1)) 51 | quantizer = faiss.IndexFlatL2(d) 52 | index = faiss.IndexIVFFlat(quantizer, d, nlist) 53 | gpu = False 54 | if use_gpu: 55 | try: 56 | res = faiss.StandardGpuResources() 57 | index = faiss.index_cpu_to_gpu(res, 0, index) 58 | gpu = True 59 | except AttributeError: 60 | LOGGER.info("Building approximate FAISS index on CPU.") 61 | 62 | if X.shape[0] > TRAIN_LARGE_SIZE: 63 | # Take random sample of 50,000 or 39 points per centroid. 64 | # 39 points per centroid is the min for for not getting warnings. 65 | # https://github.com/facebookresearch/faiss/wiki/FAQ#can-i-ignore-warning-clustering-xxx-points-to-yyy-centroids 66 | sample_size = max(39 * nlist, TRAIN_LARGE_SIZE) 67 | index.train(X[np.random.choice(X.shape[0], sample_size, replace=False)]) 68 | else: 69 | index.train(X) 70 | 71 | batch_size = 10_000 72 | for i in range(0, X.shape[0], batch_size): 73 | index.add(X[i : i + batch_size]) 74 | if gpu: 75 | index = faiss.index_gpu_to_cpu(index) 76 | nprobe = max(math.ceil(pct_probe * nlist), 1) 77 | faiss.ParameterSpace().set_index_parameter(index, "nprobe", nprobe) 78 | else: 79 | index = faiss.IndexFlat(d) 80 | index.add(X) 81 | return index 82 | 83 | 84 | def compute_euclidean_pairwise_duplicates_approx( 85 | X, 86 | counts, 87 | threshold, 88 | minimum_overlap, 89 | Y=None, 90 | y_counts=None, 91 | pct_probe=0.1, 92 | use_gpu: bool = True, 93 | faiss_cache_path: str | None = None, 94 | show_progress: bool = False, 95 | ): 96 | """Provides the same result as perception.extensions.compute_pairwise_duplicates_simple 97 | but uses an approximate search instead of an exhaustive search, which can dramatically reduce 98 | processing time. 99 | 100 | Args: 101 | X: An array of vectors to compute pairs for. 102 | Y: if provided we search in X for Y vectors. 103 | counts: A list of counts of vectors for separate files in the 104 | in the vectors (should add up to the length of X) 105 | threshold: The threshold for a match as a euclidean distance. 106 | minimum_overlap: The minimum overlap between two files to qualify as a match. 107 | pct_probe: The minimum percentage of sublists to search for matches. The larger the 108 | value, the more exhaustive the search. 109 | faiss_cache_path: If provided load any existing faiss index from this path, and if 110 | it does not exist then save the generated faiss index to the path. 111 | show_progress: Whether or not to show a progress bar while computing pairs 112 | Returns: 113 | A list of pairs of matching file indexes. 114 | """ 115 | assert ( 116 | counts.sum() == X.shape[0] 117 | ), "Length of counts incompatible with vectors shape." 118 | assert (Y is None) == ( 119 | y_counts is None 120 | ), "Must provide both or neither for y, y_counts." 121 | if X.dtype != "float32": 122 | # Only make the copy if we have to. 123 | X = X.astype("float32") 124 | 125 | if Y is not None and Y.dtype != "float32": 126 | # Only make the copy if we have to. 127 | Y = Y.astype("float32") 128 | 129 | lookup_ = [] 130 | for idx, count in enumerate(counts): 131 | lookup_.extend([idx] * count) 132 | lookup = np.array(lookup_) 133 | 134 | if faiss_cache_path is not None and op.exists(faiss_cache_path): 135 | LOGGER.debug("Loading cached FAISS index from %s", faiss_cache_path) 136 | index = faiss.read_index(faiss_cache_path) 137 | assert ( 138 | X.shape[0] == index.ntotal 139 | ), "Cached FAISS index does not match provided X." 140 | else: 141 | LOGGER.debug("Building FAISS index.") 142 | index = build_index(X=X, pct_probe=pct_probe, approximate=True, use_gpu=use_gpu) 143 | if faiss_cache_path is not None: 144 | faiss.write_index(index, faiss_cache_path) 145 | 146 | LOGGER.debug("FAISS index ready, start aprox search") 147 | pairs = [] 148 | 149 | # Only use y_counts if present. 150 | if y_counts is None: 151 | iterator_counts = counts 152 | M = X 153 | else: 154 | iterator_counts = y_counts 155 | M = Y 156 | 157 | for end, length, query in tqdm.tqdm( 158 | zip(iterator_counts.cumsum(), iterator_counts, range(len(iterator_counts))), 159 | total=len(iterator_counts), 160 | disable=not show_progress, 161 | desc="Vectors", 162 | ): 163 | if length == 0: 164 | continue 165 | Xq = M[end - length : end] 166 | lims, _, idxs = index.range_search(Xq, threshold**2) 167 | lims = lims.astype("int32") 168 | matched = [ 169 | match 170 | for match in np.unique(lookup[list(set(idxs))]) # type: ignore 171 | if match != query 172 | or Y is not None # Protect self matches if Y is not present. 173 | ] 174 | query_in_match: typing.Mapping[int, set] = {m: set() for m in matched} 175 | match_in_query: typing.Mapping[int, set] = {m: set() for m in matched} 176 | for query_idx in range(length): 177 | for match_idx in idxs[lims[query_idx] : lims[query_idx + 1]]: 178 | match = lookup[match_idx] 179 | if ( 180 | match == query and Y is None 181 | ): # Protect self matches if Y is not present. 182 | continue 183 | match_in_query[match].add(match_idx) 184 | query_in_match[match].add(query_idx) 185 | for match in matched: 186 | overlap = min( 187 | [ 188 | len(query_in_match[match]) / length, 189 | len(match_in_query[match]) / counts[match], 190 | ] 191 | ) 192 | if overlap >= minimum_overlap and overlap > 0: 193 | if Y is None: 194 | pairs.append(tuple(sorted([query, match]))) 195 | else: 196 | pairs.append(tuple([query, match])) 197 | return list(set(pairs)) 198 | 199 | 200 | def pairs_to_clusters( 201 | ids: typing.Iterable[str], 202 | pairs: typing.Iterable[tuple[str, str]], 203 | strictness: typing_extensions.Literal[ 204 | "clique", "community", "component" 205 | ] = "clique", 206 | max_clique_batch_size: int = 1000, 207 | ) -> list[ClusterAssignment]: 208 | """Given a list of pairs of matching files, compute sets 209 | of cliques where all files in a clique are connected. 210 | Args: 211 | ids: A list of node ids (e.g., filepaths). 212 | pairs: A list of pairs of node ids, each pair is assumed to have an edge 213 | strictness: The level at which groups will be clustered. "component" 214 | means that all clusters will be connected components. "community" 215 | will select clusters of files within components that are clustered 216 | together. "clique" will result in clusters where every file is 217 | connected to every other file. 218 | max_clique_batch_size: The maximum batch size for identifying 219 | cliques. 220 | Returns: 221 | A list of cluster assignments (dicts with id and cluster 222 | entries). 223 | """ 224 | assert strictness in ["component", "community", "clique"], "Invalid strictness." 225 | list_ids = list(ids) 226 | id_to_node_map = {v: i for i, v in enumerate(list_ids)} 227 | node_to_id_map = {v: k for k, v in id_to_node_map.items()} 228 | 229 | LOGGER.debug("Building graph.") 230 | graph = nk.Graph(len(list_ids)) 231 | node_pairs = {(id_to_node_map[pair[0]], id_to_node_map[pair[1]]) for pair in pairs} 232 | for node_pair in node_pairs: 233 | graph.addEdge(node_pair[0], node_pair[1]) 234 | 235 | assignments: list[ClusterAssignment] = [] 236 | cluster_index = 0 237 | cc_query = nk.components.ConnectedComponents(graph) 238 | cc_query.run() 239 | components = cc_query.getComponents() 240 | 241 | for component in components: 242 | LOGGER.debug("Got component with size: %s", len(component)) 243 | if strictness == "component": 244 | assignments.extend( 245 | [{"id": node_to_id_map[n], "cluster": cluster_index} for n in component] 246 | ) 247 | cluster_index += 1 248 | continue 249 | # Map between node values for a connected component 250 | component_node_map = dict(enumerate(component)) 251 | cc_sub_graph = nk.graphtools.subgraphFromNodes(graph, component, compact=True) 252 | algo = nk.community.PLP(cc_sub_graph) 253 | algo.run() 254 | communities = algo.getPartition() 255 | community_map = communities.subsetSizeMap() 256 | for community, size in community_map.items(): 257 | LOGGER.debug("Got community with size: %s", size) 258 | community_members = list( 259 | communities.getMembers(community) 260 | ) # Need to do this to do batching. 261 | community_members = [component_node_map[i] for i in community_members] 262 | if strictness == "community": 263 | assignments.extend( 264 | [ 265 | {"id": node_to_id_map[n], "cluster": cluster_index} 266 | for n in community_members 267 | ] 268 | ) 269 | cluster_index += 1 270 | continue 271 | 272 | for start in range(0, len(community_members), max_clique_batch_size): 273 | community_nodes = community_members[ 274 | start : start + max_clique_batch_size 275 | ] 276 | LOGGER.debug("Creating subgraph with %s nodes.", len(community_nodes)) 277 | # Map between node values for a community 278 | community_node_map = dict(enumerate(community_nodes)) 279 | subgraph = nk.graphtools.subgraphFromNodes( 280 | graph, community_nodes, compact=True 281 | ) 282 | 283 | while subgraph.numberOfNodes() > 0: 284 | LOGGER.debug("Subgraph size: %s", subgraph.numberOfNodes()) 285 | clique = nk.clique.MaximalCliques(subgraph, maximumOnly=True) 286 | clique.run() 287 | clique_members = clique.getCliques()[0] 288 | assignments.extend( 289 | [ 290 | { 291 | "id": node_to_id_map[community_node_map[n]], 292 | "cluster": cluster_index, 293 | } 294 | for n in clique_members 295 | ] 296 | ) 297 | cluster_index += 1 298 | for n in clique_members: 299 | subgraph.removeNode(n) 300 | 301 | return assignments 302 | -------------------------------------------------------------------------------- /perception/extensions.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # cython: language=c++ 3 | 4 | import math 5 | import sys 6 | 7 | import cython 8 | import numpy as np 9 | from cython.parallel import parallel, prange 10 | 11 | cimport numpy as np 12 | from libc.stdlib cimport abort, free, malloc 13 | from libcpp cimport bool as cppbool 14 | from libcpp.vector cimport vector 15 | 16 | 17 | cdef extern from "limits.h": 18 | int INT_MAX 19 | 20 | ctypedef np.uint8_t uint8 21 | 22 | @cython.boundscheck(False) 23 | @cython.wraparound(False) 24 | def compute_euclidean_pairwise_duplicates(int[:, :] X, float threshold, counts: np.uint32_t[:] = None, compute_overlap=False): 25 | """Find the pairwise overlap within an array of vectors, where there may be multiple 26 | vectors for the same file. This function is faster than using scipy.spatial.distance 27 | because it computes distances in parallel, avoids computing full distances when they're 28 | not necessary, skips computing distances for pairs of hashes that are for the 29 | same file, and skips computing distances for vectors if both have already been matched. 30 | 31 | Args: 32 | X: The vectors with shape (N, D). Vectors for the same file need to be 33 | supplied sequentially so that we can use the counts argument 34 | to determine which vectors are for the same file. 35 | counts: For each file, the number of sequential vectors in X. If not 36 | provided, each vector is assumed to be for a different file (i.e., 37 | this is equivalent to `counts = np.ones(N)`). 38 | compute_overlap: If True, the values returned will be divided by the number 39 | of hashes in each file. If False, the raw duplicate counts will 40 | be returned. 41 | 42 | Returns: 43 | duplicates: An array of shape (M!/(2*((M-2)!)), 2) indicating 44 | the fraction of vectors for each file found in another file. 45 | The indexing matches that of scipy.spatial.pdist. M is the number of files. 46 | So if M = 4, the array will represent comparisons of the file indexes as follows: 47 | [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3)]. So (assuming compute_overlap=True), 48 | a possible return would be [(1.0, 1.0), (0, 0), (0, 0), (0.66, 1.0), (0.5, 0.25)] 49 | which means that: 50 | 51 | - There was 100% overlap between file 0 and file 1 52 | - 66% of file 1 was in file 2 and 100% of file 2 was in file 1 53 | - 50% of file 2 was in file 3 and 25% of file 3 was in file 2 54 | """ 55 | if counts is None: 56 | counts = np.ones(X.shape[0], dtype=np.uint32) 57 | cdef Py_ssize_t n = X.shape[0] 58 | cdef Py_ssize_t m = counts.shape[0] 59 | cdef Py_ssize_t d = X.shape[1] 60 | n_pairs_python = int(math.factorial(m)/(2*math.factorial(m-2))) 61 | assert n_pairs_python < sys.maxsize, 'Too many files were provided for deduplication.' 62 | cdef Py_ssize_t n_pairs = n_pairs_python 63 | cdef Py_ssize_t max_counts = np.max(counts) 64 | cdef int compute_overlap_int = 0 65 | if compute_overlap: 66 | compute_overlap_int = 1 67 | # i_1 is the index of file1, i_2 is the index of file2, i_d is the 68 | # index of the vector dimension we're on, i_i is used to compute 69 | # the starting index in the flattened vector in the different threads. 70 | # i_1_subhash is the index of the hash on file1, i_2_subhash is 71 | # the index of the hash on file2. 72 | cdef Py_ssize_t i_1, i_2, i_d, i_i, i_1_sub, i_2_sub, i_1_offset 73 | duplicate_arr = np.zeros((n_pairs, 2), dtype=np.double) 74 | cdef double[:, :] duplicate = duplicate_arr 75 | offsets_arr = np.zeros(m, dtype=np.int32) 76 | cdef np.int32_t[:] offsets = offsets_arr 77 | for i_1 in range(m): 78 | for i_i in range(i_1): 79 | offsets[i_1] += counts[i_i] 80 | # local_buf will contain distance, flattened array offset, index_offset_1, index_offset_2 81 | cdef size_t local_buf_size = 4 82 | cdef float threshold2 = threshold ** 2 83 | with nogil, parallel(): 84 | local_buf = malloc(sizeof(np.uint64_t) * local_buf_size) 85 | 86 | # An array of flags indicating whether a vector in file 1 was 87 | # matched. 88 | matched_1 = malloc(sizeof(int) * max_counts) 89 | 90 | # An array of flags indicating whether a vector in file 2 was 91 | # matched. 92 | matched_2 = malloc(sizeof(int) * max_counts) 93 | if local_buf is NULL or matched_1 is NULL or matched_2 is NULL: 94 | abort() 95 | # Iterate over all of the files. 96 | for i_1 in prange(m-1): 97 | local_buf[1] = 0 98 | local_buf[2] = offsets[i_1] 99 | # Compute the index of the output vector 100 | # where we will count the number of duplicates. 101 | for i_i in range(i_1): 102 | local_buf[1] += m - i_i - 1 103 | # Iterate over all the other files to compare. 104 | for i_2 in range(i_1 + 1, m): 105 | local_buf[3] = offsets[i_2] 106 | # Initialize all match flags to zero for 107 | # both file 1 and file 2. 108 | for i_1_sub in range(counts[i_1]): 109 | matched_1[i_1_sub] = 0 110 | for i_2_sub in range(counts[i_2]): 111 | matched_2[i_2_sub] = 0 112 | # Iterate over all the hashes in file1 113 | for i_1_sub in range(counts[i_1]): 114 | # Iterate over all the hashes in file2 115 | for i_2_sub in range(counts[i_2]): 116 | local_buf[0] = 0 117 | if matched_1[i_1_sub] == 1 and matched_2[i_2_sub] == 1: 118 | # Both the vectors in this pair have already been matched, so 119 | # there is nothing to gain from this comparison. 120 | continue 121 | for i_d in range(d): 122 | local_buf[0] += (X[local_buf[2] + i_1_sub, i_d] - X[local_buf[3] + i_2_sub, i_d]) ** 2 123 | if local_buf[0] > threshold2: 124 | # If we're already beyond the distance threshold, 125 | # we don't need to continue computing squared 126 | # distances. 127 | break 128 | if local_buf[0] < threshold2: 129 | # A match was found. Set flags for both vectors 130 | # to 1. 131 | matched_1[i_1_sub] = 1 132 | matched_2[i_2_sub] = 1 133 | # Add up the number of matches for file 1. 134 | for i_1_sub in range(counts[i_1]): 135 | duplicate[local_buf[1], 0] += matched_1[i_1_sub] 136 | # Add up the number of matches for file 2. 137 | for i_2_sub in range(counts[i_2]): 138 | duplicate[local_buf[1], 1] += matched_2[i_2_sub] 139 | # Divide by the total number of vectors for each file. 140 | if compute_overlap_int: 141 | duplicate[local_buf[1], 0] /= counts[i_1] 142 | duplicate[local_buf[1], 1] /= counts[i_2] 143 | # Advance to the next pair index. 144 | local_buf[1] += 1 145 | free(local_buf) 146 | free(matched_1) 147 | free(matched_2) 148 | return duplicate_arr 149 | 150 | 151 | @cython.boundscheck(False) 152 | @cython.wraparound(False) 153 | def compute_euclidean_pairwise_duplicates_simple(int[:, :] X, float threshold, np.uint32_t[:] counts = None, float minimum_overlap = 0): 154 | """Find the pairwise overlap within an array of vectors, where there may be multiple 155 | vectors for the same file. This function is similar to compute_euclidean_pairwise_duplicates 156 | but uses much less memory. 157 | 158 | Args: 159 | X: The vectors with shape (N, D). Vectors for the same file need to be 160 | supplied sequentially so that we can use the counts argument 161 | to determine which vectors are for the same file. 162 | threshold: The maximum distance between to vectors to allow for 163 | a match. 164 | counts: For each of the M files, the number of sequential vectors in X. 165 | If not provided, each vector is assumed to be for a different file (i.e., 166 | this is equivalent to `counts = np.ones(N)` which also implies M == N). 167 | Otherwise, assumed to have length M. The counts should add up to N. 168 | minimum_overlap: The minimum overlap between two groups of hashes to 169 | call it a match. 170 | 171 | Returns: 172 | pairs: Pairs of indexes that met the matching criteria. 173 | """ 174 | if counts is None: 175 | counts_arr = np.ones(X.shape[0], dtype=np.uint32) 176 | counts = counts_arr 177 | cdef Py_ssize_t n = X.shape[0] 178 | cdef Py_ssize_t m = counts.shape[0] 179 | cdef Py_ssize_t d = X.shape[1] 180 | n_pairs_python = int(math.factorial(m)/(2*math.factorial(m-2))) 181 | assert n_pairs_python < sys.maxsize, 'Too many files were provided for deduplication.' 182 | cdef Py_ssize_t n_pairs = n_pairs_python 183 | cdef Py_ssize_t max_counts = np.max(counts) 184 | # i_1 is the index of file1, i_2 is the index of file2, i_d is the 185 | # index of the vector dimension we're on, i_i is used to compute 186 | # the starting index in the flattened vector in the different threads. 187 | # i_1_subhash is the index of the hash on file1, i_2_subhash is 188 | # the index of the hash on file2. 189 | cdef Py_ssize_t i_1, i_2, i_d, i_i, i_1_sub, i_2_sub 190 | cdef vector[cppbool] duplicate 191 | duplicate.resize(n_pairs) 192 | offsets_arr = np.zeros(m, dtype=np.uint64) 193 | cdef np.uint64_t[:] offsets = offsets_arr 194 | cdef np.int32_t expected_n = 0 195 | for i_1 in range(m): 196 | for i_i in range(i_1): 197 | offsets[i_1] += counts[i_i] 198 | expected_n += counts[i_1] 199 | assert expected_n == n, "Provided value for counts is inconsistent with X." 200 | # local_buf will contain: 201 | # distance, flattened array offset, 202 | # index_offset_1, index_offset_2 203 | cdef size_t local_buf_size = 4 204 | cdef float threshold2 = threshold ** 2 205 | with nogil, parallel(): 206 | local_buf = malloc(sizeof(np.uint64_t) * local_buf_size) 207 | 208 | # An array of flags indicating whether a vector in file 1 was 209 | # matched. 210 | matched_1 = malloc(sizeof(int) * max_counts) 211 | 212 | # An array of flags indicating whether a vector in file 2 was 213 | # matched. 214 | matched_2 = malloc(sizeof(int) * max_counts) 215 | 216 | # Pair overlap and minimum required overlap 217 | overlap = malloc(sizeof(float) * 4) 218 | 219 | if local_buf is NULL or matched_1 is NULL or matched_2 is NULL or overlap is NULL: 220 | abort() 221 | # Iterate over all of the files. 222 | for i_1 in prange(m-1): 223 | local_buf[1] = 0 224 | local_buf[2] = offsets[i_1] 225 | # Compute the index of the output vector 226 | # where we will count the number of duplicates. 227 | for i_i in range(i_1): 228 | local_buf[1] += m - i_i - 1 229 | # Iterate over all the other files to compare. 230 | for i_2 in range(i_1 + 1, m): 231 | # Set the current and minimum overlaps 232 | overlap[0] = 0 233 | overlap[1] = 0 234 | overlap[2] = minimum_overlap * counts[i_1] 235 | overlap[3] = minimum_overlap * counts[i_2] 236 | local_buf[3] = offsets[i_2] 237 | 238 | # Set early termination flag. 239 | local_buf[4] = 0 240 | 241 | # Initialize all match flags to zero for 242 | # both file 1 and file 2. 243 | for i_1_sub in range(counts[i_1]): 244 | matched_1[i_1_sub] = 0 245 | for i_2_sub in range(counts[i_2]): 246 | matched_2[i_2_sub] = 0 247 | # Iterate over all the hashes in file1 248 | for i_1_sub in range(counts[i_1]): 249 | # Stop early if there's no way to get enough 250 | # matches from i1 to i2 251 | if overlap[0] + counts[i_1] - i_1_sub < overlap[2]: 252 | break 253 | # Stop early if we've already reached the minimum overlap 254 | if overlap[0] >= overlap[2] and overlap[1] >= overlap[3] and overlap[0] > 0 and overlap[1] > 0: 255 | break 256 | 257 | # Iterate over all the hashes in file2 258 | for i_2_sub in range(counts[i_2]): 259 | local_buf[0] = 0 260 | if matched_1[i_1_sub] == 1 and matched_2[i_2_sub] == 1: 261 | # Both the vectors in this pair have already been matched, so 262 | # there is nothing to gain from this comparison. 263 | continue 264 | for i_d in range(d): 265 | local_buf[0] += (X[local_buf[2] + i_1_sub, i_d] - X[local_buf[3] + i_2_sub, i_d]) ** 2 266 | if local_buf[0] > threshold2: 267 | # If we're already beyond the distance threshold, 268 | # we don't need to continue computing squared 269 | # distances. 270 | break 271 | if local_buf[0] < threshold2: 272 | # A match was found. Set flags for both vectors 273 | # to 1 and increment the overlap. 274 | if matched_1[i_1_sub] != 1: 275 | overlap[0] += 1 276 | if matched_2[i_2_sub] != 1: 277 | overlap[1] += 1 278 | matched_1[i_1_sub] = 1 279 | matched_2[i_2_sub] = 1 280 | if overlap[0] >= overlap[2] and overlap[1] >= overlap[3] and overlap[0] > 0 and overlap[1] > 0: 281 | duplicate[local_buf[1]] = 1 282 | local_buf[1] += 1 283 | free(matched_1) 284 | free(matched_2) 285 | free(overlap) 286 | free(local_buf) 287 | cdef int n_duplicates = 0 288 | cdef Py_ssize_t i_offset = 0 289 | for i_offset in range(n_pairs): 290 | if duplicate[i_offset] > 0: 291 | n_duplicates += 1 292 | pairs_arr = np.zeros((n_duplicates, 2), dtype=np.int32) 293 | cdef np.int32_t[:, :] pairs = pairs_arr 294 | i_offset = 0 295 | cdef Py_ssize_t pair_offset = 0 296 | for i_1 in range(m-1): 297 | # Compute the index of the output vector 298 | # where we will count the number of duplicates. 299 | for i_2 in range(i_1 + 1, m): 300 | if duplicate[i_offset] > 0: 301 | pairs[pair_offset][0] = i_1 302 | pairs[pair_offset][1] = i_2 303 | pair_offset += 1 304 | i_offset += 1 305 | return pairs_arr 306 | -------------------------------------------------------------------------------- /perception/hashers/hasher.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import typing 3 | import warnings 4 | from abc import ABC, abstractmethod 5 | from logging import warning 6 | 7 | import numpy as np 8 | import scipy.spatial 9 | import tqdm 10 | 11 | from perception.hashers import tools 12 | 13 | 14 | class Hasher(ABC): 15 | """All hashers implement a common set of methods from 16 | the Hasher base class. 17 | """ 18 | 19 | #: The metric to use when computing distance between two hashes. All hashers 20 | #: must supply this parameter. 21 | distance_metric: str 22 | 23 | #: The numpy type to use when converting from string to array form. 24 | #: All hashers must supply this parameter. 25 | dtype: str 26 | 27 | #: Indicates the length of the hash vector 28 | hash_length: int 29 | 30 | #: Whether or not this hash returns multiple values 31 | returns_multiple: bool = False 32 | 33 | #: Indicates whether the hashes can be computed in parallel 34 | allow_parallel: bool = True 35 | 36 | def string_to_vector(self, hash_string: str, hash_format: str = "base64"): 37 | """Convert hash string to vector. 38 | 39 | Args: 40 | hash_string: The input hash string 41 | hash_format: One of 'base64' or 'hex' 42 | """ 43 | return tools.string_to_vector( 44 | hash_string, 45 | dtype=self.dtype, 46 | hash_length=self.hash_length, 47 | hash_format=hash_format, 48 | ) 49 | 50 | def vector_to_string( 51 | self, vector: np.ndarray, hash_format: str = "base64" 52 | ) -> str | None: 53 | """Convert vector to hash string. 54 | 55 | Args: 56 | vector: Input vector 57 | hash_format: One of 'base64' or 'hex' 58 | """ 59 | return tools.vector_to_string(vector, dtype=self.dtype, hash_format=hash_format) 60 | 61 | def compute_distance( 62 | self, 63 | hash1: np.ndarray | str, 64 | hash2: np.ndarray | str, 65 | hash_format="base64", 66 | ): 67 | """Compute the distance between two hashes. 68 | 69 | Args: 70 | hash1: The first hash or vector 71 | hash2: The second hash or vector 72 | hash_format: If either or both of the hashes are hash strings, 73 | what format the string is encoded in. 74 | """ 75 | hash1 = ( 76 | self.string_to_vector(hash1, hash_format=hash_format) 77 | if isinstance(hash1, str) 78 | else hash1 79 | ) # makes mypy happy 80 | hash2 = ( 81 | self.string_to_vector(hash2, hash_format=hash_format) 82 | if isinstance(hash2, str) 83 | else hash2 84 | ) 85 | 86 | if self.distance_metric == "sqeuclidean": 87 | return scipy.spatial.distance.sqeuclidean( 88 | hash1.astype("float32"), hash2.astype("float32") 89 | ) 90 | if self.distance_metric == "euclidean": 91 | return scipy.spatial.distance.euclidean( 92 | hash1.astype("float32"), hash2.astype("float32") 93 | ) 94 | if self.distance_metric == "hamming": 95 | return scipy.spatial.distance.hamming(hash1, hash2) 96 | if self.distance_metric == "cosine": 97 | return scipy.spatial.distance.cosine( 98 | hash1.astype("float32"), hash2.astype("float32") 99 | ) 100 | if self.distance_metric == "custom": 101 | return self._compute_distance(hash1, hash2) 102 | raise NotImplementedError( 103 | f"Distance metric: {self.distance_metric} not supported." 104 | ) 105 | 106 | def _compute_distance(self, vector1, vector2): 107 | raise ValueError("Called a custom distance function but it is not implemented.") 108 | 109 | @typing.no_type_check 110 | def compute_parallel( 111 | self, 112 | filepaths: list[str], 113 | progress: tqdm.tqdm | None = None, 114 | progress_desc: str | None = None, 115 | max_workers: int = 5, 116 | isometric: bool = False, 117 | ): 118 | """Compute hashes in a parallelized fashion. 119 | 120 | Args: 121 | filepaths: A list of paths to images or videos (depending on the hasher). 122 | progress: A tqdm-like wrapper for reporting progress. If None, 123 | progress is not reported. 124 | progress_desc: The title of the progress bar. 125 | max_workers: The maximum number of workers 126 | isometric: Whether to compute all eight isometric transforms for 127 | each image. 128 | """ 129 | if not self.allow_parallel and max_workers != 1: 130 | warnings.warn( 131 | message="This hash cannot be used in parallel. Setting max_workers to 1.", 132 | category=UserWarning, 133 | ) 134 | max_workers = 1 135 | assert all( 136 | isinstance(p, str) for p in filepaths 137 | ), "All images should be provided as paths." 138 | 139 | if isinstance(self, VideoHasher) and isometric: 140 | raise ValueError("Computing isometric hashes for videos is not supported.") 141 | 142 | # We can use a with statement to ensure threads are cleaned up promptly 143 | records = [] 144 | if isinstance(self, VideoHasher): 145 | executor_class = concurrent.futures.ProcessPoolExecutor 146 | else: 147 | executor_class = concurrent.futures.ThreadPoolExecutor 148 | with executor_class(max_workers=max_workers) as executor: 149 | # Start the load operations and mark each future with its filepath 150 | compute: typing.Callable = ( 151 | self.compute_isometric if isometric else self.compute 152 | ) 153 | future_to_path: dict = { 154 | executor.submit(compute, path): path for path in filepaths 155 | } 156 | generator = concurrent.futures.as_completed(future_to_path) 157 | if progress is not None: 158 | generator = progress( 159 | generator, total=len(filepaths), desc=progress_desc 160 | ) 161 | for future in generator: 162 | path = future_to_path[future] 163 | try: 164 | hash_value = future.result() 165 | except Exception as exc: 166 | records.append({"filepath": path, "hash": None, "error": str(exc)}) 167 | else: 168 | records.append( 169 | {"filepath": path, "hash": hash_value, "error": None} 170 | ) 171 | return records 172 | 173 | 174 | class ImageHasher(Hasher): 175 | @abstractmethod 176 | def _compute(self, image: np.ndarray) -> np.ndarray: 177 | """Compute hash from an image. 178 | 179 | Args: 180 | image: A numpy array representing an image as 181 | of shape (H, W, 3) where channels are ordered 182 | as RGB or a filepath to an image. 183 | """ 184 | 185 | def compute_isometric_from_hash(self, hash_string_or_vector, hash_format="base64"): 186 | """For supported hashes, obtain the hashes for the dihedral transformations 187 | of the original image. They are provided in the following order: 188 | 189 | - Vertical flip 190 | - Horizontal flip 191 | - 180 degree rotation 192 | - 90 degree rotation 193 | - 90 degree rotation and vertical flip 194 | - 90 degree rotation and horizontal flip 195 | - 270 degree rotation 196 | 197 | Args: 198 | hash_string_or_vector: The hash string or vector 199 | hash_format: One 'base64' or 'hex' 200 | """ 201 | if not hasattr(self, "_compute_isometric_from_hash"): 202 | raise NotImplementedError("This hasher does not support hash rotation.") 203 | rotations = self._compute_isometric_from_hash( # type: ignore 204 | hash_string_or_vector 205 | if isinstance(hash_string_or_vector, np.ndarray) 206 | else self.string_to_vector(hash_string_or_vector, hash_format=hash_format) 207 | ) 208 | return { 209 | transform_name: self.vector_to_string(vector, hash_format=hash_format) 210 | for transform_name, vector in rotations.items() 211 | } 212 | 213 | def compute_isometric(self, image: tools.ImageInputType): 214 | image = tools.to_image_array(image) 215 | if hasattr(self, "_compute_isometric"): 216 | hashes = self._compute_isometric(image) # type: ignore 217 | elif hasattr(self, "_compute_isometric_from_hash"): 218 | hashes = self._compute_isometric_from_hash( # type: ignore 219 | self._compute(image) 220 | ) 221 | else: 222 | transforms = tools.get_isometric_transforms(image) 223 | for name, transform in transforms.items(): 224 | transforms[name] = self._compute(transform) 225 | hashes = transforms 226 | return { 227 | transform_name: self.vector_to_string(vector) 228 | for transform_name, vector in hashes.items() 229 | } 230 | 231 | def compute( 232 | self, image: tools.ImageInputType, hash_format="base64" 233 | ) -> np.ndarray | str | None | list[str | None]: 234 | """Compute a hash from an image. 235 | 236 | Args: 237 | image: An image represented as a filepath, a PIL image object, 238 | or as an np.ndarray object. If it is an np.ndarray object, 239 | it must be in RGB color order (note the OpenCV default is 240 | BGR). 241 | hash_format: One 'base64', 'hex', or 'vector' 242 | """ 243 | vector = self._compute(tools.to_image_array(image)) 244 | if hash_format == "vector": 245 | # Take care of this separately because we took out `vector` 246 | # as valid return type to vector_to_string(). 247 | # The .tolist() might seem unnecessary for the 248 | # ndarray `vector` but downstream expects a list and it 249 | # stays consistent with original, so keeping for now. 250 | # return (vector.tolist() if self.returns_multiple 251 | # else vector) 252 | return vector # should iterate the same as vector.tolist() 253 | if self.returns_multiple: 254 | return [self.vector_to_string(v, hash_format=hash_format) for v in vector] 255 | return self.vector_to_string(vector, hash_format=hash_format) 256 | 257 | def compute_with_quality( 258 | self, image: tools.ImageInputType, hash_format="base64" 259 | ) -> tuple[ 260 | (np.ndarray | str | None | list[str | None]), 261 | int, 262 | ]: 263 | """Compute hash and hash quality from image. 264 | 265 | Args: 266 | image: An image represented as a filepath, a PIL image object, 267 | or as an np.ndarray object. If it is an np.ndarray object, 268 | it must be in RGB color order (note the OpenCV default is 269 | BGR). 270 | hash_format: One 'base64', 'hex', or 'vector' 271 | 272 | Returns: 273 | A tuple of (hash, quality) 274 | """ 275 | vector, quality = self._compute_with_quality(tools.to_image_array(image)) 276 | if hash_format == "vector": 277 | return vector, quality 278 | if self.returns_multiple: 279 | return ( 280 | [self.vector_to_string(v, hash_format=hash_format) for v in vector], 281 | quality, 282 | ) 283 | return (self.vector_to_string(vector, hash_format=hash_format), quality) 284 | 285 | def _compute_with_quality(self, image: np.ndarray) -> tuple[np.ndarray, int]: 286 | return self._compute(image), tools.compute_quality(image) 287 | 288 | 289 | class VideoHasher(Hasher): 290 | 291 | #: The frame rate at which videos are read 292 | frames_per_second: float = 1 293 | 294 | @abstractmethod 295 | def process_frame( 296 | self, 297 | frame: np.ndarray, 298 | frame_index: int | None, 299 | frame_timestamp: float | None, 300 | state: dict | None = None, 301 | ) -> dict: 302 | """Called for each frame in the video. For all 303 | but the first frame, a state is provided recording the state from 304 | the previous frame. 305 | 306 | Args: 307 | frame: The current frame as an RGB ndarray 308 | frame_index: The current frame index 309 | frame_timestamp: The current frame timestamp 310 | state: The state from the last call to process_frame 311 | """ 312 | 313 | @abstractmethod 314 | def hash_from_final_state(self, state: dict) -> np.ndarray: 315 | """Called after all frames have been processed. Returns the final 316 | feature vector. 317 | 318 | Args: 319 | state: The state dictionary at the end of processing. 320 | """ 321 | 322 | def compute( 323 | self, 324 | filepath, 325 | errors="raise", 326 | hash_format="base64", 327 | scenes=None, 328 | **kwargs, 329 | ): 330 | """Compute a hash for a video at a given filepath. All 331 | other arguments are passed to perception.hashers.tools.read_video. 332 | 333 | Args: 334 | filepath: Path to video file 335 | errors: One of "raise", "ignore", or "warn". Passed 336 | to perception.hashers.tools.read_video. 337 | hash_format: One of "vector", "base64", or "hex" 338 | max_duration: The maximum length of the video to hash. 339 | max_size: The maximum size of frames to queue 340 | scenes: An array used to pass scene info back to wrapper 341 | functions 342 | """ 343 | frame_timestamp, state = None, None 344 | # Iterate through the video, aggregating scene info in the state 345 | # dict 346 | for frame, frame_index, frame_timestamp in tools.read_video( 347 | filepath=filepath, 348 | frames_per_second=self.frames_per_second, 349 | errors=errors, 350 | **kwargs, 351 | ): 352 | state = self.process_frame( 353 | frame=frame, 354 | frame_index=frame_index, 355 | frame_timestamp=frame_timestamp, 356 | state=state, 357 | ) 358 | 359 | if state is None: 360 | if errors == "raise": 361 | raise ValueError( 362 | f"Video processing failed for {filepath}, State is None." 363 | ) 364 | if errors == "warn": 365 | warning(f"Video processing failed for {filepath}, State is None.") 366 | 367 | return None 368 | 369 | # Persist the final timestamp in the state to allow us to pass along 370 | # duration 371 | state["end"] = frame_timestamp 372 | vectors = self.hash_from_final_state(state=state) 373 | if scenes is not None: 374 | scenes += state.get("scenes", []) 375 | if hash_format == "vector": 376 | # Take care of this separately because we took out `vector` 377 | # as valid return type to vector_to_string(). 378 | # The .tolist() might seem unnecessary for the 379 | # ndarray `vector` but downstream expects a list and it 380 | # stays consistent with original, so keeping for now. 381 | # return (vector.tolist() if self.returns_multiple 382 | # else vector) 383 | return vectors # should iterate the same as vector.tolist() 384 | if self.returns_multiple: 385 | return [self.vector_to_string(v, hash_format=hash_format) for v in vectors] 386 | return self.vector_to_string(vectors, hash_format=hash_format) 387 | --------------------------------------------------------------------------------