├── .github └── workflows │ └── go.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── collection.go ├── collection_test.go ├── db.go ├── db_test.go ├── document.go ├── document_test.go ├── embed_cohere.go ├── embed_compat.go ├── embed_ollama.go ├── embed_ollama_test.go ├── embed_openai.go ├── embed_openai_test.go ├── embed_vertex.go ├── examples ├── README.md ├── minimal │ ├── README.md │ ├── go.mod │ └── main.go ├── rag-wikipedia-ollama │ ├── .gitignore │ ├── README.md │ ├── dbpedia_sample.jsonl │ ├── go.mod │ ├── go.sum │ ├── llm.go │ └── main.go ├── s3-export-import │ ├── README.md │ ├── go.mod │ ├── go.sum │ └── main.go ├── semantic-search-arxiv-openai │ ├── .gitignore │ ├── README.md │ ├── go.mod │ └── main.go └── webassembly │ ├── README.md │ └── index.html ├── fixtures_test.go ├── go.mod ├── go.sum ├── persistence.go ├── persistence_test.go ├── query.go ├── query_test.go ├── vector.go └── wasm └── main.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | 11 | jobs: 12 | 13 | lint: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | # We make use of the `slices` feature, only available in 1.21 and newer 18 | go-version: [ '1.21', '1.22', '1.23' ] 19 | 20 | steps: 21 | - uses: actions/checkout@v4 22 | 23 | - name: Set up Go 24 | uses: actions/setup-go@v5 25 | with: 26 | go-version: ${{ matrix.go-version }} 27 | 28 | - name: Lint 29 | uses: golangci/golangci-lint-action@v6 30 | with: 31 | version: v1.60.3 32 | args: --verbose 33 | 34 | build: 35 | runs-on: ubuntu-latest 36 | strategy: 37 | matrix: 38 | # We make use of the `slices` feature, only available in 1.21 and newer 39 | go-version: [ '1.21', '1.22', '1.23' ] 40 | 41 | steps: 42 | - uses: actions/checkout@v4 43 | 44 | - name: Set up Go 45 | uses: actions/setup-go@v5 46 | with: 47 | go-version: ${{ matrix.go-version }} 48 | 49 | - name: Build 50 | run: go build -v ./... 51 | 52 | - name: Test 53 | run: go test -v -race ./... 54 | 55 | examples: 56 | runs-on: ubuntu-latest 57 | strategy: 58 | matrix: 59 | # We make use of the `slices` feature, only available in 1.21 and newer 60 | go-version: [ '1.21', '1.22', '1.23' ] 61 | 62 | steps: 63 | - uses: actions/checkout@v4 64 | 65 | - name: Set up Go 66 | uses: actions/setup-go@v5 67 | with: 68 | go-version: ${{ matrix.go-version }} 69 | 70 | - name: Build minimal example 71 | run: | 72 | cd examples/minimal 73 | go build -v ./... 74 | 75 | - name: Build RAG Wikipedia Ollama 76 | run: | 77 | cd examples/rag-wikipedia-ollama 78 | go build -v ./... 79 | 80 | - name: Semantic search arXiv OpenAI 81 | run: | 82 | cd examples/semantic-search-arxiv-openai 83 | go build -v ./... 84 | 85 | - name: S3 export/import 86 | run: | 87 | cd examples/s3-export-import 88 | go build -v ./... 89 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Mozilla Public License Version 2.0 2 | ================================== 3 | 4 | 1. Definitions 5 | -------------- 6 | 7 | 1.1. "Contributor" 8 | means each individual or legal entity that creates, contributes to 9 | the creation of, or owns Covered Software. 10 | 11 | 1.2. "Contributor Version" 12 | means the combination of the Contributions of others (if any) used 13 | by a Contributor and that particular Contributor's Contribution. 14 | 15 | 1.3. "Contribution" 16 | means Covered Software of a particular Contributor. 17 | 18 | 1.4. "Covered Software" 19 | means Source Code Form to which the initial Contributor has attached 20 | the notice in Exhibit A, the Executable Form of such Source Code 21 | Form, and Modifications of such Source Code Form, in each case 22 | including portions thereof. 23 | 24 | 1.5. "Incompatible With Secondary Licenses" 25 | means 26 | 27 | (a) that the initial Contributor has attached the notice described 28 | in Exhibit B to the Covered Software; or 29 | 30 | (b) that the Covered Software was made available under the terms of 31 | version 1.1 or earlier of the License, but not also under the 32 | terms of a Secondary License. 33 | 34 | 1.6. "Executable Form" 35 | means any form of the work other than Source Code Form. 36 | 37 | 1.7. "Larger Work" 38 | means a work that combines Covered Software with other material, in 39 | a separate file or files, that is not Covered Software. 40 | 41 | 1.8. "License" 42 | means this document. 43 | 44 | 1.9. "Licensable" 45 | means having the right to grant, to the maximum extent possible, 46 | whether at the time of the initial grant or subsequently, any and 47 | all of the rights conveyed by this License. 48 | 49 | 1.10. "Modifications" 50 | means any of the following: 51 | 52 | (a) any file in Source Code Form that results from an addition to, 53 | deletion from, or modification of the contents of Covered 54 | Software; or 55 | 56 | (b) any new file in Source Code Form that contains any Covered 57 | Software. 58 | 59 | 1.11. "Patent Claims" of a Contributor 60 | means any patent claim(s), including without limitation, method, 61 | process, and apparatus claims, in any patent Licensable by such 62 | Contributor that would be infringed, but for the grant of the 63 | License, by the making, using, selling, offering for sale, having 64 | made, import, or transfer of either its Contributions or its 65 | Contributor Version. 66 | 67 | 1.12. "Secondary License" 68 | means either the GNU General Public License, Version 2.0, the GNU 69 | Lesser General Public License, Version 2.1, the GNU Affero General 70 | Public License, Version 3.0, or any later versions of those 71 | licenses. 72 | 73 | 1.13. "Source Code Form" 74 | means the form of the work preferred for making modifications. 75 | 76 | 1.14. "You" (or "Your") 77 | means an individual or a legal entity exercising rights under this 78 | License. For legal entities, "You" includes any entity that 79 | controls, is controlled by, or is under common control with You. For 80 | purposes of this definition, "control" means (a) the power, direct 81 | or indirect, to cause the direction or management of such entity, 82 | whether by contract or otherwise, or (b) ownership of more than 83 | fifty percent (50%) of the outstanding shares or beneficial 84 | ownership of such entity. 85 | 86 | 2. License Grants and Conditions 87 | -------------------------------- 88 | 89 | 2.1. Grants 90 | 91 | Each Contributor hereby grants You a world-wide, royalty-free, 92 | non-exclusive license: 93 | 94 | (a) under intellectual property rights (other than patent or trademark) 95 | Licensable by such Contributor to use, reproduce, make available, 96 | modify, display, perform, distribute, and otherwise exploit its 97 | Contributions, either on an unmodified basis, with Modifications, or 98 | as part of a Larger Work; and 99 | 100 | (b) under Patent Claims of such Contributor to make, use, sell, offer 101 | for sale, have made, import, and otherwise transfer either its 102 | Contributions or its Contributor Version. 103 | 104 | 2.2. Effective Date 105 | 106 | The licenses granted in Section 2.1 with respect to any Contribution 107 | become effective for each Contribution on the date the Contributor first 108 | distributes such Contribution. 109 | 110 | 2.3. Limitations on Grant Scope 111 | 112 | The licenses granted in this Section 2 are the only rights granted under 113 | this License. No additional rights or licenses will be implied from the 114 | distribution or licensing of Covered Software under this License. 115 | Notwithstanding Section 2.1(b) above, no patent license is granted by a 116 | Contributor: 117 | 118 | (a) for any code that a Contributor has removed from Covered Software; 119 | or 120 | 121 | (b) for infringements caused by: (i) Your and any other third party's 122 | modifications of Covered Software, or (ii) the combination of its 123 | Contributions with other software (except as part of its Contributor 124 | Version); or 125 | 126 | (c) under Patent Claims infringed by Covered Software in the absence of 127 | its Contributions. 128 | 129 | This License does not grant any rights in the trademarks, service marks, 130 | or logos of any Contributor (except as may be necessary to comply with 131 | the notice requirements in Section 3.4). 132 | 133 | 2.4. Subsequent Licenses 134 | 135 | No Contributor makes additional grants as a result of Your choice to 136 | distribute the Covered Software under a subsequent version of this 137 | License (see Section 10.2) or under the terms of a Secondary License (if 138 | permitted under the terms of Section 3.3). 139 | 140 | 2.5. Representation 141 | 142 | Each Contributor represents that the Contributor believes its 143 | Contributions are its original creation(s) or it has sufficient rights 144 | to grant the rights to its Contributions conveyed by this License. 145 | 146 | 2.6. Fair Use 147 | 148 | This License is not intended to limit any rights You have under 149 | applicable copyright doctrines of fair use, fair dealing, or other 150 | equivalents. 151 | 152 | 2.7. Conditions 153 | 154 | Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted 155 | in Section 2.1. 156 | 157 | 3. Responsibilities 158 | ------------------- 159 | 160 | 3.1. Distribution of Source Form 161 | 162 | All distribution of Covered Software in Source Code Form, including any 163 | Modifications that You create or to which You contribute, must be under 164 | the terms of this License. You must inform recipients that the Source 165 | Code Form of the Covered Software is governed by the terms of this 166 | License, and how they can obtain a copy of this License. You may not 167 | attempt to alter or restrict the recipients' rights in the Source Code 168 | Form. 169 | 170 | 3.2. Distribution of Executable Form 171 | 172 | If You distribute Covered Software in Executable Form then: 173 | 174 | (a) such Covered Software must also be made available in Source Code 175 | Form, as described in Section 3.1, and You must inform recipients of 176 | the Executable Form how they can obtain a copy of such Source Code 177 | Form by reasonable means in a timely manner, at a charge no more 178 | than the cost of distribution to the recipient; and 179 | 180 | (b) You may distribute such Executable Form under the terms of this 181 | License, or sublicense it under different terms, provided that the 182 | license for the Executable Form does not attempt to limit or alter 183 | the recipients' rights in the Source Code Form under this License. 184 | 185 | 3.3. Distribution of a Larger Work 186 | 187 | You may create and distribute a Larger Work under terms of Your choice, 188 | provided that You also comply with the requirements of this License for 189 | the Covered Software. If the Larger Work is a combination of Covered 190 | Software with a work governed by one or more Secondary Licenses, and the 191 | Covered Software is not Incompatible With Secondary Licenses, this 192 | License permits You to additionally distribute such Covered Software 193 | under the terms of such Secondary License(s), so that the recipient of 194 | the Larger Work may, at their option, further distribute the Covered 195 | Software under the terms of either this License or such Secondary 196 | License(s). 197 | 198 | 3.4. Notices 199 | 200 | You may not remove or alter the substance of any license notices 201 | (including copyright notices, patent notices, disclaimers of warranty, 202 | or limitations of liability) contained within the Source Code Form of 203 | the Covered Software, except that You may alter any license notices to 204 | the extent required to remedy known factual inaccuracies. 205 | 206 | 3.5. Application of Additional Terms 207 | 208 | You may choose to offer, and to charge a fee for, warranty, support, 209 | indemnity or liability obligations to one or more recipients of Covered 210 | Software. However, You may do so only on Your own behalf, and not on 211 | behalf of any Contributor. You must make it absolutely clear that any 212 | such warranty, support, indemnity, or liability obligation is offered by 213 | You alone, and You hereby agree to indemnify every Contributor for any 214 | liability incurred by such Contributor as a result of warranty, support, 215 | indemnity or liability terms You offer. You may include additional 216 | disclaimers of warranty and limitations of liability specific to any 217 | jurisdiction. 218 | 219 | 4. Inability to Comply Due to Statute or Regulation 220 | --------------------------------------------------- 221 | 222 | If it is impossible for You to comply with any of the terms of this 223 | License with respect to some or all of the Covered Software due to 224 | statute, judicial order, or regulation then You must: (a) comply with 225 | the terms of this License to the maximum extent possible; and (b) 226 | describe the limitations and the code they affect. Such description must 227 | be placed in a text file included with all distributions of the Covered 228 | Software under this License. Except to the extent prohibited by statute 229 | or regulation, such description must be sufficiently detailed for a 230 | recipient of ordinary skill to be able to understand it. 231 | 232 | 5. Termination 233 | -------------- 234 | 235 | 5.1. The rights granted under this License will terminate automatically 236 | if You fail to comply with any of its terms. However, if You become 237 | compliant, then the rights granted under this License from a particular 238 | Contributor are reinstated (a) provisionally, unless and until such 239 | Contributor explicitly and finally terminates Your grants, and (b) on an 240 | ongoing basis, if such Contributor fails to notify You of the 241 | non-compliance by some reasonable means prior to 60 days after You have 242 | come back into compliance. Moreover, Your grants from a particular 243 | Contributor are reinstated on an ongoing basis if such Contributor 244 | notifies You of the non-compliance by some reasonable means, this is the 245 | first time You have received notice of non-compliance with this License 246 | from such Contributor, and You become compliant prior to 30 days after 247 | Your receipt of the notice. 248 | 249 | 5.2. If You initiate litigation against any entity by asserting a patent 250 | infringement claim (excluding declaratory judgment actions, 251 | counter-claims, and cross-claims) alleging that a Contributor Version 252 | directly or indirectly infringes any patent, then the rights granted to 253 | You by any and all Contributors for the Covered Software under Section 254 | 2.1 of this License shall terminate. 255 | 256 | 5.3. In the event of termination under Sections 5.1 or 5.2 above, all 257 | end user license agreements (excluding distributors and resellers) which 258 | have been validly granted by You or Your distributors under this License 259 | prior to termination shall survive termination. 260 | 261 | ************************************************************************ 262 | * * 263 | * 6. Disclaimer of Warranty * 264 | * ------------------------- * 265 | * * 266 | * Covered Software is provided under this License on an "as is" * 267 | * basis, without warranty of any kind, either expressed, implied, or * 268 | * statutory, including, without limitation, warranties that the * 269 | * Covered Software is free of defects, merchantable, fit for a * 270 | * particular purpose or non-infringing. The entire risk as to the * 271 | * quality and performance of the Covered Software is with You. * 272 | * Should any Covered Software prove defective in any respect, You * 273 | * (not any Contributor) assume the cost of any necessary servicing, * 274 | * repair, or correction. This disclaimer of warranty constitutes an * 275 | * essential part of this License. No use of any Covered Software is * 276 | * authorized under this License except under this disclaimer. * 277 | * * 278 | ************************************************************************ 279 | 280 | ************************************************************************ 281 | * * 282 | * 7. Limitation of Liability * 283 | * -------------------------- * 284 | * * 285 | * Under no circumstances and under no legal theory, whether tort * 286 | * (including negligence), contract, or otherwise, shall any * 287 | * Contributor, or anyone who distributes Covered Software as * 288 | * permitted above, be liable to You for any direct, indirect, * 289 | * special, incidental, or consequential damages of any character * 290 | * including, without limitation, damages for lost profits, loss of * 291 | * goodwill, work stoppage, computer failure or malfunction, or any * 292 | * and all other commercial damages or losses, even if such party * 293 | * shall have been informed of the possibility of such damages. This * 294 | * limitation of liability shall not apply to liability for death or * 295 | * personal injury resulting from such party's negligence to the * 296 | * extent applicable law prohibits such limitation. Some * 297 | * jurisdictions do not allow the exclusion or limitation of * 298 | * incidental or consequential damages, so this exclusion and * 299 | * limitation may not apply to You. * 300 | * * 301 | ************************************************************************ 302 | 303 | 8. Litigation 304 | ------------- 305 | 306 | Any litigation relating to this License may be brought only in the 307 | courts of a jurisdiction where the defendant maintains its principal 308 | place of business and such litigation shall be governed by laws of that 309 | jurisdiction, without reference to its conflict-of-law provisions. 310 | Nothing in this Section shall prevent a party's ability to bring 311 | cross-claims or counter-claims. 312 | 313 | 9. Miscellaneous 314 | ---------------- 315 | 316 | This License represents the complete agreement concerning the subject 317 | matter hereof. If any provision of this License is held to be 318 | unenforceable, such provision shall be reformed only to the extent 319 | necessary to make it enforceable. Any law or regulation which provides 320 | that the language of a contract shall be construed against the drafter 321 | shall not be used to construe this License against a Contributor. 322 | 323 | 10. Versions of the License 324 | --------------------------- 325 | 326 | 10.1. New Versions 327 | 328 | Mozilla Foundation is the license steward. Except as provided in Section 329 | 10.3, no one other than the license steward has the right to modify or 330 | publish new versions of this License. Each version will be given a 331 | distinguishing version number. 332 | 333 | 10.2. Effect of New Versions 334 | 335 | You may distribute the Covered Software under the terms of the version 336 | of the License under which You originally received the Covered Software, 337 | or under the terms of any subsequent version published by the license 338 | steward. 339 | 340 | 10.3. Modified Versions 341 | 342 | If you create software not governed by this License, and you want to 343 | create a new license for such software, you may create and use a 344 | modified version of this License if you rename the license and remove 345 | any references to the name of the license steward (except to note that 346 | such modified license differs from this License). 347 | 348 | 10.4. Distributing Source Code Form that is Incompatible With Secondary 349 | Licenses 350 | 351 | If You choose to distribute Source Code Form that is Incompatible With 352 | Secondary Licenses under the terms of this version of the License, the 353 | notice described in Exhibit B of this License must be attached. 354 | 355 | Exhibit A - Source Code Form License Notice 356 | ------------------------------------------- 357 | 358 | This Source Code Form is subject to the terms of the Mozilla Public 359 | License, v. 2.0. If a copy of the MPL was not distributed with this 360 | file, You can obtain one at https://mozilla.org/MPL/2.0/. 361 | 362 | If it is not possible or desirable to put the notice in a particular 363 | file, then You may include the notice in a location (such as a LICENSE 364 | file in a relevant directory) where a recipient would be likely to look 365 | for such a notice. 366 | 367 | You may add additional accurate notices of copyright ownership. 368 | 369 | Exhibit B - "Incompatible With Secondary Licenses" Notice 370 | --------------------------------------------------------- 371 | 372 | This Source Code Form is "Incompatible With Secondary Licenses", as 373 | defined by the Mozilla Public License, v. 2.0. 374 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # chromem-go 2 | 3 | [![Go Reference](https://pkg.go.dev/badge/github.com/philippgille/chromem-go.svg)](https://pkg.go.dev/github.com/philippgille/chromem-go) 4 | [![Build status](https://github.com/philippgille/chromem-go/actions/workflows/go.yml/badge.svg)](https://github.com/philippgille/chromem-go/actions/workflows/go.yml) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/philippgille/chromem-go)](https://goreportcard.com/report/github.com/philippgille/chromem-go) 6 | [![GitHub Releases](https://img.shields.io/github/release/philippgille/chromem-go.svg)](https://github.com/philippgille/chromem-go/releases) 7 | 8 | Embeddable vector database for Go with Chroma-like interface and zero third-party dependencies. In-memory with optional persistence. 9 | 10 | Because `chromem-go` is embeddable it enables you to add retrieval augmented generation (RAG) and similar embeddings-based features into your Go app *without having to run a separate database*. Like when using SQLite instead of PostgreSQL/MySQL/etc. 11 | 12 | It's *not* a library to connect to Chroma and also not a reimplementation of it in Go. It's a database on its own. 13 | 14 | The focus is not scale (millions of documents) or number of features, but simplicity and performance for the most common use cases. On a mid-range 2020 Intel laptop CPU you can query 1,000 documents in 0.3 ms and 100,000 documents in 40 ms, with very few and small memory allocations. See [Benchmarks](#benchmarks) for details. 15 | 16 | > ⚠️ The project is in beta, under heavy construction, and may introduce breaking changes in releases before `v1.0.0`. All changes are documented in the [`CHANGELOG`](./CHANGELOG.md). 17 | 18 | ## Contents 19 | 20 | 1. [Use cases](#use-cases) 21 | 2. [Interface](#interface) 22 | 3. [Features + Roadmap](#features) 23 | 4. [Installation](#installation) 24 | 5. [Usage](#usage) 25 | 6. [Benchmarks](#benchmarks) 26 | 7. [Development](#development) 27 | 8. [Motivation](#motivation) 28 | 9. [Related projects](#related-projects) 29 | 30 | ## Use cases 31 | 32 | With a vector database you can do various things: 33 | 34 | - Retrieval augmented generation (RAG), question answering (Q&A) 35 | - Text and code search 36 | - Recommendation systems 37 | - Classification 38 | - Clustering 39 | 40 | Let's look at the RAG use case in more detail: 41 | 42 | ### RAG 43 | 44 | The knowledge of large language models (LLMs) - even the ones with 30 billion, 70 billion parameters and more - is limited. They don't know anything about what happened after their training ended, they don't know anything about data they were not trained with (like your company's intranet, Jira / bug tracker, wiki or other kinds of knowledge bases), and even the data they *do* know they often can't reproduce it *exactly*, but start to *hallucinate* instead. 45 | 46 | Fine-tuning an LLM can help a bit, but it's more meant to improve the LLMs reasoning about specific topics, or reproduce the style of written text or code. Fine-tuning does *not* add knowledge *1:1* into the model. Details are lost or mixed up. And knowledge cutoff (about anything that happened after the fine-tuning) isn't solved either. 47 | 48 | => A vector database can act as the up-to-date, precise knowledge for LLMs: 49 | 50 | 1. You store relevant documents that you want the LLM to know in the database. 51 | 2. The database stores the *embeddings* alongside the documents, which you can either provide or can be created by specific "embedding models" like OpenAI's `text-embedding-3-small`. 52 | - `chromem-go` can do this for you and supports multiple embedding providers and models out-of-the-box. 53 | 3. Later, when you want to talk to the LLM, you first send the question to the vector DB to find *similar*/*related* content. This is called "nearest neighbor search". 54 | 4. In the question to the LLM, you provide this content alongside your question. 55 | 5. The LLM can take this up-to-date precise content into account when answering. 56 | 57 | Check out the [example code](examples) to see it in action! 58 | 59 | ## Interface 60 | 61 | Our original inspiration was the [Chroma](https://www.trychroma.com/) interface, whose core API is the following (taken from their [README](https://github.com/chroma-core/chroma/blob/0.4.21/README.md)): 62 | 63 |
Chroma core interface 64 | 65 | ```python 66 | import chromadb 67 | # setup Chroma in-memory, for easy prototyping. Can add persistence easily! 68 | client = chromadb.Client() 69 | 70 | # Create collection. get_collection, get_or_create_collection, delete_collection also available! 71 | collection = client.create_collection("all-my-documents") 72 | 73 | # Add docs to the collection. Can also update and delete. Row-based API coming soon! 74 | collection.add( 75 | documents=["This is document1", "This is document2"], # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well 76 | metadatas=[{"source": "notion"}, {"source": "google-docs"}], # filter on these! 77 | ids=["doc1", "doc2"], # unique for each doc 78 | ) 79 | 80 | # Query/search 2 most similar results. You can also .get by id 81 | results = collection.query( 82 | query_texts=["This is a query document"], 83 | n_results=2, 84 | # where={"metadata_field": "is_equal_to_this"}, # optional filter 85 | # where_document={"$contains":"search_string"} # optional filter 86 | ) 87 | ``` 88 | 89 |
90 | 91 | Our Go library exposes the same interface: 92 | 93 |
chromem-go equivalent 94 | 95 | ```go 96 | package main 97 | 98 | import "github.com/philippgille/chromem-go" 99 | 100 | func main() { 101 | // Set up chromem-go in-memory, for easy prototyping. Can add persistence easily! 102 | // We call it DB instead of client because there's no client-server separation. The DB is embedded. 103 | db := chromem.NewDB() 104 | 105 | // Create collection. GetCollection, GetOrCreateCollection, DeleteCollection also available! 106 | collection, _ := db.CreateCollection("all-my-documents", nil, nil) 107 | 108 | // Add docs to the collection. Update and delete will be added in the future. 109 | // Can be multi-threaded with AddConcurrently()! 110 | // We're showing the Chroma-like method here, but more Go-idiomatic methods are also available! 111 | _ = collection.Add(ctx, 112 | []string{"doc1", "doc2"}, // unique ID for each doc 113 | nil, // We handle embedding automatically. You can skip that and add your own embeddings as well. 114 | []map[string]string{{"source": "notion"}, {"source": "google-docs"}}, // Filter on these! 115 | []string{"This is document1", "This is document2"}, 116 | ) 117 | 118 | // Query/search 2 most similar results. You can also get by ID. 119 | results, _ := collection.Query(ctx, 120 | "This is a query document", 121 | 2, 122 | map[string]string{"metadata_field": "is_equal_to_this"}, // optional filter 123 | map[string]string{"$contains": "search_string"}, // optional filter 124 | ) 125 | } 126 | ``` 127 | 128 |
129 | 130 | Initially `chromem-go` started with just the four core methods, but we added more over time. We intentionally don't want to cover 100% of Chroma's API surface though. 131 | We're providing some alternative methods that are more Go-idiomatic instead. 132 | 133 | For the full interface see the Godoc: 134 | 135 | ## Features 136 | 137 | - [X] Zero dependencies on third party libraries 138 | - [X] Embeddable (like SQLite, i.e. no client-server model, no separate DB to maintain) 139 | - [X] Multithreaded processing (when adding and querying documents), making use of Go's native concurrency features 140 | - [X] Experimental WebAssembly binding 141 | - Embedding creators: 142 | - Hosted: 143 | - [X] [OpenAI](https://platform.openai.com/docs/guides/embeddings/embedding-models) (default) 144 | - [X] [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings) 145 | - [X] [GCP Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings) 146 | - [X] [Cohere](https://cohere.com/models/embed) 147 | - [X] [Mistral](https://docs.mistral.ai/platform/endpoints/#embedding-models) 148 | - [X] [Jina](https://jina.ai/embeddings) 149 | - [X] [mixedbread.ai](https://www.mixedbread.ai/) 150 | - Local: 151 | - [X] [Ollama](https://github.com/ollama/ollama) 152 | - [X] [LocalAI](https://github.com/mudler/LocalAI) 153 | - Bring your own (implement [`chromem.EmbeddingFunc`](https://pkg.go.dev/github.com/philippgille/chromem-go#EmbeddingFunc)) 154 | - You can also pass existing embeddings when adding documents to a collection, instead of letting `chromem-go` create them 155 | - Similarity search: 156 | - [X] Exhaustive nearest neighbor search using cosine similarity (sometimes also called exact search or brute-force search or FLAT index) 157 | - Filters: 158 | - [X] Document filters: `$contains`, `$not_contains` 159 | - [X] Metadata filters: Exact matches 160 | - Storage: 161 | - [X] In-memory 162 | - [X] Optional immediate persistence (writes one file for each added collection and document, encoded as [gob](https://go.dev/blog/gob), optionally gzip-compressed) 163 | - [X] Backups: Export and import of the entire DB to/from a single file (encoded as [gob](https://go.dev/blog/gob), optionally gzip-compressed and AES-GCM encrypted) 164 | - Includes methods for generic `io.Writer`/`io.Reader` so you can plug S3 buckets and other blob storage, see [examples/s3-export-import](examples/s3-export-import) for example code 165 | - Data types: 166 | - [X] Documents (text) 167 | 168 | ### Roadmap 169 | 170 | - Performance: 171 | - Use SIMD for dot product calculation on supported CPUs (draft PR: [#48](https://github.com/philippgille/chromem-go/pull/48)) 172 | - Add [roaring bitmaps](https://github.com/RoaringBitmap/roaring) to speed up full text filtering 173 | - Embedding creators: 174 | - Add an `EmbeddingFunc` that downloads and shells out to [llamafile](https://github.com/Mozilla-Ocho/llamafile) 175 | - Similarity search: 176 | - Approximate nearest neighbor search with index (ANN) 177 | - Hierarchical Navigable Small World (HNSW) 178 | - Inverted file flat (IVFFlat) 179 | - Filters: 180 | - Operators (`$and`, `$or` etc.) 181 | - Storage: 182 | - JSON as second encoding format 183 | - Write-ahead log (WAL) as second file format 184 | - Optional remote storage (S3, PostgreSQL, ...) 185 | - Data types: 186 | - Images 187 | - Videos 188 | 189 | ## Installation 190 | 191 | `go get github.com/philippgille/chromem-go@latest` 192 | 193 | ## Usage 194 | 195 | See the Godoc for a reference: 196 | 197 | For full, working examples, using the vector database for retrieval augmented generation (RAG) and semantic search and using either OpenAI or locally running the embeddings model and LLM (in Ollama), see the [example code](examples). 198 | 199 | ### Quickstart 200 | 201 | This is taken from the ["minimal" example](examples/minimal): 202 | 203 | ```go 204 | package main 205 | 206 | import ( 207 | "context" 208 | "fmt" 209 | "runtime" 210 | 211 | "github.com/philippgille/chromem-go" 212 | ) 213 | 214 | func main() { 215 | ctx := context.Background() 216 | 217 | db := chromem.NewDB() 218 | 219 | // Passing nil as embedding function leads to OpenAI being used and requires 220 | // "OPENAI_API_KEY" env var to be set. Other providers are supported as well. 221 | // For example pass `chromem.NewEmbeddingFuncOllama(...)` to use Ollama. 222 | c, err := db.CreateCollection("knowledge-base", nil, nil) 223 | if err != nil { 224 | panic(err) 225 | } 226 | 227 | err = c.AddDocuments(ctx, []chromem.Document{ 228 | { 229 | ID: "1", 230 | Content: "The sky is blue because of Rayleigh scattering.", 231 | }, 232 | { 233 | ID: "2", 234 | Content: "Leaves are green because chlorophyll absorbs red and blue light.", 235 | }, 236 | }, runtime.NumCPU()) 237 | if err != nil { 238 | panic(err) 239 | } 240 | 241 | res, err := c.Query(ctx, "Why is the sky blue?", 1, nil, nil) 242 | if err != nil { 243 | panic(err) 244 | } 245 | 246 | fmt.Printf("ID: %v\nSimilarity: %v\nContent: %v\n", res[0].ID, res[0].Similarity, res[0].Content) 247 | } 248 | ``` 249 | 250 | Output: 251 | 252 | ```text 253 | ID: 1 254 | Similarity: 0.6833369 255 | Content: The sky is blue because of Rayleigh scattering. 256 | ``` 257 | 258 | ## Benchmarks 259 | 260 | Benchmarked on 2024-03-17 with: 261 | 262 | - Computer: Framework Laptop 13 (first generation, 2021) 263 | - CPU: 11th Gen Intel Core i5-1135G7 (2020) 264 | - Memory: 32 GB 265 | - OS: Fedora Linux 39 266 | - Kernel: 6.7 267 | 268 | ```console 269 | $ go test -benchmem -run=^$ -bench . 270 | goos: linux 271 | goarch: amd64 272 | pkg: github.com/philippgille/chromem-go 273 | cpu: 11th Gen Intel(R) Core(TM) i5-1135G7 @ 2.40GHz 274 | BenchmarkCollection_Query_NoContent_100-8 13164 90276 ns/op 5176 B/op 95 allocs/op 275 | BenchmarkCollection_Query_NoContent_1000-8 2142 520261 ns/op 13558 B/op 141 allocs/op 276 | BenchmarkCollection_Query_NoContent_5000-8 561 2150354 ns/op 47096 B/op 173 allocs/op 277 | BenchmarkCollection_Query_NoContent_25000-8 120 9890177 ns/op 211783 B/op 208 allocs/op 278 | BenchmarkCollection_Query_NoContent_100000-8 30 39574238 ns/op 810370 B/op 232 allocs/op 279 | BenchmarkCollection_Query_100-8 13225 91058 ns/op 5177 B/op 95 allocs/op 280 | BenchmarkCollection_Query_1000-8 2226 519693 ns/op 13552 B/op 140 allocs/op 281 | BenchmarkCollection_Query_5000-8 550 2128121 ns/op 47108 B/op 173 allocs/op 282 | BenchmarkCollection_Query_25000-8 100 10063260 ns/op 211705 B/op 205 allocs/op 283 | BenchmarkCollection_Query_100000-8 30 39404005 ns/op 810295 B/op 229 allocs/op 284 | PASS 285 | ok github.com/philippgille/chromem-go 28.402s 286 | ``` 287 | 288 | ## Development 289 | 290 | - Build: `go build ./...` 291 | - Test: `go test -v -race -count 1 ./...` 292 | - Benchmark: 293 | - `go test -benchmem -run=^$ -bench .` (add `> bench.out` or similar to write to a file) 294 | - With profiling: `go test -benchmem -run ^$ -cpuprofile cpu.out -bench .` 295 | - (profiles: `-cpuprofile`, `-memprofile`, `-blockprofile`, `-mutexprofile`) 296 | - Compare benchmarks: 297 | 1. Install `benchstat`: `go install golang.org/x/perf/cmd/benchstat@latest` 298 | 2. Compare two benchmark results: `benchstat before.out after.out` 299 | 300 | ## Motivation 301 | 302 | In December 2023, when I wanted to play around with retrieval augmented generation (RAG) in a Go program, I looked for a vector database that could be embedded in the Go program, just like you would embed SQLite in order to not require any separate DB setup and maintenance. I was surprised when I didn't find any, given the abundance of embedded key-value stores in the Go ecosystem. 303 | 304 | At the time most of the popular vector databases like Pinecone, Qdrant, Milvus, Chroma, Weaviate and others were not embeddable at all or only in Python or JavaScript/TypeScript. 305 | 306 | Then I found [@eliben](https://github.com/eliben)'s [blog post](https://eli.thegreenplace.net/2023/retrieval-augmented-generation-in-go/) and [example code](https://github.com/eliben/code-for-blog/tree/eda87b87dad9ed8bd45d1c8d6395efba3741ed39/2023/go-rag-openai) which showed that with very little Go code you could create a very basic PoC of a vector database. 307 | 308 | That's when I decided to build my own vector database, embeddable in Go, inspired by the ChromaDB interface. ChromaDB stood out for being embeddable (in Python), and by showing its core API in 4 commands on their README and on the landing page of their website. 309 | 310 | ## Related projects 311 | 312 | - Shoutout to [@eliben](https://github.com/eliben) whose [blog post](https://eli.thegreenplace.net/2023/retrieval-augmented-generation-in-go/) and [example code](https://github.com/eliben/code-for-blog/tree/eda87b87dad9ed8bd45d1c8d6395efba3741ed39/2023/go-rag-openai) inspired me to start this project! 313 | - [Chroma](https://github.com/chroma-core/chroma): Looking at Pinecone, Qdrant, Milvus, Weaviate and others, Chroma stood out by showing its core API in 4 commands on their README and on the landing page of their website. It was also putting the most emphasis on its embeddability (in Python). 314 | - The big, full-fledged client-server-based vector databases for maximum scale and performance: 315 | - [Pinecone](https://www.pinecone.io/): Closed source 316 | - [Qdrant](https://github.com/qdrant/qdrant): Written in Rust, not embeddable in Go 317 | - [Milvus](https://github.com/milvus-io/milvus): Written in Go and C++, but not embeddable as of December 2023 318 | - [Weaviate](https://github.com/weaviate/weaviate): Written in Go, but not embeddable in Go as of March 2024 (only in Python and JavaScript/TypeScript and that's experimental) 319 | - Some non-specialized SQL, NoSQL and Key-Value databases added support for storing vectors and (some of them) querying based on similarity: 320 | - [pgvector](https://github.com/pgvector/pgvector) extension for [PostgreSQL](https://www.postgresql.org/): Client-server model 321 | - [Redis](https://github.com/redis/redis) ([1](https://redis.io/docs/interact/search-and-query/query/vector-search/), [2](https://redis.io/docs/interact/search-and-query/advanced-concepts/vectors/)): Client-server model 322 | - [sqlite-vss](https://github.com/asg017/sqlite-vss) extension for [SQLite](https://www.sqlite.org/): Embedded, but the [Go bindings](https://github.com/asg017/sqlite-vss/tree/8fc44301843029a13a474d1f292378485e1fdd62/bindings/go) require CGO. There's a [CGO-free Go library](https://gitlab.com/cznic/sqlite) for SQLite, but then it's without the vector search extension. 323 | - [DuckDB](https://github.com/duckdb/duckdb) has a function to calculate cosine similarity ([1](https://duckdb.org/docs/sql/functions/nested)): Embedded, but the Go bindings use CGO 324 | - [MongoDB](https://github.com/mongodb/mongo)'s cloud platform offers a vector search product ([1](https://www.mongodb.com/products/platform/atlas-vector-search)): Client-server model 325 | - Some libraries for vector similarity search: 326 | - [Faiss](https://github.com/facebookresearch/faiss): Written in C++; 3rd party Go bindings use CGO 327 | - [Annoy](https://github.com/spotify/annoy): Written in C++; Go bindings use CGO ([1](https://github.com/spotify/annoy/blob/2be37c9e015544be2cf60c431f0cccc076151a2d/README_GO.rst)) 328 | - [USearch](https://github.com/unum-cloud/usearch): Written in C++; Go bindings use CGO 329 | - Some orchestration libraries, inspired by the Python library [LangChain](https://github.com/langchain-ai/langchain), but with no or only rudimentary embedded vector DB: 330 | - [LangChain Go](https://github.com/tmc/langchaingo) 331 | - [LinGoose](https://github.com/henomis/lingoose) 332 | - [GoLC](https://github.com/hupe1980/golc) 333 | -------------------------------------------------------------------------------- /collection.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "maps" 8 | "path/filepath" 9 | "slices" 10 | "sync" 11 | ) 12 | 13 | // Collection represents a collection of documents. 14 | // It also has a configured embedding function, which is used when adding documents 15 | // that don't have embeddings yet. 16 | type Collection struct { 17 | Name string 18 | 19 | metadata map[string]string 20 | documents map[string]*Document 21 | documentsLock sync.RWMutex 22 | embed EmbeddingFunc 23 | 24 | persistDirectory string 25 | compress bool 26 | 27 | // ⚠️ When adding fields here, consider adding them to the persistence struct 28 | // versions in [DB.Export] and [DB.Import] as well! 29 | } 30 | 31 | // NegativeMode represents the mode to use for the negative text. 32 | // See QueryOptions for more information. 33 | type NegativeMode string 34 | 35 | const ( 36 | // NEGATIVE_MODE_FILTER filters out results based on the similarity between the 37 | // negative embedding and the document embeddings. 38 | // NegativeFilterThreshold controls the threshold for filtering. Documents with 39 | // similarity above the threshold will be removed from the results. 40 | NEGATIVE_MODE_FILTER NegativeMode = "filter" 41 | 42 | // NEGATIVE_MODE_SUBTRACT subtracts the negative embedding from the query embedding. 43 | // This is the default behavior. 44 | NEGATIVE_MODE_SUBTRACT NegativeMode = "subtract" 45 | 46 | // The default threshold for the negative filter. 47 | DEFAULT_NEGATIVE_FILTER_THRESHOLD = 0.5 48 | ) 49 | 50 | // QueryOptions represents the options for a query. 51 | type QueryOptions struct { 52 | // The text to search for. 53 | QueryText string 54 | 55 | // The embedding of the query to search for. It must be created 56 | // with the same embedding model as the document embeddings in the collection. 57 | // The embedding will be normalized if it's not the case yet. 58 | // If both QueryText and QueryEmbedding are set, QueryEmbedding will be used. 59 | QueryEmbedding []float32 60 | 61 | // The number of results to return. 62 | NResults int 63 | 64 | // Conditional filtering on metadata. 65 | Where map[string]string 66 | 67 | // Conditional filtering on documents. 68 | WhereDocument map[string]string 69 | 70 | // Negative is the negative query options. 71 | // They can be used to exclude certain results from the query. 72 | Negative NegativeQueryOptions 73 | } 74 | 75 | type NegativeQueryOptions struct { 76 | // Mode is the mode to use for the negative text. 77 | Mode NegativeMode 78 | 79 | // Text is the text to exclude from the results. 80 | Text string 81 | 82 | // Embedding is the embedding of the negative text. It must be created 83 | // with the same embedding model as the document embeddings in the collection. 84 | // The embedding will be normalized if it's not the case yet. 85 | // If both Text and Embedding are set, Embedding will be used. 86 | Embedding []float32 87 | 88 | // FilterThreshold is the threshold for the negative filter. Used when Mode is NEGATIVE_MODE_FILTER. 89 | FilterThreshold float32 90 | } 91 | 92 | // We don't export this yet to keep the API surface to the bare minimum. 93 | // Users create collections via [Client.CreateCollection]. 94 | func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string, compress bool) (*Collection, error) { 95 | // We copy the metadata to avoid data races in case the caller modifies the 96 | // map after creating the collection while we range over it. 97 | m := make(map[string]string, len(metadata)) 98 | for k, v := range metadata { 99 | m[k] = v 100 | } 101 | 102 | c := &Collection{ 103 | Name: name, 104 | 105 | metadata: m, 106 | documents: make(map[string]*Document), 107 | embed: embed, 108 | } 109 | 110 | // Persistence 111 | if dbDir != "" { 112 | safeName := hash2hex(name) 113 | c.persistDirectory = filepath.Join(dbDir, safeName) 114 | c.compress = compress 115 | return c, c.persistMetadata() 116 | } 117 | 118 | return c, nil 119 | } 120 | 121 | // Add embeddings to the datastore. 122 | // 123 | // - ids: The ids of the embeddings you wish to add 124 | // - embeddings: The embeddings to add. If nil, embeddings will be computed based 125 | // on the contents using the embeddingFunc set for the Collection. Optional. 126 | // - metadatas: The metadata to associate with the embeddings. When querying, 127 | // you can filter on this metadata. Optional. 128 | // - contents: The contents to associate with the embeddings. 129 | // 130 | // This is a Chroma-like method. For a more Go-idiomatic one, see [Collection.AddDocuments]. 131 | func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string) error { 132 | return c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 1) 133 | } 134 | 135 | // AddConcurrently is like Add, but adds embeddings concurrently. 136 | // This is mostly useful when you don't pass any embeddings, so they have to be created. 137 | // Upon error, concurrently running operations are canceled and the error is returned. 138 | // 139 | // This is a Chroma-like method. For a more Go-idiomatic one, see [Collection.AddDocuments]. 140 | func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string, concurrency int) error { 141 | if len(ids) == 0 { 142 | return errors.New("ids are empty") 143 | } 144 | if len(embeddings) == 0 && len(contents) == 0 { 145 | return errors.New("either embeddings or contents must be filled") 146 | } 147 | if len(embeddings) != 0 { 148 | if len(embeddings) != len(ids) { 149 | return errors.New("ids and embeddings must have the same length") 150 | } 151 | } else { 152 | // Assign empty slice, so we can simply access via index later 153 | embeddings = make([][]float32, len(ids)) 154 | } 155 | if len(metadatas) != 0 { 156 | if len(ids) != len(metadatas) { 157 | return errors.New("when metadatas is not empty it must have the same length as ids") 158 | } 159 | } else { 160 | // Assign empty slice, so we can simply access via index later 161 | metadatas = make([]map[string]string, len(ids)) 162 | } 163 | if len(contents) != 0 { 164 | if len(contents) != len(ids) { 165 | return errors.New("ids and contents must have the same length") 166 | } 167 | } else { 168 | // Assign empty slice, so we can simply access via index later 169 | contents = make([]string, len(ids)) 170 | } 171 | if concurrency < 1 { 172 | return errors.New("concurrency must be at least 1") 173 | } 174 | 175 | // Convert Chroma-style parameters into a slice of documents. 176 | docs := make([]Document, 0, len(ids)) 177 | for i, id := range ids { 178 | docs = append(docs, Document{ 179 | ID: id, 180 | Metadata: metadatas[i], 181 | Embedding: embeddings[i], 182 | Content: contents[i], 183 | }) 184 | } 185 | 186 | return c.AddDocuments(ctx, docs, concurrency) 187 | } 188 | 189 | // AddDocuments adds documents to the collection with the specified concurrency. 190 | // If the documents don't have embeddings, they will be created using the collection's 191 | // embedding function. 192 | // Upon error, concurrently running operations are canceled and the error is returned. 193 | func (c *Collection) AddDocuments(ctx context.Context, documents []Document, concurrency int) error { 194 | if len(documents) == 0 { 195 | // TODO: Should this be a no-op instead? 196 | return errors.New("documents slice is nil or empty") 197 | } 198 | if concurrency < 1 { 199 | return errors.New("concurrency must be at least 1") 200 | } 201 | // For other validations we rely on AddDocument. 202 | 203 | var sharedErr error 204 | sharedErrLock := sync.Mutex{} 205 | ctx, cancel := context.WithCancelCause(ctx) 206 | defer cancel(nil) 207 | setSharedErr := func(err error) { 208 | sharedErrLock.Lock() 209 | defer sharedErrLock.Unlock() 210 | // Another goroutine might have already set the error. 211 | if sharedErr == nil { 212 | sharedErr = err 213 | // Cancel the operation for all other goroutines. 214 | cancel(sharedErr) 215 | } 216 | } 217 | 218 | var wg sync.WaitGroup 219 | semaphore := make(chan struct{}, concurrency) 220 | for _, doc := range documents { 221 | wg.Add(1) 222 | go func(doc Document) { 223 | defer wg.Done() 224 | 225 | // Don't even start if another goroutine already failed. 226 | if ctx.Err() != nil { 227 | return 228 | } 229 | 230 | // Wait here while $concurrency other goroutines are creating documents. 231 | semaphore <- struct{}{} 232 | defer func() { <-semaphore }() 233 | 234 | err := c.AddDocument(ctx, doc) 235 | if err != nil { 236 | setSharedErr(fmt.Errorf("couldn't add document '%s': %w", doc.ID, err)) 237 | return 238 | } 239 | }(doc) 240 | } 241 | 242 | wg.Wait() 243 | 244 | return sharedErr 245 | } 246 | 247 | // AddDocument adds a document to the collection. 248 | // If the document doesn't have an embedding, it will be created using the collection's 249 | // embedding function. 250 | func (c *Collection) AddDocument(ctx context.Context, doc Document) error { 251 | if doc.ID == "" { 252 | return errors.New("document ID is empty") 253 | } 254 | if len(doc.Embedding) == 0 && doc.Content == "" { 255 | return errors.New("either document embedding or content must be filled") 256 | } 257 | 258 | // We copy the metadata to avoid data races in case the caller modifies the 259 | // map after creating the document while we range over it. 260 | m := make(map[string]string, len(doc.Metadata)) 261 | for k, v := range doc.Metadata { 262 | m[k] = v 263 | } 264 | 265 | // Create embedding if they don't exist, otherwise normalize if necessary 266 | if len(doc.Embedding) == 0 { 267 | embedding, err := c.embed(ctx, doc.Content) 268 | if err != nil { 269 | return fmt.Errorf("couldn't create embedding of document: %w", err) 270 | } 271 | doc.Embedding = embedding 272 | } else { 273 | if !isNormalized(doc.Embedding) { 274 | doc.Embedding = normalizeVector(doc.Embedding) 275 | } 276 | } 277 | 278 | c.documentsLock.Lock() 279 | // We don't defer the unlock because we want to do it earlier. 280 | c.documents[doc.ID] = &doc 281 | c.documentsLock.Unlock() 282 | 283 | // Persist the document 284 | if c.persistDirectory != "" { 285 | docPath := c.getDocPath(doc.ID) 286 | err := persistToFile(docPath, doc, c.compress, "") 287 | if err != nil { 288 | return fmt.Errorf("couldn't persist document to %q: %w", docPath, err) 289 | } 290 | } 291 | 292 | return nil 293 | } 294 | 295 | // GetByID returns a document by its ID. 296 | // The returned document is a copy of the original document, so it can be safely 297 | // modified without affecting the collection. 298 | func (c *Collection) GetByID(ctx context.Context, id string) (Document, error) { 299 | if id == "" { 300 | return Document{}, errors.New("document ID is empty") 301 | } 302 | 303 | c.documentsLock.RLock() 304 | defer c.documentsLock.RUnlock() 305 | 306 | doc, ok := c.documents[id] 307 | if ok { 308 | // Clone the document 309 | res := *doc 310 | // Above copies the simple fields, but we need to copy the slices and maps 311 | res.Metadata = maps.Clone(doc.Metadata) 312 | res.Embedding = slices.Clone(doc.Embedding) 313 | 314 | return res, nil 315 | } 316 | 317 | return Document{}, fmt.Errorf("document with ID '%v' not found", id) 318 | } 319 | 320 | // Delete removes document(s) from the collection. 321 | // 322 | // - where: Conditional filtering on metadata. Optional. 323 | // - whereDocument: Conditional filtering on documents. Optional. 324 | // - ids: The ids of the documents to delete. If empty, all documents are deleted. 325 | func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error { 326 | // must have at least one of where, whereDocument or ids 327 | if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 { 328 | return fmt.Errorf("must have at least one of where, whereDocument or ids") 329 | } 330 | 331 | if len(c.documents) == 0 { 332 | return nil 333 | } 334 | 335 | for k := range whereDocument { 336 | if !slices.Contains(supportedFilters, k) { 337 | return errors.New("unsupported whereDocument operator") 338 | } 339 | } 340 | 341 | var docIDs []string 342 | 343 | c.documentsLock.Lock() 344 | defer c.documentsLock.Unlock() 345 | 346 | if where != nil || whereDocument != nil { 347 | // metadata + content filters 348 | filteredDocs := filterDocs(c.documents, where, whereDocument) 349 | for _, doc := range filteredDocs { 350 | docIDs = append(docIDs, doc.ID) 351 | } 352 | } else { 353 | docIDs = ids 354 | } 355 | 356 | // No-op if no docs are left 357 | if len(docIDs) == 0 { 358 | return nil 359 | } 360 | 361 | for _, docID := range docIDs { 362 | delete(c.documents, docID) 363 | 364 | // Remove the document from disk 365 | if c.persistDirectory != "" { 366 | docPath := c.getDocPath(docID) 367 | err := removeFile(docPath) 368 | if err != nil { 369 | return fmt.Errorf("couldn't remove document at %q: %w", docPath, err) 370 | } 371 | } 372 | } 373 | 374 | return nil 375 | } 376 | 377 | // Count returns the number of documents in the collection. 378 | func (c *Collection) Count() int { 379 | c.documentsLock.RLock() 380 | defer c.documentsLock.RUnlock() 381 | return len(c.documents) 382 | } 383 | 384 | // Result represents a single result from a query. 385 | type Result struct { 386 | ID string 387 | Metadata map[string]string 388 | Embedding []float32 389 | Content string 390 | 391 | // The cosine similarity between the query and the document. 392 | // The higher the value, the more similar the document is to the query. 393 | // The value is in the range [-1, 1]. 394 | Similarity float32 395 | } 396 | 397 | // Query performs an exhaustive nearest neighbor search on the collection. 398 | // 399 | // - queryText: The text to search for. Its embedding will be created using the 400 | // collection's embedding function. 401 | // - nResults: The maximum number of results to return. Must be > 0. 402 | // There can be fewer results if a filter is applied. 403 | // - where: Conditional filtering on metadata. Optional. 404 | // - whereDocument: Conditional filtering on documents. Optional. 405 | func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) { 406 | if queryText == "" { 407 | return nil, errors.New("queryText is empty") 408 | } 409 | 410 | queryVector, err := c.embed(ctx, queryText) 411 | if err != nil { 412 | return nil, fmt.Errorf("couldn't create embedding of query: %w", err) 413 | } 414 | 415 | return c.QueryEmbedding(ctx, queryVector, nResults, where, whereDocument) 416 | } 417 | 418 | // QueryWithOptions performs an exhaustive nearest neighbor search on the collection. 419 | // 420 | // - options: The options for the query. See [QueryOptions] for more information. 421 | func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions) ([]Result, error) { 422 | if options.QueryText == "" && len(options.QueryEmbedding) == 0 { 423 | return nil, errors.New("QueryText and QueryEmbedding options are empty") 424 | } 425 | 426 | var err error 427 | queryVector := options.QueryEmbedding 428 | if len(queryVector) == 0 { 429 | queryVector, err = c.embed(ctx, options.QueryText) 430 | if err != nil { 431 | return nil, fmt.Errorf("couldn't create embedding of query: %w", err) 432 | } 433 | } 434 | 435 | negativeFilterThreshold := options.Negative.FilterThreshold 436 | negativeVector := options.Negative.Embedding 437 | if len(negativeVector) == 0 && options.Negative.Text != "" { 438 | negativeVector, err = c.embed(ctx, options.Negative.Text) 439 | if err != nil { 440 | return nil, fmt.Errorf("couldn't create embedding of negative: %w", err) 441 | } 442 | } 443 | 444 | if len(negativeVector) != 0 { 445 | if !isNormalized(negativeVector) { 446 | negativeVector = normalizeVector(negativeVector) 447 | } 448 | 449 | if options.Negative.Mode == NEGATIVE_MODE_SUBTRACT { 450 | queryVector = subtractVector(queryVector, negativeVector) 451 | queryVector = normalizeVector(queryVector) 452 | } else if options.Negative.Mode == NEGATIVE_MODE_FILTER { 453 | if negativeFilterThreshold == 0 { 454 | negativeFilterThreshold = DEFAULT_NEGATIVE_FILTER_THRESHOLD 455 | } 456 | } else { 457 | return nil, fmt.Errorf("unsupported negative mode: %q", options.Negative.Mode) 458 | } 459 | } 460 | 461 | result, err := c.queryEmbedding(ctx, queryVector, negativeVector, negativeFilterThreshold, options.NResults, options.Where, options.WhereDocument) 462 | if err != nil { 463 | return nil, err 464 | } 465 | 466 | return result, nil 467 | } 468 | 469 | // QueryEmbedding performs an exhaustive nearest neighbor search on the collection. 470 | // 471 | // - queryEmbedding: The embedding of the query to search for. It must be created 472 | // with the same embedding model as the document embeddings in the collection. 473 | // The embedding will be normalized if it's not the case yet. 474 | // - nResults: The maximum number of results to return. Must be > 0. 475 | // There can be fewer results if a filter is applied. 476 | // - where: Conditional filtering on metadata. Optional. 477 | // - whereDocument: Conditional filtering on documents. Optional. 478 | func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) { 479 | return c.queryEmbedding(ctx, queryEmbedding, nil, 0, nResults, where, whereDocument) 480 | } 481 | 482 | // queryEmbedding performs an exhaustive nearest neighbor search on the collection. 483 | func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativeEmbeddings []float32, negativeFilterThreshold float32, nResults int, where, whereDocument map[string]string) ([]Result, error) { 484 | if len(queryEmbedding) == 0 { 485 | return nil, errors.New("queryEmbedding is empty") 486 | } 487 | if nResults <= 0 { 488 | return nil, errors.New("nResults must be > 0") 489 | } 490 | c.documentsLock.RLock() 491 | defer c.documentsLock.RUnlock() 492 | if nResults > len(c.documents) { 493 | return nil, errors.New("nResults must be <= the number of documents in the collection") 494 | } 495 | 496 | if len(c.documents) == 0 { 497 | return nil, nil 498 | } 499 | 500 | // Validate whereDocument operators 501 | for k := range whereDocument { 502 | if !slices.Contains(supportedFilters, k) { 503 | return nil, errors.New("unsupported operator") 504 | } 505 | } 506 | 507 | // Filter docs by metadata and content 508 | filteredDocs := filterDocs(c.documents, where, whereDocument) 509 | 510 | // No need to continue if the filters got rid of all documents 511 | if len(filteredDocs) == 0 { 512 | return nil, nil 513 | } 514 | 515 | // Normalize embedding if not the case yet. We only support cosine similarity 516 | // for now and all documents were already normalized when added to the collection. 517 | if !isNormalized(queryEmbedding) { 518 | queryEmbedding = normalizeVector(queryEmbedding) 519 | } 520 | 521 | // If the filtering already reduced the number of documents to fewer than nResults, 522 | // we only need to find the most similar docs among the filtered ones. 523 | resLen := nResults 524 | if len(filteredDocs) < nResults { 525 | resLen = len(filteredDocs) 526 | } 527 | 528 | // For the remaining documents, get the most similar docs. 529 | nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, negativeEmbeddings, negativeFilterThreshold, filteredDocs, resLen) 530 | if err != nil { 531 | return nil, fmt.Errorf("couldn't get most similar docs: %w", err) 532 | } 533 | 534 | res := make([]Result, 0, len(nMaxDocs)) 535 | for i := 0; i < len(nMaxDocs); i++ { 536 | res = append(res, Result{ 537 | ID: nMaxDocs[i].docID, 538 | Metadata: c.documents[nMaxDocs[i].docID].Metadata, 539 | Embedding: c.documents[nMaxDocs[i].docID].Embedding, 540 | Content: c.documents[nMaxDocs[i].docID].Content, 541 | Similarity: nMaxDocs[i].similarity, 542 | }) 543 | } 544 | 545 | return res, nil 546 | } 547 | 548 | // getDocPath generates the path to the document file. 549 | func (c *Collection) getDocPath(docID string) string { 550 | safeID := hash2hex(docID) 551 | docPath := filepath.Join(c.persistDirectory, safeID) 552 | docPath += ".gob" 553 | if c.compress { 554 | docPath += ".gz" 555 | } 556 | return docPath 557 | } 558 | 559 | // persistMetadata persists the collection metadata to disk 560 | func (c *Collection) persistMetadata() error { 561 | // Persist name and metadata 562 | metadataPath := filepath.Join(c.persistDirectory, metadataFileName) 563 | metadataPath += ".gob" 564 | if c.compress { 565 | metadataPath += ".gz" 566 | } 567 | pc := struct { 568 | Name string 569 | Metadata map[string]string 570 | }{ 571 | Name: c.Name, 572 | Metadata: c.metadata, 573 | } 574 | err := persistToFile(metadataPath, pc, c.compress, "") 575 | if err != nil { 576 | return err 577 | } 578 | 579 | return nil 580 | } 581 | -------------------------------------------------------------------------------- /collection_test.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "math/rand" 7 | "os" 8 | "slices" 9 | "strconv" 10 | "testing" 11 | ) 12 | 13 | func TestCollection_Add(t *testing.T) { 14 | ctx := context.Background() 15 | name := "test" 16 | metadata := map[string]string{"foo": "bar"} 17 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 18 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 19 | return vectors, nil 20 | } 21 | 22 | // Create collection 23 | db := NewDB() 24 | c, err := db.CreateCollection(name, metadata, embeddingFunc) 25 | if err != nil { 26 | t.Fatal("expected no error, got", err) 27 | } 28 | if c == nil { 29 | t.Fatal("expected collection, got nil") 30 | } 31 | 32 | // Add documents 33 | 34 | ids := []string{"1", "2"} 35 | embeddings := [][]float32{vectors, vectors} 36 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} 37 | contents := []string{"hello world", "hallo welt"} 38 | 39 | tt := []struct { 40 | name string 41 | ids []string 42 | embeddings [][]float32 43 | metadatas []map[string]string 44 | contents []string 45 | }{ 46 | { 47 | name: "No embeddings", 48 | ids: ids, 49 | embeddings: nil, 50 | metadatas: metadatas, 51 | contents: contents, 52 | }, 53 | { 54 | name: "With embeddings", 55 | ids: ids, 56 | embeddings: embeddings, 57 | metadatas: metadatas, 58 | contents: contents, 59 | }, 60 | { 61 | name: "With embeddings but no contents", 62 | ids: ids, 63 | embeddings: embeddings, 64 | metadatas: metadatas, 65 | contents: nil, 66 | }, 67 | } 68 | 69 | for _, tc := range tt { 70 | t.Run(tc.name, func(t *testing.T) { 71 | err = c.Add(ctx, ids, nil, metadatas, contents) 72 | if err != nil { 73 | t.Fatal("expected nil, got", err) 74 | } 75 | 76 | // Check documents 77 | if len(c.documents) != 2 { 78 | t.Fatal("expected 2, got", len(c.documents)) 79 | } 80 | for i, id := range ids { 81 | doc, ok := c.documents[id] 82 | if !ok { 83 | t.Fatal("expected document, got nil") 84 | } 85 | if doc.ID != id { 86 | t.Fatal("expected", id, "got", doc.ID) 87 | } 88 | if len(doc.Metadata) != 1 { 89 | t.Fatal("expected 1, got", len(doc.Metadata)) 90 | } 91 | if !slices.Equal(doc.Embedding, vectors) { 92 | t.Fatal("expected", vectors, "got", doc.Embedding) 93 | } 94 | if doc.Content != contents[i] { 95 | t.Fatal("expected", contents[i], "got", doc.Content) 96 | } 97 | } 98 | // Metadata can't be accessed with the loop's i 99 | if c.documents[ids[0]].Metadata["foo"] != "bar" { 100 | t.Fatal("expected bar, got", c.documents[ids[0]].Metadata["foo"]) 101 | } 102 | if c.documents[ids[1]].Metadata["a"] != "b" { 103 | t.Fatal("expected b, got", c.documents[ids[1]].Metadata["a"]) 104 | } 105 | }) 106 | } 107 | } 108 | 109 | func TestCollection_Add_Error(t *testing.T) { 110 | ctx := context.Background() 111 | name := "test" 112 | metadata := map[string]string{"foo": "bar"} 113 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 114 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 115 | return vectors, nil 116 | } 117 | 118 | // Create collection 119 | db := NewDB() 120 | c, err := db.CreateCollection(name, metadata, embeddingFunc) 121 | if err != nil { 122 | t.Fatal("expected no error, got", err) 123 | } 124 | if c == nil { 125 | t.Fatal("expected collection, got nil") 126 | } 127 | 128 | // Add documents, provoking errors 129 | ids := []string{"1", "2"} 130 | embeddings := [][]float32{vectors, vectors} 131 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} 132 | contents := []string{"hello world", "hallo welt"} 133 | 134 | // Empty IDs 135 | err = c.Add(ctx, []string{}, embeddings, metadatas, contents) 136 | if err == nil { 137 | t.Fatal("expected error, got nil") 138 | } 139 | // Empty embeddings and contents (both at the same time!) 140 | err = c.Add(ctx, ids, [][]float32{}, metadatas, []string{}) 141 | if err == nil { 142 | t.Fatal("expected error, got nil") 143 | } 144 | // Bad embeddings length 145 | err = c.Add(ctx, ids, [][]float32{vectors}, metadatas, contents) 146 | if err == nil { 147 | t.Fatal("expected error, got nil") 148 | } 149 | // Bad metadatas length 150 | err = c.Add(ctx, ids, embeddings, []map[string]string{{"foo": "bar"}}, contents) 151 | if err == nil { 152 | t.Fatal("expected error, got nil") 153 | } 154 | // Bad contents length 155 | err = c.Add(ctx, ids, embeddings, metadatas, []string{"hello world"}) 156 | if err == nil { 157 | t.Fatal("expected error, got nil") 158 | } 159 | } 160 | 161 | func TestCollection_AddConcurrently(t *testing.T) { 162 | ctx := context.Background() 163 | name := "test" 164 | metadata := map[string]string{"foo": "bar"} 165 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 166 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 167 | return vectors, nil 168 | } 169 | 170 | // Create collection 171 | db := NewDB() 172 | c, err := db.CreateCollection(name, metadata, embeddingFunc) 173 | if err != nil { 174 | t.Fatal("expected no error, got", err) 175 | } 176 | if c == nil { 177 | t.Fatal("expected collection, got nil") 178 | } 179 | 180 | // Add documents 181 | 182 | ids := []string{"1", "2"} 183 | embeddings := [][]float32{vectors, vectors} 184 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} 185 | contents := []string{"hello world", "hallo welt"} 186 | 187 | tt := []struct { 188 | name string 189 | ids []string 190 | embeddings [][]float32 191 | metadatas []map[string]string 192 | contents []string 193 | }{ 194 | { 195 | name: "No embeddings", 196 | ids: ids, 197 | embeddings: nil, 198 | metadatas: metadatas, 199 | contents: contents, 200 | }, 201 | { 202 | name: "With embeddings", 203 | ids: ids, 204 | embeddings: embeddings, 205 | metadatas: metadatas, 206 | contents: contents, 207 | }, 208 | { 209 | name: "With embeddings but no contents", 210 | ids: ids, 211 | embeddings: embeddings, 212 | metadatas: metadatas, 213 | contents: nil, 214 | }, 215 | } 216 | 217 | for _, tc := range tt { 218 | t.Run(tc.name, func(t *testing.T) { 219 | err = c.AddConcurrently(ctx, ids, nil, metadatas, contents, 2) 220 | if err != nil { 221 | t.Fatal("expected nil, got", err) 222 | } 223 | 224 | // Check documents 225 | if len(c.documents) != 2 { 226 | t.Fatal("expected 2, got", len(c.documents)) 227 | } 228 | for i, id := range ids { 229 | doc, ok := c.documents[id] 230 | if !ok { 231 | t.Fatal("expected document, got nil") 232 | } 233 | if doc.ID != id { 234 | t.Fatal("expected", id, "got", doc.ID) 235 | } 236 | if len(doc.Metadata) != 1 { 237 | t.Fatal("expected 1, got", len(doc.Metadata)) 238 | } 239 | if !slices.Equal(doc.Embedding, vectors) { 240 | t.Fatal("expected", vectors, "got", doc.Embedding) 241 | } 242 | if doc.Content != contents[i] { 243 | t.Fatal("expected", contents[i], "got", doc.Content) 244 | } 245 | } 246 | // Metadata can't be accessed with the loop's i 247 | if c.documents[ids[0]].Metadata["foo"] != "bar" { 248 | t.Fatal("expected bar, got", c.documents[ids[0]].Metadata["foo"]) 249 | } 250 | if c.documents[ids[1]].Metadata["a"] != "b" { 251 | t.Fatal("expected b, got", c.documents[ids[1]].Metadata["a"]) 252 | } 253 | }) 254 | } 255 | } 256 | 257 | func TestCollection_AddConcurrently_Error(t *testing.T) { 258 | ctx := context.Background() 259 | name := "test" 260 | metadata := map[string]string{"foo": "bar"} 261 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 262 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 263 | return vectors, nil 264 | } 265 | 266 | // Create collection 267 | db := NewDB() 268 | c, err := db.CreateCollection(name, metadata, embeddingFunc) 269 | if err != nil { 270 | t.Fatal("expected no error, got", err) 271 | } 272 | if c == nil { 273 | t.Fatal("expected collection, got nil") 274 | } 275 | 276 | // Add documents, provoking errors 277 | ids := []string{"1", "2"} 278 | embeddings := [][]float32{vectors, vectors} 279 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} 280 | contents := []string{"hello world", "hallo welt"} 281 | // Empty IDs 282 | err = c.AddConcurrently(ctx, []string{}, embeddings, metadatas, contents, 2) 283 | if err == nil { 284 | t.Fatal("expected error, got nil") 285 | } 286 | // Empty embeddings and contents (both at the same time!) 287 | err = c.AddConcurrently(ctx, ids, [][]float32{}, metadatas, []string{}, 2) 288 | if err == nil { 289 | t.Fatal("expected error, got nil") 290 | } 291 | // Bad embeddings length 292 | err = c.AddConcurrently(ctx, ids, [][]float32{vectors}, metadatas, contents, 2) 293 | if err == nil { 294 | t.Fatal("expected error, got nil") 295 | } 296 | // Bad metadatas length 297 | err = c.AddConcurrently(ctx, ids, embeddings, []map[string]string{{"foo": "bar"}}, contents, 2) 298 | if err == nil { 299 | t.Fatal("expected error, got nil") 300 | } 301 | // Bad contents length 302 | err = c.AddConcurrently(ctx, ids, embeddings, metadatas, []string{"hello world"}, 2) 303 | if err == nil { 304 | t.Fatal("expected error, got nil") 305 | } 306 | // Bad concurrency 307 | err = c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 0) 308 | if err == nil { 309 | t.Fatal("expected error, got nil") 310 | } 311 | } 312 | 313 | func TestCollection_QueryError(t *testing.T) { 314 | // Create collection 315 | db := NewDB() 316 | name := "test" 317 | metadata := map[string]string{"foo": "bar"} 318 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 319 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 320 | return vectors, nil 321 | } 322 | c, err := db.CreateCollection(name, metadata, embeddingFunc) 323 | if err != nil { 324 | t.Fatal("expected no error, got", err) 325 | } 326 | if c == nil { 327 | t.Fatal("expected collection, got nil") 328 | } 329 | // Add a document 330 | err = c.AddDocument(context.Background(), Document{ID: "1", Content: "hello world"}) 331 | if err != nil { 332 | t.Fatal("expected nil, got", err) 333 | } 334 | 335 | tt := []struct { 336 | name string 337 | query func() error 338 | expErr string 339 | }{ 340 | { 341 | name: "Empty query", 342 | query: func() error { 343 | _, err := c.Query(context.Background(), "", 1, nil, nil) 344 | return err 345 | }, 346 | expErr: "queryText is empty", 347 | }, 348 | { 349 | name: "Negative limit", 350 | query: func() error { 351 | _, err := c.Query(context.Background(), "foo", -1, nil, nil) 352 | return err 353 | }, 354 | expErr: "nResults must be > 0", 355 | }, 356 | { 357 | name: "Zero limit", 358 | query: func() error { 359 | _, err := c.Query(context.Background(), "foo", 0, nil, nil) 360 | return err 361 | }, 362 | expErr: "nResults must be > 0", 363 | }, 364 | { 365 | name: "Limit greater than number of documents", 366 | query: func() error { 367 | _, err := c.Query(context.Background(), "foo", 2, nil, nil) 368 | return err 369 | }, 370 | expErr: "nResults must be <= the number of documents in the collection", 371 | }, 372 | { 373 | name: "Bad content filter", 374 | query: func() error { 375 | _, err := c.Query(context.Background(), "foo", 1, nil, map[string]string{"invalid": "foo"}) 376 | return err 377 | }, 378 | expErr: "unsupported operator", 379 | }, 380 | } 381 | 382 | for _, tc := range tt { 383 | t.Run(tc.name, func(t *testing.T) { 384 | err := tc.query() 385 | if err == nil { 386 | t.Fatal("expected error, got nil") 387 | } else if err.Error() != tc.expErr { 388 | t.Fatal("expected", tc.expErr, "got", err) 389 | } 390 | }) 391 | } 392 | } 393 | 394 | func TestCollection_Get(t *testing.T) { 395 | ctx := context.Background() 396 | 397 | // Create collection 398 | db := NewDB() 399 | name := "test" 400 | metadata := map[string]string{"foo": "bar"} 401 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 402 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 403 | return vectors, nil 404 | } 405 | c, err := db.CreateCollection(name, metadata, embeddingFunc) 406 | if err != nil { 407 | t.Fatal("expected no error, got", err) 408 | } 409 | if c == nil { 410 | t.Fatal("expected collection, got nil") 411 | } 412 | 413 | // Add documents 414 | ids := []string{"1", "2"} 415 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} 416 | contents := []string{"hello world", "hallo welt"} 417 | err = c.Add(context.Background(), ids, nil, metadatas, contents) 418 | if err != nil { 419 | t.Fatal("expected nil, got", err) 420 | } 421 | 422 | // Get by ID 423 | doc, err := c.GetByID(ctx, ids[0]) 424 | if err != nil { 425 | t.Fatal("expected nil, got", err) 426 | } 427 | // Check fields 428 | if doc.ID != ids[0] { 429 | t.Fatal("expected", ids[0], "got", doc.ID) 430 | } 431 | if len(doc.Metadata) != 1 { 432 | t.Fatal("expected 1, got", len(doc.Metadata)) 433 | } 434 | if !slices.Equal(doc.Embedding, vectors) { 435 | t.Fatal("expected", vectors, "got", doc.Embedding) 436 | } 437 | if doc.Content != contents[0] { 438 | t.Fatal("expected", contents[0], "got", doc.Content) 439 | } 440 | 441 | // Check error 442 | _, err = c.GetByID(ctx, "3") 443 | if err == nil { 444 | t.Fatal("expected error, got nil") 445 | } 446 | } 447 | 448 | func TestCollection_Count(t *testing.T) { 449 | // Create collection 450 | db := NewDB() 451 | name := "test" 452 | metadata := map[string]string{"foo": "bar"} 453 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 454 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 455 | return vectors, nil 456 | } 457 | c, err := db.CreateCollection(name, metadata, embeddingFunc) 458 | if err != nil { 459 | t.Fatal("expected no error, got", err) 460 | } 461 | if c == nil { 462 | t.Fatal("expected collection, got nil") 463 | } 464 | 465 | // Add documents 466 | ids := []string{"1", "2"} 467 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} 468 | contents := []string{"hello world", "hallo welt"} 469 | err = c.Add(context.Background(), ids, nil, metadatas, contents) 470 | if err != nil { 471 | t.Fatal("expected nil, got", err) 472 | } 473 | 474 | // Check count 475 | if c.Count() != 2 { 476 | t.Fatal("expected 2, got", c.Count()) 477 | } 478 | } 479 | 480 | func TestCollection_Delete(t *testing.T) { 481 | // Create persistent collection 482 | tmpdir, err := os.MkdirTemp(os.TempDir(), "chromem-test-*") 483 | if err != nil { 484 | t.Fatal("expected no error, got", err) 485 | } 486 | db, err := NewPersistentDB(tmpdir, false) 487 | if err != nil { 488 | t.Fatal("expected no error, got", err) 489 | } 490 | name := "test" 491 | metadata := map[string]string{"foo": "bar"} 492 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 493 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 494 | return vectors, nil 495 | } 496 | c, err := db.CreateCollection(name, metadata, embeddingFunc) 497 | if err != nil { 498 | t.Fatal("expected no error, got", err) 499 | } 500 | if c == nil { 501 | t.Fatal("expected collection, got nil") 502 | } 503 | 504 | // Add documents 505 | ids := []string{"1", "2", "3", "4"} 506 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}, {"foo": "bar"}, {"e": "f"}} 507 | contents := []string{"hello world", "hallo welt", "bonjour le monde", "hola mundo"} 508 | err = c.Add(context.Background(), ids, nil, metadatas, contents) 509 | if err != nil { 510 | t.Fatal("expected nil, got", err) 511 | } 512 | 513 | // Check count 514 | if c.Count() != 4 { 515 | t.Fatal("expected 4 documents, got", c.Count()) 516 | } 517 | 518 | // Check number of files in the persist directory 519 | d, err := os.ReadDir(c.persistDirectory) 520 | if err != nil { 521 | t.Fatal("expected nil, got", err) 522 | } 523 | if len(d) != 5 { // 4 documents + 1 metadata file 524 | t.Fatal("expected 4 document files + 1 metadata file in persist_dir, got", len(d)) 525 | } 526 | 527 | checkCount := func(expected int) { 528 | // Check count 529 | if c.Count() != expected { 530 | t.Fatalf("expected %d documents, got %d", expected, c.Count()) 531 | } 532 | 533 | // Check number of files in the persist directory 534 | d, err = os.ReadDir(c.persistDirectory) 535 | if err != nil { 536 | t.Fatal("expected nil, got", err) 537 | } 538 | if len(d) != expected+1 { // 3 document + 1 metadata file 539 | t.Fatalf("expected %d document files + 1 metadata file in persist_dir, got %d", expected, len(d)) 540 | } 541 | } 542 | 543 | // Test 1 - Remove document by ID: should delete one document 544 | err = c.Delete(context.Background(), nil, nil, "4") 545 | if err != nil { 546 | t.Fatal("expected nil, got", err) 547 | } 548 | checkCount(3) 549 | 550 | // Test 2 - Remove document by metadata 551 | err = c.Delete(context.Background(), map[string]string{"foo": "bar"}, nil) 552 | if err != nil { 553 | t.Fatal("expected nil, got", err) 554 | } 555 | 556 | checkCount(1) 557 | 558 | // Test 3 - Remove document by content 559 | err = c.Delete(context.Background(), nil, map[string]string{"$contains": "hallo welt"}) 560 | if err != nil { 561 | t.Fatal("expected nil, got", err) 562 | } 563 | 564 | checkCount(0) 565 | } 566 | 567 | // Global var for assignment in the benchmark to avoid compiler optimizations. 568 | var globalRes []Result 569 | 570 | func BenchmarkCollection_Query_NoContent_100(b *testing.B) { 571 | benchmarkCollection_Query(b, 100, false) 572 | } 573 | 574 | func BenchmarkCollection_Query_NoContent_1000(b *testing.B) { 575 | benchmarkCollection_Query(b, 1000, false) 576 | } 577 | 578 | func BenchmarkCollection_Query_NoContent_5000(b *testing.B) { 579 | benchmarkCollection_Query(b, 5000, false) 580 | } 581 | 582 | func BenchmarkCollection_Query_NoContent_25000(b *testing.B) { 583 | benchmarkCollection_Query(b, 25000, false) 584 | } 585 | 586 | func BenchmarkCollection_Query_NoContent_100000(b *testing.B) { 587 | benchmarkCollection_Query(b, 100_000, false) 588 | } 589 | 590 | func BenchmarkCollection_Query_100(b *testing.B) { 591 | benchmarkCollection_Query(b, 100, true) 592 | } 593 | 594 | func BenchmarkCollection_Query_1000(b *testing.B) { 595 | benchmarkCollection_Query(b, 1000, true) 596 | } 597 | 598 | func BenchmarkCollection_Query_5000(b *testing.B) { 599 | benchmarkCollection_Query(b, 5000, true) 600 | } 601 | 602 | func BenchmarkCollection_Query_25000(b *testing.B) { 603 | benchmarkCollection_Query(b, 25000, true) 604 | } 605 | 606 | func BenchmarkCollection_Query_100000(b *testing.B) { 607 | benchmarkCollection_Query(b, 100_000, true) 608 | } 609 | 610 | // n is number of documents in the collection 611 | func benchmarkCollection_Query(b *testing.B, n int, withContent bool) { 612 | ctx := context.Background() 613 | 614 | // Seed to make deterministic 615 | r := rand.New(rand.NewSource(42)) 616 | 617 | d := 1536 // dimensions, same as text-embedding-3-small 618 | // Random query vector 619 | qv := make([]float32, d) 620 | for j := 0; j < d; j++ { 621 | qv[j] = r.Float32() 622 | } 623 | // The document embeddings are normalized, so the query must be normalized too. 624 | qv = normalizeVector(qv) 625 | 626 | // Create collection 627 | db := NewDB() 628 | name := "test" 629 | embeddingFunc := func(_ context.Context, text string) ([]float32, error) { 630 | return nil, errors.New("embedding func not expected to be called") 631 | } 632 | c, err := db.CreateCollection(name, nil, embeddingFunc) 633 | if err != nil { 634 | b.Fatal("expected no error, got", err) 635 | } 636 | if c == nil { 637 | b.Fatal("expected collection, got nil") 638 | } 639 | 640 | // Add documents 641 | for i := 0; i < n; i++ { 642 | // Random embedding 643 | v := make([]float32, d) 644 | for j := 0; j < d; j++ { 645 | v[j] = r.Float32() 646 | } 647 | v = normalizeVector(v) 648 | 649 | // Add document with some metadata and content depending on parameter. 650 | // When providing embeddings, the embedding func is not called. 651 | is := strconv.Itoa(i) 652 | doc := Document{ 653 | ID: is, 654 | Metadata: map[string]string{"i": is, "foo": "bar" + is}, 655 | Embedding: v, 656 | } 657 | if withContent { 658 | // Let's say we embed 500 tokens, that's ~375 words, ~1875 characters 659 | doc.Content = randomString(r, 1875) 660 | } 661 | 662 | if err := c.AddDocument(ctx, doc); err != nil { 663 | b.Fatal("expected nil, got", err) 664 | } 665 | } 666 | 667 | b.ResetTimer() 668 | 669 | // Query 670 | var res []Result 671 | for i := 0; i < b.N; i++ { 672 | res, err = c.QueryEmbedding(ctx, qv, 10, nil, nil) 673 | } 674 | if err != nil { 675 | b.Fatal("expected nil, got", err) 676 | } 677 | globalRes = res 678 | } 679 | 680 | // randomString returns a random string of length n using lowercase letters and space. 681 | func randomString(r *rand.Rand, n int) string { 682 | // We add 5 spaces to get roughly one space every 5 characters 683 | characters := []rune("abcdefghijklmnopqrstuvwxyz ") 684 | 685 | b := make([]rune, n) 686 | for i := range b { 687 | b[i] = characters[r.Intn(len(characters))] 688 | } 689 | return string(b) 690 | } 691 | -------------------------------------------------------------------------------- /db_test.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "context" 5 | "math/rand" 6 | "os" 7 | "path/filepath" 8 | "reflect" 9 | "slices" 10 | "testing" 11 | ) 12 | 13 | func TestNewPersistentDB(t *testing.T) { 14 | t.Run("Create directory", func(t *testing.T) { 15 | r := rand.New(rand.NewSource(rand.Int63())) 16 | randString := randomString(r, 10) 17 | path := filepath.Join(os.TempDir(), randString) 18 | defer os.RemoveAll(path) 19 | 20 | // Path shouldn't exist yet 21 | if _, err := os.Stat(path); !os.IsNotExist(err) { 22 | t.Fatal("expected path to not exist, got", err) 23 | } 24 | 25 | db, err := NewPersistentDB(path, false) 26 | if err != nil { 27 | t.Fatal("expected no error, got", err) 28 | } 29 | if db == nil { 30 | t.Fatal("expected DB, got nil") 31 | } 32 | 33 | // Path should exist now 34 | if _, err := os.Stat(path); err != nil { 35 | t.Fatal("expected path to exist, got", err) 36 | } 37 | }) 38 | t.Run("Existing directory", func(t *testing.T) { 39 | path, err := os.MkdirTemp(os.TempDir(), "") 40 | if err != nil { 41 | t.Fatal("couldn't create temp dir:", err) 42 | } 43 | defer os.RemoveAll(path) 44 | 45 | db, err := NewPersistentDB(path, false) 46 | if err != nil { 47 | t.Fatal("expected no error, got", err) 48 | } 49 | if db == nil { 50 | t.Fatal("expected DB, got nil") 51 | } 52 | }) 53 | } 54 | 55 | func TestNewPersistentDB_Errors(t *testing.T) { 56 | t.Run("Path is an existing file", func(t *testing.T) { 57 | f, err := os.CreateTemp(os.TempDir(), "") 58 | if err != nil { 59 | t.Fatal("couldn't create temp file:", err) 60 | } 61 | defer os.RemoveAll(f.Name()) 62 | 63 | _, err = NewPersistentDB(f.Name(), false) 64 | if err == nil { 65 | t.Fatal("expected error, got nil") 66 | } 67 | }) 68 | } 69 | 70 | func TestDB_ImportExport(t *testing.T) { 71 | r := rand.New(rand.NewSource(rand.Int63())) 72 | randString := randomString(r, 10) 73 | path := filepath.Join(os.TempDir(), randString) 74 | defer os.RemoveAll(path) 75 | 76 | // Values in the collection 77 | name := "test" 78 | metadata := map[string]string{"foo": "bar"} 79 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 80 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 81 | return vectors, nil 82 | } 83 | 84 | tt := []struct { 85 | name string 86 | filePath string 87 | compress bool 88 | encryptionKey string 89 | }{ 90 | { 91 | name: "gob", 92 | filePath: path + ".gob", 93 | compress: false, 94 | encryptionKey: "", 95 | }, 96 | { 97 | name: "gob compressed", 98 | filePath: path + ".gob.gz", 99 | compress: true, 100 | encryptionKey: "", 101 | }, 102 | { 103 | name: "gob compressed encrypted", 104 | filePath: path + ".gob.gz.enc", 105 | compress: true, 106 | encryptionKey: randomString(r, 32), 107 | }, 108 | { 109 | name: "gob encrypted", 110 | filePath: path + ".gob.enc", 111 | compress: false, 112 | encryptionKey: randomString(r, 32), 113 | }, 114 | } 115 | 116 | for _, tc := range tt { 117 | t.Run(tc.name, func(t *testing.T) { 118 | // Create DB, can just be in-memory 119 | origDB := NewDB() 120 | 121 | // Create collection 122 | c, err := origDB.CreateCollection(name, metadata, embeddingFunc) 123 | if err != nil { 124 | t.Fatal("expected no error, got", err) 125 | } 126 | if c == nil { 127 | t.Fatal("expected collection, got nil") 128 | } 129 | // Add document 130 | doc := Document{ 131 | ID: name, 132 | Metadata: metadata, 133 | Embedding: vectors, 134 | Content: "test", 135 | } 136 | err = c.AddDocument(context.Background(), doc) 137 | if err != nil { 138 | t.Fatal("expected no error, got", err) 139 | } 140 | 141 | // Export 142 | err = origDB.ExportToFile(tc.filePath, tc.compress, tc.encryptionKey) 143 | if err != nil { 144 | t.Fatal("expected no error, got", err) 145 | } 146 | 147 | newDB := NewDB() 148 | 149 | // Import 150 | err = newDB.ImportFromFile(tc.filePath, tc.encryptionKey) 151 | if err != nil { 152 | t.Fatal("expected no error, got", err) 153 | } 154 | 155 | // Check expectations 156 | // We have to reset the embed function, but otherwise the DB objects 157 | // should be deep equal. 158 | c.embed = nil 159 | if !reflect.DeepEqual(origDB, newDB) { 160 | t.Fatalf("expected DB %+v, got %+v", origDB, newDB) 161 | } 162 | }) 163 | } 164 | } 165 | 166 | func TestDB_ImportExportSpecificCollections(t *testing.T) { 167 | r := rand.New(rand.NewSource(rand.Int63())) 168 | randString := randomString(r, 10) 169 | path := filepath.Join(os.TempDir(), randString) 170 | filePath := path + ".gob" 171 | defer os.RemoveAll(path) 172 | 173 | // Values in the collection 174 | name := "test" 175 | name2 := "test2" 176 | metadata := map[string]string{"foo": "bar"} 177 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 178 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 179 | return vectors, nil 180 | } 181 | 182 | // Create DB, can just be in-memory 183 | origDB := NewDB() 184 | 185 | // Create collections 186 | c, err := origDB.CreateCollection(name, metadata, embeddingFunc) 187 | if err != nil { 188 | t.Fatal("expected no error, got", err) 189 | } 190 | 191 | c2, err := origDB.CreateCollection(name2, metadata, embeddingFunc) 192 | if err != nil { 193 | t.Fatal("expected no error, got", err) 194 | } 195 | 196 | // Add documents 197 | doc := Document{ 198 | ID: name, 199 | Metadata: metadata, 200 | Embedding: vectors, 201 | Content: "test", 202 | } 203 | 204 | doc2 := Document{ 205 | ID: name2, 206 | Metadata: metadata, 207 | Embedding: vectors, 208 | Content: "test2", 209 | } 210 | 211 | err = c.AddDocument(context.Background(), doc) 212 | if err != nil { 213 | t.Fatal("expected no error, got", err) 214 | } 215 | 216 | err = c2.AddDocument(context.Background(), doc2) 217 | if err != nil { 218 | t.Fatal("expected no error, got", err) 219 | } 220 | 221 | // Export only one of the two collections 222 | err = origDB.ExportToFile(filePath, false, "", name2) 223 | if err != nil { 224 | t.Fatal("expected no error, got", err) 225 | } 226 | 227 | dir := filepath.Join(path, randomString(r, 10)) 228 | defer os.RemoveAll(dir) 229 | 230 | // Instead of importing to an in-memory DB we use a persistent one to cover the behavior of immediate persistent files being created for the imported data 231 | newPDB, err := NewPersistentDB(dir, false) 232 | if err != nil { 233 | t.Fatal("expected no error, got", err) 234 | } 235 | 236 | err = newPDB.ImportFromFile(filePath, "") 237 | if err != nil { 238 | t.Fatal("expected no error, got", err) 239 | } 240 | 241 | if len(newPDB.collections) != 1 { 242 | t.Fatalf("expected 1 collection, got %d", len(newPDB.collections)) 243 | } 244 | 245 | // Make sure that the imported documents are actually persisted on disk 246 | for _, col := range newPDB.collections { 247 | for _, d := range col.documents { 248 | _, err = os.Stat(col.getDocPath(d.ID)) 249 | if err != nil { 250 | t.Fatalf("expected no error when looking up persistent file for doc %q, got %v", d.ID, err) 251 | } 252 | } 253 | } 254 | 255 | // Now export both collections and import them into the same persistent DB (overwriting the one we just imported) 256 | filePath2 := filepath.Join(path, "2.gob") 257 | err = origDB.ExportToFile(filePath2, false, "") 258 | if err != nil { 259 | t.Fatal("expected no error, got", err) 260 | } 261 | 262 | err = newPDB.ImportFromFile(filePath2, "") 263 | if err != nil { 264 | t.Fatal("expected no error, got", err) 265 | } 266 | 267 | if len(newPDB.collections) != 2 { 268 | t.Fatalf("expected 2 collections, got %d", len(newPDB.collections)) 269 | } 270 | 271 | // Make sure that the imported documents are actually persisted on disk 272 | for _, col := range newPDB.collections { 273 | for _, d := range col.documents { 274 | _, err = os.Stat(col.getDocPath(d.ID)) 275 | if err != nil { 276 | t.Fatalf("expected no error when looking up persistent file for doc %q, got %v", d.ID, err) 277 | } 278 | } 279 | } 280 | } 281 | 282 | func TestDB_CreateCollection(t *testing.T) { 283 | // Values in the collection 284 | name := "test" 285 | metadata := map[string]string{"foo": "bar"} 286 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 287 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 288 | return vectors, nil 289 | } 290 | 291 | db := NewDB() 292 | 293 | t.Run("OK", func(t *testing.T) { 294 | c, err := db.CreateCollection(name, metadata, embeddingFunc) 295 | if err != nil { 296 | t.Fatal("expected no error, got", err) 297 | } 298 | if c == nil { 299 | t.Fatal("expected collection, got nil") 300 | } 301 | 302 | // Check expectations 303 | 304 | // DB should have one collection now 305 | if len(db.collections) != 1 { 306 | t.Fatal("expected 1 collection, got", len(db.collections)) 307 | } 308 | // The collection should be the one we just created 309 | c2, ok := db.collections[name] 310 | if !ok { 311 | t.Fatal("expected collection", name, "not found") 312 | } 313 | // Check the embedding function first, then the rest with DeepEqual 314 | gotVectors, err := c.embed(context.Background(), "test") 315 | if err != nil { 316 | t.Fatal("expected no error, got", err) 317 | } 318 | if !slices.Equal(gotVectors, vectors) { 319 | t.Fatal("expected vectors", vectors, "got", gotVectors) 320 | } 321 | c.embed, c2.embed = nil, nil 322 | if !reflect.DeepEqual(c, c2) { 323 | t.Fatalf("expected collection %+v, got %+v", c, c2) 324 | } 325 | }) 326 | 327 | t.Run("NOK - Empty name", func(t *testing.T) { 328 | _, err := db.CreateCollection("", metadata, embeddingFunc) 329 | if err == nil { 330 | t.Fatal("expected error, got nil") 331 | } 332 | }) 333 | } 334 | 335 | func TestDB_ListCollections(t *testing.T) { 336 | // Values in the collection 337 | name := "test" 338 | metadata := map[string]string{"foo": "bar"} 339 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 340 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 341 | return vectors, nil 342 | } 343 | 344 | // Create initial collection 345 | db := NewDB() 346 | orig, err := db.CreateCollection(name, metadata, embeddingFunc) 347 | if err != nil { 348 | t.Fatal("expected no error, got", err) 349 | } 350 | 351 | // List collections 352 | res := db.ListCollections() 353 | 354 | // Check expectations 355 | 356 | // Should've returned a map with one collection 357 | if len(res) != 1 { 358 | t.Fatal("expected 1 collection, got", len(res)) 359 | } 360 | // The collection should be the one we just created 361 | c, ok := res[name] 362 | if !ok { 363 | t.Fatal("expected collection", name, "not found") 364 | } 365 | // Check the embedding function first, then the rest with DeepEqual 366 | gotVectors, err := c.embed(context.Background(), "test") 367 | if err != nil { 368 | t.Fatal("expected no error, got", err) 369 | } 370 | if !slices.Equal(gotVectors, vectors) { 371 | t.Fatal("expected vectors", vectors, "got", gotVectors) 372 | } 373 | orig.embed, c.embed = nil, nil 374 | if !reflect.DeepEqual(orig, c) { 375 | t.Fatalf("expected collection %+v, got %+v", orig, c) 376 | } 377 | 378 | // And it should be a copy. Adding a value here should not reflect on the DB's 379 | // collection. 380 | res["foo"] = &Collection{} 381 | if len(db.ListCollections()) != 1 { 382 | t.Fatal("expected 1 collection, got", len(db.ListCollections())) 383 | } 384 | } 385 | 386 | func TestDB_GetCollection(t *testing.T) { 387 | // Values in the collection 388 | name := "test" 389 | metadata := map[string]string{"foo": "bar"} 390 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 391 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 392 | return vectors, nil 393 | } 394 | 395 | // Create initial collection 396 | db := NewDB() 397 | orig, err := db.CreateCollection(name, metadata, embeddingFunc) 398 | if err != nil { 399 | t.Fatal("expected no error, got", err) 400 | } 401 | 402 | // Get collection 403 | c := db.GetCollection(name, nil) 404 | 405 | // Check the embedding function first, then the rest with DeepEqual 406 | gotVectors, err := c.embed(context.Background(), "test") 407 | if err != nil { 408 | t.Fatal("expected no error, got", err) 409 | } 410 | if !slices.Equal(gotVectors, vectors) { 411 | t.Fatal("expected vectors", vectors, "got", gotVectors) 412 | } 413 | orig.embed, c.embed = nil, nil 414 | if !reflect.DeepEqual(orig, c) { 415 | t.Fatalf("expected collection %+v, got %+v", orig, c) 416 | } 417 | } 418 | 419 | func TestDB_GetOrCreateCollection(t *testing.T) { 420 | // Values in the collection 421 | name := "test" 422 | metadata := map[string]string{"foo": "bar"} 423 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 424 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 425 | return vectors, nil 426 | } 427 | 428 | t.Run("Get", func(t *testing.T) { 429 | // Create initial collection 430 | db := NewDB() 431 | // Create collection so that the GetOrCreateCollection() call below only 432 | // gets it. 433 | orig, err := db.CreateCollection(name, metadata, embeddingFunc) 434 | if err != nil { 435 | t.Fatal("expected no error, got", err) 436 | } 437 | 438 | // Call GetOrCreateCollection() with the same name to only get it. We pass 439 | // nil for the metadata and embeddingFunc, so we can check that the returned 440 | // collection is the original one, and not a new one. 441 | c, err := db.GetOrCreateCollection(name, nil, nil) 442 | if err != nil { 443 | t.Fatal("expected no error, got", err) 444 | } 445 | if c == nil { 446 | t.Fatal("expected collection, got nil") 447 | } 448 | 449 | // Check the embedding function first, then the rest with DeepEqual 450 | gotVectors, err := c.embed(context.Background(), "test") 451 | if err != nil { 452 | t.Fatal("expected no error, got", err) 453 | } 454 | if !slices.Equal(gotVectors, vectors) { 455 | t.Fatal("expected vectors", vectors, "got", gotVectors) 456 | } 457 | orig.embed, c.embed = nil, nil 458 | if !reflect.DeepEqual(orig, c) { 459 | t.Fatalf("expected collection %+v, got %+v", orig, c) 460 | } 461 | }) 462 | 463 | t.Run("Create", func(t *testing.T) { 464 | // Create initial collection 465 | db := NewDB() 466 | 467 | // Call GetOrCreateCollection() 468 | c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc) 469 | if err != nil { 470 | t.Fatal("expected no error, got", err) 471 | } 472 | if c == nil { 473 | t.Fatal("expected collection, got nil") 474 | } 475 | 476 | // Check like we check CreateCollection() 477 | c2, ok := db.collections[name] 478 | if !ok { 479 | t.Fatal("expected collection", name, "not found") 480 | } 481 | gotVectors, err := c.embed(context.Background(), "test") 482 | if err != nil { 483 | t.Fatal("expected no error, got", err) 484 | } 485 | if !slices.Equal(gotVectors, vectors) { 486 | t.Fatal("expected vectors", vectors, "got", gotVectors) 487 | } 488 | c.embed, c2.embed = nil, nil 489 | if !reflect.DeepEqual(c, c2) { 490 | t.Fatalf("expected collection %+v, got %+v", c, c2) 491 | } 492 | }) 493 | } 494 | 495 | func TestDB_DeleteCollection(t *testing.T) { 496 | // Values in the collection 497 | name := "test" 498 | metadata := map[string]string{"foo": "bar"} 499 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 500 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 501 | return vectors, nil 502 | } 503 | 504 | // Create initial collection 505 | db := NewDB() 506 | // We ignore the return value. CreateCollection is tested elsewhere. 507 | _, err := db.CreateCollection(name, metadata, embeddingFunc) 508 | if err != nil { 509 | t.Fatal("expected no error, got", err) 510 | } 511 | 512 | // Delete collection 513 | if err := db.DeleteCollection(name); err != nil { 514 | t.Fatal("expected no error, got", err) 515 | } 516 | 517 | // Check expectations 518 | // We don't have access to the documents field, but we can rely on DB.ListCollections() 519 | // because it's tested elsewhere. 520 | if len(db.ListCollections()) != 0 { 521 | t.Fatal("expected 0 collections, got", len(db.ListCollections())) 522 | } 523 | // Also check internally 524 | if len(db.collections) != 0 { 525 | t.Fatal("expected 0 collections, got", len(db.collections)) 526 | } 527 | } 528 | 529 | func TestDB_Reset(t *testing.T) { 530 | // Values in the collection 531 | name := "test" 532 | metadata := map[string]string{"foo": "bar"} 533 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 534 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 535 | return vectors, nil 536 | } 537 | 538 | // Create initial collection 539 | db := NewDB() 540 | // We ignore the return value. CreateCollection is tested elsewhere. 541 | _, err := db.CreateCollection(name, metadata, embeddingFunc) 542 | if err != nil { 543 | t.Fatal("expected no error, got", err) 544 | } 545 | 546 | // Reset DB 547 | if err := db.Reset(); err != nil { 548 | t.Fatal("expected no error, got", err) 549 | } 550 | 551 | // Check expectations 552 | // We don't have access to the documents field, but we can rely on DB.ListCollections() 553 | // because it's tested elsewhere. 554 | if len(db.ListCollections()) != 0 { 555 | t.Fatal("expected 0 collections, got", len(db.ListCollections())) 556 | } 557 | // Also check internally 558 | if len(db.collections) != 0 { 559 | t.Fatal("expected 0 collections, got", len(db.collections)) 560 | } 561 | } 562 | -------------------------------------------------------------------------------- /document.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | ) 7 | 8 | // Document represents a single document. 9 | type Document struct { 10 | ID string 11 | Metadata map[string]string 12 | Embedding []float32 13 | Content string 14 | 15 | // ⚠️ When adding unexported fields here, consider adding a persistence struct 16 | // version of this in [DB.Export] and [DB.Import]. 17 | } 18 | 19 | // NewDocument creates a new document, including its embeddings. 20 | // Metadata is optional. 21 | // If the embeddings are not provided, they are created using the embedding function. 22 | // You can leave the content empty if you only want to store embeddings. 23 | // If embeddingFunc is nil, the default embedding function is used. 24 | // 25 | // If you want to create a document without embeddings, for example to let [Collection.AddDocuments] 26 | // create them concurrently, you can create a document with `chromem.Document{...}` 27 | // instead of using this constructor. 28 | func NewDocument(ctx context.Context, id string, metadata map[string]string, embedding []float32, content string, embeddingFunc EmbeddingFunc) (Document, error) { 29 | if id == "" { 30 | return Document{}, errors.New("id is empty") 31 | } 32 | if len(embedding) == 0 && content == "" { 33 | return Document{}, errors.New("either embedding or content must be filled") 34 | } 35 | if embeddingFunc == nil { 36 | embeddingFunc = NewEmbeddingFuncDefault() 37 | } 38 | 39 | if len(embedding) == 0 { 40 | var err error 41 | embedding, err = embeddingFunc(ctx, content) 42 | if err != nil { 43 | return Document{}, err 44 | } 45 | } 46 | 47 | return Document{ 48 | ID: id, 49 | Metadata: metadata, 50 | Embedding: embedding, 51 | Content: content, 52 | }, nil 53 | } 54 | -------------------------------------------------------------------------------- /document_test.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "context" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestDocument_New(t *testing.T) { 10 | ctx := context.Background() 11 | id := "test" 12 | metadata := map[string]string{"foo": "bar"} 13 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 14 | content := "hello world" 15 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { 16 | return vectors, nil 17 | } 18 | 19 | tt := []struct { 20 | name string 21 | id string 22 | metadata map[string]string 23 | vectors []float32 24 | content string 25 | embeddingFunc EmbeddingFunc 26 | }{ 27 | { 28 | name: "No embedding", 29 | id: id, 30 | metadata: metadata, 31 | vectors: nil, 32 | content: content, 33 | embeddingFunc: embeddingFunc, 34 | }, 35 | { 36 | name: "With embedding", 37 | id: id, 38 | metadata: metadata, 39 | vectors: vectors, 40 | content: content, 41 | embeddingFunc: embeddingFunc, 42 | }, 43 | } 44 | 45 | for _, tc := range tt { 46 | t.Run(tc.name, func(t *testing.T) { 47 | // Create document 48 | d, err := NewDocument(ctx, id, metadata, vectors, content, embeddingFunc) 49 | if err != nil { 50 | t.Fatal("expected no error, got", err) 51 | } 52 | // We can compare with DeepEqual after removing the embedding function 53 | d.Embedding = nil 54 | exp := Document{ 55 | ID: id, 56 | Metadata: metadata, 57 | Content: content, 58 | } 59 | if !reflect.DeepEqual(exp, d) { 60 | t.Fatalf("expected %+v, got %+v", exp, d) 61 | } 62 | }) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /embed_cohere.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "strings" 12 | "sync" 13 | ) 14 | 15 | type EmbeddingModelCohere string 16 | 17 | const ( 18 | EmbeddingModelCohereMultilingualV2 EmbeddingModelCohere = "embed-multilingual-v2.0" 19 | EmbeddingModelCohereEnglishLightV2 EmbeddingModelCohere = "embed-english-light-v2.0" 20 | EmbeddingModelCohereEnglishV2 EmbeddingModelCohere = "embed-english-v2.0" 21 | 22 | EmbeddingModelCohereMultilingualLightV3 EmbeddingModelCohere = "embed-multilingual-light-v3.0" 23 | EmbeddingModelCohereEnglishLightV3 EmbeddingModelCohere = "embed-english-light-v3.0" 24 | EmbeddingModelCohereMultilingualV3 EmbeddingModelCohere = "embed-multilingual-v3.0" 25 | EmbeddingModelCohereEnglishV3 EmbeddingModelCohere = "embed-english-v3.0" 26 | ) 27 | 28 | // Prefixes for external use. 29 | const ( 30 | InputTypeCohereSearchDocumentPrefix string = "search_document: " 31 | InputTypeCohereSearchQueryPrefix string = "search_query: " 32 | InputTypeCohereClassificationPrefix string = "classification: " 33 | InputTypeCohereClusteringPrefix string = "clustering: " 34 | ) 35 | 36 | // Input types for internal use. 37 | const ( 38 | inputTypeCohereSearchDocument string = "search_document" 39 | inputTypeCohereSearchQuery string = "search_query" 40 | inputTypeCohereClassification string = "classification" 41 | inputTypeCohereClustering string = "clustering" 42 | ) 43 | 44 | const baseURLCohere = "https://api.cohere.ai/v1" 45 | 46 | var validInputTypesCohere = map[string]string{ 47 | inputTypeCohereSearchDocument: InputTypeCohereSearchDocumentPrefix, 48 | inputTypeCohereSearchQuery: InputTypeCohereSearchQueryPrefix, 49 | inputTypeCohereClassification: InputTypeCohereClassificationPrefix, 50 | inputTypeCohereClustering: InputTypeCohereClusteringPrefix, 51 | } 52 | 53 | type cohereResponse struct { 54 | Embeddings [][]float32 `json:"embeddings"` 55 | } 56 | 57 | // NewEmbeddingFuncCohere returns a function that creates embeddings for a text 58 | // using Cohere's API. One important difference to OpenAI's and other's APIs is 59 | // that Cohere differentiates between document embeddings and search/query embeddings. 60 | // In order for this embedding func to do the differentiation, you have to prepend 61 | // the text with either "search_document" or "search_query". We'll cut off that 62 | // prefix before sending the document/query body to the API, we'll just use the 63 | // prefix to choose the right "input type" as they call it. 64 | // 65 | // When you set up a chromem-go collection with this embedding function, you might 66 | // want to create the document separately with [NewDocument] and then cut off the 67 | // prefix before adding the document to the collection. Otherwise, when you query 68 | // the collection, the returned documents will still have the prefix in their content. 69 | // 70 | // cohereFunc := chromem.NewEmbeddingFuncCohere(cohereApiKey, chromem.EmbeddingModelCohereEnglishV3) 71 | // content := "The sky is blue because of Rayleigh scattering." 72 | // // Create the document with the prefix. 73 | // contentWithPrefix := chromem.InputTypeCohereSearchDocumentPrefix + content 74 | // doc, _ := NewDocument(ctx, id, metadata, nil, contentWithPrefix, cohereFunc) 75 | // // Remove the prefix so that later query results don't have it. 76 | // doc.Content = content 77 | // _ = collection.AddDocument(ctx, doc) 78 | // 79 | // This is not necessary if you don't keep the content in the documents, as chromem-go 80 | // also works when documents only have embeddings. 81 | // You can also keep the prefix in the document, and only remove it after querying. 82 | // 83 | // We plan to improve this in the future. 84 | func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) EmbeddingFunc { 85 | // We don't set a default timeout here, although it's usually a good idea. 86 | // In our case though, the library user can set the timeout on the context, 87 | // and it might have to be a long timeout, depending on the text length. 88 | client := &http.Client{} 89 | 90 | var checkedNormalized bool 91 | checkNormalized := sync.Once{} 92 | 93 | return func(ctx context.Context, text string) ([]float32, error) { 94 | var inputType string 95 | for validInputType, validInputTypePrefix := range validInputTypesCohere { 96 | if strings.HasPrefix(text, validInputTypePrefix) { 97 | inputType = validInputType 98 | text = strings.TrimPrefix(text, validInputTypePrefix) 99 | break 100 | } 101 | } 102 | if inputType == "" { 103 | return nil, errors.New("text must start with a valid input type plus colon and space") 104 | } 105 | 106 | // Prepare the request body. 107 | reqBody, err := json.Marshal(map[string]any{ 108 | "model": model, 109 | "texts": []string{text}, 110 | "input_type": inputType, 111 | }) 112 | if err != nil { 113 | return nil, fmt.Errorf("couldn't marshal request body: %w", err) 114 | } 115 | 116 | // Create the request. Creating it with context is important for a timeout 117 | // to be possible, because the client is configured without a timeout. 118 | req, err := http.NewRequestWithContext(ctx, "POST", baseURLCohere+"/embed", bytes.NewBuffer(reqBody)) 119 | if err != nil { 120 | return nil, fmt.Errorf("couldn't create request: %w", err) 121 | } 122 | req.Header.Set("Accept", "application/json") 123 | req.Header.Set("Content-Type", "application/json") 124 | req.Header.Set("Authorization", "Bearer "+apiKey) 125 | 126 | // Send the request. 127 | resp, err := client.Do(req) 128 | if err != nil { 129 | return nil, fmt.Errorf("couldn't send request: %w", err) 130 | } 131 | defer resp.Body.Close() 132 | 133 | // Check the response status. 134 | if resp.StatusCode != http.StatusOK { 135 | return nil, errors.New("error response from the embedding API: " + resp.Status) 136 | } 137 | 138 | // Read and decode the response body. 139 | body, err := io.ReadAll(resp.Body) 140 | if err != nil { 141 | return nil, fmt.Errorf("couldn't read response body: %w", err) 142 | } 143 | var embeddingResponse cohereResponse 144 | err = json.Unmarshal(body, &embeddingResponse) 145 | if err != nil { 146 | return nil, fmt.Errorf("couldn't unmarshal response body: %w", err) 147 | } 148 | 149 | // Check if the response contains embeddings. 150 | if len(embeddingResponse.Embeddings) == 0 || len(embeddingResponse.Embeddings[0]) == 0 { 151 | return nil, errors.New("no embeddings found in the response") 152 | } 153 | 154 | v := embeddingResponse.Embeddings[0] 155 | checkNormalized.Do(func() { 156 | if isNormalized(v) { 157 | checkedNormalized = true 158 | } else { 159 | checkedNormalized = false 160 | } 161 | }) 162 | if !checkedNormalized { 163 | v = normalizeVector(v) 164 | } 165 | 166 | return v, nil 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /embed_compat.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | const ( 4 | baseURLMistral = "https://api.mistral.ai/v1" 5 | // Currently there's only one. Let's turn this into a pseudo-enum as soon as there are more. 6 | embeddingModelMistral = "mistral-embed" 7 | ) 8 | 9 | // NewEmbeddingFuncMistral returns a function that creates embeddings for a text 10 | // using the Mistral API. 11 | func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc { 12 | // Mistral embeddings are normalized, see section "Distance Measures" on 13 | // https://docs.mistral.ai/guides/embeddings/. 14 | normalized := true 15 | 16 | // The Mistral API docs don't mention the `encoding_format` as optional, 17 | // but it seems to be, just like OpenAI. So we reuse the OpenAI function. 18 | return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral, &normalized) 19 | } 20 | 21 | const baseURLJina = "https://api.jina.ai/v1" 22 | 23 | type EmbeddingModelJina string 24 | 25 | const ( 26 | EmbeddingModelJina2BaseEN EmbeddingModelJina = "jina-embeddings-v2-base-en" 27 | EmbeddingModelJina2BaseES EmbeddingModelJina = "jina-embeddings-v2-base-es" 28 | EmbeddingModelJina2BaseDE EmbeddingModelJina = "jina-embeddings-v2-base-de" 29 | EmbeddingModelJina2BaseZH EmbeddingModelJina = "jina-embeddings-v2-base-zh" 30 | 31 | EmbeddingModelJina2BaseCode EmbeddingModelJina = "jina-embeddings-v2-base-code" 32 | 33 | EmbeddingModelJinaClipV1 EmbeddingModelJina = "jina-clip-v1" 34 | ) 35 | 36 | // NewEmbeddingFuncJina returns a function that creates embeddings for a text 37 | // using the Jina API. 38 | func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc { 39 | return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model), nil) 40 | } 41 | 42 | const baseURLMixedbread = "https://api.mixedbread.ai" 43 | 44 | type EmbeddingModelMixedbread string 45 | 46 | const ( 47 | // Possibly outdated / not available anymore 48 | EmbeddingModelMixedbreadUAELargeV1 EmbeddingModelMixedbread = "UAE-Large-V1" 49 | // Possibly outdated / not available anymore 50 | EmbeddingModelMixedbreadBGELargeENV15 EmbeddingModelMixedbread = "bge-large-en-v1.5" 51 | // Possibly outdated / not available anymore 52 | EmbeddingModelMixedbreadGTELarge EmbeddingModelMixedbread = "gte-large" 53 | // Possibly outdated / not available anymore 54 | EmbeddingModelMixedbreadE5LargeV2 EmbeddingModelMixedbread = "e5-large-v2" 55 | // Possibly outdated / not available anymore 56 | EmbeddingModelMixedbreadMultilingualE5Large EmbeddingModelMixedbread = "multilingual-e5-large" 57 | // Possibly outdated / not available anymore 58 | EmbeddingModelMixedbreadMultilingualE5Base EmbeddingModelMixedbread = "multilingual-e5-base" 59 | // Possibly outdated / not available anymore 60 | EmbeddingModelMixedbreadAllMiniLML6V2 EmbeddingModelMixedbread = "all-MiniLM-L6-v2" 61 | // Possibly outdated / not available anymore 62 | EmbeddingModelMixedbreadGTELargeZh EmbeddingModelMixedbread = "gte-large-zh" 63 | 64 | EmbeddingModelMixedbreadLargeV1 EmbeddingModelMixedbread = "mxbai-embed-large-v1" 65 | EmbeddingModelMixedbreadDeepsetDELargeV1 EmbeddingModelMixedbread = "deepset-mxbai-embed-de-large-v1" 66 | EmbeddingModelMixedbread2DLargeV1 EmbeddingModelMixedbread = "mxbai-embed-2d-large-v1" 67 | ) 68 | 69 | // NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text 70 | // using the mixedbread.ai API. 71 | func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc { 72 | return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model), nil) 73 | } 74 | 75 | const baseURLLocalAI = "http://localhost:8080/v1" 76 | 77 | // NewEmbeddingFuncLocalAI returns a function that creates embeddings for a text 78 | // using the LocalAI API. 79 | // You can start a LocalAI instance like this: 80 | // 81 | // docker run -it -p 127.0.0.1:8080:8080 localai/localai:v2.7.0-ffmpeg-core bert-cpp 82 | // 83 | // And then call this constructor with model "bert-cpp-minilm-v6". 84 | // But other embedding models are supported as well. See the LocalAI documentation 85 | // for details. 86 | func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc { 87 | return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil) 88 | } 89 | 90 | const ( 91 | azureDefaultAPIVersion = "2024-02-01" 92 | ) 93 | 94 | // NewEmbeddingFuncAzureOpenAI returns a function that creates embeddings for a text 95 | // using the Azure OpenAI API. 96 | // The `deploymentURL` is the URL of the deployed model, e.g. "https://YOUR_RESOURCE_NAME.openai.azure.com/openai/deployments/YOUR_DEPLOYMENT_NAME" 97 | // See https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=console#how-to-get-embeddings 98 | func NewEmbeddingFuncAzureOpenAI(apiKey string, deploymentURL string, apiVersion string, model string) EmbeddingFunc { 99 | if apiVersion == "" { 100 | apiVersion = azureDefaultAPIVersion 101 | } 102 | return newEmbeddingFuncOpenAICompat(deploymentURL, apiKey, model, nil, map[string]string{"api-key": apiKey}, map[string]string{"api-version": apiVersion}) 103 | } 104 | -------------------------------------------------------------------------------- /embed_ollama.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "sync" 12 | ) 13 | 14 | const defaultBaseURLOllama = "http://localhost:11434/api" 15 | 16 | type ollamaResponse struct { 17 | Embedding []float32 `json:"embedding"` 18 | } 19 | 20 | // NewEmbeddingFuncOllama returns a function that creates embeddings for a text 21 | // using Ollama's embedding API. You can pass any model that Ollama supports and 22 | // that supports embeddings. A good one as of 2024-03-02 is "nomic-embed-text". 23 | // See https://ollama.com/library/nomic-embed-text 24 | // baseURLOllama is the base URL of the Ollama API. If it's empty, 25 | // "http://localhost:11434/api" is used. 26 | func NewEmbeddingFuncOllama(model string, baseURLOllama string) EmbeddingFunc { 27 | if baseURLOllama == "" { 28 | baseURLOllama = defaultBaseURLOllama 29 | } 30 | 31 | // We don't set a default timeout here, although it's usually a good idea. 32 | // In our case though, the library user can set the timeout on the context, 33 | // and it might have to be a long timeout, depending on the text length. 34 | client := &http.Client{} 35 | 36 | var checkedNormalized bool 37 | checkNormalized := sync.Once{} 38 | 39 | return func(ctx context.Context, text string) ([]float32, error) { 40 | // Prepare the request body. 41 | reqBody, err := json.Marshal(map[string]string{ 42 | "model": model, 43 | "prompt": text, 44 | }) 45 | if err != nil { 46 | return nil, fmt.Errorf("couldn't marshal request body: %w", err) 47 | } 48 | 49 | // Create the request. Creating it with context is important for a timeout 50 | // to be possible, because the client is configured without a timeout. 51 | req, err := http.NewRequestWithContext(ctx, "POST", baseURLOllama+"/embeddings", bytes.NewBuffer(reqBody)) 52 | if err != nil { 53 | return nil, fmt.Errorf("couldn't create request: %w", err) 54 | } 55 | req.Header.Set("Content-Type", "application/json") 56 | 57 | // Send the request. 58 | resp, err := client.Do(req) 59 | if err != nil { 60 | return nil, fmt.Errorf("couldn't send request: %w", err) 61 | } 62 | defer resp.Body.Close() 63 | 64 | // Check the response status. 65 | if resp.StatusCode != http.StatusOK { 66 | return nil, errors.New("error response from the embedding API: " + resp.Status) 67 | } 68 | 69 | // Read and decode the response body. 70 | body, err := io.ReadAll(resp.Body) 71 | if err != nil { 72 | return nil, fmt.Errorf("couldn't read response body: %w", err) 73 | } 74 | var embeddingResponse ollamaResponse 75 | err = json.Unmarshal(body, &embeddingResponse) 76 | if err != nil { 77 | return nil, fmt.Errorf("couldn't unmarshal response body: %w", err) 78 | } 79 | 80 | // Check if the response contains embeddings. 81 | if len(embeddingResponse.Embedding) == 0 { 82 | return nil, errors.New("no embeddings found in the response") 83 | } 84 | 85 | v := embeddingResponse.Embedding 86 | checkNormalized.Do(func() { 87 | if isNormalized(v) { 88 | checkedNormalized = true 89 | } else { 90 | checkedNormalized = false 91 | } 92 | }) 93 | if !checkedNormalized { 94 | v = normalizeVector(v) 95 | } 96 | 97 | return v, nil 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /embed_ollama_test.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "io" 8 | "net/http" 9 | "net/http/httptest" 10 | "net/url" 11 | "slices" 12 | "strings" 13 | "testing" 14 | ) 15 | 16 | func TestNewEmbeddingFuncOllama(t *testing.T) { 17 | model := "model-small" 18 | baseURLSuffix := "/api" 19 | prompt := "hello world" 20 | 21 | wantBody, err := json.Marshal(map[string]string{ 22 | "model": model, 23 | "prompt": prompt, 24 | }) 25 | if err != nil { 26 | t.Fatal("unexpected error:", err) 27 | } 28 | wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 29 | 30 | // Mock server 31 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 32 | // Check URL 33 | if !strings.HasSuffix(r.URL.Path, baseURLSuffix+"/embeddings") { 34 | t.Fatal("expected URL", baseURLSuffix+"/embeddings", "got", r.URL.Path) 35 | } 36 | // Check method 37 | if r.Method != "POST" { 38 | t.Fatal("expected method POST, got", r.Method) 39 | } 40 | // Check headers 41 | if r.Header.Get("Content-Type") != "application/json" { 42 | t.Fatal("expected Content-Type header", "application/json", "got", r.Header.Get("Content-Type")) 43 | } 44 | // Check body 45 | body, err := io.ReadAll(r.Body) 46 | if err != nil { 47 | t.Fatal("unexpected error:", err) 48 | } 49 | if !bytes.Equal(body, wantBody) { 50 | t.Fatal("expected body", wantBody, "got", body) 51 | } 52 | 53 | // Write response 54 | resp := ollamaResponse{ 55 | Embedding: wantRes, 56 | } 57 | w.WriteHeader(http.StatusOK) 58 | _ = json.NewEncoder(w).Encode(resp) 59 | })) 60 | defer ts.Close() 61 | 62 | // Get port from URL 63 | u, err := url.Parse(ts.URL) 64 | if err != nil { 65 | t.Fatal("unexpected error:", err) 66 | } 67 | 68 | f := NewEmbeddingFuncOllama(model, strings.Replace(defaultBaseURLOllama, "11434", u.Port(), 1)) 69 | res, err := f(context.Background(), prompt) 70 | if err != nil { 71 | t.Fatal("expected nil, got", err) 72 | } 73 | if slices.Compare(wantRes, res) != 0 { 74 | t.Fatal("expected res", wantRes, "got", res) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /embed_openai.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "os" 12 | "sync" 13 | ) 14 | 15 | const BaseURLOpenAI = "https://api.openai.com/v1" 16 | 17 | type EmbeddingModelOpenAI string 18 | 19 | const ( 20 | EmbeddingModelOpenAI2Ada EmbeddingModelOpenAI = "text-embedding-ada-002" 21 | 22 | EmbeddingModelOpenAI3Small EmbeddingModelOpenAI = "text-embedding-3-small" 23 | EmbeddingModelOpenAI3Large EmbeddingModelOpenAI = "text-embedding-3-large" 24 | ) 25 | 26 | type openAIResponse struct { 27 | Data []struct { 28 | Embedding []float32 `json:"embedding"` 29 | } `json:"data"` 30 | } 31 | 32 | // NewEmbeddingFuncDefault returns a function that creates embeddings for a text 33 | // using OpenAI`s "text-embedding-3-small" model via their API. 34 | // The model supports a maximum text length of 8191 tokens. 35 | // The API key is read from the environment variable "OPENAI_API_KEY". 36 | func NewEmbeddingFuncDefault() EmbeddingFunc { 37 | apiKey := os.Getenv("OPENAI_API_KEY") 38 | return NewEmbeddingFuncOpenAI(apiKey, EmbeddingModelOpenAI3Small) 39 | } 40 | 41 | // NewEmbeddingFuncOpenAI returns a function that creates embeddings for a text 42 | // using the OpenAI API. 43 | func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc { 44 | // OpenAI embeddings are normalized 45 | normalized := true 46 | return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized) 47 | } 48 | 49 | // NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text 50 | // using an OpenAI compatible API. For example: 51 | // - Azure OpenAI: https://azure.microsoft.com/en-us/products/ai-services/openai-service 52 | // - LitLLM: https://github.com/BerriAI/litellm 53 | // - Ollama: https://github.com/ollama/ollama/blob/main/docs/openai.md 54 | // - etc. 55 | // 56 | // The `normalized` parameter indicates whether the vectors returned by the embedding 57 | // model are already normalized, as is the case for OpenAI's and Mistral's models. 58 | // The flag is optional. If it's nil, it will be autodetected on the first request 59 | // (which bears a small risk that the vector just happens to have a length of 1). 60 | func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc { 61 | return newEmbeddingFuncOpenAICompat(baseURL, apiKey, model, normalized, nil, nil) 62 | } 63 | 64 | // newEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text 65 | // using an OpenAI compatible API. 66 | // It offers options to set request headers and query parameters 67 | // e.g. to pass the `api-key` header and the `api-version` query parameter for Azure OpenAI. 68 | // 69 | // The `normalized` parameter indicates whether the vectors returned by the embedding 70 | // model are already normalized, as is the case for OpenAI's and Mistral's models. 71 | // The flag is optional. If it's nil, it will be autodetected on the first request 72 | // (which bears a small risk that the vector just happens to have a length of 1). 73 | func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool, headers map[string]string, queryParams map[string]string) EmbeddingFunc { 74 | // We don't set a default timeout here, although it's usually a good idea. 75 | // In our case though, the library user can set the timeout on the context, 76 | // and it might have to be a long timeout, depending on the text length. 77 | client := &http.Client{} 78 | 79 | var checkedNormalized bool 80 | checkNormalized := sync.Once{} 81 | 82 | return func(ctx context.Context, text string) ([]float32, error) { 83 | // Prepare the request body. 84 | reqBody, err := json.Marshal(map[string]string{ 85 | "input": text, 86 | "model": model, 87 | }) 88 | if err != nil { 89 | return nil, fmt.Errorf("couldn't marshal request body: %w", err) 90 | } 91 | 92 | // Create the request. Creating it with context is important for a timeout 93 | // to be possible, because the client is configured without a timeout. 94 | req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/embeddings", bytes.NewBuffer(reqBody)) 95 | if err != nil { 96 | return nil, fmt.Errorf("couldn't create request: %w", err) 97 | } 98 | req.Header.Set("Content-Type", "application/json") 99 | req.Header.Set("Authorization", "Bearer "+apiKey) 100 | 101 | // Add headers 102 | for k, v := range headers { 103 | req.Header.Add(k, v) 104 | } 105 | 106 | // Add query parameters 107 | q := req.URL.Query() 108 | for k, v := range queryParams { 109 | q.Add(k, v) 110 | } 111 | req.URL.RawQuery = q.Encode() 112 | 113 | // Send the request. 114 | resp, err := client.Do(req) 115 | if err != nil { 116 | return nil, fmt.Errorf("couldn't send request: %w", err) 117 | } 118 | defer resp.Body.Close() 119 | 120 | // Check the response status. 121 | if resp.StatusCode != http.StatusOK { 122 | return nil, errors.New("error response from the embedding API: " + resp.Status) 123 | } 124 | 125 | // Read and decode the response body. 126 | body, err := io.ReadAll(resp.Body) 127 | if err != nil { 128 | return nil, fmt.Errorf("couldn't read response body: %w", err) 129 | } 130 | var embeddingResponse openAIResponse 131 | err = json.Unmarshal(body, &embeddingResponse) 132 | if err != nil { 133 | return nil, fmt.Errorf("couldn't unmarshal response body: %w", err) 134 | } 135 | 136 | // Check if the response contains embeddings. 137 | if len(embeddingResponse.Data) == 0 || len(embeddingResponse.Data[0].Embedding) == 0 { 138 | return nil, errors.New("no embeddings found in the response") 139 | } 140 | 141 | v := embeddingResponse.Data[0].Embedding 142 | if normalized != nil { 143 | if *normalized { 144 | return v, nil 145 | } 146 | return normalizeVector(v), nil 147 | } 148 | checkNormalized.Do(func() { 149 | if isNormalized(v) { 150 | checkedNormalized = true 151 | } else { 152 | checkedNormalized = false 153 | } 154 | }) 155 | if !checkedNormalized { 156 | v = normalizeVector(v) 157 | } 158 | 159 | return v, nil 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /embed_openai_test.go: -------------------------------------------------------------------------------- 1 | package chromem_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "io" 8 | "net/http" 9 | "net/http/httptest" 10 | "slices" 11 | "strings" 12 | "testing" 13 | 14 | "github.com/philippgille/chromem-go" 15 | ) 16 | 17 | type openAIResponse struct { 18 | Data []struct { 19 | Embedding []float32 `json:"embedding"` 20 | } `json:"data"` 21 | } 22 | 23 | func TestNewEmbeddingFuncOpenAICompat(t *testing.T) { 24 | apiKey := "secret" 25 | model := "model-small" 26 | baseURLSuffix := "/v1" 27 | input := "hello world" 28 | 29 | wantBody, err := json.Marshal(map[string]string{ 30 | "input": input, 31 | "model": model, 32 | }) 33 | if err != nil { 34 | t.Fatal("unexpected error:", err) 35 | } 36 | wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` 37 | 38 | // Mock server 39 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 40 | // Check URL 41 | if !strings.HasSuffix(r.URL.Path, baseURLSuffix+"/embeddings") { 42 | t.Fatal("expected URL", baseURLSuffix+"/embeddings", "got", r.URL.Path) 43 | } 44 | // Check method 45 | if r.Method != "POST" { 46 | t.Fatal("expected method POST, got", r.Method) 47 | } 48 | // Check headers 49 | if r.Header.Get("Authorization") != "Bearer "+apiKey { 50 | t.Fatal("expected Authorization header", "Bearer "+apiKey, "got", r.Header.Get("Authorization")) 51 | } 52 | if r.Header.Get("Content-Type") != "application/json" { 53 | t.Fatal("expected Content-Type header", "application/json", "got", r.Header.Get("Content-Type")) 54 | } 55 | // Check body 56 | body, err := io.ReadAll(r.Body) 57 | if err != nil { 58 | t.Fatal("unexpected error:", err) 59 | } 60 | if !bytes.Equal(body, wantBody) { 61 | t.Fatal("expected body", wantBody, "got", body) 62 | } 63 | 64 | // Write response 65 | resp := openAIResponse{ 66 | Data: []struct { 67 | Embedding []float32 `json:"embedding"` 68 | }{ 69 | {Embedding: wantRes}, 70 | }, 71 | } 72 | w.WriteHeader(http.StatusOK) 73 | _ = json.NewEncoder(w).Encode(resp) 74 | })) 75 | defer ts.Close() 76 | baseURL := ts.URL + baseURLSuffix 77 | 78 | f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model, nil) 79 | res, err := f(context.Background(), input) 80 | if err != nil { 81 | t.Fatal("expected nil, got", err) 82 | } 83 | if slices.Compare(wantRes, res) != 0 { 84 | t.Fatal("expected res", wantRes, "got", res) 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /embed_vertex.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | "net/http" 11 | "sync" 12 | ) 13 | 14 | type EmbeddingModelVertex string 15 | 16 | const ( 17 | EmbeddingModelVertexEnglishV1 EmbeddingModelVertex = "textembedding-gecko@001" 18 | EmbeddingModelVertexEnglishV2 EmbeddingModelVertex = "textembedding-gecko@002" 19 | EmbeddingModelVertexEnglishV3 EmbeddingModelVertex = "textembedding-gecko@003" 20 | EmbeddingModelVertexEnglishV4 EmbeddingModelVertex = "text-embedding-004" 21 | 22 | EmbeddingModelVertexMultilingualV1 EmbeddingModelVertex = "textembedding-gecko-multilingual@001" 23 | EmbeddingModelVertexMultilingualV2 EmbeddingModelVertex = "text-multilingual-embedding-002" 24 | ) 25 | 26 | const baseURLVertex = "https://us-central1-aiplatform.googleapis.com/v1" 27 | 28 | type vertexOptions struct { 29 | apiEndpoint string 30 | autoTruncate bool 31 | } 32 | 33 | func defaultVertexOptions() *vertexOptions { 34 | return &vertexOptions{ 35 | apiEndpoint: baseURLVertex, 36 | autoTruncate: false, 37 | } 38 | } 39 | 40 | type VertexOption func(*vertexOptions) 41 | 42 | func WithVertexAPIEndpoint(apiEndpoint string) VertexOption { 43 | return func(o *vertexOptions) { 44 | o.apiEndpoint = apiEndpoint 45 | } 46 | } 47 | 48 | func WithVertexAutoTruncate(autoTruncate bool) VertexOption { 49 | return func(o *vertexOptions) { 50 | o.autoTruncate = autoTruncate 51 | } 52 | } 53 | 54 | type vertexResponse struct { 55 | Predictions []vertexPrediction `json:"predictions"` 56 | } 57 | 58 | type vertexPrediction struct { 59 | Embeddings vertexEmbeddings `json:"embeddings"` 60 | } 61 | 62 | type vertexEmbeddings struct { 63 | Values []float32 `json:"values"` 64 | // there's more here, but we only care about the embeddings 65 | } 66 | 67 | func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex, opts ...VertexOption) EmbeddingFunc { 68 | cfg := defaultVertexOptions() 69 | for _, opt := range opts { 70 | opt(cfg) 71 | } 72 | 73 | if cfg.apiEndpoint == "" { 74 | cfg.apiEndpoint = baseURLVertex 75 | } 76 | 77 | // We don't set a default timeout here, although it's usually a good idea. 78 | // In our case though, the library user can set the timeout on the context, 79 | // and it might have to be a long timeout, depending on the text length. 80 | client := &http.Client{} 81 | 82 | var checkedNormalized bool 83 | checkNormalized := sync.Once{} 84 | 85 | return func(ctx context.Context, text string) ([]float32, error) { 86 | b := map[string]any{ 87 | "instances": []map[string]any{ 88 | { 89 | "content": text, 90 | }, 91 | }, 92 | "parameters": map[string]any{ 93 | "autoTruncate": cfg.autoTruncate, 94 | }, 95 | } 96 | 97 | // Prepare the request body. 98 | reqBody, err := json.Marshal(b) 99 | if err != nil { 100 | return nil, fmt.Errorf("couldn't marshal request body: %w", err) 101 | } 102 | 103 | fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", cfg.apiEndpoint, project, model) 104 | 105 | // Create the request. Creating it with context is important for a timeout 106 | // to be possible, because the client is configured without a timeout. 107 | req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewBuffer(reqBody)) 108 | if err != nil { 109 | return nil, fmt.Errorf("couldn't create request: %w", err) 110 | } 111 | req.Header.Set("Accept", "application/json") 112 | req.Header.Set("Content-Type", "application/json") 113 | req.Header.Set("Authorization", "Bearer "+apiKey) 114 | 115 | // Send the request. 116 | resp, err := client.Do(req) 117 | if err != nil { 118 | return nil, fmt.Errorf("couldn't send request: %w", err) 119 | } 120 | defer resp.Body.Close() 121 | 122 | // Check the response status. 123 | if resp.StatusCode != http.StatusOK { 124 | return nil, errors.New("error response from the embedding API: " + resp.Status) 125 | } 126 | 127 | // Read and decode the response body. 128 | body, err := io.ReadAll(resp.Body) 129 | if err != nil { 130 | return nil, fmt.Errorf("couldn't read response body: %w", err) 131 | } 132 | var embeddingResponse vertexResponse 133 | err = json.Unmarshal(body, &embeddingResponse) 134 | if err != nil { 135 | return nil, fmt.Errorf("couldn't unmarshal response body: %w", err) 136 | } 137 | 138 | // Check if the response contains embeddings. 139 | if len(embeddingResponse.Predictions) == 0 || len(embeddingResponse.Predictions[0].Embeddings.Values) == 0 { 140 | return nil, errors.New("no embeddings found in the response") 141 | } 142 | 143 | v := embeddingResponse.Predictions[0].Embeddings.Values 144 | checkNormalized.Do(func() { 145 | if isNormalized(v) { 146 | checkedNormalized = true 147 | } else { 148 | checkedNormalized = false 149 | } 150 | }) 151 | if !checkedNormalized { 152 | v = normalizeVector(v) 153 | } 154 | 155 | return v, nil 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | 1. [Minimal example](minimal) 4 | - A minimal example with the least amount of code and no comments 5 | - Uses OpenAI for creating the embeddings 6 | 2. [RAG Wikipedia Ollama](rag-wikipedia-ollama) 7 | - This example shows a retrieval augmented generation (RAG) application, using `chromem-go` as knowledge base for finding relevant info for a question. More specifically the app is doing *question answering*. 8 | - The underlying data is 200 Wikipedia articles (or rather their lead section / introduction). 9 | - Runs the embeddings model and LLM in [Ollama](https://github.com/ollama/ollama), to showcase how a RAG application can run entirely offline, without relying on OpenAI or other third party APIs. 10 | 3. [Semantic search arXiv OpenAI](semantic-search-arxiv-openai) 11 | - This example shows a semantic search application, using `chromem-go` as vector database for finding semantically relevant search results. 12 | - Loads and searches across ~5,000 arXiv papers in the "Computer Science - Computation and Language" category, which is the relevant one for Natural Language Processing (NLP) related papers. 13 | - Uses OpenAI for creating the embeddings 14 | 4. [WebAssembly](webassembly) 15 | - This example shows how `chromem-go` can be compiled to WebAssembly and then used from JavaScript in a browser 16 | 5. [S3 Export/Import](s3-export-import) 17 | - This example shows how to export the DB to and import it from any S3-compatible blob storage service 18 | -------------------------------------------------------------------------------- /examples/minimal/README.md: -------------------------------------------------------------------------------- 1 | # Minimal example 2 | 3 | This is a minimal example that shows how `chromem-go` works. For more sophisticated examples that use the persistent DB, locally running embedding models, a basic RAG pipeline, with more explanations in the README as well as code comments, check the other examples! 4 | 5 | ## How to run 6 | 7 | 1. Set the OpenAI API key in your env as `OPENAI_API_KEY` 8 | 2. Run the example: `go run .` 9 | 10 | ## Output 11 | 12 | ```text 13 | ID: 1 14 | Similarity: 0.6833369 15 | Content: The sky is blue because of Rayleigh scattering. 16 | ``` 17 | -------------------------------------------------------------------------------- /examples/minimal/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/philippgille/chromem-go/examples/minimal 2 | 3 | go 1.21 4 | 5 | require github.com/philippgille/chromem-go v0.0.0 6 | 7 | replace github.com/philippgille/chromem-go => ./../.. 8 | -------------------------------------------------------------------------------- /examples/minimal/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "runtime" 7 | 8 | "github.com/philippgille/chromem-go" 9 | ) 10 | 11 | func main() { 12 | ctx := context.Background() 13 | 14 | db := chromem.NewDB() 15 | 16 | // Passing nil as embedding function leads to OpenAI being used and requires 17 | // "OPENAI_API_KEY" env var to be set. Other providers are supported as well. 18 | // For example pass `chromem.NewEmbeddingFuncOllama(...)` to use Ollama. 19 | c, err := db.CreateCollection("knowledge-base", nil, nil) 20 | if err != nil { 21 | panic(err) 22 | } 23 | 24 | err = c.AddDocuments(ctx, []chromem.Document{ 25 | { 26 | ID: "1", 27 | Content: "The sky is blue because of Rayleigh scattering.", 28 | }, 29 | { 30 | ID: "2", 31 | Content: "Leaves are green because chlorophyll absorbs red and blue light.", 32 | }, 33 | }, runtime.NumCPU()) 34 | if err != nil { 35 | panic(err) 36 | } 37 | 38 | res, err := c.Query(ctx, "Why is the sky blue?", 1, nil, nil) 39 | if err != nil { 40 | panic(err) 41 | } 42 | 43 | fmt.Printf("ID: %v\nSimilarity: %v\nContent: %v\n", res[0].ID, res[0].Similarity, res[0].Content) 44 | 45 | /* Output: 46 | ID: 1 47 | Similarity: 0.6833369 48 | Content: The sky is blue because of Rayleigh scattering. 49 | */ 50 | } 51 | -------------------------------------------------------------------------------- /examples/rag-wikipedia-ollama/.gitignore: -------------------------------------------------------------------------------- 1 | /db 2 | -------------------------------------------------------------------------------- /examples/rag-wikipedia-ollama/README.md: -------------------------------------------------------------------------------- 1 | # RAG Wikipedia Ollama 2 | 3 | This example shows a retrieval augmented generation (RAG) application, using `chromem-go` as knowledge base for finding relevant info for a question. More specifically the app is doing *question answering*. The underlying data is 200 Wikipedia articles (or rather their lead section / introduction). 4 | 5 | We run the embeddings model and LLM in [Ollama](https://github.com/ollama/ollama), to showcase how a RAG application can run entirely offline, without relying on OpenAI or other third party APIs. It doesn't require a GPU, and a CPU like an 11th Gen Intel i5-1135G7 (like in the first generation Framework Laptop 13) is fast enough. 6 | 7 | As LLM we use Google's [Gemma (2B)](https://huggingface.co/google/gemma-2b), a very small model that doesn't need many resources and is fast, but doesn't have much knowledge, so it's a prime example for the combination of LLMs and vector databases. We found Gemma 2B to be superior to [TinyLlama (1.1B)](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0), [Stable LM 2 (1.6B)](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b) and [Phi-2 (2.7B)](https://huggingface.co/microsoft/phi-2) for the RAG use case. 8 | 9 | As embeddings model we use Nomic's [nomic-embed-text v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5). 10 | 11 | ## How to run 12 | 13 | 1. Install Ollama: 14 | 2. Download the two models: 15 | - `ollama pull gemma:2b` 16 | - `ollama pull nomic-embed-text` 17 | 3. Run the example: `go run .` 18 | 19 | ## Output 20 | 21 | The output can differ slightly on each run, but it's along the lines of: 22 | 23 | ```log 24 | 2024/03/02 20:02:30 Warming up Ollama... 25 | 2024/03/02 20:02:33 Question: When did the Monarch Company exist? 26 | 2024/03/02 20:02:33 Asking LLM... 27 | 2024/03/02 20:02:34 Initial reply from the LLM: "I cannot provide information on the Monarch Company, as I am unable to access real-time or comprehensive knowledge sources." 28 | 2024/03/02 20:02:34 Setting up chromem-go... 29 | 2024/03/02 20:02:34 Reading JSON lines... 30 | 2024/03/02 20:02:34 Adding documents to chromem-go, including creating their embeddings via Ollama API... 31 | 2024/03/02 20:03:11 Querying chromem-go... 32 | 2024/03/02 20:03:11 Search (incl query embedding) took 231.672667ms 33 | 2024/03/02 20:03:11 Document 1 (similarity: 0.655357): "Malleable Iron Range Company was a company that existed from 1896 to 1985 and primarily produced kitchen ranges made of malleable iron but also produced a variety of other related products. The company's primary trademark was 'Monarch' and was colloquially often referred to as the Monarch Company or just Monarch." 34 | 2024/03/02 20:03:11 Document 2 (similarity: 0.504042): "The American Motor Car Company was a short-lived company in the automotive industry founded in 1906 lasting until 1913. It was based in Indianapolis Indiana United States. The American Motor Car Company pioneered the underslung design." 35 | 2024/03/02 20:03:11 Asking LLM with augmented question... 36 | 2024/03/02 20:03:32 Reply after augmenting the question with knowledge: "The Monarch Company existed from 1896 to 1985." 37 | ``` 38 | 39 | The majority of the time here is spent during the embeddings creation, where we are limited by the performance of the Ollama API, which depends on your CPU/GPU and the embeddings model. 40 | 41 | ## OpenAI 42 | 43 | You can easily adapt the code to work with OpenAI instead of locally in Ollama. 44 | 45 | Add the OpenAI API key in your environment as `OPENAI_API_KEY`. 46 | 47 | Then, if you want to create the embeddings via OpenAI, but still use Gemma 2B as LLM: 48 | 49 |
Apply this patch 50 | 51 | ```diff 52 | diff --git a/examples/rag-wikipedia-ollama/main.go b/examples/rag-wikipedia-ollama/main.go 53 | index 55b3076..cee9561 100644 54 | --- a/examples/rag-wikipedia-ollama/main.go 55 | +++ b/examples/rag-wikipedia-ollama/main.go 56 | @@ -14,8 +14,6 @@ import ( 57 | 58 | const ( 59 | question = "When did the Monarch Company exist?" 60 | - // We use a local LLM running in Ollama for the embedding: https://huggingface.co/nomic-ai/nomic-embed-text-v1.5 61 | - embeddingModel = "nomic-embed-text" 62 | ) 63 | 64 | func main() { 65 | @@ -48,7 +46,7 @@ func main() { 66 | // variable to be set. 67 | // For this example we choose to use a locally running embedding model though. 68 | // It requires Ollama to serve its API at "http://localhost:11434/api". 69 | - collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel)) 70 | + collection, err := db.GetOrCreateCollection("Wikipedia", nil, nil) 71 | if err != nil { 72 | panic(err) 73 | } 74 | @@ -82,7 +80,7 @@ func main() { 75 | Content: article.Text, 76 | }) 77 | } 78 | - log.Println("Adding documents to chromem-go, including creating their embeddings via Ollama API...") 79 | + log.Println("Adding documents to chromem-go, including creating their embeddings via OpenAI API...") 80 | err = collection.AddDocuments(ctx, docs, runtime.NumCPU()) 81 | if err != nil { 82 | panic(err) 83 | ``` 84 | 85 |
86 | 87 | Or alternatively, if you want to use OpenAI for everything (embeddings creation and LLM): 88 | 89 |
Apply this patch 90 | 91 | ```diff 92 | diff --git a/examples/rag-wikipedia-ollama/llm.go b/examples/rag-wikipedia-ollama/llm.go 93 | index 1fde4ec..7cb81cc 100644 94 | --- a/examples/rag-wikipedia-ollama/llm.go 95 | +++ b/examples/rag-wikipedia-ollama/llm.go 96 | @@ -2,23 +2,13 @@ package main 97 | 98 | import ( 99 | "context" 100 | - "net/http" 101 | + "os" 102 | "strings" 103 | "text/template" 104 | 105 | "github.com/sashabaranov/go-openai" 106 | ) 107 | 108 | -const ( 109 | - // We use a local LLM running in Ollama for asking the question: https://github.com/ollama/ollama 110 | - ollamaBaseURL = "http://localhost:11434/v1" 111 | - // We use Google's Gemma (2B), a very small model that doesn't need much resources 112 | - // and is fast, but doesn't have much knowledge: https://huggingface.co/google/gemma-2b 113 | - // We found Gemma 2B to be superior to TinyLlama (1.1B), Stable LM 2 (1.6B) 114 | - // and Phi-2 (2.7B) for the retrieval augmented generation (RAG) use case. 115 | - llmModel = "gemma:2b" 116 | -) 117 | - 118 | // There are many different ways to provide the context to the LLM. 119 | // You can pass each context as user message, or the list as one user message, 120 | // or pass it in the system prompt. The system prompt itself also has a big impact 121 | @@ -47,10 +37,7 @@ Don't mention the knowledge base, context or search results in your answer. 122 | 123 | func askLLM(ctx context.Context, contexts []string, question string) string { 124 | // We can use the OpenAI client because Ollama is compatible with OpenAI's API. 125 | - openAIClient := openai.NewClientWithConfig(openai.ClientConfig{ 126 | - BaseURL: ollamaBaseURL, 127 | - HTTPClient: http.DefaultClient, 128 | - }) 129 | + openAIClient := openai.NewClient(os.Getenv("OPENAI_API_KEY")) 130 | sb := &strings.Builder{} 131 | err := systemPromptTpl.Execute(sb, contexts) 132 | if err != nil { 133 | @@ -66,7 +53,7 @@ func askLLM(ctx context.Context, contexts []string, question string) string { 134 | }, 135 | } 136 | res, err := openAIClient.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ 137 | - Model: llmModel, 138 | + Model: openai.GPT3Dot5Turbo, 139 | Messages: messages, 140 | }) 141 | if err != nil { 142 | diff --git a/examples/rag-wikipedia-ollama/main.go b/examples/rag-wikipedia-ollama/main.go 143 | index 55b3076..044a246 100644 144 | --- a/examples/rag-wikipedia-ollama/main.go 145 | +++ b/examples/rag-wikipedia-ollama/main.go 146 | @@ -12,19 +12,11 @@ import ( 147 | "github.com/philippgille/chromem-go" 148 | ) 149 | 150 | -const ( 151 | - question = "When did the Monarch Company exist?" 152 | - // We use a local LLM running in Ollama for the embedding: https://huggingface.co/nomic-ai/nomic-embed-text-v1.5 153 | - embeddingModel = "nomic-embed-text" 154 | -) 155 | +const question = "When did the Monarch Company exist?" 156 | 157 | func main() { 158 | ctx := context.Background() 159 | 160 | - // Warm up Ollama, in case the model isn't loaded yet 161 | - log.Println("Warming up Ollama...") 162 | - _ = askLLM(ctx, nil, "Hello!") 163 | - 164 | // First we ask an LLM a fairly specific question that it likely won't know 165 | // the answer to. 166 | log.Println("Question: " + question) 167 | @@ -48,7 +40,7 @@ func main() { 168 | // variable to be set. 169 | // For this example we choose to use a locally running embedding model though. 170 | // It requires Ollama to serve its API at "http://localhost:11434/api". 171 | - collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel)) 172 | + collection, err := db.GetOrCreateCollection("Wikipedia", nil, nil) 173 | if err != nil { 174 | panic(err) 175 | } 176 | @@ -82,7 +74,7 @@ func main() { 177 | Content: article.Text, 178 | }) 179 | } 180 | - log.Println("Adding documents to chromem-go, including creating their embeddings via Ollama API...") 181 | + log.Println("Adding documents to chromem-go, including creating their embeddings via OpenAI API...") 182 | err = collection.AddDocuments(ctx, docs, runtime.NumCPU()) 183 | if err != nil { 184 | panic(err) 185 | ``` 186 | 187 |
188 | -------------------------------------------------------------------------------- /examples/rag-wikipedia-ollama/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/philippgille/chromem-go/examples/rag-wikipedia-ollama 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/philippgille/chromem-go v0.0.0 7 | github.com/sashabaranov/go-openai v1.17.9 8 | ) 9 | 10 | replace github.com/philippgille/chromem-go => ./../.. 11 | -------------------------------------------------------------------------------- /examples/rag-wikipedia-ollama/go.sum: -------------------------------------------------------------------------------- 1 | github.com/sashabaranov/go-openai v1.17.9 h1:QEoBiGKWW68W79YIfXWEFZ7l5cEgZBV4/Ow3uy+5hNY= 2 | github.com/sashabaranov/go-openai v1.17.9/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= 3 | -------------------------------------------------------------------------------- /examples/rag-wikipedia-ollama/llm.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "net/http" 6 | "strings" 7 | "text/template" 8 | 9 | "github.com/sashabaranov/go-openai" 10 | ) 11 | 12 | const ( 13 | // We use a local LLM running in Ollama to ask a question: https://github.com/ollama/ollama 14 | ollamaBaseURL = "http://localhost:11434/v1" 15 | // We use Google's Gemma (2B), a very small model that doesn't need many resources 16 | // and is fast, but doesn't have much knowledge: https://huggingface.co/google/gemma-2b 17 | // We found Gemma 2B to be superior to TinyLlama (1.1B), Stable LM 2 (1.6B) 18 | // and Phi-2 (2.7B) for the retrieval augmented generation (RAG) use case. 19 | llmModel = "gemma:2b" 20 | ) 21 | 22 | // There are many different ways to provide the context to the LLM. 23 | // You can pass each context as user message, or the list as one user message, 24 | // or pass it in the system prompt. The system prompt itself also has a big impact 25 | // on how well the LLM handles the context, especially for LLMs with < 7B parameters. 26 | // The prompt engineering is up to you, it's out of scope for the vector database. 27 | var systemPromptTpl = template.Must(template.New("system_prompt").Parse(` 28 | You are a helpful assistant with access to a knowlege base, tasked with answering questions about the world and its history, people, places and other things. 29 | 30 | Answer the question in a very concise manner. Use an unbiased and journalistic tone. Do not repeat text. Don't make anything up. If you are not sure about something, just say that you don't know. 31 | {{- /* Stop here if no context is provided. The rest below is for handling contexts. */ -}} 32 | {{- if . -}} 33 | Answer the question solely based on the provided search results from the knowledge base. If the search results from the knowledge base are not relevant to the question at hand, just say that you don't know. Don't make anything up. 34 | 35 | Anything between the following 'context' XML blocks is retrieved from the knowledge base, not part of the conversation with the user. The bullet points are ordered by relevance, so the first one is the most relevant. 36 | 37 | 38 | {{- if . -}} 39 | {{- range $context := .}} 40 | - {{.}}{{end}} 41 | {{- end}} 42 | 43 | {{- end -}} 44 | 45 | Don't mention the knowledge base, context or search results in your answer. 46 | `)) 47 | 48 | func askLLM(ctx context.Context, contexts []string, question string) string { 49 | // We can use the OpenAI client because Ollama is compatible with OpenAI's API. 50 | openAIClient := openai.NewClientWithConfig(openai.ClientConfig{ 51 | BaseURL: ollamaBaseURL, 52 | HTTPClient: http.DefaultClient, 53 | }) 54 | sb := &strings.Builder{} 55 | err := systemPromptTpl.Execute(sb, contexts) 56 | if err != nil { 57 | panic(err) 58 | } 59 | messages := []openai.ChatCompletionMessage{ 60 | { 61 | Role: openai.ChatMessageRoleSystem, 62 | Content: sb.String(), 63 | }, { 64 | Role: openai.ChatMessageRoleUser, 65 | Content: "Question: " + question, 66 | }, 67 | } 68 | res, err := openAIClient.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ 69 | Model: llmModel, 70 | Messages: messages, 71 | }) 72 | if err != nil { 73 | panic(err) 74 | } 75 | reply := res.Choices[0].Message.Content 76 | reply = strings.TrimSpace(reply) 77 | 78 | return reply 79 | } 80 | -------------------------------------------------------------------------------- /examples/rag-wikipedia-ollama/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "io" 7 | "log" 8 | "os" 9 | "runtime" 10 | "strconv" 11 | "strings" 12 | "time" 13 | 14 | "github.com/philippgille/chromem-go" 15 | ) 16 | 17 | const ( 18 | question = "When did the Monarch Company exist?" 19 | // We use a local LLM running in Ollama for the embedding: https://huggingface.co/nomic-ai/nomic-embed-text-v1.5 20 | embeddingModel = "nomic-embed-text" 21 | ) 22 | 23 | func main() { 24 | ctx := context.Background() 25 | 26 | // Warm up Ollama, in case the model isn't loaded yet 27 | log.Println("Warming up Ollama...") 28 | _ = askLLM(ctx, nil, "Hello!") 29 | 30 | // First we ask an LLM a fairly specific question that it likely won't know 31 | // the answer to. 32 | log.Println("Question: " + question) 33 | log.Println("Asking LLM...") 34 | reply := askLLM(ctx, nil, question) 35 | log.Printf("Initial reply from the LLM: \"" + reply + "\"\n") 36 | 37 | // Now we use our vector database for retrieval augmented generation (RAG), 38 | // which means we provide the LLM with relevant knowledge. 39 | 40 | // Set up chromem-go with persistence, so that when the program restarts, the 41 | // DB's data is still available. 42 | log.Println("Setting up chromem-go...") 43 | db, err := chromem.NewPersistentDB("./db", false) 44 | if err != nil { 45 | panic(err) 46 | } 47 | // Create collection if it wasn't loaded from persistent storage yet. 48 | // You can pass nil as embedding function to use the default (OpenAI text-embedding-3-small), 49 | // which is very good and cheap. It would require the OPENAI_API_KEY environment 50 | // variable to be set. 51 | // For this example we choose to use a locally running embedding model though. 52 | // It requires Ollama to serve its API at "http://localhost:11434/api". 53 | collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel, "")) 54 | if err != nil { 55 | panic(err) 56 | } 57 | // Add docs to the collection, if the collection was just created (and not 58 | // loaded from persistent storage). 59 | var docs []chromem.Document 60 | if collection.Count() == 0 { 61 | // Here we use a DBpedia sample, where each line contains the lead section/introduction 62 | // to some Wikipedia article and its category. 63 | f, err := os.Open("dbpedia_sample.jsonl") 64 | if err != nil { 65 | panic(err) 66 | } 67 | defer f.Close() 68 | d := json.NewDecoder(f) 69 | log.Println("Reading JSON lines...") 70 | for i := 1; ; i++ { 71 | var article struct { 72 | Text string `json:"text"` 73 | Category string `json:"category"` 74 | } 75 | err := d.Decode(&article) 76 | if err == io.EOF { 77 | break // reached end of file 78 | } else if err != nil { 79 | panic(err) 80 | } 81 | 82 | // The embeddings model we use in this example ("nomic-embed-text") 83 | // fare better with a prefix to differentiate between document and query. 84 | // We'll have to cut it off later when we retrieve the documents. 85 | // An alternative is to create the embedding with `chromem.NewDocument()`, 86 | // and then change back the content before adding it do the collection 87 | // with `collection.AddDocument()`. 88 | content := "search_document: " + article.Text 89 | 90 | docs = append(docs, chromem.Document{ 91 | ID: strconv.Itoa(i), 92 | Metadata: map[string]string{"category": article.Category}, 93 | Content: content, 94 | }) 95 | } 96 | log.Println("Adding documents to chromem-go, including creating their embeddings via Ollama API...") 97 | err = collection.AddDocuments(ctx, docs, runtime.NumCPU()) 98 | if err != nil { 99 | panic(err) 100 | } 101 | } else { 102 | log.Println("Not reading JSON lines because collection was loaded from persistent storage.") 103 | } 104 | 105 | // Search for documents that are semantically similar to the original question. 106 | // We ask for the two most similar documents, but you can use more or less depending 107 | // on your needs and the supported context size of the LLM you use. 108 | // You can limit the search by filtering on content or metadata (like the article's 109 | // category), but we don't do that in this example. 110 | start := time.Now() 111 | log.Println("Querying chromem-go...") 112 | // "nomic-embed-text" specific prefix (not required with OpenAI's or other models) 113 | query := "search_query: " + question 114 | docRes, err := collection.Query(ctx, query, 2, nil, nil) 115 | if err != nil { 116 | panic(err) 117 | } 118 | log.Println("Search (incl query embedding) took", time.Since(start)) 119 | // Here you could filter out any documents whose similarity is below a certain threshold. 120 | // if docRes[...].Similarity < 0.5 { ... 121 | 122 | // Print the retrieved documents and their similarity to the question. 123 | for i, res := range docRes { 124 | // Cut off the prefix we added before adding the document (see comment above). 125 | // This is specific to the "nomic-embed-text" model. 126 | content := strings.TrimPrefix(res.Content, "search_document: ") 127 | log.Printf("Document %d (similarity: %f): \"%s\"\n", i+1, res.Similarity, content) 128 | } 129 | 130 | // Now we can ask the LLM again, augmenting the question with the knowledge we retrieved. 131 | // In this example we just use both retrieved documents as context. 132 | contexts := []string{docRes[0].Content, docRes[1].Content} 133 | log.Println("Asking LLM with augmented question...") 134 | reply = askLLM(ctx, contexts, question) 135 | log.Printf("Reply after augmenting the question with knowledge: \"" + reply + "\"\n") 136 | 137 | /* Output (can differ slightly on each run): 138 | 2024/03/02 20:02:30 Warming up Ollama... 139 | 2024/03/02 20:02:33 Question: When did the Monarch Company exist? 140 | 2024/03/02 20:02:33 Asking LLM... 141 | 2024/03/02 20:02:34 Initial reply from the LLM: "I cannot provide information on the Monarch Company, as I am unable to access real-time or comprehensive knowledge sources." 142 | 2024/03/02 20:02:34 Setting up chromem-go... 143 | 2024/03/02 20:02:34 Reading JSON lines... 144 | 2024/03/02 20:02:34 Adding documents to chromem-go, including creating their embeddings via Ollama API... 145 | 2024/03/02 20:03:11 Querying chromem-go... 146 | 2024/03/02 20:03:11 Search (incl query embedding) took 231.672667ms 147 | 2024/03/02 20:03:11 Document 1 (similarity: 0.655357): "Malleable Iron Range Company was a company that existed from 1896 to 1985 and primarily produced kitchen ranges made of malleable iron but also produced a variety of other related products. The company's primary trademark was 'Monarch' and was colloquially often referred to as the Monarch Company or just Monarch." 148 | 2024/03/02 20:03:11 Document 2 (similarity: 0.504042): "The American Motor Car Company was a short-lived company in the automotive industry founded in 1906 lasting until 1913. It was based in Indianapolis Indiana United States. The American Motor Car Company pioneered the underslung design." 149 | 2024/03/02 20:03:11 Asking LLM with augmented question... 150 | 2024/03/02 20:03:32 Reply after augmenting the question with knowledge: "The Monarch Company existed from 1896 to 1985." 151 | */ 152 | } 153 | -------------------------------------------------------------------------------- /examples/s3-export-import/README.md: -------------------------------------------------------------------------------- 1 | # S3 Export/Import 2 | 3 | This example shows how to export the DB to and import it from any S3-compatible blob storage service. 4 | 5 | - The example uses [MinIO](https://github.com/minio/minio), but any S3-compatible storage works. 6 | - The example uses [gocloud.dev](https://github.com/google/go-cloud) Go "Cloud Development Kit" from Google for interfacing with any S3-compatible storage, and because it provides methods for creating writers and readers that make it easy to use with `chromem-go`. 7 | 8 | ## How to run 9 | 10 | 1. Prepare the S3-compatible storage 11 | 1. `docker run -d --rm --name minio -p 127.0.0.1:9000:9000 -p 127.0.0.1:9001:9001 quay.io/minio/minio:RELEASE.2024-05-01T01-11-10Z server /data --console-address ":9001"` 12 | 2. Open the MinIO Console in your browser: 13 | 3. Log in with user `minioadmin` and password `minioadmin` 14 | 4. Use the web UI to create a bucket named `mybucket` 15 | 2. Set the OpenAI API key in your env as `OPENAI_API_KEY` 16 | 3. `go run .` 17 | 18 | You can also check and see the exported DB as `chromem.gob.gz`. 19 | 20 | To stop the MinIO server run `docker stop minio`. 21 | 22 | ## Output 23 | 24 | ```text 25 | 2024/05/04 19:24:07 Successfully exported DB to S3 storage. 26 | 2024/05/04 19:24:07 Imported collection with 1 documents 27 | 2024/05/04 19:24:07 Successfully imported DB from S3 storage. 28 | ``` 29 | -------------------------------------------------------------------------------- /examples/s3-export-import/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/philippgille/chromem-go/examples/s3-export-import 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/philippgille/chromem-go v0.0.0 7 | gocloud.dev v0.37.0 8 | ) 9 | 10 | replace github.com/philippgille/chromem-go => ./../.. 11 | 12 | require ( 13 | github.com/aws/aws-sdk-go v1.50.36 // indirect 14 | github.com/aws/aws-sdk-go-v2 v1.25.3 // indirect 15 | github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1 // indirect 16 | github.com/aws/aws-sdk-go-v2/config v1.27.7 // indirect 17 | github.com/aws/aws-sdk-go-v2/credentials v1.17.7 // indirect 18 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.3 // indirect 19 | github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.16.9 // indirect 20 | github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3 // indirect 21 | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3 // indirect 22 | github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect 23 | github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.3 // indirect 24 | github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.1 // indirect 25 | github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.5 // indirect 26 | github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.5 // indirect 27 | github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.3 // indirect 28 | github.com/aws/aws-sdk-go-v2/service/s3 v1.51.4 // indirect 29 | github.com/aws/aws-sdk-go-v2/service/sso v1.20.2 // indirect 30 | github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.2 // indirect 31 | github.com/aws/aws-sdk-go-v2/service/sts v1.28.4 // indirect 32 | github.com/aws/smithy-go v1.20.1 // indirect 33 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 34 | github.com/golang/protobuf v1.5.4 // indirect 35 | github.com/google/wire v0.6.0 // indirect 36 | github.com/googleapis/gax-go/v2 v2.12.2 // indirect 37 | github.com/jmespath/go-jmespath v0.4.0 // indirect 38 | go.opencensus.io v0.24.0 // indirect 39 | golang.org/x/net v0.22.0 // indirect 40 | golang.org/x/sys v0.18.0 // indirect 41 | golang.org/x/text v0.14.0 // indirect 42 | golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect 43 | google.golang.org/api v0.169.0 // indirect 44 | google.golang.org/genproto/googleapis/rpc v0.0.0-20240311173647-c811ad7063a7 // indirect 45 | google.golang.org/grpc v1.62.1 // indirect 46 | google.golang.org/protobuf v1.33.0 // indirect 47 | ) 48 | -------------------------------------------------------------------------------- /examples/s3-export-import/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log" 6 | "os" 7 | 8 | "github.com/philippgille/chromem-go" 9 | "gocloud.dev/blob" 10 | _ "gocloud.dev/blob/s3blob" 11 | ) 12 | 13 | func main() { 14 | ctx := context.Background() 15 | 16 | // As S3-style storage we use a local MinIO instance in this example. It has 17 | // default credentials which we set in environment variables in order to be 18 | // read when calling `blob.OpenBucket`, because that call implies using the 19 | // AWS SDK default "shared config" loader. 20 | // That loader checks the environment variables, but can also fall back to a 21 | // file in `~/.aws/config` from `aws sso login` or similar. 22 | // See https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/configuring-sdk.html#specifying-credentials 23 | // for details about credential loading. 24 | err := os.Setenv("AWS_ACCESS_KEY_ID", "minioadmin") 25 | if err != nil { 26 | panic(err) 27 | } 28 | err = os.Setenv("AWS_SECRET_ACCESS_KEY", "minioadmin") 29 | if err != nil { 30 | panic(err) 31 | } 32 | // A region configuration is also required. Alternatively it can be passed in 33 | // the connection string with "®ion=us-west-1" for example. 34 | err = os.Setenv("AWS_DEFAULT_REGION", "us-west-1") 35 | if err != nil { 36 | panic(err) 37 | } 38 | 39 | // Export DB 40 | err = exportDB(ctx) 41 | if err != nil { 42 | panic(err) 43 | } 44 | log.Println("Successfully exported DB to S3 storage.") 45 | 46 | // Import DB 47 | err = importDB(ctx) 48 | if err != nil { 49 | panic(err) 50 | } 51 | log.Println("Successfully imported DB from S3 storage.") 52 | } 53 | 54 | func exportDB(ctx context.Context) error { 55 | // Create and fill DB 56 | db := chromem.NewDB() 57 | c, err := db.CreateCollection("knowledge-base", nil, nil) 58 | if err != nil { 59 | return err 60 | } 61 | err = c.AddDocument(ctx, chromem.Document{ 62 | ID: "1", 63 | Content: "The sky is blue because of Rayleigh scattering.", 64 | }) 65 | if err != nil { 66 | return err 67 | } 68 | 69 | // Open S3 bucket. We're using a local MinIO instance here, but it can be any 70 | // S3-compatible storage. We're also using the gocloud.dev/blob package instead 71 | // of the AWS SDK for Go directly, because it provides a unified Writer/Reader 72 | // API for different cloud storage providers. 73 | bucket, err := blob.OpenBucket(ctx, "s3://mybucket?"+ 74 | "endpoint=localhost:9000&"+ 75 | "disableSSL=true&"+ 76 | "s3ForcePathStyle=true") 77 | if err != nil { 78 | return err 79 | } 80 | 81 | // Create writer to an S3 object 82 | w, err := bucket.NewWriter(ctx, "chromem.gob.gz", nil) 83 | if err != nil { 84 | return err 85 | } 86 | // Instead of deferring w.Close() here, we close it at the end of the function 87 | // to handle its errors, as the close is important for the actual write to happen 88 | // and can lead to errors such as "The specified bucket does not exist" etc. 89 | // Another option is to use a named return value and defer a function that 90 | // overwrites the error with the close error or uses [errors.Join] or similar. 91 | 92 | // Persist the DB to the S3 object 93 | err = db.ExportToWriter(w, true, "") 94 | if err != nil { 95 | return err 96 | } 97 | 98 | return w.Close() 99 | } 100 | 101 | func importDB(ctx context.Context) error { 102 | // Open S3 bucket. We're using a local MinIO instance here, but it can be any 103 | // S3-compatible storage. We're also using the gocloud.dev/blob package instead 104 | // of the AWS SDK for Go directly, because it provides a unified Writer/Reader 105 | // API for different cloud storage providers. 106 | bucket, err := blob.OpenBucket(ctx, "s3://mybucket?"+ 107 | "endpoint=localhost:9000&"+ 108 | "disableSSL=true&"+ 109 | "s3ForcePathStyle=true") 110 | if err != nil { 111 | return err 112 | } 113 | 114 | // Open reader to the S3 object 115 | r, err := bucket.NewReader(ctx, "chromem.gob.gz", nil) 116 | if err != nil { 117 | return err 118 | } 119 | defer r.Close() 120 | 121 | // Create empty DB 122 | db := chromem.NewDB() 123 | 124 | // Import the DB from the S3 object 125 | err = db.ImportFromReader(r, "") 126 | if err != nil { 127 | return err 128 | } 129 | 130 | c := db.GetCollection("knowledge-base", nil) 131 | log.Printf("Imported collection with %d documents\n", c.Count()) 132 | 133 | return nil 134 | } 135 | -------------------------------------------------------------------------------- /examples/semantic-search-arxiv-openai/.gitignore: -------------------------------------------------------------------------------- 1 | /db 2 | -------------------------------------------------------------------------------- /examples/semantic-search-arxiv-openai/README.md: -------------------------------------------------------------------------------- 1 | # Semantic search arXiv OpenAI 2 | 3 | This example shows a semantic search application, using `chromem-go` as vector database for finding semantically relevant search results. We load and search across ~5,000 arXiv papers in the "Computer Science - Computation and Language" category, which is the relevant one for Natural Language Processing (NLP) related papers. 4 | 5 | This is not a retrieval augmented generation (RAG) app, because after *retrieving* the semantically relevant results, we don't *augment* any prompt to an LLM. No LLM generates the final output. 6 | 7 | ## How to run 8 | 9 | 1. Prepare the dataset 10 | 1. Download `arxiv-metadata-oai-snapshot.json` from 11 | 2. Filter by "Computer Science - Computation and Language" category (see [taxonomy](https://arxiv.org/category_taxonomy)), filter by updates from 2023 12 | 1. Ensure you have [ripgrep](https://github.com/BurntSushi/ripgrep) installed, or adapt the following commands to use grep 13 | 2. Run `rg '"categories":"cs.CL"' ~/Downloads/arxiv-metadata-oai-snapshot.json | rg '"update_date":"2023' > /tmp/arxiv_cs-cl_2023.jsonl` (adapt input file path if necessary) 14 | 3. Check the data 15 | 1. `wc -l /tmp/arxiv_cs-cl_2023.jsonl` should show ~5,000 lines 16 | 2. `du -h /tmp/arxiv_cs-cl_2023.jsonl` should show ~8.8 MB 17 | 2. Set the OpenAI API key in your env as `OPENAI_API_KEY` 18 | 3. Run the example: `go run .` 19 | 20 | ## Output 21 | 22 | The output can differ slightly on each run, but it's along the lines of: 23 | 24 | ```log 25 | 2024/03/10 18:23:55 Setting up chromem-go... 26 | 2024/03/10 18:23:55 Reading JSON lines... 27 | 2024/03/10 18:23:55 Read and parsed 5006 documents. 28 | 2024/03/10 18:23:55 Adding documents to chromem-go, including creating their embeddings via OpenAI API... 29 | 2024/03/10 18:28:12 Querying chromem-go... 30 | 2024/03/10 18:28:12 Search (incl query embedding) took 529.451163ms 31 | 2024/03/10 18:28:12 Search results: 32 | 1) Similarity 0.488895: 33 | URL: https://arxiv.org/abs/2209.15469 34 | Submitter: Christian Buck 35 | Title: Zero-Shot Retrieval with Search Agents and Hybrid Environments 36 | Abstract: Learning to search is the task of building artificial agents that learn to autonomously use a search... 37 | 2) Similarity 0.480713: 38 | URL: https://arxiv.org/abs/2305.11516 39 | Submitter: Ryo Nagata Dr. 40 | Title: Contextualized Word Vector-based Methods for Discovering Semantic Differences with No Training nor Word Alignment 41 | Abstract: In this paper, we propose methods for discovering semantic differences in words appearing in two cor... 42 | 3) Similarity 0.476079: 43 | URL: https://arxiv.org/abs/2310.14025 44 | Submitter: Maria Lymperaiou 45 | Title: Large Language Models and Multimodal Retrieval for Visual Word Sense Disambiguation 46 | Abstract: Visual Word Sense Disambiguation (VWSD) is a novel challenging task with the goal of retrieving an i... 47 | 4) Similarity 0.474883: 48 | URL: https://arxiv.org/abs/2302.14785 49 | Submitter: Teven Le Scao 50 | Title: Joint Representations of Text and Knowledge Graphs for Retrieval and Evaluation 51 | Abstract: A key feature of neural models is that they can produce semantic vector representations of objects (... 52 | 5) Similarity 0.470326: 53 | URL: https://arxiv.org/abs/2309.02403 54 | Submitter: Dallas Card 55 | Title: Substitution-based Semantic Change Detection using Contextual Embeddings 56 | Abstract: Measuring semantic change has thus far remained a task where methods using contextual embeddings hav... 57 | 6) Similarity 0.466851: 58 | URL: https://arxiv.org/abs/2309.08187 59 | Submitter: Vu Tran 60 | Title: Encoded Summarization: Summarizing Documents into Continuous Vector Space for Legal Case Retrieval 61 | Abstract: We present our method for tackling a legal case retrieval task by introducing our method of encoding... 62 | 7) Similarity 0.461783: 63 | URL: https://arxiv.org/abs/2307.16638 64 | Submitter: Maiia Bocharova Bocharova 65 | Title: VacancySBERT: the approach for representation of titles and skills for semantic similarity search in the recruitment domain 66 | Abstract: The paper focuses on deep learning semantic search algorithms applied in the HR domain. The aim of t... 67 | 8) Similarity 0.460481: 68 | URL: https://arxiv.org/abs/2106.07400 69 | Submitter: Clara Meister 70 | Title: Determinantal Beam Search 71 | Abstract: Beam search is a go-to strategy for decoding neural sequence models. The algorithm can naturally be ... 72 | 9) Similarity 0.460001: 73 | URL: https://arxiv.org/abs/2305.04049 74 | Submitter: Yuxia Wu 75 | Title: Actively Discovering New Slots for Task-oriented Conversation 76 | Abstract: Existing task-oriented conversational search systems heavily rely on domain ontologies with pre-defi... 77 | 10) Similarity 0.458321: 78 | URL: https://arxiv.org/abs/2305.08654 79 | Submitter: Taichi Aida 80 | Title: Unsupervised Semantic Variation Prediction using the Distribution of Sibling Embeddings 81 | Abstract: Languages are dynamic entities, where the meanings associated with words constantly change with time... 82 | ``` 83 | 84 | The majority of the time here is spent during the embeddings creation, where we are limited by the performance of the OpenAI API. 85 | -------------------------------------------------------------------------------- /examples/semantic-search-arxiv-openai/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/philippgille/chromem-go/examples/semantic-search-arxiv-openai 2 | 3 | go 1.21 4 | 5 | require github.com/philippgille/chromem-go v0.0.0 6 | 7 | replace github.com/philippgille/chromem-go => ./../.. 8 | -------------------------------------------------------------------------------- /examples/semantic-search-arxiv-openai/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "log" 9 | "os" 10 | "runtime" 11 | "strings" 12 | "time" 13 | 14 | "github.com/philippgille/chromem-go" 15 | ) 16 | 17 | const searchTerm = "semantic search with vector databases" 18 | 19 | func main() { 20 | ctx := context.Background() 21 | 22 | // Set up chromem-go with persistence, so that when the program restarts, the 23 | // DB's data is still available. 24 | log.Println("Setting up chromem-go...") 25 | db, err := chromem.NewPersistentDB("./db", false) 26 | if err != nil { 27 | panic(err) 28 | } 29 | // Create collection if it wasn't loaded from persistent storage yet. 30 | // We pass nil as embedding function to use the default (OpenAI text-embedding-3-small), 31 | // which is very good and cheap. It requires the OPENAI_API_KEY environment 32 | // variable to be set. See the other examples for how to use Ollama. 33 | collection, err := db.GetOrCreateCollection("arXiv cs.CL 2023", nil, nil) 34 | if err != nil { 35 | panic(err) 36 | } 37 | // Add docs to the collection, if the collection was just created (and not 38 | // loaded from persistent storage). 39 | var docs []chromem.Document 40 | if collection.Count() == 0 { 41 | // Here we use an arXiv metadata sample, where each line contains the metadata 42 | // of a paper, including its submitter, title and abstract. 43 | f, err := os.Open("/tmp/arxiv_cs-cl_2023.jsonl") 44 | if err != nil { 45 | panic(err) 46 | } 47 | defer f.Close() 48 | d := json.NewDecoder(f) 49 | log.Println("Reading JSON lines...") 50 | i := 0 51 | for { 52 | var paper struct { 53 | ID string `json:"id"` 54 | Submitter string `json:"submitter"` 55 | Title string `json:"title"` 56 | Abstract string `json:"abstract"` 57 | } 58 | err := d.Decode(&paper) 59 | if err == io.EOF { 60 | break // reached end of file 61 | } else if err != nil { 62 | panic(err) 63 | } 64 | 65 | title := strings.ReplaceAll(paper.Title, "\n", " ") 66 | title = strings.ReplaceAll(title, " ", " ") 67 | content := strings.TrimSpace(paper.Abstract) 68 | docs = append(docs, chromem.Document{ 69 | ID: paper.ID, 70 | Metadata: map[string]string{"submitter": paper.Submitter, "title": title}, 71 | Content: content, 72 | }) 73 | i++ 74 | } 75 | log.Println("Read and parsed", i, "documents.") 76 | log.Println("Adding documents to chromem-go, including creating their embeddings via OpenAI API...") 77 | err = collection.AddDocuments(ctx, docs, runtime.NumCPU()) 78 | if err != nil { 79 | panic(err) 80 | } 81 | } else { 82 | log.Println("Not reading JSON lines because collection was loaded from persistent storage.") 83 | } 84 | 85 | // Search for documents that are semantically similar to the search term. 86 | // We ask for the 10 most similar documents, but you can use more or less depending 87 | // on your needs. 88 | // You can limit the search by filtering on content or metadata (like the paper's 89 | // submitter), but we don't do that in this example. 90 | log.Println("Querying chromem-go...") 91 | start := time.Now() 92 | docRes, err := collection.Query(ctx, searchTerm, 10, nil, nil) 93 | if err != nil { 94 | panic(err) 95 | } 96 | log.Println("Search (incl query embedding) took", time.Since(start)) 97 | // Here you could filter out any documents whose similarity is below a certain threshold. 98 | // if docRes[...].Similarity < 0.5 { ... 99 | 100 | // Print the retrieved documents and their similarity to the question. 101 | buf := &strings.Builder{} 102 | for i, res := range docRes { 103 | content := strings.ReplaceAll(res.Content, "\n", " ") 104 | content = content[:min(100, len(content))] + "..." 105 | fmt.Fprintf(buf, "\t%d) Similarity %f:\n"+ 106 | "\t\tURL: https://arxiv.org/abs/%s\n"+ 107 | "\t\tSubmitter: %s\n"+ 108 | "\t\tTitle: %s\n"+ 109 | "\t\tAbstract: %s\n", 110 | i+1, res.Similarity, res.ID, res.Metadata["submitter"], res.Metadata["title"], content) 111 | } 112 | log.Printf("Search results:\n%s\n", buf.String()) 113 | 114 | /* Output: 115 | 2024/03/10 18:23:55 Setting up chromem-go... 116 | 2024/03/10 18:23:55 Reading JSON lines... 117 | 2024/03/10 18:23:55 Read and parsed 5006 documents. 118 | 2024/03/10 18:23:55 Adding documents to chromem-go, including creating their embeddings via OpenAI API... 119 | 2024/03/10 18:28:12 Querying chromem-go... 120 | 2024/03/10 18:28:12 Search (incl query embedding) took 529.451163ms 121 | 2024/03/10 18:28:12 Search results: 122 | 1) Similarity 0.488895: 123 | URL: https://arxiv.org/abs/2209.15469 124 | Submitter: Christian Buck 125 | Title: Zero-Shot Retrieval with Search Agents and Hybrid Environments 126 | Abstract: Learning to search is the task of building artificial agents that learn to autonomously use a search... 127 | 2) Similarity 0.480713: 128 | URL: https://arxiv.org/abs/2305.11516 129 | Submitter: Ryo Nagata Dr. 130 | Title: Contextualized Word Vector-based Methods for Discovering Semantic Differences with No Training nor Word Alignment 131 | Abstract: In this paper, we propose methods for discovering semantic differences in words appearing in two cor... 132 | 3) Similarity 0.476079: 133 | URL: https://arxiv.org/abs/2310.14025 134 | Submitter: Maria Lymperaiou 135 | Title: Large Language Models and Multimodal Retrieval for Visual Word Sense Disambiguation 136 | Abstract: Visual Word Sense Disambiguation (VWSD) is a novel challenging task with the goal of retrieving an i... 137 | 4) Similarity 0.474883: 138 | URL: https://arxiv.org/abs/2302.14785 139 | Submitter: Teven Le Scao 140 | Title: Joint Representations of Text and Knowledge Graphs for Retrieval and Evaluation 141 | Abstract: A key feature of neural models is that they can produce semantic vector representations of objects (... 142 | 5) Similarity 0.470326: 143 | URL: https://arxiv.org/abs/2309.02403 144 | Submitter: Dallas Card 145 | Title: Substitution-based Semantic Change Detection using Contextual Embeddings 146 | Abstract: Measuring semantic change has thus far remained a task where methods using contextual embeddings hav... 147 | 6) Similarity 0.466851: 148 | URL: https://arxiv.org/abs/2309.08187 149 | Submitter: Vu Tran 150 | Title: Encoded Summarization: Summarizing Documents into Continuous Vector Space for Legal Case Retrieval 151 | Abstract: We present our method for tackling a legal case retrieval task by introducing our method of encoding... 152 | 7) Similarity 0.461783: 153 | URL: https://arxiv.org/abs/2307.16638 154 | Submitter: Maiia Bocharova Bocharova 155 | Title: VacancySBERT: the approach for representation of titles and skills for semantic similarity search in the recruitment domain 156 | Abstract: The paper focuses on deep learning semantic search algorithms applied in the HR domain. The aim of t... 157 | 8) Similarity 0.460481: 158 | URL: https://arxiv.org/abs/2106.07400 159 | Submitter: Clara Meister 160 | Title: Determinantal Beam Search 161 | Abstract: Beam search is a go-to strategy for decoding neural sequence models. The algorithm can naturally be ... 162 | 9) Similarity 0.460001: 163 | URL: https://arxiv.org/abs/2305.04049 164 | Submitter: Yuxia Wu 165 | Title: Actively Discovering New Slots for Task-oriented Conversation 166 | Abstract: Existing task-oriented conversational search systems heavily rely on domain ontologies with pre-defi... 167 | 10) Similarity 0.458321: 168 | URL: https://arxiv.org/abs/2305.08654 169 | Submitter: Taichi Aida 170 | Title: Unsupervised Semantic Variation Prediction using the Distribution of Sibling Embeddings 171 | Abstract: Languages are dynamic entities, where the meanings associated with words constantly change with time... 172 | */ 173 | } 174 | -------------------------------------------------------------------------------- /examples/webassembly/README.md: -------------------------------------------------------------------------------- 1 | # WebAssembly (WASM) 2 | 3 | Go can compile to WebAssembly, which you can then use from JavaScript in a Browser or similar environments (Node, Deno, Bun etc.). You could also target WASI (WebAssembly System Interface) and run it in a standalone runtime (wazero, wasmtime, Wasmer), but in this example we focus on the Browser use case. 4 | 5 | ## How to run 6 | 7 | 1. Compile the `chromem-go` WASM binding to WebAssembly: 8 | 1. `cd /path/to/chromem-go/wasm` 9 | 2. `GOOS=js GOARCH=wasm go build -o ../examples/webassembly/chromem-go.wasm` 10 | 2. Copy Go's wrapper JavaScript: 11 | 1. `cp $(go env GOROOT)/misc/wasm/wasm_exec.js ../examples/webassembly/wasm_exec.js` 12 | 3. Serve the files 13 | 1. `cd ../examples/webassembly` 14 | 2. `go run github.com/philippgille/serve@latest -b localhost -p 8080` or similar 15 | 4. Open in your browser 16 | -------------------------------------------------------------------------------- /examples/webassembly/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 61 | 62 | 63 | 64 | 65 | 66 |

67 | 68 | 69 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/philippgille/chromem-go 2 | 3 | go 1.21 4 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/philippgille/chromem-go/8311eb0938d6877effd931efb18c1d7460277bac/go.sum -------------------------------------------------------------------------------- /persistence.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "bytes" 5 | "compress/gzip" 6 | "crypto/aes" 7 | "crypto/cipher" 8 | "crypto/rand" 9 | "crypto/sha256" 10 | "encoding/gob" 11 | "encoding/hex" 12 | "errors" 13 | "fmt" 14 | "io" 15 | "io/fs" 16 | "os" 17 | "path/filepath" 18 | ) 19 | 20 | const metadataFileName = "00000000" 21 | 22 | func hash2hex(name string) string { 23 | hash := sha256.Sum256([]byte(name)) 24 | // We encode 4 of the 32 bytes (32 out of 256 bits), so 8 hex characters. 25 | // It's enough to avoid collisions in reasonable amounts of documents per collection 26 | // and being shorter is better for file paths. 27 | return hex.EncodeToString(hash[:4]) 28 | } 29 | 30 | // persistToFile persists an object to a file at the given path. The object is serialized 31 | // as gob, optionally compressed with flate (as gzip) and optionally encrypted with 32 | // AES-GCM. The encryption key must be 32 bytes long. If the file exists, it's 33 | // overwritten, otherwise created. 34 | func persistToFile(filePath string, obj any, compress bool, encryptionKey string) error { 35 | if filePath == "" { 36 | return fmt.Errorf("file path is empty") 37 | } 38 | // AES 256 requires a 32 byte key 39 | if encryptionKey != "" { 40 | if len(encryptionKey) != 32 { 41 | return errors.New("encryption key must be 32 bytes long") 42 | } 43 | } 44 | 45 | // If path doesn't exist, create the parent path. 46 | // If path exists, and it's a directory, return an error. 47 | fi, err := os.Stat(filePath) 48 | if err != nil { 49 | if !errors.Is(err, fs.ErrNotExist) { 50 | return fmt.Errorf("couldn't get info about the path: %w", err) 51 | } else { 52 | // If the file doesn't exist, create the parent path 53 | err := os.MkdirAll(filepath.Dir(filePath), 0o700) 54 | if err != nil { 55 | return fmt.Errorf("couldn't create parent directories to path: %w", err) 56 | } 57 | } 58 | } else if fi.IsDir() { 59 | return fmt.Errorf("path is a directory: %s", filePath) 60 | } 61 | 62 | // Open file for writing 63 | f, err := os.Create(filePath) 64 | if err != nil { 65 | return fmt.Errorf("couldn't create file: %w", err) 66 | } 67 | defer f.Close() 68 | 69 | return persistToWriter(f, obj, compress, encryptionKey) 70 | } 71 | 72 | // persistToWriter persists an object to a writer. The object is serialized 73 | // as gob, optionally compressed with flate (as gzip) and optionally encrypted with 74 | // AES-GCM. The encryption key must be 32 bytes long. 75 | // If the writer has to be closed, it's the caller's responsibility. 76 | func persistToWriter(w io.Writer, obj any, compress bool, encryptionKey string) error { 77 | // AES 256 requires a 32 byte key 78 | if encryptionKey != "" { 79 | if len(encryptionKey) != 32 { 80 | return errors.New("encryption key must be 32 bytes long") 81 | } 82 | } 83 | 84 | // We want to: 85 | // Encode as gob -> compress with flate -> encrypt with AES-GCM -> write to 86 | // passed writer. 87 | // To reduce memory usage we chain the writers instead of buffering, so we start 88 | // from the end. For AES GCM sealing the stdlib doesn't provide a writer though. 89 | 90 | var chainedWriter io.Writer 91 | if encryptionKey == "" { 92 | chainedWriter = w 93 | } else { 94 | chainedWriter = &bytes.Buffer{} 95 | } 96 | 97 | var gzw *gzip.Writer 98 | var enc *gob.Encoder 99 | if compress { 100 | gzw = gzip.NewWriter(chainedWriter) 101 | enc = gob.NewEncoder(gzw) 102 | } else { 103 | enc = gob.NewEncoder(chainedWriter) 104 | } 105 | 106 | // Start encoding, it will write to the chain of writers. 107 | if err := enc.Encode(obj); err != nil { 108 | return fmt.Errorf("couldn't encode or write object: %w", err) 109 | } 110 | 111 | // If compressing, close the gzip writer. Otherwise, the gzip footer won't be 112 | // written yet. When using encryption (and chainedWriter is a buffer) then 113 | // we'll encrypt an incomplete stream. Without encryption when we return here and having 114 | // a deferred Close(), there might be a silenced error. 115 | if compress { 116 | err := gzw.Close() 117 | if err != nil { 118 | return fmt.Errorf("couldn't close gzip writer: %w", err) 119 | } 120 | } 121 | 122 | // Without encyrption, the chain is done and the writing is finished. 123 | if encryptionKey == "" { 124 | return nil 125 | } 126 | 127 | // Otherwise, encrypt and then write to the unchained target writer. 128 | block, err := aes.NewCipher([]byte(encryptionKey)) 129 | if err != nil { 130 | return fmt.Errorf("couldn't create new AES cipher: %w", err) 131 | } 132 | gcm, err := cipher.NewGCM(block) 133 | if err != nil { 134 | return fmt.Errorf("couldn't create GCM wrapper: %w", err) 135 | } 136 | nonce := make([]byte, gcm.NonceSize()) 137 | if _, err := io.ReadFull(rand.Reader, nonce); err != nil { 138 | return fmt.Errorf("couldn't read random bytes for nonce: %w", err) 139 | } 140 | // chainedWriter is a *bytes.Buffer 141 | buf := chainedWriter.(*bytes.Buffer) 142 | encrypted := gcm.Seal(nonce, nonce, buf.Bytes(), nil) 143 | _, err = w.Write(encrypted) 144 | if err != nil { 145 | return fmt.Errorf("couldn't write encrypted data: %w", err) 146 | } 147 | 148 | return nil 149 | } 150 | 151 | // readFromFile reads an object from a file at the given path. The object is deserialized 152 | // from gob. `obj` must be a pointer to an instantiated object. The file may 153 | // optionally be compressed as gzip and/or encrypted with AES-GCM. The encryption 154 | // key must be 32 bytes long. 155 | func readFromFile(filePath string, obj any, encryptionKey string) error { 156 | if filePath == "" { 157 | return fmt.Errorf("file path is empty") 158 | } 159 | // AES 256 requires a 32 byte key 160 | if encryptionKey != "" { 161 | if len(encryptionKey) != 32 { 162 | return errors.New("encryption key must be 32 bytes long") 163 | } 164 | } 165 | 166 | r, err := os.Open(filePath) 167 | if err != nil { 168 | return fmt.Errorf("couldn't open file: %w", err) 169 | } 170 | defer r.Close() 171 | 172 | return readFromReader(r, obj, encryptionKey) 173 | } 174 | 175 | // readFromReader reads an object from a Reader. The object is deserialized from gob. 176 | // `obj` must be a pointer to an instantiated object. The stream may optionally 177 | // be compressed as gzip and/or encrypted with AES-GCM. The encryption key must 178 | // be 32 bytes long. 179 | // If the reader has to be closed, it's the caller's responsibility. 180 | func readFromReader(r io.ReadSeeker, obj any, encryptionKey string) error { 181 | // AES 256 requires a 32 byte key 182 | if encryptionKey != "" { 183 | if len(encryptionKey) != 32 { 184 | return errors.New("encryption key must be 32 bytes long") 185 | } 186 | } 187 | 188 | // We want to: 189 | // Read from reader -> decrypt with AES-GCM -> decompress with flate -> decode 190 | // as gob. 191 | // To reduce memory usage we chain the readers instead of buffering, so we start 192 | // from the end. For the decryption there's no reader though. 193 | 194 | // For the chainedReader we don't declare it as ReadSeeker, so we can reassign 195 | // the gzip reader to it. 196 | var chainedReader io.Reader 197 | 198 | // Decrypt if an encryption key is provided 199 | if encryptionKey != "" { 200 | encrypted, err := io.ReadAll(r) 201 | if err != nil { 202 | return fmt.Errorf("couldn't read from reader: %w", err) 203 | } 204 | block, err := aes.NewCipher([]byte(encryptionKey)) 205 | if err != nil { 206 | return fmt.Errorf("couldn't create AES cipher: %w", err) 207 | } 208 | gcm, err := cipher.NewGCM(block) 209 | if err != nil { 210 | return fmt.Errorf("couldn't create GCM wrapper: %w", err) 211 | } 212 | nonceSize := gcm.NonceSize() 213 | if len(encrypted) < nonceSize { 214 | return fmt.Errorf("encrypted data too short") 215 | } 216 | nonce, ciphertext := encrypted[:nonceSize], encrypted[nonceSize:] 217 | data, err := gcm.Open(nil, nonce, ciphertext, nil) 218 | if err != nil { 219 | return fmt.Errorf("couldn't decrypt data: %w", err) 220 | } 221 | 222 | chainedReader = bytes.NewReader(data) 223 | } else { 224 | chainedReader = r 225 | } 226 | 227 | // Determine if the stream is compressed 228 | magicNumber := make([]byte, 2) 229 | _, err := chainedReader.Read(magicNumber) 230 | if err != nil { 231 | return fmt.Errorf("couldn't read magic number to determine whether the stream is compressed: %w", err) 232 | } 233 | var compressed bool 234 | if magicNumber[0] == 0x1f && magicNumber[1] == 0x8b { 235 | compressed = true 236 | } 237 | 238 | // Reset reader. Both the reader from the param and bytes.Reader support seeking. 239 | if s, ok := chainedReader.(io.Seeker); !ok { 240 | return fmt.Errorf("reader doesn't support seeking") 241 | } else { 242 | _, err := s.Seek(0, 0) 243 | if err != nil { 244 | return fmt.Errorf("couldn't reset reader: %w", err) 245 | } 246 | } 247 | 248 | if compressed { 249 | gzr, err := gzip.NewReader(chainedReader) 250 | if err != nil { 251 | return fmt.Errorf("couldn't create gzip reader: %w", err) 252 | } 253 | defer gzr.Close() 254 | chainedReader = gzr 255 | } 256 | 257 | dec := gob.NewDecoder(chainedReader) 258 | err = dec.Decode(obj) 259 | if err != nil { 260 | return fmt.Errorf("couldn't decode object: %w", err) 261 | } 262 | 263 | return nil 264 | } 265 | 266 | // removeFile removes a file at the given path. If the file doesn't exist, it's a no-op. 267 | func removeFile(filePath string) error { 268 | if filePath == "" { 269 | return fmt.Errorf("file path is empty") 270 | } 271 | 272 | err := os.Remove(filePath) 273 | if err != nil { 274 | if !errors.Is(err, fs.ErrNotExist) { 275 | return fmt.Errorf("couldn't remove file %q: %w", filePath, err) 276 | } 277 | } 278 | 279 | return nil 280 | } 281 | -------------------------------------------------------------------------------- /persistence_test.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "compress/gzip" 5 | "encoding/gob" 6 | "math/rand" 7 | "os" 8 | "path/filepath" 9 | "reflect" 10 | "testing" 11 | ) 12 | 13 | func TestPersistenceWrite(t *testing.T) { 14 | tempDir, err := os.MkdirTemp("", "chromem-go") 15 | if err != nil { 16 | t.Fatal("expected nil, got", err) 17 | } 18 | defer os.RemoveAll(tempDir) 19 | 20 | type s struct { 21 | Foo string 22 | Bar []float32 23 | } 24 | obj := s{ 25 | Foo: "test", 26 | Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}` 27 | } 28 | 29 | t.Run("gob", func(t *testing.T) { 30 | tempFilePath := tempDir + ".gob" 31 | if err := persistToFile(tempFilePath, obj, false, ""); err != nil { 32 | t.Fatal("expected nil, got", err) 33 | } 34 | 35 | // Check if the file exists. 36 | _, err = os.Stat(tempFilePath) 37 | if err != nil { 38 | t.Fatal("expected nil, got", err) 39 | } 40 | 41 | // Read file and decode 42 | f, err := os.Open(tempFilePath) 43 | if err != nil { 44 | t.Fatal("expected nil, got", err) 45 | } 46 | defer f.Close() 47 | d := gob.NewDecoder(f) 48 | res := s{} 49 | err = d.Decode(&res) 50 | if err != nil { 51 | t.Fatal("expected nil, got", err) 52 | } 53 | 54 | // Compare 55 | if !reflect.DeepEqual(obj, res) { 56 | t.Fatalf("expected %+v, got %+v", obj, res) 57 | } 58 | }) 59 | 60 | t.Run("gob gzipped", func(t *testing.T) { 61 | tempFilePath := tempDir + ".gob.gz" 62 | if err := persistToFile(tempFilePath, obj, true, ""); err != nil { 63 | t.Fatal("expected nil, got", err) 64 | } 65 | 66 | // Check if the file exists. 67 | _, err = os.Stat(tempFilePath) 68 | if err != nil { 69 | t.Fatal("expected nil, got", err) 70 | } 71 | 72 | // Read file, decompress and decode 73 | f, err := os.Open(tempFilePath) 74 | if err != nil { 75 | t.Fatal("expected nil, got", err) 76 | } 77 | defer f.Close() 78 | gzr, err := gzip.NewReader(f) 79 | if err != nil { 80 | t.Fatal("expected nil, got", err) 81 | } 82 | d := gob.NewDecoder(gzr) 83 | res := s{} 84 | err = d.Decode(&res) 85 | if err != nil { 86 | t.Fatal("expected nil, got", err) 87 | } 88 | 89 | // Compare 90 | if !reflect.DeepEqual(obj, res) { 91 | t.Fatalf("expected %+v, got %+v", obj, res) 92 | } 93 | }) 94 | } 95 | 96 | func TestPersistenceRead(t *testing.T) { 97 | tempDir, err := os.MkdirTemp("", "chromem-go") 98 | if err != nil { 99 | t.Fatal("expected nil, got", err) 100 | } 101 | defer os.RemoveAll(tempDir) 102 | 103 | type s struct { 104 | Foo string 105 | Bar []float32 106 | } 107 | obj := s{ 108 | Foo: "test", 109 | Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}` 110 | } 111 | 112 | t.Run("gob", func(t *testing.T) { 113 | tempFilePath := tempDir + ".gob" 114 | f, err := os.Create(tempFilePath) 115 | if err != nil { 116 | t.Fatal("expected nil, got", err) 117 | } 118 | enc := gob.NewEncoder(f) 119 | err = enc.Encode(obj) 120 | if err != nil { 121 | t.Fatal("expected nil, got", err) 122 | } 123 | err = f.Close() 124 | if err != nil { 125 | t.Fatal("expected nil, got", err) 126 | } 127 | 128 | // Read the file. 129 | var res s 130 | err = readFromFile(tempFilePath, &res, "") 131 | if err != nil { 132 | t.Fatal("expected nil, got", err) 133 | } 134 | 135 | // Compare 136 | if !reflect.DeepEqual(obj, res) { 137 | t.Fatalf("expected %+v, got %+v", obj, res) 138 | } 139 | }) 140 | 141 | t.Run("gob gzipped", func(t *testing.T) { 142 | tempFilePath := tempDir + ".gob.gz" 143 | f, err := os.Create(tempFilePath) 144 | if err != nil { 145 | t.Fatal("expected nil, got", err) 146 | } 147 | gzw := gzip.NewWriter(f) 148 | enc := gob.NewEncoder(gzw) 149 | err = enc.Encode(obj) 150 | if err != nil { 151 | t.Fatal("expected nil, got", err) 152 | } 153 | err = gzw.Close() 154 | if err != nil { 155 | t.Fatal("expected nil, got", err) 156 | } 157 | err = f.Close() 158 | if err != nil { 159 | t.Fatal("expected nil, got", err) 160 | } 161 | 162 | // Read the file. 163 | var res s 164 | err = readFromFile(tempFilePath, &res, "") 165 | if err != nil { 166 | t.Fatal("expected nil, got", err) 167 | } 168 | 169 | // Compare 170 | if !reflect.DeepEqual(obj, res) { 171 | t.Fatalf("expected %+v, got %+v", obj, res) 172 | } 173 | }) 174 | } 175 | 176 | func TestPersistenceEncryption(t *testing.T) { 177 | // Instead of copy pasting encryption/decryption code, we resort to using both 178 | // functions under test, instead of one combined with an independent implementation. 179 | 180 | r := rand.New(rand.NewSource(rand.Int63())) 181 | // randString := randomString(r, 10) 182 | path := filepath.Join(os.TempDir(), "a", "chromem-go") 183 | // defer os.RemoveAll(path) 184 | 185 | type s struct { 186 | Foo string 187 | Bar []float32 188 | } 189 | obj := s{ 190 | Foo: "test", 191 | Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}` 192 | } 193 | encryptionKey := randomString(r, 32) 194 | 195 | tt := []struct { 196 | name string 197 | filePath string 198 | compress bool 199 | }{ 200 | { 201 | name: "compress false", 202 | filePath: path + ".gob.enc", 203 | compress: false, 204 | }, 205 | { 206 | name: "compress true", 207 | filePath: path + ".gob.gz.enc", 208 | compress: true, 209 | }, 210 | } 211 | 212 | for _, tc := range tt { 213 | t.Run(tc.name, func(t *testing.T) { 214 | err := persistToFile(tc.filePath, obj, tc.compress, encryptionKey) 215 | if err != nil { 216 | t.Fatal("expected nil, got", err) 217 | } 218 | 219 | // Check if the file exists. 220 | _, err = os.Stat(tc.filePath) 221 | if err != nil { 222 | t.Fatal("expected nil, got", err) 223 | } 224 | 225 | // Read the file. 226 | var res s 227 | err = readFromFile(tc.filePath, &res, encryptionKey) 228 | if err != nil { 229 | t.Fatal("expected nil, got", err) 230 | } 231 | 232 | // Compare 233 | if !reflect.DeepEqual(obj, res) { 234 | t.Fatalf("expected %+v, got %+v", obj, res) 235 | } 236 | }) 237 | } 238 | } 239 | -------------------------------------------------------------------------------- /query.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "cmp" 5 | "container/heap" 6 | "context" 7 | "fmt" 8 | "runtime" 9 | "slices" 10 | "strings" 11 | "sync" 12 | ) 13 | 14 | var supportedFilters = []string{"$contains", "$not_contains"} 15 | 16 | type docSim struct { 17 | docID string 18 | similarity float32 19 | } 20 | 21 | // docMaxHeap is a max-heap of docSims, based on similarity. 22 | // See https://pkg.go.dev/container/heap@go1.22#example-package-IntHeap 23 | type docMaxHeap []docSim 24 | 25 | func (h docMaxHeap) Len() int { return len(h) } 26 | func (h docMaxHeap) Less(i, j int) bool { return h[i].similarity < h[j].similarity } 27 | func (h docMaxHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } 28 | 29 | func (h *docMaxHeap) Push(x any) { 30 | // Push and Pop use pointer receivers because they modify the slice's length, 31 | // not just its contents. 32 | *h = append(*h, x.(docSim)) 33 | } 34 | 35 | func (h *docMaxHeap) Pop() any { 36 | old := *h 37 | n := len(old) 38 | x := old[n-1] 39 | *h = old[0 : n-1] 40 | return x 41 | } 42 | 43 | // maxDocSims manages a max-heap of docSims with a fixed size, keeping the n highest 44 | // similarities. It's safe for concurrent use, but not the result of values(). 45 | // In our benchmarks this was faster than sorting a slice of docSims at the end. 46 | type maxDocSims struct { 47 | h docMaxHeap 48 | lock sync.RWMutex 49 | size int 50 | } 51 | 52 | // newMaxDocSims creates a new nMaxDocs with a fixed size. 53 | func newMaxDocSims(size int) *maxDocSims { 54 | return &maxDocSims{ 55 | h: make(docMaxHeap, 0, size), 56 | size: size, 57 | } 58 | } 59 | 60 | // add inserts a new docSim into the heap, keeping only the top n similarities. 61 | func (d *maxDocSims) add(doc docSim) { 62 | d.lock.Lock() 63 | defer d.lock.Unlock() 64 | if d.h.Len() < d.size { 65 | heap.Push(&d.h, doc) 66 | } else if d.h.Len() > 0 && d.h[0].similarity < doc.similarity { 67 | // Replace the smallest similarity if the new doc's similarity is higher 68 | heap.Pop(&d.h) 69 | heap.Push(&d.h, doc) 70 | } 71 | } 72 | 73 | // values returns the docSims in the heap, sorted by similarity (descending). 74 | // The call itself is safe for concurrent use with add(), but the result isn't. 75 | // Only work with the result after all calls to add() have finished. 76 | func (d *maxDocSims) values() []docSim { 77 | d.lock.RLock() 78 | defer d.lock.RUnlock() 79 | slices.SortFunc(d.h, func(i, j docSim) int { 80 | return cmp.Compare(j.similarity, i.similarity) 81 | }) 82 | return d.h 83 | } 84 | 85 | // filterDocs filters a map of documents by metadata and content. 86 | // It does this concurrently. 87 | func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document { 88 | filteredDocs := make([]*Document, 0, len(docs)) 89 | filteredDocsLock := sync.Mutex{} 90 | 91 | // Determine concurrency. Use number of docs or CPUs, whichever is smaller. 92 | numCPUs := runtime.NumCPU() 93 | numDocs := len(docs) 94 | concurrency := numCPUs 95 | if numDocs < numCPUs { 96 | concurrency = numDocs 97 | } 98 | 99 | docChan := make(chan *Document, concurrency*2) 100 | 101 | wg := sync.WaitGroup{} 102 | for i := 0; i < concurrency; i++ { 103 | wg.Add(1) 104 | go func() { 105 | defer wg.Done() 106 | for doc := range docChan { 107 | if documentMatchesFilters(doc, where, whereDocument) { 108 | filteredDocsLock.Lock() 109 | filteredDocs = append(filteredDocs, doc) 110 | filteredDocsLock.Unlock() 111 | } 112 | } 113 | }() 114 | } 115 | 116 | for _, doc := range docs { 117 | docChan <- doc 118 | } 119 | close(docChan) 120 | 121 | wg.Wait() 122 | 123 | // With filteredDocs being initialized as potentially large slice, let's return 124 | // nil instead of the empty slice. 125 | if len(filteredDocs) == 0 { 126 | filteredDocs = nil 127 | } 128 | return filteredDocs 129 | } 130 | 131 | // documentMatchesFilters checks if a document matches the given filters. 132 | // When calling this function, the whereDocument keys must already be validated! 133 | func documentMatchesFilters(document *Document, where, whereDocument map[string]string) bool { 134 | // A document's metadata must have *all* the fields in the where clause. 135 | for k, v := range where { 136 | // TODO: Do we want to check for existence of the key? I.e. should 137 | // a where clause with empty string as value match a document's 138 | // metadata that doesn't have the key at all? 139 | if document.Metadata[k] != v { 140 | return false 141 | } 142 | } 143 | 144 | // A document must satisfy *all* filters, until we support the `$or` operator. 145 | for k, v := range whereDocument { 146 | switch k { 147 | case "$contains": 148 | if !strings.Contains(document.Content, v) { 149 | return false 150 | } 151 | case "$not_contains": 152 | if strings.Contains(document.Content, v) { 153 | return false 154 | } 155 | default: 156 | // No handling (error) required because we already validated the 157 | // operators. This simplifies the concurrency logic (no err var 158 | // and lock, no context to cancel). 159 | } 160 | } 161 | 162 | return true 163 | } 164 | 165 | func getMostSimilarDocs(ctx context.Context, queryVectors, negativeVector []float32, negativeFilterThreshold float32, docs []*Document, n int) ([]docSim, error) { 166 | nMaxDocs := newMaxDocSims(n) 167 | 168 | // Determine concurrency. Use number of docs or CPUs, whichever is smaller. 169 | numCPUs := runtime.NumCPU() 170 | numDocs := len(docs) 171 | concurrency := numCPUs 172 | if numDocs < numCPUs { 173 | concurrency = numDocs 174 | } 175 | 176 | var sharedErr error 177 | sharedErrLock := sync.Mutex{} 178 | ctx, cancel := context.WithCancelCause(ctx) 179 | defer cancel(nil) 180 | setSharedErr := func(err error) { 181 | sharedErrLock.Lock() 182 | defer sharedErrLock.Unlock() 183 | // Another goroutine might have already set the error. 184 | if sharedErr == nil { 185 | sharedErr = err 186 | // Cancel the operation for all other goroutines. 187 | cancel(sharedErr) 188 | } 189 | } 190 | 191 | wg := sync.WaitGroup{} 192 | // Instead of using a channel to pass documents into the goroutines, we just 193 | // split the slice into sub-slices and pass those to the goroutines. 194 | // This turned out to be faster in the query benchmarks. 195 | subSliceSize := len(docs) / concurrency // Can leave remainder, e.g. 10/3 = 3; leaves 1 196 | rem := len(docs) % concurrency 197 | for i := 0; i < concurrency; i++ { 198 | start := i * subSliceSize 199 | end := start + subSliceSize 200 | // Add remainder to last goroutine 201 | if i == concurrency-1 { 202 | end += rem 203 | } 204 | 205 | wg.Add(1) 206 | go func(subSlice []*Document) { 207 | defer wg.Done() 208 | for _, doc := range subSlice { 209 | // Stop work if another goroutine encountered an error. 210 | if ctx.Err() != nil { 211 | return 212 | } 213 | 214 | // As the vectors are normalized, the dot product is the cosine similarity. 215 | sim, err := dotProduct(queryVectors, doc.Embedding) 216 | if err != nil { 217 | setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err)) 218 | return 219 | } 220 | 221 | if negativeFilterThreshold > 0 { 222 | nsim, err := dotProduct(negativeVector, doc.Embedding) 223 | if err != nil { 224 | setSharedErr(fmt.Errorf("couldn't calculate negative similarity for document '%s': %w", doc.ID, err)) 225 | return 226 | } 227 | 228 | if nsim > negativeFilterThreshold { 229 | continue 230 | } 231 | } 232 | 233 | nMaxDocs.add(docSim{docID: doc.ID, similarity: sim}) 234 | } 235 | }(docs[start:end]) 236 | } 237 | 238 | wg.Wait() 239 | 240 | if sharedErr != nil { 241 | return nil, sharedErr 242 | } 243 | 244 | return nMaxDocs.values(), nil 245 | } 246 | -------------------------------------------------------------------------------- /query_test.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "context" 5 | "reflect" 6 | "slices" 7 | "testing" 8 | ) 9 | 10 | func TestFilterDocs(t *testing.T) { 11 | docs := map[string]*Document{ 12 | "1": { 13 | ID: "1", 14 | Metadata: map[string]string{ 15 | "language": "en", 16 | }, 17 | Embedding: []float32{0.1, 0.2, 0.3}, 18 | Content: "hello world", 19 | }, 20 | "2": { 21 | ID: "2", 22 | Metadata: map[string]string{ 23 | "language": "de", 24 | }, 25 | Embedding: []float32{0.2, 0.3, 0.4}, 26 | Content: "hallo welt", 27 | }, 28 | } 29 | 30 | tt := []struct { 31 | name string 32 | where map[string]string 33 | whereDocument map[string]string 34 | want []*Document 35 | }{ 36 | { 37 | name: "meta match", 38 | where: map[string]string{"language": "de"}, 39 | whereDocument: nil, 40 | want: []*Document{docs["2"]}, 41 | }, 42 | { 43 | name: "meta no match", 44 | where: map[string]string{"language": "fr"}, 45 | whereDocument: nil, 46 | want: nil, 47 | }, 48 | { 49 | name: "content contains all", 50 | where: nil, 51 | whereDocument: map[string]string{"$contains": "llo"}, 52 | want: []*Document{docs["1"], docs["2"]}, 53 | }, 54 | { 55 | name: "content contains one", 56 | where: nil, 57 | whereDocument: map[string]string{"$contains": "hallo"}, 58 | want: []*Document{docs["2"]}, 59 | }, 60 | { 61 | name: "content contains none", 62 | where: nil, 63 | whereDocument: map[string]string{"$contains": "bonjour"}, 64 | want: nil, 65 | }, 66 | { 67 | name: "content not_contains all", 68 | where: nil, 69 | whereDocument: map[string]string{"$not_contains": "bonjour"}, 70 | want: []*Document{docs["1"], docs["2"]}, 71 | }, 72 | { 73 | name: "content not_contains one", 74 | where: nil, 75 | whereDocument: map[string]string{"$not_contains": "hello"}, 76 | want: []*Document{docs["2"]}, 77 | }, 78 | { 79 | name: "meta and content match", 80 | where: map[string]string{"language": "de"}, 81 | whereDocument: map[string]string{"$contains": "hallo"}, 82 | want: []*Document{docs["2"]}, 83 | }, 84 | { 85 | name: "meta + contains + not_contains", 86 | where: map[string]string{"language": "de"}, 87 | whereDocument: map[string]string{"$contains": "hallo", "$not_contains": "bonjour"}, 88 | want: []*Document{docs["2"]}, 89 | }, 90 | } 91 | 92 | for _, tc := range tt { 93 | t.Run(tc.name, func(t *testing.T) { 94 | got := filterDocs(docs, tc.where, tc.whereDocument) 95 | 96 | if !reflect.DeepEqual(got, tc.want) { 97 | // If len is 2, the order might be different (function under test 98 | // is concurrent and order is not guaranteed). 99 | if len(got) == 2 && len(tc.want) == 2 { 100 | slices.Reverse(got) 101 | if reflect.DeepEqual(got, tc.want) { 102 | return 103 | } 104 | } 105 | t.Fatalf("got %v; want %v", got, tc.want) 106 | } 107 | }) 108 | } 109 | } 110 | 111 | func TestNegative(t *testing.T) { 112 | ctx := context.Background() 113 | db := NewDB() 114 | 115 | c, err := db.CreateCollection("test", nil, nil) 116 | if err != nil { 117 | panic(err) 118 | } 119 | 120 | if err := c.AddDocuments(ctx, []Document{ 121 | { 122 | ID: "1", 123 | Embedding: testEmbeddings["search_document: Village Builder Game"], 124 | }, 125 | { 126 | ID: "2", 127 | Embedding: testEmbeddings["search_document: Town Craft Idle Game"], 128 | }, 129 | { 130 | ID: "3", 131 | Embedding: testEmbeddings["search_document: Some Idle Game"], 132 | }, 133 | }, 1); err != nil { 134 | t.Fatalf("failed to add documents: %v", err) 135 | } 136 | 137 | t.Run("NEGATIVE_MODE_SUBTRACT", func(t *testing.T) { 138 | res, err := c.QueryWithOptions(ctx, QueryOptions{ 139 | QueryEmbedding: testEmbeddings["search_query: town"], 140 | NResults: c.Count(), 141 | Negative: NegativeQueryOptions{ 142 | Embedding: testEmbeddings["search_query: idle"], 143 | Mode: NEGATIVE_MODE_SUBTRACT, 144 | }, 145 | }) 146 | if err != nil { 147 | panic(err) 148 | } 149 | 150 | for _, r := range res { 151 | t.Logf("%s: %v", r.ID, r.Similarity) 152 | } 153 | 154 | if len(res) != 3 { 155 | t.Fatalf("expected 3 results, got %d", len(res)) 156 | } 157 | 158 | // Village Builder Game 159 | if res[0].ID != "1" { 160 | t.Fatalf("expected document with ID 1, got %s", res[0].ID) 161 | } 162 | // Town Craft Idle Game 163 | if res[1].ID != "2" { 164 | t.Fatalf("expected document with ID 2, got %s", res[1].ID) 165 | } 166 | // Some Idle Game 167 | if res[2].ID != "3" { 168 | t.Fatalf("expected document with ID 3, got %s", res[2].ID) 169 | } 170 | }) 171 | 172 | t.Run("NEGATIVE_MODE_FILTER", func(t *testing.T) { 173 | res, err := c.QueryWithOptions(ctx, QueryOptions{ 174 | QueryEmbedding: testEmbeddings["search_query: town"], 175 | NResults: c.Count(), 176 | Negative: NegativeQueryOptions{ 177 | Embedding: testEmbeddings["search_query: idle"], 178 | Mode: NEGATIVE_MODE_FILTER, 179 | }, 180 | }) 181 | if err != nil { 182 | panic(err) 183 | } 184 | 185 | for _, r := range res { 186 | t.Logf("%s: %v", r.ID, r.Similarity) 187 | } 188 | 189 | if len(res) != 1 { 190 | t.Fatalf("expected 1 result, got %d", len(res)) 191 | } 192 | 193 | // Village Builder Game 194 | if res[0].ID != "1" { 195 | t.Fatalf("expected document with ID 1, got %s", res[0].ID) 196 | } 197 | }) 198 | } 199 | -------------------------------------------------------------------------------- /vector.go: -------------------------------------------------------------------------------- 1 | package chromem 2 | 3 | import ( 4 | "errors" 5 | "math" 6 | ) 7 | 8 | const isNormalizedPrecisionTolerance = 1e-6 9 | 10 | // dotProduct calculates the dot product between two vectors. 11 | // It's the same as cosine similarity for normalized vectors. 12 | // The resulting value represents the similarity, so a higher value means the 13 | // vectors are more similar. 14 | func dotProduct(a, b []float32) (float32, error) { 15 | // The vectors must have the same length 16 | if len(a) != len(b) { 17 | return 0, errors.New("vectors must have the same length") 18 | } 19 | 20 | var dotProduct float32 21 | for i := range a { 22 | dotProduct += a[i] * b[i] 23 | } 24 | 25 | return dotProduct, nil 26 | } 27 | 28 | func normalizeVector(v []float32) []float32 { 29 | var norm float32 30 | for _, val := range v { 31 | norm += val * val 32 | } 33 | norm = float32(math.Sqrt(float64(norm))) 34 | 35 | res := make([]float32, len(v)) 36 | for i, val := range v { 37 | res[i] = val / norm 38 | } 39 | 40 | return res 41 | } 42 | 43 | // subtractVector subtracts vector b from vector a in place. 44 | func subtractVector(a, b []float32) []float32 { 45 | res := make([]float32, len(a)) 46 | 47 | for i := range a { 48 | res[i] = a[i] - b[i] 49 | } 50 | 51 | return res 52 | } 53 | 54 | // isNormalized checks if the vector is normalized. 55 | func isNormalized(v []float32) bool { 56 | var sqSum float64 57 | for _, val := range v { 58 | sqSum += float64(val) * float64(val) 59 | } 60 | magnitude := math.Sqrt(sqSum) 61 | return math.Abs(magnitude-1) < isNormalizedPrecisionTolerance 62 | } 63 | -------------------------------------------------------------------------------- /wasm/main.go: -------------------------------------------------------------------------------- 1 | //go:build js 2 | 3 | package main 4 | 5 | import ( 6 | "context" 7 | "errors" 8 | "syscall/js" 9 | 10 | "github.com/philippgille/chromem-go" 11 | ) 12 | 13 | var c *chromem.Collection 14 | 15 | func main() { 16 | js.Global().Set("initDB", js.FuncOf(initDB)) 17 | js.Global().Set("addDocument", js.FuncOf(addDocument)) 18 | js.Global().Set("query", js.FuncOf(query)) 19 | 20 | select {} // prevent main from exiting 21 | } 22 | 23 | // Exported function to initialize the database and collection. 24 | // Takes an OpenAI API key as argument. 25 | func initDB(this js.Value, args []js.Value) interface{} { 26 | if len(args) != 1 { 27 | return "expected 1 argument with the OpenAI API key" 28 | } 29 | 30 | openAIAPIKey := args[0].String() 31 | embeddingFunc := chromem.NewEmbeddingFuncOpenAI(openAIAPIKey, chromem.EmbeddingModelOpenAI3Small) 32 | 33 | db := chromem.NewDB() 34 | var err error 35 | c, err = db.CreateCollection("chromem", nil, embeddingFunc) 36 | if err != nil { 37 | return err.Error() 38 | } 39 | 40 | return nil 41 | } 42 | 43 | // Exported function to add documents to the collection. 44 | // Takes the document ID and content as arguments. 45 | func addDocument(this js.Value, args []js.Value) interface{} { 46 | ctx := context.Background() 47 | 48 | var id string 49 | var content string 50 | var err error 51 | if len(args) != 2 { 52 | err = errors.New("expected 2 arguments with the document ID and content") 53 | } else { 54 | id = args[0].String() 55 | content = args[1].String() 56 | } 57 | 58 | handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} { 59 | resolve := args[0] 60 | reject := args[1] 61 | go func() { 62 | if err != nil { 63 | handleErr(err, reject) 64 | return 65 | } 66 | 67 | err = c.AddDocument(ctx, chromem.Document{ 68 | ID: id, 69 | Content: content, 70 | }) 71 | if err != nil { 72 | handleErr(err, reject) 73 | return 74 | } 75 | resolve.Invoke() 76 | }() 77 | return nil 78 | }) 79 | 80 | promiseConstructor := js.Global().Get("Promise") 81 | return promiseConstructor.New(handler) 82 | } 83 | 84 | // Exported function to query the collection 85 | // Takes the query string and the number of documents to return as argument. 86 | func query(this js.Value, args []js.Value) interface{} { 87 | ctx := context.Background() 88 | 89 | var q string 90 | var err error 91 | if len(args) != 1 { 92 | err = errors.New("expected 1 argument with the query string") 93 | } else { 94 | q = args[0].String() 95 | } 96 | 97 | handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} { 98 | resolve := args[0] 99 | reject := args[1] 100 | go func() { 101 | if err != nil { 102 | handleErr(err, reject) 103 | return 104 | } 105 | 106 | res, err := c.Query(ctx, q, 1, nil, nil) 107 | if err != nil { 108 | handleErr(err, reject) 109 | return 110 | } 111 | 112 | // Convert response to JS values 113 | // TODO: Return more than one result 114 | o := js.Global().Get("Object").New() 115 | o.Set("ID", res[0].ID) 116 | o.Set("Similarity", res[0].Similarity) 117 | o.Set("Content", res[0].Content) 118 | 119 | resolve.Invoke(o) 120 | }() 121 | return nil 122 | }) 123 | 124 | promiseConstructor := js.Global().Get("Promise") 125 | return promiseConstructor.New(handler) 126 | } 127 | 128 | func handleErr(err error, reject js.Value) { 129 | errorConstructor := js.Global().Get("Error") 130 | errorObject := errorConstructor.New(err.Error()) 131 | reject.Invoke(errorObject) 132 | } 133 | --------------------------------------------------------------------------------