├── .Dockerfile
├── .dockerignore
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── TINYVECTORLOGO.png
└── results.png
├── pyproject.toml
├── requirements.txt
├── server
├── __main__.py
├── cache
│ └── README.md
├── logs
│ └── README.md
├── models
│ └── model_response.py
└── utils
│ ├── generate_token.py
│ └── util_pydantic.py
├── setup.py
├── tests
├── README.md
├── __init__.py
├── api
│ └── test_endpoint.py
├── db
│ ├── __init__.py
│ └── test_create_db.py
└── models
│ ├── __init__.py
│ └── test_table_metadata.py
└── tinyvector
├── LICENSE
├── MANIFEST.in
├── __init__.py
├── database.py
└── types
└── model_db.py
/.Dockerfile:
--------------------------------------------------------------------------------
1 | # Use an official Python runtime as a parent image
2 | FROM python:3.8-slim-buster
3 |
4 | # Set the working directory in the container to /app
5 | WORKDIR /app
6 |
7 | # Copy the current directory contents into the container at /app
8 | ADD . /app
9 |
10 | # Install any needed packages specified in requirements.txt
11 | RUN pip install --no-cache-dir -r requirements.txt
12 |
13 | # Generate JWT secret
14 | RUN python -c "import secrets; print(secrets.token_urlsafe())" > jwt_secret.txt
15 |
16 | # Make port 5000 available to the world outside this container
17 | EXPOSE 5000
18 |
19 | # Define environment variable
20 | ENV JWT_SECRET=$(cat jwt_secret.txt)
21 |
22 | # Run gunicorn when the container launches
23 | CMD ["gunicorn", "-w", "4", "server.__main__:app"]
24 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | # Ignore all logs
2 | *.log
3 |
4 | # Ignore all .pyc files
5 | *.pyc
6 |
7 | # Ignore data directory
8 | /data/*
9 |
10 | # Ignore Git and GitHub files
11 | .git/
12 | .github/
13 |
14 | # Ignore environment files
15 | .env
16 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 |
163 | *.parquet
164 | *.db
165 |
166 | logs
167 | cache
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Will DePue
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | tinyvector - the tiny, least-dumb, speedy vector embedding database.
8 | No, you don't need a vector database. You need tinyvector.
9 |
10 |
11 | In pre-release: prod-ready by late-July. Still in development, not ready!
12 |
13 |
14 |
15 | ## Features
16 | - __Tiny__: It's in the name. It's just a Flask server, SQLite DB, and Numpy indexes. Extremely easy to customize, under 500 lines of code.
17 | - __Fast__: Tinyvector wlll have comparable speed to advanced vector databases when it comes to speed on small to medium datasets.
18 | - __Vertically Scales__: Tinyvector stores all indexes in memory for fast querying. Very easy to scale up to 100 million+ vector dimensions without issue.
19 | - __Open Source__: MIT Licensed, free forever.
20 |
21 | ### Soon
22 | - __Powerful Queries__: Tinyvector is being upgraded with full SQL querying functionality, something missing from most other databases.
23 | - __Integrated Models__: Soon you won't have to bring your own vectors, just generate them on the server automaticaly. Will support SBert, Hugging Face models, OpenAI, Cohere, etc.
24 | - __Python/JS Client__: We'll add a comprehensive Python and Javascript package for easy integration with tinyvector in the next two weeks.
25 |
26 | ## Versions
27 |
28 | 🦀 tinyvector in Rust: [tinyvector-rs](https://github.com/m1guelpf/tinyvector-rs)
29 | 🐍 tinyvector in Python: [tinyvector](https://github.com/0hq/tinyvector)
30 |
31 | ## We're better than ...
32 |
33 | In most cases, most vector databases are overkill for something simple like:
34 | 1. Using embeddings to chat with your documents. Most document search is nowhere close to what you'd need to justify accelerating search speed with [HNSW](https://github.com/nmslib/hnswlib) or [FAISS](https://github.com/facebookresearch/faiss).
35 | 2. Doing search for your website or store. Unless you're selling 1,000,000 items, you don't need Pinecone.
36 | 3. Performing complex search queries on a very large database. Even if you have 2 million embeddings, this might still be the better option due to vector databases struggling with complex filtering. Tinyvector doesn't support metadata/filtering just yet, but it's very easy for you to add that yourself.
37 |
38 | ## Usage
39 |
40 | ```
41 | // Run the server manually:
42 | pip install -r requirements
43 | python -m server
44 |
45 | // Run tests:
46 | pip install pytest pytest-mock
47 | pytest
48 | ```
49 |
50 | ## Embeddings?
51 |
52 | What are embeddings?
53 |
54 | > As simple as possible: Embeddings are a way to compare similar things, in the same way humans compare similar things, by converting text into a small list of numbers. Similar pieces of text will have similar numbers, different ones have very different numbers.
55 |
56 | Read OpenAI's [explanation](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings).
57 |
58 |
59 | ## Get involved
60 |
61 | tinyvector is going to be growing a lot (don't worry, will still be tiny). Feel free to make a PR and contribute. If you have questions, just mention [@willdepue](https://twitter.com/willdepue).
62 |
63 | Some ideas for first pulls:
64 |
65 | - Add metadata and allow querying/filtering. This is especially important since a lot vector databases literally don't have a WHERE clause lol (or just an extremely weak one). Not a problem here. [Read more about this.](https://www.pinecone.io/learn/vector-search-filtering)
66 | - Rethinking SQLite and choosing something. NOSQL feels fitting for embeddings?
67 | - Add embedding functions for easy adding text (sentence transformers, OpenAI, Cohere, etc.)
68 | - Let's start GPU accelerating with a Pytorch index. GPUs are great at matmuls -> NN search with a fused kernel. Let's put 32 million vectors on a single GPU.
69 | - Help write unit and integration tests.
70 | - See all [active issues](https://github.com/0hq/tinyvector/issues)!
71 |
72 | ### Known Issues
73 | ```
74 | # Major bugs:
75 | Data corruption SQLite error? Stored vectors end up changing. Replicate by creating a table, inserting vectors, creating an index and then screwing around till an error happens. Dims end up unmatched (might be the blob functions or the norm functions most likely, but doesn't explain why the database is changing).
76 | PCA is not tested, neither is immutable Brute Force index.
77 | ```
78 |
79 |
80 | ## License
81 |
82 | [MIT](./LICENSE)
83 |
--------------------------------------------------------------------------------
/assets/TINYVECTORLOGO.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0hq/tinyvector/3d216419c28bd59aa042cbe836dc7a1a60e8a049/assets/TINYVECTORLOGO.png
--------------------------------------------------------------------------------
/assets/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0hq/tinyvector/3d216419c28bd59aa042cbe836dc7a1a60e8a049/assets/results.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.pytest.ini_options]
2 | pythonpath = [
3 | "."
4 | ]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | blinker==1.6.2
2 | click==8.1.4
3 | Flask==2.3.2
4 | flask-pydantic-spec==0.4.5
5 | gunicorn==20.1.0
6 | inflection==0.5.1
7 | itsdangerous==2.1.2
8 | Jinja2==3.1.2
9 | MarkupSafe==2.1.3
10 | numpy==1.25.1
11 | pydantic==1.10.11
12 | PyJWT==2.7.0
13 | python-dotenv==1.0.0
14 | typing_extensions==4.7.1
15 | Werkzeug==2.3.6
16 | tinyvector==0.1.0
--------------------------------------------------------------------------------
/server/__main__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from functools import wraps
4 |
5 | import jwt
6 | import numpy as np
7 | from dotenv import load_dotenv
8 | from flask import Flask, g, jsonify, request
9 | from flask_pydantic_spec import FlaskPydanticSpec, Response
10 | from pydantic import BaseModel
11 |
12 | from server.models.model_response import ErrorMessage
13 | from server.utils.util_pydantic import pydantic_to_dict
14 | from tinyvector import DB
15 | from tinyvector.types.model_db import (DatabaseInfo, IndexCreationBody,
16 | IndexDeletionBody, ItemInsertionBody,
17 | TableCreationBody, TableDeletionBody,
18 | TableMetadata, TableQueryObject,
19 | TableQueryResult)
20 |
21 | # if logs directory does not exist, create it
22 | if not os.path.exists("logs"):
23 | os.makedirs("logs")
24 |
25 | # if cache directory does not exist, create it
26 | if not os.path.exists("cache"):
27 | os.makedirs("cache")
28 |
29 | logging.basicConfig(
30 | filename="logs/server.log",
31 | level=logging.INFO,
32 | format="%(asctime)s %(levelname)s %(name)s %(threadName)s : %(message)s",
33 | )
34 | load_dotenv()
35 |
36 | # If JWT_SECRET is not set, the application will run in debug mode
37 | if os.getenv("JWT_SECRET") is None:
38 | os.environ["FLASK_ENV"] = "development"
39 |
40 |
41 | app = Flask(__name__)
42 | api = FlaskPydanticSpec(
43 | "Tiny Vector Database",
44 | )
45 |
46 | stream_handler = logging.StreamHandler()
47 | stream_handler.setLevel(logging.DEBUG)
48 | formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
49 | stream_handler.setFormatter(formatter)
50 | app.logger.addHandler(stream_handler)
51 |
52 | DATABASE_PATH = os.environ.get("DATABASE_PATH", "cache/database.db")
53 |
54 |
55 | def get_db():
56 | if "db" not in g:
57 | g.db = DB(DATABASE_PATH)
58 | return g.db
59 |
60 |
61 | @app.teardown_appcontext
62 | def close_db(exception):
63 | _ = g.pop("db", None)
64 |
65 |
66 | def token_required(f):
67 | """
68 | Decorator function to enforce token authentication for specific routes.
69 | """
70 |
71 | @wraps(f)
72 | def decorator(*args, **kwargs):
73 | token = None
74 |
75 | if app.debug:
76 | return f(*args, **kwargs)
77 |
78 | if "Authorization" in request.headers:
79 | token = request.headers["Authorization"]
80 |
81 | if not token:
82 | app.logger.warning("Token is missing!")
83 | return jsonify({"message": "Token is missing!"}), 401
84 |
85 | JWT_SECRET = os.getenv("JWT_SECRET")
86 | if JWT_SECRET is None:
87 | raise Exception("JWT_SECRET is not set!")
88 |
89 | try:
90 | _ = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
91 | except jwt.ExpiredSignatureError:
92 | app.logger.error("Token is expired!")
93 | return jsonify({"message": "Token is expired!"}), 401
94 | except jwt.InvalidTokenError:
95 | app.logger.error("Token is invalid!")
96 | return jsonify({"message": "Token is invalid!"}), 401
97 |
98 | return f(*args, **kwargs)
99 |
100 | return decorator
101 |
102 |
103 | class SuccessMessage(BaseModel):
104 | status: str = "success"
105 |
106 |
107 | @app.route("/status", methods=["GET"])
108 | @api.validate(body=None, resp=Response(HTTP_200=SuccessMessage), tags=["API"])
109 | @pydantic_to_dict
110 | def status():
111 | """
112 | Route to check the status of the application.
113 | """
114 | app.logger.info("Status check performed")
115 | return SuccessMessage(status="success"), 200
116 |
117 |
118 | @app.route("/info", methods=["GET"])
119 | @token_required
120 | @api.validate(
121 | body=None, resp=Response(HTTP_200=DatabaseInfo, HTTP_400=ErrorMessage), tags=["API"]
122 | )
123 | @pydantic_to_dict
124 | def info():
125 | """
126 | Route to get information about the database.
127 | """
128 | try:
129 | info = get_db().info()
130 | app.logger.info("Info retrieved successfully")
131 | return info, 200
132 | except Exception as e:
133 | app.logger.error(f"Error while retrieving info: {str(e)}")
134 | return ErrorMessage(error=f"Error while retrieving info: {str(e)}"), 400
135 |
136 |
137 | @app.route("/create_table", methods=["POST"])
138 | @token_required
139 | @api.validate(
140 | body=TableCreationBody,
141 | resp=Response(HTTP_200=SuccessMessage, HTTP_400=ErrorMessage),
142 | tags=["DB"],
143 | )
144 | @pydantic_to_dict
145 | def create_table():
146 | """
147 | Route to create a table in the database.
148 | If use_uuid is True, the table will use UUIDs as IDs, and the IDs provided in the insert route are not allowed.
149 | If use_uuid is False, the table will require strings as IDs.
150 | """
151 |
152 | data = request.get_json()
153 | body = TableCreationBody(**data)
154 | try:
155 | get_db().create_table_and_index(body)
156 | app.logger.info(f"Table {body.table_name} created successfully")
157 | return (
158 | SuccessMessage(status=f"Table {body.table_name} created successfully"),
159 | 200,
160 | )
161 | except Exception as e:
162 | app.logger.error(f"Error while creating table {body.table_name}: {str(e)}")
163 | return ErrorMessage(error=str(e)), 400
164 |
165 |
166 | @app.route("/delete_table", methods=["DELETE"])
167 | @api.validate(
168 | body=TableDeletionBody,
169 | resp=Response(HTTP_200=SuccessMessage, HTTP_400=ErrorMessage),
170 | tags=["DB"],
171 | )
172 | @token_required
173 | @pydantic_to_dict
174 | def delete_table():
175 | """
176 | Route to permanently delete a table and its data from the database.
177 | This will also delete the index associated with the table.
178 | """
179 | data = request.get_json()
180 | body = TableDeletionBody(**data)
181 | table_name = body.table_name
182 | try:
183 | get_db().delete_table(table_name)
184 | app.logger.info(f"Table {table_name} deleted successfully")
185 | return SuccessMessage(status=f"Table {table_name} deleted successfully"), 200
186 | except Exception as e:
187 | app.logger.error(f"Error while deleting table {table_name}: {str(e)}")
188 | return (
189 | ErrorMessage(status=f"Error while deleting table {table_name}: {str(e)}"),
190 | 400,
191 | )
192 |
193 |
194 | @app.route("/insert", methods=["POST"])
195 | @api.validate(
196 | body=ItemInsertionBody,
197 | resp=Response(HTTP_200=SuccessMessage, HTTP_400=ErrorMessage),
198 | tags=["DB"],
199 | )
200 | @token_required
201 | @pydantic_to_dict
202 | def insert():
203 | """
204 | Route to insert an item into a table in the database.
205 | Requires a previously generated vector embedding of the right dimension (set when creating the table).
206 | If use_uuid was set to True when creating the table, the ID will be generated automatically. Providing an ID in the request will result in an error.
207 | If use_uuid was set to False, the ID must be provided as a string.
208 | Defer index update can be set to True to stop the index from being updated after the insert. This only works for brute force index, as other indexes can't be efficiently updated after creation.
209 | """
210 | data = request.get_json()
211 | body = ItemInsertionBody(**data)
212 |
213 | table_name = body.table_name
214 | id = body.id
215 | embedding = body.embedding
216 | content = body.content
217 | defer_index_update = body.defer_index_update
218 | try:
219 | embedding = np.array(embedding)
220 | get_db().insert(table_name, id, embedding, content, defer_index_update)
221 | app.logger.info(f"Item {id} inserted successfully into table {table_name}")
222 | return (
223 | SuccessMessage(
224 | status=f"Item {id} inserted successfully into table {table_name}"
225 | ),
226 | 200,
227 | )
228 | except Exception as e:
229 | app.logger.error(f"Error while inserting item {id}: {str(e)}")
230 | return (
231 | ErrorMessage(status=f"Error while inserting item {id}: {str(e)}"),
232 | 400,
233 | )
234 |
235 |
236 | @app.route("/query", methods=["POST"])
237 | @token_required
238 | @api.validate(
239 | body=TableQueryObject,
240 | resp=Response(HTTP_200=TableQueryResult, HTTP_400=ErrorMessage),
241 | tags=["DB"],
242 | )
243 | @pydantic_to_dict
244 | def query():
245 | """
246 | Route to perform a query on a table in the database.
247 | Requires a previously generated query vector embedding of the right dimension (set when creating the table).
248 | K is the number of items to return.
249 | """
250 | body = TableQueryObject(**request.get_json())
251 |
252 | table_name = body.table_name
253 | query = body.query
254 | k = body.k
255 | try:
256 | query = np.array(query)
257 | items = get_db().query(table_name, query, k)
258 | app.logger.info(f"Query performed successfully on table {table_name}")
259 | print(items)
260 | return TableQueryResult(items=items), 200
261 | except Exception as e:
262 | app.logger.error(
263 | f"Error while performing query on table {table_name}: {str(e)}"
264 | )
265 | return (
266 | ErrorMessage(
267 | error=f"Error while performing query on table {table_name}: {str(e)}"
268 | ),
269 | 400,
270 | )
271 |
272 |
273 | @app.route("/create_index", methods=["POST"])
274 | @api.validate(
275 | body=IndexCreationBody,
276 | tags=["DB"],
277 | resp=Response(HTTP_200=SuccessMessage, HTTP_400=ErrorMessage),
278 | )
279 | @token_required
280 | @pydantic_to_dict
281 | def create_index():
282 | """
283 | Route to create an index on a table in the database.
284 | Index type can be 'brute_force' or 'pca'.
285 | PCA index is a dimensionality reduction index, and n_components is the number of components to keep and must be less than the dimension of the table. See README for more details.
286 | If normalize is True, cosine similarity will be used. If False, dot product will be used.
287 | If allow_index_updates is True, the index will be updated after each insert. This only works for brute force index, as other indexes can't be efficiently updated after creation.
288 | If you want to update index contents from a non-updatable index (PCA, others), the reccomended method is to delete and create a new one.
289 | """
290 | body = IndexCreationBody(**request.get_json())
291 |
292 | table_name = body.table_name
293 | index_type = body.index_type
294 | normalize = body.normalize
295 | allow_index_updates = body.allow_index_updates
296 | n_components = body.n_components
297 |
298 | try:
299 | get_db().create_index(
300 | table_name, index_type, normalize, allow_index_updates, n_components
301 | )
302 | app.logger.info(f"Index created successfully on table {table_name}")
303 | return (
304 | SuccessMessage(status=f"Index created successfully on table {table_name}"),
305 | 200,
306 | )
307 | except Exception as e:
308 | app.logger.error(f"Error while creating index on table {table_name}: {str(e)}")
309 | return (
310 | ErrorMessage(
311 | error=f"Error while creating index on table {table_name}: {str(e)}"
312 | ),
313 | 400,
314 | )
315 |
316 |
317 | @app.route("/delete_index", methods=["DELETE"])
318 | @token_required
319 | @api.validate(
320 | body=IndexDeletionBody,
321 | tags=["DB"],
322 | resp=Response(HTTP_200=SuccessMessage, HTTP_400=ErrorMessage),
323 | )
324 | @pydantic_to_dict
325 | def delete_index():
326 | """
327 | Route to delete an index from a table in the database.
328 | This will not delete the table or its data.
329 | """
330 | body = IndexDeletionBody(**request.get_json())
331 | table_name = body.table_name
332 | try:
333 | get_db().delete_index(table_name)
334 | app.logger.info(f"Index deleted successfully on table {table_name}")
335 | return (
336 | SuccessMessage(status=f"Index deleted successfully on table {table_name}"),
337 | 200,
338 | )
339 | except Exception as e:
340 | app.logger.error(f"Error while deleting index on table {table_name}: {str(e)}")
341 | return (
342 | ErrorMessage(
343 | error=f"Error while deleting index on table {table_name}: {str(e)}"
344 | ),
345 | 400,
346 | )
347 |
348 |
349 | if __name__ == "__main__":
350 | app.logger.info("Starting server...")
351 | api.register(app)
352 | PORT = 5234
353 | app.logger.info(
354 | f"\nSuccesfully Generated Documentation :) \n\n- Redoc: http://localhost:{PORT}/apidoc/redoc \n- Swagger: http://localhost:{PORT}/apidoc/swagger"
355 | )
356 | app.run(host="0.0.0.0", port=PORT, debug=True)
357 | get_db()
358 |
--------------------------------------------------------------------------------
/server/cache/README.md:
--------------------------------------------------------------------------------
1 | # Database File Cache
2 |
3 | Store database files, index caches, and more here.
4 |
--------------------------------------------------------------------------------
/server/logs/README.md:
--------------------------------------------------------------------------------
1 | # App Logs
2 |
3 | Logs are stored in .log format following the Python logger format.
4 | Logs are also piped to stdout for console viewing.
5 |
--------------------------------------------------------------------------------
/server/models/model_response.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 |
3 |
4 | class SuccessMessage(BaseModel):
5 | status: str = "success"
6 |
7 |
8 | class ErrorMessage(BaseModel):
9 | error: str = ""
10 |
--------------------------------------------------------------------------------
/server/utils/generate_token.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import secrets
4 |
5 | import jwt
6 | from dotenv import load_dotenv
7 |
8 | load_dotenv()
9 |
10 | def generate_jwt_secret():
11 | # Generate a random string of length 64 and write it to .env file
12 | jwt_secret = secrets.token_urlsafe(64)
13 |
14 | with open('.env', 'a') as file:
15 | file.write(f'\nJWT_SECRET={jwt_secret}')
16 |
17 | return jwt_secret
18 |
19 | def generate_token(user_id):
20 | # Generate a token with user_id as payload with 30 days expiry
21 | payload = {
22 | 'user_id': user_id,
23 | 'exp': datetime.datetime.utcnow() + datetime.timedelta(days=30, minutes=0)
24 | }
25 |
26 | token = jwt.encode(payload, os.getenv('JWT_SECRET'), algorithm='HS256')
27 |
28 | return token
29 |
30 | if __name__ == '__main__':
31 | generate_jwt_secret()
32 | generate_token('root')
--------------------------------------------------------------------------------
/server/utils/util_pydantic.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 |
3 | from pydantic import BaseModel
4 |
5 |
6 | def pydantic_to_dict(f):
7 | @wraps(f)
8 | def decorated_function(*args, **kwargs):
9 | result, status_code = f(*args, **kwargs)
10 | if isinstance(result, BaseModel):
11 | return result.dict(), status_code
12 | return result, status_code
13 |
14 | return decorated_function
15 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name='tinyvector',
5 | version='0.1.1',
6 | author='Will DePue',
7 | author_email='will@depue.net',
8 | license='MIT',
9 | description='the tiny, least-dumb, speedy vector embedding database.',
10 | long_description=open('README.md').read(),
11 | long_description_content_type='text/markdown',
12 | url='https://github.com/0hq/tinyvector',
13 | packages=find_packages(include=['core', 'core.*']),
14 | install_requires=[
15 | 'numpy',
16 | 'pydantic',
17 | 'psutil',
18 | 'scikit-learn',
19 | ],
20 | )
21 |
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | Unit tests coming soon.
2 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0hq/tinyvector/3d216419c28bd59aa042cbe836dc7a1a60e8a049/tests/__init__.py
--------------------------------------------------------------------------------
/tests/api/test_endpoint.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | import time
5 | from unittest.mock import patch
6 |
7 | import jwt as pyjwt
8 | import pytest
9 |
10 | from server.__main__ import SuccessMessage, app
11 | from server.models.model_response import ErrorMessage
12 | from tinyvector.database import DB
13 | from tinyvector.types.model_db import (DatabaseInfo, ItemInsertionBody,
14 | TableCreationBody, TableDeletionBody,
15 | TableMetadata, TableQueryObject)
16 |
17 | test_table = "test_table"
18 | jwt_secret = "testing_config"
19 | db_path = "test_endpoint_db.db"
20 | new_table = TableMetadata(
21 | table_name=test_table,
22 | allow_index_updates=True,
23 | dimension=42,
24 | index_type="brute_force",
25 | is_index_active=True,
26 | normalize=True,
27 | use_uuid=False,
28 | )
29 |
30 |
31 | def generate_client(app):
32 | os.environ["JWT_SECRET"] = jwt_secret
33 | encoded_jwt = pyjwt.encode(
34 | {"payload": "random payload"}, jwt_secret, algorithm="HS256"
35 | )
36 | headers = {"Authorization": encoded_jwt}
37 |
38 | client = app.test_client()
39 | client.environ_base.update(headers)
40 | return client, headers
41 |
42 |
43 | @pytest.fixture(scope="session")
44 | def create_test_db():
45 | test_db = DB(path=db_path)
46 | test_db.create_table_and_index(new_table)
47 | yield test_db
48 | os.remove(db_path)
49 |
50 |
51 | def test_token_authentication(create_test_db, mocker):
52 | """
53 | This validates that the info endpoint correctly captures the state of the DB at a specific snapshot in time.
54 | """
55 | # Mock the DB object
56 | mocker.patch("server.__main__.get_db", return_value=create_test_db)
57 |
58 | test_client, headers = generate_client(app)
59 |
60 | # /info
61 | with test_client as client:
62 | response = client.get("/info")
63 | data = json.loads(response.data)
64 |
65 | assert response.status_code == 401
66 | assert data == {"message": "Token is missing!"}
67 |
68 | response = client.get("/info", headers={"Authorization": "random_token"})
69 | data = json.loads(response.data)
70 | assert response.status_code == 401
71 | assert data == {"message": "Token is invalid!"}
72 |
73 | expired_jwt = pyjwt.encode(
74 | {"payload": "random payload", "exp": int(time.time()) - 10},
75 | jwt_secret,
76 | algorithm="HS256",
77 | )
78 | response = client.get("/info", headers={"Authorization": expired_jwt})
79 | data = json.loads(response.data)
80 | assert response.status_code == 401
81 | assert data == {"message": "Token is expired!"}
82 |
83 |
84 | def test_info(create_test_db, mocker):
85 | """
86 | This validates that the info endpoint correctly captures the state of the DB at a specific snapshot in time.
87 | """
88 | # Mock the DB object
89 | mocker.patch("server.__main__.get_db", return_value=create_test_db)
90 |
91 | test_client, headers = generate_client(app)
92 |
93 | # /info
94 | with test_client as client:
95 | response = client.get("/info", headers=headers)
96 | data = json.loads(response.data)
97 |
98 | assert response.status_code == 200
99 |
100 | expected_response = DatabaseInfo(
101 | num_tables=1,
102 | num_indexes=1,
103 | tables={},
104 | indexes=["test_table brute_force_mutable"],
105 | )
106 | expected_response.tables[test_table] = new_table
107 |
108 | # Replace the expected data with the actual data you expect from DB.info()
109 | assert data == expected_response.dict()
110 |
111 |
112 | def test_status_endpoint(create_test_db, mocker):
113 | """
114 | This validates that the info endpoint correctly captures the state of the DB at a specific snapshot in time.
115 | """
116 | # Mock the DB object
117 | mocker.patch("server.__main__.get_db", return_value=create_test_db)
118 |
119 | test_client, headers = generate_client(app)
120 |
121 | # /info
122 | with test_client as client:
123 | response = client.get("/status")
124 | data = json.loads(response.data)
125 |
126 | assert response.status_code == 200
127 |
128 | # Replace the expected data with the actual data you expect from DB.info()
129 | assert data == SuccessMessage(status="success").dict()
130 |
131 |
132 | def test_create_endpoint(create_test_db, mocker):
133 | """
134 | This validates that we can succesfully create a table object and have the database be updated and populated with the new table.
135 | """
136 | # Mock the DB object
137 | mocker.patch("server.__main__.get_db", return_value=create_test_db)
138 |
139 | test_client, headers = generate_client(app)
140 |
141 | # /info
142 | with test_client as client:
143 | table_2 = TableCreationBody(
144 | table_name="table_testing_2",
145 | allow_index_updates=True,
146 | dimension=42,
147 | index_type="brute_force",
148 | is_index_active=True,
149 | normalize=True,
150 | use_uuid=False,
151 | )
152 |
153 | response = client.post(
154 | "/create_table", json=table_2.dict(), headers=headers
155 | )
156 |
157 | data = json.loads(response.data)
158 |
159 | assert response.status_code == 200
160 | assert (
161 | data
162 | == SuccessMessage(
163 | status=f"Table {table_2.table_name} created successfully"
164 | ).dict()
165 | )
166 |
167 | response = client.get("/info", headers=headers)
168 | data = json.loads(response.data)
169 |
170 | assert response.status_code == 200
171 | assert data["num_tables"] == 2
172 | assert data["num_indexes"] == 2
173 | assert data["tables"][table_2.table_name] == table_2.dict()
174 |
175 | # We then create a new db instance which loads metadata from scratch based on sqlite3 database and validate that it has the same data
176 |
177 | new_db = DB(path=db_path, debug=True)
178 | new_instance_data = new_db.info()
179 | assert new_instance_data.dict() == data
180 |
181 |
182 | def test_delete_endpoint(create_test_db, mocker):
183 | """
184 | This validates the delete endpoint
185 | """
186 | # Mock the DB object
187 | mocker.patch("server.__main__.get_db", return_value=create_test_db)
188 |
189 | test_client, headers = generate_client(app)
190 |
191 | with test_client as client:
192 | # We previously created a new table in a previous test, let's make sure that the changes have been applied.
193 |
194 | response = client.get("/info", headers=headers)
195 | body = json.loads(response.data)
196 | assert response.status_code == 200
197 | assert body["num_tables"] == 2
198 | assert body["num_indexes"] == 2
199 |
200 | # Try to delete an invalid table
201 | body = TableDeletionBody(
202 | table_name="fake_table",
203 | )
204 |
205 | response = client.delete(
206 | "/delete_table", json=body.dict(), headers=headers
207 | )
208 | data = json.loads(response.data)
209 | assert response.status_code == 400
210 | assert (
211 | data
212 | == ErrorMessage(
213 | status=f"Error while deleting table fake_table: Table fake_table does not exist"
214 | ).dict()
215 | )
216 |
217 | # Try to delete a valid table
218 | body = TableDeletionBody(
219 | table_name="table_testing_2",
220 | )
221 | response = client.delete(
222 | "/delete_table", json=body.dict(), headers=headers
223 | )
224 | data = json.loads(response.data)
225 | assert response.status_code == 200
226 | assert (
227 | data
228 | == SuccessMessage(
229 | status=f"Table {body.table_name} deleted successfully"
230 | ).dict()
231 | )
232 |
233 | # Validate that the table has been deleted
234 | response = client.get("/info", headers=headers)
235 | data = json.loads(response.data)
236 |
237 | assert response.status_code == 200
238 | assert data["num_tables"] == 1
239 | assert data["num_indexes"] == 1
240 |
241 | # We then create a new db instance which loads metadata from scratch based on sqlite3 database and validate that it has the same data
242 |
243 | new_db = DB(path=db_path, debug=True)
244 | new_instance_data = new_db.info()
245 | assert new_instance_data.dict() == data
246 |
247 |
248 | def test_insert_endpoint(create_test_db, mocker):
249 | """
250 | This validates the delete endpoint
251 | """
252 | # Mock the DB object
253 | mocker.patch("server.__main__.get_db", return_value=create_test_db)
254 |
255 | test_client, headers = generate_client(app)
256 |
257 | with test_client as client:
258 | values = 200
259 | for i in range(values):
260 | item = ItemInsertionBody(
261 | table_name=test_table,
262 | id=f"Item {i}",
263 | embedding=[random.randint(0, values) for i in range(42)],
264 | content=f"Item {i} content",
265 | defer_index_update=False,
266 | )
267 | response = client.post("/insert", json=item.dict(), headers=headers)
268 | data = json.loads(response.data)
269 | assert response.status_code == 200
270 | assert (
271 | data["status"]
272 | == f"Item Item {i} inserted successfully into table {test_table}"
273 | )
274 |
275 | # We now validate that our table has a total of 10 items by calling the /query endpoint
276 | body = TableQueryObject(
277 | table_name=test_table, query=[0 for i in range(42)], k=values
278 | )
279 | response = client.post("/query", json=body.dict(), headers=headers)
280 | data = json.loads(response.data)
281 |
282 | assert response.status_code == 200
283 | assert len(data["items"]) == values
284 |
--------------------------------------------------------------------------------
/tests/db/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0hq/tinyvector/3d216419c28bd59aa042cbe836dc7a1a60e8a049/tests/db/__init__.py
--------------------------------------------------------------------------------
/tests/db/test_create_db.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sqlite3
3 |
4 | import pytest
5 |
6 | from tinyvector.database import DB
7 | from tinyvector.types.model_db import (DatabaseInfo, TableCreationBody,
8 | TableMetadata)
9 |
10 |
11 | @pytest.fixture
12 | def create_db(tmp_path):
13 | db_path = tmp_path / "test.db"
14 | table_name = "test_table"
15 |
16 | db_config = TableCreationBody(
17 | table_name=table_name,
18 | allow_index_updates=False,
19 | dimension=42,
20 | index_type="brute_force",
21 | is_index_active=True,
22 | normalize=True,
23 | use_uuid=False,
24 | )
25 | test_db = DB(db_path)
26 | test_db.create_table_and_index(db_config)
27 | yield test_db, table_name, db_config, db_path
28 | os.remove(db_path)
29 |
30 |
31 | def test_create_table_again_raises_error(create_db):
32 | test_db, table_name, db_config, db_path = create_db
33 | with pytest.raises(ValueError, match=f"Table {table_name} already exists"):
34 | test_db.create_table_and_index(db_config)
35 |
36 |
37 | def test_loads_correct_metadata_on_startup(create_db):
38 | test_db, table_name, db_config, db_path = create_db
39 | test_db_2 = DB(db_path)
40 | info = test_db_2.info()
41 | print(info)
42 | assert table_name in info.tables, f"Table {table_name} not found"
43 | assert TableMetadata(**db_config.dict()) == info.tables[table_name]
44 | assert info.num_tables == 1, "Number of tables should be 1"
45 |
46 |
47 | def test_delete_table(create_db):
48 | test_db, table_name, db_config, db_path = create_db
49 | test_db.delete_table(table_name)
50 | info = test_db.info()
51 | assert table_name not in info.tables, f"Table {table_name} should not be found"
52 | assert info.num_tables == 0, "Number of tables should be 0"
53 |
54 |
55 | def test_create_index(create_db):
56 | test_db, table_name, db_config, db_path = create_db
57 | # TODO
58 | pass
59 |
--------------------------------------------------------------------------------
/tests/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0hq/tinyvector/3d216419c28bd59aa042cbe836dc7a1a60e8a049/tests/models/__init__.py
--------------------------------------------------------------------------------
/tests/models/test_table_metadata.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydantic import ValidationError
3 |
4 | from tinyvector.types.model_db import TableMetadata
5 |
6 |
7 | def test_invalid_table_metadata():
8 | """
9 | We raise an error when allow_index_updates is True and index type is PCA since we don't support this functionality.
10 | """
11 | with pytest.raises(ValidationError) as exc_info:
12 | TableMetadata(
13 | table_name="example",
14 | allow_index_updates=True,
15 | dimension=3,
16 | index_type="pca",
17 | is_index_active=True,
18 | normalize=True,
19 | use_uuid=1,
20 | )
21 |
--------------------------------------------------------------------------------
/tinyvector/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Will DePue
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/tinyvector/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
--------------------------------------------------------------------------------
/tinyvector/__init__.py:
--------------------------------------------------------------------------------
1 | from .database import DB
2 |
--------------------------------------------------------------------------------
/tinyvector/database.py:
--------------------------------------------------------------------------------
1 | import sqlite3
2 | import uuid
3 | from abc import ABC, abstractmethod
4 | from typing import cast
5 |
6 | import numpy as np
7 | import psutil
8 | from sklearn.decomposition import PCA
9 |
10 | from .types.model_db import DatabaseInfo, TableCreationBody, TableMetadata
11 |
12 |
13 | def norm_np(datum):
14 | """
15 | Normalize a NumPy array along the specified axis using L2 normalization.
16 | """
17 | axis = None if datum.ndim == 1 else 1
18 | return datum / (np.linalg.norm(datum, axis=axis, keepdims=True) + 1e-6)
19 |
20 |
21 | def norm_python(datum):
22 | """
23 | Normalize a Python list or ndarray using L2 normalization.
24 | """
25 | datum = np.array(datum)
26 | result = norm_np(datum)
27 | return result.tolist()
28 |
29 |
30 | def array_to_blob(array):
31 | """
32 | Convert a NumPy array to a binary blob.
33 | """
34 | # Validate that the array is a numpy array of type float32
35 | if not isinstance(array, np.ndarray):
36 | raise ValueError("Expected a numpy array")
37 | return array.tobytes()
38 |
39 |
40 | def blob_to_array(blob):
41 | """
42 | Convert a binary blob to a NumPy array.
43 | """
44 | result = np.frombuffer(blob, dtype=np.float32)
45 | return result
46 |
47 |
48 | class AbstractIndex(ABC):
49 | """
50 | Abstract base class for database indexes.
51 | """
52 |
53 | def __init__(self, table_name, type, num_vectors, dimension):
54 | self.table_name = table_name
55 | self.type = type
56 | self.num_vectors = num_vectors
57 | self.dimension = dimension
58 |
59 | def __len__(self):
60 | return self.num_vectors
61 |
62 | def __str__(self):
63 | return self.table_name + " " + self.type
64 |
65 | def __repr__(self):
66 | return self.__str__()
67 |
68 | @abstractmethod
69 | def add_vector(self, id, embedding):
70 | """
71 | Add a vector with the specified ID and embedding to the index.
72 | """
73 | pass
74 |
75 | @abstractmethod
76 | def get_similarity(self, query, k):
77 | """
78 | Get the top-k most similar vectors to the given query vector.
79 | """
80 | pass
81 |
82 |
83 | class BruteForceIndexImmutable(AbstractIndex):
84 | """
85 | Brute-force index implementation that is immutable (does not support vector addition).
86 | Faster than the mutable version because it uses NumPy arrays instead of Python lists.
87 | Search complexity is O(n + k log k) where n is the number of vectors in the index.
88 | """
89 |
90 | def __init__(self, table_name, dimension, normalize, embeddings, ids):
91 | super().__init__(
92 | table_name, "brute_force_immutable", len(embeddings), dimension
93 | )
94 | # TODO: Validate if this is the right fit. I hope it is ... ?
95 | self.embeddings = norm_np(np.array(embeddings)) if normalize else embeddings
96 | self.ids = ids
97 | self.normalize = normalize
98 |
99 | def add_vector(self, id, embedding):
100 | """
101 | Add a vector to the index (not supported in the immutable version).
102 | """
103 | raise NotImplementedError(
104 | "This index does not support vector addition. Please use BruteForceIndexMutable if updates are required."
105 | )
106 |
107 | def get_similarity(self, query, k):
108 | """
109 | Get the top-k most similar vectors to the query vector.
110 | """
111 | if len(query) != self.dimension:
112 | raise ValueError(
113 | f"Expected query of dimension {self.dimension}, got {len(query)}"
114 | )
115 |
116 | query_normalized = norm_np(query) if self.normalize else query
117 | import pdb
118 |
119 | pdb.set_trace()
120 | scores = query_normalized @ self.embeddings.T
121 | arg_k = k if k < len(scores) else len(scores) - 1
122 | partitioned_indices = np.argpartition(-scores, kth=arg_k)[:k]
123 | top_k_indices = partitioned_indices[np.argsort(-scores[partitioned_indices])]
124 |
125 | return [
126 | {"id": self.ids[i], "embedding": self.embeddings[i], "score": scores[i]}
127 | for i in top_k_indices
128 | ]
129 |
130 |
131 | class BruteForceIndexMutable(AbstractIndex):
132 | """
133 | Brute-force index implementation that is mutable (supports vector addition).
134 | Slower than the immutable version because it uses Python lists instead of NumPy arrays.
135 | Search complexity is O(n + k log k) where n is the number of vectors in the index.
136 | """
137 |
138 | def __init__(self, table_name, dimension, normalize, embeddings, ids):
139 | super().__init__(table_name, "brute_force_mutable", len(embeddings), dimension)
140 |
141 | if len(embeddings) > 0 and dimension != len(embeddings[0]):
142 | raise ValueError(
143 | f"Expected embeddings of dimension {self.dimension}, got {len(embeddings[0])}"
144 | )
145 | print("embeddings: ", embeddings, "norm_python: ", norm_python(embeddings))
146 | self.embeddings = norm_python(embeddings) if normalize else embeddings.tolist()
147 | self.ids = ids
148 | self.normalize = normalize
149 |
150 | def add_vector(self, id, embedding):
151 | """
152 | Add a vector to the index.
153 | """
154 | if len(embedding) != self.dimension:
155 | raise ValueError(
156 | f"Expected embedding of dimension {self.dimension}, got {len(embedding)}"
157 | )
158 | self.ids.append(id)
159 | self.embeddings.append(
160 | norm_python(embedding) if self.normalize else embedding.tolist()
161 | )
162 | self.num_vectors += 1
163 |
164 | def get_similarity(self, query, k):
165 | """
166 | Get the top-k most similar vectors to the query vector.
167 | """
168 | if len(query) != self.dimension:
169 | raise ValueError(
170 | f"Expected query of dimension {self.dimension}, got {len(query)}"
171 | )
172 |
173 | query_normalized = norm_np(query) if self.normalize else query
174 | dataset_vectors = np.array(self.embeddings)
175 | scores = query_normalized @ dataset_vectors.T
176 | arg_k = k if k < len(scores) else len(scores) - 1
177 | partitioned_indices = np.argpartition(-scores, kth=arg_k)[:k]
178 | top_k_indices = partitioned_indices[np.argsort(-scores[partitioned_indices])]
179 |
180 | return [
181 | {"id": self.ids[i], "embedding": self.embeddings[i], "score": scores[i]}
182 | for i in top_k_indices
183 | ]
184 |
185 |
186 | class PCAIndex(AbstractIndex):
187 | """
188 | Index implementation using PCA for dimensionality reduction.
189 | Not mutable (does not support vector addition) and fastest of all indexes.
190 | This index does a dimensionality reduction on the input vectors using PCA. This can considerably speed up the search process, as the number of dimensions is reduced from the original dimension to the number of components specified in the constructor. The downside is maybe results are not as accurate as the brute-force index, but can be surprisingly good for many applications.
191 | Search complexity is O(n + k log k) where n is the number of vectors in the index.
192 | Indexing time can be slow on startup for large datasets.
193 | """
194 |
195 | def __init__(self, table_name, dimension, n_components, normalize, embeddings, ids):
196 | super().__init__(
197 | table_name, "pca", len(embeddings), n_components
198 | ) # Initialize name/len attribute in parent class
199 | self.ids = ids
200 | self.pca = PCA(n_components)
201 | self.original_dimension = dimension
202 | self.embeddings = self.pca.fit_transform(embeddings)
203 | self.embeddings = norm_np(self.embeddings) if normalize else self.embeddings
204 | self.normalize = normalize
205 |
206 | def add_vector(self, id, embedding):
207 | """
208 | Add a vector to the index (not supported in the PCA index).
209 | """
210 | raise NotImplementedError(
211 | "This index does not support vector addition. Please use BruteForceIndexMutable if updates are required."
212 | )
213 |
214 | def get_similarity(self, query, k):
215 | """
216 | Get the top-k most similar vectors to the query vector.
217 | Applies the previously calculated PCA model to the query vector before searching.
218 | """
219 | if len(query) != self.original_dimension:
220 | raise ValueError(
221 | f"Expected query of dimension {self.original_dimension}, got {len(query)}"
222 | )
223 |
224 | dataset_vectors = self.embeddings
225 | transformed_query = self.pca.transform([query])[0]
226 | query_normalized = (
227 | norm_np(transformed_query) if self.normalize else transformed_query
228 | )
229 |
230 | scores = query_normalized @ dataset_vectors.T
231 | arg_k = k if k < len(scores) else len(scores) - 1
232 | partitioned_indices = np.argpartition(-scores, kth=arg_k)[:k]
233 | top_k_indices = partitioned_indices[np.argsort(-scores[partitioned_indices])]
234 |
235 | return [
236 | {"id": self.ids[i], "embedding": self.embeddings[i], "score": scores[i]}
237 | for i in top_k_indices
238 | ]
239 |
240 |
241 | class DB:
242 | """
243 | Database class for managing tables and indexes.
244 | """
245 |
246 | def __init__(self, path, debug=False):
247 | self.conn = sqlite3.connect(path, check_same_thread=False)
248 | self.c = self.conn.cursor()
249 | self.table_metadata = {}
250 | self.indexes = {}
251 |
252 | self.path = path
253 | self.debug = debug
254 | self._init_db()
255 |
256 | def info(self):
257 | """
258 | Get information about all tables in the database.
259 | """
260 |
261 | res = DatabaseInfo(
262 | num_indexes=len(self.indexes),
263 | num_tables=len(self.table_metadata),
264 | tables={},
265 | indexes=[str(index) for index in self.indexes.values()],
266 | )
267 | res.tables = self.table_metadata
268 | return res
269 |
270 | def _init_db(self):
271 | """
272 | Initialize the database and load metadata.
273 | """
274 | self.c.execute(
275 | "CREATE TABLE IF NOT EXISTS table_metadata (table_name TEXT PRIMARY KEY, dimension INTEGER, index_type TEXT, normalize BOOLEAN, allow_index_updates BOOLEAN, is_active BOOLEAN, use_uuid BOOLEAN)"
276 | )
277 | self.conn.commit()
278 | self._load_metadata()
279 |
280 | def _load_metadata(self):
281 | """
282 | Load table metadata from the database.
283 | At the moment, all indexes are rebuilt on startup. This will be changed in the future.
284 | """
285 | self.table_metadata = {}
286 | select_query = "SELECT * FROM table_metadata"
287 |
288 | for row in self.c.execute(select_query).fetchall():
289 | (
290 | table_name,
291 | dimension,
292 | index_type,
293 | normalize,
294 | allow_index_updates,
295 | is_index_active,
296 | use_uuid,
297 | ) = row
298 | self.table_metadata[table_name] = TableMetadata(
299 | table_name=table_name,
300 | dimension=dimension,
301 | index_type=index_type,
302 | normalize=normalize,
303 | allow_index_updates=allow_index_updates,
304 | is_index_active=is_index_active,
305 | use_uuid=use_uuid,
306 | )
307 |
308 | if is_index_active:
309 | try:
310 | self.create_index(
311 | table_name,
312 | index_type,
313 | normalize,
314 | allow_index_updates,
315 | dimension,
316 | )
317 |
318 | except Exception as e:
319 | print(
320 | f"Error loading index for table {table_name}: {e}. Clearing index..."
321 | )
322 | del self.table_metadata[table_name]
323 | if self.indexes.get(table_name) is not None:
324 | del self.indexes[table_name]
325 |
326 | def create_index(
327 | self,
328 | table_name,
329 | index_type,
330 | normalize,
331 | allow_index_updates=None,
332 | n_components=None,
333 | ):
334 | """
335 | Create an index on the specified table.
336 | """
337 |
338 | if psutil.virtual_memory().available < 0.1 * psutil.virtual_memory().total:
339 | raise MemoryError("System is running out of memory")
340 |
341 | def get_data(select_query):
342 | self.c.execute(select_query)
343 | rows = self.c.fetchall()
344 | ids = [row[0] for row in rows]
345 | embeddings = [blob_to_array(row[1]) for row in rows]
346 | return ids, embeddings
347 |
348 | if self.table_metadata.get(table_name) is None:
349 | raise ValueError(
350 | f"Table {table_name} does not exist. Create the table first."
351 | )
352 |
353 | if self.indexes.get(table_name) is not None:
354 | raise ValueError(
355 | f"Index for table {table_name} already exists. Delete the index first if you want to rebuild it."
356 | )
357 |
358 | dimension = self.table_metadata[table_name].dimension
359 | if index_type == "brute_force":
360 | ids, embeddings = get_data(f"SELECT * FROM {table_name}")
361 | if allow_index_updates:
362 | self.indexes[table_name] = BruteForceIndexMutable(
363 | table_name, dimension, normalize, embeddings, ids
364 | )
365 | else:
366 | self.indexes[table_name] = BruteForceIndexImmutable(
367 | table_name, dimension, normalize, embeddings, ids
368 | )
369 | elif index_type == "pca":
370 | if n_components is None:
371 | raise ValueError("n_components must be specified for PCA index")
372 | ids, embeddings = get_data(f"SELECT * FROM {table_name}")
373 | self.indexes[table_name] = PCAIndex(
374 | table_name, dimension, n_components, normalize, embeddings, ids
375 | )
376 | else:
377 | raise ValueError(f"Unknown index type {index_type}")
378 |
379 | # Update metadata
380 | self.table_metadata[table_name].index_type = index_type
381 | self.table_metadata[table_name].is_index_active = True
382 | self.table_metadata[table_name].allow_index_updates = allow_index_updates
383 | self.table_metadata[table_name].normalize = normalize
384 | self.c.execute(
385 | "UPDATE table_metadata SET index_type = ?, is_active = ?, allow_index_updates = ?, normalize = ? WHERE table_name = ?",
386 | (index_type, True, allow_index_updates, normalize, table_name),
387 | )
388 | self.conn.commit()
389 |
390 | def create_table_and_index(self, table_config: TableCreationBody):
391 | """
392 | Creates a new table and index in the database
393 | """
394 |
395 | table_name = table_config.table_name
396 |
397 | # We first validate that the table does not exist
398 | if table_name in self.table_metadata:
399 | raise ValueError(f"Table {table_name} already exists")
400 |
401 | # We create the table
402 | self.c.execute(
403 | f"CREATE TABLE {table_name} (id TEXT PRIMARY KEY, embedding BLOB, content TEXT)"
404 | )
405 |
406 | # We update table metadata
407 | self.table_metadata[table_name] = TableMetadata
408 |
409 | self.update_table_metadata(table_config)
410 |
411 | self.create_index(
412 | table_config.table_name,
413 | table_config.index_type,
414 | table_config.normalize,
415 | table_config.allow_index_updates,
416 | table_config.dimension,
417 | )
418 |
419 | with self.conn:
420 | cursor = self.conn.cursor()
421 | cursor.execute(
422 | "SELECT * FROM table_metadata WHERE table_name = ?", (table_name,)
423 | )
424 | row = cursor.fetchone()
425 | self.conn.commit()
426 |
427 | def update_table_metadata(self, table_config: TableCreationBody):
428 | query = """
429 | INSERT INTO table_metadata (
430 | table_name, dimension, index_type, normalize, allow_index_updates, is_active, use_uuid
431 | ) VALUES (?, ?, ?, ?, ?, ?, ?)
432 | """
433 | values = (
434 | table_config.table_name,
435 | table_config.dimension,
436 | table_config.index_type,
437 | table_config.normalize,
438 | table_config.allow_index_updates,
439 | table_config.is_index_active,
440 | table_config.use_uuid,
441 | )
442 | with self.conn:
443 | cursor = self.conn.cursor()
444 | cursor.execute(query, values)
445 | self.conn.commit()
446 |
447 | self.table_metadata[table_config.table_name] = table_config
448 | return
449 |
450 | def create_table(self, table_name, dimension, use_uuid):
451 | """
452 | Create a new table in the database.
453 | """
454 | if self.table_metadata.get(table_name) is not None:
455 | raise ValueError(f"Table {table_name} already exists")
456 | self.c.execute(
457 | f"CREATE TABLE {table_name} (id TEXT PRIMARY KEY, embedding BLOB, content TEXT)"
458 | )
459 | self.c.execute(
460 | "INSERT INTO table_metadata VALUES (?, ?, ?, ?, ?, ?, ?)",
461 | (table_name, dimension, None, None, None, False, use_uuid),
462 | )
463 | self.conn.commit()
464 |
465 | # Update metadata
466 | self.table_metadata[table_name] = {
467 | "dimension": dimension,
468 | "index_type": None,
469 | "normalize": None,
470 | "allow_index_updates": None,
471 | "is_index_active": False,
472 | "use_uuid": use_uuid,
473 | }
474 |
475 | def delete_table(self, table_name):
476 | """
477 | Delete a table from the database.
478 | """
479 | if self.table_metadata.get(table_name) is None:
480 | raise ValueError(f"Table {table_name} does not exist")
481 | self.c.execute(f"DROP TABLE {table_name}")
482 | self.c.execute("DELETE FROM table_metadata WHERE table_name = ?", (table_name,))
483 | self.conn.commit()
484 | if self.indexes.get(table_name) is not None:
485 | del self.indexes[table_name]
486 | self.table_metadata.pop(table_name)
487 |
488 | def delete_index(self, table_name):
489 | """
490 | Delete an index from a table.
491 | """
492 | if self.indexes.get(table_name) is None:
493 | raise ValueError(f"Index for table {table_name} does not exist")
494 | del self.indexes[table_name]
495 | self.table_metadata[table_name]["is_index_active"] = False
496 | self.table_metadata[table_name]["index_type"] = None
497 | self.table_metadata[table_name]["normalize"] = None
498 | self.table_metadata[table_name]["allow_index_updates"] = None
499 | self.c.execute(
500 | "UPDATE table_metadata SET index_type = NULL, normalize = NULL, allow_index_updates = NULL, is_active = 0 WHERE table_name = ?",
501 | (table_name,),
502 | )
503 | self.conn.commit()
504 |
505 | def insert(self, table_name, id, embedding, content, defer_index_update=False):
506 | """
507 | Insert a vector into a table.
508 | """
509 | if psutil.virtual_memory().available < 0.1 * psutil.virtual_memory().total:
510 | raise MemoryError("System is running out of memory")
511 |
512 | # First validate that table exists
513 | if table_name not in self.table_metadata:
514 | raise ValueError(f"Table {table_name} does not exist")
515 |
516 | table_config: TableMetadata = self.table_metadata[table_name]
517 |
518 | if table_config.use_uuid and id is not None:
519 | raise ValueError(
520 | "This table uses auto-generated UUIDs. Do not provide an ID."
521 | )
522 | elif table_config.use_uuid:
523 | id = str(uuid.uuid4()) # Generate a unique ID using the uuid library.
524 | elif id is None:
525 | raise ValueError("This table uses custom IDs. Please provide an ID.")
526 |
527 | insert_query = f"INSERT INTO {table_name} VALUES (?, ?, ?)"
528 | self.c.execute(insert_query, (id, array_to_blob(embedding), content))
529 | self.conn.commit()
530 |
531 | if (
532 | table_config.is_index_active is True
533 | and table_config.allow_index_updates is True
534 | and defer_index_update is False
535 | ):
536 | self.indexes[table_name].add_vector(id, embedding)
537 |
538 | def query(self, table_name, query, k):
539 | """
540 | Perform a query on a table.
541 | """
542 | if self.table_metadata.get(table_name) is None:
543 | raise ValueError(f"Table {table_name} does not exist")
544 |
545 | table_config = self.table_metadata[table_name]
546 | table_config = cast(TableMetadata, table_config)
547 |
548 | if table_config.is_index_active is False:
549 | raise ValueError(f"Index for table {table_name} does not exist")
550 |
551 | items = self.indexes[table_name].get_similarity(np.array(query), k)
552 |
553 | # Get content from DB in a single query
554 | ids = [item["id"] for item in items]
555 | placeholder = ", ".join(["?"] * len(ids)) # Generate placeholders for query
556 | self.c.execute(
557 | f"SELECT id, content FROM {table_name} WHERE id IN ({placeholder})", ids
558 | )
559 | content_dict = {id: content for id, content in self.c.fetchall()}
560 |
561 | # Add the content to the items
562 | for item in items:
563 | item["content"] = content_dict.get(item["id"])
564 |
565 | return items
566 |
--------------------------------------------------------------------------------
/tinyvector/types/model_db.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Dict, Optional
3 |
4 | from pydantic import BaseModel, validator
5 |
6 |
7 | class TableCreationBody(BaseModel):
8 | table_name: str
9 | dimension: int
10 | use_uuid: bool = False
11 | normalize: bool = True
12 | index_type: str = "brute_force"
13 | allow_index_updates: bool = True
14 | is_index_active: bool = True
15 |
16 |
17 | class TableDeletionBody(BaseModel):
18 | table_name: str
19 |
20 |
21 | class IndexType(str, Enum):
22 | BRUTE_FORCE = "brute_force"
23 | PCA = "pca"
24 |
25 |
26 | class IndexDeletionBody(BaseModel):
27 | table_name: str
28 |
29 |
30 | class IndexCreationBody(BaseModel):
31 | table_name: str
32 | index_type: IndexType
33 | normalize: bool = True
34 | allow_index_updates: bool = False
35 | n_components: Optional[int] = None
36 |
37 |
38 | class ItemInsertionBody(BaseModel):
39 | table_name: str
40 | id: Optional[str] = None
41 | embedding: list[int]
42 | content: Optional[str] = None
43 | defer_index_update: bool = False
44 |
45 |
46 | class TableQueryObject(BaseModel):
47 | table_name: str
48 | query: list[int]
49 | k: int
50 |
51 |
52 | class TableQueryResultInstance(BaseModel):
53 | content: str
54 | embedding: list[float] | list[int]
55 | id: str
56 | score: float
57 |
58 |
59 | class TableQueryResult(BaseModel):
60 | items: list[TableQueryResultInstance]
61 |
62 |
63 | class TableMetadata(BaseModel):
64 | table_name: str
65 | allow_index_updates: bool
66 | dimension: int
67 | index_type: IndexType
68 | is_index_active: bool
69 | normalize: bool
70 | use_uuid: bool
71 |
72 | @validator('*', check_fields=False)
73 | def check_index_update_and_type(cls, value, values, config, field):
74 | if 'allow_index_updates' in values and 'index_type' in values:
75 | if values['allow_index_updates'] and values['index_type'] == "pca":
76 | raise ValueError(
77 | "PCA index does not support updates. Please set allow_index_updates=False."
78 | )
79 | return value
80 |
81 |
82 | class DatabaseInfo(BaseModel):
83 | indexes: list[str]
84 | num_indexes: int
85 | num_tables: int
86 | tables: Dict[str, TableMetadata]
87 |
--------------------------------------------------------------------------------