├── .Dockerfile ├── .dockerignore ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── TINYVECTORLOGO.png └── results.png ├── pyproject.toml ├── requirements.txt ├── server ├── __main__.py ├── cache │ └── README.md ├── logs │ └── README.md ├── models │ └── model_response.py └── utils │ ├── generate_token.py │ └── util_pydantic.py ├── setup.py ├── tests ├── README.md ├── __init__.py ├── api │ └── test_endpoint.py ├── db │ ├── __init__.py │ └── test_create_db.py └── models │ ├── __init__.py │ └── test_table_metadata.py └── tinyvector ├── LICENSE ├── MANIFEST.in ├── __init__.py ├── database.py └── types └── model_db.py /.Dockerfile: -------------------------------------------------------------------------------- 1 | # Use an official Python runtime as a parent image 2 | FROM python:3.8-slim-buster 3 | 4 | # Set the working directory in the container to /app 5 | WORKDIR /app 6 | 7 | # Copy the current directory contents into the container at /app 8 | ADD . /app 9 | 10 | # Install any needed packages specified in requirements.txt 11 | RUN pip install --no-cache-dir -r requirements.txt 12 | 13 | # Generate JWT secret 14 | RUN python -c "import secrets; print(secrets.token_urlsafe())" > jwt_secret.txt 15 | 16 | # Make port 5000 available to the world outside this container 17 | EXPOSE 5000 18 | 19 | # Define environment variable 20 | ENV JWT_SECRET=$(cat jwt_secret.txt) 21 | 22 | # Run gunicorn when the container launches 23 | CMD ["gunicorn", "-w", "4", "server.__main__:app"] 24 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Ignore all logs 2 | *.log 3 | 4 | # Ignore all .pyc files 5 | *.pyc 6 | 7 | # Ignore data directory 8 | /data/* 9 | 10 | # Ignore Git and GitHub files 11 | .git/ 12 | .github/ 13 | 14 | # Ignore environment files 15 | .env 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | 163 | *.parquet 164 | *.db 165 | 166 | logs 167 | cache -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Will DePue 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | tinyvector logo 3 |

4 | 5 | 6 |

7 | tinyvector - the tiny, least-dumb, speedy vector embedding database.
8 | No, you don't need a vector database. You need tinyvector. 9 |

10 |

11 | In pre-release: prod-ready by late-July. Still in development, not ready!
12 |

