├── .github
└── workflows
│ └── go.yml
├── .gitignore
├── CHANGELOG.md
├── LICENSE
├── README.md
├── collection.go
├── collection_test.go
├── db.go
├── db_test.go
├── document.go
├── document_test.go
├── embed_cohere.go
├── embed_compat.go
├── embed_ollama.go
├── embed_ollama_test.go
├── embed_openai.go
├── embed_openai_test.go
├── embed_vertex.go
├── examples
├── README.md
├── minimal
│ ├── README.md
│ ├── go.mod
│ └── main.go
├── rag-wikipedia-ollama
│ ├── .gitignore
│ ├── README.md
│ ├── dbpedia_sample.jsonl
│ ├── go.mod
│ ├── go.sum
│ ├── llm.go
│ └── main.go
├── s3-export-import
│ ├── README.md
│ ├── go.mod
│ ├── go.sum
│ └── main.go
├── semantic-search-arxiv-openai
│ ├── .gitignore
│ ├── README.md
│ ├── go.mod
│ └── main.go
└── webassembly
│ ├── README.md
│ └── index.html
├── fixtures_test.go
├── go.mod
├── go.sum
├── persistence.go
├── persistence_test.go
├── query.go
├── query_test.go
├── vector.go
└── wasm
└── main.go
/.github/workflows/go.yml:
--------------------------------------------------------------------------------
1 | # This workflow will build a golang project
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
3 |
4 | name: Go
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 |
11 | jobs:
12 |
13 | lint:
14 | runs-on: ubuntu-latest
15 | strategy:
16 | matrix:
17 | # We make use of the `slices` feature, only available in 1.21 and newer
18 | go-version: [ '1.21', '1.22', '1.23' ]
19 |
20 | steps:
21 | - uses: actions/checkout@v4
22 |
23 | - name: Set up Go
24 | uses: actions/setup-go@v5
25 | with:
26 | go-version: ${{ matrix.go-version }}
27 |
28 | - name: Lint
29 | uses: golangci/golangci-lint-action@v6
30 | with:
31 | version: v1.60.3
32 | args: --verbose
33 |
34 | build:
35 | runs-on: ubuntu-latest
36 | strategy:
37 | matrix:
38 | # We make use of the `slices` feature, only available in 1.21 and newer
39 | go-version: [ '1.21', '1.22', '1.23' ]
40 |
41 | steps:
42 | - uses: actions/checkout@v4
43 |
44 | - name: Set up Go
45 | uses: actions/setup-go@v5
46 | with:
47 | go-version: ${{ matrix.go-version }}
48 |
49 | - name: Build
50 | run: go build -v ./...
51 |
52 | - name: Test
53 | run: go test -v -race ./...
54 |
55 | examples:
56 | runs-on: ubuntu-latest
57 | strategy:
58 | matrix:
59 | # We make use of the `slices` feature, only available in 1.21 and newer
60 | go-version: [ '1.21', '1.22', '1.23' ]
61 |
62 | steps:
63 | - uses: actions/checkout@v4
64 |
65 | - name: Set up Go
66 | uses: actions/setup-go@v5
67 | with:
68 | go-version: ${{ matrix.go-version }}
69 |
70 | - name: Build minimal example
71 | run: |
72 | cd examples/minimal
73 | go build -v ./...
74 |
75 | - name: Build RAG Wikipedia Ollama
76 | run: |
77 | cd examples/rag-wikipedia-ollama
78 | go build -v ./...
79 |
80 | - name: Semantic search arXiv OpenAI
81 | run: |
82 | cd examples/semantic-search-arxiv-openai
83 | go build -v ./...
84 |
85 | - name: S3 export/import
86 | run: |
87 | cd examples/s3-export-import
88 | go build -v ./...
89 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # If you prefer the allow list template instead of the deny list, see community template:
2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
3 | #
4 | # Binaries for programs and plugins
5 | *.exe
6 | *.exe~
7 | *.dll
8 | *.so
9 | *.dylib
10 |
11 | # Test binary, built with `go test -c`
12 | *.test
13 |
14 | # Output of the go coverage tool, specifically when used with LiteIDE
15 | *.out
16 |
17 | # Dependency directories (remove the comment below to include it)
18 | # vendor/
19 |
20 | # Go workspace file
21 | go.work
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Mozilla Public License Version 2.0
2 | ==================================
3 |
4 | 1. Definitions
5 | --------------
6 |
7 | 1.1. "Contributor"
8 | means each individual or legal entity that creates, contributes to
9 | the creation of, or owns Covered Software.
10 |
11 | 1.2. "Contributor Version"
12 | means the combination of the Contributions of others (if any) used
13 | by a Contributor and that particular Contributor's Contribution.
14 |
15 | 1.3. "Contribution"
16 | means Covered Software of a particular Contributor.
17 |
18 | 1.4. "Covered Software"
19 | means Source Code Form to which the initial Contributor has attached
20 | the notice in Exhibit A, the Executable Form of such Source Code
21 | Form, and Modifications of such Source Code Form, in each case
22 | including portions thereof.
23 |
24 | 1.5. "Incompatible With Secondary Licenses"
25 | means
26 |
27 | (a) that the initial Contributor has attached the notice described
28 | in Exhibit B to the Covered Software; or
29 |
30 | (b) that the Covered Software was made available under the terms of
31 | version 1.1 or earlier of the License, but not also under the
32 | terms of a Secondary License.
33 |
34 | 1.6. "Executable Form"
35 | means any form of the work other than Source Code Form.
36 |
37 | 1.7. "Larger Work"
38 | means a work that combines Covered Software with other material, in
39 | a separate file or files, that is not Covered Software.
40 |
41 | 1.8. "License"
42 | means this document.
43 |
44 | 1.9. "Licensable"
45 | means having the right to grant, to the maximum extent possible,
46 | whether at the time of the initial grant or subsequently, any and
47 | all of the rights conveyed by this License.
48 |
49 | 1.10. "Modifications"
50 | means any of the following:
51 |
52 | (a) any file in Source Code Form that results from an addition to,
53 | deletion from, or modification of the contents of Covered
54 | Software; or
55 |
56 | (b) any new file in Source Code Form that contains any Covered
57 | Software.
58 |
59 | 1.11. "Patent Claims" of a Contributor
60 | means any patent claim(s), including without limitation, method,
61 | process, and apparatus claims, in any patent Licensable by such
62 | Contributor that would be infringed, but for the grant of the
63 | License, by the making, using, selling, offering for sale, having
64 | made, import, or transfer of either its Contributions or its
65 | Contributor Version.
66 |
67 | 1.12. "Secondary License"
68 | means either the GNU General Public License, Version 2.0, the GNU
69 | Lesser General Public License, Version 2.1, the GNU Affero General
70 | Public License, Version 3.0, or any later versions of those
71 | licenses.
72 |
73 | 1.13. "Source Code Form"
74 | means the form of the work preferred for making modifications.
75 |
76 | 1.14. "You" (or "Your")
77 | means an individual or a legal entity exercising rights under this
78 | License. For legal entities, "You" includes any entity that
79 | controls, is controlled by, or is under common control with You. For
80 | purposes of this definition, "control" means (a) the power, direct
81 | or indirect, to cause the direction or management of such entity,
82 | whether by contract or otherwise, or (b) ownership of more than
83 | fifty percent (50%) of the outstanding shares or beneficial
84 | ownership of such entity.
85 |
86 | 2. License Grants and Conditions
87 | --------------------------------
88 |
89 | 2.1. Grants
90 |
91 | Each Contributor hereby grants You a world-wide, royalty-free,
92 | non-exclusive license:
93 |
94 | (a) under intellectual property rights (other than patent or trademark)
95 | Licensable by such Contributor to use, reproduce, make available,
96 | modify, display, perform, distribute, and otherwise exploit its
97 | Contributions, either on an unmodified basis, with Modifications, or
98 | as part of a Larger Work; and
99 |
100 | (b) under Patent Claims of such Contributor to make, use, sell, offer
101 | for sale, have made, import, and otherwise transfer either its
102 | Contributions or its Contributor Version.
103 |
104 | 2.2. Effective Date
105 |
106 | The licenses granted in Section 2.1 with respect to any Contribution
107 | become effective for each Contribution on the date the Contributor first
108 | distributes such Contribution.
109 |
110 | 2.3. Limitations on Grant Scope
111 |
112 | The licenses granted in this Section 2 are the only rights granted under
113 | this License. No additional rights or licenses will be implied from the
114 | distribution or licensing of Covered Software under this License.
115 | Notwithstanding Section 2.1(b) above, no patent license is granted by a
116 | Contributor:
117 |
118 | (a) for any code that a Contributor has removed from Covered Software;
119 | or
120 |
121 | (b) for infringements caused by: (i) Your and any other third party's
122 | modifications of Covered Software, or (ii) the combination of its
123 | Contributions with other software (except as part of its Contributor
124 | Version); or
125 |
126 | (c) under Patent Claims infringed by Covered Software in the absence of
127 | its Contributions.
128 |
129 | This License does not grant any rights in the trademarks, service marks,
130 | or logos of any Contributor (except as may be necessary to comply with
131 | the notice requirements in Section 3.4).
132 |
133 | 2.4. Subsequent Licenses
134 |
135 | No Contributor makes additional grants as a result of Your choice to
136 | distribute the Covered Software under a subsequent version of this
137 | License (see Section 10.2) or under the terms of a Secondary License (if
138 | permitted under the terms of Section 3.3).
139 |
140 | 2.5. Representation
141 |
142 | Each Contributor represents that the Contributor believes its
143 | Contributions are its original creation(s) or it has sufficient rights
144 | to grant the rights to its Contributions conveyed by this License.
145 |
146 | 2.6. Fair Use
147 |
148 | This License is not intended to limit any rights You have under
149 | applicable copyright doctrines of fair use, fair dealing, or other
150 | equivalents.
151 |
152 | 2.7. Conditions
153 |
154 | Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
155 | in Section 2.1.
156 |
157 | 3. Responsibilities
158 | -------------------
159 |
160 | 3.1. Distribution of Source Form
161 |
162 | All distribution of Covered Software in Source Code Form, including any
163 | Modifications that You create or to which You contribute, must be under
164 | the terms of this License. You must inform recipients that the Source
165 | Code Form of the Covered Software is governed by the terms of this
166 | License, and how they can obtain a copy of this License. You may not
167 | attempt to alter or restrict the recipients' rights in the Source Code
168 | Form.
169 |
170 | 3.2. Distribution of Executable Form
171 |
172 | If You distribute Covered Software in Executable Form then:
173 |
174 | (a) such Covered Software must also be made available in Source Code
175 | Form, as described in Section 3.1, and You must inform recipients of
176 | the Executable Form how they can obtain a copy of such Source Code
177 | Form by reasonable means in a timely manner, at a charge no more
178 | than the cost of distribution to the recipient; and
179 |
180 | (b) You may distribute such Executable Form under the terms of this
181 | License, or sublicense it under different terms, provided that the
182 | license for the Executable Form does not attempt to limit or alter
183 | the recipients' rights in the Source Code Form under this License.
184 |
185 | 3.3. Distribution of a Larger Work
186 |
187 | You may create and distribute a Larger Work under terms of Your choice,
188 | provided that You also comply with the requirements of this License for
189 | the Covered Software. If the Larger Work is a combination of Covered
190 | Software with a work governed by one or more Secondary Licenses, and the
191 | Covered Software is not Incompatible With Secondary Licenses, this
192 | License permits You to additionally distribute such Covered Software
193 | under the terms of such Secondary License(s), so that the recipient of
194 | the Larger Work may, at their option, further distribute the Covered
195 | Software under the terms of either this License or such Secondary
196 | License(s).
197 |
198 | 3.4. Notices
199 |
200 | You may not remove or alter the substance of any license notices
201 | (including copyright notices, patent notices, disclaimers of warranty,
202 | or limitations of liability) contained within the Source Code Form of
203 | the Covered Software, except that You may alter any license notices to
204 | the extent required to remedy known factual inaccuracies.
205 |
206 | 3.5. Application of Additional Terms
207 |
208 | You may choose to offer, and to charge a fee for, warranty, support,
209 | indemnity or liability obligations to one or more recipients of Covered
210 | Software. However, You may do so only on Your own behalf, and not on
211 | behalf of any Contributor. You must make it absolutely clear that any
212 | such warranty, support, indemnity, or liability obligation is offered by
213 | You alone, and You hereby agree to indemnify every Contributor for any
214 | liability incurred by such Contributor as a result of warranty, support,
215 | indemnity or liability terms You offer. You may include additional
216 | disclaimers of warranty and limitations of liability specific to any
217 | jurisdiction.
218 |
219 | 4. Inability to Comply Due to Statute or Regulation
220 | ---------------------------------------------------
221 |
222 | If it is impossible for You to comply with any of the terms of this
223 | License with respect to some or all of the Covered Software due to
224 | statute, judicial order, or regulation then You must: (a) comply with
225 | the terms of this License to the maximum extent possible; and (b)
226 | describe the limitations and the code they affect. Such description must
227 | be placed in a text file included with all distributions of the Covered
228 | Software under this License. Except to the extent prohibited by statute
229 | or regulation, such description must be sufficiently detailed for a
230 | recipient of ordinary skill to be able to understand it.
231 |
232 | 5. Termination
233 | --------------
234 |
235 | 5.1. The rights granted under this License will terminate automatically
236 | if You fail to comply with any of its terms. However, if You become
237 | compliant, then the rights granted under this License from a particular
238 | Contributor are reinstated (a) provisionally, unless and until such
239 | Contributor explicitly and finally terminates Your grants, and (b) on an
240 | ongoing basis, if such Contributor fails to notify You of the
241 | non-compliance by some reasonable means prior to 60 days after You have
242 | come back into compliance. Moreover, Your grants from a particular
243 | Contributor are reinstated on an ongoing basis if such Contributor
244 | notifies You of the non-compliance by some reasonable means, this is the
245 | first time You have received notice of non-compliance with this License
246 | from such Contributor, and You become compliant prior to 30 days after
247 | Your receipt of the notice.
248 |
249 | 5.2. If You initiate litigation against any entity by asserting a patent
250 | infringement claim (excluding declaratory judgment actions,
251 | counter-claims, and cross-claims) alleging that a Contributor Version
252 | directly or indirectly infringes any patent, then the rights granted to
253 | You by any and all Contributors for the Covered Software under Section
254 | 2.1 of this License shall terminate.
255 |
256 | 5.3. In the event of termination under Sections 5.1 or 5.2 above, all
257 | end user license agreements (excluding distributors and resellers) which
258 | have been validly granted by You or Your distributors under this License
259 | prior to termination shall survive termination.
260 |
261 | ************************************************************************
262 | * *
263 | * 6. Disclaimer of Warranty *
264 | * ------------------------- *
265 | * *
266 | * Covered Software is provided under this License on an "as is" *
267 | * basis, without warranty of any kind, either expressed, implied, or *
268 | * statutory, including, without limitation, warranties that the *
269 | * Covered Software is free of defects, merchantable, fit for a *
270 | * particular purpose or non-infringing. The entire risk as to the *
271 | * quality and performance of the Covered Software is with You. *
272 | * Should any Covered Software prove defective in any respect, You *
273 | * (not any Contributor) assume the cost of any necessary servicing, *
274 | * repair, or correction. This disclaimer of warranty constitutes an *
275 | * essential part of this License. No use of any Covered Software is *
276 | * authorized under this License except under this disclaimer. *
277 | * *
278 | ************************************************************************
279 |
280 | ************************************************************************
281 | * *
282 | * 7. Limitation of Liability *
283 | * -------------------------- *
284 | * *
285 | * Under no circumstances and under no legal theory, whether tort *
286 | * (including negligence), contract, or otherwise, shall any *
287 | * Contributor, or anyone who distributes Covered Software as *
288 | * permitted above, be liable to You for any direct, indirect, *
289 | * special, incidental, or consequential damages of any character *
290 | * including, without limitation, damages for lost profits, loss of *
291 | * goodwill, work stoppage, computer failure or malfunction, or any *
292 | * and all other commercial damages or losses, even if such party *
293 | * shall have been informed of the possibility of such damages. This *
294 | * limitation of liability shall not apply to liability for death or *
295 | * personal injury resulting from such party's negligence to the *
296 | * extent applicable law prohibits such limitation. Some *
297 | * jurisdictions do not allow the exclusion or limitation of *
298 | * incidental or consequential damages, so this exclusion and *
299 | * limitation may not apply to You. *
300 | * *
301 | ************************************************************************
302 |
303 | 8. Litigation
304 | -------------
305 |
306 | Any litigation relating to this License may be brought only in the
307 | courts of a jurisdiction where the defendant maintains its principal
308 | place of business and such litigation shall be governed by laws of that
309 | jurisdiction, without reference to its conflict-of-law provisions.
310 | Nothing in this Section shall prevent a party's ability to bring
311 | cross-claims or counter-claims.
312 |
313 | 9. Miscellaneous
314 | ----------------
315 |
316 | This License represents the complete agreement concerning the subject
317 | matter hereof. If any provision of this License is held to be
318 | unenforceable, such provision shall be reformed only to the extent
319 | necessary to make it enforceable. Any law or regulation which provides
320 | that the language of a contract shall be construed against the drafter
321 | shall not be used to construe this License against a Contributor.
322 |
323 | 10. Versions of the License
324 | ---------------------------
325 |
326 | 10.1. New Versions
327 |
328 | Mozilla Foundation is the license steward. Except as provided in Section
329 | 10.3, no one other than the license steward has the right to modify or
330 | publish new versions of this License. Each version will be given a
331 | distinguishing version number.
332 |
333 | 10.2. Effect of New Versions
334 |
335 | You may distribute the Covered Software under the terms of the version
336 | of the License under which You originally received the Covered Software,
337 | or under the terms of any subsequent version published by the license
338 | steward.
339 |
340 | 10.3. Modified Versions
341 |
342 | If you create software not governed by this License, and you want to
343 | create a new license for such software, you may create and use a
344 | modified version of this License if you rename the license and remove
345 | any references to the name of the license steward (except to note that
346 | such modified license differs from this License).
347 |
348 | 10.4. Distributing Source Code Form that is Incompatible With Secondary
349 | Licenses
350 |
351 | If You choose to distribute Source Code Form that is Incompatible With
352 | Secondary Licenses under the terms of this version of the License, the
353 | notice described in Exhibit B of this License must be attached.
354 |
355 | Exhibit A - Source Code Form License Notice
356 | -------------------------------------------
357 |
358 | This Source Code Form is subject to the terms of the Mozilla Public
359 | License, v. 2.0. If a copy of the MPL was not distributed with this
360 | file, You can obtain one at https://mozilla.org/MPL/2.0/.
361 |
362 | If it is not possible or desirable to put the notice in a particular
363 | file, then You may include the notice in a location (such as a LICENSE
364 | file in a relevant directory) where a recipient would be likely to look
365 | for such a notice.
366 |
367 | You may add additional accurate notices of copyright ownership.
368 |
369 | Exhibit B - "Incompatible With Secondary Licenses" Notice
370 | ---------------------------------------------------------
371 |
372 | This Source Code Form is "Incompatible With Secondary Licenses", as
373 | defined by the Mozilla Public License, v. 2.0.
374 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # chromem-go
2 |
3 | [](https://pkg.go.dev/github.com/philippgille/chromem-go)
4 | [](https://github.com/philippgille/chromem-go/actions/workflows/go.yml)
5 | [](https://goreportcard.com/report/github.com/philippgille/chromem-go)
6 | [](https://github.com/philippgille/chromem-go/releases)
7 |
8 | Embeddable vector database for Go with Chroma-like interface and zero third-party dependencies. In-memory with optional persistence.
9 |
10 | Because `chromem-go` is embeddable it enables you to add retrieval augmented generation (RAG) and similar embeddings-based features into your Go app *without having to run a separate database*. Like when using SQLite instead of PostgreSQL/MySQL/etc.
11 |
12 | It's *not* a library to connect to Chroma and also not a reimplementation of it in Go. It's a database on its own.
13 |
14 | The focus is not scale (millions of documents) or number of features, but simplicity and performance for the most common use cases. On a mid-range 2020 Intel laptop CPU you can query 1,000 documents in 0.3 ms and 100,000 documents in 40 ms, with very few and small memory allocations. See [Benchmarks](#benchmarks) for details.
15 |
16 | > ⚠️ The project is in beta, under heavy construction, and may introduce breaking changes in releases before `v1.0.0`. All changes are documented in the [`CHANGELOG`](./CHANGELOG.md).
17 |
18 | ## Contents
19 |
20 | 1. [Use cases](#use-cases)
21 | 2. [Interface](#interface)
22 | 3. [Features + Roadmap](#features)
23 | 4. [Installation](#installation)
24 | 5. [Usage](#usage)
25 | 6. [Benchmarks](#benchmarks)
26 | 7. [Development](#development)
27 | 8. [Motivation](#motivation)
28 | 9. [Related projects](#related-projects)
29 |
30 | ## Use cases
31 |
32 | With a vector database you can do various things:
33 |
34 | - Retrieval augmented generation (RAG), question answering (Q&A)
35 | - Text and code search
36 | - Recommendation systems
37 | - Classification
38 | - Clustering
39 |
40 | Let's look at the RAG use case in more detail:
41 |
42 | ### RAG
43 |
44 | The knowledge of large language models (LLMs) - even the ones with 30 billion, 70 billion parameters and more - is limited. They don't know anything about what happened after their training ended, they don't know anything about data they were not trained with (like your company's intranet, Jira / bug tracker, wiki or other kinds of knowledge bases), and even the data they *do* know they often can't reproduce it *exactly*, but start to *hallucinate* instead.
45 |
46 | Fine-tuning an LLM can help a bit, but it's more meant to improve the LLMs reasoning about specific topics, or reproduce the style of written text or code. Fine-tuning does *not* add knowledge *1:1* into the model. Details are lost or mixed up. And knowledge cutoff (about anything that happened after the fine-tuning) isn't solved either.
47 |
48 | => A vector database can act as the up-to-date, precise knowledge for LLMs:
49 |
50 | 1. You store relevant documents that you want the LLM to know in the database.
51 | 2. The database stores the *embeddings* alongside the documents, which you can either provide or can be created by specific "embedding models" like OpenAI's `text-embedding-3-small`.
52 | - `chromem-go` can do this for you and supports multiple embedding providers and models out-of-the-box.
53 | 3. Later, when you want to talk to the LLM, you first send the question to the vector DB to find *similar*/*related* content. This is called "nearest neighbor search".
54 | 4. In the question to the LLM, you provide this content alongside your question.
55 | 5. The LLM can take this up-to-date precise content into account when answering.
56 |
57 | Check out the [example code](examples) to see it in action!
58 |
59 | ## Interface
60 |
61 | Our original inspiration was the [Chroma](https://www.trychroma.com/) interface, whose core API is the following (taken from their [README](https://github.com/chroma-core/chroma/blob/0.4.21/README.md)):
62 |
63 | Chroma core interface
64 |
65 | ```python
66 | import chromadb
67 | # setup Chroma in-memory, for easy prototyping. Can add persistence easily!
68 | client = chromadb.Client()
69 |
70 | # Create collection. get_collection, get_or_create_collection, delete_collection also available!
71 | collection = client.create_collection("all-my-documents")
72 |
73 | # Add docs to the collection. Can also update and delete. Row-based API coming soon!
74 | collection.add(
75 | documents=["This is document1", "This is document2"], # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well
76 | metadatas=[{"source": "notion"}, {"source": "google-docs"}], # filter on these!
77 | ids=["doc1", "doc2"], # unique for each doc
78 | )
79 |
80 | # Query/search 2 most similar results. You can also .get by id
81 | results = collection.query(
82 | query_texts=["This is a query document"],
83 | n_results=2,
84 | # where={"metadata_field": "is_equal_to_this"}, # optional filter
85 | # where_document={"$contains":"search_string"} # optional filter
86 | )
87 | ```
88 |
89 |
90 |
91 | Our Go library exposes the same interface:
92 |
93 | chromem-go equivalent
94 |
95 | ```go
96 | package main
97 |
98 | import "github.com/philippgille/chromem-go"
99 |
100 | func main() {
101 | // Set up chromem-go in-memory, for easy prototyping. Can add persistence easily!
102 | // We call it DB instead of client because there's no client-server separation. The DB is embedded.
103 | db := chromem.NewDB()
104 |
105 | // Create collection. GetCollection, GetOrCreateCollection, DeleteCollection also available!
106 | collection, _ := db.CreateCollection("all-my-documents", nil, nil)
107 |
108 | // Add docs to the collection. Update and delete will be added in the future.
109 | // Can be multi-threaded with AddConcurrently()!
110 | // We're showing the Chroma-like method here, but more Go-idiomatic methods are also available!
111 | _ = collection.Add(ctx,
112 | []string{"doc1", "doc2"}, // unique ID for each doc
113 | nil, // We handle embedding automatically. You can skip that and add your own embeddings as well.
114 | []map[string]string{{"source": "notion"}, {"source": "google-docs"}}, // Filter on these!
115 | []string{"This is document1", "This is document2"},
116 | )
117 |
118 | // Query/search 2 most similar results. You can also get by ID.
119 | results, _ := collection.Query(ctx,
120 | "This is a query document",
121 | 2,
122 | map[string]string{"metadata_field": "is_equal_to_this"}, // optional filter
123 | map[string]string{"$contains": "search_string"}, // optional filter
124 | )
125 | }
126 | ```
127 |
128 |
129 |
130 | Initially `chromem-go` started with just the four core methods, but we added more over time. We intentionally don't want to cover 100% of Chroma's API surface though.
131 | We're providing some alternative methods that are more Go-idiomatic instead.
132 |
133 | For the full interface see the Godoc:
134 |
135 | ## Features
136 |
137 | - [X] Zero dependencies on third party libraries
138 | - [X] Embeddable (like SQLite, i.e. no client-server model, no separate DB to maintain)
139 | - [X] Multithreaded processing (when adding and querying documents), making use of Go's native concurrency features
140 | - [X] Experimental WebAssembly binding
141 | - Embedding creators:
142 | - Hosted:
143 | - [X] [OpenAI](https://platform.openai.com/docs/guides/embeddings/embedding-models) (default)
144 | - [X] [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings)
145 | - [X] [GCP Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings)
146 | - [X] [Cohere](https://cohere.com/models/embed)
147 | - [X] [Mistral](https://docs.mistral.ai/platform/endpoints/#embedding-models)
148 | - [X] [Jina](https://jina.ai/embeddings)
149 | - [X] [mixedbread.ai](https://www.mixedbread.ai/)
150 | - Local:
151 | - [X] [Ollama](https://github.com/ollama/ollama)
152 | - [X] [LocalAI](https://github.com/mudler/LocalAI)
153 | - Bring your own (implement [`chromem.EmbeddingFunc`](https://pkg.go.dev/github.com/philippgille/chromem-go#EmbeddingFunc))
154 | - You can also pass existing embeddings when adding documents to a collection, instead of letting `chromem-go` create them
155 | - Similarity search:
156 | - [X] Exhaustive nearest neighbor search using cosine similarity (sometimes also called exact search or brute-force search or FLAT index)
157 | - Filters:
158 | - [X] Document filters: `$contains`, `$not_contains`
159 | - [X] Metadata filters: Exact matches
160 | - Storage:
161 | - [X] In-memory
162 | - [X] Optional immediate persistence (writes one file for each added collection and document, encoded as [gob](https://go.dev/blog/gob), optionally gzip-compressed)
163 | - [X] Backups: Export and import of the entire DB to/from a single file (encoded as [gob](https://go.dev/blog/gob), optionally gzip-compressed and AES-GCM encrypted)
164 | - Includes methods for generic `io.Writer`/`io.Reader` so you can plug S3 buckets and other blob storage, see [examples/s3-export-import](examples/s3-export-import) for example code
165 | - Data types:
166 | - [X] Documents (text)
167 |
168 | ### Roadmap
169 |
170 | - Performance:
171 | - Use SIMD for dot product calculation on supported CPUs (draft PR: [#48](https://github.com/philippgille/chromem-go/pull/48))
172 | - Add [roaring bitmaps](https://github.com/RoaringBitmap/roaring) to speed up full text filtering
173 | - Embedding creators:
174 | - Add an `EmbeddingFunc` that downloads and shells out to [llamafile](https://github.com/Mozilla-Ocho/llamafile)
175 | - Similarity search:
176 | - Approximate nearest neighbor search with index (ANN)
177 | - Hierarchical Navigable Small World (HNSW)
178 | - Inverted file flat (IVFFlat)
179 | - Filters:
180 | - Operators (`$and`, `$or` etc.)
181 | - Storage:
182 | - JSON as second encoding format
183 | - Write-ahead log (WAL) as second file format
184 | - Optional remote storage (S3, PostgreSQL, ...)
185 | - Data types:
186 | - Images
187 | - Videos
188 |
189 | ## Installation
190 |
191 | `go get github.com/philippgille/chromem-go@latest`
192 |
193 | ## Usage
194 |
195 | See the Godoc for a reference:
196 |
197 | For full, working examples, using the vector database for retrieval augmented generation (RAG) and semantic search and using either OpenAI or locally running the embeddings model and LLM (in Ollama), see the [example code](examples).
198 |
199 | ### Quickstart
200 |
201 | This is taken from the ["minimal" example](examples/minimal):
202 |
203 | ```go
204 | package main
205 |
206 | import (
207 | "context"
208 | "fmt"
209 | "runtime"
210 |
211 | "github.com/philippgille/chromem-go"
212 | )
213 |
214 | func main() {
215 | ctx := context.Background()
216 |
217 | db := chromem.NewDB()
218 |
219 | // Passing nil as embedding function leads to OpenAI being used and requires
220 | // "OPENAI_API_KEY" env var to be set. Other providers are supported as well.
221 | // For example pass `chromem.NewEmbeddingFuncOllama(...)` to use Ollama.
222 | c, err := db.CreateCollection("knowledge-base", nil, nil)
223 | if err != nil {
224 | panic(err)
225 | }
226 |
227 | err = c.AddDocuments(ctx, []chromem.Document{
228 | {
229 | ID: "1",
230 | Content: "The sky is blue because of Rayleigh scattering.",
231 | },
232 | {
233 | ID: "2",
234 | Content: "Leaves are green because chlorophyll absorbs red and blue light.",
235 | },
236 | }, runtime.NumCPU())
237 | if err != nil {
238 | panic(err)
239 | }
240 |
241 | res, err := c.Query(ctx, "Why is the sky blue?", 1, nil, nil)
242 | if err != nil {
243 | panic(err)
244 | }
245 |
246 | fmt.Printf("ID: %v\nSimilarity: %v\nContent: %v\n", res[0].ID, res[0].Similarity, res[0].Content)
247 | }
248 | ```
249 |
250 | Output:
251 |
252 | ```text
253 | ID: 1
254 | Similarity: 0.6833369
255 | Content: The sky is blue because of Rayleigh scattering.
256 | ```
257 |
258 | ## Benchmarks
259 |
260 | Benchmarked on 2024-03-17 with:
261 |
262 | - Computer: Framework Laptop 13 (first generation, 2021)
263 | - CPU: 11th Gen Intel Core i5-1135G7 (2020)
264 | - Memory: 32 GB
265 | - OS: Fedora Linux 39
266 | - Kernel: 6.7
267 |
268 | ```console
269 | $ go test -benchmem -run=^$ -bench .
270 | goos: linux
271 | goarch: amd64
272 | pkg: github.com/philippgille/chromem-go
273 | cpu: 11th Gen Intel(R) Core(TM) i5-1135G7 @ 2.40GHz
274 | BenchmarkCollection_Query_NoContent_100-8 13164 90276 ns/op 5176 B/op 95 allocs/op
275 | BenchmarkCollection_Query_NoContent_1000-8 2142 520261 ns/op 13558 B/op 141 allocs/op
276 | BenchmarkCollection_Query_NoContent_5000-8 561 2150354 ns/op 47096 B/op 173 allocs/op
277 | BenchmarkCollection_Query_NoContent_25000-8 120 9890177 ns/op 211783 B/op 208 allocs/op
278 | BenchmarkCollection_Query_NoContent_100000-8 30 39574238 ns/op 810370 B/op 232 allocs/op
279 | BenchmarkCollection_Query_100-8 13225 91058 ns/op 5177 B/op 95 allocs/op
280 | BenchmarkCollection_Query_1000-8 2226 519693 ns/op 13552 B/op 140 allocs/op
281 | BenchmarkCollection_Query_5000-8 550 2128121 ns/op 47108 B/op 173 allocs/op
282 | BenchmarkCollection_Query_25000-8 100 10063260 ns/op 211705 B/op 205 allocs/op
283 | BenchmarkCollection_Query_100000-8 30 39404005 ns/op 810295 B/op 229 allocs/op
284 | PASS
285 | ok github.com/philippgille/chromem-go 28.402s
286 | ```
287 |
288 | ## Development
289 |
290 | - Build: `go build ./...`
291 | - Test: `go test -v -race -count 1 ./...`
292 | - Benchmark:
293 | - `go test -benchmem -run=^$ -bench .` (add `> bench.out` or similar to write to a file)
294 | - With profiling: `go test -benchmem -run ^$ -cpuprofile cpu.out -bench .`
295 | - (profiles: `-cpuprofile`, `-memprofile`, `-blockprofile`, `-mutexprofile`)
296 | - Compare benchmarks:
297 | 1. Install `benchstat`: `go install golang.org/x/perf/cmd/benchstat@latest`
298 | 2. Compare two benchmark results: `benchstat before.out after.out`
299 |
300 | ## Motivation
301 |
302 | In December 2023, when I wanted to play around with retrieval augmented generation (RAG) in a Go program, I looked for a vector database that could be embedded in the Go program, just like you would embed SQLite in order to not require any separate DB setup and maintenance. I was surprised when I didn't find any, given the abundance of embedded key-value stores in the Go ecosystem.
303 |
304 | At the time most of the popular vector databases like Pinecone, Qdrant, Milvus, Chroma, Weaviate and others were not embeddable at all or only in Python or JavaScript/TypeScript.
305 |
306 | Then I found [@eliben](https://github.com/eliben)'s [blog post](https://eli.thegreenplace.net/2023/retrieval-augmented-generation-in-go/) and [example code](https://github.com/eliben/code-for-blog/tree/eda87b87dad9ed8bd45d1c8d6395efba3741ed39/2023/go-rag-openai) which showed that with very little Go code you could create a very basic PoC of a vector database.
307 |
308 | That's when I decided to build my own vector database, embeddable in Go, inspired by the ChromaDB interface. ChromaDB stood out for being embeddable (in Python), and by showing its core API in 4 commands on their README and on the landing page of their website.
309 |
310 | ## Related projects
311 |
312 | - Shoutout to [@eliben](https://github.com/eliben) whose [blog post](https://eli.thegreenplace.net/2023/retrieval-augmented-generation-in-go/) and [example code](https://github.com/eliben/code-for-blog/tree/eda87b87dad9ed8bd45d1c8d6395efba3741ed39/2023/go-rag-openai) inspired me to start this project!
313 | - [Chroma](https://github.com/chroma-core/chroma): Looking at Pinecone, Qdrant, Milvus, Weaviate and others, Chroma stood out by showing its core API in 4 commands on their README and on the landing page of their website. It was also putting the most emphasis on its embeddability (in Python).
314 | - The big, full-fledged client-server-based vector databases for maximum scale and performance:
315 | - [Pinecone](https://www.pinecone.io/): Closed source
316 | - [Qdrant](https://github.com/qdrant/qdrant): Written in Rust, not embeddable in Go
317 | - [Milvus](https://github.com/milvus-io/milvus): Written in Go and C++, but not embeddable as of December 2023
318 | - [Weaviate](https://github.com/weaviate/weaviate): Written in Go, but not embeddable in Go as of March 2024 (only in Python and JavaScript/TypeScript and that's experimental)
319 | - Some non-specialized SQL, NoSQL and Key-Value databases added support for storing vectors and (some of them) querying based on similarity:
320 | - [pgvector](https://github.com/pgvector/pgvector) extension for [PostgreSQL](https://www.postgresql.org/): Client-server model
321 | - [Redis](https://github.com/redis/redis) ([1](https://redis.io/docs/interact/search-and-query/query/vector-search/), [2](https://redis.io/docs/interact/search-and-query/advanced-concepts/vectors/)): Client-server model
322 | - [sqlite-vss](https://github.com/asg017/sqlite-vss) extension for [SQLite](https://www.sqlite.org/): Embedded, but the [Go bindings](https://github.com/asg017/sqlite-vss/tree/8fc44301843029a13a474d1f292378485e1fdd62/bindings/go) require CGO. There's a [CGO-free Go library](https://gitlab.com/cznic/sqlite) for SQLite, but then it's without the vector search extension.
323 | - [DuckDB](https://github.com/duckdb/duckdb) has a function to calculate cosine similarity ([1](https://duckdb.org/docs/sql/functions/nested)): Embedded, but the Go bindings use CGO
324 | - [MongoDB](https://github.com/mongodb/mongo)'s cloud platform offers a vector search product ([1](https://www.mongodb.com/products/platform/atlas-vector-search)): Client-server model
325 | - Some libraries for vector similarity search:
326 | - [Faiss](https://github.com/facebookresearch/faiss): Written in C++; 3rd party Go bindings use CGO
327 | - [Annoy](https://github.com/spotify/annoy): Written in C++; Go bindings use CGO ([1](https://github.com/spotify/annoy/blob/2be37c9e015544be2cf60c431f0cccc076151a2d/README_GO.rst))
328 | - [USearch](https://github.com/unum-cloud/usearch): Written in C++; Go bindings use CGO
329 | - Some orchestration libraries, inspired by the Python library [LangChain](https://github.com/langchain-ai/langchain), but with no or only rudimentary embedded vector DB:
330 | - [LangChain Go](https://github.com/tmc/langchaingo)
331 | - [LinGoose](https://github.com/henomis/lingoose)
332 | - [GoLC](https://github.com/hupe1980/golc)
333 |
--------------------------------------------------------------------------------
/collection.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "maps"
8 | "path/filepath"
9 | "slices"
10 | "sync"
11 | )
12 |
13 | // Collection represents a collection of documents.
14 | // It also has a configured embedding function, which is used when adding documents
15 | // that don't have embeddings yet.
16 | type Collection struct {
17 | Name string
18 |
19 | metadata map[string]string
20 | documents map[string]*Document
21 | documentsLock sync.RWMutex
22 | embed EmbeddingFunc
23 |
24 | persistDirectory string
25 | compress bool
26 |
27 | // ⚠️ When adding fields here, consider adding them to the persistence struct
28 | // versions in [DB.Export] and [DB.Import] as well!
29 | }
30 |
31 | // NegativeMode represents the mode to use for the negative text.
32 | // See QueryOptions for more information.
33 | type NegativeMode string
34 |
35 | const (
36 | // NEGATIVE_MODE_FILTER filters out results based on the similarity between the
37 | // negative embedding and the document embeddings.
38 | // NegativeFilterThreshold controls the threshold for filtering. Documents with
39 | // similarity above the threshold will be removed from the results.
40 | NEGATIVE_MODE_FILTER NegativeMode = "filter"
41 |
42 | // NEGATIVE_MODE_SUBTRACT subtracts the negative embedding from the query embedding.
43 | // This is the default behavior.
44 | NEGATIVE_MODE_SUBTRACT NegativeMode = "subtract"
45 |
46 | // The default threshold for the negative filter.
47 | DEFAULT_NEGATIVE_FILTER_THRESHOLD = 0.5
48 | )
49 |
50 | // QueryOptions represents the options for a query.
51 | type QueryOptions struct {
52 | // The text to search for.
53 | QueryText string
54 |
55 | // The embedding of the query to search for. It must be created
56 | // with the same embedding model as the document embeddings in the collection.
57 | // The embedding will be normalized if it's not the case yet.
58 | // If both QueryText and QueryEmbedding are set, QueryEmbedding will be used.
59 | QueryEmbedding []float32
60 |
61 | // The number of results to return.
62 | NResults int
63 |
64 | // Conditional filtering on metadata.
65 | Where map[string]string
66 |
67 | // Conditional filtering on documents.
68 | WhereDocument map[string]string
69 |
70 | // Negative is the negative query options.
71 | // They can be used to exclude certain results from the query.
72 | Negative NegativeQueryOptions
73 | }
74 |
75 | type NegativeQueryOptions struct {
76 | // Mode is the mode to use for the negative text.
77 | Mode NegativeMode
78 |
79 | // Text is the text to exclude from the results.
80 | Text string
81 |
82 | // Embedding is the embedding of the negative text. It must be created
83 | // with the same embedding model as the document embeddings in the collection.
84 | // The embedding will be normalized if it's not the case yet.
85 | // If both Text and Embedding are set, Embedding will be used.
86 | Embedding []float32
87 |
88 | // FilterThreshold is the threshold for the negative filter. Used when Mode is NEGATIVE_MODE_FILTER.
89 | FilterThreshold float32
90 | }
91 |
92 | // We don't export this yet to keep the API surface to the bare minimum.
93 | // Users create collections via [Client.CreateCollection].
94 | func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string, compress bool) (*Collection, error) {
95 | // We copy the metadata to avoid data races in case the caller modifies the
96 | // map after creating the collection while we range over it.
97 | m := make(map[string]string, len(metadata))
98 | for k, v := range metadata {
99 | m[k] = v
100 | }
101 |
102 | c := &Collection{
103 | Name: name,
104 |
105 | metadata: m,
106 | documents: make(map[string]*Document),
107 | embed: embed,
108 | }
109 |
110 | // Persistence
111 | if dbDir != "" {
112 | safeName := hash2hex(name)
113 | c.persistDirectory = filepath.Join(dbDir, safeName)
114 | c.compress = compress
115 | return c, c.persistMetadata()
116 | }
117 |
118 | return c, nil
119 | }
120 |
121 | // Add embeddings to the datastore.
122 | //
123 | // - ids: The ids of the embeddings you wish to add
124 | // - embeddings: The embeddings to add. If nil, embeddings will be computed based
125 | // on the contents using the embeddingFunc set for the Collection. Optional.
126 | // - metadatas: The metadata to associate with the embeddings. When querying,
127 | // you can filter on this metadata. Optional.
128 | // - contents: The contents to associate with the embeddings.
129 | //
130 | // This is a Chroma-like method. For a more Go-idiomatic one, see [Collection.AddDocuments].
131 | func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string) error {
132 | return c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 1)
133 | }
134 |
135 | // AddConcurrently is like Add, but adds embeddings concurrently.
136 | // This is mostly useful when you don't pass any embeddings, so they have to be created.
137 | // Upon error, concurrently running operations are canceled and the error is returned.
138 | //
139 | // This is a Chroma-like method. For a more Go-idiomatic one, see [Collection.AddDocuments].
140 | func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string, concurrency int) error {
141 | if len(ids) == 0 {
142 | return errors.New("ids are empty")
143 | }
144 | if len(embeddings) == 0 && len(contents) == 0 {
145 | return errors.New("either embeddings or contents must be filled")
146 | }
147 | if len(embeddings) != 0 {
148 | if len(embeddings) != len(ids) {
149 | return errors.New("ids and embeddings must have the same length")
150 | }
151 | } else {
152 | // Assign empty slice, so we can simply access via index later
153 | embeddings = make([][]float32, len(ids))
154 | }
155 | if len(metadatas) != 0 {
156 | if len(ids) != len(metadatas) {
157 | return errors.New("when metadatas is not empty it must have the same length as ids")
158 | }
159 | } else {
160 | // Assign empty slice, so we can simply access via index later
161 | metadatas = make([]map[string]string, len(ids))
162 | }
163 | if len(contents) != 0 {
164 | if len(contents) != len(ids) {
165 | return errors.New("ids and contents must have the same length")
166 | }
167 | } else {
168 | // Assign empty slice, so we can simply access via index later
169 | contents = make([]string, len(ids))
170 | }
171 | if concurrency < 1 {
172 | return errors.New("concurrency must be at least 1")
173 | }
174 |
175 | // Convert Chroma-style parameters into a slice of documents.
176 | docs := make([]Document, 0, len(ids))
177 | for i, id := range ids {
178 | docs = append(docs, Document{
179 | ID: id,
180 | Metadata: metadatas[i],
181 | Embedding: embeddings[i],
182 | Content: contents[i],
183 | })
184 | }
185 |
186 | return c.AddDocuments(ctx, docs, concurrency)
187 | }
188 |
189 | // AddDocuments adds documents to the collection with the specified concurrency.
190 | // If the documents don't have embeddings, they will be created using the collection's
191 | // embedding function.
192 | // Upon error, concurrently running operations are canceled and the error is returned.
193 | func (c *Collection) AddDocuments(ctx context.Context, documents []Document, concurrency int) error {
194 | if len(documents) == 0 {
195 | // TODO: Should this be a no-op instead?
196 | return errors.New("documents slice is nil or empty")
197 | }
198 | if concurrency < 1 {
199 | return errors.New("concurrency must be at least 1")
200 | }
201 | // For other validations we rely on AddDocument.
202 |
203 | var sharedErr error
204 | sharedErrLock := sync.Mutex{}
205 | ctx, cancel := context.WithCancelCause(ctx)
206 | defer cancel(nil)
207 | setSharedErr := func(err error) {
208 | sharedErrLock.Lock()
209 | defer sharedErrLock.Unlock()
210 | // Another goroutine might have already set the error.
211 | if sharedErr == nil {
212 | sharedErr = err
213 | // Cancel the operation for all other goroutines.
214 | cancel(sharedErr)
215 | }
216 | }
217 |
218 | var wg sync.WaitGroup
219 | semaphore := make(chan struct{}, concurrency)
220 | for _, doc := range documents {
221 | wg.Add(1)
222 | go func(doc Document) {
223 | defer wg.Done()
224 |
225 | // Don't even start if another goroutine already failed.
226 | if ctx.Err() != nil {
227 | return
228 | }
229 |
230 | // Wait here while $concurrency other goroutines are creating documents.
231 | semaphore <- struct{}{}
232 | defer func() { <-semaphore }()
233 |
234 | err := c.AddDocument(ctx, doc)
235 | if err != nil {
236 | setSharedErr(fmt.Errorf("couldn't add document '%s': %w", doc.ID, err))
237 | return
238 | }
239 | }(doc)
240 | }
241 |
242 | wg.Wait()
243 |
244 | return sharedErr
245 | }
246 |
247 | // AddDocument adds a document to the collection.
248 | // If the document doesn't have an embedding, it will be created using the collection's
249 | // embedding function.
250 | func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
251 | if doc.ID == "" {
252 | return errors.New("document ID is empty")
253 | }
254 | if len(doc.Embedding) == 0 && doc.Content == "" {
255 | return errors.New("either document embedding or content must be filled")
256 | }
257 |
258 | // We copy the metadata to avoid data races in case the caller modifies the
259 | // map after creating the document while we range over it.
260 | m := make(map[string]string, len(doc.Metadata))
261 | for k, v := range doc.Metadata {
262 | m[k] = v
263 | }
264 |
265 | // Create embedding if they don't exist, otherwise normalize if necessary
266 | if len(doc.Embedding) == 0 {
267 | embedding, err := c.embed(ctx, doc.Content)
268 | if err != nil {
269 | return fmt.Errorf("couldn't create embedding of document: %w", err)
270 | }
271 | doc.Embedding = embedding
272 | } else {
273 | if !isNormalized(doc.Embedding) {
274 | doc.Embedding = normalizeVector(doc.Embedding)
275 | }
276 | }
277 |
278 | c.documentsLock.Lock()
279 | // We don't defer the unlock because we want to do it earlier.
280 | c.documents[doc.ID] = &doc
281 | c.documentsLock.Unlock()
282 |
283 | // Persist the document
284 | if c.persistDirectory != "" {
285 | docPath := c.getDocPath(doc.ID)
286 | err := persistToFile(docPath, doc, c.compress, "")
287 | if err != nil {
288 | return fmt.Errorf("couldn't persist document to %q: %w", docPath, err)
289 | }
290 | }
291 |
292 | return nil
293 | }
294 |
295 | // GetByID returns a document by its ID.
296 | // The returned document is a copy of the original document, so it can be safely
297 | // modified without affecting the collection.
298 | func (c *Collection) GetByID(ctx context.Context, id string) (Document, error) {
299 | if id == "" {
300 | return Document{}, errors.New("document ID is empty")
301 | }
302 |
303 | c.documentsLock.RLock()
304 | defer c.documentsLock.RUnlock()
305 |
306 | doc, ok := c.documents[id]
307 | if ok {
308 | // Clone the document
309 | res := *doc
310 | // Above copies the simple fields, but we need to copy the slices and maps
311 | res.Metadata = maps.Clone(doc.Metadata)
312 | res.Embedding = slices.Clone(doc.Embedding)
313 |
314 | return res, nil
315 | }
316 |
317 | return Document{}, fmt.Errorf("document with ID '%v' not found", id)
318 | }
319 |
320 | // Delete removes document(s) from the collection.
321 | //
322 | // - where: Conditional filtering on metadata. Optional.
323 | // - whereDocument: Conditional filtering on documents. Optional.
324 | // - ids: The ids of the documents to delete. If empty, all documents are deleted.
325 | func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error {
326 | // must have at least one of where, whereDocument or ids
327 | if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 {
328 | return fmt.Errorf("must have at least one of where, whereDocument or ids")
329 | }
330 |
331 | if len(c.documents) == 0 {
332 | return nil
333 | }
334 |
335 | for k := range whereDocument {
336 | if !slices.Contains(supportedFilters, k) {
337 | return errors.New("unsupported whereDocument operator")
338 | }
339 | }
340 |
341 | var docIDs []string
342 |
343 | c.documentsLock.Lock()
344 | defer c.documentsLock.Unlock()
345 |
346 | if where != nil || whereDocument != nil {
347 | // metadata + content filters
348 | filteredDocs := filterDocs(c.documents, where, whereDocument)
349 | for _, doc := range filteredDocs {
350 | docIDs = append(docIDs, doc.ID)
351 | }
352 | } else {
353 | docIDs = ids
354 | }
355 |
356 | // No-op if no docs are left
357 | if len(docIDs) == 0 {
358 | return nil
359 | }
360 |
361 | for _, docID := range docIDs {
362 | delete(c.documents, docID)
363 |
364 | // Remove the document from disk
365 | if c.persistDirectory != "" {
366 | docPath := c.getDocPath(docID)
367 | err := removeFile(docPath)
368 | if err != nil {
369 | return fmt.Errorf("couldn't remove document at %q: %w", docPath, err)
370 | }
371 | }
372 | }
373 |
374 | return nil
375 | }
376 |
377 | // Count returns the number of documents in the collection.
378 | func (c *Collection) Count() int {
379 | c.documentsLock.RLock()
380 | defer c.documentsLock.RUnlock()
381 | return len(c.documents)
382 | }
383 |
384 | // Result represents a single result from a query.
385 | type Result struct {
386 | ID string
387 | Metadata map[string]string
388 | Embedding []float32
389 | Content string
390 |
391 | // The cosine similarity between the query and the document.
392 | // The higher the value, the more similar the document is to the query.
393 | // The value is in the range [-1, 1].
394 | Similarity float32
395 | }
396 |
397 | // Query performs an exhaustive nearest neighbor search on the collection.
398 | //
399 | // - queryText: The text to search for. Its embedding will be created using the
400 | // collection's embedding function.
401 | // - nResults: The maximum number of results to return. Must be > 0.
402 | // There can be fewer results if a filter is applied.
403 | // - where: Conditional filtering on metadata. Optional.
404 | // - whereDocument: Conditional filtering on documents. Optional.
405 | func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
406 | if queryText == "" {
407 | return nil, errors.New("queryText is empty")
408 | }
409 |
410 | queryVector, err := c.embed(ctx, queryText)
411 | if err != nil {
412 | return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
413 | }
414 |
415 | return c.QueryEmbedding(ctx, queryVector, nResults, where, whereDocument)
416 | }
417 |
418 | // QueryWithOptions performs an exhaustive nearest neighbor search on the collection.
419 | //
420 | // - options: The options for the query. See [QueryOptions] for more information.
421 | func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions) ([]Result, error) {
422 | if options.QueryText == "" && len(options.QueryEmbedding) == 0 {
423 | return nil, errors.New("QueryText and QueryEmbedding options are empty")
424 | }
425 |
426 | var err error
427 | queryVector := options.QueryEmbedding
428 | if len(queryVector) == 0 {
429 | queryVector, err = c.embed(ctx, options.QueryText)
430 | if err != nil {
431 | return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
432 | }
433 | }
434 |
435 | negativeFilterThreshold := options.Negative.FilterThreshold
436 | negativeVector := options.Negative.Embedding
437 | if len(negativeVector) == 0 && options.Negative.Text != "" {
438 | negativeVector, err = c.embed(ctx, options.Negative.Text)
439 | if err != nil {
440 | return nil, fmt.Errorf("couldn't create embedding of negative: %w", err)
441 | }
442 | }
443 |
444 | if len(negativeVector) != 0 {
445 | if !isNormalized(negativeVector) {
446 | negativeVector = normalizeVector(negativeVector)
447 | }
448 |
449 | if options.Negative.Mode == NEGATIVE_MODE_SUBTRACT {
450 | queryVector = subtractVector(queryVector, negativeVector)
451 | queryVector = normalizeVector(queryVector)
452 | } else if options.Negative.Mode == NEGATIVE_MODE_FILTER {
453 | if negativeFilterThreshold == 0 {
454 | negativeFilterThreshold = DEFAULT_NEGATIVE_FILTER_THRESHOLD
455 | }
456 | } else {
457 | return nil, fmt.Errorf("unsupported negative mode: %q", options.Negative.Mode)
458 | }
459 | }
460 |
461 | result, err := c.queryEmbedding(ctx, queryVector, negativeVector, negativeFilterThreshold, options.NResults, options.Where, options.WhereDocument)
462 | if err != nil {
463 | return nil, err
464 | }
465 |
466 | return result, nil
467 | }
468 |
469 | // QueryEmbedding performs an exhaustive nearest neighbor search on the collection.
470 | //
471 | // - queryEmbedding: The embedding of the query to search for. It must be created
472 | // with the same embedding model as the document embeddings in the collection.
473 | // The embedding will be normalized if it's not the case yet.
474 | // - nResults: The maximum number of results to return. Must be > 0.
475 | // There can be fewer results if a filter is applied.
476 | // - where: Conditional filtering on metadata. Optional.
477 | // - whereDocument: Conditional filtering on documents. Optional.
478 | func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
479 | return c.queryEmbedding(ctx, queryEmbedding, nil, 0, nResults, where, whereDocument)
480 | }
481 |
482 | // queryEmbedding performs an exhaustive nearest neighbor search on the collection.
483 | func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativeEmbeddings []float32, negativeFilterThreshold float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
484 | if len(queryEmbedding) == 0 {
485 | return nil, errors.New("queryEmbedding is empty")
486 | }
487 | if nResults <= 0 {
488 | return nil, errors.New("nResults must be > 0")
489 | }
490 | c.documentsLock.RLock()
491 | defer c.documentsLock.RUnlock()
492 | if nResults > len(c.documents) {
493 | return nil, errors.New("nResults must be <= the number of documents in the collection")
494 | }
495 |
496 | if len(c.documents) == 0 {
497 | return nil, nil
498 | }
499 |
500 | // Validate whereDocument operators
501 | for k := range whereDocument {
502 | if !slices.Contains(supportedFilters, k) {
503 | return nil, errors.New("unsupported operator")
504 | }
505 | }
506 |
507 | // Filter docs by metadata and content
508 | filteredDocs := filterDocs(c.documents, where, whereDocument)
509 |
510 | // No need to continue if the filters got rid of all documents
511 | if len(filteredDocs) == 0 {
512 | return nil, nil
513 | }
514 |
515 | // Normalize embedding if not the case yet. We only support cosine similarity
516 | // for now and all documents were already normalized when added to the collection.
517 | if !isNormalized(queryEmbedding) {
518 | queryEmbedding = normalizeVector(queryEmbedding)
519 | }
520 |
521 | // If the filtering already reduced the number of documents to fewer than nResults,
522 | // we only need to find the most similar docs among the filtered ones.
523 | resLen := nResults
524 | if len(filteredDocs) < nResults {
525 | resLen = len(filteredDocs)
526 | }
527 |
528 | // For the remaining documents, get the most similar docs.
529 | nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, negativeEmbeddings, negativeFilterThreshold, filteredDocs, resLen)
530 | if err != nil {
531 | return nil, fmt.Errorf("couldn't get most similar docs: %w", err)
532 | }
533 |
534 | res := make([]Result, 0, len(nMaxDocs))
535 | for i := 0; i < len(nMaxDocs); i++ {
536 | res = append(res, Result{
537 | ID: nMaxDocs[i].docID,
538 | Metadata: c.documents[nMaxDocs[i].docID].Metadata,
539 | Embedding: c.documents[nMaxDocs[i].docID].Embedding,
540 | Content: c.documents[nMaxDocs[i].docID].Content,
541 | Similarity: nMaxDocs[i].similarity,
542 | })
543 | }
544 |
545 | return res, nil
546 | }
547 |
548 | // getDocPath generates the path to the document file.
549 | func (c *Collection) getDocPath(docID string) string {
550 | safeID := hash2hex(docID)
551 | docPath := filepath.Join(c.persistDirectory, safeID)
552 | docPath += ".gob"
553 | if c.compress {
554 | docPath += ".gz"
555 | }
556 | return docPath
557 | }
558 |
559 | // persistMetadata persists the collection metadata to disk
560 | func (c *Collection) persistMetadata() error {
561 | // Persist name and metadata
562 | metadataPath := filepath.Join(c.persistDirectory, metadataFileName)
563 | metadataPath += ".gob"
564 | if c.compress {
565 | metadataPath += ".gz"
566 | }
567 | pc := struct {
568 | Name string
569 | Metadata map[string]string
570 | }{
571 | Name: c.Name,
572 | Metadata: c.metadata,
573 | }
574 | err := persistToFile(metadataPath, pc, c.compress, "")
575 | if err != nil {
576 | return err
577 | }
578 |
579 | return nil
580 | }
581 |
--------------------------------------------------------------------------------
/collection_test.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "math/rand"
7 | "os"
8 | "slices"
9 | "strconv"
10 | "testing"
11 | )
12 |
13 | func TestCollection_Add(t *testing.T) {
14 | ctx := context.Background()
15 | name := "test"
16 | metadata := map[string]string{"foo": "bar"}
17 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
18 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
19 | return vectors, nil
20 | }
21 |
22 | // Create collection
23 | db := NewDB()
24 | c, err := db.CreateCollection(name, metadata, embeddingFunc)
25 | if err != nil {
26 | t.Fatal("expected no error, got", err)
27 | }
28 | if c == nil {
29 | t.Fatal("expected collection, got nil")
30 | }
31 |
32 | // Add documents
33 |
34 | ids := []string{"1", "2"}
35 | embeddings := [][]float32{vectors, vectors}
36 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
37 | contents := []string{"hello world", "hallo welt"}
38 |
39 | tt := []struct {
40 | name string
41 | ids []string
42 | embeddings [][]float32
43 | metadatas []map[string]string
44 | contents []string
45 | }{
46 | {
47 | name: "No embeddings",
48 | ids: ids,
49 | embeddings: nil,
50 | metadatas: metadatas,
51 | contents: contents,
52 | },
53 | {
54 | name: "With embeddings",
55 | ids: ids,
56 | embeddings: embeddings,
57 | metadatas: metadatas,
58 | contents: contents,
59 | },
60 | {
61 | name: "With embeddings but no contents",
62 | ids: ids,
63 | embeddings: embeddings,
64 | metadatas: metadatas,
65 | contents: nil,
66 | },
67 | }
68 |
69 | for _, tc := range tt {
70 | t.Run(tc.name, func(t *testing.T) {
71 | err = c.Add(ctx, ids, nil, metadatas, contents)
72 | if err != nil {
73 | t.Fatal("expected nil, got", err)
74 | }
75 |
76 | // Check documents
77 | if len(c.documents) != 2 {
78 | t.Fatal("expected 2, got", len(c.documents))
79 | }
80 | for i, id := range ids {
81 | doc, ok := c.documents[id]
82 | if !ok {
83 | t.Fatal("expected document, got nil")
84 | }
85 | if doc.ID != id {
86 | t.Fatal("expected", id, "got", doc.ID)
87 | }
88 | if len(doc.Metadata) != 1 {
89 | t.Fatal("expected 1, got", len(doc.Metadata))
90 | }
91 | if !slices.Equal(doc.Embedding, vectors) {
92 | t.Fatal("expected", vectors, "got", doc.Embedding)
93 | }
94 | if doc.Content != contents[i] {
95 | t.Fatal("expected", contents[i], "got", doc.Content)
96 | }
97 | }
98 | // Metadata can't be accessed with the loop's i
99 | if c.documents[ids[0]].Metadata["foo"] != "bar" {
100 | t.Fatal("expected bar, got", c.documents[ids[0]].Metadata["foo"])
101 | }
102 | if c.documents[ids[1]].Metadata["a"] != "b" {
103 | t.Fatal("expected b, got", c.documents[ids[1]].Metadata["a"])
104 | }
105 | })
106 | }
107 | }
108 |
109 | func TestCollection_Add_Error(t *testing.T) {
110 | ctx := context.Background()
111 | name := "test"
112 | metadata := map[string]string{"foo": "bar"}
113 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
114 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
115 | return vectors, nil
116 | }
117 |
118 | // Create collection
119 | db := NewDB()
120 | c, err := db.CreateCollection(name, metadata, embeddingFunc)
121 | if err != nil {
122 | t.Fatal("expected no error, got", err)
123 | }
124 | if c == nil {
125 | t.Fatal("expected collection, got nil")
126 | }
127 |
128 | // Add documents, provoking errors
129 | ids := []string{"1", "2"}
130 | embeddings := [][]float32{vectors, vectors}
131 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
132 | contents := []string{"hello world", "hallo welt"}
133 |
134 | // Empty IDs
135 | err = c.Add(ctx, []string{}, embeddings, metadatas, contents)
136 | if err == nil {
137 | t.Fatal("expected error, got nil")
138 | }
139 | // Empty embeddings and contents (both at the same time!)
140 | err = c.Add(ctx, ids, [][]float32{}, metadatas, []string{})
141 | if err == nil {
142 | t.Fatal("expected error, got nil")
143 | }
144 | // Bad embeddings length
145 | err = c.Add(ctx, ids, [][]float32{vectors}, metadatas, contents)
146 | if err == nil {
147 | t.Fatal("expected error, got nil")
148 | }
149 | // Bad metadatas length
150 | err = c.Add(ctx, ids, embeddings, []map[string]string{{"foo": "bar"}}, contents)
151 | if err == nil {
152 | t.Fatal("expected error, got nil")
153 | }
154 | // Bad contents length
155 | err = c.Add(ctx, ids, embeddings, metadatas, []string{"hello world"})
156 | if err == nil {
157 | t.Fatal("expected error, got nil")
158 | }
159 | }
160 |
161 | func TestCollection_AddConcurrently(t *testing.T) {
162 | ctx := context.Background()
163 | name := "test"
164 | metadata := map[string]string{"foo": "bar"}
165 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
166 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
167 | return vectors, nil
168 | }
169 |
170 | // Create collection
171 | db := NewDB()
172 | c, err := db.CreateCollection(name, metadata, embeddingFunc)
173 | if err != nil {
174 | t.Fatal("expected no error, got", err)
175 | }
176 | if c == nil {
177 | t.Fatal("expected collection, got nil")
178 | }
179 |
180 | // Add documents
181 |
182 | ids := []string{"1", "2"}
183 | embeddings := [][]float32{vectors, vectors}
184 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
185 | contents := []string{"hello world", "hallo welt"}
186 |
187 | tt := []struct {
188 | name string
189 | ids []string
190 | embeddings [][]float32
191 | metadatas []map[string]string
192 | contents []string
193 | }{
194 | {
195 | name: "No embeddings",
196 | ids: ids,
197 | embeddings: nil,
198 | metadatas: metadatas,
199 | contents: contents,
200 | },
201 | {
202 | name: "With embeddings",
203 | ids: ids,
204 | embeddings: embeddings,
205 | metadatas: metadatas,
206 | contents: contents,
207 | },
208 | {
209 | name: "With embeddings but no contents",
210 | ids: ids,
211 | embeddings: embeddings,
212 | metadatas: metadatas,
213 | contents: nil,
214 | },
215 | }
216 |
217 | for _, tc := range tt {
218 | t.Run(tc.name, func(t *testing.T) {
219 | err = c.AddConcurrently(ctx, ids, nil, metadatas, contents, 2)
220 | if err != nil {
221 | t.Fatal("expected nil, got", err)
222 | }
223 |
224 | // Check documents
225 | if len(c.documents) != 2 {
226 | t.Fatal("expected 2, got", len(c.documents))
227 | }
228 | for i, id := range ids {
229 | doc, ok := c.documents[id]
230 | if !ok {
231 | t.Fatal("expected document, got nil")
232 | }
233 | if doc.ID != id {
234 | t.Fatal("expected", id, "got", doc.ID)
235 | }
236 | if len(doc.Metadata) != 1 {
237 | t.Fatal("expected 1, got", len(doc.Metadata))
238 | }
239 | if !slices.Equal(doc.Embedding, vectors) {
240 | t.Fatal("expected", vectors, "got", doc.Embedding)
241 | }
242 | if doc.Content != contents[i] {
243 | t.Fatal("expected", contents[i], "got", doc.Content)
244 | }
245 | }
246 | // Metadata can't be accessed with the loop's i
247 | if c.documents[ids[0]].Metadata["foo"] != "bar" {
248 | t.Fatal("expected bar, got", c.documents[ids[0]].Metadata["foo"])
249 | }
250 | if c.documents[ids[1]].Metadata["a"] != "b" {
251 | t.Fatal("expected b, got", c.documents[ids[1]].Metadata["a"])
252 | }
253 | })
254 | }
255 | }
256 |
257 | func TestCollection_AddConcurrently_Error(t *testing.T) {
258 | ctx := context.Background()
259 | name := "test"
260 | metadata := map[string]string{"foo": "bar"}
261 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
262 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
263 | return vectors, nil
264 | }
265 |
266 | // Create collection
267 | db := NewDB()
268 | c, err := db.CreateCollection(name, metadata, embeddingFunc)
269 | if err != nil {
270 | t.Fatal("expected no error, got", err)
271 | }
272 | if c == nil {
273 | t.Fatal("expected collection, got nil")
274 | }
275 |
276 | // Add documents, provoking errors
277 | ids := []string{"1", "2"}
278 | embeddings := [][]float32{vectors, vectors}
279 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
280 | contents := []string{"hello world", "hallo welt"}
281 | // Empty IDs
282 | err = c.AddConcurrently(ctx, []string{}, embeddings, metadatas, contents, 2)
283 | if err == nil {
284 | t.Fatal("expected error, got nil")
285 | }
286 | // Empty embeddings and contents (both at the same time!)
287 | err = c.AddConcurrently(ctx, ids, [][]float32{}, metadatas, []string{}, 2)
288 | if err == nil {
289 | t.Fatal("expected error, got nil")
290 | }
291 | // Bad embeddings length
292 | err = c.AddConcurrently(ctx, ids, [][]float32{vectors}, metadatas, contents, 2)
293 | if err == nil {
294 | t.Fatal("expected error, got nil")
295 | }
296 | // Bad metadatas length
297 | err = c.AddConcurrently(ctx, ids, embeddings, []map[string]string{{"foo": "bar"}}, contents, 2)
298 | if err == nil {
299 | t.Fatal("expected error, got nil")
300 | }
301 | // Bad contents length
302 | err = c.AddConcurrently(ctx, ids, embeddings, metadatas, []string{"hello world"}, 2)
303 | if err == nil {
304 | t.Fatal("expected error, got nil")
305 | }
306 | // Bad concurrency
307 | err = c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 0)
308 | if err == nil {
309 | t.Fatal("expected error, got nil")
310 | }
311 | }
312 |
313 | func TestCollection_QueryError(t *testing.T) {
314 | // Create collection
315 | db := NewDB()
316 | name := "test"
317 | metadata := map[string]string{"foo": "bar"}
318 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
319 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
320 | return vectors, nil
321 | }
322 | c, err := db.CreateCollection(name, metadata, embeddingFunc)
323 | if err != nil {
324 | t.Fatal("expected no error, got", err)
325 | }
326 | if c == nil {
327 | t.Fatal("expected collection, got nil")
328 | }
329 | // Add a document
330 | err = c.AddDocument(context.Background(), Document{ID: "1", Content: "hello world"})
331 | if err != nil {
332 | t.Fatal("expected nil, got", err)
333 | }
334 |
335 | tt := []struct {
336 | name string
337 | query func() error
338 | expErr string
339 | }{
340 | {
341 | name: "Empty query",
342 | query: func() error {
343 | _, err := c.Query(context.Background(), "", 1, nil, nil)
344 | return err
345 | },
346 | expErr: "queryText is empty",
347 | },
348 | {
349 | name: "Negative limit",
350 | query: func() error {
351 | _, err := c.Query(context.Background(), "foo", -1, nil, nil)
352 | return err
353 | },
354 | expErr: "nResults must be > 0",
355 | },
356 | {
357 | name: "Zero limit",
358 | query: func() error {
359 | _, err := c.Query(context.Background(), "foo", 0, nil, nil)
360 | return err
361 | },
362 | expErr: "nResults must be > 0",
363 | },
364 | {
365 | name: "Limit greater than number of documents",
366 | query: func() error {
367 | _, err := c.Query(context.Background(), "foo", 2, nil, nil)
368 | return err
369 | },
370 | expErr: "nResults must be <= the number of documents in the collection",
371 | },
372 | {
373 | name: "Bad content filter",
374 | query: func() error {
375 | _, err := c.Query(context.Background(), "foo", 1, nil, map[string]string{"invalid": "foo"})
376 | return err
377 | },
378 | expErr: "unsupported operator",
379 | },
380 | }
381 |
382 | for _, tc := range tt {
383 | t.Run(tc.name, func(t *testing.T) {
384 | err := tc.query()
385 | if err == nil {
386 | t.Fatal("expected error, got nil")
387 | } else if err.Error() != tc.expErr {
388 | t.Fatal("expected", tc.expErr, "got", err)
389 | }
390 | })
391 | }
392 | }
393 |
394 | func TestCollection_Get(t *testing.T) {
395 | ctx := context.Background()
396 |
397 | // Create collection
398 | db := NewDB()
399 | name := "test"
400 | metadata := map[string]string{"foo": "bar"}
401 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
402 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
403 | return vectors, nil
404 | }
405 | c, err := db.CreateCollection(name, metadata, embeddingFunc)
406 | if err != nil {
407 | t.Fatal("expected no error, got", err)
408 | }
409 | if c == nil {
410 | t.Fatal("expected collection, got nil")
411 | }
412 |
413 | // Add documents
414 | ids := []string{"1", "2"}
415 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
416 | contents := []string{"hello world", "hallo welt"}
417 | err = c.Add(context.Background(), ids, nil, metadatas, contents)
418 | if err != nil {
419 | t.Fatal("expected nil, got", err)
420 | }
421 |
422 | // Get by ID
423 | doc, err := c.GetByID(ctx, ids[0])
424 | if err != nil {
425 | t.Fatal("expected nil, got", err)
426 | }
427 | // Check fields
428 | if doc.ID != ids[0] {
429 | t.Fatal("expected", ids[0], "got", doc.ID)
430 | }
431 | if len(doc.Metadata) != 1 {
432 | t.Fatal("expected 1, got", len(doc.Metadata))
433 | }
434 | if !slices.Equal(doc.Embedding, vectors) {
435 | t.Fatal("expected", vectors, "got", doc.Embedding)
436 | }
437 | if doc.Content != contents[0] {
438 | t.Fatal("expected", contents[0], "got", doc.Content)
439 | }
440 |
441 | // Check error
442 | _, err = c.GetByID(ctx, "3")
443 | if err == nil {
444 | t.Fatal("expected error, got nil")
445 | }
446 | }
447 |
448 | func TestCollection_Count(t *testing.T) {
449 | // Create collection
450 | db := NewDB()
451 | name := "test"
452 | metadata := map[string]string{"foo": "bar"}
453 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
454 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
455 | return vectors, nil
456 | }
457 | c, err := db.CreateCollection(name, metadata, embeddingFunc)
458 | if err != nil {
459 | t.Fatal("expected no error, got", err)
460 | }
461 | if c == nil {
462 | t.Fatal("expected collection, got nil")
463 | }
464 |
465 | // Add documents
466 | ids := []string{"1", "2"}
467 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
468 | contents := []string{"hello world", "hallo welt"}
469 | err = c.Add(context.Background(), ids, nil, metadatas, contents)
470 | if err != nil {
471 | t.Fatal("expected nil, got", err)
472 | }
473 |
474 | // Check count
475 | if c.Count() != 2 {
476 | t.Fatal("expected 2, got", c.Count())
477 | }
478 | }
479 |
480 | func TestCollection_Delete(t *testing.T) {
481 | // Create persistent collection
482 | tmpdir, err := os.MkdirTemp(os.TempDir(), "chromem-test-*")
483 | if err != nil {
484 | t.Fatal("expected no error, got", err)
485 | }
486 | db, err := NewPersistentDB(tmpdir, false)
487 | if err != nil {
488 | t.Fatal("expected no error, got", err)
489 | }
490 | name := "test"
491 | metadata := map[string]string{"foo": "bar"}
492 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
493 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
494 | return vectors, nil
495 | }
496 | c, err := db.CreateCollection(name, metadata, embeddingFunc)
497 | if err != nil {
498 | t.Fatal("expected no error, got", err)
499 | }
500 | if c == nil {
501 | t.Fatal("expected collection, got nil")
502 | }
503 |
504 | // Add documents
505 | ids := []string{"1", "2", "3", "4"}
506 | metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}, {"foo": "bar"}, {"e": "f"}}
507 | contents := []string{"hello world", "hallo welt", "bonjour le monde", "hola mundo"}
508 | err = c.Add(context.Background(), ids, nil, metadatas, contents)
509 | if err != nil {
510 | t.Fatal("expected nil, got", err)
511 | }
512 |
513 | // Check count
514 | if c.Count() != 4 {
515 | t.Fatal("expected 4 documents, got", c.Count())
516 | }
517 |
518 | // Check number of files in the persist directory
519 | d, err := os.ReadDir(c.persistDirectory)
520 | if err != nil {
521 | t.Fatal("expected nil, got", err)
522 | }
523 | if len(d) != 5 { // 4 documents + 1 metadata file
524 | t.Fatal("expected 4 document files + 1 metadata file in persist_dir, got", len(d))
525 | }
526 |
527 | checkCount := func(expected int) {
528 | // Check count
529 | if c.Count() != expected {
530 | t.Fatalf("expected %d documents, got %d", expected, c.Count())
531 | }
532 |
533 | // Check number of files in the persist directory
534 | d, err = os.ReadDir(c.persistDirectory)
535 | if err != nil {
536 | t.Fatal("expected nil, got", err)
537 | }
538 | if len(d) != expected+1 { // 3 document + 1 metadata file
539 | t.Fatalf("expected %d document files + 1 metadata file in persist_dir, got %d", expected, len(d))
540 | }
541 | }
542 |
543 | // Test 1 - Remove document by ID: should delete one document
544 | err = c.Delete(context.Background(), nil, nil, "4")
545 | if err != nil {
546 | t.Fatal("expected nil, got", err)
547 | }
548 | checkCount(3)
549 |
550 | // Test 2 - Remove document by metadata
551 | err = c.Delete(context.Background(), map[string]string{"foo": "bar"}, nil)
552 | if err != nil {
553 | t.Fatal("expected nil, got", err)
554 | }
555 |
556 | checkCount(1)
557 |
558 | // Test 3 - Remove document by content
559 | err = c.Delete(context.Background(), nil, map[string]string{"$contains": "hallo welt"})
560 | if err != nil {
561 | t.Fatal("expected nil, got", err)
562 | }
563 |
564 | checkCount(0)
565 | }
566 |
567 | // Global var for assignment in the benchmark to avoid compiler optimizations.
568 | var globalRes []Result
569 |
570 | func BenchmarkCollection_Query_NoContent_100(b *testing.B) {
571 | benchmarkCollection_Query(b, 100, false)
572 | }
573 |
574 | func BenchmarkCollection_Query_NoContent_1000(b *testing.B) {
575 | benchmarkCollection_Query(b, 1000, false)
576 | }
577 |
578 | func BenchmarkCollection_Query_NoContent_5000(b *testing.B) {
579 | benchmarkCollection_Query(b, 5000, false)
580 | }
581 |
582 | func BenchmarkCollection_Query_NoContent_25000(b *testing.B) {
583 | benchmarkCollection_Query(b, 25000, false)
584 | }
585 |
586 | func BenchmarkCollection_Query_NoContent_100000(b *testing.B) {
587 | benchmarkCollection_Query(b, 100_000, false)
588 | }
589 |
590 | func BenchmarkCollection_Query_100(b *testing.B) {
591 | benchmarkCollection_Query(b, 100, true)
592 | }
593 |
594 | func BenchmarkCollection_Query_1000(b *testing.B) {
595 | benchmarkCollection_Query(b, 1000, true)
596 | }
597 |
598 | func BenchmarkCollection_Query_5000(b *testing.B) {
599 | benchmarkCollection_Query(b, 5000, true)
600 | }
601 |
602 | func BenchmarkCollection_Query_25000(b *testing.B) {
603 | benchmarkCollection_Query(b, 25000, true)
604 | }
605 |
606 | func BenchmarkCollection_Query_100000(b *testing.B) {
607 | benchmarkCollection_Query(b, 100_000, true)
608 | }
609 |
610 | // n is number of documents in the collection
611 | func benchmarkCollection_Query(b *testing.B, n int, withContent bool) {
612 | ctx := context.Background()
613 |
614 | // Seed to make deterministic
615 | r := rand.New(rand.NewSource(42))
616 |
617 | d := 1536 // dimensions, same as text-embedding-3-small
618 | // Random query vector
619 | qv := make([]float32, d)
620 | for j := 0; j < d; j++ {
621 | qv[j] = r.Float32()
622 | }
623 | // The document embeddings are normalized, so the query must be normalized too.
624 | qv = normalizeVector(qv)
625 |
626 | // Create collection
627 | db := NewDB()
628 | name := "test"
629 | embeddingFunc := func(_ context.Context, text string) ([]float32, error) {
630 | return nil, errors.New("embedding func not expected to be called")
631 | }
632 | c, err := db.CreateCollection(name, nil, embeddingFunc)
633 | if err != nil {
634 | b.Fatal("expected no error, got", err)
635 | }
636 | if c == nil {
637 | b.Fatal("expected collection, got nil")
638 | }
639 |
640 | // Add documents
641 | for i := 0; i < n; i++ {
642 | // Random embedding
643 | v := make([]float32, d)
644 | for j := 0; j < d; j++ {
645 | v[j] = r.Float32()
646 | }
647 | v = normalizeVector(v)
648 |
649 | // Add document with some metadata and content depending on parameter.
650 | // When providing embeddings, the embedding func is not called.
651 | is := strconv.Itoa(i)
652 | doc := Document{
653 | ID: is,
654 | Metadata: map[string]string{"i": is, "foo": "bar" + is},
655 | Embedding: v,
656 | }
657 | if withContent {
658 | // Let's say we embed 500 tokens, that's ~375 words, ~1875 characters
659 | doc.Content = randomString(r, 1875)
660 | }
661 |
662 | if err := c.AddDocument(ctx, doc); err != nil {
663 | b.Fatal("expected nil, got", err)
664 | }
665 | }
666 |
667 | b.ResetTimer()
668 |
669 | // Query
670 | var res []Result
671 | for i := 0; i < b.N; i++ {
672 | res, err = c.QueryEmbedding(ctx, qv, 10, nil, nil)
673 | }
674 | if err != nil {
675 | b.Fatal("expected nil, got", err)
676 | }
677 | globalRes = res
678 | }
679 |
680 | // randomString returns a random string of length n using lowercase letters and space.
681 | func randomString(r *rand.Rand, n int) string {
682 | // We add 5 spaces to get roughly one space every 5 characters
683 | characters := []rune("abcdefghijklmnopqrstuvwxyz ")
684 |
685 | b := make([]rune, n)
686 | for i := range b {
687 | b[i] = characters[r.Intn(len(characters))]
688 | }
689 | return string(b)
690 | }
691 |
--------------------------------------------------------------------------------
/db_test.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "context"
5 | "math/rand"
6 | "os"
7 | "path/filepath"
8 | "reflect"
9 | "slices"
10 | "testing"
11 | )
12 |
13 | func TestNewPersistentDB(t *testing.T) {
14 | t.Run("Create directory", func(t *testing.T) {
15 | r := rand.New(rand.NewSource(rand.Int63()))
16 | randString := randomString(r, 10)
17 | path := filepath.Join(os.TempDir(), randString)
18 | defer os.RemoveAll(path)
19 |
20 | // Path shouldn't exist yet
21 | if _, err := os.Stat(path); !os.IsNotExist(err) {
22 | t.Fatal("expected path to not exist, got", err)
23 | }
24 |
25 | db, err := NewPersistentDB(path, false)
26 | if err != nil {
27 | t.Fatal("expected no error, got", err)
28 | }
29 | if db == nil {
30 | t.Fatal("expected DB, got nil")
31 | }
32 |
33 | // Path should exist now
34 | if _, err := os.Stat(path); err != nil {
35 | t.Fatal("expected path to exist, got", err)
36 | }
37 | })
38 | t.Run("Existing directory", func(t *testing.T) {
39 | path, err := os.MkdirTemp(os.TempDir(), "")
40 | if err != nil {
41 | t.Fatal("couldn't create temp dir:", err)
42 | }
43 | defer os.RemoveAll(path)
44 |
45 | db, err := NewPersistentDB(path, false)
46 | if err != nil {
47 | t.Fatal("expected no error, got", err)
48 | }
49 | if db == nil {
50 | t.Fatal("expected DB, got nil")
51 | }
52 | })
53 | }
54 |
55 | func TestNewPersistentDB_Errors(t *testing.T) {
56 | t.Run("Path is an existing file", func(t *testing.T) {
57 | f, err := os.CreateTemp(os.TempDir(), "")
58 | if err != nil {
59 | t.Fatal("couldn't create temp file:", err)
60 | }
61 | defer os.RemoveAll(f.Name())
62 |
63 | _, err = NewPersistentDB(f.Name(), false)
64 | if err == nil {
65 | t.Fatal("expected error, got nil")
66 | }
67 | })
68 | }
69 |
70 | func TestDB_ImportExport(t *testing.T) {
71 | r := rand.New(rand.NewSource(rand.Int63()))
72 | randString := randomString(r, 10)
73 | path := filepath.Join(os.TempDir(), randString)
74 | defer os.RemoveAll(path)
75 |
76 | // Values in the collection
77 | name := "test"
78 | metadata := map[string]string{"foo": "bar"}
79 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
80 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
81 | return vectors, nil
82 | }
83 |
84 | tt := []struct {
85 | name string
86 | filePath string
87 | compress bool
88 | encryptionKey string
89 | }{
90 | {
91 | name: "gob",
92 | filePath: path + ".gob",
93 | compress: false,
94 | encryptionKey: "",
95 | },
96 | {
97 | name: "gob compressed",
98 | filePath: path + ".gob.gz",
99 | compress: true,
100 | encryptionKey: "",
101 | },
102 | {
103 | name: "gob compressed encrypted",
104 | filePath: path + ".gob.gz.enc",
105 | compress: true,
106 | encryptionKey: randomString(r, 32),
107 | },
108 | {
109 | name: "gob encrypted",
110 | filePath: path + ".gob.enc",
111 | compress: false,
112 | encryptionKey: randomString(r, 32),
113 | },
114 | }
115 |
116 | for _, tc := range tt {
117 | t.Run(tc.name, func(t *testing.T) {
118 | // Create DB, can just be in-memory
119 | origDB := NewDB()
120 |
121 | // Create collection
122 | c, err := origDB.CreateCollection(name, metadata, embeddingFunc)
123 | if err != nil {
124 | t.Fatal("expected no error, got", err)
125 | }
126 | if c == nil {
127 | t.Fatal("expected collection, got nil")
128 | }
129 | // Add document
130 | doc := Document{
131 | ID: name,
132 | Metadata: metadata,
133 | Embedding: vectors,
134 | Content: "test",
135 | }
136 | err = c.AddDocument(context.Background(), doc)
137 | if err != nil {
138 | t.Fatal("expected no error, got", err)
139 | }
140 |
141 | // Export
142 | err = origDB.ExportToFile(tc.filePath, tc.compress, tc.encryptionKey)
143 | if err != nil {
144 | t.Fatal("expected no error, got", err)
145 | }
146 |
147 | newDB := NewDB()
148 |
149 | // Import
150 | err = newDB.ImportFromFile(tc.filePath, tc.encryptionKey)
151 | if err != nil {
152 | t.Fatal("expected no error, got", err)
153 | }
154 |
155 | // Check expectations
156 | // We have to reset the embed function, but otherwise the DB objects
157 | // should be deep equal.
158 | c.embed = nil
159 | if !reflect.DeepEqual(origDB, newDB) {
160 | t.Fatalf("expected DB %+v, got %+v", origDB, newDB)
161 | }
162 | })
163 | }
164 | }
165 |
166 | func TestDB_ImportExportSpecificCollections(t *testing.T) {
167 | r := rand.New(rand.NewSource(rand.Int63()))
168 | randString := randomString(r, 10)
169 | path := filepath.Join(os.TempDir(), randString)
170 | filePath := path + ".gob"
171 | defer os.RemoveAll(path)
172 |
173 | // Values in the collection
174 | name := "test"
175 | name2 := "test2"
176 | metadata := map[string]string{"foo": "bar"}
177 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
178 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
179 | return vectors, nil
180 | }
181 |
182 | // Create DB, can just be in-memory
183 | origDB := NewDB()
184 |
185 | // Create collections
186 | c, err := origDB.CreateCollection(name, metadata, embeddingFunc)
187 | if err != nil {
188 | t.Fatal("expected no error, got", err)
189 | }
190 |
191 | c2, err := origDB.CreateCollection(name2, metadata, embeddingFunc)
192 | if err != nil {
193 | t.Fatal("expected no error, got", err)
194 | }
195 |
196 | // Add documents
197 | doc := Document{
198 | ID: name,
199 | Metadata: metadata,
200 | Embedding: vectors,
201 | Content: "test",
202 | }
203 |
204 | doc2 := Document{
205 | ID: name2,
206 | Metadata: metadata,
207 | Embedding: vectors,
208 | Content: "test2",
209 | }
210 |
211 | err = c.AddDocument(context.Background(), doc)
212 | if err != nil {
213 | t.Fatal("expected no error, got", err)
214 | }
215 |
216 | err = c2.AddDocument(context.Background(), doc2)
217 | if err != nil {
218 | t.Fatal("expected no error, got", err)
219 | }
220 |
221 | // Export only one of the two collections
222 | err = origDB.ExportToFile(filePath, false, "", name2)
223 | if err != nil {
224 | t.Fatal("expected no error, got", err)
225 | }
226 |
227 | dir := filepath.Join(path, randomString(r, 10))
228 | defer os.RemoveAll(dir)
229 |
230 | // Instead of importing to an in-memory DB we use a persistent one to cover the behavior of immediate persistent files being created for the imported data
231 | newPDB, err := NewPersistentDB(dir, false)
232 | if err != nil {
233 | t.Fatal("expected no error, got", err)
234 | }
235 |
236 | err = newPDB.ImportFromFile(filePath, "")
237 | if err != nil {
238 | t.Fatal("expected no error, got", err)
239 | }
240 |
241 | if len(newPDB.collections) != 1 {
242 | t.Fatalf("expected 1 collection, got %d", len(newPDB.collections))
243 | }
244 |
245 | // Make sure that the imported documents are actually persisted on disk
246 | for _, col := range newPDB.collections {
247 | for _, d := range col.documents {
248 | _, err = os.Stat(col.getDocPath(d.ID))
249 | if err != nil {
250 | t.Fatalf("expected no error when looking up persistent file for doc %q, got %v", d.ID, err)
251 | }
252 | }
253 | }
254 |
255 | // Now export both collections and import them into the same persistent DB (overwriting the one we just imported)
256 | filePath2 := filepath.Join(path, "2.gob")
257 | err = origDB.ExportToFile(filePath2, false, "")
258 | if err != nil {
259 | t.Fatal("expected no error, got", err)
260 | }
261 |
262 | err = newPDB.ImportFromFile(filePath2, "")
263 | if err != nil {
264 | t.Fatal("expected no error, got", err)
265 | }
266 |
267 | if len(newPDB.collections) != 2 {
268 | t.Fatalf("expected 2 collections, got %d", len(newPDB.collections))
269 | }
270 |
271 | // Make sure that the imported documents are actually persisted on disk
272 | for _, col := range newPDB.collections {
273 | for _, d := range col.documents {
274 | _, err = os.Stat(col.getDocPath(d.ID))
275 | if err != nil {
276 | t.Fatalf("expected no error when looking up persistent file for doc %q, got %v", d.ID, err)
277 | }
278 | }
279 | }
280 | }
281 |
282 | func TestDB_CreateCollection(t *testing.T) {
283 | // Values in the collection
284 | name := "test"
285 | metadata := map[string]string{"foo": "bar"}
286 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
287 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
288 | return vectors, nil
289 | }
290 |
291 | db := NewDB()
292 |
293 | t.Run("OK", func(t *testing.T) {
294 | c, err := db.CreateCollection(name, metadata, embeddingFunc)
295 | if err != nil {
296 | t.Fatal("expected no error, got", err)
297 | }
298 | if c == nil {
299 | t.Fatal("expected collection, got nil")
300 | }
301 |
302 | // Check expectations
303 |
304 | // DB should have one collection now
305 | if len(db.collections) != 1 {
306 | t.Fatal("expected 1 collection, got", len(db.collections))
307 | }
308 | // The collection should be the one we just created
309 | c2, ok := db.collections[name]
310 | if !ok {
311 | t.Fatal("expected collection", name, "not found")
312 | }
313 | // Check the embedding function first, then the rest with DeepEqual
314 | gotVectors, err := c.embed(context.Background(), "test")
315 | if err != nil {
316 | t.Fatal("expected no error, got", err)
317 | }
318 | if !slices.Equal(gotVectors, vectors) {
319 | t.Fatal("expected vectors", vectors, "got", gotVectors)
320 | }
321 | c.embed, c2.embed = nil, nil
322 | if !reflect.DeepEqual(c, c2) {
323 | t.Fatalf("expected collection %+v, got %+v", c, c2)
324 | }
325 | })
326 |
327 | t.Run("NOK - Empty name", func(t *testing.T) {
328 | _, err := db.CreateCollection("", metadata, embeddingFunc)
329 | if err == nil {
330 | t.Fatal("expected error, got nil")
331 | }
332 | })
333 | }
334 |
335 | func TestDB_ListCollections(t *testing.T) {
336 | // Values in the collection
337 | name := "test"
338 | metadata := map[string]string{"foo": "bar"}
339 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
340 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
341 | return vectors, nil
342 | }
343 |
344 | // Create initial collection
345 | db := NewDB()
346 | orig, err := db.CreateCollection(name, metadata, embeddingFunc)
347 | if err != nil {
348 | t.Fatal("expected no error, got", err)
349 | }
350 |
351 | // List collections
352 | res := db.ListCollections()
353 |
354 | // Check expectations
355 |
356 | // Should've returned a map with one collection
357 | if len(res) != 1 {
358 | t.Fatal("expected 1 collection, got", len(res))
359 | }
360 | // The collection should be the one we just created
361 | c, ok := res[name]
362 | if !ok {
363 | t.Fatal("expected collection", name, "not found")
364 | }
365 | // Check the embedding function first, then the rest with DeepEqual
366 | gotVectors, err := c.embed(context.Background(), "test")
367 | if err != nil {
368 | t.Fatal("expected no error, got", err)
369 | }
370 | if !slices.Equal(gotVectors, vectors) {
371 | t.Fatal("expected vectors", vectors, "got", gotVectors)
372 | }
373 | orig.embed, c.embed = nil, nil
374 | if !reflect.DeepEqual(orig, c) {
375 | t.Fatalf("expected collection %+v, got %+v", orig, c)
376 | }
377 |
378 | // And it should be a copy. Adding a value here should not reflect on the DB's
379 | // collection.
380 | res["foo"] = &Collection{}
381 | if len(db.ListCollections()) != 1 {
382 | t.Fatal("expected 1 collection, got", len(db.ListCollections()))
383 | }
384 | }
385 |
386 | func TestDB_GetCollection(t *testing.T) {
387 | // Values in the collection
388 | name := "test"
389 | metadata := map[string]string{"foo": "bar"}
390 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
391 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
392 | return vectors, nil
393 | }
394 |
395 | // Create initial collection
396 | db := NewDB()
397 | orig, err := db.CreateCollection(name, metadata, embeddingFunc)
398 | if err != nil {
399 | t.Fatal("expected no error, got", err)
400 | }
401 |
402 | // Get collection
403 | c := db.GetCollection(name, nil)
404 |
405 | // Check the embedding function first, then the rest with DeepEqual
406 | gotVectors, err := c.embed(context.Background(), "test")
407 | if err != nil {
408 | t.Fatal("expected no error, got", err)
409 | }
410 | if !slices.Equal(gotVectors, vectors) {
411 | t.Fatal("expected vectors", vectors, "got", gotVectors)
412 | }
413 | orig.embed, c.embed = nil, nil
414 | if !reflect.DeepEqual(orig, c) {
415 | t.Fatalf("expected collection %+v, got %+v", orig, c)
416 | }
417 | }
418 |
419 | func TestDB_GetOrCreateCollection(t *testing.T) {
420 | // Values in the collection
421 | name := "test"
422 | metadata := map[string]string{"foo": "bar"}
423 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
424 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
425 | return vectors, nil
426 | }
427 |
428 | t.Run("Get", func(t *testing.T) {
429 | // Create initial collection
430 | db := NewDB()
431 | // Create collection so that the GetOrCreateCollection() call below only
432 | // gets it.
433 | orig, err := db.CreateCollection(name, metadata, embeddingFunc)
434 | if err != nil {
435 | t.Fatal("expected no error, got", err)
436 | }
437 |
438 | // Call GetOrCreateCollection() with the same name to only get it. We pass
439 | // nil for the metadata and embeddingFunc, so we can check that the returned
440 | // collection is the original one, and not a new one.
441 | c, err := db.GetOrCreateCollection(name, nil, nil)
442 | if err != nil {
443 | t.Fatal("expected no error, got", err)
444 | }
445 | if c == nil {
446 | t.Fatal("expected collection, got nil")
447 | }
448 |
449 | // Check the embedding function first, then the rest with DeepEqual
450 | gotVectors, err := c.embed(context.Background(), "test")
451 | if err != nil {
452 | t.Fatal("expected no error, got", err)
453 | }
454 | if !slices.Equal(gotVectors, vectors) {
455 | t.Fatal("expected vectors", vectors, "got", gotVectors)
456 | }
457 | orig.embed, c.embed = nil, nil
458 | if !reflect.DeepEqual(orig, c) {
459 | t.Fatalf("expected collection %+v, got %+v", orig, c)
460 | }
461 | })
462 |
463 | t.Run("Create", func(t *testing.T) {
464 | // Create initial collection
465 | db := NewDB()
466 |
467 | // Call GetOrCreateCollection()
468 | c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc)
469 | if err != nil {
470 | t.Fatal("expected no error, got", err)
471 | }
472 | if c == nil {
473 | t.Fatal("expected collection, got nil")
474 | }
475 |
476 | // Check like we check CreateCollection()
477 | c2, ok := db.collections[name]
478 | if !ok {
479 | t.Fatal("expected collection", name, "not found")
480 | }
481 | gotVectors, err := c.embed(context.Background(), "test")
482 | if err != nil {
483 | t.Fatal("expected no error, got", err)
484 | }
485 | if !slices.Equal(gotVectors, vectors) {
486 | t.Fatal("expected vectors", vectors, "got", gotVectors)
487 | }
488 | c.embed, c2.embed = nil, nil
489 | if !reflect.DeepEqual(c, c2) {
490 | t.Fatalf("expected collection %+v, got %+v", c, c2)
491 | }
492 | })
493 | }
494 |
495 | func TestDB_DeleteCollection(t *testing.T) {
496 | // Values in the collection
497 | name := "test"
498 | metadata := map[string]string{"foo": "bar"}
499 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
500 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
501 | return vectors, nil
502 | }
503 |
504 | // Create initial collection
505 | db := NewDB()
506 | // We ignore the return value. CreateCollection is tested elsewhere.
507 | _, err := db.CreateCollection(name, metadata, embeddingFunc)
508 | if err != nil {
509 | t.Fatal("expected no error, got", err)
510 | }
511 |
512 | // Delete collection
513 | if err := db.DeleteCollection(name); err != nil {
514 | t.Fatal("expected no error, got", err)
515 | }
516 |
517 | // Check expectations
518 | // We don't have access to the documents field, but we can rely on DB.ListCollections()
519 | // because it's tested elsewhere.
520 | if len(db.ListCollections()) != 0 {
521 | t.Fatal("expected 0 collections, got", len(db.ListCollections()))
522 | }
523 | // Also check internally
524 | if len(db.collections) != 0 {
525 | t.Fatal("expected 0 collections, got", len(db.collections))
526 | }
527 | }
528 |
529 | func TestDB_Reset(t *testing.T) {
530 | // Values in the collection
531 | name := "test"
532 | metadata := map[string]string{"foo": "bar"}
533 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
534 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
535 | return vectors, nil
536 | }
537 |
538 | // Create initial collection
539 | db := NewDB()
540 | // We ignore the return value. CreateCollection is tested elsewhere.
541 | _, err := db.CreateCollection(name, metadata, embeddingFunc)
542 | if err != nil {
543 | t.Fatal("expected no error, got", err)
544 | }
545 |
546 | // Reset DB
547 | if err := db.Reset(); err != nil {
548 | t.Fatal("expected no error, got", err)
549 | }
550 |
551 | // Check expectations
552 | // We don't have access to the documents field, but we can rely on DB.ListCollections()
553 | // because it's tested elsewhere.
554 | if len(db.ListCollections()) != 0 {
555 | t.Fatal("expected 0 collections, got", len(db.ListCollections()))
556 | }
557 | // Also check internally
558 | if len(db.collections) != 0 {
559 | t.Fatal("expected 0 collections, got", len(db.collections))
560 | }
561 | }
562 |
--------------------------------------------------------------------------------
/document.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "context"
5 | "errors"
6 | )
7 |
8 | // Document represents a single document.
9 | type Document struct {
10 | ID string
11 | Metadata map[string]string
12 | Embedding []float32
13 | Content string
14 |
15 | // ⚠️ When adding unexported fields here, consider adding a persistence struct
16 | // version of this in [DB.Export] and [DB.Import].
17 | }
18 |
19 | // NewDocument creates a new document, including its embeddings.
20 | // Metadata is optional.
21 | // If the embeddings are not provided, they are created using the embedding function.
22 | // You can leave the content empty if you only want to store embeddings.
23 | // If embeddingFunc is nil, the default embedding function is used.
24 | //
25 | // If you want to create a document without embeddings, for example to let [Collection.AddDocuments]
26 | // create them concurrently, you can create a document with `chromem.Document{...}`
27 | // instead of using this constructor.
28 | func NewDocument(ctx context.Context, id string, metadata map[string]string, embedding []float32, content string, embeddingFunc EmbeddingFunc) (Document, error) {
29 | if id == "" {
30 | return Document{}, errors.New("id is empty")
31 | }
32 | if len(embedding) == 0 && content == "" {
33 | return Document{}, errors.New("either embedding or content must be filled")
34 | }
35 | if embeddingFunc == nil {
36 | embeddingFunc = NewEmbeddingFuncDefault()
37 | }
38 |
39 | if len(embedding) == 0 {
40 | var err error
41 | embedding, err = embeddingFunc(ctx, content)
42 | if err != nil {
43 | return Document{}, err
44 | }
45 | }
46 |
47 | return Document{
48 | ID: id,
49 | Metadata: metadata,
50 | Embedding: embedding,
51 | Content: content,
52 | }, nil
53 | }
54 |
--------------------------------------------------------------------------------
/document_test.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "context"
5 | "reflect"
6 | "testing"
7 | )
8 |
9 | func TestDocument_New(t *testing.T) {
10 | ctx := context.Background()
11 | id := "test"
12 | metadata := map[string]string{"foo": "bar"}
13 | vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
14 | content := "hello world"
15 | embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
16 | return vectors, nil
17 | }
18 |
19 | tt := []struct {
20 | name string
21 | id string
22 | metadata map[string]string
23 | vectors []float32
24 | content string
25 | embeddingFunc EmbeddingFunc
26 | }{
27 | {
28 | name: "No embedding",
29 | id: id,
30 | metadata: metadata,
31 | vectors: nil,
32 | content: content,
33 | embeddingFunc: embeddingFunc,
34 | },
35 | {
36 | name: "With embedding",
37 | id: id,
38 | metadata: metadata,
39 | vectors: vectors,
40 | content: content,
41 | embeddingFunc: embeddingFunc,
42 | },
43 | }
44 |
45 | for _, tc := range tt {
46 | t.Run(tc.name, func(t *testing.T) {
47 | // Create document
48 | d, err := NewDocument(ctx, id, metadata, vectors, content, embeddingFunc)
49 | if err != nil {
50 | t.Fatal("expected no error, got", err)
51 | }
52 | // We can compare with DeepEqual after removing the embedding function
53 | d.Embedding = nil
54 | exp := Document{
55 | ID: id,
56 | Metadata: metadata,
57 | Content: content,
58 | }
59 | if !reflect.DeepEqual(exp, d) {
60 | t.Fatalf("expected %+v, got %+v", exp, d)
61 | }
62 | })
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/embed_cohere.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "net/http"
11 | "strings"
12 | "sync"
13 | )
14 |
15 | type EmbeddingModelCohere string
16 |
17 | const (
18 | EmbeddingModelCohereMultilingualV2 EmbeddingModelCohere = "embed-multilingual-v2.0"
19 | EmbeddingModelCohereEnglishLightV2 EmbeddingModelCohere = "embed-english-light-v2.0"
20 | EmbeddingModelCohereEnglishV2 EmbeddingModelCohere = "embed-english-v2.0"
21 |
22 | EmbeddingModelCohereMultilingualLightV3 EmbeddingModelCohere = "embed-multilingual-light-v3.0"
23 | EmbeddingModelCohereEnglishLightV3 EmbeddingModelCohere = "embed-english-light-v3.0"
24 | EmbeddingModelCohereMultilingualV3 EmbeddingModelCohere = "embed-multilingual-v3.0"
25 | EmbeddingModelCohereEnglishV3 EmbeddingModelCohere = "embed-english-v3.0"
26 | )
27 |
28 | // Prefixes for external use.
29 | const (
30 | InputTypeCohereSearchDocumentPrefix string = "search_document: "
31 | InputTypeCohereSearchQueryPrefix string = "search_query: "
32 | InputTypeCohereClassificationPrefix string = "classification: "
33 | InputTypeCohereClusteringPrefix string = "clustering: "
34 | )
35 |
36 | // Input types for internal use.
37 | const (
38 | inputTypeCohereSearchDocument string = "search_document"
39 | inputTypeCohereSearchQuery string = "search_query"
40 | inputTypeCohereClassification string = "classification"
41 | inputTypeCohereClustering string = "clustering"
42 | )
43 |
44 | const baseURLCohere = "https://api.cohere.ai/v1"
45 |
46 | var validInputTypesCohere = map[string]string{
47 | inputTypeCohereSearchDocument: InputTypeCohereSearchDocumentPrefix,
48 | inputTypeCohereSearchQuery: InputTypeCohereSearchQueryPrefix,
49 | inputTypeCohereClassification: InputTypeCohereClassificationPrefix,
50 | inputTypeCohereClustering: InputTypeCohereClusteringPrefix,
51 | }
52 |
53 | type cohereResponse struct {
54 | Embeddings [][]float32 `json:"embeddings"`
55 | }
56 |
57 | // NewEmbeddingFuncCohere returns a function that creates embeddings for a text
58 | // using Cohere's API. One important difference to OpenAI's and other's APIs is
59 | // that Cohere differentiates between document embeddings and search/query embeddings.
60 | // In order for this embedding func to do the differentiation, you have to prepend
61 | // the text with either "search_document" or "search_query". We'll cut off that
62 | // prefix before sending the document/query body to the API, we'll just use the
63 | // prefix to choose the right "input type" as they call it.
64 | //
65 | // When you set up a chromem-go collection with this embedding function, you might
66 | // want to create the document separately with [NewDocument] and then cut off the
67 | // prefix before adding the document to the collection. Otherwise, when you query
68 | // the collection, the returned documents will still have the prefix in their content.
69 | //
70 | // cohereFunc := chromem.NewEmbeddingFuncCohere(cohereApiKey, chromem.EmbeddingModelCohereEnglishV3)
71 | // content := "The sky is blue because of Rayleigh scattering."
72 | // // Create the document with the prefix.
73 | // contentWithPrefix := chromem.InputTypeCohereSearchDocumentPrefix + content
74 | // doc, _ := NewDocument(ctx, id, metadata, nil, contentWithPrefix, cohereFunc)
75 | // // Remove the prefix so that later query results don't have it.
76 | // doc.Content = content
77 | // _ = collection.AddDocument(ctx, doc)
78 | //
79 | // This is not necessary if you don't keep the content in the documents, as chromem-go
80 | // also works when documents only have embeddings.
81 | // You can also keep the prefix in the document, and only remove it after querying.
82 | //
83 | // We plan to improve this in the future.
84 | func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) EmbeddingFunc {
85 | // We don't set a default timeout here, although it's usually a good idea.
86 | // In our case though, the library user can set the timeout on the context,
87 | // and it might have to be a long timeout, depending on the text length.
88 | client := &http.Client{}
89 |
90 | var checkedNormalized bool
91 | checkNormalized := sync.Once{}
92 |
93 | return func(ctx context.Context, text string) ([]float32, error) {
94 | var inputType string
95 | for validInputType, validInputTypePrefix := range validInputTypesCohere {
96 | if strings.HasPrefix(text, validInputTypePrefix) {
97 | inputType = validInputType
98 | text = strings.TrimPrefix(text, validInputTypePrefix)
99 | break
100 | }
101 | }
102 | if inputType == "" {
103 | return nil, errors.New("text must start with a valid input type plus colon and space")
104 | }
105 |
106 | // Prepare the request body.
107 | reqBody, err := json.Marshal(map[string]any{
108 | "model": model,
109 | "texts": []string{text},
110 | "input_type": inputType,
111 | })
112 | if err != nil {
113 | return nil, fmt.Errorf("couldn't marshal request body: %w", err)
114 | }
115 |
116 | // Create the request. Creating it with context is important for a timeout
117 | // to be possible, because the client is configured without a timeout.
118 | req, err := http.NewRequestWithContext(ctx, "POST", baseURLCohere+"/embed", bytes.NewBuffer(reqBody))
119 | if err != nil {
120 | return nil, fmt.Errorf("couldn't create request: %w", err)
121 | }
122 | req.Header.Set("Accept", "application/json")
123 | req.Header.Set("Content-Type", "application/json")
124 | req.Header.Set("Authorization", "Bearer "+apiKey)
125 |
126 | // Send the request.
127 | resp, err := client.Do(req)
128 | if err != nil {
129 | return nil, fmt.Errorf("couldn't send request: %w", err)
130 | }
131 | defer resp.Body.Close()
132 |
133 | // Check the response status.
134 | if resp.StatusCode != http.StatusOK {
135 | return nil, errors.New("error response from the embedding API: " + resp.Status)
136 | }
137 |
138 | // Read and decode the response body.
139 | body, err := io.ReadAll(resp.Body)
140 | if err != nil {
141 | return nil, fmt.Errorf("couldn't read response body: %w", err)
142 | }
143 | var embeddingResponse cohereResponse
144 | err = json.Unmarshal(body, &embeddingResponse)
145 | if err != nil {
146 | return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
147 | }
148 |
149 | // Check if the response contains embeddings.
150 | if len(embeddingResponse.Embeddings) == 0 || len(embeddingResponse.Embeddings[0]) == 0 {
151 | return nil, errors.New("no embeddings found in the response")
152 | }
153 |
154 | v := embeddingResponse.Embeddings[0]
155 | checkNormalized.Do(func() {
156 | if isNormalized(v) {
157 | checkedNormalized = true
158 | } else {
159 | checkedNormalized = false
160 | }
161 | })
162 | if !checkedNormalized {
163 | v = normalizeVector(v)
164 | }
165 |
166 | return v, nil
167 | }
168 | }
169 |
--------------------------------------------------------------------------------
/embed_compat.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | const (
4 | baseURLMistral = "https://api.mistral.ai/v1"
5 | // Currently there's only one. Let's turn this into a pseudo-enum as soon as there are more.
6 | embeddingModelMistral = "mistral-embed"
7 | )
8 |
9 | // NewEmbeddingFuncMistral returns a function that creates embeddings for a text
10 | // using the Mistral API.
11 | func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc {
12 | // Mistral embeddings are normalized, see section "Distance Measures" on
13 | // https://docs.mistral.ai/guides/embeddings/.
14 | normalized := true
15 |
16 | // The Mistral API docs don't mention the `encoding_format` as optional,
17 | // but it seems to be, just like OpenAI. So we reuse the OpenAI function.
18 | return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral, &normalized)
19 | }
20 |
21 | const baseURLJina = "https://api.jina.ai/v1"
22 |
23 | type EmbeddingModelJina string
24 |
25 | const (
26 | EmbeddingModelJina2BaseEN EmbeddingModelJina = "jina-embeddings-v2-base-en"
27 | EmbeddingModelJina2BaseES EmbeddingModelJina = "jina-embeddings-v2-base-es"
28 | EmbeddingModelJina2BaseDE EmbeddingModelJina = "jina-embeddings-v2-base-de"
29 | EmbeddingModelJina2BaseZH EmbeddingModelJina = "jina-embeddings-v2-base-zh"
30 |
31 | EmbeddingModelJina2BaseCode EmbeddingModelJina = "jina-embeddings-v2-base-code"
32 |
33 | EmbeddingModelJinaClipV1 EmbeddingModelJina = "jina-clip-v1"
34 | )
35 |
36 | // NewEmbeddingFuncJina returns a function that creates embeddings for a text
37 | // using the Jina API.
38 | func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc {
39 | return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model), nil)
40 | }
41 |
42 | const baseURLMixedbread = "https://api.mixedbread.ai"
43 |
44 | type EmbeddingModelMixedbread string
45 |
46 | const (
47 | // Possibly outdated / not available anymore
48 | EmbeddingModelMixedbreadUAELargeV1 EmbeddingModelMixedbread = "UAE-Large-V1"
49 | // Possibly outdated / not available anymore
50 | EmbeddingModelMixedbreadBGELargeENV15 EmbeddingModelMixedbread = "bge-large-en-v1.5"
51 | // Possibly outdated / not available anymore
52 | EmbeddingModelMixedbreadGTELarge EmbeddingModelMixedbread = "gte-large"
53 | // Possibly outdated / not available anymore
54 | EmbeddingModelMixedbreadE5LargeV2 EmbeddingModelMixedbread = "e5-large-v2"
55 | // Possibly outdated / not available anymore
56 | EmbeddingModelMixedbreadMultilingualE5Large EmbeddingModelMixedbread = "multilingual-e5-large"
57 | // Possibly outdated / not available anymore
58 | EmbeddingModelMixedbreadMultilingualE5Base EmbeddingModelMixedbread = "multilingual-e5-base"
59 | // Possibly outdated / not available anymore
60 | EmbeddingModelMixedbreadAllMiniLML6V2 EmbeddingModelMixedbread = "all-MiniLM-L6-v2"
61 | // Possibly outdated / not available anymore
62 | EmbeddingModelMixedbreadGTELargeZh EmbeddingModelMixedbread = "gte-large-zh"
63 |
64 | EmbeddingModelMixedbreadLargeV1 EmbeddingModelMixedbread = "mxbai-embed-large-v1"
65 | EmbeddingModelMixedbreadDeepsetDELargeV1 EmbeddingModelMixedbread = "deepset-mxbai-embed-de-large-v1"
66 | EmbeddingModelMixedbread2DLargeV1 EmbeddingModelMixedbread = "mxbai-embed-2d-large-v1"
67 | )
68 |
69 | // NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text
70 | // using the mixedbread.ai API.
71 | func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc {
72 | return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model), nil)
73 | }
74 |
75 | const baseURLLocalAI = "http://localhost:8080/v1"
76 |
77 | // NewEmbeddingFuncLocalAI returns a function that creates embeddings for a text
78 | // using the LocalAI API.
79 | // You can start a LocalAI instance like this:
80 | //
81 | // docker run -it -p 127.0.0.1:8080:8080 localai/localai:v2.7.0-ffmpeg-core bert-cpp
82 | //
83 | // And then call this constructor with model "bert-cpp-minilm-v6".
84 | // But other embedding models are supported as well. See the LocalAI documentation
85 | // for details.
86 | func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc {
87 | return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil)
88 | }
89 |
90 | const (
91 | azureDefaultAPIVersion = "2024-02-01"
92 | )
93 |
94 | // NewEmbeddingFuncAzureOpenAI returns a function that creates embeddings for a text
95 | // using the Azure OpenAI API.
96 | // The `deploymentURL` is the URL of the deployed model, e.g. "https://YOUR_RESOURCE_NAME.openai.azure.com/openai/deployments/YOUR_DEPLOYMENT_NAME"
97 | // See https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=console#how-to-get-embeddings
98 | func NewEmbeddingFuncAzureOpenAI(apiKey string, deploymentURL string, apiVersion string, model string) EmbeddingFunc {
99 | if apiVersion == "" {
100 | apiVersion = azureDefaultAPIVersion
101 | }
102 | return newEmbeddingFuncOpenAICompat(deploymentURL, apiKey, model, nil, map[string]string{"api-key": apiKey}, map[string]string{"api-version": apiVersion})
103 | }
104 |
--------------------------------------------------------------------------------
/embed_ollama.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "net/http"
11 | "sync"
12 | )
13 |
14 | const defaultBaseURLOllama = "http://localhost:11434/api"
15 |
16 | type ollamaResponse struct {
17 | Embedding []float32 `json:"embedding"`
18 | }
19 |
20 | // NewEmbeddingFuncOllama returns a function that creates embeddings for a text
21 | // using Ollama's embedding API. You can pass any model that Ollama supports and
22 | // that supports embeddings. A good one as of 2024-03-02 is "nomic-embed-text".
23 | // See https://ollama.com/library/nomic-embed-text
24 | // baseURLOllama is the base URL of the Ollama API. If it's empty,
25 | // "http://localhost:11434/api" is used.
26 | func NewEmbeddingFuncOllama(model string, baseURLOllama string) EmbeddingFunc {
27 | if baseURLOllama == "" {
28 | baseURLOllama = defaultBaseURLOllama
29 | }
30 |
31 | // We don't set a default timeout here, although it's usually a good idea.
32 | // In our case though, the library user can set the timeout on the context,
33 | // and it might have to be a long timeout, depending on the text length.
34 | client := &http.Client{}
35 |
36 | var checkedNormalized bool
37 | checkNormalized := sync.Once{}
38 |
39 | return func(ctx context.Context, text string) ([]float32, error) {
40 | // Prepare the request body.
41 | reqBody, err := json.Marshal(map[string]string{
42 | "model": model,
43 | "prompt": text,
44 | })
45 | if err != nil {
46 | return nil, fmt.Errorf("couldn't marshal request body: %w", err)
47 | }
48 |
49 | // Create the request. Creating it with context is important for a timeout
50 | // to be possible, because the client is configured without a timeout.
51 | req, err := http.NewRequestWithContext(ctx, "POST", baseURLOllama+"/embeddings", bytes.NewBuffer(reqBody))
52 | if err != nil {
53 | return nil, fmt.Errorf("couldn't create request: %w", err)
54 | }
55 | req.Header.Set("Content-Type", "application/json")
56 |
57 | // Send the request.
58 | resp, err := client.Do(req)
59 | if err != nil {
60 | return nil, fmt.Errorf("couldn't send request: %w", err)
61 | }
62 | defer resp.Body.Close()
63 |
64 | // Check the response status.
65 | if resp.StatusCode != http.StatusOK {
66 | return nil, errors.New("error response from the embedding API: " + resp.Status)
67 | }
68 |
69 | // Read and decode the response body.
70 | body, err := io.ReadAll(resp.Body)
71 | if err != nil {
72 | return nil, fmt.Errorf("couldn't read response body: %w", err)
73 | }
74 | var embeddingResponse ollamaResponse
75 | err = json.Unmarshal(body, &embeddingResponse)
76 | if err != nil {
77 | return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
78 | }
79 |
80 | // Check if the response contains embeddings.
81 | if len(embeddingResponse.Embedding) == 0 {
82 | return nil, errors.New("no embeddings found in the response")
83 | }
84 |
85 | v := embeddingResponse.Embedding
86 | checkNormalized.Do(func() {
87 | if isNormalized(v) {
88 | checkedNormalized = true
89 | } else {
90 | checkedNormalized = false
91 | }
92 | })
93 | if !checkedNormalized {
94 | v = normalizeVector(v)
95 | }
96 |
97 | return v, nil
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/embed_ollama_test.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "io"
8 | "net/http"
9 | "net/http/httptest"
10 | "net/url"
11 | "slices"
12 | "strings"
13 | "testing"
14 | )
15 |
16 | func TestNewEmbeddingFuncOllama(t *testing.T) {
17 | model := "model-small"
18 | baseURLSuffix := "/api"
19 | prompt := "hello world"
20 |
21 | wantBody, err := json.Marshal(map[string]string{
22 | "model": model,
23 | "prompt": prompt,
24 | })
25 | if err != nil {
26 | t.Fatal("unexpected error:", err)
27 | }
28 | wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
29 |
30 | // Mock server
31 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
32 | // Check URL
33 | if !strings.HasSuffix(r.URL.Path, baseURLSuffix+"/embeddings") {
34 | t.Fatal("expected URL", baseURLSuffix+"/embeddings", "got", r.URL.Path)
35 | }
36 | // Check method
37 | if r.Method != "POST" {
38 | t.Fatal("expected method POST, got", r.Method)
39 | }
40 | // Check headers
41 | if r.Header.Get("Content-Type") != "application/json" {
42 | t.Fatal("expected Content-Type header", "application/json", "got", r.Header.Get("Content-Type"))
43 | }
44 | // Check body
45 | body, err := io.ReadAll(r.Body)
46 | if err != nil {
47 | t.Fatal("unexpected error:", err)
48 | }
49 | if !bytes.Equal(body, wantBody) {
50 | t.Fatal("expected body", wantBody, "got", body)
51 | }
52 |
53 | // Write response
54 | resp := ollamaResponse{
55 | Embedding: wantRes,
56 | }
57 | w.WriteHeader(http.StatusOK)
58 | _ = json.NewEncoder(w).Encode(resp)
59 | }))
60 | defer ts.Close()
61 |
62 | // Get port from URL
63 | u, err := url.Parse(ts.URL)
64 | if err != nil {
65 | t.Fatal("unexpected error:", err)
66 | }
67 |
68 | f := NewEmbeddingFuncOllama(model, strings.Replace(defaultBaseURLOllama, "11434", u.Port(), 1))
69 | res, err := f(context.Background(), prompt)
70 | if err != nil {
71 | t.Fatal("expected nil, got", err)
72 | }
73 | if slices.Compare(wantRes, res) != 0 {
74 | t.Fatal("expected res", wantRes, "got", res)
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/embed_openai.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "net/http"
11 | "os"
12 | "sync"
13 | )
14 |
15 | const BaseURLOpenAI = "https://api.openai.com/v1"
16 |
17 | type EmbeddingModelOpenAI string
18 |
19 | const (
20 | EmbeddingModelOpenAI2Ada EmbeddingModelOpenAI = "text-embedding-ada-002"
21 |
22 | EmbeddingModelOpenAI3Small EmbeddingModelOpenAI = "text-embedding-3-small"
23 | EmbeddingModelOpenAI3Large EmbeddingModelOpenAI = "text-embedding-3-large"
24 | )
25 |
26 | type openAIResponse struct {
27 | Data []struct {
28 | Embedding []float32 `json:"embedding"`
29 | } `json:"data"`
30 | }
31 |
32 | // NewEmbeddingFuncDefault returns a function that creates embeddings for a text
33 | // using OpenAI`s "text-embedding-3-small" model via their API.
34 | // The model supports a maximum text length of 8191 tokens.
35 | // The API key is read from the environment variable "OPENAI_API_KEY".
36 | func NewEmbeddingFuncDefault() EmbeddingFunc {
37 | apiKey := os.Getenv("OPENAI_API_KEY")
38 | return NewEmbeddingFuncOpenAI(apiKey, EmbeddingModelOpenAI3Small)
39 | }
40 |
41 | // NewEmbeddingFuncOpenAI returns a function that creates embeddings for a text
42 | // using the OpenAI API.
43 | func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc {
44 | // OpenAI embeddings are normalized
45 | normalized := true
46 | return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized)
47 | }
48 |
49 | // NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text
50 | // using an OpenAI compatible API. For example:
51 | // - Azure OpenAI: https://azure.microsoft.com/en-us/products/ai-services/openai-service
52 | // - LitLLM: https://github.com/BerriAI/litellm
53 | // - Ollama: https://github.com/ollama/ollama/blob/main/docs/openai.md
54 | // - etc.
55 | //
56 | // The `normalized` parameter indicates whether the vectors returned by the embedding
57 | // model are already normalized, as is the case for OpenAI's and Mistral's models.
58 | // The flag is optional. If it's nil, it will be autodetected on the first request
59 | // (which bears a small risk that the vector just happens to have a length of 1).
60 | func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc {
61 | return newEmbeddingFuncOpenAICompat(baseURL, apiKey, model, normalized, nil, nil)
62 | }
63 |
64 | // newEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text
65 | // using an OpenAI compatible API.
66 | // It offers options to set request headers and query parameters
67 | // e.g. to pass the `api-key` header and the `api-version` query parameter for Azure OpenAI.
68 | //
69 | // The `normalized` parameter indicates whether the vectors returned by the embedding
70 | // model are already normalized, as is the case for OpenAI's and Mistral's models.
71 | // The flag is optional. If it's nil, it will be autodetected on the first request
72 | // (which bears a small risk that the vector just happens to have a length of 1).
73 | func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool, headers map[string]string, queryParams map[string]string) EmbeddingFunc {
74 | // We don't set a default timeout here, although it's usually a good idea.
75 | // In our case though, the library user can set the timeout on the context,
76 | // and it might have to be a long timeout, depending on the text length.
77 | client := &http.Client{}
78 |
79 | var checkedNormalized bool
80 | checkNormalized := sync.Once{}
81 |
82 | return func(ctx context.Context, text string) ([]float32, error) {
83 | // Prepare the request body.
84 | reqBody, err := json.Marshal(map[string]string{
85 | "input": text,
86 | "model": model,
87 | })
88 | if err != nil {
89 | return nil, fmt.Errorf("couldn't marshal request body: %w", err)
90 | }
91 |
92 | // Create the request. Creating it with context is important for a timeout
93 | // to be possible, because the client is configured without a timeout.
94 | req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/embeddings", bytes.NewBuffer(reqBody))
95 | if err != nil {
96 | return nil, fmt.Errorf("couldn't create request: %w", err)
97 | }
98 | req.Header.Set("Content-Type", "application/json")
99 | req.Header.Set("Authorization", "Bearer "+apiKey)
100 |
101 | // Add headers
102 | for k, v := range headers {
103 | req.Header.Add(k, v)
104 | }
105 |
106 | // Add query parameters
107 | q := req.URL.Query()
108 | for k, v := range queryParams {
109 | q.Add(k, v)
110 | }
111 | req.URL.RawQuery = q.Encode()
112 |
113 | // Send the request.
114 | resp, err := client.Do(req)
115 | if err != nil {
116 | return nil, fmt.Errorf("couldn't send request: %w", err)
117 | }
118 | defer resp.Body.Close()
119 |
120 | // Check the response status.
121 | if resp.StatusCode != http.StatusOK {
122 | return nil, errors.New("error response from the embedding API: " + resp.Status)
123 | }
124 |
125 | // Read and decode the response body.
126 | body, err := io.ReadAll(resp.Body)
127 | if err != nil {
128 | return nil, fmt.Errorf("couldn't read response body: %w", err)
129 | }
130 | var embeddingResponse openAIResponse
131 | err = json.Unmarshal(body, &embeddingResponse)
132 | if err != nil {
133 | return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
134 | }
135 |
136 | // Check if the response contains embeddings.
137 | if len(embeddingResponse.Data) == 0 || len(embeddingResponse.Data[0].Embedding) == 0 {
138 | return nil, errors.New("no embeddings found in the response")
139 | }
140 |
141 | v := embeddingResponse.Data[0].Embedding
142 | if normalized != nil {
143 | if *normalized {
144 | return v, nil
145 | }
146 | return normalizeVector(v), nil
147 | }
148 | checkNormalized.Do(func() {
149 | if isNormalized(v) {
150 | checkedNormalized = true
151 | } else {
152 | checkedNormalized = false
153 | }
154 | })
155 | if !checkedNormalized {
156 | v = normalizeVector(v)
157 | }
158 |
159 | return v, nil
160 | }
161 | }
162 |
--------------------------------------------------------------------------------
/embed_openai_test.go:
--------------------------------------------------------------------------------
1 | package chromem_test
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "io"
8 | "net/http"
9 | "net/http/httptest"
10 | "slices"
11 | "strings"
12 | "testing"
13 |
14 | "github.com/philippgille/chromem-go"
15 | )
16 |
17 | type openAIResponse struct {
18 | Data []struct {
19 | Embedding []float32 `json:"embedding"`
20 | } `json:"data"`
21 | }
22 |
23 | func TestNewEmbeddingFuncOpenAICompat(t *testing.T) {
24 | apiKey := "secret"
25 | model := "model-small"
26 | baseURLSuffix := "/v1"
27 | input := "hello world"
28 |
29 | wantBody, err := json.Marshal(map[string]string{
30 | "input": input,
31 | "model": model,
32 | })
33 | if err != nil {
34 | t.Fatal("unexpected error:", err)
35 | }
36 | wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
37 |
38 | // Mock server
39 | ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
40 | // Check URL
41 | if !strings.HasSuffix(r.URL.Path, baseURLSuffix+"/embeddings") {
42 | t.Fatal("expected URL", baseURLSuffix+"/embeddings", "got", r.URL.Path)
43 | }
44 | // Check method
45 | if r.Method != "POST" {
46 | t.Fatal("expected method POST, got", r.Method)
47 | }
48 | // Check headers
49 | if r.Header.Get("Authorization") != "Bearer "+apiKey {
50 | t.Fatal("expected Authorization header", "Bearer "+apiKey, "got", r.Header.Get("Authorization"))
51 | }
52 | if r.Header.Get("Content-Type") != "application/json" {
53 | t.Fatal("expected Content-Type header", "application/json", "got", r.Header.Get("Content-Type"))
54 | }
55 | // Check body
56 | body, err := io.ReadAll(r.Body)
57 | if err != nil {
58 | t.Fatal("unexpected error:", err)
59 | }
60 | if !bytes.Equal(body, wantBody) {
61 | t.Fatal("expected body", wantBody, "got", body)
62 | }
63 |
64 | // Write response
65 | resp := openAIResponse{
66 | Data: []struct {
67 | Embedding []float32 `json:"embedding"`
68 | }{
69 | {Embedding: wantRes},
70 | },
71 | }
72 | w.WriteHeader(http.StatusOK)
73 | _ = json.NewEncoder(w).Encode(resp)
74 | }))
75 | defer ts.Close()
76 | baseURL := ts.URL + baseURLSuffix
77 |
78 | f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model, nil)
79 | res, err := f(context.Background(), input)
80 | if err != nil {
81 | t.Fatal("expected nil, got", err)
82 | }
83 | if slices.Compare(wantRes, res) != 0 {
84 | t.Fatal("expected res", wantRes, "got", res)
85 | }
86 | }
87 |
--------------------------------------------------------------------------------
/embed_vertex.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "errors"
8 | "fmt"
9 | "io"
10 | "net/http"
11 | "sync"
12 | )
13 |
14 | type EmbeddingModelVertex string
15 |
16 | const (
17 | EmbeddingModelVertexEnglishV1 EmbeddingModelVertex = "textembedding-gecko@001"
18 | EmbeddingModelVertexEnglishV2 EmbeddingModelVertex = "textembedding-gecko@002"
19 | EmbeddingModelVertexEnglishV3 EmbeddingModelVertex = "textembedding-gecko@003"
20 | EmbeddingModelVertexEnglishV4 EmbeddingModelVertex = "text-embedding-004"
21 |
22 | EmbeddingModelVertexMultilingualV1 EmbeddingModelVertex = "textembedding-gecko-multilingual@001"
23 | EmbeddingModelVertexMultilingualV2 EmbeddingModelVertex = "text-multilingual-embedding-002"
24 | )
25 |
26 | const baseURLVertex = "https://us-central1-aiplatform.googleapis.com/v1"
27 |
28 | type vertexOptions struct {
29 | apiEndpoint string
30 | autoTruncate bool
31 | }
32 |
33 | func defaultVertexOptions() *vertexOptions {
34 | return &vertexOptions{
35 | apiEndpoint: baseURLVertex,
36 | autoTruncate: false,
37 | }
38 | }
39 |
40 | type VertexOption func(*vertexOptions)
41 |
42 | func WithVertexAPIEndpoint(apiEndpoint string) VertexOption {
43 | return func(o *vertexOptions) {
44 | o.apiEndpoint = apiEndpoint
45 | }
46 | }
47 |
48 | func WithVertexAutoTruncate(autoTruncate bool) VertexOption {
49 | return func(o *vertexOptions) {
50 | o.autoTruncate = autoTruncate
51 | }
52 | }
53 |
54 | type vertexResponse struct {
55 | Predictions []vertexPrediction `json:"predictions"`
56 | }
57 |
58 | type vertexPrediction struct {
59 | Embeddings vertexEmbeddings `json:"embeddings"`
60 | }
61 |
62 | type vertexEmbeddings struct {
63 | Values []float32 `json:"values"`
64 | // there's more here, but we only care about the embeddings
65 | }
66 |
67 | func NewEmbeddingFuncVertex(apiKey, project string, model EmbeddingModelVertex, opts ...VertexOption) EmbeddingFunc {
68 | cfg := defaultVertexOptions()
69 | for _, opt := range opts {
70 | opt(cfg)
71 | }
72 |
73 | if cfg.apiEndpoint == "" {
74 | cfg.apiEndpoint = baseURLVertex
75 | }
76 |
77 | // We don't set a default timeout here, although it's usually a good idea.
78 | // In our case though, the library user can set the timeout on the context,
79 | // and it might have to be a long timeout, depending on the text length.
80 | client := &http.Client{}
81 |
82 | var checkedNormalized bool
83 | checkNormalized := sync.Once{}
84 |
85 | return func(ctx context.Context, text string) ([]float32, error) {
86 | b := map[string]any{
87 | "instances": []map[string]any{
88 | {
89 | "content": text,
90 | },
91 | },
92 | "parameters": map[string]any{
93 | "autoTruncate": cfg.autoTruncate,
94 | },
95 | }
96 |
97 | // Prepare the request body.
98 | reqBody, err := json.Marshal(b)
99 | if err != nil {
100 | return nil, fmt.Errorf("couldn't marshal request body: %w", err)
101 | }
102 |
103 | fullURL := fmt.Sprintf("%s/projects/%s/locations/us-central1/publishers/google/models/%s:predict", cfg.apiEndpoint, project, model)
104 |
105 | // Create the request. Creating it with context is important for a timeout
106 | // to be possible, because the client is configured without a timeout.
107 | req, err := http.NewRequestWithContext(ctx, "POST", fullURL, bytes.NewBuffer(reqBody))
108 | if err != nil {
109 | return nil, fmt.Errorf("couldn't create request: %w", err)
110 | }
111 | req.Header.Set("Accept", "application/json")
112 | req.Header.Set("Content-Type", "application/json")
113 | req.Header.Set("Authorization", "Bearer "+apiKey)
114 |
115 | // Send the request.
116 | resp, err := client.Do(req)
117 | if err != nil {
118 | return nil, fmt.Errorf("couldn't send request: %w", err)
119 | }
120 | defer resp.Body.Close()
121 |
122 | // Check the response status.
123 | if resp.StatusCode != http.StatusOK {
124 | return nil, errors.New("error response from the embedding API: " + resp.Status)
125 | }
126 |
127 | // Read and decode the response body.
128 | body, err := io.ReadAll(resp.Body)
129 | if err != nil {
130 | return nil, fmt.Errorf("couldn't read response body: %w", err)
131 | }
132 | var embeddingResponse vertexResponse
133 | err = json.Unmarshal(body, &embeddingResponse)
134 | if err != nil {
135 | return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
136 | }
137 |
138 | // Check if the response contains embeddings.
139 | if len(embeddingResponse.Predictions) == 0 || len(embeddingResponse.Predictions[0].Embeddings.Values) == 0 {
140 | return nil, errors.New("no embeddings found in the response")
141 | }
142 |
143 | v := embeddingResponse.Predictions[0].Embeddings.Values
144 | checkNormalized.Do(func() {
145 | if isNormalized(v) {
146 | checkedNormalized = true
147 | } else {
148 | checkedNormalized = false
149 | }
150 | })
151 | if !checkedNormalized {
152 | v = normalizeVector(v)
153 | }
154 |
155 | return v, nil
156 | }
157 | }
158 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Examples
2 |
3 | 1. [Minimal example](minimal)
4 | - A minimal example with the least amount of code and no comments
5 | - Uses OpenAI for creating the embeddings
6 | 2. [RAG Wikipedia Ollama](rag-wikipedia-ollama)
7 | - This example shows a retrieval augmented generation (RAG) application, using `chromem-go` as knowledge base for finding relevant info for a question. More specifically the app is doing *question answering*.
8 | - The underlying data is 200 Wikipedia articles (or rather their lead section / introduction).
9 | - Runs the embeddings model and LLM in [Ollama](https://github.com/ollama/ollama), to showcase how a RAG application can run entirely offline, without relying on OpenAI or other third party APIs.
10 | 3. [Semantic search arXiv OpenAI](semantic-search-arxiv-openai)
11 | - This example shows a semantic search application, using `chromem-go` as vector database for finding semantically relevant search results.
12 | - Loads and searches across ~5,000 arXiv papers in the "Computer Science - Computation and Language" category, which is the relevant one for Natural Language Processing (NLP) related papers.
13 | - Uses OpenAI for creating the embeddings
14 | 4. [WebAssembly](webassembly)
15 | - This example shows how `chromem-go` can be compiled to WebAssembly and then used from JavaScript in a browser
16 | 5. [S3 Export/Import](s3-export-import)
17 | - This example shows how to export the DB to and import it from any S3-compatible blob storage service
18 |
--------------------------------------------------------------------------------
/examples/minimal/README.md:
--------------------------------------------------------------------------------
1 | # Minimal example
2 |
3 | This is a minimal example that shows how `chromem-go` works. For more sophisticated examples that use the persistent DB, locally running embedding models, a basic RAG pipeline, with more explanations in the README as well as code comments, check the other examples!
4 |
5 | ## How to run
6 |
7 | 1. Set the OpenAI API key in your env as `OPENAI_API_KEY`
8 | 2. Run the example: `go run .`
9 |
10 | ## Output
11 |
12 | ```text
13 | ID: 1
14 | Similarity: 0.6833369
15 | Content: The sky is blue because of Rayleigh scattering.
16 | ```
17 |
--------------------------------------------------------------------------------
/examples/minimal/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/philippgille/chromem-go/examples/minimal
2 |
3 | go 1.21
4 |
5 | require github.com/philippgille/chromem-go v0.0.0
6 |
7 | replace github.com/philippgille/chromem-go => ./../..
8 |
--------------------------------------------------------------------------------
/examples/minimal/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "runtime"
7 |
8 | "github.com/philippgille/chromem-go"
9 | )
10 |
11 | func main() {
12 | ctx := context.Background()
13 |
14 | db := chromem.NewDB()
15 |
16 | // Passing nil as embedding function leads to OpenAI being used and requires
17 | // "OPENAI_API_KEY" env var to be set. Other providers are supported as well.
18 | // For example pass `chromem.NewEmbeddingFuncOllama(...)` to use Ollama.
19 | c, err := db.CreateCollection("knowledge-base", nil, nil)
20 | if err != nil {
21 | panic(err)
22 | }
23 |
24 | err = c.AddDocuments(ctx, []chromem.Document{
25 | {
26 | ID: "1",
27 | Content: "The sky is blue because of Rayleigh scattering.",
28 | },
29 | {
30 | ID: "2",
31 | Content: "Leaves are green because chlorophyll absorbs red and blue light.",
32 | },
33 | }, runtime.NumCPU())
34 | if err != nil {
35 | panic(err)
36 | }
37 |
38 | res, err := c.Query(ctx, "Why is the sky blue?", 1, nil, nil)
39 | if err != nil {
40 | panic(err)
41 | }
42 |
43 | fmt.Printf("ID: %v\nSimilarity: %v\nContent: %v\n", res[0].ID, res[0].Similarity, res[0].Content)
44 |
45 | /* Output:
46 | ID: 1
47 | Similarity: 0.6833369
48 | Content: The sky is blue because of Rayleigh scattering.
49 | */
50 | }
51 |
--------------------------------------------------------------------------------
/examples/rag-wikipedia-ollama/.gitignore:
--------------------------------------------------------------------------------
1 | /db
2 |
--------------------------------------------------------------------------------
/examples/rag-wikipedia-ollama/README.md:
--------------------------------------------------------------------------------
1 | # RAG Wikipedia Ollama
2 |
3 | This example shows a retrieval augmented generation (RAG) application, using `chromem-go` as knowledge base for finding relevant info for a question. More specifically the app is doing *question answering*. The underlying data is 200 Wikipedia articles (or rather their lead section / introduction).
4 |
5 | We run the embeddings model and LLM in [Ollama](https://github.com/ollama/ollama), to showcase how a RAG application can run entirely offline, without relying on OpenAI or other third party APIs. It doesn't require a GPU, and a CPU like an 11th Gen Intel i5-1135G7 (like in the first generation Framework Laptop 13) is fast enough.
6 |
7 | As LLM we use Google's [Gemma (2B)](https://huggingface.co/google/gemma-2b), a very small model that doesn't need many resources and is fast, but doesn't have much knowledge, so it's a prime example for the combination of LLMs and vector databases. We found Gemma 2B to be superior to [TinyLlama (1.1B)](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0), [Stable LM 2 (1.6B)](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b) and [Phi-2 (2.7B)](https://huggingface.co/microsoft/phi-2) for the RAG use case.
8 |
9 | As embeddings model we use Nomic's [nomic-embed-text v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5).
10 |
11 | ## How to run
12 |
13 | 1. Install Ollama:
14 | 2. Download the two models:
15 | - `ollama pull gemma:2b`
16 | - `ollama pull nomic-embed-text`
17 | 3. Run the example: `go run .`
18 |
19 | ## Output
20 |
21 | The output can differ slightly on each run, but it's along the lines of:
22 |
23 | ```log
24 | 2024/03/02 20:02:30 Warming up Ollama...
25 | 2024/03/02 20:02:33 Question: When did the Monarch Company exist?
26 | 2024/03/02 20:02:33 Asking LLM...
27 | 2024/03/02 20:02:34 Initial reply from the LLM: "I cannot provide information on the Monarch Company, as I am unable to access real-time or comprehensive knowledge sources."
28 | 2024/03/02 20:02:34 Setting up chromem-go...
29 | 2024/03/02 20:02:34 Reading JSON lines...
30 | 2024/03/02 20:02:34 Adding documents to chromem-go, including creating their embeddings via Ollama API...
31 | 2024/03/02 20:03:11 Querying chromem-go...
32 | 2024/03/02 20:03:11 Search (incl query embedding) took 231.672667ms
33 | 2024/03/02 20:03:11 Document 1 (similarity: 0.655357): "Malleable Iron Range Company was a company that existed from 1896 to 1985 and primarily produced kitchen ranges made of malleable iron but also produced a variety of other related products. The company's primary trademark was 'Monarch' and was colloquially often referred to as the Monarch Company or just Monarch."
34 | 2024/03/02 20:03:11 Document 2 (similarity: 0.504042): "The American Motor Car Company was a short-lived company in the automotive industry founded in 1906 lasting until 1913. It was based in Indianapolis Indiana United States. The American Motor Car Company pioneered the underslung design."
35 | 2024/03/02 20:03:11 Asking LLM with augmented question...
36 | 2024/03/02 20:03:32 Reply after augmenting the question with knowledge: "The Monarch Company existed from 1896 to 1985."
37 | ```
38 |
39 | The majority of the time here is spent during the embeddings creation, where we are limited by the performance of the Ollama API, which depends on your CPU/GPU and the embeddings model.
40 |
41 | ## OpenAI
42 |
43 | You can easily adapt the code to work with OpenAI instead of locally in Ollama.
44 |
45 | Add the OpenAI API key in your environment as `OPENAI_API_KEY`.
46 |
47 | Then, if you want to create the embeddings via OpenAI, but still use Gemma 2B as LLM:
48 |
49 | Apply this patch
50 |
51 | ```diff
52 | diff --git a/examples/rag-wikipedia-ollama/main.go b/examples/rag-wikipedia-ollama/main.go
53 | index 55b3076..cee9561 100644
54 | --- a/examples/rag-wikipedia-ollama/main.go
55 | +++ b/examples/rag-wikipedia-ollama/main.go
56 | @@ -14,8 +14,6 @@ import (
57 |
58 | const (
59 | question = "When did the Monarch Company exist?"
60 | - // We use a local LLM running in Ollama for the embedding: https://huggingface.co/nomic-ai/nomic-embed-text-v1.5
61 | - embeddingModel = "nomic-embed-text"
62 | )
63 |
64 | func main() {
65 | @@ -48,7 +46,7 @@ func main() {
66 | // variable to be set.
67 | // For this example we choose to use a locally running embedding model though.
68 | // It requires Ollama to serve its API at "http://localhost:11434/api".
69 | - collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel))
70 | + collection, err := db.GetOrCreateCollection("Wikipedia", nil, nil)
71 | if err != nil {
72 | panic(err)
73 | }
74 | @@ -82,7 +80,7 @@ func main() {
75 | Content: article.Text,
76 | })
77 | }
78 | - log.Println("Adding documents to chromem-go, including creating their embeddings via Ollama API...")
79 | + log.Println("Adding documents to chromem-go, including creating their embeddings via OpenAI API...")
80 | err = collection.AddDocuments(ctx, docs, runtime.NumCPU())
81 | if err != nil {
82 | panic(err)
83 | ```
84 |
85 |
86 |
87 | Or alternatively, if you want to use OpenAI for everything (embeddings creation and LLM):
88 |
89 | Apply this patch
90 |
91 | ```diff
92 | diff --git a/examples/rag-wikipedia-ollama/llm.go b/examples/rag-wikipedia-ollama/llm.go
93 | index 1fde4ec..7cb81cc 100644
94 | --- a/examples/rag-wikipedia-ollama/llm.go
95 | +++ b/examples/rag-wikipedia-ollama/llm.go
96 | @@ -2,23 +2,13 @@ package main
97 |
98 | import (
99 | "context"
100 | - "net/http"
101 | + "os"
102 | "strings"
103 | "text/template"
104 |
105 | "github.com/sashabaranov/go-openai"
106 | )
107 |
108 | -const (
109 | - // We use a local LLM running in Ollama for asking the question: https://github.com/ollama/ollama
110 | - ollamaBaseURL = "http://localhost:11434/v1"
111 | - // We use Google's Gemma (2B), a very small model that doesn't need much resources
112 | - // and is fast, but doesn't have much knowledge: https://huggingface.co/google/gemma-2b
113 | - // We found Gemma 2B to be superior to TinyLlama (1.1B), Stable LM 2 (1.6B)
114 | - // and Phi-2 (2.7B) for the retrieval augmented generation (RAG) use case.
115 | - llmModel = "gemma:2b"
116 | -)
117 | -
118 | // There are many different ways to provide the context to the LLM.
119 | // You can pass each context as user message, or the list as one user message,
120 | // or pass it in the system prompt. The system prompt itself also has a big impact
121 | @@ -47,10 +37,7 @@ Don't mention the knowledge base, context or search results in your answer.
122 |
123 | func askLLM(ctx context.Context, contexts []string, question string) string {
124 | // We can use the OpenAI client because Ollama is compatible with OpenAI's API.
125 | - openAIClient := openai.NewClientWithConfig(openai.ClientConfig{
126 | - BaseURL: ollamaBaseURL,
127 | - HTTPClient: http.DefaultClient,
128 | - })
129 | + openAIClient := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
130 | sb := &strings.Builder{}
131 | err := systemPromptTpl.Execute(sb, contexts)
132 | if err != nil {
133 | @@ -66,7 +53,7 @@ func askLLM(ctx context.Context, contexts []string, question string) string {
134 | },
135 | }
136 | res, err := openAIClient.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
137 | - Model: llmModel,
138 | + Model: openai.GPT3Dot5Turbo,
139 | Messages: messages,
140 | })
141 | if err != nil {
142 | diff --git a/examples/rag-wikipedia-ollama/main.go b/examples/rag-wikipedia-ollama/main.go
143 | index 55b3076..044a246 100644
144 | --- a/examples/rag-wikipedia-ollama/main.go
145 | +++ b/examples/rag-wikipedia-ollama/main.go
146 | @@ -12,19 +12,11 @@ import (
147 | "github.com/philippgille/chromem-go"
148 | )
149 |
150 | -const (
151 | - question = "When did the Monarch Company exist?"
152 | - // We use a local LLM running in Ollama for the embedding: https://huggingface.co/nomic-ai/nomic-embed-text-v1.5
153 | - embeddingModel = "nomic-embed-text"
154 | -)
155 | +const question = "When did the Monarch Company exist?"
156 |
157 | func main() {
158 | ctx := context.Background()
159 |
160 | - // Warm up Ollama, in case the model isn't loaded yet
161 | - log.Println("Warming up Ollama...")
162 | - _ = askLLM(ctx, nil, "Hello!")
163 | -
164 | // First we ask an LLM a fairly specific question that it likely won't know
165 | // the answer to.
166 | log.Println("Question: " + question)
167 | @@ -48,7 +40,7 @@ func main() {
168 | // variable to be set.
169 | // For this example we choose to use a locally running embedding model though.
170 | // It requires Ollama to serve its API at "http://localhost:11434/api".
171 | - collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel))
172 | + collection, err := db.GetOrCreateCollection("Wikipedia", nil, nil)
173 | if err != nil {
174 | panic(err)
175 | }
176 | @@ -82,7 +74,7 @@ func main() {
177 | Content: article.Text,
178 | })
179 | }
180 | - log.Println("Adding documents to chromem-go, including creating their embeddings via Ollama API...")
181 | + log.Println("Adding documents to chromem-go, including creating their embeddings via OpenAI API...")
182 | err = collection.AddDocuments(ctx, docs, runtime.NumCPU())
183 | if err != nil {
184 | panic(err)
185 | ```
186 |
187 |
188 |
--------------------------------------------------------------------------------
/examples/rag-wikipedia-ollama/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/philippgille/chromem-go/examples/rag-wikipedia-ollama
2 |
3 | go 1.21
4 |
5 | require (
6 | github.com/philippgille/chromem-go v0.0.0
7 | github.com/sashabaranov/go-openai v1.17.9
8 | )
9 |
10 | replace github.com/philippgille/chromem-go => ./../..
11 |
--------------------------------------------------------------------------------
/examples/rag-wikipedia-ollama/go.sum:
--------------------------------------------------------------------------------
1 | github.com/sashabaranov/go-openai v1.17.9 h1:QEoBiGKWW68W79YIfXWEFZ7l5cEgZBV4/Ow3uy+5hNY=
2 | github.com/sashabaranov/go-openai v1.17.9/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
3 |
--------------------------------------------------------------------------------
/examples/rag-wikipedia-ollama/llm.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "net/http"
6 | "strings"
7 | "text/template"
8 |
9 | "github.com/sashabaranov/go-openai"
10 | )
11 |
12 | const (
13 | // We use a local LLM running in Ollama to ask a question: https://github.com/ollama/ollama
14 | ollamaBaseURL = "http://localhost:11434/v1"
15 | // We use Google's Gemma (2B), a very small model that doesn't need many resources
16 | // and is fast, but doesn't have much knowledge: https://huggingface.co/google/gemma-2b
17 | // We found Gemma 2B to be superior to TinyLlama (1.1B), Stable LM 2 (1.6B)
18 | // and Phi-2 (2.7B) for the retrieval augmented generation (RAG) use case.
19 | llmModel = "gemma:2b"
20 | )
21 |
22 | // There are many different ways to provide the context to the LLM.
23 | // You can pass each context as user message, or the list as one user message,
24 | // or pass it in the system prompt. The system prompt itself also has a big impact
25 | // on how well the LLM handles the context, especially for LLMs with < 7B parameters.
26 | // The prompt engineering is up to you, it's out of scope for the vector database.
27 | var systemPromptTpl = template.Must(template.New("system_prompt").Parse(`
28 | You are a helpful assistant with access to a knowlege base, tasked with answering questions about the world and its history, people, places and other things.
29 |
30 | Answer the question in a very concise manner. Use an unbiased and journalistic tone. Do not repeat text. Don't make anything up. If you are not sure about something, just say that you don't know.
31 | {{- /* Stop here if no context is provided. The rest below is for handling contexts. */ -}}
32 | {{- if . -}}
33 | Answer the question solely based on the provided search results from the knowledge base. If the search results from the knowledge base are not relevant to the question at hand, just say that you don't know. Don't make anything up.
34 |
35 | Anything between the following 'context' XML blocks is retrieved from the knowledge base, not part of the conversation with the user. The bullet points are ordered by relevance, so the first one is the most relevant.
36 |
37 |
38 | {{- if . -}}
39 | {{- range $context := .}}
40 | - {{.}}{{end}}
41 | {{- end}}
42 |
43 | {{- end -}}
44 |
45 | Don't mention the knowledge base, context or search results in your answer.
46 | `))
47 |
48 | func askLLM(ctx context.Context, contexts []string, question string) string {
49 | // We can use the OpenAI client because Ollama is compatible with OpenAI's API.
50 | openAIClient := openai.NewClientWithConfig(openai.ClientConfig{
51 | BaseURL: ollamaBaseURL,
52 | HTTPClient: http.DefaultClient,
53 | })
54 | sb := &strings.Builder{}
55 | err := systemPromptTpl.Execute(sb, contexts)
56 | if err != nil {
57 | panic(err)
58 | }
59 | messages := []openai.ChatCompletionMessage{
60 | {
61 | Role: openai.ChatMessageRoleSystem,
62 | Content: sb.String(),
63 | }, {
64 | Role: openai.ChatMessageRoleUser,
65 | Content: "Question: " + question,
66 | },
67 | }
68 | res, err := openAIClient.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
69 | Model: llmModel,
70 | Messages: messages,
71 | })
72 | if err != nil {
73 | panic(err)
74 | }
75 | reply := res.Choices[0].Message.Content
76 | reply = strings.TrimSpace(reply)
77 |
78 | return reply
79 | }
80 |
--------------------------------------------------------------------------------
/examples/rag-wikipedia-ollama/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "io"
7 | "log"
8 | "os"
9 | "runtime"
10 | "strconv"
11 | "strings"
12 | "time"
13 |
14 | "github.com/philippgille/chromem-go"
15 | )
16 |
17 | const (
18 | question = "When did the Monarch Company exist?"
19 | // We use a local LLM running in Ollama for the embedding: https://huggingface.co/nomic-ai/nomic-embed-text-v1.5
20 | embeddingModel = "nomic-embed-text"
21 | )
22 |
23 | func main() {
24 | ctx := context.Background()
25 |
26 | // Warm up Ollama, in case the model isn't loaded yet
27 | log.Println("Warming up Ollama...")
28 | _ = askLLM(ctx, nil, "Hello!")
29 |
30 | // First we ask an LLM a fairly specific question that it likely won't know
31 | // the answer to.
32 | log.Println("Question: " + question)
33 | log.Println("Asking LLM...")
34 | reply := askLLM(ctx, nil, question)
35 | log.Printf("Initial reply from the LLM: \"" + reply + "\"\n")
36 |
37 | // Now we use our vector database for retrieval augmented generation (RAG),
38 | // which means we provide the LLM with relevant knowledge.
39 |
40 | // Set up chromem-go with persistence, so that when the program restarts, the
41 | // DB's data is still available.
42 | log.Println("Setting up chromem-go...")
43 | db, err := chromem.NewPersistentDB("./db", false)
44 | if err != nil {
45 | panic(err)
46 | }
47 | // Create collection if it wasn't loaded from persistent storage yet.
48 | // You can pass nil as embedding function to use the default (OpenAI text-embedding-3-small),
49 | // which is very good and cheap. It would require the OPENAI_API_KEY environment
50 | // variable to be set.
51 | // For this example we choose to use a locally running embedding model though.
52 | // It requires Ollama to serve its API at "http://localhost:11434/api".
53 | collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel, ""))
54 | if err != nil {
55 | panic(err)
56 | }
57 | // Add docs to the collection, if the collection was just created (and not
58 | // loaded from persistent storage).
59 | var docs []chromem.Document
60 | if collection.Count() == 0 {
61 | // Here we use a DBpedia sample, where each line contains the lead section/introduction
62 | // to some Wikipedia article and its category.
63 | f, err := os.Open("dbpedia_sample.jsonl")
64 | if err != nil {
65 | panic(err)
66 | }
67 | defer f.Close()
68 | d := json.NewDecoder(f)
69 | log.Println("Reading JSON lines...")
70 | for i := 1; ; i++ {
71 | var article struct {
72 | Text string `json:"text"`
73 | Category string `json:"category"`
74 | }
75 | err := d.Decode(&article)
76 | if err == io.EOF {
77 | break // reached end of file
78 | } else if err != nil {
79 | panic(err)
80 | }
81 |
82 | // The embeddings model we use in this example ("nomic-embed-text")
83 | // fare better with a prefix to differentiate between document and query.
84 | // We'll have to cut it off later when we retrieve the documents.
85 | // An alternative is to create the embedding with `chromem.NewDocument()`,
86 | // and then change back the content before adding it do the collection
87 | // with `collection.AddDocument()`.
88 | content := "search_document: " + article.Text
89 |
90 | docs = append(docs, chromem.Document{
91 | ID: strconv.Itoa(i),
92 | Metadata: map[string]string{"category": article.Category},
93 | Content: content,
94 | })
95 | }
96 | log.Println("Adding documents to chromem-go, including creating their embeddings via Ollama API...")
97 | err = collection.AddDocuments(ctx, docs, runtime.NumCPU())
98 | if err != nil {
99 | panic(err)
100 | }
101 | } else {
102 | log.Println("Not reading JSON lines because collection was loaded from persistent storage.")
103 | }
104 |
105 | // Search for documents that are semantically similar to the original question.
106 | // We ask for the two most similar documents, but you can use more or less depending
107 | // on your needs and the supported context size of the LLM you use.
108 | // You can limit the search by filtering on content or metadata (like the article's
109 | // category), but we don't do that in this example.
110 | start := time.Now()
111 | log.Println("Querying chromem-go...")
112 | // "nomic-embed-text" specific prefix (not required with OpenAI's or other models)
113 | query := "search_query: " + question
114 | docRes, err := collection.Query(ctx, query, 2, nil, nil)
115 | if err != nil {
116 | panic(err)
117 | }
118 | log.Println("Search (incl query embedding) took", time.Since(start))
119 | // Here you could filter out any documents whose similarity is below a certain threshold.
120 | // if docRes[...].Similarity < 0.5 { ...
121 |
122 | // Print the retrieved documents and their similarity to the question.
123 | for i, res := range docRes {
124 | // Cut off the prefix we added before adding the document (see comment above).
125 | // This is specific to the "nomic-embed-text" model.
126 | content := strings.TrimPrefix(res.Content, "search_document: ")
127 | log.Printf("Document %d (similarity: %f): \"%s\"\n", i+1, res.Similarity, content)
128 | }
129 |
130 | // Now we can ask the LLM again, augmenting the question with the knowledge we retrieved.
131 | // In this example we just use both retrieved documents as context.
132 | contexts := []string{docRes[0].Content, docRes[1].Content}
133 | log.Println("Asking LLM with augmented question...")
134 | reply = askLLM(ctx, contexts, question)
135 | log.Printf("Reply after augmenting the question with knowledge: \"" + reply + "\"\n")
136 |
137 | /* Output (can differ slightly on each run):
138 | 2024/03/02 20:02:30 Warming up Ollama...
139 | 2024/03/02 20:02:33 Question: When did the Monarch Company exist?
140 | 2024/03/02 20:02:33 Asking LLM...
141 | 2024/03/02 20:02:34 Initial reply from the LLM: "I cannot provide information on the Monarch Company, as I am unable to access real-time or comprehensive knowledge sources."
142 | 2024/03/02 20:02:34 Setting up chromem-go...
143 | 2024/03/02 20:02:34 Reading JSON lines...
144 | 2024/03/02 20:02:34 Adding documents to chromem-go, including creating their embeddings via Ollama API...
145 | 2024/03/02 20:03:11 Querying chromem-go...
146 | 2024/03/02 20:03:11 Search (incl query embedding) took 231.672667ms
147 | 2024/03/02 20:03:11 Document 1 (similarity: 0.655357): "Malleable Iron Range Company was a company that existed from 1896 to 1985 and primarily produced kitchen ranges made of malleable iron but also produced a variety of other related products. The company's primary trademark was 'Monarch' and was colloquially often referred to as the Monarch Company or just Monarch."
148 | 2024/03/02 20:03:11 Document 2 (similarity: 0.504042): "The American Motor Car Company was a short-lived company in the automotive industry founded in 1906 lasting until 1913. It was based in Indianapolis Indiana United States. The American Motor Car Company pioneered the underslung design."
149 | 2024/03/02 20:03:11 Asking LLM with augmented question...
150 | 2024/03/02 20:03:32 Reply after augmenting the question with knowledge: "The Monarch Company existed from 1896 to 1985."
151 | */
152 | }
153 |
--------------------------------------------------------------------------------
/examples/s3-export-import/README.md:
--------------------------------------------------------------------------------
1 | # S3 Export/Import
2 |
3 | This example shows how to export the DB to and import it from any S3-compatible blob storage service.
4 |
5 | - The example uses [MinIO](https://github.com/minio/minio), but any S3-compatible storage works.
6 | - The example uses [gocloud.dev](https://github.com/google/go-cloud) Go "Cloud Development Kit" from Google for interfacing with any S3-compatible storage, and because it provides methods for creating writers and readers that make it easy to use with `chromem-go`.
7 |
8 | ## How to run
9 |
10 | 1. Prepare the S3-compatible storage
11 | 1. `docker run -d --rm --name minio -p 127.0.0.1:9000:9000 -p 127.0.0.1:9001:9001 quay.io/minio/minio:RELEASE.2024-05-01T01-11-10Z server /data --console-address ":9001"`
12 | 2. Open the MinIO Console in your browser:
13 | 3. Log in with user `minioadmin` and password `minioadmin`
14 | 4. Use the web UI to create a bucket named `mybucket`
15 | 2. Set the OpenAI API key in your env as `OPENAI_API_KEY`
16 | 3. `go run .`
17 |
18 | You can also check and see the exported DB as `chromem.gob.gz`.
19 |
20 | To stop the MinIO server run `docker stop minio`.
21 |
22 | ## Output
23 |
24 | ```text
25 | 2024/05/04 19:24:07 Successfully exported DB to S3 storage.
26 | 2024/05/04 19:24:07 Imported collection with 1 documents
27 | 2024/05/04 19:24:07 Successfully imported DB from S3 storage.
28 | ```
29 |
--------------------------------------------------------------------------------
/examples/s3-export-import/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/philippgille/chromem-go/examples/s3-export-import
2 |
3 | go 1.21
4 |
5 | require (
6 | github.com/philippgille/chromem-go v0.0.0
7 | gocloud.dev v0.37.0
8 | )
9 |
10 | replace github.com/philippgille/chromem-go => ./../..
11 |
12 | require (
13 | github.com/aws/aws-sdk-go v1.50.36 // indirect
14 | github.com/aws/aws-sdk-go-v2 v1.25.3 // indirect
15 | github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.1 // indirect
16 | github.com/aws/aws-sdk-go-v2/config v1.27.7 // indirect
17 | github.com/aws/aws-sdk-go-v2/credentials v1.17.7 // indirect
18 | github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.15.3 // indirect
19 | github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.16.9 // indirect
20 | github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3 // indirect
21 | github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3 // indirect
22 | github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
23 | github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.3 // indirect
24 | github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.1 // indirect
25 | github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.5 // indirect
26 | github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.5 // indirect
27 | github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.3 // indirect
28 | github.com/aws/aws-sdk-go-v2/service/s3 v1.51.4 // indirect
29 | github.com/aws/aws-sdk-go-v2/service/sso v1.20.2 // indirect
30 | github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.2 // indirect
31 | github.com/aws/aws-sdk-go-v2/service/sts v1.28.4 // indirect
32 | github.com/aws/smithy-go v1.20.1 // indirect
33 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
34 | github.com/golang/protobuf v1.5.4 // indirect
35 | github.com/google/wire v0.6.0 // indirect
36 | github.com/googleapis/gax-go/v2 v2.12.2 // indirect
37 | github.com/jmespath/go-jmespath v0.4.0 // indirect
38 | go.opencensus.io v0.24.0 // indirect
39 | golang.org/x/net v0.22.0 // indirect
40 | golang.org/x/sys v0.18.0 // indirect
41 | golang.org/x/text v0.14.0 // indirect
42 | golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect
43 | google.golang.org/api v0.169.0 // indirect
44 | google.golang.org/genproto/googleapis/rpc v0.0.0-20240311173647-c811ad7063a7 // indirect
45 | google.golang.org/grpc v1.62.1 // indirect
46 | google.golang.org/protobuf v1.33.0 // indirect
47 | )
48 |
--------------------------------------------------------------------------------
/examples/s3-export-import/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "log"
6 | "os"
7 |
8 | "github.com/philippgille/chromem-go"
9 | "gocloud.dev/blob"
10 | _ "gocloud.dev/blob/s3blob"
11 | )
12 |
13 | func main() {
14 | ctx := context.Background()
15 |
16 | // As S3-style storage we use a local MinIO instance in this example. It has
17 | // default credentials which we set in environment variables in order to be
18 | // read when calling `blob.OpenBucket`, because that call implies using the
19 | // AWS SDK default "shared config" loader.
20 | // That loader checks the environment variables, but can also fall back to a
21 | // file in `~/.aws/config` from `aws sso login` or similar.
22 | // See https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/configuring-sdk.html#specifying-credentials
23 | // for details about credential loading.
24 | err := os.Setenv("AWS_ACCESS_KEY_ID", "minioadmin")
25 | if err != nil {
26 | panic(err)
27 | }
28 | err = os.Setenv("AWS_SECRET_ACCESS_KEY", "minioadmin")
29 | if err != nil {
30 | panic(err)
31 | }
32 | // A region configuration is also required. Alternatively it can be passed in
33 | // the connection string with "®ion=us-west-1" for example.
34 | err = os.Setenv("AWS_DEFAULT_REGION", "us-west-1")
35 | if err != nil {
36 | panic(err)
37 | }
38 |
39 | // Export DB
40 | err = exportDB(ctx)
41 | if err != nil {
42 | panic(err)
43 | }
44 | log.Println("Successfully exported DB to S3 storage.")
45 |
46 | // Import DB
47 | err = importDB(ctx)
48 | if err != nil {
49 | panic(err)
50 | }
51 | log.Println("Successfully imported DB from S3 storage.")
52 | }
53 |
54 | func exportDB(ctx context.Context) error {
55 | // Create and fill DB
56 | db := chromem.NewDB()
57 | c, err := db.CreateCollection("knowledge-base", nil, nil)
58 | if err != nil {
59 | return err
60 | }
61 | err = c.AddDocument(ctx, chromem.Document{
62 | ID: "1",
63 | Content: "The sky is blue because of Rayleigh scattering.",
64 | })
65 | if err != nil {
66 | return err
67 | }
68 |
69 | // Open S3 bucket. We're using a local MinIO instance here, but it can be any
70 | // S3-compatible storage. We're also using the gocloud.dev/blob package instead
71 | // of the AWS SDK for Go directly, because it provides a unified Writer/Reader
72 | // API for different cloud storage providers.
73 | bucket, err := blob.OpenBucket(ctx, "s3://mybucket?"+
74 | "endpoint=localhost:9000&"+
75 | "disableSSL=true&"+
76 | "s3ForcePathStyle=true")
77 | if err != nil {
78 | return err
79 | }
80 |
81 | // Create writer to an S3 object
82 | w, err := bucket.NewWriter(ctx, "chromem.gob.gz", nil)
83 | if err != nil {
84 | return err
85 | }
86 | // Instead of deferring w.Close() here, we close it at the end of the function
87 | // to handle its errors, as the close is important for the actual write to happen
88 | // and can lead to errors such as "The specified bucket does not exist" etc.
89 | // Another option is to use a named return value and defer a function that
90 | // overwrites the error with the close error or uses [errors.Join] or similar.
91 |
92 | // Persist the DB to the S3 object
93 | err = db.ExportToWriter(w, true, "")
94 | if err != nil {
95 | return err
96 | }
97 |
98 | return w.Close()
99 | }
100 |
101 | func importDB(ctx context.Context) error {
102 | // Open S3 bucket. We're using a local MinIO instance here, but it can be any
103 | // S3-compatible storage. We're also using the gocloud.dev/blob package instead
104 | // of the AWS SDK for Go directly, because it provides a unified Writer/Reader
105 | // API for different cloud storage providers.
106 | bucket, err := blob.OpenBucket(ctx, "s3://mybucket?"+
107 | "endpoint=localhost:9000&"+
108 | "disableSSL=true&"+
109 | "s3ForcePathStyle=true")
110 | if err != nil {
111 | return err
112 | }
113 |
114 | // Open reader to the S3 object
115 | r, err := bucket.NewReader(ctx, "chromem.gob.gz", nil)
116 | if err != nil {
117 | return err
118 | }
119 | defer r.Close()
120 |
121 | // Create empty DB
122 | db := chromem.NewDB()
123 |
124 | // Import the DB from the S3 object
125 | err = db.ImportFromReader(r, "")
126 | if err != nil {
127 | return err
128 | }
129 |
130 | c := db.GetCollection("knowledge-base", nil)
131 | log.Printf("Imported collection with %d documents\n", c.Count())
132 |
133 | return nil
134 | }
135 |
--------------------------------------------------------------------------------
/examples/semantic-search-arxiv-openai/.gitignore:
--------------------------------------------------------------------------------
1 | /db
2 |
--------------------------------------------------------------------------------
/examples/semantic-search-arxiv-openai/README.md:
--------------------------------------------------------------------------------
1 | # Semantic search arXiv OpenAI
2 |
3 | This example shows a semantic search application, using `chromem-go` as vector database for finding semantically relevant search results. We load and search across ~5,000 arXiv papers in the "Computer Science - Computation and Language" category, which is the relevant one for Natural Language Processing (NLP) related papers.
4 |
5 | This is not a retrieval augmented generation (RAG) app, because after *retrieving* the semantically relevant results, we don't *augment* any prompt to an LLM. No LLM generates the final output.
6 |
7 | ## How to run
8 |
9 | 1. Prepare the dataset
10 | 1. Download `arxiv-metadata-oai-snapshot.json` from
11 | 2. Filter by "Computer Science - Computation and Language" category (see [taxonomy](https://arxiv.org/category_taxonomy)), filter by updates from 2023
12 | 1. Ensure you have [ripgrep](https://github.com/BurntSushi/ripgrep) installed, or adapt the following commands to use grep
13 | 2. Run `rg '"categories":"cs.CL"' ~/Downloads/arxiv-metadata-oai-snapshot.json | rg '"update_date":"2023' > /tmp/arxiv_cs-cl_2023.jsonl` (adapt input file path if necessary)
14 | 3. Check the data
15 | 1. `wc -l /tmp/arxiv_cs-cl_2023.jsonl` should show ~5,000 lines
16 | 2. `du -h /tmp/arxiv_cs-cl_2023.jsonl` should show ~8.8 MB
17 | 2. Set the OpenAI API key in your env as `OPENAI_API_KEY`
18 | 3. Run the example: `go run .`
19 |
20 | ## Output
21 |
22 | The output can differ slightly on each run, but it's along the lines of:
23 |
24 | ```log
25 | 2024/03/10 18:23:55 Setting up chromem-go...
26 | 2024/03/10 18:23:55 Reading JSON lines...
27 | 2024/03/10 18:23:55 Read and parsed 5006 documents.
28 | 2024/03/10 18:23:55 Adding documents to chromem-go, including creating their embeddings via OpenAI API...
29 | 2024/03/10 18:28:12 Querying chromem-go...
30 | 2024/03/10 18:28:12 Search (incl query embedding) took 529.451163ms
31 | 2024/03/10 18:28:12 Search results:
32 | 1) Similarity 0.488895:
33 | URL: https://arxiv.org/abs/2209.15469
34 | Submitter: Christian Buck
35 | Title: Zero-Shot Retrieval with Search Agents and Hybrid Environments
36 | Abstract: Learning to search is the task of building artificial agents that learn to autonomously use a search...
37 | 2) Similarity 0.480713:
38 | URL: https://arxiv.org/abs/2305.11516
39 | Submitter: Ryo Nagata Dr.
40 | Title: Contextualized Word Vector-based Methods for Discovering Semantic Differences with No Training nor Word Alignment
41 | Abstract: In this paper, we propose methods for discovering semantic differences in words appearing in two cor...
42 | 3) Similarity 0.476079:
43 | URL: https://arxiv.org/abs/2310.14025
44 | Submitter: Maria Lymperaiou
45 | Title: Large Language Models and Multimodal Retrieval for Visual Word Sense Disambiguation
46 | Abstract: Visual Word Sense Disambiguation (VWSD) is a novel challenging task with the goal of retrieving an i...
47 | 4) Similarity 0.474883:
48 | URL: https://arxiv.org/abs/2302.14785
49 | Submitter: Teven Le Scao
50 | Title: Joint Representations of Text and Knowledge Graphs for Retrieval and Evaluation
51 | Abstract: A key feature of neural models is that they can produce semantic vector representations of objects (...
52 | 5) Similarity 0.470326:
53 | URL: https://arxiv.org/abs/2309.02403
54 | Submitter: Dallas Card
55 | Title: Substitution-based Semantic Change Detection using Contextual Embeddings
56 | Abstract: Measuring semantic change has thus far remained a task where methods using contextual embeddings hav...
57 | 6) Similarity 0.466851:
58 | URL: https://arxiv.org/abs/2309.08187
59 | Submitter: Vu Tran
60 | Title: Encoded Summarization: Summarizing Documents into Continuous Vector Space for Legal Case Retrieval
61 | Abstract: We present our method for tackling a legal case retrieval task by introducing our method of encoding...
62 | 7) Similarity 0.461783:
63 | URL: https://arxiv.org/abs/2307.16638
64 | Submitter: Maiia Bocharova Bocharova
65 | Title: VacancySBERT: the approach for representation of titles and skills for semantic similarity search in the recruitment domain
66 | Abstract: The paper focuses on deep learning semantic search algorithms applied in the HR domain. The aim of t...
67 | 8) Similarity 0.460481:
68 | URL: https://arxiv.org/abs/2106.07400
69 | Submitter: Clara Meister
70 | Title: Determinantal Beam Search
71 | Abstract: Beam search is a go-to strategy for decoding neural sequence models. The algorithm can naturally be ...
72 | 9) Similarity 0.460001:
73 | URL: https://arxiv.org/abs/2305.04049
74 | Submitter: Yuxia Wu
75 | Title: Actively Discovering New Slots for Task-oriented Conversation
76 | Abstract: Existing task-oriented conversational search systems heavily rely on domain ontologies with pre-defi...
77 | 10) Similarity 0.458321:
78 | URL: https://arxiv.org/abs/2305.08654
79 | Submitter: Taichi Aida
80 | Title: Unsupervised Semantic Variation Prediction using the Distribution of Sibling Embeddings
81 | Abstract: Languages are dynamic entities, where the meanings associated with words constantly change with time...
82 | ```
83 |
84 | The majority of the time here is spent during the embeddings creation, where we are limited by the performance of the OpenAI API.
85 |
--------------------------------------------------------------------------------
/examples/semantic-search-arxiv-openai/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/philippgille/chromem-go/examples/semantic-search-arxiv-openai
2 |
3 | go 1.21
4 |
5 | require github.com/philippgille/chromem-go v0.0.0
6 |
7 | replace github.com/philippgille/chromem-go => ./../..
8 |
--------------------------------------------------------------------------------
/examples/semantic-search-arxiv-openai/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "io"
8 | "log"
9 | "os"
10 | "runtime"
11 | "strings"
12 | "time"
13 |
14 | "github.com/philippgille/chromem-go"
15 | )
16 |
17 | const searchTerm = "semantic search with vector databases"
18 |
19 | func main() {
20 | ctx := context.Background()
21 |
22 | // Set up chromem-go with persistence, so that when the program restarts, the
23 | // DB's data is still available.
24 | log.Println("Setting up chromem-go...")
25 | db, err := chromem.NewPersistentDB("./db", false)
26 | if err != nil {
27 | panic(err)
28 | }
29 | // Create collection if it wasn't loaded from persistent storage yet.
30 | // We pass nil as embedding function to use the default (OpenAI text-embedding-3-small),
31 | // which is very good and cheap. It requires the OPENAI_API_KEY environment
32 | // variable to be set. See the other examples for how to use Ollama.
33 | collection, err := db.GetOrCreateCollection("arXiv cs.CL 2023", nil, nil)
34 | if err != nil {
35 | panic(err)
36 | }
37 | // Add docs to the collection, if the collection was just created (and not
38 | // loaded from persistent storage).
39 | var docs []chromem.Document
40 | if collection.Count() == 0 {
41 | // Here we use an arXiv metadata sample, where each line contains the metadata
42 | // of a paper, including its submitter, title and abstract.
43 | f, err := os.Open("/tmp/arxiv_cs-cl_2023.jsonl")
44 | if err != nil {
45 | panic(err)
46 | }
47 | defer f.Close()
48 | d := json.NewDecoder(f)
49 | log.Println("Reading JSON lines...")
50 | i := 0
51 | for {
52 | var paper struct {
53 | ID string `json:"id"`
54 | Submitter string `json:"submitter"`
55 | Title string `json:"title"`
56 | Abstract string `json:"abstract"`
57 | }
58 | err := d.Decode(&paper)
59 | if err == io.EOF {
60 | break // reached end of file
61 | } else if err != nil {
62 | panic(err)
63 | }
64 |
65 | title := strings.ReplaceAll(paper.Title, "\n", " ")
66 | title = strings.ReplaceAll(title, " ", " ")
67 | content := strings.TrimSpace(paper.Abstract)
68 | docs = append(docs, chromem.Document{
69 | ID: paper.ID,
70 | Metadata: map[string]string{"submitter": paper.Submitter, "title": title},
71 | Content: content,
72 | })
73 | i++
74 | }
75 | log.Println("Read and parsed", i, "documents.")
76 | log.Println("Adding documents to chromem-go, including creating their embeddings via OpenAI API...")
77 | err = collection.AddDocuments(ctx, docs, runtime.NumCPU())
78 | if err != nil {
79 | panic(err)
80 | }
81 | } else {
82 | log.Println("Not reading JSON lines because collection was loaded from persistent storage.")
83 | }
84 |
85 | // Search for documents that are semantically similar to the search term.
86 | // We ask for the 10 most similar documents, but you can use more or less depending
87 | // on your needs.
88 | // You can limit the search by filtering on content or metadata (like the paper's
89 | // submitter), but we don't do that in this example.
90 | log.Println("Querying chromem-go...")
91 | start := time.Now()
92 | docRes, err := collection.Query(ctx, searchTerm, 10, nil, nil)
93 | if err != nil {
94 | panic(err)
95 | }
96 | log.Println("Search (incl query embedding) took", time.Since(start))
97 | // Here you could filter out any documents whose similarity is below a certain threshold.
98 | // if docRes[...].Similarity < 0.5 { ...
99 |
100 | // Print the retrieved documents and their similarity to the question.
101 | buf := &strings.Builder{}
102 | for i, res := range docRes {
103 | content := strings.ReplaceAll(res.Content, "\n", " ")
104 | content = content[:min(100, len(content))] + "..."
105 | fmt.Fprintf(buf, "\t%d) Similarity %f:\n"+
106 | "\t\tURL: https://arxiv.org/abs/%s\n"+
107 | "\t\tSubmitter: %s\n"+
108 | "\t\tTitle: %s\n"+
109 | "\t\tAbstract: %s\n",
110 | i+1, res.Similarity, res.ID, res.Metadata["submitter"], res.Metadata["title"], content)
111 | }
112 | log.Printf("Search results:\n%s\n", buf.String())
113 |
114 | /* Output:
115 | 2024/03/10 18:23:55 Setting up chromem-go...
116 | 2024/03/10 18:23:55 Reading JSON lines...
117 | 2024/03/10 18:23:55 Read and parsed 5006 documents.
118 | 2024/03/10 18:23:55 Adding documents to chromem-go, including creating their embeddings via OpenAI API...
119 | 2024/03/10 18:28:12 Querying chromem-go...
120 | 2024/03/10 18:28:12 Search (incl query embedding) took 529.451163ms
121 | 2024/03/10 18:28:12 Search results:
122 | 1) Similarity 0.488895:
123 | URL: https://arxiv.org/abs/2209.15469
124 | Submitter: Christian Buck
125 | Title: Zero-Shot Retrieval with Search Agents and Hybrid Environments
126 | Abstract: Learning to search is the task of building artificial agents that learn to autonomously use a search...
127 | 2) Similarity 0.480713:
128 | URL: https://arxiv.org/abs/2305.11516
129 | Submitter: Ryo Nagata Dr.
130 | Title: Contextualized Word Vector-based Methods for Discovering Semantic Differences with No Training nor Word Alignment
131 | Abstract: In this paper, we propose methods for discovering semantic differences in words appearing in two cor...
132 | 3) Similarity 0.476079:
133 | URL: https://arxiv.org/abs/2310.14025
134 | Submitter: Maria Lymperaiou
135 | Title: Large Language Models and Multimodal Retrieval for Visual Word Sense Disambiguation
136 | Abstract: Visual Word Sense Disambiguation (VWSD) is a novel challenging task with the goal of retrieving an i...
137 | 4) Similarity 0.474883:
138 | URL: https://arxiv.org/abs/2302.14785
139 | Submitter: Teven Le Scao
140 | Title: Joint Representations of Text and Knowledge Graphs for Retrieval and Evaluation
141 | Abstract: A key feature of neural models is that they can produce semantic vector representations of objects (...
142 | 5) Similarity 0.470326:
143 | URL: https://arxiv.org/abs/2309.02403
144 | Submitter: Dallas Card
145 | Title: Substitution-based Semantic Change Detection using Contextual Embeddings
146 | Abstract: Measuring semantic change has thus far remained a task where methods using contextual embeddings hav...
147 | 6) Similarity 0.466851:
148 | URL: https://arxiv.org/abs/2309.08187
149 | Submitter: Vu Tran
150 | Title: Encoded Summarization: Summarizing Documents into Continuous Vector Space for Legal Case Retrieval
151 | Abstract: We present our method for tackling a legal case retrieval task by introducing our method of encoding...
152 | 7) Similarity 0.461783:
153 | URL: https://arxiv.org/abs/2307.16638
154 | Submitter: Maiia Bocharova Bocharova
155 | Title: VacancySBERT: the approach for representation of titles and skills for semantic similarity search in the recruitment domain
156 | Abstract: The paper focuses on deep learning semantic search algorithms applied in the HR domain. The aim of t...
157 | 8) Similarity 0.460481:
158 | URL: https://arxiv.org/abs/2106.07400
159 | Submitter: Clara Meister
160 | Title: Determinantal Beam Search
161 | Abstract: Beam search is a go-to strategy for decoding neural sequence models. The algorithm can naturally be ...
162 | 9) Similarity 0.460001:
163 | URL: https://arxiv.org/abs/2305.04049
164 | Submitter: Yuxia Wu
165 | Title: Actively Discovering New Slots for Task-oriented Conversation
166 | Abstract: Existing task-oriented conversational search systems heavily rely on domain ontologies with pre-defi...
167 | 10) Similarity 0.458321:
168 | URL: https://arxiv.org/abs/2305.08654
169 | Submitter: Taichi Aida
170 | Title: Unsupervised Semantic Variation Prediction using the Distribution of Sibling Embeddings
171 | Abstract: Languages are dynamic entities, where the meanings associated with words constantly change with time...
172 | */
173 | }
174 |
--------------------------------------------------------------------------------
/examples/webassembly/README.md:
--------------------------------------------------------------------------------
1 | # WebAssembly (WASM)
2 |
3 | Go can compile to WebAssembly, which you can then use from JavaScript in a Browser or similar environments (Node, Deno, Bun etc.). You could also target WASI (WebAssembly System Interface) and run it in a standalone runtime (wazero, wasmtime, Wasmer), but in this example we focus on the Browser use case.
4 |
5 | ## How to run
6 |
7 | 1. Compile the `chromem-go` WASM binding to WebAssembly:
8 | 1. `cd /path/to/chromem-go/wasm`
9 | 2. `GOOS=js GOARCH=wasm go build -o ../examples/webassembly/chromem-go.wasm`
10 | 2. Copy Go's wrapper JavaScript:
11 | 1. `cp $(go env GOROOT)/misc/wasm/wasm_exec.js ../examples/webassembly/wasm_exec.js`
12 | 3. Serve the files
13 | 1. `cd ../examples/webassembly`
14 | 2. `go run github.com/philippgille/serve@latest -b localhost -p 8080` or similar
15 | 4. Open in your browser
16 |
--------------------------------------------------------------------------------
/examples/webassembly/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/philippgille/chromem-go
2 |
3 | go 1.21
4 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/philippgille/chromem-go/8311eb0938d6877effd931efb18c1d7460277bac/go.sum
--------------------------------------------------------------------------------
/persistence.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "bytes"
5 | "compress/gzip"
6 | "crypto/aes"
7 | "crypto/cipher"
8 | "crypto/rand"
9 | "crypto/sha256"
10 | "encoding/gob"
11 | "encoding/hex"
12 | "errors"
13 | "fmt"
14 | "io"
15 | "io/fs"
16 | "os"
17 | "path/filepath"
18 | )
19 |
20 | const metadataFileName = "00000000"
21 |
22 | func hash2hex(name string) string {
23 | hash := sha256.Sum256([]byte(name))
24 | // We encode 4 of the 32 bytes (32 out of 256 bits), so 8 hex characters.
25 | // It's enough to avoid collisions in reasonable amounts of documents per collection
26 | // and being shorter is better for file paths.
27 | return hex.EncodeToString(hash[:4])
28 | }
29 |
30 | // persistToFile persists an object to a file at the given path. The object is serialized
31 | // as gob, optionally compressed with flate (as gzip) and optionally encrypted with
32 | // AES-GCM. The encryption key must be 32 bytes long. If the file exists, it's
33 | // overwritten, otherwise created.
34 | func persistToFile(filePath string, obj any, compress bool, encryptionKey string) error {
35 | if filePath == "" {
36 | return fmt.Errorf("file path is empty")
37 | }
38 | // AES 256 requires a 32 byte key
39 | if encryptionKey != "" {
40 | if len(encryptionKey) != 32 {
41 | return errors.New("encryption key must be 32 bytes long")
42 | }
43 | }
44 |
45 | // If path doesn't exist, create the parent path.
46 | // If path exists, and it's a directory, return an error.
47 | fi, err := os.Stat(filePath)
48 | if err != nil {
49 | if !errors.Is(err, fs.ErrNotExist) {
50 | return fmt.Errorf("couldn't get info about the path: %w", err)
51 | } else {
52 | // If the file doesn't exist, create the parent path
53 | err := os.MkdirAll(filepath.Dir(filePath), 0o700)
54 | if err != nil {
55 | return fmt.Errorf("couldn't create parent directories to path: %w", err)
56 | }
57 | }
58 | } else if fi.IsDir() {
59 | return fmt.Errorf("path is a directory: %s", filePath)
60 | }
61 |
62 | // Open file for writing
63 | f, err := os.Create(filePath)
64 | if err != nil {
65 | return fmt.Errorf("couldn't create file: %w", err)
66 | }
67 | defer f.Close()
68 |
69 | return persistToWriter(f, obj, compress, encryptionKey)
70 | }
71 |
72 | // persistToWriter persists an object to a writer. The object is serialized
73 | // as gob, optionally compressed with flate (as gzip) and optionally encrypted with
74 | // AES-GCM. The encryption key must be 32 bytes long.
75 | // If the writer has to be closed, it's the caller's responsibility.
76 | func persistToWriter(w io.Writer, obj any, compress bool, encryptionKey string) error {
77 | // AES 256 requires a 32 byte key
78 | if encryptionKey != "" {
79 | if len(encryptionKey) != 32 {
80 | return errors.New("encryption key must be 32 bytes long")
81 | }
82 | }
83 |
84 | // We want to:
85 | // Encode as gob -> compress with flate -> encrypt with AES-GCM -> write to
86 | // passed writer.
87 | // To reduce memory usage we chain the writers instead of buffering, so we start
88 | // from the end. For AES GCM sealing the stdlib doesn't provide a writer though.
89 |
90 | var chainedWriter io.Writer
91 | if encryptionKey == "" {
92 | chainedWriter = w
93 | } else {
94 | chainedWriter = &bytes.Buffer{}
95 | }
96 |
97 | var gzw *gzip.Writer
98 | var enc *gob.Encoder
99 | if compress {
100 | gzw = gzip.NewWriter(chainedWriter)
101 | enc = gob.NewEncoder(gzw)
102 | } else {
103 | enc = gob.NewEncoder(chainedWriter)
104 | }
105 |
106 | // Start encoding, it will write to the chain of writers.
107 | if err := enc.Encode(obj); err != nil {
108 | return fmt.Errorf("couldn't encode or write object: %w", err)
109 | }
110 |
111 | // If compressing, close the gzip writer. Otherwise, the gzip footer won't be
112 | // written yet. When using encryption (and chainedWriter is a buffer) then
113 | // we'll encrypt an incomplete stream. Without encryption when we return here and having
114 | // a deferred Close(), there might be a silenced error.
115 | if compress {
116 | err := gzw.Close()
117 | if err != nil {
118 | return fmt.Errorf("couldn't close gzip writer: %w", err)
119 | }
120 | }
121 |
122 | // Without encyrption, the chain is done and the writing is finished.
123 | if encryptionKey == "" {
124 | return nil
125 | }
126 |
127 | // Otherwise, encrypt and then write to the unchained target writer.
128 | block, err := aes.NewCipher([]byte(encryptionKey))
129 | if err != nil {
130 | return fmt.Errorf("couldn't create new AES cipher: %w", err)
131 | }
132 | gcm, err := cipher.NewGCM(block)
133 | if err != nil {
134 | return fmt.Errorf("couldn't create GCM wrapper: %w", err)
135 | }
136 | nonce := make([]byte, gcm.NonceSize())
137 | if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
138 | return fmt.Errorf("couldn't read random bytes for nonce: %w", err)
139 | }
140 | // chainedWriter is a *bytes.Buffer
141 | buf := chainedWriter.(*bytes.Buffer)
142 | encrypted := gcm.Seal(nonce, nonce, buf.Bytes(), nil)
143 | _, err = w.Write(encrypted)
144 | if err != nil {
145 | return fmt.Errorf("couldn't write encrypted data: %w", err)
146 | }
147 |
148 | return nil
149 | }
150 |
151 | // readFromFile reads an object from a file at the given path. The object is deserialized
152 | // from gob. `obj` must be a pointer to an instantiated object. The file may
153 | // optionally be compressed as gzip and/or encrypted with AES-GCM. The encryption
154 | // key must be 32 bytes long.
155 | func readFromFile(filePath string, obj any, encryptionKey string) error {
156 | if filePath == "" {
157 | return fmt.Errorf("file path is empty")
158 | }
159 | // AES 256 requires a 32 byte key
160 | if encryptionKey != "" {
161 | if len(encryptionKey) != 32 {
162 | return errors.New("encryption key must be 32 bytes long")
163 | }
164 | }
165 |
166 | r, err := os.Open(filePath)
167 | if err != nil {
168 | return fmt.Errorf("couldn't open file: %w", err)
169 | }
170 | defer r.Close()
171 |
172 | return readFromReader(r, obj, encryptionKey)
173 | }
174 |
175 | // readFromReader reads an object from a Reader. The object is deserialized from gob.
176 | // `obj` must be a pointer to an instantiated object. The stream may optionally
177 | // be compressed as gzip and/or encrypted with AES-GCM. The encryption key must
178 | // be 32 bytes long.
179 | // If the reader has to be closed, it's the caller's responsibility.
180 | func readFromReader(r io.ReadSeeker, obj any, encryptionKey string) error {
181 | // AES 256 requires a 32 byte key
182 | if encryptionKey != "" {
183 | if len(encryptionKey) != 32 {
184 | return errors.New("encryption key must be 32 bytes long")
185 | }
186 | }
187 |
188 | // We want to:
189 | // Read from reader -> decrypt with AES-GCM -> decompress with flate -> decode
190 | // as gob.
191 | // To reduce memory usage we chain the readers instead of buffering, so we start
192 | // from the end. For the decryption there's no reader though.
193 |
194 | // For the chainedReader we don't declare it as ReadSeeker, so we can reassign
195 | // the gzip reader to it.
196 | var chainedReader io.Reader
197 |
198 | // Decrypt if an encryption key is provided
199 | if encryptionKey != "" {
200 | encrypted, err := io.ReadAll(r)
201 | if err != nil {
202 | return fmt.Errorf("couldn't read from reader: %w", err)
203 | }
204 | block, err := aes.NewCipher([]byte(encryptionKey))
205 | if err != nil {
206 | return fmt.Errorf("couldn't create AES cipher: %w", err)
207 | }
208 | gcm, err := cipher.NewGCM(block)
209 | if err != nil {
210 | return fmt.Errorf("couldn't create GCM wrapper: %w", err)
211 | }
212 | nonceSize := gcm.NonceSize()
213 | if len(encrypted) < nonceSize {
214 | return fmt.Errorf("encrypted data too short")
215 | }
216 | nonce, ciphertext := encrypted[:nonceSize], encrypted[nonceSize:]
217 | data, err := gcm.Open(nil, nonce, ciphertext, nil)
218 | if err != nil {
219 | return fmt.Errorf("couldn't decrypt data: %w", err)
220 | }
221 |
222 | chainedReader = bytes.NewReader(data)
223 | } else {
224 | chainedReader = r
225 | }
226 |
227 | // Determine if the stream is compressed
228 | magicNumber := make([]byte, 2)
229 | _, err := chainedReader.Read(magicNumber)
230 | if err != nil {
231 | return fmt.Errorf("couldn't read magic number to determine whether the stream is compressed: %w", err)
232 | }
233 | var compressed bool
234 | if magicNumber[0] == 0x1f && magicNumber[1] == 0x8b {
235 | compressed = true
236 | }
237 |
238 | // Reset reader. Both the reader from the param and bytes.Reader support seeking.
239 | if s, ok := chainedReader.(io.Seeker); !ok {
240 | return fmt.Errorf("reader doesn't support seeking")
241 | } else {
242 | _, err := s.Seek(0, 0)
243 | if err != nil {
244 | return fmt.Errorf("couldn't reset reader: %w", err)
245 | }
246 | }
247 |
248 | if compressed {
249 | gzr, err := gzip.NewReader(chainedReader)
250 | if err != nil {
251 | return fmt.Errorf("couldn't create gzip reader: %w", err)
252 | }
253 | defer gzr.Close()
254 | chainedReader = gzr
255 | }
256 |
257 | dec := gob.NewDecoder(chainedReader)
258 | err = dec.Decode(obj)
259 | if err != nil {
260 | return fmt.Errorf("couldn't decode object: %w", err)
261 | }
262 |
263 | return nil
264 | }
265 |
266 | // removeFile removes a file at the given path. If the file doesn't exist, it's a no-op.
267 | func removeFile(filePath string) error {
268 | if filePath == "" {
269 | return fmt.Errorf("file path is empty")
270 | }
271 |
272 | err := os.Remove(filePath)
273 | if err != nil {
274 | if !errors.Is(err, fs.ErrNotExist) {
275 | return fmt.Errorf("couldn't remove file %q: %w", filePath, err)
276 | }
277 | }
278 |
279 | return nil
280 | }
281 |
--------------------------------------------------------------------------------
/persistence_test.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "compress/gzip"
5 | "encoding/gob"
6 | "math/rand"
7 | "os"
8 | "path/filepath"
9 | "reflect"
10 | "testing"
11 | )
12 |
13 | func TestPersistenceWrite(t *testing.T) {
14 | tempDir, err := os.MkdirTemp("", "chromem-go")
15 | if err != nil {
16 | t.Fatal("expected nil, got", err)
17 | }
18 | defer os.RemoveAll(tempDir)
19 |
20 | type s struct {
21 | Foo string
22 | Bar []float32
23 | }
24 | obj := s{
25 | Foo: "test",
26 | Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}`
27 | }
28 |
29 | t.Run("gob", func(t *testing.T) {
30 | tempFilePath := tempDir + ".gob"
31 | if err := persistToFile(tempFilePath, obj, false, ""); err != nil {
32 | t.Fatal("expected nil, got", err)
33 | }
34 |
35 | // Check if the file exists.
36 | _, err = os.Stat(tempFilePath)
37 | if err != nil {
38 | t.Fatal("expected nil, got", err)
39 | }
40 |
41 | // Read file and decode
42 | f, err := os.Open(tempFilePath)
43 | if err != nil {
44 | t.Fatal("expected nil, got", err)
45 | }
46 | defer f.Close()
47 | d := gob.NewDecoder(f)
48 | res := s{}
49 | err = d.Decode(&res)
50 | if err != nil {
51 | t.Fatal("expected nil, got", err)
52 | }
53 |
54 | // Compare
55 | if !reflect.DeepEqual(obj, res) {
56 | t.Fatalf("expected %+v, got %+v", obj, res)
57 | }
58 | })
59 |
60 | t.Run("gob gzipped", func(t *testing.T) {
61 | tempFilePath := tempDir + ".gob.gz"
62 | if err := persistToFile(tempFilePath, obj, true, ""); err != nil {
63 | t.Fatal("expected nil, got", err)
64 | }
65 |
66 | // Check if the file exists.
67 | _, err = os.Stat(tempFilePath)
68 | if err != nil {
69 | t.Fatal("expected nil, got", err)
70 | }
71 |
72 | // Read file, decompress and decode
73 | f, err := os.Open(tempFilePath)
74 | if err != nil {
75 | t.Fatal("expected nil, got", err)
76 | }
77 | defer f.Close()
78 | gzr, err := gzip.NewReader(f)
79 | if err != nil {
80 | t.Fatal("expected nil, got", err)
81 | }
82 | d := gob.NewDecoder(gzr)
83 | res := s{}
84 | err = d.Decode(&res)
85 | if err != nil {
86 | t.Fatal("expected nil, got", err)
87 | }
88 |
89 | // Compare
90 | if !reflect.DeepEqual(obj, res) {
91 | t.Fatalf("expected %+v, got %+v", obj, res)
92 | }
93 | })
94 | }
95 |
96 | func TestPersistenceRead(t *testing.T) {
97 | tempDir, err := os.MkdirTemp("", "chromem-go")
98 | if err != nil {
99 | t.Fatal("expected nil, got", err)
100 | }
101 | defer os.RemoveAll(tempDir)
102 |
103 | type s struct {
104 | Foo string
105 | Bar []float32
106 | }
107 | obj := s{
108 | Foo: "test",
109 | Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}`
110 | }
111 |
112 | t.Run("gob", func(t *testing.T) {
113 | tempFilePath := tempDir + ".gob"
114 | f, err := os.Create(tempFilePath)
115 | if err != nil {
116 | t.Fatal("expected nil, got", err)
117 | }
118 | enc := gob.NewEncoder(f)
119 | err = enc.Encode(obj)
120 | if err != nil {
121 | t.Fatal("expected nil, got", err)
122 | }
123 | err = f.Close()
124 | if err != nil {
125 | t.Fatal("expected nil, got", err)
126 | }
127 |
128 | // Read the file.
129 | var res s
130 | err = readFromFile(tempFilePath, &res, "")
131 | if err != nil {
132 | t.Fatal("expected nil, got", err)
133 | }
134 |
135 | // Compare
136 | if !reflect.DeepEqual(obj, res) {
137 | t.Fatalf("expected %+v, got %+v", obj, res)
138 | }
139 | })
140 |
141 | t.Run("gob gzipped", func(t *testing.T) {
142 | tempFilePath := tempDir + ".gob.gz"
143 | f, err := os.Create(tempFilePath)
144 | if err != nil {
145 | t.Fatal("expected nil, got", err)
146 | }
147 | gzw := gzip.NewWriter(f)
148 | enc := gob.NewEncoder(gzw)
149 | err = enc.Encode(obj)
150 | if err != nil {
151 | t.Fatal("expected nil, got", err)
152 | }
153 | err = gzw.Close()
154 | if err != nil {
155 | t.Fatal("expected nil, got", err)
156 | }
157 | err = f.Close()
158 | if err != nil {
159 | t.Fatal("expected nil, got", err)
160 | }
161 |
162 | // Read the file.
163 | var res s
164 | err = readFromFile(tempFilePath, &res, "")
165 | if err != nil {
166 | t.Fatal("expected nil, got", err)
167 | }
168 |
169 | // Compare
170 | if !reflect.DeepEqual(obj, res) {
171 | t.Fatalf("expected %+v, got %+v", obj, res)
172 | }
173 | })
174 | }
175 |
176 | func TestPersistenceEncryption(t *testing.T) {
177 | // Instead of copy pasting encryption/decryption code, we resort to using both
178 | // functions under test, instead of one combined with an independent implementation.
179 |
180 | r := rand.New(rand.NewSource(rand.Int63()))
181 | // randString := randomString(r, 10)
182 | path := filepath.Join(os.TempDir(), "a", "chromem-go")
183 | // defer os.RemoveAll(path)
184 |
185 | type s struct {
186 | Foo string
187 | Bar []float32
188 | }
189 | obj := s{
190 | Foo: "test",
191 | Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}`
192 | }
193 | encryptionKey := randomString(r, 32)
194 |
195 | tt := []struct {
196 | name string
197 | filePath string
198 | compress bool
199 | }{
200 | {
201 | name: "compress false",
202 | filePath: path + ".gob.enc",
203 | compress: false,
204 | },
205 | {
206 | name: "compress true",
207 | filePath: path + ".gob.gz.enc",
208 | compress: true,
209 | },
210 | }
211 |
212 | for _, tc := range tt {
213 | t.Run(tc.name, func(t *testing.T) {
214 | err := persistToFile(tc.filePath, obj, tc.compress, encryptionKey)
215 | if err != nil {
216 | t.Fatal("expected nil, got", err)
217 | }
218 |
219 | // Check if the file exists.
220 | _, err = os.Stat(tc.filePath)
221 | if err != nil {
222 | t.Fatal("expected nil, got", err)
223 | }
224 |
225 | // Read the file.
226 | var res s
227 | err = readFromFile(tc.filePath, &res, encryptionKey)
228 | if err != nil {
229 | t.Fatal("expected nil, got", err)
230 | }
231 |
232 | // Compare
233 | if !reflect.DeepEqual(obj, res) {
234 | t.Fatalf("expected %+v, got %+v", obj, res)
235 | }
236 | })
237 | }
238 | }
239 |
--------------------------------------------------------------------------------
/query.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "cmp"
5 | "container/heap"
6 | "context"
7 | "fmt"
8 | "runtime"
9 | "slices"
10 | "strings"
11 | "sync"
12 | )
13 |
14 | var supportedFilters = []string{"$contains", "$not_contains"}
15 |
16 | type docSim struct {
17 | docID string
18 | similarity float32
19 | }
20 |
21 | // docMaxHeap is a max-heap of docSims, based on similarity.
22 | // See https://pkg.go.dev/container/heap@go1.22#example-package-IntHeap
23 | type docMaxHeap []docSim
24 |
25 | func (h docMaxHeap) Len() int { return len(h) }
26 | func (h docMaxHeap) Less(i, j int) bool { return h[i].similarity < h[j].similarity }
27 | func (h docMaxHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
28 |
29 | func (h *docMaxHeap) Push(x any) {
30 | // Push and Pop use pointer receivers because they modify the slice's length,
31 | // not just its contents.
32 | *h = append(*h, x.(docSim))
33 | }
34 |
35 | func (h *docMaxHeap) Pop() any {
36 | old := *h
37 | n := len(old)
38 | x := old[n-1]
39 | *h = old[0 : n-1]
40 | return x
41 | }
42 |
43 | // maxDocSims manages a max-heap of docSims with a fixed size, keeping the n highest
44 | // similarities. It's safe for concurrent use, but not the result of values().
45 | // In our benchmarks this was faster than sorting a slice of docSims at the end.
46 | type maxDocSims struct {
47 | h docMaxHeap
48 | lock sync.RWMutex
49 | size int
50 | }
51 |
52 | // newMaxDocSims creates a new nMaxDocs with a fixed size.
53 | func newMaxDocSims(size int) *maxDocSims {
54 | return &maxDocSims{
55 | h: make(docMaxHeap, 0, size),
56 | size: size,
57 | }
58 | }
59 |
60 | // add inserts a new docSim into the heap, keeping only the top n similarities.
61 | func (d *maxDocSims) add(doc docSim) {
62 | d.lock.Lock()
63 | defer d.lock.Unlock()
64 | if d.h.Len() < d.size {
65 | heap.Push(&d.h, doc)
66 | } else if d.h.Len() > 0 && d.h[0].similarity < doc.similarity {
67 | // Replace the smallest similarity if the new doc's similarity is higher
68 | heap.Pop(&d.h)
69 | heap.Push(&d.h, doc)
70 | }
71 | }
72 |
73 | // values returns the docSims in the heap, sorted by similarity (descending).
74 | // The call itself is safe for concurrent use with add(), but the result isn't.
75 | // Only work with the result after all calls to add() have finished.
76 | func (d *maxDocSims) values() []docSim {
77 | d.lock.RLock()
78 | defer d.lock.RUnlock()
79 | slices.SortFunc(d.h, func(i, j docSim) int {
80 | return cmp.Compare(j.similarity, i.similarity)
81 | })
82 | return d.h
83 | }
84 |
85 | // filterDocs filters a map of documents by metadata and content.
86 | // It does this concurrently.
87 | func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document {
88 | filteredDocs := make([]*Document, 0, len(docs))
89 | filteredDocsLock := sync.Mutex{}
90 |
91 | // Determine concurrency. Use number of docs or CPUs, whichever is smaller.
92 | numCPUs := runtime.NumCPU()
93 | numDocs := len(docs)
94 | concurrency := numCPUs
95 | if numDocs < numCPUs {
96 | concurrency = numDocs
97 | }
98 |
99 | docChan := make(chan *Document, concurrency*2)
100 |
101 | wg := sync.WaitGroup{}
102 | for i := 0; i < concurrency; i++ {
103 | wg.Add(1)
104 | go func() {
105 | defer wg.Done()
106 | for doc := range docChan {
107 | if documentMatchesFilters(doc, where, whereDocument) {
108 | filteredDocsLock.Lock()
109 | filteredDocs = append(filteredDocs, doc)
110 | filteredDocsLock.Unlock()
111 | }
112 | }
113 | }()
114 | }
115 |
116 | for _, doc := range docs {
117 | docChan <- doc
118 | }
119 | close(docChan)
120 |
121 | wg.Wait()
122 |
123 | // With filteredDocs being initialized as potentially large slice, let's return
124 | // nil instead of the empty slice.
125 | if len(filteredDocs) == 0 {
126 | filteredDocs = nil
127 | }
128 | return filteredDocs
129 | }
130 |
131 | // documentMatchesFilters checks if a document matches the given filters.
132 | // When calling this function, the whereDocument keys must already be validated!
133 | func documentMatchesFilters(document *Document, where, whereDocument map[string]string) bool {
134 | // A document's metadata must have *all* the fields in the where clause.
135 | for k, v := range where {
136 | // TODO: Do we want to check for existence of the key? I.e. should
137 | // a where clause with empty string as value match a document's
138 | // metadata that doesn't have the key at all?
139 | if document.Metadata[k] != v {
140 | return false
141 | }
142 | }
143 |
144 | // A document must satisfy *all* filters, until we support the `$or` operator.
145 | for k, v := range whereDocument {
146 | switch k {
147 | case "$contains":
148 | if !strings.Contains(document.Content, v) {
149 | return false
150 | }
151 | case "$not_contains":
152 | if strings.Contains(document.Content, v) {
153 | return false
154 | }
155 | default:
156 | // No handling (error) required because we already validated the
157 | // operators. This simplifies the concurrency logic (no err var
158 | // and lock, no context to cancel).
159 | }
160 | }
161 |
162 | return true
163 | }
164 |
165 | func getMostSimilarDocs(ctx context.Context, queryVectors, negativeVector []float32, negativeFilterThreshold float32, docs []*Document, n int) ([]docSim, error) {
166 | nMaxDocs := newMaxDocSims(n)
167 |
168 | // Determine concurrency. Use number of docs or CPUs, whichever is smaller.
169 | numCPUs := runtime.NumCPU()
170 | numDocs := len(docs)
171 | concurrency := numCPUs
172 | if numDocs < numCPUs {
173 | concurrency = numDocs
174 | }
175 |
176 | var sharedErr error
177 | sharedErrLock := sync.Mutex{}
178 | ctx, cancel := context.WithCancelCause(ctx)
179 | defer cancel(nil)
180 | setSharedErr := func(err error) {
181 | sharedErrLock.Lock()
182 | defer sharedErrLock.Unlock()
183 | // Another goroutine might have already set the error.
184 | if sharedErr == nil {
185 | sharedErr = err
186 | // Cancel the operation for all other goroutines.
187 | cancel(sharedErr)
188 | }
189 | }
190 |
191 | wg := sync.WaitGroup{}
192 | // Instead of using a channel to pass documents into the goroutines, we just
193 | // split the slice into sub-slices and pass those to the goroutines.
194 | // This turned out to be faster in the query benchmarks.
195 | subSliceSize := len(docs) / concurrency // Can leave remainder, e.g. 10/3 = 3; leaves 1
196 | rem := len(docs) % concurrency
197 | for i := 0; i < concurrency; i++ {
198 | start := i * subSliceSize
199 | end := start + subSliceSize
200 | // Add remainder to last goroutine
201 | if i == concurrency-1 {
202 | end += rem
203 | }
204 |
205 | wg.Add(1)
206 | go func(subSlice []*Document) {
207 | defer wg.Done()
208 | for _, doc := range subSlice {
209 | // Stop work if another goroutine encountered an error.
210 | if ctx.Err() != nil {
211 | return
212 | }
213 |
214 | // As the vectors are normalized, the dot product is the cosine similarity.
215 | sim, err := dotProduct(queryVectors, doc.Embedding)
216 | if err != nil {
217 | setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err))
218 | return
219 | }
220 |
221 | if negativeFilterThreshold > 0 {
222 | nsim, err := dotProduct(negativeVector, doc.Embedding)
223 | if err != nil {
224 | setSharedErr(fmt.Errorf("couldn't calculate negative similarity for document '%s': %w", doc.ID, err))
225 | return
226 | }
227 |
228 | if nsim > negativeFilterThreshold {
229 | continue
230 | }
231 | }
232 |
233 | nMaxDocs.add(docSim{docID: doc.ID, similarity: sim})
234 | }
235 | }(docs[start:end])
236 | }
237 |
238 | wg.Wait()
239 |
240 | if sharedErr != nil {
241 | return nil, sharedErr
242 | }
243 |
244 | return nMaxDocs.values(), nil
245 | }
246 |
--------------------------------------------------------------------------------
/query_test.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "context"
5 | "reflect"
6 | "slices"
7 | "testing"
8 | )
9 |
10 | func TestFilterDocs(t *testing.T) {
11 | docs := map[string]*Document{
12 | "1": {
13 | ID: "1",
14 | Metadata: map[string]string{
15 | "language": "en",
16 | },
17 | Embedding: []float32{0.1, 0.2, 0.3},
18 | Content: "hello world",
19 | },
20 | "2": {
21 | ID: "2",
22 | Metadata: map[string]string{
23 | "language": "de",
24 | },
25 | Embedding: []float32{0.2, 0.3, 0.4},
26 | Content: "hallo welt",
27 | },
28 | }
29 |
30 | tt := []struct {
31 | name string
32 | where map[string]string
33 | whereDocument map[string]string
34 | want []*Document
35 | }{
36 | {
37 | name: "meta match",
38 | where: map[string]string{"language": "de"},
39 | whereDocument: nil,
40 | want: []*Document{docs["2"]},
41 | },
42 | {
43 | name: "meta no match",
44 | where: map[string]string{"language": "fr"},
45 | whereDocument: nil,
46 | want: nil,
47 | },
48 | {
49 | name: "content contains all",
50 | where: nil,
51 | whereDocument: map[string]string{"$contains": "llo"},
52 | want: []*Document{docs["1"], docs["2"]},
53 | },
54 | {
55 | name: "content contains one",
56 | where: nil,
57 | whereDocument: map[string]string{"$contains": "hallo"},
58 | want: []*Document{docs["2"]},
59 | },
60 | {
61 | name: "content contains none",
62 | where: nil,
63 | whereDocument: map[string]string{"$contains": "bonjour"},
64 | want: nil,
65 | },
66 | {
67 | name: "content not_contains all",
68 | where: nil,
69 | whereDocument: map[string]string{"$not_contains": "bonjour"},
70 | want: []*Document{docs["1"], docs["2"]},
71 | },
72 | {
73 | name: "content not_contains one",
74 | where: nil,
75 | whereDocument: map[string]string{"$not_contains": "hello"},
76 | want: []*Document{docs["2"]},
77 | },
78 | {
79 | name: "meta and content match",
80 | where: map[string]string{"language": "de"},
81 | whereDocument: map[string]string{"$contains": "hallo"},
82 | want: []*Document{docs["2"]},
83 | },
84 | {
85 | name: "meta + contains + not_contains",
86 | where: map[string]string{"language": "de"},
87 | whereDocument: map[string]string{"$contains": "hallo", "$not_contains": "bonjour"},
88 | want: []*Document{docs["2"]},
89 | },
90 | }
91 |
92 | for _, tc := range tt {
93 | t.Run(tc.name, func(t *testing.T) {
94 | got := filterDocs(docs, tc.where, tc.whereDocument)
95 |
96 | if !reflect.DeepEqual(got, tc.want) {
97 | // If len is 2, the order might be different (function under test
98 | // is concurrent and order is not guaranteed).
99 | if len(got) == 2 && len(tc.want) == 2 {
100 | slices.Reverse(got)
101 | if reflect.DeepEqual(got, tc.want) {
102 | return
103 | }
104 | }
105 | t.Fatalf("got %v; want %v", got, tc.want)
106 | }
107 | })
108 | }
109 | }
110 |
111 | func TestNegative(t *testing.T) {
112 | ctx := context.Background()
113 | db := NewDB()
114 |
115 | c, err := db.CreateCollection("test", nil, nil)
116 | if err != nil {
117 | panic(err)
118 | }
119 |
120 | if err := c.AddDocuments(ctx, []Document{
121 | {
122 | ID: "1",
123 | Embedding: testEmbeddings["search_document: Village Builder Game"],
124 | },
125 | {
126 | ID: "2",
127 | Embedding: testEmbeddings["search_document: Town Craft Idle Game"],
128 | },
129 | {
130 | ID: "3",
131 | Embedding: testEmbeddings["search_document: Some Idle Game"],
132 | },
133 | }, 1); err != nil {
134 | t.Fatalf("failed to add documents: %v", err)
135 | }
136 |
137 | t.Run("NEGATIVE_MODE_SUBTRACT", func(t *testing.T) {
138 | res, err := c.QueryWithOptions(ctx, QueryOptions{
139 | QueryEmbedding: testEmbeddings["search_query: town"],
140 | NResults: c.Count(),
141 | Negative: NegativeQueryOptions{
142 | Embedding: testEmbeddings["search_query: idle"],
143 | Mode: NEGATIVE_MODE_SUBTRACT,
144 | },
145 | })
146 | if err != nil {
147 | panic(err)
148 | }
149 |
150 | for _, r := range res {
151 | t.Logf("%s: %v", r.ID, r.Similarity)
152 | }
153 |
154 | if len(res) != 3 {
155 | t.Fatalf("expected 3 results, got %d", len(res))
156 | }
157 |
158 | // Village Builder Game
159 | if res[0].ID != "1" {
160 | t.Fatalf("expected document with ID 1, got %s", res[0].ID)
161 | }
162 | // Town Craft Idle Game
163 | if res[1].ID != "2" {
164 | t.Fatalf("expected document with ID 2, got %s", res[1].ID)
165 | }
166 | // Some Idle Game
167 | if res[2].ID != "3" {
168 | t.Fatalf("expected document with ID 3, got %s", res[2].ID)
169 | }
170 | })
171 |
172 | t.Run("NEGATIVE_MODE_FILTER", func(t *testing.T) {
173 | res, err := c.QueryWithOptions(ctx, QueryOptions{
174 | QueryEmbedding: testEmbeddings["search_query: town"],
175 | NResults: c.Count(),
176 | Negative: NegativeQueryOptions{
177 | Embedding: testEmbeddings["search_query: idle"],
178 | Mode: NEGATIVE_MODE_FILTER,
179 | },
180 | })
181 | if err != nil {
182 | panic(err)
183 | }
184 |
185 | for _, r := range res {
186 | t.Logf("%s: %v", r.ID, r.Similarity)
187 | }
188 |
189 | if len(res) != 1 {
190 | t.Fatalf("expected 1 result, got %d", len(res))
191 | }
192 |
193 | // Village Builder Game
194 | if res[0].ID != "1" {
195 | t.Fatalf("expected document with ID 1, got %s", res[0].ID)
196 | }
197 | })
198 | }
199 |
--------------------------------------------------------------------------------
/vector.go:
--------------------------------------------------------------------------------
1 | package chromem
2 |
3 | import (
4 | "errors"
5 | "math"
6 | )
7 |
8 | const isNormalizedPrecisionTolerance = 1e-6
9 |
10 | // dotProduct calculates the dot product between two vectors.
11 | // It's the same as cosine similarity for normalized vectors.
12 | // The resulting value represents the similarity, so a higher value means the
13 | // vectors are more similar.
14 | func dotProduct(a, b []float32) (float32, error) {
15 | // The vectors must have the same length
16 | if len(a) != len(b) {
17 | return 0, errors.New("vectors must have the same length")
18 | }
19 |
20 | var dotProduct float32
21 | for i := range a {
22 | dotProduct += a[i] * b[i]
23 | }
24 |
25 | return dotProduct, nil
26 | }
27 |
28 | func normalizeVector(v []float32) []float32 {
29 | var norm float32
30 | for _, val := range v {
31 | norm += val * val
32 | }
33 | norm = float32(math.Sqrt(float64(norm)))
34 |
35 | res := make([]float32, len(v))
36 | for i, val := range v {
37 | res[i] = val / norm
38 | }
39 |
40 | return res
41 | }
42 |
43 | // subtractVector subtracts vector b from vector a in place.
44 | func subtractVector(a, b []float32) []float32 {
45 | res := make([]float32, len(a))
46 |
47 | for i := range a {
48 | res[i] = a[i] - b[i]
49 | }
50 |
51 | return res
52 | }
53 |
54 | // isNormalized checks if the vector is normalized.
55 | func isNormalized(v []float32) bool {
56 | var sqSum float64
57 | for _, val := range v {
58 | sqSum += float64(val) * float64(val)
59 | }
60 | magnitude := math.Sqrt(sqSum)
61 | return math.Abs(magnitude-1) < isNormalizedPrecisionTolerance
62 | }
63 |
--------------------------------------------------------------------------------
/wasm/main.go:
--------------------------------------------------------------------------------
1 | //go:build js
2 |
3 | package main
4 |
5 | import (
6 | "context"
7 | "errors"
8 | "syscall/js"
9 |
10 | "github.com/philippgille/chromem-go"
11 | )
12 |
13 | var c *chromem.Collection
14 |
15 | func main() {
16 | js.Global().Set("initDB", js.FuncOf(initDB))
17 | js.Global().Set("addDocument", js.FuncOf(addDocument))
18 | js.Global().Set("query", js.FuncOf(query))
19 |
20 | select {} // prevent main from exiting
21 | }
22 |
23 | // Exported function to initialize the database and collection.
24 | // Takes an OpenAI API key as argument.
25 | func initDB(this js.Value, args []js.Value) interface{} {
26 | if len(args) != 1 {
27 | return "expected 1 argument with the OpenAI API key"
28 | }
29 |
30 | openAIAPIKey := args[0].String()
31 | embeddingFunc := chromem.NewEmbeddingFuncOpenAI(openAIAPIKey, chromem.EmbeddingModelOpenAI3Small)
32 |
33 | db := chromem.NewDB()
34 | var err error
35 | c, err = db.CreateCollection("chromem", nil, embeddingFunc)
36 | if err != nil {
37 | return err.Error()
38 | }
39 |
40 | return nil
41 | }
42 |
43 | // Exported function to add documents to the collection.
44 | // Takes the document ID and content as arguments.
45 | func addDocument(this js.Value, args []js.Value) interface{} {
46 | ctx := context.Background()
47 |
48 | var id string
49 | var content string
50 | var err error
51 | if len(args) != 2 {
52 | err = errors.New("expected 2 arguments with the document ID and content")
53 | } else {
54 | id = args[0].String()
55 | content = args[1].String()
56 | }
57 |
58 | handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
59 | resolve := args[0]
60 | reject := args[1]
61 | go func() {
62 | if err != nil {
63 | handleErr(err, reject)
64 | return
65 | }
66 |
67 | err = c.AddDocument(ctx, chromem.Document{
68 | ID: id,
69 | Content: content,
70 | })
71 | if err != nil {
72 | handleErr(err, reject)
73 | return
74 | }
75 | resolve.Invoke()
76 | }()
77 | return nil
78 | })
79 |
80 | promiseConstructor := js.Global().Get("Promise")
81 | return promiseConstructor.New(handler)
82 | }
83 |
84 | // Exported function to query the collection
85 | // Takes the query string and the number of documents to return as argument.
86 | func query(this js.Value, args []js.Value) interface{} {
87 | ctx := context.Background()
88 |
89 | var q string
90 | var err error
91 | if len(args) != 1 {
92 | err = errors.New("expected 1 argument with the query string")
93 | } else {
94 | q = args[0].String()
95 | }
96 |
97 | handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
98 | resolve := args[0]
99 | reject := args[1]
100 | go func() {
101 | if err != nil {
102 | handleErr(err, reject)
103 | return
104 | }
105 |
106 | res, err := c.Query(ctx, q, 1, nil, nil)
107 | if err != nil {
108 | handleErr(err, reject)
109 | return
110 | }
111 |
112 | // Convert response to JS values
113 | // TODO: Return more than one result
114 | o := js.Global().Get("Object").New()
115 | o.Set("ID", res[0].ID)
116 | o.Set("Similarity", res[0].Similarity)
117 | o.Set("Content", res[0].Content)
118 |
119 | resolve.Invoke(o)
120 | }()
121 | return nil
122 | })
123 |
124 | promiseConstructor := js.Global().Get("Promise")
125 | return promiseConstructor.New(handler)
126 | }
127 |
128 | func handleErr(err error, reject js.Value) {
129 | errorConstructor := js.Global().Get("Error")
130 | errorObject := errorConstructor.New(err.Error())
131 | reject.Invoke(errorObject)
132 | }
133 |
--------------------------------------------------------------------------------