├── MANIFEST.in ├── requirements.txt ├── test ├── imgs │ ├── no_face.png │ ├── joe_biden.jpeg │ ├── barak_obama.jpeg │ ├── joe_biden_2.jpeg │ └── narendra_modi.jpeg └── test.py ├── facedb ├── __init__.py ├── query.py ├── db_tools.py ├── db_models.py └── db.py ├── .github └── workflows │ └── test.yml ├── LICENSE ├── setup.py ├── docs └── code_documentation.md ├── .gitignore ├── CODE_OF_CONDUCT.md └── README.md /MANIFEST.in: -------------------------------------------------------------------------------- 1 | global-include *.py,*.md,*.txt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | chromadb 2 | pillow 3 | opencv-python 4 | face-recognition 5 | -------------------------------------------------------------------------------- /test/imgs/no_face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shhossain/FaceDB/HEAD/test/imgs/no_face.png -------------------------------------------------------------------------------- /facedb/__init__.py: -------------------------------------------------------------------------------- 1 | from facedb.db import FaceDB 2 | from facedb.db_models import FaceResult, FaceResults -------------------------------------------------------------------------------- /test/imgs/joe_biden.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shhossain/FaceDB/HEAD/test/imgs/joe_biden.jpeg -------------------------------------------------------------------------------- /test/imgs/barak_obama.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shhossain/FaceDB/HEAD/test/imgs/barak_obama.jpeg -------------------------------------------------------------------------------- /test/imgs/joe_biden_2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shhossain/FaceDB/HEAD/test/imgs/joe_biden_2.jpeg -------------------------------------------------------------------------------- /test/imgs/narendra_modi.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shhossain/FaceDB/HEAD/test/imgs/narendra_modi.jpeg -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | test: 13 | strategy: 14 | matrix: 15 | python-version: [3.8, 3.11] 16 | os: [ubuntu-latest, windows-latest, macos-latest] 17 | 18 | runs-on: ${{ matrix.os }} 19 | 20 | steps: 21 | - name: Checkout repository 22 | uses: actions/checkout@v2 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install -r requirements.txt # If you have any requirements 33 | 34 | - name: Run tests 35 | run: python test/test.py ${{ secrets.TEST_KEY }} 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 sifat 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import sys 3 | 4 | version = "0.0.13" 5 | description = "A python package for face database management" 6 | 7 | with open("README.md", encoding="utf-8") as f: 8 | long_description = f.read() 9 | 10 | name = "FaceDB" 11 | author = "sifat (shhossain)" 12 | 13 | with open("requirements.txt") as f: 14 | required = f.read().splitlines() 15 | 16 | if sys.version_info < (3, 8): 17 | required.append("typing_extensions") 18 | 19 | keywords = ["python", "face", "recognition"] 20 | 21 | classifiers = [ 22 | "Development Status :: 3 - Alpha", 23 | "Intended Audience :: Developers", 24 | "Intended Audience :: Education", 25 | "Intended Audience :: Science/Research", 26 | "License :: OSI Approved :: MIT License", 27 | "Programming Language :: Python :: 3", 28 | "Programming Language :: Python :: 3.6", 29 | "Programming Language :: Python :: 3.7", 30 | "Programming Language :: Python :: 3.8", 31 | "Programming Language :: Python :: 3.9", 32 | "Programming Language :: Python :: 3.10", 33 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 34 | "Topic :: Software Development :: Libraries :: Python Modules", 35 | "Topic :: Text Processing :: Linguistic", 36 | "Topic :: Utilities", 37 | "Operating System :: OS Independent", 38 | ] 39 | 40 | projects_links = { 41 | "Documentation": "https://github.com/shhossain/facedb", 42 | "Source": "https://github.com/shhossain/facedb", 43 | "Bug Tracker": "https://github.com/shhossain/facedb/issues", 44 | } 45 | 46 | 47 | setup( 48 | name=name, 49 | version=version, 50 | description=description, 51 | long_description=long_description, 52 | long_description_content_type="text/markdown", 53 | author=author, 54 | url="https://github.com/shhossain/facedb", 55 | project_urls=projects_links, 56 | packages=find_packages(), 57 | install_requires=required, 58 | keywords=keywords, 59 | classifiers=classifiers, 60 | python_requires=">=3.7", 61 | include_package_data=True, 62 | zip_safe=False, 63 | ) 64 | -------------------------------------------------------------------------------- /docs/code_documentation.md: -------------------------------------------------------------------------------- 1 | # FaceDB object initialization 2 | 3 | ### FaceResult 4 | 5 | #### Parameters 6 | ``` 7 | id 8 | name 9 | distance 10 | embedding 11 | img 12 | ``` 13 | #### Methods 14 | 15 | ##### show_img() 16 | Convenient way to see the image that is in your FaceResults object. Open a matplotlib window 17 | 18 | ### FaceResults 19 | 20 | #### Parameters 21 | Note that the following parameters are only going to be accessile if you just have and FaceResult in your object FaceResults 22 | ``` 23 | id 24 | name 25 | distance 26 | embedding 27 | img 28 | ``` 29 | #### Methods 30 | 31 | ##### show_img() 32 | Convenient way to see the images that is in your FaceResults object. Open a matplotlib window 33 | 34 | # Main features of FaceDB 35 | 36 | ### FaceDB.add(name:str,img=None,embedding=None,id=None,check_similar=True,save_just_face=False,**kwargs: Additional metadata for the face.) 37 | Give you the possibility to add a new entry in our FaceDB database. 38 | Example : 39 | ``` 40 | db.add("Nelson Mandela", img="mandela.jpg", profession="Politician", country="South Africa") 41 | db.add("Barack Obama", img="obama.jpg", profession="Politician", country="USA") 42 | db.add("Einstein", img="einstein.jpg", profession="Scientist", country="Germany") 43 | ``` 44 | 45 | ### FaceDB.add_many(embeddings=None,imgs=None,metadata=None,ids=None,names=None,check_similar=True) 46 | Give you the possibility to add several new entries in our FaceDB database at one time. 47 | Example : 48 | ``` 49 | files = glob("faces/*.jpg") # Suppose you have a folder with imgs with names as filenames 50 | imgs = [] 51 | names = [] 52 | for file in files: 53 | imgs.append(file) 54 | names.append(Path(file).name) 55 | 56 | ids, failed_indexes = db.add_many( 57 | imgs=imgs, 58 | names=names, 59 | ) 60 | ``` 61 | 62 | ### FaceDB.recognize(img=None, embedding=None, include=None, threshold=None, top_k=1) -> False|None|FaceResults 63 | Try to find the name of the personne within the picture. 64 | Example: 65 | ``` 66 | result = db.recognize(img="your_image.jpg") 67 | ``` 68 | ### FaceDB.all(include=None) -> FaceResults 69 | Retrieve information about all faces in the database. 70 | Example: 71 | ``` 72 | results = db.all(include='name') 73 | #Or with a list 74 | results = db.all(include=['name', 'img']) 75 | ``` 76 | ### FaceDB.all().df -> pd.DataFrame 77 | Easy to get your result in a Pandas DataFrame. 78 | Example: 79 | ``` 80 | df = db.all().df 81 | ``` 82 | 83 | ### FaceDB.search(embedding,include) -> FaceResults 84 | Search for similar faces based on the image you provided. 85 | Example: 86 | ``` 87 | results = db.search(img="your image.jpg") 88 | ``` 89 | ### FaceDB.get_all(include) -> FaceResults 90 | Get all the faces of the db according on the parameters you want. 91 | Example: 92 | ``` 93 | results = db.get_all(include=['name', 'img']) 94 | ``` 95 | ### FaceDB.update(id, name=None, embedding=None, img=None, only_face=False) 96 | Update value in your db. 97 | Example: 98 | ``` 99 | db.update(id=face_id, name="John Doe", img="john_doe.jpg") 100 | ``` 101 | ### FaceDB.delete(id) 102 | Delete one element in the database. 103 | Example: 104 | ``` 105 | db.delete(face_id) 106 | ``` 107 | ### Face.count() -> int 108 | Count the number of faces in the database. 109 | Example: 110 | ``` 111 | count = db.count() 112 | ``` 113 | ### Face.query(embeddings,top_k=1,include: Optional[List[Literal["embeddings", "metadatas"]]] = None,where=None,)) -> FaceResults 114 | Make a query to the database and get the result of the db. 115 | Example: 116 | ``` 117 | results = db.query(name="Nelson Mandela") 118 | ``` 119 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | *data*/ 3 | *.ipynb 4 | 5 | 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | hossain0338@gmail.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FaceDB - A Face Recognition Database 2 | 3 | FaceDB is a Python library that provides an easy-to-use interface for face recognition and face database management. It allows you to perform face recognition tasks, such as face matching and face searching, and manage a database of faces efficiently. FaceDB supports two popular face recognition frameworks: DeepFace and face_recognition. 4 | 5 | ## Links 6 | [Pypi](https://pypi.org/project/facedb/) 7 | [Github](https://github.com/shhossain/facedb) 8 | 9 | ## Installation 10 | 11 | FaceDB can be installed using pip: 12 | 13 | ```bash 14 | pip install facedb 15 | ``` 16 | 17 | You can use face_recognition or DeepFace for face recognition. If you want to use DeepFace, you need to install the following dependencies: 18 | for face_recognition: 19 | 20 | ```bash 21 | pip install face_recognition 22 | ``` 23 | 24 | for DeepFace: 25 | 26 | ```bash 27 | pip install deepface 28 | ``` 29 | 30 | ## Simple Usage 31 | 32 | This will create a chromadb database in the current directory. 33 | 34 | ```python 35 | # Import the FaceDB library 36 | from facedb import FaceDB 37 | 38 | # Create a FaceDB instance 39 | db = FaceDB( 40 | path="facedata", 41 | ) 42 | 43 | # Add a new face to the database 44 | face_id = db.add("John Doe", img="john_doe.jpg") 45 | 46 | # Recognize a face 47 | result = db.recognize(img="new_face.jpg") 48 | 49 | # Check if the recognized face is similar to the one in the database 50 | if result and result["id"] == face_id: 51 | print("Recognized as John Doe") 52 | else: 53 | print("Unknown face") 54 | ``` 55 | 56 | ## Advanced Usage 57 | 58 | You need to install pinecone first to use pinecone as the database backend. 59 | 60 | ```bash 61 | pip install pinecone 62 | ``` 63 | 64 | ```python 65 | import os 66 | 67 | os.environ["PINECONE_API_KEY"] = "YOUR_API_KEY" 68 | 69 | db = FaceDB( 70 | path="facedata", 71 | metric='euclidean', 72 | database_backend='pinecone', 73 | index_name='faces', 74 | embedding_dim=128, 75 | module='face_recognition', 76 | ) 77 | 78 | # This will create a pinecone index with name 'faces' in your environment if it doesn't exist 79 | 80 | # add multiple faces 81 | from glob import glob 82 | from pathlib import Path 83 | 84 | files = glob("faces/*.jpg") # Suppose you have a folder with imgs with names as filenames 85 | imgs = [] 86 | names = [] 87 | for file in files: 88 | imgs.append(file) 89 | names.append(Path(file).name) 90 | 91 | ids, failed_indexes = db.add_many( 92 | imgs=imgs, 93 | names=names, 94 | ) 95 | 96 | unknown_face = "unknown_face.jpg" 97 | result = db.recognize(img=unknown_face, include=['name']) 98 | if result: 99 | print(f"Recognized as {result['name']}") 100 | else: 101 | print("Unknown face") 102 | 103 | 104 | # Include img in the result 105 | result = db.recognize(img=unknown_face, include=['img']) 106 | if result: 107 | result.show_img() 108 | 109 | # # Use can also use show_img() for multiple results 110 | results = db.all(include='name') 111 | results.show_img() # make sure you have matplotlib installed 112 | 113 | # or 114 | img = result['img'] # cv2 image (numpy array) 115 | 116 | # Include embedding in the result 117 | result = db.recognize(img=unknown_face, include=['embedding']) 118 | if result: 119 | print(result['embedding']) 120 | 121 | 122 | # Search for similar faces 123 | results = db.search(img=unknown_face, top_k=5, include=['name'])[0] 124 | 125 | for result in results: 126 | print(f"Found {result['name']} with distance {result['distance']}") 127 | 128 | # or search for multiple faces 129 | multi_results = db.search(img=[img1, img2], top_k=5, include=['name']) 130 | 131 | for results in multi_results: 132 | for result in results: 133 | print(f"Found {result['name']} with distance {result['distance']}") 134 | 135 | # get all faces 136 | faces = db.get_all(include=['name', 'img']) 137 | 138 | # Update a face 139 | db.update(face_id, name="John Doe", img="john_doe.jpg", metadata1="metadata1", metadata2="metadata2") 140 | 141 | # Delete a face 142 | db.delete(face_id) 143 | 144 | # Count the number of faces in the database 145 | count = db.count() 146 | 147 | # Get pandas dataframe of all faces 148 | df = db.all().df 149 | ``` 150 | 151 | ## Simple Querying 152 | 153 | ```python 154 | 155 | # First add some faces to the database 156 | db.add("Nelson Mandela", img="mandela.jpg", profession="Politician", country="South Africa") 157 | db.add("Barack Obama", img="obama.jpg", profession="Politician", country="USA") 158 | db.add("Einstein", img="einstein.jpg", profession="Scientist", country="Germany") 159 | 160 | # Query the database by name 161 | results = db.query(name="Nelson Mandela") 162 | 163 | # Query the database by profession 164 | results = db.query(profession="Politician") 165 | ``` 166 | ## If you don't have an API key 167 | 168 | You can follow the official pinecone tutorial : https://docs.pinecone.io/docs/new-api 169 | It's easy to use and to understand, don't worry. 170 | 171 | ## Advanced Querying 172 | 173 | You can use following operators in queries: 174 | 175 | - $eq - Equal to (number, string, boolean) 176 | - $ne - Not equal to (number, string, boolean) 177 | - $gt - Greater than (number) 178 | - $lt - Less than (number) 179 | - $in - In array (string or number) 180 | - $regex - Regex match (string) 181 | 182 | ```python 183 | results = db.query( 184 | profession={"$eq": "Politician"}, 185 | country={"$in": ["USA", "South Africa"]}, 186 | ) 187 | # or write in a single json 188 | results = db.query( 189 | where={ 190 | "profession": {"$eq": "Politician"}, 191 | "country": {"$in": ["USA", "South Africa"]}, 192 | } 193 | ) 194 | 195 | # you can use show_img(), df, query to further filter the results 196 | results.show_img() 197 | results.df 198 | results.query(name="Nelson Mandela") 199 | 200 | ``` 201 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | import sys 4 | from pathlib import Path 5 | 6 | current_dir = Path(__file__).parent 7 | sys.path.append(str(current_dir.parent)) 8 | 9 | from facedb import ( 10 | FaceDB, 11 | ) 12 | 13 | 14 | class TestFaceDBChroma(unittest.TestCase): 15 | @classmethod 16 | def setUpClass(cls): 17 | cls.db = FaceDB( 18 | path="facedata", 19 | metric="euclidean", 20 | database_backend="chromadb", 21 | embedding_dim=128, 22 | module="face_recognition", 23 | ) 24 | 25 | def test_add_many(self): 26 | files = [ 27 | current_dir / "imgs" / "joe_biden.jpeg", 28 | current_dir / "imgs" / "no_face.png", 29 | current_dir / "imgs" / "narendra_modi.jpeg", 30 | ] 31 | imgs = [] 32 | names = [] 33 | for file in files: 34 | imgs.append(file) 35 | names.append(Path(file).stem) 36 | 37 | ids, failed_indexes = self.db.add_many(imgs=imgs, names=names) 38 | 39 | self.assertEqual(len(failed_indexes), 1) 40 | self.assertEqual(len(ids), 2) 41 | 42 | def test_recognize_known_face(self): 43 | known_face = str(current_dir / "imgs" / "joe_biden_2.jpeg") 44 | result = self.db.recognize(img=known_face, include=["name"]) 45 | self.assertIsNotNone(result) 46 | if result: 47 | self.assertIn("joe_biden", result["name"]) # type: ignore 48 | 49 | def test_recognize_unknown_face(self): 50 | unknown_face = current_dir / "imgs" / "barak_obama.jpeg" 51 | result = self.db.recognize(img=unknown_face, include=["name"]) 52 | self.assertIsNone(result) 53 | 54 | def test_update(self): 55 | img = current_dir / "imgs" / "joe_biden_2.jpeg" 56 | idx = self.db.recognize(img=img, include=["name"]).id # type: ignore 57 | self.db.update(id=idx, name="joe_biden_2") 58 | 59 | result = self.db.recognize(img=img, include=["name"]) 60 | self.assertIsNotNone(result) 61 | if result: 62 | self.assertIn("joe_biden_2", result["name"]) # type: ignore 63 | 64 | def test_get(self): 65 | img = current_dir / "imgs" / "joe_biden_2.jpeg" 66 | idx = self.db.recognize(img=img, include=["name"]).id # type: ignore 67 | result = self.db.get(id=idx, include=["name"]) 68 | self.assertIsNotNone(result) 69 | if result: 70 | self.assertIn("joe_biden_2", result["name"]) # type: ignore 71 | 72 | def test_delete(self): 73 | img = current_dir / "imgs" / "joe_biden_2.jpeg" 74 | idx = self.db.recognize(img=img, include=["name"]).id # type: ignore 75 | self.db.delete(id=idx) 76 | result = self.db.get(id=idx, include=["name"]) 77 | if result is None: 78 | self.assertIsNone(result) 79 | else: 80 | self.assertEqual(len(result), 0) 81 | 82 | def test_search(self): 83 | img = current_dir / "imgs" / "joe_biden_2.jpeg" 84 | emb = self.db.embedding_func(img) 85 | result = self.db.search(embedding=emb, include=["name"]) 86 | self.assertIsNotNone(result) 87 | 88 | def test_query(self): 89 | results = self.db.query(name="narendra_modi", include=["name"]) 90 | self.assertIsNotNone(results) 91 | if results: 92 | self.assertEqual(results[0]["name"], "narendra_modi") # type: ignore 93 | 94 | @classmethod 95 | def tearDownClass(cls): 96 | cls.db.delete_all() 97 | 98 | 99 | class TestFaceDBPinecone(unittest.TestCase): 100 | @classmethod 101 | def setUpClass(cls): 102 | cls.db = FaceDB( 103 | path="facedata", 104 | metric="euclidean", 105 | database_backend="pinecone", 106 | embedding_dim=128, 107 | module="face_recognition", 108 | pinecone_settings={ 109 | "index_name": "test-face-db", 110 | }, 111 | ) 112 | 113 | def test_add_many(self): 114 | files = [ 115 | current_dir / "imgs" / "joe_biden.jpeg", 116 | current_dir / "imgs" / "no_face.png", 117 | current_dir / "imgs" / "narendra_modi.jpeg", 118 | ] 119 | imgs = [] 120 | names = [] 121 | for file in files: 122 | imgs.append(file) 123 | names.append(Path(file).stem) 124 | 125 | ids, failed_indexes = self.db.add_many(imgs=imgs, names=names) 126 | 127 | print( 128 | f"Failed indexes: {failed_indexes}\n" 129 | f"IDs: {ids}" 130 | ) 131 | print(self.db.all(include=["name"])) 132 | 133 | self.assertEqual(len(failed_indexes), 1) 134 | self.assertEqual(len(ids), 2) 135 | 136 | def test_recognize_known_face(self): 137 | known_face = str(current_dir / "imgs" / "joe_biden_2.jpeg") 138 | result = self.db.recognize(img=known_face, include=["name"]) 139 | self.assertIsNotNone(result) 140 | if result: 141 | self.assertIn("joe_biden", result["name"]) # type: ignore 142 | 143 | def test_recognize_unknown_face(self): 144 | unknown_face = current_dir / "imgs" / "barak_obama.jpeg" 145 | result = self.db.recognize(img=unknown_face, include=["name"]) 146 | self.assertIsNone(result) 147 | 148 | def test_update(self): 149 | img = current_dir / "imgs" / "joe_biden_2.jpeg" 150 | idx = self.db.recognize(img=img, include=["name"]).id # type: ignore 151 | self.db.update(id=idx, name="joe_biden_2") 152 | 153 | result = self.db.recognize(img=img, include=["name"]) 154 | self.assertIsNotNone(result) 155 | if result: 156 | self.assertIn("joe_biden_2", result["name"]) # type: ignore 157 | 158 | def test_get(self): 159 | img = current_dir / "imgs" / "joe_biden_2.jpeg" 160 | idx = self.db.recognize(img=img, include=["name"]).id # type: ignore 161 | result = self.db.get(id=idx, include=["name"]) 162 | self.assertIsNotNone(result) 163 | if result: 164 | self.assertIn("joe_biden_2", result["name"]) # type: ignore 165 | 166 | def test_delete(self): 167 | img = current_dir / "imgs" / "joe_biden_2.jpeg" 168 | idx = self.db.recognize(img=img, include=["name"]).id # type: ignore 169 | self.db.delete(id=idx) 170 | result = self.db.get(id=idx, include=["name"]) 171 | if result is None: 172 | self.assertIsNone(result) 173 | else: 174 | self.assertEqual(len(result), 0) 175 | 176 | def test_search(self): 177 | img = current_dir / "imgs" / "joe_biden_2.jpeg" 178 | emb = self.db.embedding_func(img) 179 | result = self.db.search(embedding=emb, include=["name"]) 180 | self.assertIsNotNone(result) 181 | 182 | def test_query(self): 183 | results = self.db.query(name="narendra_modi", include=["name"]) 184 | self.assertIsNotNone(results) 185 | if results: 186 | self.assertEqual(results[0]["name"], "narendra_modi") # type: ignore 187 | 188 | @classmethod 189 | def tearDownClass(cls): 190 | cls.db.delete_all() 191 | 192 | 193 | if __name__ == "__main__": 194 | suite = unittest.TestSuite() 195 | suite.addTest(TestFaceDBChroma("test_add_many")) 196 | suite.addTest(TestFaceDBChroma("test_recognize_known_face")) 197 | suite.addTest(TestFaceDBChroma("test_recognize_unknown_face")) 198 | suite.addTest(TestFaceDBChroma("test_update")) 199 | suite.addTest(TestFaceDBChroma("test_get")) 200 | suite.addTest(TestFaceDBChroma("test_search")) 201 | suite.addTest(TestFaceDBChroma("test_delete")) 202 | suite.addTest(TestFaceDBChroma("test_query")) 203 | 204 | # get api_key and env from sys 205 | api_key = sys.argv[1] 206 | 207 | os.environ["PINECONE_API_KEY"] = api_key 208 | 209 | suite.addTest(TestFaceDBPinecone("test_add_many")) 210 | suite.addTest(TestFaceDBPinecone("test_recognize_known_face")) 211 | suite.addTest(TestFaceDBPinecone("test_recognize_unknown_face")) 212 | suite.addTest(TestFaceDBPinecone("test_update")) 213 | suite.addTest(TestFaceDBPinecone("test_get")) 214 | suite.addTest(TestFaceDBPinecone("test_search")) 215 | suite.addTest(TestFaceDBPinecone("test_delete")) 216 | suite.addTest(TestFaceDBPinecone("test_query")) 217 | 218 | # Run the test suite 219 | runner = unittest.TextTestRunner() 220 | result = runner.run(suite) 221 | -------------------------------------------------------------------------------- /facedb/query.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from mongoquery (https://github.com/kapouille/mongoquery) 3 | 4 | mongoquery provides a straightforward API to match Python objects against 5 | MongoDB Query Language queries. 6 | """ 7 | 8 | import re 9 | from collections.abc import Sequence, Mapping 10 | from six import string_types 11 | 12 | 13 | try: 14 | string_type = basestring # type: ignore 15 | except NameError: 16 | string_type = str 17 | 18 | 19 | try: 20 | regex_type = re.Pattern 21 | except AttributeError: 22 | regex_type = re._pattern_type # type: ignore 23 | 24 | 25 | class QueryError(Exception): 26 | """Query error exception""" 27 | 28 | pass 29 | 30 | 31 | class _Undefined(object): 32 | # pylint: disable=too-few-public-methods 33 | pass 34 | 35 | 36 | def is_non_string_sequence(entry): 37 | """Returns True if entry is a Python sequence iterable, and not a string""" 38 | return isinstance(entry, Sequence) and not isinstance(entry, string_type) 39 | 40 | 41 | class Query(object): 42 | """The Query class is used to match an object against a MongoDB-like query""" 43 | 44 | # pylint: disable=too-few-public-methods 45 | def __init__(self, definition): 46 | self._definition = definition 47 | 48 | def match(self, entry): 49 | """Matches the entry object against the query specified on instanciation""" 50 | return self._match(self._definition, entry) 51 | 52 | def _match(self, condition, entry): 53 | if isinstance(condition, Mapping): 54 | return all( 55 | self._process_condition(sub_operator, sub_condition, entry) 56 | for sub_operator, sub_condition in condition.items() 57 | ) 58 | if is_non_string_sequence(entry): 59 | return condition in entry 60 | return condition == entry 61 | 62 | def _extract(self, entry, path): 63 | if not path: 64 | return entry 65 | if entry is None: 66 | return entry 67 | if is_non_string_sequence(entry): 68 | try: 69 | index = int(path[0]) 70 | return self._extract(entry[index], path[1:]) 71 | except ValueError: 72 | return [self._extract(item, path) for item in entry] 73 | elif isinstance(entry, Mapping) and path[0] in entry: 74 | return self._extract(entry[path[0]], path[1:]) 75 | else: 76 | return _Undefined() 77 | 78 | def _path_exists(self, operator, condition, entry): 79 | keys_list = list(operator.split(".")) 80 | for i, k in enumerate(keys_list): 81 | if isinstance(entry, Sequence) and not k.isdigit(): 82 | for elem in entry: 83 | operator = ".".join(keys_list[i:]) 84 | if self._path_exists(operator, condition, elem) == condition: 85 | return condition 86 | return not condition 87 | elif isinstance(entry, Sequence): 88 | k = int(k) 89 | try: 90 | entry = entry[k] 91 | except (TypeError, IndexError, KeyError): 92 | return not condition 93 | return condition 94 | 95 | def _process_condition(self, operator, condition, entry): 96 | if isinstance(condition, Mapping) and "$exists" in condition: 97 | if isinstance(operator, string_types) and operator.find(".") != -1: 98 | return self._path_exists(operator, condition["$exists"], entry) 99 | elif condition["$exists"] != (operator in entry): 100 | return False 101 | elif tuple(condition.keys()) == ("$exists",): 102 | return True 103 | if isinstance(operator, string_type): 104 | if operator.startswith("$"): 105 | try: 106 | return getattr(self, "_" + operator[1:])(condition, entry) 107 | except AttributeError: 108 | raise QueryError("{!r} operator isn't supported".format(operator)) 109 | else: 110 | try: 111 | extracted_data = self._extract(entry, operator.split(".")) 112 | except IndexError: 113 | extracted_data = _Undefined() 114 | else: 115 | if operator not in entry: 116 | return False 117 | extracted_data = entry[operator] 118 | return self._match(condition, extracted_data) 119 | 120 | @staticmethod 121 | def _not_implemented(*_): 122 | raise NotImplementedError 123 | 124 | @staticmethod 125 | def _noop(*_): 126 | return True 127 | 128 | @staticmethod 129 | def _eq(condition, entry): 130 | try: 131 | return entry == condition 132 | except TypeError: 133 | return False 134 | 135 | @staticmethod 136 | def _gt(condition, entry): 137 | try: 138 | return entry > condition 139 | except TypeError: 140 | return False 141 | 142 | @staticmethod 143 | def _gte(condition, entry): 144 | try: 145 | return entry >= condition 146 | except TypeError: 147 | return False 148 | 149 | @staticmethod 150 | def _in(condition, entry): 151 | if is_non_string_sequence(condition): 152 | for elem in condition: 153 | if is_non_string_sequence(entry) and elem in entry: 154 | return True 155 | elif not is_non_string_sequence(entry) and elem == entry: 156 | return True 157 | return False 158 | else: 159 | raise TypeError("condition must be a list") 160 | 161 | @staticmethod 162 | def _lt(condition, entry): 163 | try: 164 | return entry < condition 165 | except TypeError: 166 | return False 167 | 168 | @staticmethod 169 | def _lte(condition, entry): 170 | try: 171 | return entry <= condition 172 | except TypeError: 173 | return False 174 | 175 | @staticmethod 176 | def _ne(condition, entry): 177 | return entry != condition 178 | 179 | def _nin(self, condition, entry): 180 | return not self._in(condition, entry) 181 | 182 | def _and(self, condition, entry): 183 | if isinstance(condition, Sequence): 184 | return all(self._match(sub_condition, entry) for sub_condition in condition) 185 | raise QueryError( 186 | "$and has been attributed incorrect argument {!r}".format(condition) 187 | ) 188 | 189 | def _nor(self, condition, entry): 190 | if isinstance(condition, Sequence): 191 | return all( 192 | not self._match(sub_condition, entry) for sub_condition in condition 193 | ) 194 | raise QueryError( 195 | "$nor has been attributed incorrect argument {!r}".format(condition) 196 | ) 197 | 198 | def _not(self, condition, entry): 199 | return not self._match(condition, entry) 200 | 201 | def _or(self, condition, entry): 202 | if isinstance(condition, Sequence): 203 | return any(self._match(sub_condition, entry) for sub_condition in condition) 204 | raise QueryError( 205 | "$or has been attributed incorrect argument {!r}".format(condition) 206 | ) 207 | 208 | @staticmethod 209 | def _type(condition, entry): 210 | # TODO: further validation to ensure the right type 211 | # rather than just checking 212 | bson_type = { 213 | 1: float, 214 | 2: string_type, 215 | 3: Mapping, 216 | 4: Sequence, 217 | 5: bytearray, 218 | 7: string_type, # object id (uuid) 219 | 8: bool, 220 | 9: string_type, # date (UTC datetime) 221 | 10: type(None), 222 | 11: regex_type, # regex, 223 | 13: string_type, # Javascript 224 | 15: string_type, # JavaScript (with scope) 225 | 16: int, # 32-bit integer 226 | 17: int, # Timestamp 227 | 18: int, # 64-bit integer 228 | } 229 | bson_alias = { 230 | "double": 1, 231 | "string": 2, 232 | "object": 3, 233 | "array": 4, 234 | "binData": 5, 235 | "objectId": 7, 236 | "bool": 8, 237 | "date": 9, 238 | "null": 10, 239 | "regex": 11, 240 | "javascript": 13, 241 | "javascriptWithScope": 15, 242 | "int": 16, 243 | "timestamp": 17, 244 | "long": 18, 245 | } 246 | 247 | if condition == "number": 248 | return any( 249 | [ 250 | isinstance(entry, bson_type[bson_alias[alias]]) 251 | for alias in ["double", "int", "long"] 252 | ] 253 | ) 254 | 255 | # resolves bson alias, or keeps original condition value 256 | condition = bson_alias.get(condition, condition) 257 | 258 | if condition not in bson_type: 259 | raise QueryError( 260 | "$type has been used with unknown type {!r}".format(condition) 261 | ) 262 | 263 | return isinstance(entry, bson_type.get(condition)) # type: ignore 264 | 265 | _exists = _noop 266 | 267 | @staticmethod 268 | def _mod(condition, entry): 269 | return entry % condition[0] == condition[1] 270 | 271 | @staticmethod 272 | def _regex(condition, entry): 273 | if not isinstance(entry, string_type): 274 | return False 275 | # If the caller has supplied a compiled regex, assume options are already 276 | # included. 277 | if isinstance(condition, regex_type): 278 | return bool(re.search(condition, entry)) 279 | try: 280 | regex = re.match(r"\A/(.+)/([imsx]{,4})\Z", condition, flags=re.DOTALL) 281 | except TypeError: 282 | raise QueryError( 283 | "{!r} is not a regular expression " 284 | "and should be a string".format(condition) 285 | ) 286 | 287 | flags = 0 288 | if regex: 289 | options = regex.group(2) 290 | for option in options: 291 | flags |= getattr(re, option.upper()) 292 | exp = regex.group(1) 293 | else: 294 | exp = condition 295 | 296 | try: 297 | match = re.search(exp, entry, flags=flags) 298 | except Exception as error: 299 | raise QueryError( 300 | "{!r} failed to execute with error {!r}".format(condition, error) 301 | ) 302 | return bool(match) 303 | 304 | _options = _text = _where = _not_implemented 305 | 306 | def _all(self, condition, entry): 307 | return all(self._match(item, entry) for item in condition) 308 | 309 | def _elemMatch(self, condition, entry): 310 | # pylint: disable=invalid-name 311 | if not isinstance(entry, Sequence): 312 | return False 313 | return any( 314 | all( 315 | self._process_condition(sub_operator, sub_condition, element) 316 | for sub_operator, sub_condition in condition.items() 317 | ) 318 | for element in entry 319 | ) 320 | 321 | @staticmethod 322 | def _size(condition, entry): 323 | if not isinstance(condition, int): 324 | raise QueryError( 325 | "$size has been attributed incorrect argument {!r}".format(condition) 326 | ) 327 | 328 | if is_non_string_sequence(entry): 329 | return len(entry) == condition 330 | 331 | return False 332 | 333 | _comment = _noop 334 | -------------------------------------------------------------------------------- /facedb/db_tools.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import sqlite3 3 | 4 | try: 5 | from typing import Optional, Union, List, Callable, Literal, Tuple 6 | except ImportError: 7 | from typing_extensions import Optional, Union, List, Callable, Literal, Tuple 8 | 9 | import cv2 10 | import numpy as np 11 | from PIL import Image 12 | import io 13 | from pathlib import Path 14 | import warnings 15 | import pprint 16 | import json 17 | import requests 18 | 19 | 20 | def l2_normalize(x): 21 | return x / np.sqrt(np.sum(np.multiply(x, x))) 22 | 23 | 24 | def is_none_or_empty(x): 25 | if x is None: 26 | return True 27 | else: 28 | return len(x) == 0 29 | 30 | 31 | def img_to_cv2(img): 32 | if isinstance(img, str): 33 | if "https://" in img or "http://" in img: 34 | img = requests.get(img).content 35 | img = np.frombuffer(img, np.uint8) 36 | img = cv2.imdecode(img, cv2.IMREAD_COLOR) 37 | else: 38 | img = cv2.imread(img) 39 | elif isinstance(img, Path): 40 | img = cv2.imread(str(img)) 41 | elif isinstance(img, np.ndarray): 42 | pass 43 | elif isinstance(img, Image.Image): 44 | img = np.array(img) 45 | elif isinstance(img, io.BytesIO): 46 | img = np.frombuffer(img.getvalue(), np.uint8) 47 | img = cv2.imdecode(img, cv2.IMREAD_COLOR) 48 | elif isinstance(img, bytes): 49 | img = np.frombuffer(img, np.uint8) 50 | img = cv2.imdecode(img, cv2.IMREAD_COLOR) 51 | else: 52 | raise TypeError(f"Unknown type of img: {type(img)}") 53 | return img 54 | 55 | 56 | def img_to_bytes(img): 57 | img = img_to_cv2(img) 58 | return cv2.imencode(".jpg", img)[1].tobytes() 59 | 60 | 61 | def many_vectors(obj): 62 | if isinstance(obj, list): 63 | if len(obj) == 0: 64 | return False 65 | elif isinstance(obj[0], list): 66 | return True 67 | elif isinstance(obj[0], np.ndarray): 68 | return True 69 | return False 70 | 71 | 72 | def is_list_of_img(obj): 73 | if isinstance(obj, list): 74 | return True 75 | return False 76 | 77 | 78 | def is_2d(x): 79 | if isinstance(x, list): 80 | x = np.array(x) 81 | if len(x.shape) == 2: 82 | return True 83 | return False 84 | 85 | 86 | def convert_shape(x): 87 | if isinstance(x, list): 88 | x = np.array(x) 89 | if len(x.shape) == 3: 90 | x = x.squeeze() 91 | return x 92 | 93 | 94 | def get_embeddings( 95 | imgs: Optional[Union[str, List[str], np.ndarray, List[np.ndarray]]] = None, # type: ignore 96 | embeddings: Optional[Union[List[List[float]], List[np.ndarray]]] = None, # type: ignore 97 | embedding_func: Optional[ 98 | Callable[[Union[str, np.ndarray]], Union[List[float], np.ndarray]] 99 | ] = None, 100 | raise_error: bool = True, 101 | ) -> List[List[float]]: 102 | if embeddings is None: 103 | if imgs is None: 104 | if raise_error: 105 | raise ValueError("imgs and embeddings cannot be both None") 106 | else: 107 | return [] 108 | 109 | elif not is_list_of_img(imgs): 110 | imgs = [imgs] # type: ignore 111 | 112 | if embedding_func is None: 113 | if raise_error: 114 | raise ValueError("embedding_func cannot be None") 115 | else: 116 | return [] 117 | 118 | embeddings = [] 119 | for img in imgs: # type: ignore 120 | embeds = embedding_func(img) 121 | if is_none_or_empty(embeds): 122 | continue 123 | for embed in embeds: 124 | embeddings.append(embed) # type: ignore 125 | 126 | embeddings = convert_shape(embeddings).tolist() 127 | 128 | return embeddings # type: ignore 129 | 130 | 131 | def get_include(default=None, include=None): 132 | if include is None: 133 | include = [] 134 | elif isinstance(include, str): 135 | include = [include] 136 | 137 | sincludes = [default] if default else [] 138 | if "embedding" in include: 139 | sincludes.append("embeddings") 140 | 141 | if include: 142 | sincludes.append("metadatas") 143 | 144 | return sincludes, include 145 | 146 | 147 | fthresholds = { 148 | "pinecone": { 149 | "cosine": {"value": 0.07, "operator": "le", "direction": "negative"}, 150 | "cosine_l2": { 151 | "value": 0.07, 152 | "operator": "le", 153 | "direction": "negative", 154 | }, 155 | "dotproduct": { 156 | "value": -0.8, 157 | "operator": "ge", 158 | "direction": "positive", 159 | }, 160 | "dotproduct_l2": { 161 | "value": 0.07, 162 | "operator": "le", 163 | "direction": "negative", 164 | }, 165 | "euclidean": { 166 | "value": 0.72, 167 | "operator": "ge", 168 | "direction": "positive", 169 | }, 170 | "euclidean_l2": { 171 | "value": 0.85, 172 | "operator": "ge", 173 | "direction": "positive", 174 | }, 175 | }, 176 | "chromadb": { 177 | "cosine": {"value": 0.06, "operator": "le", "direction": "negative"}, 178 | "cosine_l2": { 179 | "value": 0.07, 180 | "operator": "le", 181 | "direction": "negative", 182 | }, 183 | "ip": {"value": -1.1, "operator": "ge", "direction": "positive"}, 184 | "ip_l2": {"value": 0.07, "operator": "le", "direction": "negative"}, 185 | "l2": {"value": 0.27, "operator": "le", "direction": "negative"}, 186 | "l2_l2": {"value": 0.14, "operator": "le", "direction": "negative"}, 187 | }, 188 | } 189 | 190 | 191 | def time_now(): 192 | return datetime.now().strftime("%m-%d-%Y-%I-%M-%S-%p") 193 | 194 | 195 | def get_model_dimension(module, model_name): 196 | if module == "face_recognition": 197 | return 128 198 | elif module == "deepface": 199 | dim_map = { 200 | "VGG-Face": 2622, 201 | "Facenet": 128, 202 | "Facenet512": 512, 203 | "OpenFace": 128, 204 | "DeepFace": 8631, 205 | "DeepID": 160, 206 | "Dlib": 128, 207 | "ArcFace": 512, 208 | "Ensemble": 8631, 209 | } 210 | return dim_map[model_name] 211 | else: 212 | raise ValueError(f"Unknown module: {module}") 213 | 214 | 215 | class FailedImageIndex(int): 216 | def __new__(cls, value, failed_reason): 217 | obj = super().__new__(cls, value) 218 | obj.failed_reason = failed_reason 219 | return obj 220 | 221 | def __init__(self, value, failed_reason): 222 | self.failed_reason = failed_reason 223 | self.idx = value 224 | 225 | def __repr__(self): 226 | return f"{self.idx} ({self.failed_reason})" 227 | 228 | 229 | class FailedImageIndexList(list): 230 | def __init__(self, *args, **kwargs): 231 | super().__init__(*args, **kwargs) 232 | self.failed_reasons = [] 233 | 234 | def append(self, value, failed_reason): 235 | super().append(FailedImageIndex(value, failed_reason)) 236 | self.failed_reasons.append(failed_reason) 237 | 238 | def __repr__(self): 239 | txt = "FailedImageIndexList:\n" 240 | for i, j in zip(self, self.failed_reasons): 241 | txt += f"{i} ({j})\n" 242 | return txt 243 | 244 | def __str__(self): 245 | return self.__repr__() 246 | 247 | 248 | class Rect(dict): 249 | def __init__(self, x, y, w, h): 250 | super().__init__(x=x, y=y, width=w, height=h) 251 | 252 | self.x = x 253 | self.y = y 254 | self.w = w 255 | self.h = h 256 | 257 | @property 258 | def width(self): 259 | return self.w 260 | 261 | @width.setter 262 | def width(self, value): 263 | self.w = value 264 | 265 | @property 266 | def height(self): 267 | return self.h 268 | 269 | @height.setter 270 | def height(self, value): 271 | self.h = value 272 | 273 | @classmethod 274 | def from_json(cls, json): 275 | if "width" in json: 276 | return cls(json["x"], json["y"], json["width"], json["height"]) 277 | else: 278 | return cls(json["x"], json["y"], json["w"], json["h"]) 279 | 280 | def to(self, module): 281 | if module == "face_recognition": 282 | return (self.y, self.x + self.w, self.y + self.h, self.x) 283 | else: 284 | return [self.x, self.y, self.w, self.h] 285 | 286 | def to_json(self): 287 | return json.dumps(self) 288 | 289 | def __len__(self): 290 | return 4 291 | 292 | def __repr__(self): 293 | return f"" 294 | 295 | def __str__(self): 296 | return pprint.pformat(self.to_json()) 297 | 298 | def __getitem__(self, key): 299 | if key == 0: 300 | return self.x 301 | elif key == 1: 302 | return self.y 303 | elif key == 2: 304 | return self.width 305 | elif key == 3: 306 | return self.height 307 | else: 308 | raise IndexError("Rect index out of range") 309 | 310 | def __setitem__(self, key, value): 311 | if key == 0: 312 | self.x = value 313 | elif key == 1: 314 | self.y = value 315 | elif key == 2: 316 | self.width = value 317 | elif key == 3: 318 | self.height = value 319 | else: 320 | raise IndexError("Rect index out of range") 321 | 322 | def __iter__(self): 323 | return iter([self.x, self.y, self.width, self.height]) 324 | 325 | def __eq__(self, other): 326 | if isinstance(other, Rect): 327 | return all( 328 | [ 329 | self.x == other.x, 330 | self.y == other.y, 331 | self.width == other.width, 332 | self.height == other.height, 333 | ] 334 | ) 335 | elif isinstance(other, list) or isinstance(other, tuple): 336 | return all( 337 | [ 338 | self.x == other[0], 339 | self.y == other[1], 340 | self.width == other[2], 341 | self.height == other[3], 342 | ] 343 | ) 344 | else: 345 | return False 346 | 347 | def __ne__(self, other): 348 | return not self.__eq__(other) 349 | 350 | 351 | class ImgDB: 352 | def __init__(self, db_path): 353 | self.db_path = db_path 354 | self.conn = sqlite3.connect(db_path, check_same_thread=False) 355 | self.cursor = self.conn.cursor() 356 | self.create_table() 357 | 358 | def delete_all(self): 359 | self.cursor.execute("""DELETE FROM img""") 360 | self.conn.commit() 361 | 362 | def __del__(self): 363 | self.conn.close() 364 | 365 | # img, id (str) 366 | def create_table(self): 367 | self.cursor.execute( 368 | """CREATE TABLE IF NOT EXISTS img ( 369 | img_id TEXT PRIMARY KEY, 370 | img BLOB 371 | )""" 372 | ) 373 | self.conn.commit() 374 | 375 | def _add(self, *, img_id, img): 376 | img = img_to_bytes(img) 377 | self.cursor.execute("""INSERT INTO img VALUES (?, ?)""", (img_id, img)) 378 | 379 | def add(self, *, img_id: Union[str, list], img): 380 | if isinstance(img, list): 381 | if len(img) != len(img_id): 382 | raise ValueError("Length of `img` and `img_id` must be same.") 383 | for i, j in zip(img_id, img): 384 | self._add(img_id=i, img=j) 385 | else: 386 | self._add(img_id=img_id, img=img) 387 | self.conn.commit() 388 | 389 | def add_rects(self, *, img, img_ids: List[str], rects: List[Rect], zoom_out=0.25): 390 | img = img_to_cv2(img) 391 | img_h, img_w = img.shape[:2] 392 | for rect in rects: 393 | rect.x = max(0, rect.x - int(rect.w * zoom_out)) 394 | rect.y = max(0, rect.y - int(rect.h * zoom_out)) 395 | rect.w = min(img_w - rect.x, rect.w + int(rect.w * zoom_out)) 396 | rect.h = min(img_h - rect.y, rect.h + int(rect.h * zoom_out)) 397 | 398 | if len(img_ids) != len(rects): 399 | raise ValueError("Length of `img_ids` and `rects` must be same.") 400 | 401 | for img_id, rect in zip(img_ids, rects): 402 | self._add( 403 | img_id=img_id, 404 | img=img[rect.y : rect.y + rect.h, rect.x : rect.x + rect.w], 405 | ) 406 | self.conn.commit() 407 | 408 | def get(self, img_id): 409 | self.cursor.execute("""SELECT img FROM img WHERE img_id=?""", (img_id,)) 410 | img = self.cursor.fetchone() 411 | if img is None: 412 | return None 413 | 414 | try: 415 | img = cv2.imdecode(np.frombuffer(img[0], np.uint8), cv2.IMREAD_COLOR) 416 | except Exception as e: 417 | warnings.warn(f"Error in decoding image: {e}") 418 | return None 419 | 420 | return img 421 | 422 | def delete(self, img_id): 423 | self.cursor.execute("""DELETE FROM img WHERE img_id=?""", (img_id,)) 424 | self.conn.commit() 425 | 426 | def update(self, *, img_id, img): 427 | img = img_to_bytes(img) 428 | self.cursor.execute("""UPDATE img SET img=? WHERE img_id=?""", (img, img_id)) 429 | self.conn.commit() 430 | 431 | def auto(self, *, img_id, img): 432 | img = img_to_bytes(img) 433 | self.cursor.execute( 434 | """INSERT OR REPLACE INTO img VALUES (?, ?)""", (img_id, img) 435 | ) 436 | self.conn.commit() 437 | -------------------------------------------------------------------------------- /facedb/db_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import pprint 4 | from pathlib import Path 5 | import os 6 | from facedb.query import Query 7 | from math import ceil 8 | import warnings 9 | 10 | try: 11 | from typing import Literal, Optional, Union, List 12 | except ImportError: 13 | from typing_extensions import Literal, Optional, Union, List 14 | 15 | # pinecone = None 16 | 17 | PINECONE_IMPORTED = False 18 | try: 19 | from pinecone import Pinecone, Index, ServerlessSpec 20 | 21 | PINECONE_IMPORTED = True 22 | except ImportError: 23 | pass 24 | 25 | CHROMADB_IMPORTED = False 26 | try: 27 | import chromadb 28 | 29 | CHROMADB_IMPORTED = True 30 | except ImportError: 31 | pass 32 | 33 | if not PINECONE_IMPORTED and not CHROMADB_IMPORTED: 34 | raise ImportError("Please install `pinecone` or `chromadb` to use this module.") 35 | 36 | 37 | def many_vectors(obj): 38 | if isinstance(obj, list): 39 | if len(obj) == 0: 40 | return False 41 | elif isinstance(obj[0], list): 42 | return True 43 | elif isinstance(obj[0], np.ndarray): 44 | return True 45 | return False 46 | 47 | 48 | def calculate_confidence(dis, threshold, direction, assume=80): 49 | if direction == "positive": 50 | return ceil((assume / (threshold)) * (dis)) 51 | elif direction == "negative": 52 | return ceil((assume / (1 - threshold)) * (1 - dis)) 53 | 54 | 55 | class FaceResults(List["FaceResult"]): 56 | def __init__(self, *args, **kw): 57 | super().__init__(*args, **kw) 58 | 59 | if len(self) == 1: 60 | self.name = self[0].name # type: ignore 61 | self.id = self[0].id 62 | self.distance = self[0].distance 63 | self.embedding = self[0].embedding 64 | self.img = self[0].img 65 | self.kw = self[0].kw 66 | 67 | for i in self.kw: 68 | setattr(self, i, self.kw[i]) 69 | 70 | def query(self, *args, **kw): 71 | q = {} 72 | for i in args: 73 | q.update(i) 74 | q.update(kw) 75 | 76 | query = Query(q) 77 | results = list(filter(query.match, self)) # type: ignore 78 | return FaceResults(results) 79 | 80 | def query_generator(self, query): 81 | query = Query(query) 82 | result = filter(query.match, self) # type: ignore 83 | for i in result: 84 | yield i 85 | 86 | def __getitem__(self, key): 87 | if isinstance(key, int): 88 | return super().__getitem__(key) 89 | elif isinstance(key, str): 90 | if key == "name": 91 | return self.name 92 | elif key == "id": 93 | return self.id 94 | elif key == "distance": 95 | return self.distance 96 | elif key == "embedding": 97 | return self.embedding 98 | elif key == "img": 99 | return self.img 100 | else: 101 | return self.kw[key] 102 | 103 | def __repr__(self): 104 | txt = f"FaceResults [\n" 105 | ct = 0 106 | for i in self: 107 | t = f"id={i['id']} name={i['name']}" 108 | if i.get("confidence"): 109 | t += f" confidence={i['confidence']}%" 110 | t += ",\n" 111 | txt += t 112 | if ct == 5: 113 | txt += f"... {len(self) - ct} more" 114 | break 115 | ct += 1 116 | if txt.endswith(",\n"): 117 | txt = txt[:-2] 118 | 119 | txt += " ]" 120 | return txt 121 | 122 | def __str__(self): 123 | return self.__repr__() 124 | 125 | @property 126 | def df(self): 127 | import pandas as pd 128 | 129 | data = [] 130 | for i in self: 131 | d = {} 132 | for k in i: 133 | val = i[k] 134 | if val is not None: 135 | if k == "embedding": 136 | val = f"Embedding({len(val)} dim)" 137 | elif k == "img": 138 | val = "Image(cv2 image)" 139 | d[k] = val 140 | data.append(d) 141 | 142 | return pd.DataFrame(data) 143 | 144 | def show_img(self, limit=10, page=1, img_size=(100, 100)): 145 | if len(self) == 0: 146 | print("No image available") 147 | return 148 | 149 | if page < 1: 150 | page = 1 151 | elif page > len(self) // limit + 1: 152 | page = len(self) // limit + 1 153 | 154 | images = [] 155 | for i in range((page - 1) * limit, min(page * limit, len(self))): 156 | images.append({"name": self[i].name, "image": self[i].img}) # type: ignore 157 | 158 | fixed_size = (200, 200) 159 | for item in images: 160 | if item["image"] is None: 161 | # create a blank image white color 162 | img = np.zeros((fixed_size[0], fixed_size[1], 3), dtype=np.uint8) 163 | # put not img available text in middle 164 | cv2.putText( 165 | img, 166 | "No image available", 167 | (10, 100), 168 | cv2.FONT_HERSHEY_DUPLEX, 169 | 0.5, 170 | (255, 255, 255), 171 | 1, 172 | ) 173 | item["image"] = img 174 | else: 175 | # resize image 176 | item["image"] = cv2.resize(item["image"], fixed_size) 177 | 178 | # Define text settings 179 | font = cv2.FONT_HERSHEY_DUPLEX 180 | font_scale = 0.3 181 | font_thickness = 1 182 | 183 | # Create a blank canvas 184 | num_rows = 2 185 | num_cols = 2 186 | canvas_width = fixed_size[0] * num_cols 187 | canvas_height = fixed_size[1] * num_rows 188 | canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8) 189 | 190 | # Populate the canvas with images and their names 191 | for i in range(num_rows): 192 | for j in range(num_cols): 193 | index = i * num_cols + j 194 | if index < len(images): 195 | y_offset = i * fixed_size[1] 196 | x_offset = j * fixed_size[0] 197 | 198 | # Add the image to the canvas 199 | canvas[ 200 | y_offset : y_offset + fixed_size[1], 201 | x_offset : x_offset + fixed_size[0], 202 | ] = images[index]["image"] 203 | 204 | # Add a blurred background for the name 205 | name = images[index]["name"] 206 | 207 | # add name in bottom left corner 208 | text_size, _ = cv2.getTextSize( 209 | name, font, font_scale, font_thickness 210 | ) 211 | text_x = x_offset + 5 212 | text_y = y_offset + fixed_size[1] - 20 213 | 214 | avg_color = np.average(images[index]["image"], axis=(0, 1)) 215 | text_color = ( 216 | (255, 255, 255) if np.mean(avg_color) < 128 else (0, 0, 0) 217 | ) 218 | 219 | # Create a region of interest (ROI) for the text background 220 | roi = canvas[ 221 | text_y : text_y + text_size[1] + 5, 222 | text_x : text_x + text_size[0] + 5, 223 | ] 224 | 225 | # Apply Gaussian blur to the ROI 226 | roi = cv2.GaussianBlur(roi, (15, 15), 0) 227 | 228 | # Place the blurred ROI back onto the canvas 229 | canvas[ 230 | text_y : text_y + text_size[1] + 5, 231 | text_x : text_x + text_size[0] + 5, 232 | ] = roi 233 | 234 | # Add the image name as text annotation in the bottom right corner 235 | cv2.putText( 236 | canvas, 237 | name, 238 | (text_x, text_y + text_size[1]), 239 | font, 240 | font_scale, 241 | text_color, 242 | font_thickness, 243 | ) 244 | 245 | # Display the canvas 246 | import matplotlib.pyplot as plt 247 | 248 | plt.imshow(canvas) 249 | plt.show() 250 | 251 | 252 | class FaceResult(dict): 253 | def __init__(self, id, name=None, distance=None, embedding=None, img=None, **kw): 254 | kw["id"] = id 255 | kw["name"] = name 256 | kw["distance"] = distance 257 | kw["embedding"] = embedding 258 | kw["img"] = img 259 | 260 | self.id = id 261 | self.name = name 262 | self.distance = distance 263 | self.embedding = embedding 264 | self.img = img 265 | 266 | self.kw = kw 267 | 268 | for i in kw: 269 | setattr(self, i, kw[i]) 270 | 271 | super().__init__(**kw) 272 | 273 | def __repr__(self): 274 | txt = f"FaceResult(id={self.id}, name={self.name}" 275 | if self.get("confidence"): 276 | txt += f", confidence={self['confidence']}%" 277 | txt += ")" 278 | return txt 279 | 280 | def __str__(self): 281 | result = {} 282 | for key in self: 283 | val = self[key] 284 | if val is not None: 285 | if key == "embedding": 286 | val = f"Embedding({len(val)} dim)" 287 | elif key == "img": 288 | val = "Image(cv2 image)" 289 | result[key] = val 290 | 291 | return pprint.pformat(result) 292 | 293 | def show_img(self): 294 | if self.get("img") is None: 295 | print("No image available") 296 | return 297 | else: 298 | import matplotlib.pyplot as plt 299 | 300 | if self.img is None: 301 | print("No image available. Include `img` in `include` to get the image") 302 | return 303 | 304 | plt.imshow(self.img) # type: ignore 305 | plt.title(self.name) # type: ignore 306 | plt.show() 307 | 308 | 309 | class BaseDB: 310 | def __init__(self, path): 311 | self.path = Path(path) 312 | self.path.mkdir(exist_ok=True) 313 | 314 | def __repr__(self): 315 | return f"{self.__class__.__name__}({self.path})" 316 | 317 | def __str__(self): 318 | return self.__repr__() 319 | 320 | def add(self, ids, embeddings, metadatas=None): 321 | if isinstance(ids, str): 322 | ids = [ids] 323 | if not many_vectors(embeddings): 324 | embeddings = [embeddings] 325 | 326 | assert len(ids) == len( 327 | embeddings 328 | ), "ids and embeddings must have the same length" 329 | if metadatas: 330 | assert len(ids) == len( 331 | metadatas 332 | ), "ids and metadatas must have the same length" 333 | 334 | if not isinstance(metadatas, list): 335 | metadatas = [metadatas] 336 | 337 | assert isinstance(metadatas[0], dict), "metadatas must be a list of dict" 338 | 339 | return self._add(ids, embeddings, metadatas) 340 | 341 | def _add(self, ids, embeddings, metadatas=None): 342 | raise NotImplementedError 343 | 344 | def delete(self, ids): 345 | if isinstance(ids, str): 346 | ids = [ids] 347 | return self._delete(ids) 348 | 349 | def _delete(self, ids): 350 | raise NotImplementedError 351 | 352 | def update(self, ids, embeddings=None, metadatas=None): 353 | if isinstance(ids, str): 354 | ids = [ids] 355 | if embeddings: 356 | if not many_vectors(embeddings): 357 | embeddings = [embeddings] 358 | 359 | assert len(ids) == len( 360 | embeddings 361 | ), "ids and embeddings must have the same length" 362 | if metadatas: 363 | if not isinstance(metadatas, list): 364 | metadatas = [metadatas] 365 | 366 | assert len(ids) == len( 367 | metadatas 368 | ), "ids and metadatas must have the same length" 369 | assert isinstance(metadatas[0], dict), "metadatas must be a list of dict" 370 | 371 | return self._update(ids, embeddings, metadatas) 372 | 373 | def _update(self, ids, embeddings=None, metadatas=None): 374 | raise NotImplementedError 375 | 376 | def query(self, embeddings, top_k=1, include=None, where=None): 377 | raise NotImplementedError 378 | 379 | def _get(self, ids, include=None, where=None): 380 | raise NotImplementedError 381 | 382 | def get(self, ids, include=None, where=None): 383 | if isinstance(ids, str): 384 | ids = [ids] 385 | return self._get(ids, include, where) 386 | 387 | def parser(self, result, imgdb, include=None, query=True) -> List[FaceResults]: 388 | raise NotImplementedError 389 | 390 | def all(self, include=None): 391 | raise NotImplementedError 392 | 393 | def count(self) -> int: 394 | raise NotImplementedError 395 | 396 | def __len__(self) -> int: 397 | return self.count() 398 | 399 | 400 | class PineconeDB(BaseDB): 401 | 402 | def __init__( 403 | self, 404 | dimension: int, 405 | spec=None, 406 | pinecone_client=None, 407 | index_name=None, 408 | metric="cosine", 409 | api_key=None, 410 | **kw, 411 | ): 412 | assert metric in [ 413 | "cosine", 414 | "euclidean", 415 | "dotproduct", 416 | ], "metric must be cosine, euclidean, or dotproduct" 417 | 418 | assert index_name is not None, "index_name must be provided" 419 | assert dimension is not None, "dimension must be provided" 420 | 421 | self.index_name = index_name 422 | self.pc: Pinecone = None # type: ignore 423 | self.dimension: int = dimension 424 | self.spec = spec 425 | 426 | if not spec: 427 | self.spec = ServerlessSpec( 428 | cloud="aws", 429 | region="us-east-1", 430 | ) 431 | 432 | if not pinecone_client: 433 | api_key = api_key or os.environ.get("PINECONE_API_KEY", None) 434 | # environment = kw.get( 435 | # "environment", 436 | # ) or os.environ.get("PINECONE_ENVIRONMENT", None) # Deprecated 437 | 438 | if "environment" in kw: 439 | warnings.warn( 440 | "`environment` is deprecated. Only use `api_key` instead.", 441 | ) 442 | del kw["environment"] 443 | 444 | if api_key is None: 445 | raise ConnectionError( 446 | "Pinecone api_key is not provided. Please provide an api_key or set it as an environment variable." 447 | ) 448 | try: 449 | self.pc = Pinecone( 450 | api_key=api_key, 451 | **kw, 452 | ) 453 | except Exception as e: 454 | raise Exception( 455 | f"Failed to initialize pinecone. Please check your api_key and environment. Error: {e}" 456 | ) 457 | else: 458 | assert isinstance( 459 | pinecone_client, Pinecone # type: ignore 460 | ), f"pinecone_client must be an instance of Pinecone. Got {type(pinecone_client)}" 461 | self.pc = pinecone_client 462 | 463 | idx_infos_raw = self.pc.list_indexes() 464 | self.idx_infos = {i["name"]: i for i in idx_infos_raw} 465 | 466 | self.index = self.get_index(index_name, dimension, metric) 467 | self.index_info: dict = self.idx_infos[self.index_name] 468 | 469 | assert ( 470 | self.index_info["dimension"] == self.dimension # type: ignore 471 | ), f"dimension must be the same as the index. `{self.index_name}` has dimension of `{self.index_info.dimension}` but got `{self.dimension}`" 472 | 473 | assert ( 474 | self.index_info.metric == metric # type: ignore 475 | ), f"metric must be the same as the index. `{self.index_name}` has metric of `{self.index_info.metric}` but got `{metric}`" 476 | 477 | def count(self): 478 | return self.index.describe_index_stats().get("total_vector_count", -1) 479 | 480 | def get_index(self, index_name, dimension, metric): 481 | if index_name in self.idx_infos: 482 | return self.pc.Index(index_name) 483 | else: 484 | self.pc.create_index( 485 | name=index_name, dimension=dimension, metric=metric, spec=self.spec 486 | ) 487 | idx_infos_raw = self.pc.list_indexes() 488 | self.idx_infos = {i["name"]: i for i in idx_infos_raw} 489 | return self.pc.Index(index_name) 490 | 491 | def _add(self, ids, embeddings, metadatas=None): 492 | vectors = [] 493 | for i in range(len(ids)): 494 | data = {"id": ids[i], "values": embeddings[i]} 495 | if metadatas: 496 | data["metadata"] = metadatas[i] 497 | vectors.append(data) 498 | 499 | return self.index.upsert(vectors=vectors) 500 | 501 | def _delete(self, ids): 502 | return self.index.delete(ids) 503 | 504 | def delete_all(self): 505 | return self.index.delete(delete_all=True) 506 | 507 | def _update(self, ids, embeddings=None, metadatas=None): 508 | res = [] 509 | for i in range(len(ids)): 510 | data = { 511 | "id": ids[i], 512 | "values": None, 513 | "set_metadata": None, 514 | } 515 | 516 | if embeddings: 517 | data["values"] = embeddings[i] 518 | if metadatas: 519 | data["set_metadata"] = metadatas[i] 520 | 521 | res.append(self.index.update(**data)) 522 | 523 | return res 524 | 525 | def query( 526 | self, 527 | embeddings, 528 | top_k=1, 529 | include: Optional[List[Literal["embeddings", "metadatas"]]] = None, 530 | where=None, 531 | ): 532 | params = { 533 | "top_k": top_k, 534 | } 535 | 536 | if include is not None: 537 | if isinstance(include, str): 538 | include = [include] # type: ignore 539 | 540 | for i in include: # type: ignore 541 | if i == "embeddings": 542 | params["include_values"] = True 543 | elif i == "metadatas": 544 | params["include_metadata"] = True 545 | 546 | if where: 547 | params["filter"] = where 548 | 549 | if many_vectors(embeddings): 550 | res = [] 551 | for i in range(len(embeddings)): 552 | res.append(self.index.query(vector=embeddings[i], **params).to_dict()) # type: ignore 553 | return res 554 | else: 555 | return self.index.query(vector=embeddings, **params).to_dict() # type: ignore 556 | 557 | def parser( 558 | self, result, imgdb, include=None, query=True, threshold=None 559 | ) -> Union[FaceResults, List[FaceResults]]: 560 | if isinstance(result, dict): 561 | result = [result] 562 | 563 | if include is None: 564 | include = [] 565 | elif isinstance(include, str): 566 | include = [include] 567 | 568 | results: List[FaceResults] = [] 569 | 570 | for i in range(len(result)): 571 | rs = [] 572 | for j, r in enumerate(result[i]["matches"]): 573 | data = { 574 | "id": r["id"], 575 | } 576 | if "score" in r: 577 | data["distance"] = 1 - r["score"] 578 | if threshold: 579 | _, direction, value = threshold 580 | data["confidence"] = calculate_confidence( 581 | data["distance"], value, direction 582 | ) 583 | 584 | for k in include: 585 | if k[:9] == "embedding": 586 | data["embedding"] = r["values"] 587 | elif k[:3] == "img": 588 | data["img"] = imgdb.get(r["id"]) 589 | elif k[:8] == "distance": 590 | continue 591 | 592 | # add all metadata keys 593 | if "metadata" in r and r["metadata"]: 594 | for key in r["metadata"]: 595 | if key == "values" or key == "id" or key == "score": 596 | continue 597 | data[key] = r["metadata"][key] 598 | 599 | rs.append(FaceResult(**data)) 600 | if rs: 601 | results.append(FaceResults(rs)) 602 | else: 603 | results.append(None) # type: ignore 604 | 605 | if not query: 606 | return results[0] 607 | 608 | return results 609 | 610 | def get(self, ids, include=None, where=None): 611 | if isinstance(ids, str): 612 | ids = [ids] 613 | 614 | return self.query( 615 | embeddings=[0] * self.dimension, 616 | top_k=len(ids), 617 | where={"id": {"$in": ids}}, 618 | include=include, 619 | ) 620 | 621 | def all(self, include=None): 622 | get_metadata = None 623 | get_values = None 624 | if include: 625 | if "metadatas" in include: 626 | get_metadata = True 627 | if "embeddings" in include: 628 | get_values = True 629 | 630 | stats = self.index.describe_index_stats() 631 | namespace_map = stats["namespaces"] 632 | ret = [] 633 | for namespace in namespace_map: 634 | vector_count = namespace_map[namespace]["vector_count"] 635 | res = self.index.query( 636 | vector=[0 for _ in range(self.dimension)], 637 | top_k=vector_count, 638 | namespace=namespace, 639 | include_values=get_values, 640 | include_metadata=get_metadata, 641 | ) 642 | ret.extend(res["matches"]) 643 | 644 | return {"matches": ret} 645 | 646 | 647 | class ChromaDB(BaseDB): 648 | def __init__( 649 | self, path=None, client=None, metric="cosine", collection_name="faces" 650 | ): 651 | assert metric in [ 652 | "cosine", 653 | "l2", 654 | "ip", 655 | ], "chromadb only support cosine, l2, and ip metric" 656 | 657 | if path is None: 658 | path = "data" 659 | if client is None: 660 | self.client = chromadb.PersistentClient(path) 661 | else: 662 | self.client = client 663 | 664 | self.path = Path(path) 665 | 666 | self.db = self.client.get_or_create_collection( 667 | name=collection_name, 668 | metadata={ 669 | "hnsw:space": metric, 670 | }, 671 | ) 672 | 673 | self.metric = metric 674 | self.collection_name = collection_name 675 | 676 | def add(self, ids, embeddings, metadatas=None): 677 | return self.db.add(ids=ids, embeddings=embeddings, metadatas=metadatas) 678 | 679 | def delete(self, ids): 680 | return self.db.delete(ids) 681 | 682 | def update(self, ids, embeddings=None, metadatas=None): 683 | return self.db.update(ids=ids, embeddings=embeddings, metadatas=metadatas) 684 | 685 | def query( 686 | self, 687 | embeddings, 688 | top_k=1, 689 | include: Optional[List[Literal["embeddings", "metadatas", "distances"]]] = [ 690 | "distances" 691 | ], 692 | where=None, 693 | ): 694 | return self.db.query( 695 | query_embeddings=embeddings, n_results=top_k, include=include or ["distances"], where=where # type: ignore 696 | ) 697 | 698 | def get(self, ids, include=None, where=None): 699 | return self.db.get(ids=ids, include=include or ["metadatas"], where=where) 700 | 701 | def count(self) -> int: 702 | return self.db.count() 703 | 704 | def all(self, include=None): 705 | return self.db.get(include=include or ["metadatas"]) 706 | 707 | def query_parser( 708 | self, result, imgdb, include=["distances"], threshold=None 709 | ) -> List[FaceResults]: 710 | results: List[FaceResults] = [] 711 | for i in range(len(result["ids"])): 712 | rs = [] 713 | for j, id in enumerate(result["ids"][i]): 714 | data = {"id": id} 715 | if result["distances"]: 716 | data["distance"] = result["distances"][i][j] 717 | if threshold: 718 | _, direction, value = threshold 719 | data["confidence"] = calculate_confidence( 720 | data["distance"], value, direction 721 | ) 722 | 723 | for k in include: 724 | if k[:9] == "embedding": 725 | if result["embeddings"]: 726 | data["embedding"] = result["embeddings"][i][j] 727 | elif k[:3] == "img": 728 | data["img"] = imgdb.get(id) 729 | 730 | # add all metadata keys 731 | if ( 732 | result["metadatas"] 733 | and result["metadatas"][i] 734 | and result["metadatas"][i][j] 735 | ): 736 | for key in result["metadatas"][i][j]: 737 | data[key] = result["metadatas"][i][j][key] 738 | rs.append(FaceResult(**data)) 739 | if rs: 740 | results.append(FaceResults(rs)) 741 | else: 742 | results.append(None) # type: ignore 743 | return results 744 | 745 | def get_parser(self, result, imgdb, include=["metadatas"]) -> FaceResults: 746 | if include is None: 747 | include = [] 748 | elif isinstance(include, str): 749 | include = [include] 750 | 751 | results: List[FaceResult] = [] 752 | for i, id in enumerate(result["ids"]): 753 | data = {"id": id} 754 | for k in include: 755 | if k[:9] == "embedding": 756 | data["embedding"] = result["embeddings"][i] 757 | elif k[:3] == "img": 758 | data["img"] = imgdb.get(id) 759 | 760 | # add all metadata keys 761 | if result["metadatas"] and result["metadatas"][i]: 762 | for key in result["metadatas"][i]: 763 | data[key] = result["metadatas"][i][key] 764 | 765 | results.append(FaceResult(**data)) 766 | 767 | return FaceResults(results) 768 | 769 | def parser( 770 | self, result, imgdb, include=None, query=True, threshold=None 771 | ) -> Union[FaceResults, List[FaceResults]]: 772 | if query: 773 | return self.query_parser(result, imgdb, include, threshold) 774 | else: 775 | return self.get_parser(result, imgdb, include) 776 | 777 | def delete_all(self): 778 | self.client.delete_collection(self.collection_name) 779 | -------------------------------------------------------------------------------- /facedb/db.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | from tqdm.auto import tqdm 4 | import cv2 5 | import warnings 6 | import threading 7 | import os 8 | 9 | DeepFace = None 10 | deepface_distance = None 11 | face_recognition = None 12 | 13 | from facedb.db_tools import ( 14 | get_embeddings, 15 | get_include, 16 | ImgDB, 17 | Rect, 18 | img_to_cv2, 19 | is_list_of_img, 20 | is_2d, 21 | time_now, 22 | get_model_dimension, 23 | l2_normalize, 24 | is_none_or_empty, 25 | Union, 26 | Literal, 27 | Optional, 28 | List, 29 | Tuple, 30 | fthresholds, 31 | FailedImageIndexList, 32 | ) 33 | 34 | from facedb.db_models import FaceResults, PineconeDB, ChromaDB 35 | 36 | from pathlib import Path 37 | 38 | import_lock = threading.Lock() 39 | 40 | 41 | def load_module(module: Literal["deepface", "face_recognition"]): 42 | with import_lock: 43 | if module == "deepface": 44 | global DeepFace 45 | global deepface_distance 46 | if DeepFace is None: 47 | try: 48 | from deepface import DeepFace 49 | from deepface.modules import verification as deepface_distance 50 | except ImportError: 51 | raise ImportError( 52 | "Please install `deepface` to use `deepface` module." 53 | ) 54 | elif module == "face_recognition": 55 | global face_recognition 56 | if face_recognition is None: 57 | try: 58 | import face_recognition 59 | except ImportError: 60 | raise ImportError( 61 | "Please install `face_recognition` to use `face_recognition` module." 62 | ) 63 | else: 64 | raise ValueError( 65 | "Currently only `deepface` and `face_recognition` are supported." 66 | ) 67 | 68 | 69 | def create_deepface_embedding_func( 70 | model_name, 71 | detector_backend, 72 | enforce_detection, 73 | align, 74 | normalization, 75 | l2_normalization, 76 | ): 77 | def embedding_func(img, enforce_detection=enforce_detection, **kw): 78 | try: 79 | result = DeepFace.represent( # type: ignore 80 | img, 81 | model_name=model_name, 82 | detector_backend=detector_backend, 83 | enforce_detection=enforce_detection, 84 | align=align, 85 | normalization=normalization, 86 | ) 87 | except ValueError: 88 | return None 89 | 90 | result = [i["embedding"] for i in result] 91 | if l2_normalization: 92 | result = l2_normalize(result) 93 | 94 | if is_none_or_empty(result): 95 | return None 96 | 97 | return result 98 | 99 | return embedding_func 100 | 101 | 102 | def create_face_recognition_embedding_func( 103 | model, 104 | num_jitters, 105 | l2_normalization, 106 | ): 107 | def embedding_func(img, know_face_locations=None, **kw): 108 | img = img_to_cv2(img) 109 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 110 | 111 | result = face_recognition.face_encodings( # type: ignore 112 | img, 113 | num_jitters=num_jitters, 114 | model=model, 115 | known_face_locations=know_face_locations, 116 | ) 117 | if is_none_or_empty(result): 118 | return None 119 | 120 | if l2_normalization: 121 | result = l2_normalize(result) 122 | return result 123 | 124 | return embedding_func 125 | 126 | 127 | def create_deepface_extract_faces_func( 128 | extract_faces_detector_backend, 129 | enforce_detection, 130 | align, 131 | ): 132 | def extract_faces(img): 133 | try: 134 | result = DeepFace.extract_faces( # type: ignore 135 | img, 136 | detector_backend=extract_faces_detector_backend, 137 | enforce_detection=enforce_detection, 138 | align=align, 139 | ) 140 | except ValueError: 141 | return [] 142 | 143 | if result is None: 144 | return [] 145 | 146 | return [Rect(**i["facial_area"]) for i in result] 147 | 148 | return extract_faces 149 | 150 | 151 | def create_face_recognition_extract_faces_func(extract_face_model="hog"): 152 | def extract_faces(img): 153 | img = img_to_cv2(img) 154 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 155 | 156 | result = face_recognition.face_locations(img, model=extract_face_model) # type: ignore 157 | if result is None: 158 | return [] 159 | 160 | rects = [] 161 | for i in result: 162 | y, xw, yh, x = i 163 | w = xw - x # type: ignore 164 | h = yh - y # type: ignore 165 | rects.append(Rect(x, y, w, h)) 166 | return rects 167 | 168 | return extract_faces 169 | 170 | 171 | metric_map = { 172 | "pinecone": { 173 | "cosine": "cosine", 174 | "euclidean": "euclidean", 175 | "dot": "dotproduct", 176 | }, 177 | "chromadb": { 178 | "cosine": "cosine", 179 | "euclidean": "l2", 180 | "dot": "ip", 181 | }, 182 | } 183 | 184 | 185 | class FaceDB: 186 | 187 | def __init__( 188 | self, 189 | *, 190 | path: str = "facedata", 191 | metric: Literal["cosine", "euclidean", "dot"] = "euclidean", 192 | embedding_func=None, 193 | embedding_dim: Optional[int] = None, 194 | l2_normalization: bool = True, 195 | module: Literal["deepface", "face_recognition"] = "face_recognition", 196 | database_backend: Literal["chromadb", "pinecone"] = "chromadb", 197 | index_name: Optional[str] = None, 198 | pinecone_settings: dict = {}, 199 | face_recognition_settings: dict = {}, 200 | deepface_settings: dict = {}, 201 | **kw, 202 | ): 203 | """ 204 | Initialize the FaceDB instance for face recognition. 205 | 206 | Args: 207 | path (str, optional): The path to store data. Defaults to "facedata". 208 | metric (str, optional): The distance metric to use for similarity. Defaults to "euclidean". 209 | embedding_func (callable, optional): Custom embedding function. Defaults to None. 210 | embedding_dim (int, optional): The dimension of face embeddings. Defaults to selected automatically. 211 | l2_normalization (bool, optional): Whether to perform L2 normalization on embeddings(increases accuracy). Defaults to True. 212 | module (str, optional): The face recognition module to use. Defaults to "face_recognition" (DeepFace not optimized). 213 | database_backend (str, optional): The database backend to use(ChromaDB or Pinecone). Defaults to "chromadb". 214 | 215 | pinecone_settings (dict, optional): Additional settings to pass to the Pinecone client. 216 | client (PineconeClient, optional): The Pinecone client to use. Defaults to HTTPClient. 217 | index_name (str, optional): The name of the Pinecone index to use. 218 | api_key (str, optional): Can be passed as `pinecone_api_key` or environment variable `PINECONE_API_KEY`. 219 | spec (str, optional): The Pinecone spec to use. Defaults to ServerlessSpec(cloud="aws",region="us-east-1"). 220 | Rest of the keyword arguments are passed to the pinecone client directly. 221 | 222 | face_recognition_settings (dict, optional): Additional settings to pass to the face recognition module. 223 | model (str, optional): Model size. Defaults to "small". 224 | num_jitters (int, optional): Number of jitter samples. Defaults to 1. 225 | extract_face_model (str, optional): Face detection model. Defaults to "hog". 226 | 227 | deepface_settings (dict, optional): Additional settings to pass to the DeepFace module. 228 | model_name (str, optional): Model name. Defaults to "Facenet512". 229 | detector_backend (str, optional): Face detection backend. Defaults to "ssd". 230 | enforce_detection (bool, optional): Whether to enforce face detection. Defaults to True. 231 | normalization (bool, optional): Whether to normalize face embeddings. Defaults to True. 232 | extract_face_backend (str, optional): Face detection backend. Defaults to "ssd". 233 | enforce_detection (bool, optional): Whether to enforce face detection. Defaults to True. 234 | align (bool, optional): Whether to align faces. Defaults to True. 235 | 236 | 237 | 238 | Examples: 239 | >>> from facedb import FaceDB 240 | >>> facedb = FaceDB() 241 | >>> facedb.add("elon_musk", img="elon_musk.jpg") 242 | >>> facedb.add("jeff_bezos", img="jeff_bezos.jpg") 243 | >>> facedb.recognize(img="elon_musk_2.jpg") # returns FaceResults 244 | """ 245 | 246 | path = Path(path) 247 | 248 | assert metric in [ 249 | "cosine", 250 | "euclidean", 251 | "dot", 252 | ], "Supported metrics are `cosine`, `euclidean` and `dot`." 253 | assert module in [ 254 | "deepface", 255 | "face_recognition", 256 | ], "Supported modules are `deepface` and `face_recognition`." 257 | assert database_backend in [ 258 | "chromadb", 259 | "pinecone", 260 | ], "Supported database backends are `chromadb` and `pinecone`." 261 | 262 | if module == "deepface": 263 | warnings.warn( 264 | "Deepface module is not calibrated for vector database. Use `face_recognition` instead." 265 | ) 266 | 267 | os.environ["DB_BACKEND"] = database_backend 268 | 269 | if index_name: 270 | warnings.warn( 271 | """`index_name` is deprecated. Use `pinecone_settings` instead. 272 | For example: 273 | ```pinecone_settings = {index_name = 'my_index'}``` 274 | """ 275 | ) 276 | pinecone_settings["index_name"] = index_name 277 | 278 | self.metric = metric 279 | self.embedding_func: Callable = embedding_func # type: ignore 280 | self.extract_faces: Callable = None # type: ignore 281 | self.module = module 282 | self.db_backend = database_backend 283 | self.l2_normalization = l2_normalization 284 | self.deepface_model_name = kw.get("model_name", "Facenet") 285 | 286 | if database_backend == "chromadb": 287 | self.db = ChromaDB( 288 | path=str(path), 289 | client=kw.pop("client", None), 290 | metric=metric_map[database_backend][metric], 291 | collection_name=kw.pop("collection_name", "facedb"), 292 | ) 293 | 294 | elif database_backend == "pinecone": 295 | self.db = PineconeDB( 296 | pinecone_client=pinecone_settings.pop("client", None), 297 | index_name=pinecone_settings.pop("index_name", None), 298 | metric=metric_map[database_backend][metric], 299 | dimension=embedding_dim 300 | or get_model_dimension(module, self.deepface_model_name), 301 | api_key=pinecone_settings.pop("pinecone_api_key", None), 302 | spec=pinecone_settings.pop("spec", None), 303 | **pinecone_settings, 304 | ) 305 | 306 | if embedding_func is None: 307 | load_module(module) 308 | if module == "deepface": 309 | self.embedding_func = create_deepface_embedding_func( 310 | model_name=deepface_settings.pop("model_name", "Facenet"), 311 | detector_backend=deepface_settings.pop("detector_backend", "ssd"), 312 | enforce_detection=deepface_settings.pop("enforce_detection", True), 313 | align=deepface_settings.pop("align", True), 314 | normalization=deepface_settings.pop("normalization", "base"), 315 | l2_normalization=l2_normalization, 316 | ) 317 | self.extract_faces = create_deepface_extract_faces_func( 318 | extract_faces_detector_backend=deepface_settings.pop( 319 | "extract_face_backend", 320 | "ssd", 321 | ), 322 | enforce_detection=deepface_settings.pop("enforce_detection", True), 323 | align=deepface_settings.pop("align", True), 324 | ) 325 | elif module == "face_recognition": 326 | self.embedding_func = create_face_recognition_embedding_func( 327 | model=face_recognition_settings.pop("model", "small"), 328 | num_jitters=face_recognition_settings.pop("num_jitters", 1), 329 | l2_normalization=l2_normalization, 330 | ) 331 | self.extract_faces = create_face_recognition_extract_faces_func( 332 | extract_face_model=face_recognition_settings.pop( 333 | "extract_face_model", "hog" 334 | ), 335 | ) 336 | else: 337 | raise ValueError( 338 | "Currently only `deepface` and `face_recognition` are supported." 339 | ) 340 | else: 341 | self.embedding_func = embedding_func 342 | 343 | if not path.exists(): 344 | path.mkdir(parents=True) 345 | 346 | self.imgdb = ImgDB(db_path=str(path / "img.db")) 347 | self.threshold = self.get_threshold() 348 | 349 | def __len__(self): 350 | return self.db.count() 351 | 352 | def count(self): 353 | """ 354 | Get the number of faces in the database (alias for __len__). 355 | 356 | Returns: 357 | int: The number of faces in the database. 358 | """ 359 | return self.db.count() 360 | 361 | def get_threshold(self) -> Tuple[str, str, float]: 362 | """ 363 | Get the similarity threshold for the database. 364 | 365 | Returns: 366 | Tuple[str, str, float]: The similarity threshold. 367 | """ 368 | 369 | metric = self.metric 370 | if self.module == "deepface": 371 | if metric == "euclidean" and self.l2_normalization: 372 | metric = "euclidean_l2" 373 | 374 | threshold = deepface_distance.find_threshold( # type: ignore 375 | self.deepface_model_name, metric 376 | ) 377 | return "le", "negative", threshold 378 | 379 | elif self.module == "face_recognition": 380 | metric = metric_map[self.db_backend][metric] 381 | if self.l2_normalization: 382 | metric_threshold = fthresholds[self.db_backend][metric + "_l2"] 383 | else: 384 | metric_threshold = fthresholds[self.db_backend][metric] 385 | 386 | return ( 387 | metric_threshold["operator"], 388 | metric_threshold["direction"], 389 | metric_threshold["value"], 390 | ) 391 | 392 | else: 393 | raise ValueError( 394 | "Currently only `deepface` and `face_recognition` are supported." 395 | ) 396 | 397 | def _is_match(self, distance, threshold=None): 398 | if threshold is None or threshold == 80: 399 | op, _, threshold = self.threshold 400 | else: 401 | op, _, thrs = self.threshold 402 | threshold = max(10, threshold) 403 | thrs = thrs / (threshold / 100) 404 | threshold = thrs 405 | 406 | if op == "le": 407 | return distance <= threshold 408 | elif op == "ge": 409 | return distance >= threshold 410 | elif op == "eq": 411 | return distance == threshold 412 | elif op == "l": 413 | return distance < threshold 414 | elif op == "g": 415 | return distance > threshold 416 | elif op == "ne": 417 | return distance != threshold 418 | else: 419 | raise ValueError("Invalid operator.") 420 | 421 | def get_faces(self, img, *, zoom_out=0.25, only_rect=False) -> Union[None, list]: 422 | """ 423 | Extract faces from an image. 424 | 425 | Args: 426 | img: The input image. 427 | zoom_out (float, optional): Zoom factor for the extracted faces. Defaults to 0.25. 428 | only_rect (bool, optional): Whether to return only the face rectangles. Defaults to False. 429 | 430 | Returns: 431 | Union[None, list]: A list of extracted faces or face rectangles. 432 | """ 433 | img = img_to_cv2(img) 434 | rects = self.extract_faces(img) 435 | img_h, img_w = img.shape[:2] 436 | for rect in rects: 437 | rect.x = max(0, rect.x - int(rect.w * zoom_out)) 438 | rect.y = max(0, rect.y - int(rect.h * zoom_out)) 439 | rect.w = min(img_w - rect.x, rect.w + int(rect.w * zoom_out)) 440 | rect.h = min(img_h - rect.y, rect.h + int(rect.h * zoom_out)) 441 | if rects: 442 | if only_rect: 443 | return rects 444 | 445 | return [ 446 | img[rect.y : rect.y + rect.h, rect.x : rect.x + rect.w] 447 | for rect in rects 448 | ] 449 | return None 450 | 451 | def check_similar(self, embeddings, threshold=None) -> list: 452 | """ 453 | Check for similar faces in the database. 454 | 455 | Args: 456 | embeddings: Face embeddings to compare. 457 | threshold (float, optional): The similarity threshold. Defaults to 80. 458 | 459 | Returns: 460 | list: List of id(if it is match) or false 461 | """ 462 | embeddings = get_embeddings( 463 | imgs=None, 464 | embeddings=embeddings, 465 | ) 466 | 467 | result = self.db.query( 468 | embeddings=embeddings, 469 | top_k=1, 470 | include=None, 471 | ) 472 | 473 | results = self.db.parser(result, imgdb=self.imgdb, include=["distances"]) 474 | rs = [] 475 | for result in results: 476 | if result is None: 477 | rs.append(False) 478 | elif self._is_match(result["distance"], threshold): 479 | rs.append(result["id"]) 480 | else: 481 | rs.append(False) 482 | return rs 483 | 484 | def recognize( 485 | self, *, img=None, embedding=None, include=None, threshold=None, top_k=1 486 | ): 487 | """ 488 | Recognize a face from an image or embedding. 489 | 490 | Args: 491 | img: The input image. 492 | embedding: Face embeddings for recognition. 493 | include (list, optional): List of information to include in the result. Defaults to None. 494 | threshold (float, optional): The similarity threshold. Defaults to None. 495 | top_k (int, optional): Number of top results to return. Defaults to 1. 496 | 497 | Returns: 498 | False: If no face is found in the image. 499 | None: If no match is found. 500 | FaceResults: If a match is found. 501 | """ 502 | 503 | single = False 504 | size = 0 505 | if embedding is not None: 506 | if not is_2d(embedding): 507 | single = True 508 | else: 509 | size = len(embedding) 510 | elif img is not None: 511 | if not is_list_of_img(img): 512 | single = True 513 | else: 514 | size = len(img) 515 | 516 | embedding = get_embeddings( 517 | embeddings=embedding, 518 | imgs=img, 519 | embedding_func=self.embedding_func, 520 | ) 521 | 522 | if is_none_or_empty(embedding): 523 | if size == 0: 524 | return False 525 | return [False] * size 526 | 527 | rincludes, include = get_include(default="distances", include=include) 528 | result = self.db.query( 529 | embeddings=embedding, 530 | top_k=top_k, 531 | include=rincludes, 532 | ) 533 | 534 | results = self.db.parser(result, imgdb=self.imgdb, include=include, threshold=self.threshold) # type: ignore 535 | res = [] 536 | for result in results: 537 | if result is None: 538 | res.append(None) 539 | elif self._is_match(result["distance"], threshold): 540 | res.append(result) 541 | else: 542 | res.append(None) 543 | 544 | if single and res: 545 | return res[0] 546 | 547 | return res 548 | 549 | def add( 550 | self, 551 | name, 552 | img=None, 553 | embedding=None, 554 | id=None, 555 | check_similar=True, 556 | save_just_face=False, 557 | **metadata, 558 | ) -> str: 559 | """ 560 | Add a new face to the database. 561 | 562 | Args: 563 | name (str): The name of the person associated with the face. 564 | img: The input image. 565 | embedding: Face embeddings for the new face. 566 | id (str, optional): The unique ID for the face. Defaults to None. 567 | check_similar (bool, optional): Whether to check for similar faces. Defaults to True. 568 | save_just_face (bool, optional): Whether to save only the face region. Defaults to False. 569 | **metadata: Additional metadata for the face. 570 | 571 | Returns: 572 | str: The ID of the added face. 573 | 574 | Raises: 575 | ValueError: If no face is found in the image. 576 | """ 577 | embedding = get_embeddings( 578 | embeddings=embedding, 579 | imgs=img, 580 | embedding_func=self.embedding_func, 581 | ) 582 | 583 | if is_none_or_empty(embedding): 584 | raise ValueError("No face found in the img.") 585 | 586 | if check_similar: 587 | result = self.check_similar(embeddings=embedding)[0] 588 | if result: 589 | print( 590 | "Similar face already exists. If you want to add anyway, set `check_similar` to `False`." 591 | ) 592 | return result 593 | 594 | metadata["name"] = name 595 | idx = id or re.sub(r"[\W]", "-", name) + "-" + time_now() 596 | if img is not None: 597 | if save_just_face: 598 | img = self.get_faces(img) 599 | if img is None: 600 | raise ValueError("No face found in the img.") 601 | img = img[0] 602 | 603 | self.imgdb.add(img_id=idx, img=img) 604 | 605 | self.db.add( 606 | ids=[idx], 607 | embeddings=embedding, 608 | metadatas=[ 609 | { 610 | **metadata, 611 | } 612 | ], 613 | ) 614 | 615 | return idx 616 | 617 | def add_many( 618 | self, 619 | *, 620 | embeddings=None, 621 | imgs=None, 622 | metadata=None, 623 | ids=None, 624 | names=None, 625 | check_similar=True, 626 | ) -> Tuple[list, list]: 627 | """ 628 | Add multiple faces to the database. 629 | 630 | Args: 631 | embeddings: List of face embeddings to add. 632 | imgs: List of input images containing faces. 633 | metadata: List of metadata for the faces. 634 | ids (list, optional): List of unique IDs for the faces. Defaults to None. 635 | names (list, optional): List of names associated with the faces. Defaults to None. 636 | check_similar (bool, optional): Whether to check for similar faces. Defaults to True. 637 | 638 | Returns: 639 | tuple: A tuple containing lists of added IDs and failed IDs. 640 | """ 641 | faces = [] 642 | failed = FailedImageIndexList() 643 | metadata_posible = metadata is not None 644 | if embeddings is None: 645 | if metadata_posible: 646 | warnings.warn( 647 | "Without embeddings, add metadata may not work as expected." 648 | ) 649 | if imgs is None: 650 | raise ValueError("Either `embeddings` or `imgs` must be provided.") 651 | 652 | if names is not None: 653 | if len(imgs) != len(names): 654 | raise ValueError("`imgs` length and `names` length must be same") 655 | cnames = [re.sub(r"[\W]", "-", name) for name in names] 656 | idxs = ids or [f"{name}-{time_now()}" for name in cnames] 657 | else: 658 | names = [f"face_{i}-{time_now()}" for i in range(len(imgs))] 659 | idxs = ids or [f"faceid_{i}-{time_now()}" for i in range(len(imgs))] 660 | 661 | for i, img in enumerate(tqdm(imgs, desc="Extracting faces")): 662 | try: 663 | rects: List[Rect] = self.get_faces(img, only_rect=True) # type: ignore 664 | if is_none_or_empty(rects): 665 | print(f"No face found in the img {i}. Skipping.") 666 | failed.append(i, failed_reason="No face found in the img") 667 | continue 668 | 669 | result = self.embedding_func( 670 | img, 671 | know_face_locations=[r.to(self.module) for r in rects], 672 | enforce_detection=False, 673 | ) 674 | 675 | if is_none_or_empty(result): 676 | print(f"No face found in the img {i}. Skipping.") 677 | failed.append(i, failed_reason="No face found in the img") 678 | continue 679 | 680 | try: 681 | idx = idxs[i] 682 | except IndexError: 683 | raise IndexError("`ids` length and `imgs` length must be same") 684 | 685 | if len(result) > 1: 686 | # metadata_posible = False 687 | for j, embedding in enumerate(result): 688 | name = f"{names[i]}_{j}" 689 | faces.append( 690 | { 691 | "id": f"{idx}_{j}", 692 | "name": name, 693 | "embedding": embedding, 694 | "img": img, 695 | "rect": rects[j], 696 | "index": i, 697 | } 698 | ) 699 | elif len(result) > 0: 700 | res = { 701 | "id": idx, 702 | "name": names[i], 703 | "embedding": result[0], 704 | "img": img, 705 | "rect": rects[0], 706 | "index": i, 707 | } 708 | if metadata_posible: 709 | try: 710 | res.update(metadata[i]) 711 | except IndexError: 712 | pass 713 | faces.append(res) 714 | else: 715 | print(f"No face found in the img {i}. Skipping.") 716 | failed.append(i, failed_reason="No face found in the img") 717 | continue 718 | except Exception as e: 719 | print(f"Error in img {i}. Skipping.", e) 720 | failed.append(i, failed_reason=str(e)) 721 | continue 722 | 723 | result = None 724 | else: 725 | idxs = ids or [f"faceid_{i}-{time_now()}" for i in range(len(embeddings))] 726 | for i, embedding in enumerate(embeddings): 727 | try: 728 | idx = idxs[i] 729 | except IndexError: 730 | raise IndexError( 731 | "`ids` length and `embeddings` length must be same" 732 | ) 733 | 734 | res = { 735 | "id": idx, 736 | "embedding": embedding, 737 | "index": i, 738 | } 739 | 740 | if names is not None: 741 | try: 742 | res["name"] = names[i] 743 | except IndexError: 744 | raise IndexError( 745 | "`names` length and `embeddings` length must be same" 746 | ) 747 | else: 748 | res["name"] = f"face_{i}-{time_now()}" 749 | 750 | if imgs is not None: 751 | try: 752 | rects = self.extract_faces(imgs[i]) 753 | except IndexError: 754 | raise IndexError( 755 | "`imgs` length and `embeddings` length must be same" 756 | ) 757 | 758 | if len(rects) > 1: 759 | print("Multiple faces found in the img. Taking first face.") 760 | 761 | res["img"] = imgs[i] 762 | res["rect"] = rects[0] 763 | 764 | if metadata_posible: 765 | try: 766 | res.update(metadata[i]) 767 | except IndexError: 768 | raise IndexError( 769 | "`metadatas` length and `embeddings` length must be same" 770 | ) 771 | 772 | faces.append(res) 773 | 774 | if not faces: 775 | return [], failed 776 | 777 | embedding = None 778 | embeddings = None 779 | 780 | if check_similar: 781 | res = self.check_similar(embeddings=[i["embedding"] for i in faces]) 782 | for i, r in enumerate(res): 783 | if r: 784 | print( 785 | f"Similar face {r} already exists. If you want to add anyway, set `check_similar` to `False`." 786 | ) 787 | failed.append( 788 | faces[i]["index"], 789 | failed_reason=f"Similar face {r} already exists.", 790 | ) 791 | faces[i] = None 792 | 793 | res = None 794 | 795 | # remove None faces 796 | faces = [i for i in faces if i is not None] 797 | if not faces: 798 | return [], failed 799 | 800 | metadata = [] 801 | added_img = False 802 | ids = [i["id"] for i in faces] 803 | for i in faces: 804 | data = {"id": i["id"]} 805 | if "img" in i: 806 | added_img = True 807 | img = img_to_cv2(i["img"]) 808 | rect = i["rect"] 809 | img = img[rect.y : rect.y + rect.h, rect.x : rect.x + rect.w] 810 | self.imgdb._add( 811 | img_id=i["id"], 812 | img=img, 813 | ) 814 | 815 | for j in i: 816 | if j == "rect": 817 | data["rect"] = i[j].to_json() 818 | 819 | elif j not in ["id", "embedding", "img"]: 820 | data[j] = i[j] 821 | 822 | metadata.append(data) 823 | 824 | embeddings = [f["embedding"] for f in faces] 825 | if isinstance(embeddings[0], np.ndarray): 826 | embeddings = [e.tolist() for e in embeddings] 827 | 828 | faces = None 829 | 830 | self.db.add( 831 | ids=ids, 832 | embeddings=embeddings, 833 | metadatas=metadata, 834 | ) 835 | 836 | if added_img: 837 | self.imgdb.conn.commit() 838 | 839 | print(f"Added {len(ids)} faces.") 840 | print(f"Failed to add {len(failed)} faces.") 841 | 842 | return ids, failed 843 | 844 | def search( 845 | self, *, embedding=None, img=None, include=None, top_k=1 846 | ) -> List[FaceResults]: 847 | """ 848 | Search for similar faces in the database. 849 | 850 | Args: 851 | embedding: Face embeddings for searching. 852 | img: The input image for searching. 853 | include (list, optional): List of information to include in the result. Defaults to None. 854 | top_k (int, optional): Number of top results to return. Defaults to 1. 855 | 856 | Returns: 857 | list: List of search results. 858 | """ 859 | embedding = get_embeddings( 860 | embeddings=embedding, 861 | imgs=img, 862 | embedding_func=self.embedding_func, 863 | ) 864 | 865 | if is_none_or_empty(embedding): 866 | return [] 867 | 868 | sincludes, include = get_include(default="distances", include=include) 869 | 870 | result = self.db.query( 871 | embeddings=embedding, 872 | top_k=top_k, 873 | include=sincludes, 874 | ) 875 | 876 | return self.db.parser(result, imgdb=self.imgdb, include=include, threshold=self.threshold) # type: ignore 877 | 878 | def query( 879 | self, 880 | *, 881 | embedding=None, 882 | img=None, 883 | name=None, 884 | include=None, 885 | top_k=1, 886 | **search_params, 887 | ) -> Union[List[FaceResults], FaceResults]: 888 | """ 889 | Query the database for faces based on specified parameters. 890 | 891 | Args: 892 | embedding: Face embeddings for querying. 893 | img: The input image for querying. 894 | name (str, optional): The name associated with the face. Defaults to None. 895 | include (list, optional): List of information to include in the result. Defaults to None. 896 | top_k (int, optional): Number of top results to return. Defaults to 1. 897 | **search_params: Additional search parameters. 898 | 899 | Returns: 900 | Union[List[FaceResults], FaceResults]: Query results. 901 | """ 902 | params = { 903 | "embeddings": None, 904 | "top_k": top_k, 905 | "where": {}, 906 | "include": None, 907 | } 908 | 909 | params["include"], include = get_include(default=None, include=include) 910 | 911 | params["embeddings"] = get_embeddings( 912 | embeddings=embedding, 913 | imgs=img, 914 | embedding_func=self.embedding_func, 915 | raise_error=False, 916 | ) 917 | 918 | if not params["embeddings"]: 919 | params["embeddings"] = None 920 | 921 | if name is not None: 922 | params["where"]["name"] = name 923 | 924 | if search_params.get("where", None) is None: 925 | for key, value in search_params.items(): 926 | params["where"][key] = value 927 | else: 928 | params["where"] = search_params["where"] 929 | 930 | if not params["where"]: 931 | params["where"] = None 932 | 933 | if params["embeddings"]: 934 | result = self.db.query( 935 | **params, 936 | ) 937 | return self.db.parser(result, imgdb=self.imgdb, include=include, threshold=self.threshold) # type: ignore 938 | 939 | elif params["where"] is not None: 940 | return self.all(include=["name"]).query(**params["where"]) 941 | else: 942 | raise ValueError("Either embedding, img or name must be provided") 943 | 944 | def get(self, id, include=None): 945 | """ 946 | Retrieve information about a specific face from the database. 947 | 948 | Args: 949 | id (str): The ID of the face to retrieve. 950 | include (list, optional): List of information to include in the result. Defaults to None. 951 | 952 | Returns: 953 | FaceResults: Information about the retrieved face. 954 | """ 955 | sincludes, include = get_include(default="metadatas", include=include) 956 | result = self.db.get(ids=id, include=sincludes) 957 | return self.db.parser(result, imgdb=self.imgdb, include=include, query=False) 958 | 959 | def update( 960 | self, id, name=None, embedding=None, img=None, only_face=False, **metadata 961 | ): 962 | """ 963 | Update information for a specific face in the database. 964 | 965 | Args: 966 | id (str): The ID of the face to update. 967 | name (str, optional): The new name associated with the face. Defaults to None. 968 | embedding: New face embeddings for the face. 969 | img: New input image for the face. 970 | only_face (bool, optional): Whether to update only the face region. Defaults to False. 971 | **metadata: Additional metadata to update. 972 | 973 | Returns: 974 | None 975 | 976 | Raises: 977 | ValueError: If id is not found. 978 | """ 979 | result = self.get(id=id, include=["name"]) 980 | if not result: 981 | raise ValueError(f"Face with id {id} not found.") 982 | 983 | data = {} 984 | data_update = False 985 | if name is not None: 986 | data["name"] = name 987 | data_update = True 988 | else: 989 | data["name"] = result["name"] # type: ignore 990 | 991 | if metadata: 992 | data_update = True 993 | for key, value in metadata.items(): 994 | data[key] = value 995 | 996 | if not data_update: 997 | data = None 998 | 999 | if img is not None: 1000 | if only_face: 1001 | img = self.get_faces(img) 1002 | if img is None: 1003 | raise ValueError("No face found in the img.") 1004 | img = img[0] 1005 | 1006 | self.imgdb.auto(img_id=id, img=img) 1007 | 1008 | self.db.update( 1009 | ids=id, 1010 | embeddings=embedding, 1011 | metadatas=data, 1012 | ) 1013 | 1014 | def delete(self, id): 1015 | """ 1016 | Delete a face from the database. 1017 | 1018 | Args: 1019 | id (str) or ids (list) of face id. 1020 | 1021 | Returns: 1022 | None 1023 | """ 1024 | 1025 | self.imgdb.delete(id) 1026 | self.db.delete(ids=id) 1027 | 1028 | def delete_all(self): 1029 | """ 1030 | Delete all faces from the database. This action is irreversible. Use with caution. 1031 | """ 1032 | self.db.delete_all() 1033 | self.imgdb.delete_all() 1034 | 1035 | def all(self, include=None) -> FaceResults: 1036 | """ 1037 | Retrieve information about all faces in the database. 1038 | 1039 | Args: 1040 | include (list, optional): List of information to include in the result. Defaults to None. 1041 | 1042 | Returns: 1043 | FaceResults: Information about all faces in the database. 1044 | """ 1045 | dincludes, include = get_include(default=None, include=include) 1046 | result = self.db.all(include=dincludes) 1047 | return self.db.parser(result, imgdb=self.imgdb, include=include, query=False) # type: ignore 1048 | --------------------------------------------------------------------------------