13 | 14 | 15 | ## Features 16 | - __Tiny__: It's in the name. It's just a Flask server, SQLite DB, and Numpy indexes. Extremely easy to customize, under 500 lines of code. 17 | - __Fast__: Tinyvector wlll have comparable speed to advanced vector databases when it comes to speed on small to medium datasets. 18 | - __Vertically Scales__: Tinyvector stores all indexes in memory for fast querying. Very easy to scale up to 100 million+ vector dimensions without issue. 19 | - __Open Source__: MIT Licensed, free forever. 20 | 21 | ### Soon 22 | - __Powerful Queries__: Tinyvector is being upgraded with full SQL querying functionality, something missing from most other databases. 23 | - __Integrated Models__: Soon you won't have to bring your own vectors, just generate them on the server automaticaly. Will support SBert, Hugging Face models, OpenAI, Cohere, etc. 24 | - __Python/JS Client__: We'll add a comprehensive Python and Javascript package for easy integration with tinyvector in the next two weeks. 25 | 26 | ## Versions 27 | 28 | 🦀 tinyvector in Rust: [tinyvector-rs](https://github.com/m1guelpf/tinyvector-rs) 29 | 🐍 tinyvector in Python: [tinyvector](https://github.com/0hq/tinyvector) 30 | 31 | ## We're better than ... 32 | 33 | In most cases, most vector databases are overkill for something simple like: 34 | 1. Using embeddings to chat with your documents. Most document search is nowhere close to what you'd need to justify accelerating search speed with [HNSW](https://github.com/nmslib/hnswlib) or [FAISS](https://github.com/facebookresearch/faiss). 35 | 2. Doing search for your website or store. Unless you're selling 1,000,000 items, you don't need Pinecone. 36 | 3. Performing complex search queries on a very large database. Even if you have 2 million embeddings, this might still be the better option due to vector databases struggling with complex filtering. Tinyvector doesn't support metadata/filtering just yet, but it's very easy for you to add that yourself. 37 | 38 | ## Usage 39 | 40 | ``` 41 | // Run the server manually: 42 | pip install -r requirements 43 | python -m server 44 | 45 | // Run tests: 46 | pip install pytest pytest-mock 47 | pytest 48 | ``` 49 | 50 | ## Embeddings? 51 | 52 | What are embeddings? 53 | 54 | > As simple as possible: Embeddings are a way to compare similar things, in the same way humans compare similar things, by converting text into a small list of numbers. Similar pieces of text will have similar numbers, different ones have very different numbers. 55 | 56 | Read OpenAI's [explanation](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings). 57 | 58 | 59 | ## Get involved 60 | 61 | tinyvector is going to be growing a lot (don't worry, will still be tiny). Feel free to make a PR and contribute. If you have questions, just mention [@willdepue](https://twitter.com/willdepue). 62 | 63 | Some ideas for first pulls: 64 | 65 | - Add metadata and allow querying/filtering. This is especially important since a lot vector databases literally don't have a WHERE clause lol (or just an extremely weak one). Not a problem here. [Read more about this.](https://www.pinecone.io/learn/vector-search-filtering) 66 | - Rethinking SQLite and choosing something. NOSQL feels fitting for embeddings? 67 | - Add embedding functions for easy adding text (sentence transformers, OpenAI, Cohere, etc.) 68 | - Let's start GPU accelerating with a Pytorch index. GPUs are great at matmuls -> NN search with a fused kernel. Let's put 32 million vectors on a single GPU. 69 | - Help write unit and integration tests. 70 | - See all [active issues](https://github.com/0hq/tinyvector/issues)! 71 | 72 | ### Known Issues 73 | ``` 74 | # Major bugs: 75 | Data corruption SQLite error? Stored vectors end up changing. Replicate by creating a table, inserting vectors, creating an index and then screwing around till an error happens. Dims end up unmatched (might be the blob functions or the norm functions most likely, but doesn't explain why the database is changing). 76 | PCA is not tested, neither is immutable Brute Force index. 77 | ``` 78 | 79 | 80 | ## License 81 | 82 | [MIT](./LICENSE) 83 | -------------------------------------------------------------------------------- /assets/TINYVECTORLOGO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0hq/tinyvector/3d216419c28bd59aa042cbe836dc7a1a60e8a049/assets/TINYVECTORLOGO.png -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0hq/tinyvector/3d216419c28bd59aa042cbe836dc7a1a60e8a049/assets/results.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | pythonpath = [ 3 | "." 4 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | blinker==1.6.2 2 | click==8.1.4 3 | Flask==2.3.2 4 | flask-pydantic-spec==0.4.5 5 | gunicorn==20.1.0 6 | inflection==0.5.1 7 | itsdangerous==2.1.2 8 | Jinja2==3.1.2 9 | MarkupSafe==2.1.3 10 | numpy==1.25.1 11 | pydantic==1.10.11 12 | PyJWT==2.7.0 13 | python-dotenv==1.0.0 14 | typing_extensions==4.7.1 15 | Werkzeug==2.3.6 16 | tinyvector==0.1.0 -------------------------------------------------------------------------------- /server/__main__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from functools import wraps 4 | 5 | import jwt 6 | import numpy as np 7 | from dotenv import load_dotenv 8 | from flask import Flask, g, jsonify, request 9 | from flask_pydantic_spec import FlaskPydanticSpec, Response 10 | from pydantic import BaseModel 11 | 12 | from server.models.model_response import ErrorMessage 13 | from server.utils.util_pydantic import pydantic_to_dict 14 | from tinyvector import DB 15 | from tinyvector.types.model_db import (DatabaseInfo, IndexCreationBody, 16 | IndexDeletionBody, ItemInsertionBody, 17 | TableCreationBody, TableDeletionBody, 18 | TableMetadata, TableQueryObject, 19 | TableQueryResult) 20 | 21 | # if logs directory does not exist, create it 22 | if not os.path.exists("logs"): 23 | os.makedirs("logs") 24 | 25 | # if cache directory does not exist, create it 26 | if not os.path.exists("cache"): 27 | os.makedirs("cache") 28 | 29 | logging.basicConfig( 30 | filename="logs/server.log", 31 | level=logging.INFO, 32 | format="%(asctime)s %(levelname)s %(name)s %(threadName)s : %(message)s", 33 | ) 34 | load_dotenv() 35 | 36 | # If JWT_SECRET is not set, the application will run in debug mode 37 | if os.getenv("JWT_SECRET") is None: 38 | os.environ["FLASK_ENV"] = "development" 39 | 40 | 41 | app = Flask(__name__) 42 | api = FlaskPydanticSpec( 43 | "Tiny Vector Database", 44 | ) 45 | 46 | stream_handler = logging.StreamHandler() 47 | stream_handler.setLevel(logging.DEBUG) 48 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 49 | stream_handler.setFormatter(formatter) 50 | app.logger.addHandler(stream_handler) 51 | 52 | DATABASE_PATH = os.environ.get("DATABASE_PATH", "cache/database.db") 53 | 54 | 55 | def get_db(): 56 | if "db" not in g: 57 | g.db = DB(DATABASE_PATH) 58 | return g.db 59 | 60 | 61 | @app.teardown_appcontext 62 | def close_db(exception): 63 | _ = g.pop("db", None) 64 | 65 | 66 | def token_required(f): 67 | """ 68 | Decorator function to enforce token authentication for specific routes. 69 | """ 70 | 71 | @wraps(f) 72 | def decorator(*args, **kwargs): 73 | token = None 74 | 75 | if app.debug: 76 | return f(*args, **kwargs) 77 | 78 | if "Authorization" in request.headers: 79 | token = request.headers["Authorization"] 80 | 81 | if not token: 82 | app.logger.warning("Token is missing!") 83 | return jsonify({"message": "Token is missing!"}), 401 84 | 85 | JWT_SECRET = os.getenv("JWT_SECRET") 86 | if JWT_SECRET is None: 87 | raise Exception("JWT_SECRET is not set!") 88 | 89 | try: 90 | _ = jwt.decode(token, JWT_SECRET, algorithms=["HS256"]) 91 | except jwt.ExpiredSignatureError: 92 | app.logger.error("Token is expired!") 93 | return jsonify({"message": "Token is expired!"}), 401 94 | except jwt.InvalidTokenError: 95 | app.logger.error("Token is invalid!") 96 | return jsonify({"message": "Token is invalid!"}), 401 97 | 98 | return f(*args, **kwargs) 99 | 100 | return decorator 101 | 102 | 103 | class SuccessMessage(BaseModel): 104 | status: str = "success" 105 | 106 | 107 | @app.route("/status", methods=["GET"]) 108 | @api.validate(body=None, resp=Response(HTTP_200=SuccessMessage), tags=["API"]) 109 | @pydantic_to_dict 110 | def status(): 111 | """ 112 | Route to check the status of the application. 113 | """ 114 | app.logger.info("Status check performed") 115 | return SuccessMessage(status="success"), 200 116 | 117 | 118 | @app.route("/info", methods=["GET"]) 119 | @token_required 120 | @api.validate( 121 | body=None, resp=Response(HTTP_200=DatabaseInfo, HTTP_400=ErrorMessage), tags=["API"] 122 | ) 123 | @pydantic_to_dict 124 | def info(): 125 | """ 126 | Route to get information about the database. 127 | """ 128 | try: 129 | info = get_db().info() 130 | app.logger.info("Info retrieved successfully") 131 | return info, 200 132 | except Exception as e: 133 | app.logger.error(f"Error while retrieving info: {str(e)}") 134 | return ErrorMessage(error=f"Error while retrieving info: {str(e)}"), 400 135 | 136 | 137 | @app.route("/create_table", methods=["POST"]) 138 | @token_required 139 | @api.validate( 140 | body=TableCreationBody, 141 | resp=Response(HTTP_200=SuccessMessage, HTTP_400=ErrorMessage), 142 | tags=["DB"], 143 | ) 144 | @pydantic_to_dict 145 | def create_table(): 146 | """ 147 | Route to create a table in the database. 148 | If use_uuid is True, the table will use UUIDs as IDs, and the IDs provided in the insert route are not allowed. 149 | If use_uuid is False, the table will require strings as IDs. 150 | """ 151 | 152 | data = request.get_json() 153 | body = TableCreationBody(**data) 154 | try: 155 | get_db().create_table_and_index(body) 156 | app.logger.info(f"Table {body.table_name} created successfully") 157 | return ( 158 | SuccessMessage(status=f"Table {body.table_name} created successfully"), 159 | 200, 160 | ) 161 | except Exception as e: 162 | app.logger.error(f"Error while creating table {body.table_name}: {str(e)}") 163 | return ErrorMessage(error=str(e)), 400 164 | 165 | 166 | @app.route("/delete_table", methods=["DELETE"]) 167 | @api.validate( 168 | body=TableDeletionBody, 169 | resp=Response(HTTP_200=SuccessMessage, HTTP_400=ErrorMessage), 170 | tags=["DB"], 171 | ) 172 | @token_required 173 | @pydantic_to_dict 174 | def delete_table(): 175 | """ 176 | Route to permanently delete a table and its data from the database. 177 | This will also delete the index associated with the table. 178 | """ 179 | data = request.get_json() 180 | body = TableDeletionBody(**data) 181 | table_name = body.table_name 182 | try: 183 | get_db().delete_table(table_name) 184 | app.logger.info(f"Table {table_name} deleted successfully") 185 | return SuccessMessage(status=f"Table {table_name} deleted successfully"), 200 186 | except Exception as e: 187 | app.logger.error(f"Error while deleting table {table_name}: {str(e)}") 188 | return ( 189 | ErrorMessage(status=f"Error while deleting table {table_name}: {str(e)}"), 190 | 400, 191 | ) 192 | 193 | 194 | @app.route("/insert", methods=["POST"]) 195 | @api.validate( 196 | body=ItemInsertionBody, 197 | resp=Response(HTTP_200=SuccessMessage, HTTP_400=ErrorMessage), 198 | tags=["DB"], 199 | ) 200 | @token_required 201 | @pydantic_to_dict 202 | def insert(): 203 | """ 204 | Route to insert an item into a table in the database. 205 | Requires a previously generated vector embedding of the right dimension (set when creating the table). 206 | If use_uuid was set to True when creating the table, the ID will be generated automatically. Providing an ID in the request will result in an error. 207 | If use_uuid was set to False, the ID must be provided as a string. 208 | Defer index update can be set to True to stop the index from being updated after the insert. This only works for brute force index, as other indexes can't be efficiently updated after creation. 209 | """ 210 | data = request.get_json() 211 | body = ItemInsertionBody(**data) 212 | 213 | table_name = body.table_name 214 | id = body.id 215 | embedding = body.embedding 216 | content = body.content 217 | defer_index_update = body.defer_index_update 218 | try: 219 | embedding = np.array(embedding) 220 | get_db().insert(table_name, id, embedding, content, defer_index_update) 221 | app.logger.info(f"Item {id} inserted successfully into table {table_name}") 222 | return ( 223 | SuccessMessage( 224 | status=f"Item {id} inserted successfully into table {table_name}" 225 | ), 226 | 200, 227 | ) 228 | except Exception as e: 229 | app.logger.error(f"Error while inserting item {id}: {str(e)}") 230 | return ( 231 | ErrorMessage(status=f"Error while inserting item {id}: {str(e)}"), 232 | 400, 233 | ) 234 | 235 | 236 | @app.route("/query", methods=["POST"]) 237 | @token_required 238 | @api.validate( 239 | body=TableQueryObject, 240 | resp=Response(HTTP_200=TableQueryResult, HTTP_400=ErrorMessage), 241 | tags=["DB"], 242 | ) 243 | @pydantic_to_dict 244 | def query(): 245 | """ 246 | Route to perform a query on a table in the database. 247 | Requires a previously generated query vector embedding of the right dimension (set when creating the table). 248 | K is the number of items to return. 249 | """ 250 | body = TableQueryObject(**request.get_json()) 251 | 252 | table_name = body.table_name 253 | query = body.query 254 | k = body.k 255 | try: 256 | query = np.array(query) 257 | items = get_db().query(table_name, query, k) 258 | app.logger.info(f"Query performed successfully on table {table_name}") 259 | print(items) 260 | return TableQueryResult(items=items), 200 261 | except Exception as e: 262 | app.logger.error( 263 | f"Error while performing query on table {table_name}: {str(e)}" 264 | ) 265 | return ( 266 | ErrorMessage( 267 | error=f"Error while performing query on table {table_name}: {str(e)}" 268 | ), 269 | 400, 270 | ) 271 | 272 | 273 | @app.route("/create_index", methods=["POST"]) 274 | @api.validate( 275 | body=IndexCreationBody, 276 | tags=["DB"], 277 | resp=Response(HTTP_200=SuccessMessage, HTTP_400=ErrorMessage), 278 | ) 279 | @token_required 280 | @pydantic_to_dict 281 | def create_index(): 282 | """ 283 | Route to create an index on a table in the database. 284 | Index type can be 'brute_force' or 'pca'. 285 | PCA index is a dimensionality reduction index, and n_components is the number of components to keep and must be less than the dimension of the table. See README for more details. 286 | If normalize is True, cosine similarity will be used. If False, dot product will be used. 287 | If allow_index_updates is True, the index will be updated after each insert. This only works for brute force index, as other indexes can't be efficiently updated after creation. 288 | If you want to update index contents from a non-updatable index (PCA, others), the reccomended method is to delete and create a new one. 289 | """ 290 | body = IndexCreationBody(**request.get_json()) 291 | 292 | table_name = body.table_name 293 | index_type = body.index_type 294 | normalize = body.normalize 295 | allow_index_updates = body.allow_index_updates 296 | n_components = body.n_components 297 | 298 | try: 299 | get_db().create_index( 300 | table_name, index_type, normalize, allow_index_updates, n_components 301 | ) 302 | app.logger.info(f"Index created successfully on table {table_name}") 303 | return ( 304 | SuccessMessage(status=f"Index created successfully on table {table_name}"), 305 | 200, 306 | ) 307 | except Exception as e: 308 | app.logger.error(f"Error while creating index on table {table_name}: {str(e)}") 309 | return ( 310 | ErrorMessage( 311 | error=f"Error while creating index on table {table_name}: {str(e)}" 312 | ), 313 | 400, 314 | ) 315 | 316 | 317 | @app.route("/delete_index", methods=["DELETE"]) 318 | @token_required 319 | @api.validate( 320 | body=IndexDeletionBody, 321 | tags=["DB"], 322 | resp=Response(HTTP_200=SuccessMessage, HTTP_400=ErrorMessage), 323 | ) 324 | @pydantic_to_dict 325 | def delete_index(): 326 | """ 327 | Route to delete an index from a table in the database. 328 | This will not delete the table or its data. 329 | """ 330 | body = IndexDeletionBody(**request.get_json()) 331 | table_name = body.table_name 332 | try: 333 | get_db().delete_index(table_name) 334 | app.logger.info(f"Index deleted successfully on table {table_name}") 335 | return ( 336 | SuccessMessage(status=f"Index deleted successfully on table {table_name}"), 337 | 200, 338 | ) 339 | except Exception as e: 340 | app.logger.error(f"Error while deleting index on table {table_name}: {str(e)}") 341 | return ( 342 | ErrorMessage( 343 | error=f"Error while deleting index on table {table_name}: {str(e)}" 344 | ), 345 | 400, 346 | ) 347 | 348 | 349 | if __name__ == "__main__": 350 | app.logger.info("Starting server...") 351 | api.register(app) 352 | PORT = 5234 353 | app.logger.info( 354 | f"\nSuccesfully Generated Documentation :) \n\n- Redoc: http://localhost:{PORT}/apidoc/redoc \n- Swagger: http://localhost:{PORT}/apidoc/swagger" 355 | ) 356 | app.run(host="0.0.0.0", port=PORT, debug=True) 357 | get_db() 358 | -------------------------------------------------------------------------------- /server/cache/README.md: -------------------------------------------------------------------------------- 1 | # Database File Cache 2 | 3 | Store database files, index caches, and more here. 4 | -------------------------------------------------------------------------------- /server/logs/README.md: -------------------------------------------------------------------------------- 1 | # App Logs 2 | 3 | Logs are stored in .log format following the Python logger format. 4 | Logs are also piped to stdout for console viewing. 5 | -------------------------------------------------------------------------------- /server/models/model_response.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class SuccessMessage(BaseModel): 5 | status: str = "success" 6 | 7 | 8 | class ErrorMessage(BaseModel): 9 | error: str = "" 10 | -------------------------------------------------------------------------------- /server/utils/generate_token.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import secrets 4 | 5 | import jwt 6 | from dotenv import load_dotenv 7 | 8 | load_dotenv() 9 | 10 | def generate_jwt_secret(): 11 | # Generate a random string of length 64 and write it to .env file 12 | jwt_secret = secrets.token_urlsafe(64) 13 | 14 | with open('.env', 'a') as file: 15 | file.write(f'\nJWT_SECRET={jwt_secret}') 16 | 17 | return jwt_secret 18 | 19 | def generate_token(user_id): 20 | # Generate a token with user_id as payload with 30 days expiry 21 | payload = { 22 | 'user_id': user_id, 23 | 'exp': datetime.datetime.utcnow() + datetime.timedelta(days=30, minutes=0) 24 | } 25 | 26 | token = jwt.encode(payload, os.getenv('JWT_SECRET'), algorithm='HS256') 27 | 28 | return token 29 | 30 | if __name__ == '__main__': 31 | generate_jwt_secret() 32 | generate_token('root') -------------------------------------------------------------------------------- /server/utils/util_pydantic.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | def pydantic_to_dict(f): 7 | @wraps(f) 8 | def decorated_function(*args, **kwargs): 9 | result, status_code = f(*args, **kwargs) 10 | if isinstance(result, BaseModel): 11 | return result.dict(), status_code 12 | return result, status_code 13 | 14 | return decorated_function 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='tinyvector', 5 | version='0.1.1', 6 | author='Will DePue', 7 | author_email='will@depue.net', 8 | license='MIT', 9 | description='the tiny, least-dumb, speedy vector embedding database.', 10 | long_description=open('README.md').read(), 11 | long_description_content_type='text/markdown', 12 | url='https://github.com/0hq/tinyvector', 13 | packages=find_packages(include=['core', 'core.*']), 14 | install_requires=[ 15 | 'numpy', 16 | 'pydantic', 17 | 'psutil', 18 | 'scikit-learn', 19 | ], 20 | ) 21 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | Unit tests coming soon. 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0hq/tinyvector/3d216419c28bd59aa042cbe836dc7a1a60e8a049/tests/__init__.py -------------------------------------------------------------------------------- /tests/api/test_endpoint.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import time 5 | from unittest.mock import patch 6 | 7 | import jwt as pyjwt 8 | import pytest 9 | 10 | from server.__main__ import SuccessMessage, app 11 | from server.models.model_response import ErrorMessage 12 | from tinyvector.database import DB 13 | from tinyvector.types.model_db import (DatabaseInfo, ItemInsertionBody, 14 | TableCreationBody, TableDeletionBody, 15 | TableMetadata, TableQueryObject) 16 | 17 | test_table = "test_table" 18 | jwt_secret = "testing_config" 19 | db_path = "test_endpoint_db.db" 20 | new_table = TableMetadata( 21 | table_name=test_table, 22 | allow_index_updates=True, 23 | dimension=42, 24 | index_type="brute_force", 25 | is_index_active=True, 26 | normalize=True, 27 | use_uuid=False, 28 | ) 29 | 30 | 31 | def generate_client(app): 32 | os.environ["JWT_SECRET"] = jwt_secret 33 | encoded_jwt = pyjwt.encode( 34 | {"payload": "random payload"}, jwt_secret, algorithm="HS256" 35 | ) 36 | headers = {"Authorization": encoded_jwt} 37 | 38 | client = app.test_client() 39 | client.environ_base.update(headers) 40 | return client, headers 41 | 42 | 43 | @pytest.fixture(scope="session") 44 | def create_test_db(): 45 | test_db = DB(path=db_path) 46 | test_db.create_table_and_index(new_table) 47 | yield test_db 48 | os.remove(db_path) 49 | 50 | 51 | def test_token_authentication(create_test_db, mocker): 52 | """ 53 | This validates that the info endpoint correctly captures the state of the DB at a specific snapshot in time. 54 | """ 55 | # Mock the DB object 56 | mocker.patch("server.__main__.get_db", return_value=create_test_db) 57 | 58 | test_client, headers = generate_client(app) 59 | 60 | # /info 61 | with test_client as client: 62 | response = client.get("/info") 63 | data = json.loads(response.data) 64 | 65 | assert response.status_code == 401 66 | assert data == {"message": "Token is missing!"} 67 | 68 | response = client.get("/info", headers={"Authorization": "random_token"}) 69 | data = json.loads(response.data) 70 | assert response.status_code == 401 71 | assert data == {"message": "Token is invalid!"} 72 | 73 | expired_jwt = pyjwt.encode( 74 | {"payload": "random payload", "exp": int(time.time()) - 10}, 75 | jwt_secret, 76 | algorithm="HS256", 77 | ) 78 | response = client.get("/info", headers={"Authorization": expired_jwt}) 79 | data = json.loads(response.data) 80 | assert response.status_code == 401 81 | assert data == {"message": "Token is expired!"} 82 | 83 | 84 | def test_info(create_test_db, mocker): 85 | """ 86 | This validates that the info endpoint correctly captures the state of the DB at a specific snapshot in time. 87 | """ 88 | # Mock the DB object 89 | mocker.patch("server.__main__.get_db", return_value=create_test_db) 90 | 91 | test_client, headers = generate_client(app) 92 | 93 | # /info 94 | with test_client as client: 95 | response = client.get("/info", headers=headers) 96 | data = json.loads(response.data) 97 | 98 | assert response.status_code == 200 99 | 100 | expected_response = DatabaseInfo( 101 | num_tables=1, 102 | num_indexes=1, 103 | tables={}, 104 | indexes=["test_table brute_force_mutable"], 105 | ) 106 | expected_response.tables[test_table] = new_table 107 | 108 | # Replace the expected data with the actual data you expect from DB.info() 109 | assert data == expected_response.dict() 110 | 111 | 112 | def test_status_endpoint(create_test_db, mocker): 113 | """ 114 | This validates that the info endpoint correctly captures the state of the DB at a specific snapshot in time. 115 | """ 116 | # Mock the DB object 117 | mocker.patch("server.__main__.get_db", return_value=create_test_db) 118 | 119 | test_client, headers = generate_client(app) 120 | 121 | # /info 122 | with test_client as client: 123 | response = client.get("/status") 124 | data = json.loads(response.data) 125 | 126 | assert response.status_code == 200 127 | 128 | # Replace the expected data with the actual data you expect from DB.info() 129 | assert data == SuccessMessage(status="success").dict() 130 | 131 | 132 | def test_create_endpoint(create_test_db, mocker): 133 | """ 134 | This validates that we can succesfully create a table object and have the database be updated and populated with the new table. 135 | """ 136 | # Mock the DB object 137 | mocker.patch("server.__main__.get_db", return_value=create_test_db) 138 | 139 | test_client, headers = generate_client(app) 140 | 141 | # /info 142 | with test_client as client: 143 | table_2 = TableCreationBody( 144 | table_name="table_testing_2", 145 | allow_index_updates=True, 146 | dimension=42, 147 | index_type="brute_force", 148 | is_index_active=True, 149 | normalize=True, 150 | use_uuid=False, 151 | ) 152 | 153 | response = client.post( 154 | "/create_table", json=table_2.dict(), headers=headers 155 | ) 156 | 157 | data = json.loads(response.data) 158 | 159 | assert response.status_code == 200 160 | assert ( 161 | data 162 | == SuccessMessage( 163 | status=f"Table {table_2.table_name} created successfully" 164 | ).dict() 165 | ) 166 | 167 | response = client.get("/info", headers=headers) 168 | data = json.loads(response.data) 169 | 170 | assert response.status_code == 200 171 | assert data["num_tables"] == 2 172 | assert data["num_indexes"] == 2 173 | assert data["tables"][table_2.table_name] == table_2.dict() 174 | 175 | # We then create a new db instance which loads metadata from scratch based on sqlite3 database and validate that it has the same data 176 | 177 | new_db = DB(path=db_path, debug=True) 178 | new_instance_data = new_db.info() 179 | assert new_instance_data.dict() == data 180 | 181 | 182 | def test_delete_endpoint(create_test_db, mocker): 183 | """ 184 | This validates the delete endpoint 185 | """ 186 | # Mock the DB object 187 | mocker.patch("server.__main__.get_db", return_value=create_test_db) 188 | 189 | test_client, headers = generate_client(app) 190 | 191 | with test_client as client: 192 | # We previously created a new table in a previous test, let's make sure that the changes have been applied. 193 | 194 | response = client.get("/info", headers=headers) 195 | body = json.loads(response.data) 196 | assert response.status_code == 200 197 | assert body["num_tables"] == 2 198 | assert body["num_indexes"] == 2 199 | 200 | # Try to delete an invalid table 201 | body = TableDeletionBody( 202 | table_name="fake_table", 203 | ) 204 | 205 | response = client.delete( 206 | "/delete_table", json=body.dict(), headers=headers 207 | ) 208 | data = json.loads(response.data) 209 | assert response.status_code == 400 210 | assert ( 211 | data 212 | == ErrorMessage( 213 | status=f"Error while deleting table fake_table: Table fake_table does not exist" 214 | ).dict() 215 | ) 216 | 217 | # Try to delete a valid table 218 | body = TableDeletionBody( 219 | table_name="table_testing_2", 220 | ) 221 | response = client.delete( 222 | "/delete_table", json=body.dict(), headers=headers 223 | ) 224 | data = json.loads(response.data) 225 | assert response.status_code == 200 226 | assert ( 227 | data 228 | == SuccessMessage( 229 | status=f"Table {body.table_name} deleted successfully" 230 | ).dict() 231 | ) 232 | 233 | # Validate that the table has been deleted 234 | response = client.get("/info", headers=headers) 235 | data = json.loads(response.data) 236 | 237 | assert response.status_code == 200 238 | assert data["num_tables"] == 1 239 | assert data["num_indexes"] == 1 240 | 241 | # We then create a new db instance which loads metadata from scratch based on sqlite3 database and validate that it has the same data 242 | 243 | new_db = DB(path=db_path, debug=True) 244 | new_instance_data = new_db.info() 245 | assert new_instance_data.dict() == data 246 | 247 | 248 | def test_insert_endpoint(create_test_db, mocker): 249 | """ 250 | This validates the delete endpoint 251 | """ 252 | # Mock the DB object 253 | mocker.patch("server.__main__.get_db", return_value=create_test_db) 254 | 255 | test_client, headers = generate_client(app) 256 | 257 | with test_client as client: 258 | values = 200 259 | for i in range(values): 260 | item = ItemInsertionBody( 261 | table_name=test_table, 262 | id=f"Item {i}", 263 | embedding=[random.randint(0, values) for i in range(42)], 264 | content=f"Item {i} content", 265 | defer_index_update=False, 266 | ) 267 | response = client.post("/insert", json=item.dict(), headers=headers) 268 | data = json.loads(response.data) 269 | assert response.status_code == 200 270 | assert ( 271 | data["status"] 272 | == f"Item Item {i} inserted successfully into table {test_table}" 273 | ) 274 | 275 | # We now validate that our table has a total of 10 items by calling the /query endpoint 276 | body = TableQueryObject( 277 | table_name=test_table, query=[0 for i in range(42)], k=values 278 | ) 279 | response = client.post("/query", json=body.dict(), headers=headers) 280 | data = json.loads(response.data) 281 | 282 | assert response.status_code == 200 283 | assert len(data["items"]) == values 284 | -------------------------------------------------------------------------------- /tests/db/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0hq/tinyvector/3d216419c28bd59aa042cbe836dc7a1a60e8a049/tests/db/__init__.py -------------------------------------------------------------------------------- /tests/db/test_create_db.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sqlite3 3 | 4 | import pytest 5 | 6 | from tinyvector.database import DB 7 | from tinyvector.types.model_db import (DatabaseInfo, TableCreationBody, 8 | TableMetadata) 9 | 10 | 11 | @pytest.fixture 12 | def create_db(tmp_path): 13 | db_path = tmp_path / "test.db" 14 | table_name = "test_table" 15 | 16 | db_config = TableCreationBody( 17 | table_name=table_name, 18 | allow_index_updates=False, 19 | dimension=42, 20 | index_type="brute_force", 21 | is_index_active=True, 22 | normalize=True, 23 | use_uuid=False, 24 | ) 25 | test_db = DB(db_path) 26 | test_db.create_table_and_index(db_config) 27 | yield test_db, table_name, db_config, db_path 28 | os.remove(db_path) 29 | 30 | 31 | def test_create_table_again_raises_error(create_db): 32 | test_db, table_name, db_config, db_path = create_db 33 | with pytest.raises(ValueError, match=f"Table {table_name} already exists"): 34 | test_db.create_table_and_index(db_config) 35 | 36 | 37 | def test_loads_correct_metadata_on_startup(create_db): 38 | test_db, table_name, db_config, db_path = create_db 39 | test_db_2 = DB(db_path) 40 | info = test_db_2.info() 41 | print(info) 42 | assert table_name in info.tables, f"Table {table_name} not found" 43 | assert TableMetadata(**db_config.dict()) == info.tables[table_name] 44 | assert info.num_tables == 1, "Number of tables should be 1" 45 | 46 | 47 | def test_delete_table(create_db): 48 | test_db, table_name, db_config, db_path = create_db 49 | test_db.delete_table(table_name) 50 | info = test_db.info() 51 | assert table_name not in info.tables, f"Table {table_name} should not be found" 52 | assert info.num_tables == 0, "Number of tables should be 0" 53 | 54 | 55 | def test_create_index(create_db): 56 | test_db, table_name, db_config, db_path = create_db 57 | # TODO 58 | pass 59 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0hq/tinyvector/3d216419c28bd59aa042cbe836dc7a1a60e8a049/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/test_table_metadata.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import ValidationError 3 | 4 | from tinyvector.types.model_db import TableMetadata 5 | 6 | 7 | def test_invalid_table_metadata(): 8 | """ 9 | We raise an error when allow_index_updates is True and index type is PCA since we don't support this functionality. 10 | """ 11 | with pytest.raises(ValidationError) as exc_info: 12 | TableMetadata( 13 | table_name="example", 14 | allow_index_updates=True, 15 | dimension=3, 16 | index_type="pca", 17 | is_index_active=True, 18 | normalize=True, 19 | use_uuid=1, 20 | ) 21 | -------------------------------------------------------------------------------- /tinyvector/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Will DePue 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tinyvector/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md -------------------------------------------------------------------------------- /tinyvector/__init__.py: -------------------------------------------------------------------------------- 1 | from .database import DB 2 | -------------------------------------------------------------------------------- /tinyvector/database.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import uuid 3 | from abc import ABC, abstractmethod 4 | from typing import cast 5 | 6 | import numpy as np 7 | import psutil 8 | from sklearn.decomposition import PCA 9 | 10 | from .types.model_db import DatabaseInfo, TableCreationBody, TableMetadata 11 | 12 | 13 | def norm_np(datum): 14 | """ 15 | Normalize a NumPy array along the specified axis using L2 normalization. 16 | """ 17 | axis = None if datum.ndim == 1 else 1 18 | return datum / (np.linalg.norm(datum, axis=axis, keepdims=True) + 1e-6) 19 | 20 | 21 | def norm_python(datum): 22 | """ 23 | Normalize a Python list or ndarray using L2 normalization. 24 | """ 25 | datum = np.array(datum) 26 | result = norm_np(datum) 27 | return result.tolist() 28 | 29 | 30 | def array_to_blob(array): 31 | """ 32 | Convert a NumPy array to a binary blob. 33 | """ 34 | # Validate that the array is a numpy array of type float32 35 | if not isinstance(array, np.ndarray): 36 | raise ValueError("Expected a numpy array") 37 | return array.tobytes() 38 | 39 | 40 | def blob_to_array(blob): 41 | """ 42 | Convert a binary blob to a NumPy array. 43 | """ 44 | result = np.frombuffer(blob, dtype=np.float32) 45 | return result 46 | 47 | 48 | class AbstractIndex(ABC): 49 | """ 50 | Abstract base class for database indexes. 51 | """ 52 | 53 | def __init__(self, table_name, type, num_vectors, dimension): 54 | self.table_name = table_name 55 | self.type = type 56 | self.num_vectors = num_vectors 57 | self.dimension = dimension 58 | 59 | def __len__(self): 60 | return self.num_vectors 61 | 62 | def __str__(self): 63 | return self.table_name + " " + self.type 64 | 65 | def __repr__(self): 66 | return self.__str__() 67 | 68 | @abstractmethod 69 | def add_vector(self, id, embedding): 70 | """ 71 | Add a vector with the specified ID and embedding to the index. 72 | """ 73 | pass 74 | 75 | @abstractmethod 76 | def get_similarity(self, query, k): 77 | """ 78 | Get the top-k most similar vectors to the given query vector. 79 | """ 80 | pass 81 | 82 | 83 | class BruteForceIndexImmutable(AbstractIndex): 84 | """ 85 | Brute-force index implementation that is immutable (does not support vector addition). 86 | Faster than the mutable version because it uses NumPy arrays instead of Python lists. 87 | Search complexity is O(n + k log k) where n is the number of vectors in the index. 88 | """ 89 | 90 | def __init__(self, table_name, dimension, normalize, embeddings, ids): 91 | super().__init__( 92 | table_name, "brute_force_immutable", len(embeddings), dimension 93 | ) 94 | # TODO: Validate if this is the right fit. I hope it is ... ? 95 | self.embeddings = norm_np(np.array(embeddings)) if normalize else embeddings 96 | self.ids = ids 97 | self.normalize = normalize 98 | 99 | def add_vector(self, id, embedding): 100 | """ 101 | Add a vector to the index (not supported in the immutable version). 102 | """ 103 | raise NotImplementedError( 104 | "This index does not support vector addition. Please use BruteForceIndexMutable if updates are required." 105 | ) 106 | 107 | def get_similarity(self, query, k): 108 | """ 109 | Get the top-k most similar vectors to the query vector. 110 | """ 111 | if len(query) != self.dimension: 112 | raise ValueError( 113 | f"Expected query of dimension {self.dimension}, got {len(query)}" 114 | ) 115 | 116 | query_normalized = norm_np(query) if self.normalize else query 117 | import pdb 118 | 119 | pdb.set_trace() 120 | scores = query_normalized @ self.embeddings.T 121 | arg_k = k if k < len(scores) else len(scores) - 1 122 | partitioned_indices = np.argpartition(-scores, kth=arg_k)[:k] 123 | top_k_indices = partitioned_indices[np.argsort(-scores[partitioned_indices])] 124 | 125 | return [ 126 | {"id": self.ids[i], "embedding": self.embeddings[i], "score": scores[i]} 127 | for i in top_k_indices 128 | ] 129 | 130 | 131 | class BruteForceIndexMutable(AbstractIndex): 132 | """ 133 | Brute-force index implementation that is mutable (supports vector addition). 134 | Slower than the immutable version because it uses Python lists instead of NumPy arrays. 135 | Search complexity is O(n + k log k) where n is the number of vectors in the index. 136 | """ 137 | 138 | def __init__(self, table_name, dimension, normalize, embeddings, ids): 139 | super().__init__(table_name, "brute_force_mutable", len(embeddings), dimension) 140 | 141 | if len(embeddings) > 0 and dimension != len(embeddings[0]): 142 | raise ValueError( 143 | f"Expected embeddings of dimension {self.dimension}, got {len(embeddings[0])}" 144 | ) 145 | print("embeddings: ", embeddings, "norm_python: ", norm_python(embeddings)) 146 | self.embeddings = norm_python(embeddings) if normalize else embeddings.tolist() 147 | self.ids = ids 148 | self.normalize = normalize 149 | 150 | def add_vector(self, id, embedding): 151 | """ 152 | Add a vector to the index. 153 | """ 154 | if len(embedding) != self.dimension: 155 | raise ValueError( 156 | f"Expected embedding of dimension {self.dimension}, got {len(embedding)}" 157 | ) 158 | self.ids.append(id) 159 | self.embeddings.append( 160 | norm_python(embedding) if self.normalize else embedding.tolist() 161 | ) 162 | self.num_vectors += 1 163 | 164 | def get_similarity(self, query, k): 165 | """ 166 | Get the top-k most similar vectors to the query vector. 167 | """ 168 | if len(query) != self.dimension: 169 | raise ValueError( 170 | f"Expected query of dimension {self.dimension}, got {len(query)}" 171 | ) 172 | 173 | query_normalized = norm_np(query) if self.normalize else query 174 | dataset_vectors = np.array(self.embeddings) 175 | scores = query_normalized @ dataset_vectors.T 176 | arg_k = k if k < len(scores) else len(scores) - 1 177 | partitioned_indices = np.argpartition(-scores, kth=arg_k)[:k] 178 | top_k_indices = partitioned_indices[np.argsort(-scores[partitioned_indices])] 179 | 180 | return [ 181 | {"id": self.ids[i], "embedding": self.embeddings[i], "score": scores[i]} 182 | for i in top_k_indices 183 | ] 184 | 185 | 186 | class PCAIndex(AbstractIndex): 187 | """ 188 | Index implementation using PCA for dimensionality reduction. 189 | Not mutable (does not support vector addition) and fastest of all indexes. 190 | This index does a dimensionality reduction on the input vectors using PCA. This can considerably speed up the search process, as the number of dimensions is reduced from the original dimension to the number of components specified in the constructor. The downside is maybe results are not as accurate as the brute-force index, but can be surprisingly good for many applications. 191 | Search complexity is O(n + k log k) where n is the number of vectors in the index. 192 | Indexing time can be slow on startup for large datasets. 193 | """ 194 | 195 | def __init__(self, table_name, dimension, n_components, normalize, embeddings, ids): 196 | super().__init__( 197 | table_name, "pca", len(embeddings), n_components 198 | ) # Initialize name/len attribute in parent class 199 | self.ids = ids 200 | self.pca = PCA(n_components) 201 | self.original_dimension = dimension 202 | self.embeddings = self.pca.fit_transform(embeddings) 203 | self.embeddings = norm_np(self.embeddings) if normalize else self.embeddings 204 | self.normalize = normalize 205 | 206 | def add_vector(self, id, embedding): 207 | """ 208 | Add a vector to the index (not supported in the PCA index). 209 | """ 210 | raise NotImplementedError( 211 | "This index does not support vector addition. Please use BruteForceIndexMutable if updates are required." 212 | ) 213 | 214 | def get_similarity(self, query, k): 215 | """ 216 | Get the top-k most similar vectors to the query vector. 217 | Applies the previously calculated PCA model to the query vector before searching. 218 | """ 219 | if len(query) != self.original_dimension: 220 | raise ValueError( 221 | f"Expected query of dimension {self.original_dimension}, got {len(query)}" 222 | ) 223 | 224 | dataset_vectors = self.embeddings 225 | transformed_query = self.pca.transform([query])[0] 226 | query_normalized = ( 227 | norm_np(transformed_query) if self.normalize else transformed_query 228 | ) 229 | 230 | scores = query_normalized @ dataset_vectors.T 231 | arg_k = k if k < len(scores) else len(scores) - 1 232 | partitioned_indices = np.argpartition(-scores, kth=arg_k)[:k] 233 | top_k_indices = partitioned_indices[np.argsort(-scores[partitioned_indices])] 234 | 235 | return [ 236 | {"id": self.ids[i], "embedding": self.embeddings[i], "score": scores[i]} 237 | for i in top_k_indices 238 | ] 239 | 240 | 241 | class DB: 242 | """ 243 | Database class for managing tables and indexes. 244 | """ 245 | 246 | def __init__(self, path, debug=False): 247 | self.conn = sqlite3.connect(path, check_same_thread=False) 248 | self.c = self.conn.cursor() 249 | self.table_metadata = {} 250 | self.indexes = {} 251 | 252 | self.path = path 253 | self.debug = debug 254 | self._init_db() 255 | 256 | def info(self): 257 | """ 258 | Get information about all tables in the database. 259 | """ 260 | 261 | res = DatabaseInfo( 262 | num_indexes=len(self.indexes), 263 | num_tables=len(self.table_metadata), 264 | tables={}, 265 | indexes=[str(index) for index in self.indexes.values()], 266 | ) 267 | res.tables = self.table_metadata 268 | return res 269 | 270 | def _init_db(self): 271 | """ 272 | Initialize the database and load metadata. 273 | """ 274 | self.c.execute( 275 | "CREATE TABLE IF NOT EXISTS table_metadata (table_name TEXT PRIMARY KEY, dimension INTEGER, index_type TEXT, normalize BOOLEAN, allow_index_updates BOOLEAN, is_active BOOLEAN, use_uuid BOOLEAN)" 276 | ) 277 | self.conn.commit() 278 | self._load_metadata() 279 | 280 | def _load_metadata(self): 281 | """ 282 | Load table metadata from the database. 283 | At the moment, all indexes are rebuilt on startup. This will be changed in the future. 284 | """ 285 | self.table_metadata = {} 286 | select_query = "SELECT * FROM table_metadata" 287 | 288 | for row in self.c.execute(select_query).fetchall(): 289 | ( 290 | table_name, 291 | dimension, 292 | index_type, 293 | normalize, 294 | allow_index_updates, 295 | is_index_active, 296 | use_uuid, 297 | ) = row 298 | self.table_metadata[table_name] = TableMetadata( 299 | table_name=table_name, 300 | dimension=dimension, 301 | index_type=index_type, 302 | normalize=normalize, 303 | allow_index_updates=allow_index_updates, 304 | is_index_active=is_index_active, 305 | use_uuid=use_uuid, 306 | ) 307 | 308 | if is_index_active: 309 | try: 310 | self.create_index( 311 | table_name, 312 | index_type, 313 | normalize, 314 | allow_index_updates, 315 | dimension, 316 | ) 317 | 318 | except Exception as e: 319 | print( 320 | f"Error loading index for table {table_name}: {e}. Clearing index..." 321 | ) 322 | del self.table_metadata[table_name] 323 | if self.indexes.get(table_name) is not None: 324 | del self.indexes[table_name] 325 | 326 | def create_index( 327 | self, 328 | table_name, 329 | index_type, 330 | normalize, 331 | allow_index_updates=None, 332 | n_components=None, 333 | ): 334 | """ 335 | Create an index on the specified table. 336 | """ 337 | 338 | if psutil.virtual_memory().available < 0.1 * psutil.virtual_memory().total: 339 | raise MemoryError("System is running out of memory") 340 | 341 | def get_data(select_query): 342 | self.c.execute(select_query) 343 | rows = self.c.fetchall() 344 | ids = [row[0] for row in rows] 345 | embeddings = [blob_to_array(row[1]) for row in rows] 346 | return ids, embeddings 347 | 348 | if self.table_metadata.get(table_name) is None: 349 | raise ValueError( 350 | f"Table {table_name} does not exist. Create the table first." 351 | ) 352 | 353 | if self.indexes.get(table_name) is not None: 354 | raise ValueError( 355 | f"Index for table {table_name} already exists. Delete the index first if you want to rebuild it." 356 | ) 357 | 358 | dimension = self.table_metadata[table_name].dimension 359 | if index_type == "brute_force": 360 | ids, embeddings = get_data(f"SELECT * FROM {table_name}") 361 | if allow_index_updates: 362 | self.indexes[table_name] = BruteForceIndexMutable( 363 | table_name, dimension, normalize, embeddings, ids 364 | ) 365 | else: 366 | self.indexes[table_name] = BruteForceIndexImmutable( 367 | table_name, dimension, normalize, embeddings, ids 368 | ) 369 | elif index_type == "pca": 370 | if n_components is None: 371 | raise ValueError("n_components must be specified for PCA index") 372 | ids, embeddings = get_data(f"SELECT * FROM {table_name}") 373 | self.indexes[table_name] = PCAIndex( 374 | table_name, dimension, n_components, normalize, embeddings, ids 375 | ) 376 | else: 377 | raise ValueError(f"Unknown index type {index_type}") 378 | 379 | # Update metadata 380 | self.table_metadata[table_name].index_type = index_type 381 | self.table_metadata[table_name].is_index_active = True 382 | self.table_metadata[table_name].allow_index_updates = allow_index_updates 383 | self.table_metadata[table_name].normalize = normalize 384 | self.c.execute( 385 | "UPDATE table_metadata SET index_type = ?, is_active = ?, allow_index_updates = ?, normalize = ? WHERE table_name = ?", 386 | (index_type, True, allow_index_updates, normalize, table_name), 387 | ) 388 | self.conn.commit() 389 | 390 | def create_table_and_index(self, table_config: TableCreationBody): 391 | """ 392 | Creates a new table and index in the database 393 | """ 394 | 395 | table_name = table_config.table_name 396 | 397 | # We first validate that the table does not exist 398 | if table_name in self.table_metadata: 399 | raise ValueError(f"Table {table_name} already exists") 400 | 401 | # We create the table 402 | self.c.execute( 403 | f"CREATE TABLE {table_name} (id TEXT PRIMARY KEY, embedding BLOB, content TEXT)" 404 | ) 405 | 406 | # We update table metadata 407 | self.table_metadata[table_name] = TableMetadata 408 | 409 | self.update_table_metadata(table_config) 410 | 411 | self.create_index( 412 | table_config.table_name, 413 | table_config.index_type, 414 | table_config.normalize, 415 | table_config.allow_index_updates, 416 | table_config.dimension, 417 | ) 418 | 419 | with self.conn: 420 | cursor = self.conn.cursor() 421 | cursor.execute( 422 | "SELECT * FROM table_metadata WHERE table_name = ?", (table_name,) 423 | ) 424 | row = cursor.fetchone() 425 | self.conn.commit() 426 | 427 | def update_table_metadata(self, table_config: TableCreationBody): 428 | query = """ 429 | INSERT INTO table_metadata ( 430 | table_name, dimension, index_type, normalize, allow_index_updates, is_active, use_uuid 431 | ) VALUES (?, ?, ?, ?, ?, ?, ?) 432 | """ 433 | values = ( 434 | table_config.table_name, 435 | table_config.dimension, 436 | table_config.index_type, 437 | table_config.normalize, 438 | table_config.allow_index_updates, 439 | table_config.is_index_active, 440 | table_config.use_uuid, 441 | ) 442 | with self.conn: 443 | cursor = self.conn.cursor() 444 | cursor.execute(query, values) 445 | self.conn.commit() 446 | 447 | self.table_metadata[table_config.table_name] = table_config 448 | return 449 | 450 | def create_table(self, table_name, dimension, use_uuid): 451 | """ 452 | Create a new table in the database. 453 | """ 454 | if self.table_metadata.get(table_name) is not None: 455 | raise ValueError(f"Table {table_name} already exists") 456 | self.c.execute( 457 | f"CREATE TABLE {table_name} (id TEXT PRIMARY KEY, embedding BLOB, content TEXT)" 458 | ) 459 | self.c.execute( 460 | "INSERT INTO table_metadata VALUES (?, ?, ?, ?, ?, ?, ?)", 461 | (table_name, dimension, None, None, None, False, use_uuid), 462 | ) 463 | self.conn.commit() 464 | 465 | # Update metadata 466 | self.table_metadata[table_name] = { 467 | "dimension": dimension, 468 | "index_type": None, 469 | "normalize": None, 470 | "allow_index_updates": None, 471 | "is_index_active": False, 472 | "use_uuid": use_uuid, 473 | } 474 | 475 | def delete_table(self, table_name): 476 | """ 477 | Delete a table from the database. 478 | """ 479 | if self.table_metadata.get(table_name) is None: 480 | raise ValueError(f"Table {table_name} does not exist") 481 | self.c.execute(f"DROP TABLE {table_name}") 482 | self.c.execute("DELETE FROM table_metadata WHERE table_name = ?", (table_name,)) 483 | self.conn.commit() 484 | if self.indexes.get(table_name) is not None: 485 | del self.indexes[table_name] 486 | self.table_metadata.pop(table_name) 487 | 488 | def delete_index(self, table_name): 489 | """ 490 | Delete an index from a table. 491 | """ 492 | if self.indexes.get(table_name) is None: 493 | raise ValueError(f"Index for table {table_name} does not exist") 494 | del self.indexes[table_name] 495 | self.table_metadata[table_name]["is_index_active"] = False 496 | self.table_metadata[table_name]["index_type"] = None 497 | self.table_metadata[table_name]["normalize"] = None 498 | self.table_metadata[table_name]["allow_index_updates"] = None 499 | self.c.execute( 500 | "UPDATE table_metadata SET index_type = NULL, normalize = NULL, allow_index_updates = NULL, is_active = 0 WHERE table_name = ?", 501 | (table_name,), 502 | ) 503 | self.conn.commit() 504 | 505 | def insert(self, table_name, id, embedding, content, defer_index_update=False): 506 | """ 507 | Insert a vector into a table. 508 | """ 509 | if psutil.virtual_memory().available < 0.1 * psutil.virtual_memory().total: 510 | raise MemoryError("System is running out of memory") 511 | 512 | # First validate that table exists 513 | if table_name not in self.table_metadata: 514 | raise ValueError(f"Table {table_name} does not exist") 515 | 516 | table_config: TableMetadata = self.table_metadata[table_name] 517 | 518 | if table_config.use_uuid and id is not None: 519 | raise ValueError( 520 | "This table uses auto-generated UUIDs. Do not provide an ID." 521 | ) 522 | elif table_config.use_uuid: 523 | id = str(uuid.uuid4()) # Generate a unique ID using the uuid library. 524 | elif id is None: 525 | raise ValueError("This table uses custom IDs. Please provide an ID.") 526 | 527 | insert_query = f"INSERT INTO {table_name} VALUES (?, ?, ?)" 528 | self.c.execute(insert_query, (id, array_to_blob(embedding), content)) 529 | self.conn.commit() 530 | 531 | if ( 532 | table_config.is_index_active is True 533 | and table_config.allow_index_updates is True 534 | and defer_index_update is False 535 | ): 536 | self.indexes[table_name].add_vector(id, embedding) 537 | 538 | def query(self, table_name, query, k): 539 | """ 540 | Perform a query on a table. 541 | """ 542 | if self.table_metadata.get(table_name) is None: 543 | raise ValueError(f"Table {table_name} does not exist") 544 | 545 | table_config = self.table_metadata[table_name] 546 | table_config = cast(TableMetadata, table_config) 547 | 548 | if table_config.is_index_active is False: 549 | raise ValueError(f"Index for table {table_name} does not exist") 550 | 551 | items = self.indexes[table_name].get_similarity(np.array(query), k) 552 | 553 | # Get content from DB in a single query 554 | ids = [item["id"] for item in items] 555 | placeholder = ", ".join(["?"] * len(ids)) # Generate placeholders for query 556 | self.c.execute( 557 | f"SELECT id, content FROM {table_name} WHERE id IN ({placeholder})", ids 558 | ) 559 | content_dict = {id: content for id, content in self.c.fetchall()} 560 | 561 | # Add the content to the items 562 | for item in items: 563 | item["content"] = content_dict.get(item["id"]) 564 | 565 | return items 566 | -------------------------------------------------------------------------------- /tinyvector/types/model_db.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Dict, Optional 3 | 4 | from pydantic import BaseModel, validator 5 | 6 | 7 | class TableCreationBody(BaseModel): 8 | table_name: str 9 | dimension: int 10 | use_uuid: bool = False 11 | normalize: bool = True 12 | index_type: str = "brute_force" 13 | allow_index_updates: bool = True 14 | is_index_active: bool = True 15 | 16 | 17 | class TableDeletionBody(BaseModel): 18 | table_name: str 19 | 20 | 21 | class IndexType(str, Enum): 22 | BRUTE_FORCE = "brute_force" 23 | PCA = "pca" 24 | 25 | 26 | class IndexDeletionBody(BaseModel): 27 | table_name: str 28 | 29 | 30 | class IndexCreationBody(BaseModel): 31 | table_name: str 32 | index_type: IndexType 33 | normalize: bool = True 34 | allow_index_updates: bool = False 35 | n_components: Optional[int] = None 36 | 37 | 38 | class ItemInsertionBody(BaseModel): 39 | table_name: str 40 | id: Optional[str] = None 41 | embedding: list[int] 42 | content: Optional[str] = None 43 | defer_index_update: bool = False 44 | 45 | 46 | class TableQueryObject(BaseModel): 47 | table_name: str 48 | query: list[int] 49 | k: int 50 | 51 | 52 | class TableQueryResultInstance(BaseModel): 53 | content: str 54 | embedding: list[float] | list[int] 55 | id: str 56 | score: float 57 | 58 | 59 | class TableQueryResult(BaseModel): 60 | items: list[TableQueryResultInstance] 61 | 62 | 63 | class TableMetadata(BaseModel): 64 | table_name: str 65 | allow_index_updates: bool 66 | dimension: int 67 | index_type: IndexType 68 | is_index_active: bool 69 | normalize: bool 70 | use_uuid: bool 71 | 72 | @validator('*', check_fields=False) 73 | def check_index_update_and_type(cls, value, values, config, field): 74 | if 'allow_index_updates' in values and 'index_type' in values: 75 | if values['allow_index_updates'] and values['index_type'] == "pca": 76 | raise ValueError( 77 | "PCA index does not support updates. Please set allow_index_updates=False." 78 | ) 79 | return value 80 | 81 | 82 | class DatabaseInfo(BaseModel): 83 | indexes: list[str] 84 | num_indexes: int 85 | num_tables: int 86 | tables: Dict[str, TableMetadata] 87 | --------------------------------------------------------------------------------