├── .github
└── workflows
│ ├── go.yml
│ ├── lint.yaml
│ ├── mkdocs.yml
│ └── nightly.yml
├── .gitignore
├── .golangci.yml
├── CNAME
├── CODEOWNERS
├── LICENSE.md
├── Makefile
├── README.md
├── chroma.go
├── collection
├── collection.go
└── collection_test.go
├── docs
├── CNAME
├── docs
│ ├── assets
│ │ └── images
│ │ │ ├── favicon.png
│ │ │ └── logo.png
│ ├── auth.md
│ ├── client.md
│ ├── embeddings.md
│ ├── filtering.md
│ ├── index.md
│ ├── javascripts
│ │ └── gtag.js
│ ├── records.md
│ └── rerankers.md
└── mkdocs.yml
├── examples
└── v2
│ ├── basic
│ ├── README.md
│ ├── go.mod
│ ├── go.sum
│ └── main.go
│ ├── custom_embedding_function
│ ├── README.md
│ ├── go.mod
│ └── go.sum
│ ├── embedding_function_basic
│ ├── README.md
│ ├── go.mod
│ ├── go.sum
│ └── main.go
│ ├── reranking_function_basic
│ ├── README.md
│ ├── go.mod
│ └── go.sum
│ └── tenant_and_db
│ ├── README.md
│ ├── go.mod
│ ├── go.sum
│ └── main.go
├── gen_api.sh
├── gen_api_v3.sh
├── go.mod
├── go.sum
├── internal
└── http
│ ├── constants.go
│ ├── errors.go
│ ├── retry.go
│ ├── retry_test.go
│ ├── strategy.go
│ └── utils.go
├── metadata
├── metadata.go
└── metadata_test.go
├── openapi.yaml
├── patches
└── model_anyof.mustache
├── pkg
├── api
│ └── v2
│ │ ├── auth.go
│ │ ├── base.go
│ │ ├── client.go
│ │ ├── client_http.go
│ │ ├── client_http_integration_test.go
│ │ ├── client_http_test.go
│ │ ├── collection.go
│ │ ├── collection_http.go
│ │ ├── collection_http_integration_test.go
│ │ ├── collection_http_test.go
│ │ ├── constants.go
│ │ ├── document.go
│ │ ├── document_test.go
│ │ ├── ids.go
│ │ ├── metadata.go
│ │ ├── metadata_test.go
│ │ ├── record.go
│ │ ├── record_test.go
│ │ ├── reranking.go
│ │ ├── results.go
│ │ ├── results_test.go
│ │ ├── server.htpasswd
│ │ ├── utils.go
│ │ ├── v1-config.yaml
│ │ ├── where.go
│ │ ├── where_document.go
│ │ ├── where_document_test.go
│ │ └── where_test.go
├── commons
│ ├── cohere
│ │ ├── cohere_commons.go
│ │ └── cohere_commons_test.go
│ └── http
│ │ ├── constants.go
│ │ ├── errors.go
│ │ ├── retry.go
│ │ ├── retry_test.go
│ │ ├── strategy.go
│ │ └── utils.go
├── embeddings
│ ├── cloudflare
│ │ ├── cloudflare.go
│ │ ├── cloudflare_test.go
│ │ └── option.go
│ ├── cohere
│ │ ├── cohere.go
│ │ ├── cohere_test.go
│ │ └── option.go
│ ├── default_ef
│ │ ├── constants.go
│ │ ├── default_ef.go
│ │ ├── default_ef_test.go
│ │ ├── download_utils.go
│ │ ├── download_utils_test.go
│ │ └── tensors_utils.go
│ ├── distance_metric.go
│ ├── embedding.go
│ ├── embedding_test.go
│ ├── gemini
│ │ ├── gemini.go
│ │ ├── gemini_test.go
│ │ └── option.go
│ ├── hf
│ │ ├── hf.go
│ │ ├── hf_test.go
│ │ └── option.go
│ ├── jina
│ │ ├── jina.go
│ │ ├── jina_test.go
│ │ └── option.go
│ ├── mistral
│ │ ├── mistral.go
│ │ ├── mistral_test.go
│ │ └── option.go
│ ├── nomic
│ │ ├── nomic.go
│ │ ├── nomic_test.go
│ │ └── option.go
│ ├── ollama
│ │ ├── ollama.go
│ │ ├── ollama_test.go
│ │ └── option.go
│ ├── openai
│ │ ├── openai.go
│ │ ├── openai_test.go
│ │ └── options.go
│ ├── together
│ │ ├── option.go
│ │ ├── together.go
│ │ └── together_test.go
│ └── voyage
│ │ ├── option.go
│ │ ├── voyage.go
│ │ └── voyage_test.go
├── rerankings
│ ├── cohere
│ │ ├── cohere.go
│ │ ├── cohere_test.go
│ │ └── option.go
│ ├── hf
│ │ ├── huggingface.go
│ │ ├── huggingface_test.go
│ │ └── option.go
│ ├── jina
│ │ ├── jina.go
│ │ ├── jina_test.go
│ │ └── option.go
│ ├── reranking.go
│ └── reranking_test.go
└── tokenizers
│ └── libtokenizers
│ ├── tokenizer.go
│ └── tokenizers.h
├── scripts
└── chroma_server.sh
├── swagger
├── api_default.go
├── client.go
├── configuration.go
├── model_add_embedding.go
├── model_collection.go
├── model_create_collection.go
├── model_create_database.go
├── model_create_tenant.go
├── model_database.go
├── model_delete_embedding.go
├── model_embeddings_inner.go
├── model_get_embedding.go
├── model_get_result.go
├── model_http_validation_error.go
├── model_include_inner.go
├── model_location_inner.go
├── model_metadata.go
├── model_query_embedding.go
├── model_query_result.go
├── model_tenant.go
├── model_update_collection.go
├── model_update_embedding.go
├── model_validation_error.go
├── response.go
└── utils.go
├── test
├── chroma_api_test.go
├── chroma_client_test.go
└── test_utils.go
├── types
├── record.go
├── record_test.go
├── types.go
└── types_test.go
├── where
├── where.go
└── where_test.go
└── where_document
├── wheredoc.go
└── wheredoc_test.go
/.github/workflows/lint.yaml:
--------------------------------------------------------------------------------
1 | name: Go Lint
2 |
3 | on:
4 | pull_request: {}
5 |
6 | jobs:
7 | lint:
8 | runs-on: ubuntu-latest
9 | permissions:
10 | # Required: allow read access to the content for analysis.
11 | contents: read
12 | # Optional: allow read access to pull request. Use with `only-new-issues` option.
13 | pull-requests: read
14 | # Optional: Allow write access to checks to allow the action to annotate code in the PR.
15 | checks: write
16 | steps:
17 | - name: Checkout
18 | uses: actions/checkout@v4
19 | with:
20 | fetch-depth: 0
21 | - name: Set up Go
22 | uses: actions/setup-go@v4
23 | with:
24 | go-version-file: 'go.mod'
25 | - name: Run golangci-lint
26 | uses: golangci/golangci-lint-action@v8
27 | with:
28 | version: v2.1
--------------------------------------------------------------------------------
/.github/workflows/mkdocs.yml:
--------------------------------------------------------------------------------
1 | name: Deploy MkDocs Site
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - uses: actions/checkout@v2
14 |
15 | - name: Set up Python
16 | uses: actions/setup-python@v2
17 | with:
18 | python-version: 3.x
19 |
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install mkdocs-material
24 |
25 | - name: Build the MkDocs site
26 | run: |
27 | cd docs
28 | mkdocs build -d ../site
29 | cp ../CNAME ../site
30 |
31 | - name: Deploy to GitHub Pages
32 | uses: peaceiris/actions-gh-pages@v3
33 | with:
34 | github_token: ${{ secrets.GITHUB_TOKEN }}
35 | publish_dir: ./site
--------------------------------------------------------------------------------
/.github/workflows/nightly.yml:
--------------------------------------------------------------------------------
1 | # This workflow will build a golang project
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
3 |
4 | name: Nightly Test
5 |
6 | on:
7 | schedule:
8 | - cron: '0 0 * * *' # Run nightly at 00:00 UTC
9 | workflow_dispatch:
10 |
11 |
12 | jobs:
13 | build:
14 | name: Test API V2
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v4
18 | - name: Log in to GitHub Container Registry
19 | uses: docker/login-action@v3
20 | with:
21 | registry: ghcr.io
22 | username: ${{ github.actor }}
23 | password: ${{ secrets.GITHUB_TOKEN }}
24 | - name: Set up Go
25 | uses: actions/setup-go@v4
26 | with:
27 | go-version-file: 'go.mod'
28 | - name: Lint
29 | uses: golangci/golangci-lint-action@v8
30 | with:
31 | version: v2.1
32 | - name: Build
33 | run: make build
34 | - name: Test
35 | run: make test-v2
36 | env:
37 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
38 | COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
39 | HF_API_KEY: ${{ secrets.HF_API_KEY }}
40 | CF_API_TOKEN: ${{ secrets.CF_API_TOKEN }}
41 | CF_ACCOUNT_ID: ${{ secrets.CF_ACCOUNT_ID }}
42 | CF_GATEWAY_ENDPOINT: ${{ secrets.CF_GATEWAY_ENDPOINT }}
43 | TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
44 | VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }}
45 | GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
46 | MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
47 | NOMIC_API_KEY: ${{ secrets.NOMIC_API_KEY }}
48 | CHROMA_VERSION: "latest"
49 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled Object files, Static and Dynamic libs (Shared Objects)
2 | *.o
3 | *.a
4 | *.so
5 |
6 | # Folders
7 | _obj
8 | _test
9 |
10 | # Architecture specific extensions/prefixes
11 | *.[568vq]
12 | [568vq].out
13 |
14 | *.cgo1.go
15 | *.cgo2.c
16 | _cgo_defun.c
17 | _cgo_gotypes.go
18 | _cgo_export.*
19 |
20 | _testmain.go
21 |
22 | *.exe
23 | *.test
24 | *.prof
25 | .env
26 | .idea/*
27 | *.iml
28 | /data
29 | /pkg/rerankings/hf/data
30 |
--------------------------------------------------------------------------------
/.golangci.yml:
--------------------------------------------------------------------------------
1 | version: "2"
2 | run:
3 | modules-download-mode: readonly
4 | linters:
5 | enable:
6 | - dupword
7 | - ginkgolinter
8 | - gocritic
9 | - mirror
10 | settings:
11 | gocritic:
12 | disable-all: false
13 | staticcheck:
14 | checks:
15 | - all
16 | - ST1000
17 | - ST1001
18 | - ST1003
19 | dot-import-whitelist:
20 | - fmt
21 | exclusions:
22 | generated: lax
23 | presets:
24 | - comments
25 | - common-false-positives
26 | - legacy
27 | - std-error-handling
28 | rules:
29 | - linters:
30 | - ineffassign
31 | path: conversion\.go
32 | - linters:
33 | - staticcheck
34 | text: 'ST1003: should not use underscores in Go names; func (Convert_.*_To_.*|SetDefaults_)'
35 | - linters:
36 | - ginkgolinter
37 | text: use a function call in (Eventually|Consistently)
38 | paths:
39 | - third_party$
40 | - builtin$
41 | - examples$
42 | - ./swagger
43 | issues:
44 | max-same-issues: 0
45 | formatters:
46 | enable:
47 | - gci
48 | settings:
49 | gci:
50 | sections:
51 | - standard
52 | - default
53 | - prefix(github.com/amikos-tech/chroma-go)
54 | - blank
55 | - dot
56 | custom-order: true
57 | exclusions:
58 | generated: lax
59 | paths:
60 | - third_party$
61 | - builtin$
62 | - examples$
63 |
--------------------------------------------------------------------------------
/CNAME:
--------------------------------------------------------------------------------
1 | go-client.chromadb.dev
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @tazarov
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2023 Amikos Tech Ltd.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | generate:
2 | echo "This is deprecated. 0.2.x or later does not use generated client."
3 | sh ./gen_api_v3.sh
4 |
5 | build:
6 | go build -v ./...
7 |
8 | .PHONY: gotestsum-bin
9 | gotestsum-bin:
10 | go install gotest.tools/gotestsum@latest
11 |
12 | .PHONY: test
13 | test: gotestsum-bin
14 | gotestsum \
15 | --format short-verbose \
16 | --rerun-fails=5 \
17 | --packages="./..." \
18 | --junitfile unit-v1.xml \
19 | -- \
20 | -v \
21 | -tags=basic \
22 | -coverprofile=coverage-v1.out \
23 | -timeout=30m
24 |
25 | .PHONY: test-v2
26 | test-v2: gotestsum-bin
27 | gotestsum \
28 | --format short-verbose \
29 | --rerun-fails=5 \
30 | --packages="./..." \
31 | --junitfile unit-v2.xml \
32 | -- \
33 | -v \
34 | -tags=basicv2 \
35 | -coverprofile=coverage-v2.out \
36 | -timeout=30m
37 |
38 | .PHONY: test-rf
39 | test-rf: gotestsum-bin
40 | gotestsum \
41 | --format short-verbose \
42 | --rerun-fails=5 \
43 | --packages="./..." \
44 | --junitfile unit-rf.xml \
45 | -- \
46 | -v \
47 | -tags=rf \
48 | -coverprofile=coverage-rf.out \
49 | -timeout=30m
50 |
51 | .PHONY: test-ef
52 | test-ef: gotestsum-bin
53 | gotestsum \
54 | --format short-verbose \
55 | --rerun-fails=5 \
56 | --packages="./..." \
57 | --junitfile unit-ef.xml \
58 | -- \
59 | -v \
60 | -tags=ef \
61 | -coverprofile=coverage-ef.out \
62 | -timeout=30m
63 |
64 | .PHONY: lint
65 | lint:
66 | golangci-lint run
67 |
68 | .PHONY: lint-fix
69 | lint-fix:
70 | golangci-lint run --fix ./...
71 |
72 | .PHONY: clean-lint-cache
73 | clean-lint-cache:
74 | golangci-lint cache clean
75 |
76 |
77 | .PHONY: server
78 | server:
79 | sh ./scripts/chroma_server.sh
80 |
--------------------------------------------------------------------------------
/collection/collection.go:
--------------------------------------------------------------------------------
1 | package collection
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/amikos-tech/chroma-go/metadata"
7 | "github.com/amikos-tech/chroma-go/types"
8 | )
9 |
10 | type Builder struct {
11 | Tenant string
12 | Database string
13 | Name string
14 | Metadata map[string]interface{}
15 | CreateIfNotExist bool
16 | EmbeddingFunction types.EmbeddingFunction
17 | IDGenerator types.IDGenerator
18 | }
19 |
20 | type Option func(*Builder) error
21 |
22 | func WithEmbeddingFunction(embeddingFunction types.EmbeddingFunction) Option {
23 | return func(c *Builder) error {
24 | c.EmbeddingFunction = embeddingFunction
25 | return nil
26 | }
27 | }
28 |
29 | func WithIDGenerator(idGenerator types.IDGenerator) Option {
30 | return func(c *Builder) error {
31 | c.IDGenerator = idGenerator
32 | return nil
33 | }
34 | }
35 |
36 | func WithCreateIfNotExist(create bool) Option {
37 | return func(c *Builder) error {
38 | c.CreateIfNotExist = create
39 | return nil
40 | }
41 | }
42 |
43 | func WithHNSWDistanceFunction(distanceFunction types.DistanceFunction) Option {
44 | return func(b *Builder) error {
45 | if distanceFunction != types.L2 && distanceFunction != types.IP && distanceFunction != types.COSINE {
46 | return fmt.Errorf("invalid distance function, must be one of l2, ip, or cosine")
47 | }
48 | return WithMetadata(types.HNSWSpace, distanceFunction)(b)
49 | }
50 | }
51 |
52 | func WithHNSWBatchSize(batchSize int32) Option {
53 | return func(b *Builder) error {
54 | if batchSize < 1 {
55 | return fmt.Errorf("batch size must be greater than 0")
56 | }
57 | return WithMetadata(types.HNSWBatchSize, batchSize)(b)
58 | }
59 | }
60 |
61 | func WithHNSWSyncThreshold(syncThreshold int32) Option {
62 | return func(b *Builder) error {
63 | if syncThreshold < 1 {
64 | return fmt.Errorf("sync threshold must be greater than 0")
65 | }
66 | return WithMetadata(types.HNSWSyncThreshold, syncThreshold)(b)
67 | }
68 | }
69 |
70 | func WithHNSWM(m int32) Option {
71 | return func(b *Builder) error {
72 | if m < 1 {
73 | return fmt.Errorf("m must be greater than 0")
74 | }
75 | return WithMetadata(types.HNSWM, m)(b)
76 | }
77 | }
78 |
79 | func WithHNSWConstructionEf(efConstruction int32) Option {
80 | return func(b *Builder) error {
81 | if efConstruction < 1 {
82 | return fmt.Errorf("efConstruction must be greater than 0")
83 | }
84 | return WithMetadata(types.HNSWConstructionEF, efConstruction)(b)
85 | }
86 | }
87 |
88 | // WithMetadatas adds metadata to the collection. If the metadata key already exists, the value is overwritten.
89 | func WithMetadatas(metadata map[string]interface{}) Option {
90 | return func(b *Builder) error {
91 | if b.Metadata == nil {
92 | b.Metadata = make(map[string]interface{})
93 | }
94 | for k, v := range metadata {
95 | err := WithMetadata(k, v)(b)
96 | if err != nil {
97 | return err
98 | }
99 | }
100 | return nil
101 | }
102 | }
103 |
104 | func WithHNSWSearchEf(efSearch int32) Option {
105 | return func(b *Builder) error {
106 | if efSearch < 1 {
107 | return fmt.Errorf("efSearch must be greater than 0")
108 | }
109 | return WithMetadata(types.HNSWSearchEF, efSearch)(b)
110 | }
111 | }
112 |
113 | func WithHNSWNumThreads(numThreads int32) Option {
114 | return func(b *Builder) error {
115 | if numThreads < 1 {
116 | return fmt.Errorf("numThreads must be greater than 0")
117 | }
118 | return WithMetadata(types.HNSWNumThreads, numThreads)(b)
119 | }
120 | }
121 |
122 | func WithHNSWResizeFactor(resizeFactor float32) Option {
123 | return func(b *Builder) error {
124 | if resizeFactor < 0 {
125 | return fmt.Errorf("resizeFactor must be greater than or equal to 0")
126 | }
127 | return WithMetadata(types.HNSWResizeFactor, resizeFactor)(b)
128 | }
129 | }
130 |
131 | func WithMetadata(key string, value interface{}) Option {
132 | return func(b *Builder) error {
133 | if b.Metadata == nil {
134 | b.Metadata = make(map[string]interface{})
135 | }
136 | err := metadata.WithMetadata(key, value)(metadata.NewMetadataBuilder(&b.Metadata))
137 | if err != nil {
138 | return err
139 | }
140 | return nil
141 | }
142 | }
143 |
144 | func WithTenant(tenant string) Option {
145 | return func(c *Builder) error {
146 | if tenant == "" {
147 | return fmt.Errorf("tenant cannot be empty")
148 | }
149 | c.Tenant = tenant
150 | return nil
151 | }
152 | }
153 |
154 | func WithDatabase(database string) Option {
155 | return func(c *Builder) error {
156 | if database == "" {
157 | return fmt.Errorf("database cannot be empty")
158 | }
159 | c.Database = database
160 | return nil
161 | }
162 | }
163 |
--------------------------------------------------------------------------------
/collection/collection_test.go:
--------------------------------------------------------------------------------
1 | //go:build basic
2 |
3 | package collection
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/stretchr/testify/require"
9 |
10 | "github.com/amikos-tech/chroma-go/types"
11 | )
12 |
13 | func TestCollectionBuilder(t *testing.T) {
14 |
15 | t.Run("With Embedding Function", func(t *testing.T) {
16 | b := &Builder{}
17 | err := WithEmbeddingFunction(nil)(b)
18 | require.NoError(t, err, "Unexpected error: %v", err)
19 | require.Nil(t, b.EmbeddingFunction)
20 | })
21 | t.Run("With ID Generator", func(t *testing.T) {
22 | var generator = types.NewULIDGenerator()
23 | b := &Builder{}
24 | err := WithIDGenerator(generator)(b)
25 | require.NoError(t, err, "Unexpected error: %v", err)
26 | require.Equal(t, generator, b.IDGenerator)
27 | })
28 | t.Run("With Create If Not Exist", func(t *testing.T) {
29 | b := &Builder{}
30 | err := WithCreateIfNotExist(true)(b)
31 | require.NoError(t, err, "Unexpected error: %v", err)
32 | require.True(t, b.CreateIfNotExist)
33 | })
34 |
35 | t.Run("With HNSW Distance Function", func(t *testing.T) {
36 | b := &Builder{}
37 | err := WithHNSWDistanceFunction(types.L2)(b)
38 | require.NoError(t, err, "Unexpected error: %v", err)
39 | require.NoError(t, err, "Unexpected error: %v", err)
40 | require.Equal(t, types.L2, b.Metadata[types.HNSWSpace])
41 | })
42 |
43 | t.Run("With Metadata", func(t *testing.T) {
44 | b := &Builder{}
45 | err := WithMetadata("testKey", "testValue")(b)
46 | require.NoError(t, err, "Unexpected error: %v", err)
47 | require.Equal(t, "testValue", b.Metadata["testKey"])
48 | })
49 |
50 | t.Run("With Metadatas", func(t *testing.T) {
51 | b := &Builder{}
52 | err := WithMetadatas(map[string]interface{}{"testKey": "testValue"})(b)
53 | require.NoError(t, err, "Unexpected error: %v", err)
54 | require.Equal(t, "testValue", b.Metadata["testKey"])
55 | })
56 |
57 | t.Run("With Metadatas for existing no override", func(t *testing.T) {
58 | b := &Builder{}
59 | b.Metadata = map[string]interface{}{"existingKey": "existingValue"}
60 | err := WithMetadatas(map[string]interface{}{"testKey": "testValue"})(b)
61 | require.NoError(t, err, "Unexpected error: %v", err)
62 | require.Contains(t, b.Metadata, "testKey")
63 | require.Equal(t, "testValue", b.Metadata["testKey"])
64 | require.Contains(t, b.Metadata, "existingKey")
65 | require.Equal(t, "existingValue", b.Metadata["existingKey"])
66 | })
67 |
68 | t.Run("With Metadatas for existing with override", func(t *testing.T) {
69 | b := &Builder{}
70 | b.Metadata = map[string]interface{}{"existingKey": "existingValue"}
71 | err := WithMetadatas(map[string]interface{}{"existingKey": "newValue"})(b)
72 | require.NoError(t, err, "Unexpected error: %v", err)
73 | require.Contains(t, b.Metadata, "existingKey")
74 | require.Equal(t, "newValue", b.Metadata["existingKey"])
75 | })
76 |
77 | t.Run("With Metadatas for invalid type", func(t *testing.T) {
78 | b := &Builder{}
79 | err := WithMetadatas(map[string]interface{}{"testKey": map[string]interface{}{"invalid": "value"}})(b)
80 | require.Error(t, err)
81 | })
82 |
83 | t.Run("With HNSW Batch Size", func(t *testing.T) {
84 | b := &Builder{}
85 | err := WithHNSWBatchSize(10)(b)
86 | require.NoError(t, err, "Unexpected error: %v", err)
87 | require.Equal(t, int32(10), b.Metadata[types.HNSWBatchSize])
88 | })
89 |
90 | t.Run("With HNSW Sync Threshold", func(t *testing.T) {
91 | b := &Builder{}
92 | err := WithHNSWSyncThreshold(10)(b)
93 | require.NoError(t, err, "Unexpected error: %v", err)
94 | require.Equal(t, int32(10), b.Metadata[types.HNSWSyncThreshold])
95 | })
96 |
97 | t.Run("With HNSWM", func(t *testing.T) {
98 | b := &Builder{}
99 | err := WithHNSWM(10)(b)
100 | require.NoError(t, err, "Unexpected error: %v", err)
101 | require.Equal(t, int32(10), b.Metadata[types.HNSWM])
102 | })
103 |
104 | t.Run("With HNSW Construction Ef", func(t *testing.T) {
105 | b := &Builder{}
106 | err := WithHNSWConstructionEf(10)(b)
107 | require.NoError(t, err, "Unexpected error: %v", err)
108 | require.Equal(t, int32(10), b.Metadata[types.HNSWConstructionEF])
109 | })
110 |
111 | t.Run("With HNSW Search Ef", func(t *testing.T) {
112 | b := &Builder{}
113 | err := WithHNSWSearchEf(10)(b)
114 | require.NoError(t, err, "Unexpected error: %v", err)
115 | require.Equal(t, int32(10), b.Metadata[types.HNSWSearchEF])
116 | })
117 | }
118 |
--------------------------------------------------------------------------------
/docs/CNAME:
--------------------------------------------------------------------------------
1 | go-client.chromadb.dev
--------------------------------------------------------------------------------
/docs/docs/assets/images/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amikos-tech/chroma-go/6bff4f299fe5cf068434eaf8cf2c4c6738639309/docs/docs/assets/images/favicon.png
--------------------------------------------------------------------------------
/docs/docs/assets/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amikos-tech/chroma-go/6bff4f299fe5cf068434eaf8cf2c4c6738639309/docs/docs/assets/images/logo.png
--------------------------------------------------------------------------------
/docs/docs/auth.md:
--------------------------------------------------------------------------------
1 | # Authentication
2 |
3 | There are four ways to authenticate with Chroma:
4 |
5 | - Manual Header authentication - this approach requires you to be familiar with the server-side auth and generate and insert the necessary headers manually.
6 | - Chroma Basic Auth mechanism
7 | - Chroma Token Auth mechanism with Bearer Authorization header
8 | - Chroma Token Auth mechanism with X-Chroma-Token header
9 |
10 | ### Manual Header Authentication
11 |
12 | ```go
13 | package main
14 |
15 | import (
16 | "context"
17 | "log"
18 | chroma "github.com/amikos-tech/chroma-go"
19 | )
20 |
21 | func main() {
22 | var defaultHeaders = map[string]string{"Authorization": "Bearer my-custom-token"}
23 | clientWithTenant, err := chroma.NewClient(chroma.WithBasePath("http://api.trychroma.com/v1/"), chroma.WithDefaultHeaders(defaultHeaders))
24 | if err != nil {
25 | log.Fatalf("Error creating client: %s \n", err)
26 | }
27 | _, err = clientWithTenant.Heartbeat(context.TODO())
28 | if err != nil {
29 | log.Fatalf("Error calling heartbeat: %s \n", err)
30 | }
31 | }
32 | ```
33 |
34 | ### Chroma Basic Auth mechanism
35 |
36 | ```go
37 | package main
38 |
39 | import (
40 | "context"
41 | "log"
42 | chroma "github.com/amikos-tech/chroma-go"
43 | "github.com/amikos-tech/chroma-go/types"
44 | )
45 |
46 | func main() {
47 | client, err := chroma.NewClient(
48 | chroma.WithBasePath("http://api.trychroma.com/v1/"),
49 | chroma.WithAuth(types.NewBasicAuthCredentialsProvider("myUser", "myPassword")),
50 | )
51 | if err != nil {
52 | log.Fatalf("Error creating client: %s \n", err)
53 | }
54 | _, err = client.Heartbeat(context.TODO())
55 | if err != nil {
56 | log.Fatalf("Error calling heartbeat: %s \n", err)
57 | }
58 | }
59 | ```
60 |
61 | ### Chroma Token Auth mechanism with Bearer Authorization header
62 |
63 | ```go
64 | package main
65 |
66 | import (
67 | "context"
68 | "log"
69 | chroma "github.com/amikos-tech/chroma-go"
70 | "github.com/amikos-tech/chroma-go/types"
71 | )
72 |
73 | func main() {
74 | client, err := chroma.NewClient(
75 | chroma.WithBasePath("http://api.trychroma.com/v1/"),
76 | chroma.WithAuth(types.NewTokenAuthCredentialsProvider("my-auth-token", types.AuthorizationTokenHeader)),
77 | )
78 | if err != nil {
79 | log.Fatalf("Error creating client: %s \n", err)
80 | }
81 | _, err = client.Heartbeat(context.TODO())
82 | if err != nil {
83 | log.Fatalf("Error calling heartbeat: %s \n", err)
84 | }
85 | }
86 | ```
87 |
88 | ### Chroma Token Auth mechanism with X-Chroma-Token header
89 |
90 | ```go
91 | package main
92 |
93 | import (
94 | "context"
95 | "log"
96 | chroma "github.com/amikos-tech/chroma-go"
97 | "github.com/amikos-tech/chroma-go/types"
98 | )
99 |
100 | func main() {
101 | client, err := chroma.NewClient(
102 | chroma.WithBasePath("http://api.trychroma.com/v1/"),
103 | chroma.WithAuth(types.NewTokenAuthCredentialsProvider("my-auth-token", types.XChromaTokenHeader)),
104 | )
105 | if err != nil {
106 | log.Fatalf("Error creating client: %s \n", err)
107 | }
108 | _, err = client.Heartbeat(context.TODO())
109 | if err != nil {
110 | log.Fatalf("Error calling heartbeat: %s \n", err)
111 | }
112 | }
113 | ```
114 |
115 |
--------------------------------------------------------------------------------
/docs/docs/client.md:
--------------------------------------------------------------------------------
1 | # Chroma Client
2 |
3 | Options:
4 |
5 | | Options | Usage | Description | Value | Required |
6 | |-------------------|-----------------------------------------|-----------------------------------------------------------------------------------------|----------------------------|---------------------------------------|
7 | | basePath | `WithBasePath("http://localhost:8000")` | The Chroma server base API. | Non-empty valid URL string | No (default: `http://localhost:8000`) |
8 | | Tenant | `WithTenant("tenant")` | The default tenant to use. | `string` | No (default: `default_tenant`) |
9 | | Database | `WithDatabase("database")` | The default database to use. | `string` | No (default: `default_database`) |
10 | | Debug | `WithDebug(true/false)` | Enable debug mode. | `bool` | No (default: `false`) |
11 | | Default Headers | `WithDefaultHeaders(map[string]string)` | Set default headers for the client. | `map[string]string` | No (default: `nil`) |
12 | | SSL Cert | `WithSSLCert("path/to/cert.pem")` | Set the path to the SSL certificate. | valid path to SSL cert. | No (default: Not Set) |
13 | | Insecure | `WithInsecure()` | Disable SSL certificate verification | | No (default: Not Set) |
14 | | Custom HttpClient | `WithHTTPClient(http.Client)` | Set a custom http client. If this is set then SSL Cert and Insecure options are ignore. | `*http.Client` | No (default: Default HTTPClient) |
15 |
16 | !!! note "Tenant and Database"
17 |
18 | The tenant and database are only supported for Chroma API version `0.4.15+`.
19 |
20 | Creating a new client:
21 |
22 | ```go
23 | package main
24 |
25 | import (
26 | "context"
27 | "fmt"
28 | "log"
29 | "os"
30 |
31 | chroma "github.com/amikos-tech/chroma-go"
32 | )
33 |
34 | func main() {
35 | client, err := chroma.NewClient(
36 | chroma.WithBasePath("http://localhost:8000"),
37 | chroma.WithTenant("my_tenant"),
38 | chroma.WithDatabase("my_db"),
39 | chroma.WithDebug(true),
40 | chroma.WithDefaultHeaders(map[string]string{"Authorization": "Bearer my token"}),
41 | chroma.WithSSLCert("path/to/cert.pem"),
42 | )
43 | if err != nil {
44 | fmt.Printf("Failed to create client: %v", err)
45 | }
46 | // do something with client
47 |
48 | // Close the client to release any resources such as local embedding functions
49 | err = client.Close()
50 | if err != nil {
51 | fmt.Printf("Failed to close client: %v",err)
52 | }
53 | }
54 | ```
55 |
--------------------------------------------------------------------------------
/docs/docs/filtering.md:
--------------------------------------------------------------------------------
1 | # Filtering
2 |
3 | Chroma offers two types of filters:
4 |
5 | - Metadata - filtering based on metadata attribute values
6 | - Documents - filtering based on document content (contains or not contains)
7 |
8 | ## Metadata
9 |
10 | * TODO - Add builder example
11 | * TODO - Describe all available operations
12 |
13 | ```go
14 | package main
15 |
16 | import (
17 | "context"
18 | "fmt"
19 | chroma "github.com/amikos-tech/chroma-go"
20 | "github.com/amikos-tech/chroma-go/pkg/embeddings/openai"
21 | "github.com/amikos-tech/chroma-go/types"
22 | "github.com/amikos-tech/chroma-go/where"
23 | )
24 |
25 | func main() {
26 | embeddingF, err := openai.NewOpenAIEmbeddingFunction("sk-xxxx")
27 | if err != nil {
28 | fmt.Println(err)
29 | return
30 | }
31 | client, err := chroma.NewClient() // connects to localhost:8000
32 | if err != nil {
33 | fmt.Println(err)
34 | return
35 | }
36 | collection, err := client.GetCollection(context.TODO(), "my-collection", embeddingF)
37 | if err != nil {
38 | fmt.Println(err)
39 | return
40 | }
41 | // Filter by metadata
42 |
43 | result, err := collection.GetWithOptions(
44 | context.Background(),
45 | types.WithWhere(
46 | where.Or(
47 | where.Eq("category", "Chroma"),
48 | where.Eq("type", "vector database"),
49 | ),
50 | ),
51 | )
52 | if err != nil {
53 | fmt.Println(err)
54 | return
55 | }
56 | // do something with result
57 | fmt.Println(result)
58 | }
59 |
60 | ```
61 |
62 | ## Document
63 |
64 | * TODO - Add builder example
65 | * TODO - Describe all available operations
66 |
67 | ```go
68 | package main
69 |
70 | import (
71 | "context"
72 | "fmt"
73 | chroma "github.com/amikos-tech/chroma-go"
74 | "github.com/amikos-tech/chroma-go/pkg/embeddings/openai"
75 | "github.com/amikos-tech/chroma-go/types"
76 | "github.com/amikos-tech/chroma-go/where_document"
77 | )
78 |
79 | func main() {
80 | embeddingF, err := openai.NewOpenAIEmbeddingFunction("sk-xxxx")
81 | if err != nil {
82 | fmt.Println(err)
83 | return
84 | }
85 | client, err := chroma.NewClient(chroma.WithBasePath("http://localhost:8000"))
86 | if err != nil {
87 | fmt.Println(err)
88 | return
89 | }
90 | collection, err := client.GetCollection(context.TODO(), "my-collection", embeddingF)
91 | if err != nil {
92 | fmt.Println(err)
93 | return
94 | }
95 | // Filter by metadata
96 |
97 | result, err := collection.GetWithOptions(
98 | context.Background(),
99 | types.WithWhereDocument(
100 | wheredoc.Or(
101 | wheredoc.Contains("Vector database"),
102 | wheredoc.Contains("Chroma"),
103 | ),
104 | ),
105 | )
106 |
107 | if err != nil {
108 | fmt.Println(err)
109 | return
110 | }
111 | // do something with result
112 | fmt.Println(result)
113 | }
114 | ```
--------------------------------------------------------------------------------
/docs/docs/javascripts/gtag.js:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/docs/docs/records.md:
--------------------------------------------------------------------------------
1 | # Records
2 |
3 | Records are a mechanism that allows you to manage Chroma documents as a cohesive unit. This has several advantages over
4 | the traditional approach of managing documents, ids, embeddings, and metadata separately.
5 |
6 | Two concepts are important to keep in mind here:
7 |
8 | - Record - corresponds to a single document in Chroma which includes id, embedding, metadata, the document or URI
9 | - RecordSet - a single unit of work to insert, upsert, update or delete records.
10 |
11 |
12 | ## Record
13 |
14 | A Record contains the following fields:
15 |
16 | - ID (string)
17 | - Document (string) - optional
18 | - Metadata (map[string]interface{}) - optional
19 | - Embedding ([]float32 or []int32, wrapped in Embedding struct)
20 | - URI (string) - optional
21 |
22 | Here's the `Record` type:
23 |
24 | ```go
25 | package types
26 |
27 | type Record struct {
28 | ID string
29 | Embedding Embedding
30 | Metadata map[string]interface{}
31 | Document string
32 | URI string
33 | err error // indicating whether the record is valid
34 | }
35 | ```
36 |
37 | ## RecordSet
38 |
39 | A record set is a cohesive unit of work, allowing the user to add, upsert, update, or delete records.
40 |
41 |
42 | !!! note "Operation support"
43 |
44 | Currently the record set only supports add operation
45 |
46 | ```go
47 | rs, rerr := types.NewRecordSet(
48 | types.WithEmbeddingFunction(types.NewConsistentHashEmbeddingFunction()),
49 | types.WithIDGenerator(types.NewULIDGenerator()),
50 | )
51 | if err != nil {
52 | log.Fatalf("Error creating record set: %s", err)
53 | }
54 | // you can loop here to add multiple records
55 | rs.WithRecord(types.WithDocument("Document 1 content"), types.WithMetadata("key1", "value1"))
56 | rs.WithRecord(types.WithDocument("Document 2 content"), types.WithMetadata("key2", "value2"))
57 | records, err = rs.BuildAndValidate(context.Background())
58 |
59 | ```
--------------------------------------------------------------------------------
/docs/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: ChromaDB Go Client
2 | site_url: https://go-client.chromadb.dev
3 | repo_url: https://github.com/amikos-tech/chroma-go
4 | copyright: "Amikos Tech LTD, 2024 (core ChromaDB contributors)"
5 | theme:
6 | name: material
7 | palette:
8 | primary: black
9 | logo: assets/images/logo.png
10 | favicon: assets/images/favicon.png
11 | font:
12 | text: Roboto
13 | code: Roboto Mono
14 | features:
15 | - content.code.annotate
16 | - content.code.copy
17 | - navigation.instant
18 | - navigation.instant.progress
19 | - navigation.tracking
20 | - navigation.indexes
21 | extra:
22 | homepage: https://www.trychroma.com
23 | social:
24 | - icon: fontawesome/brands/github
25 | link: https://github.com/chroma-core/chroma
26 | name: Chroma on GitHub
27 | - icon: fontawesome/brands/twitter
28 | link: https://twitter.com/trychroma
29 | name: Chroma on Twitter
30 | - icon: fontawesome/brands/github
31 | link: https://github.com/amikos-tech
32 | name: Amikos on GitHub
33 | - icon: fontawesome/brands/medium
34 | link: https://medium.com/@amikostech
35 | name: Amikos on Medium
36 | analytics:
37 | provider: google
38 | property: G-NNN722BJKE
39 | consent:
40 | title: Cookie consent
41 | description: >-
42 | We use cookies for analytics purposes. By continuing to use this website, you agree to their use.
43 | extra_javascript:
44 | - javascripts/gtag.js
45 | markdown_extensions:
46 | - abbr
47 | - admonition
48 | - attr_list
49 | - md_in_html
50 | - markdown.extensions.extra
51 | - toc:
52 | permalink: true
53 | title: On this page
54 | toc_depth: 3
55 | - tables
56 | - pymdownx.highlight:
57 | anchor_linenums: true
58 | line_spans: __span
59 | pygments_lang_class: true
60 | - pymdownx.inlinehilite
61 | - pymdownx.snippets:
62 | base_path: assets/snippets/
63 | - pymdownx.superfences
64 | - pymdownx.tabbed:
65 | alternate_style: true
66 | - pymdownx.tasklist:
67 | custom_checkbox: true
68 | plugins:
69 | - tags
70 | - search
--------------------------------------------------------------------------------
/examples/v2/basic/README.md:
--------------------------------------------------------------------------------
1 | # Basic Usage
2 |
3 | This example represents a getting started guide on how to use the go client with new Chroma API v2 and Chroma v1.0.x
--------------------------------------------------------------------------------
/examples/v2/basic/go.mod:
--------------------------------------------------------------------------------
1 | module main
2 |
3 | go 1.24
4 |
5 | replace github.com/amikos-tech/chroma-go => ../../../
6 |
7 | require github.com/amikos-tech/chroma-go v0.2.0
8 |
9 | require (
10 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
11 | github.com/google/uuid v1.6.0 // indirect
12 | github.com/oklog/ulid v1.3.1 // indirect
13 | github.com/pkg/errors v0.9.1 // indirect
14 | github.com/yalue/onnxruntime_go v1.19.0 // indirect
15 | )
16 |
--------------------------------------------------------------------------------
/examples/v2/basic/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log"
7 |
8 | chroma "github.com/amikos-tech/chroma-go/pkg/api/v2"
9 | )
10 |
11 | func main() {
12 | // Create a new Chroma client
13 | client, err := chroma.NewHTTPClient(chroma.WithDebug())
14 | if err != nil {
15 | log.Fatalf("Error creating client: %s \n", err)
16 | return
17 | }
18 | // Close the client to release any resources such as local embedding functions
19 | defer func() {
20 | err = client.Close()
21 | if err != nil {
22 | log.Fatalf("Error closing client: %s \n", err)
23 | }
24 | }()
25 |
26 | // Create a new collection with options. We don't provide an embedding function here, so the default embedding function will be used
27 | col, err := client.GetOrCreateCollection(context.Background(), "col1",
28 | chroma.WithCollectionMetadataCreate(
29 | chroma.NewMetadata(
30 | chroma.NewStringAttribute("str", "hello2"),
31 | chroma.NewIntAttribute("int", 1),
32 | chroma.NewFloatAttribute("float", 1.1),
33 | ),
34 | ),
35 | )
36 | if err != nil {
37 | log.Fatalf("Error creating collection: %s \n", err)
38 | return
39 | }
40 |
41 | err = col.Add(context.Background(),
42 | //chroma.WithIDGenerator(chroma.NewULIDGenerator()),
43 | chroma.WithIDs("1", "2"),
44 | chroma.WithTexts("hello world", "goodbye world"),
45 | chroma.WithMetadatas(
46 | chroma.NewDocumentMetadata(chroma.NewIntAttribute("int", 1)),
47 | chroma.NewDocumentMetadata(chroma.NewStringAttribute("str1", "hello2")),
48 | ))
49 | if err != nil {
50 | log.Fatalf("Error adding collection: %s \n", err)
51 | }
52 |
53 | count, err := col.Count(context.Background())
54 | if err != nil {
55 | log.Fatalf("Error counting collection: %s \n", err)
56 | return
57 | }
58 | fmt.Printf("Count collection: %d\n", count)
59 |
60 | qr, err := col.Query(context.Background(),
61 | chroma.WithQueryTexts("say hello"),
62 | chroma.WithIncludeQuery(chroma.IncludeDocuments, chroma.IncludeMetadatas),
63 | )
64 | if err != nil {
65 | log.Fatalf("Error querying collection: %s \n", err)
66 | return
67 | }
68 | fmt.Printf("Query result: %v\n", qr.GetDocumentsGroups()[0][0])
69 |
70 | err = col.Delete(context.Background(), chroma.WithIDsDelete("1", "2"))
71 | if err != nil {
72 | log.Fatalf("Error deleting collection: %s \n", err)
73 | return
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/examples/v2/custom_embedding_function/README.md:
--------------------------------------------------------------------------------
1 | # Custom Embedding Function Example
2 |
3 | > [!NOTE]
4 | > Coming soon...
--------------------------------------------------------------------------------
/examples/v2/custom_embedding_function/go.mod:
--------------------------------------------------------------------------------
1 | module main
2 |
3 | go 1.24.1
4 |
5 | replace github.com/amikos-tech/chroma-go => ../../../
6 |
7 | require github.com/amikos-tech/chroma-go v0.2.0
8 |
9 | require (
10 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
11 | github.com/google/uuid v1.6.0 // indirect
12 | github.com/oklog/ulid v1.3.1 // indirect
13 | github.com/pkg/errors v0.9.1 // indirect
14 | github.com/yalue/onnxruntime_go v1.19.0 // indirect
15 | )
16 |
--------------------------------------------------------------------------------
/examples/v2/custom_embedding_function/go.sum:
--------------------------------------------------------------------------------
1 | github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
2 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
3 | github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U=
4 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
5 | github.com/yalue/onnxruntime_go v1.19.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
6 |
--------------------------------------------------------------------------------
/examples/v2/embedding_function_basic/README.md:
--------------------------------------------------------------------------------
1 | # Embedding Function Usage Example
2 |
3 | This example demonstrates who to use embedding functions.
4 |
5 | ## Running
6 |
7 | ```bash
8 | go mod download
9 | OPENAI_API_KEY=sk-xxxxx go run main.go
10 | ```
11 |
--------------------------------------------------------------------------------
/examples/v2/embedding_function_basic/go.mod:
--------------------------------------------------------------------------------
1 | module main
2 |
3 | go 1.24.1
4 |
5 | replace github.com/amikos-tech/chroma-go => ../../../
6 |
7 | require github.com/amikos-tech/chroma-go v0.2.0
8 |
9 | require (
10 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
11 | github.com/google/uuid v1.6.0 // indirect
12 | github.com/oklog/ulid v1.3.1 // indirect
13 | github.com/pkg/errors v0.9.1 // indirect
14 | github.com/yalue/onnxruntime_go v1.19.0 // indirect
15 | )
16 |
--------------------------------------------------------------------------------
/examples/v2/embedding_function_basic/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "log"
7 | "os"
8 |
9 | chroma "github.com/amikos-tech/chroma-go/pkg/api/v2"
10 | openai "github.com/amikos-tech/chroma-go/pkg/embeddings/openai"
11 | )
12 |
13 | func main() {
14 | // Create a new Chroma client
15 | client, err := chroma.NewHTTPClient(chroma.WithDebug())
16 | if err != nil {
17 | log.Fatalf("Error creating client: %s \n", err)
18 | return
19 | }
20 | // Close the client to release any resources such as local embedding functions
21 | defer func() {
22 | err = client.Close()
23 | if err != nil {
24 | log.Fatalf("Error closing client: %s \n", err)
25 | }
26 | }()
27 |
28 | ef, err := openai.NewOpenAIEmbeddingFunction(os.Getenv("OPENAI_API_KEY"), openai.WithModel(openai.TextEmbedding3Small))
29 | if err != nil {
30 | log.Fatalf("Error creating embedding function: %s \n", err)
31 | return
32 | }
33 |
34 | // Create a new collection with options. We don't provide an embedding function here, so the default embedding function will be used
35 | col, err := client.GetOrCreateCollection(context.Background(), "openai-embedding-function-test",
36 | chroma.WithCollectionMetadataCreate(
37 | chroma.NewMetadata(
38 | chroma.NewStringAttribute("str", "hello2"),
39 | chroma.NewIntAttribute("int", 1),
40 | chroma.NewFloatAttribute("float", 1.1),
41 | ),
42 | ),
43 | chroma.WithEmbeddingFunctionCreate(ef),
44 | )
45 | if err != nil {
46 | log.Fatalf("Error creating collection: %s \n", err)
47 | return
48 | }
49 |
50 | err = col.Add(context.Background(),
51 | //chroma.WithIDGenerator(chroma.NewULIDGenerator()),
52 | chroma.WithIDs("1", "2"),
53 | chroma.WithTexts("hello world", "goodbye world"),
54 | chroma.WithMetadatas(
55 | chroma.NewDocumentMetadata(chroma.NewIntAttribute("int", 1)),
56 | chroma.NewDocumentMetadata(chroma.NewStringAttribute("str1", "hello2")),
57 | ))
58 | if err != nil {
59 | log.Fatalf("Error adding collection: %s \n", err)
60 | }
61 |
62 | count, err := col.Count(context.Background())
63 | if err != nil {
64 | log.Fatalf("Error counting collection: %s \n", err)
65 | return
66 | }
67 | fmt.Printf("Count collection: %d\n", count)
68 |
69 | qr, err := col.Query(context.Background(), chroma.WithQueryTexts("say hello"))
70 | if err != nil {
71 | log.Fatalf("Error querying collection: %s \n", err)
72 | return
73 | }
74 | fmt.Printf("Query result: %v\n", qr.GetDocumentsGroups()[0][0])
75 |
76 | err = col.Delete(context.Background(), chroma.WithIDsDelete("1", "2"))
77 | if err != nil {
78 | log.Fatalf("Error deleting collection: %s \n", err)
79 | return
80 | }
81 | }
82 |
--------------------------------------------------------------------------------
/examples/v2/reranking_function_basic/README.md:
--------------------------------------------------------------------------------
1 | # Reranking Function Usage Examples
2 |
3 | > [!NOTE]
4 | > Coming soon...
--------------------------------------------------------------------------------
/examples/v2/reranking_function_basic/go.mod:
--------------------------------------------------------------------------------
1 | module main
2 |
3 | go 1.24.1
4 |
5 | replace github.com/amikos-tech/chroma-go => ../../../
6 |
7 | require github.com/amikos-tech/chroma-go v0.2.0
8 |
9 | require (
10 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
11 | github.com/google/uuid v1.6.0 // indirect
12 | github.com/oklog/ulid v1.3.1 // indirect
13 | github.com/pkg/errors v0.9.1 // indirect
14 | github.com/yalue/onnxruntime_go v1.19.0 // indirect
15 | )
16 |
--------------------------------------------------------------------------------
/examples/v2/reranking_function_basic/go.sum:
--------------------------------------------------------------------------------
1 | github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
2 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
3 | github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U=
4 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
5 | github.com/yalue/onnxruntime_go v1.19.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
6 |
--------------------------------------------------------------------------------
/examples/v2/tenant_and_db/README.md:
--------------------------------------------------------------------------------
1 | # Tenant and DB management
2 |
3 | This example shows how to manage tenants and databases.
--------------------------------------------------------------------------------
/examples/v2/tenant_and_db/go.mod:
--------------------------------------------------------------------------------
1 | module main
2 |
3 | go 1.24.1
4 |
5 | replace github.com/amikos-tech/chroma-go => ../../../
6 |
7 | require github.com/amikos-tech/chroma-go v0.2.0
8 |
9 | require (
10 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
11 | github.com/google/uuid v1.6.0 // indirect
12 | github.com/oklog/ulid v1.3.1 // indirect
13 | github.com/pkg/errors v0.9.1 // indirect
14 | github.com/yalue/onnxruntime_go v1.19.0 // indirect
15 | )
16 |
--------------------------------------------------------------------------------
/examples/v2/tenant_and_db/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | chroma "github.com/amikos-tech/chroma-go/pkg/api/v2"
7 | "log"
8 | "math/rand"
9 | )
10 |
11 | func main() {
12 | client, err := chroma.NewHTTPClient(chroma.WithDebug())
13 | if err != nil {
14 | log.Fatalf("Error creating client: %s \n", err)
15 | return
16 | }
17 | // Close the client to release any resources such as local embedding functions
18 | defer func() {
19 | err = client.Close()
20 | if err != nil {
21 | log.Fatalf("Error closing client: %s \n", err)
22 | }
23 | }()
24 | r := rand.Int()
25 | tenant, err := client.CreateTenant(context.Background(), chroma.NewTenant(fmt.Sprintf("tenant-%d", r)))
26 | if err != nil {
27 | log.Fatalf("Error creating tenant: %s \n", err)
28 | return
29 | }
30 | fmt.Printf("Created tenant %v\n", tenant)
31 | db1, err := client.CreateDatabase(context.Background(), tenant.Database("db2"))
32 | if err != nil {
33 | log.Fatalf("Error creating database: %s \n", err)
34 | return
35 | }
36 | col, err := client.GetOrCreateCollection(context.Background(), "col1",
37 | chroma.WithDatabaseCreate(db1), chroma.WithCollectionMetadataCreate(
38 | chroma.NewMetadata(
39 | chroma.NewStringAttribute("str", "hello"),
40 | chroma.NewIntAttribute("int", 1),
41 | chroma.NewFloatAttribute("float", 1.1),
42 | ),
43 | ),
44 | )
45 | if err != nil {
46 | log.Fatalf("Error creating collection: %s \n", err)
47 | return
48 | }
49 | fmt.Printf("Created collection %v+\n", col)
50 |
51 | err = col.Add(context.Background(),
52 | //chroma.WithIDGenerator(chroma.NewULIDGenerator()),
53 | chroma.WithIDs("1", "2"),
54 | chroma.WithTexts("hello world", "goodbye world"),
55 | chroma.WithMetadatas(
56 | chroma.NewDocumentMetadata(chroma.NewIntAttribute("int", 1)),
57 | chroma.NewDocumentMetadata(chroma.NewStringAttribute("str", "hello")),
58 | ))
59 | if err != nil {
60 | log.Fatalf("Error adding collection: %s \n", err)
61 | }
62 |
63 | colCount, err := client.CountCollections(context.Background(), chroma.WithDatabaseCount(db1))
64 | if err != nil {
65 | log.Fatalf("Error counting collections: %s \n", err)
66 | return
67 | }
68 | fmt.Printf("Count collections in %s : %d\n", db1.String(), colCount)
69 | cols, err := client.ListCollections(context.Background(), chroma.WithDatabaseList(db1))
70 | if err != nil {
71 | log.Fatalf("Error listing collections: %s \n", err)
72 | return
73 | }
74 | fmt.Printf("List collections in %s : %d\n", db1.String(), len(cols))
75 |
76 | qr, err := col.Query(context.Background(), chroma.WithQueryTexts("say hello"))
77 | if err != nil {
78 | log.Fatalf("Error querying collection: %s \n", err)
79 | return
80 | }
81 | fmt.Printf("Query result: %v\n", qr.GetDocumentsGroups()[0][0])
82 | err = col.Delete(context.Background(), chroma.WithIDsDelete("1", "2"))
83 | if err != nil {
84 | log.Fatalf("Error deleting collection: %s \n", err)
85 | return
86 | }
87 | fmt.Printf("Deleted items from collection %s\n", col.Name())
88 |
89 | err = client.DeleteCollection(context.Background(), "col1", chroma.WithDatabaseDelete(db1))
90 | if err != nil {
91 | log.Fatalf("Error deleting collection: %s \n", err)
92 | return
93 | }
94 | err = client.DeleteDatabase(context.Background(), db1)
95 | if err != nil {
96 | log.Fatalf("Error deleting database: %s \n", err)
97 | return
98 | }
99 | fmt.Printf("Deleted database %s\n", db1)
100 | }
101 |
--------------------------------------------------------------------------------
/gen_api.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 |
4 | #docker run --rm -v ${PWD}:/local swaggerapi/swagger-codegen-cli -v
5 |
6 | docker run --rm -v ${PWD}:/local swaggerapi/swagger-codegen-cli-v3 generate \
7 | -DapiTests=false \
8 | -i /local/openapi.yaml \
9 | -l go \
10 | -o /local/swagger
--------------------------------------------------------------------------------
/gen_api_v3.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | rm -rf swagger
4 | # if openapi-generator-cli.jar exists in the current directory, use it; otherwise, download it
5 | if [ ! -f openapi-generator-cli.jar ]; then
6 | wget https://repo1.maven.org/maven2/org/openapitools/openapi-generator-cli/6.6.0/openapi-generator-cli-6.6.0.jar -O openapi-generator-cli.jar
7 | fi
8 | mkdir generator
9 | cd generator
10 | jar -xf ../openapi-generator-cli.jar
11 | cp ../patches/model_anyof.mustache ./go/
12 | jar -cf ../openapi-generator-cli-patched.jar *
13 | jar -cmf META-INF/MANIFEST.MF ../openapi-generator-cli-patched-fixed.jar *
14 | cd ..
15 | rm -rf generator
16 | # on windows: Invoke-WebRequest -OutFile openapi-generator-cli.jar https://repo1.maven.org/maven2/org/openapitools/openapi-generator-cli/6.6.0/openapi-generator-cli-6.6.0.jar
17 | java -jar openapi-generator-cli-patched-fixed.jar generate -i openapi.yaml -g go -o swagger/
18 |
19 | rm swagger/go.mod
20 | rm swagger/go.sum
21 | rm swagger/.gitignore
22 | rm swagger/.openapi-generator-ignore
23 | rm swagger/.travis.yml
24 | rm swagger/git_push.sh
25 | rm swagger/README.md
26 | rm -rf swagger/test/
27 | rm -rf swagger/docs/
28 | rm -rf swagger/api/
29 | rm -rf swagger/.openapi-generator/
30 |
31 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/amikos-tech/chroma-go
2 |
3 | go 1.23.0
4 |
5 | toolchain go1.23.6
6 |
7 | require (
8 | github.com/Masterminds/semver v1.5.0
9 | github.com/docker/docker v28.0.1+incompatible
10 | github.com/go-playground/validator/v10 v10.22.0
11 | github.com/go-viper/mapstructure/v2 v2.2.1
12 | github.com/google/generative-ai-go v0.19.0
13 | github.com/google/uuid v1.6.0
14 | github.com/joho/godotenv v1.5.1
15 | github.com/leanovate/gopter v0.2.11
16 | github.com/oklog/ulid v1.3.1
17 | github.com/pkg/errors v0.9.1
18 | github.com/stretchr/testify v1.10.0
19 | github.com/testcontainers/testcontainers-go v0.36.0
20 | github.com/testcontainers/testcontainers-go/modules/chroma v0.36.0
21 | github.com/testcontainers/testcontainers-go/modules/ollama v0.36.0
22 | github.com/yalue/onnxruntime_go v1.19.0
23 | google.golang.org/api v0.186.0
24 | )
25 |
26 | require (
27 | cloud.google.com/go v0.115.0 // indirect
28 | cloud.google.com/go/ai v0.8.0 // indirect
29 | cloud.google.com/go/auth v0.6.0 // indirect
30 | cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
31 | cloud.google.com/go/compute/metadata v0.5.0 // indirect
32 | cloud.google.com/go/longrunning v0.5.7 // indirect
33 | dario.cat/mergo v1.0.1 // indirect
34 | github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
35 | github.com/Microsoft/go-winio v0.6.2 // indirect
36 | github.com/cenkalti/backoff/v4 v4.3.0 // indirect
37 | github.com/containerd/log v0.1.0 // indirect
38 | github.com/containerd/platforms v1.0.0-rc.1 // indirect
39 | github.com/cpuguy83/dockercfg v0.3.2 // indirect
40 | github.com/davecgh/go-spew v1.1.1 // indirect
41 | github.com/distribution/reference v0.6.0 // indirect
42 | github.com/docker/go-connections v0.5.0 // indirect
43 | github.com/docker/go-units v0.5.0 // indirect
44 | github.com/ebitengine/purego v0.8.2 // indirect
45 | github.com/felixge/httpsnoop v1.0.4 // indirect
46 | github.com/gabriel-vasile/mimetype v1.4.3 // indirect
47 | github.com/go-logr/logr v1.4.2 // indirect
48 | github.com/go-logr/stdr v1.2.2 // indirect
49 | github.com/go-ole/go-ole v1.2.6 // indirect
50 | github.com/go-playground/locales v0.14.1 // indirect
51 | github.com/go-playground/universal-translator v0.18.1 // indirect
52 | github.com/gogo/protobuf v1.3.2 // indirect
53 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
54 | github.com/golang/protobuf v1.5.4 // indirect
55 | github.com/google/s2a-go v0.1.7 // indirect
56 | github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
57 | github.com/googleapis/gax-go/v2 v2.12.5 // indirect
58 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect
59 | github.com/klauspost/compress v1.17.11 // indirect
60 | github.com/leodido/go-urn v1.4.0 // indirect
61 | github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
62 | github.com/magiconair/properties v1.8.9 // indirect
63 | github.com/moby/docker-image-spec v1.3.1 // indirect
64 | github.com/moby/patternmatcher v0.6.0 // indirect
65 | github.com/moby/sys/sequential v0.6.0 // indirect
66 | github.com/moby/sys/user v0.3.0 // indirect
67 | github.com/moby/sys/userns v0.1.0 // indirect
68 | github.com/moby/term v0.5.0 // indirect
69 | github.com/morikuni/aec v1.0.0 // indirect
70 | github.com/opencontainers/go-digest v1.0.0 // indirect
71 | github.com/opencontainers/image-spec v1.1.1 // indirect
72 | github.com/pmezard/go-difflib v1.0.0 // indirect
73 | github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
74 | github.com/shirou/gopsutil/v4 v4.25.1 // indirect
75 | github.com/sirupsen/logrus v1.9.3 // indirect
76 | github.com/tklauser/go-sysconf v0.3.12 // indirect
77 | github.com/tklauser/numcpus v0.6.1 // indirect
78 | github.com/yusufpapurcu/wmi v1.2.4 // indirect
79 | go.opencensus.io v0.24.0 // indirect
80 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect
81 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect
82 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect
83 | go.opentelemetry.io/otel v1.35.0 // indirect
84 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0 // indirect
85 | go.opentelemetry.io/otel/metric v1.35.0 // indirect
86 | go.opentelemetry.io/otel/trace v1.35.0 // indirect
87 | go.opentelemetry.io/proto/otlp v1.3.1 // indirect
88 | golang.org/x/crypto v0.37.0 // indirect
89 | golang.org/x/net v0.39.0 // indirect
90 | golang.org/x/oauth2 v0.23.0 // indirect
91 | golang.org/x/sync v0.13.0 // indirect
92 | golang.org/x/sys v0.32.0 // indirect
93 | golang.org/x/text v0.24.0 // indirect
94 | golang.org/x/time v0.5.0 // indirect
95 | google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect
96 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 // indirect
97 | google.golang.org/grpc v1.68.1 // indirect
98 | google.golang.org/protobuf v1.35.2 // indirect
99 | gopkg.in/yaml.v3 v3.0.1 // indirect
100 | )
101 |
--------------------------------------------------------------------------------
/internal/http/constants.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | const ChromaGoClientUserAgent = "chroma-go-client/0.1.x"
4 |
--------------------------------------------------------------------------------
/internal/http/errors.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "net/http"
7 | )
8 |
9 | // ChromaError represents an error returned by the Chroma API. It contains the ID of the error, the error message and the status code from the HTTP call.
10 | // Example:
11 | //
12 | // {
13 | // "error": "NotFoundError",
14 | // "message": "Tenant default_tenant2 not found"
15 | // }
16 | type ChromaError struct {
17 | ErrorID string `json:"error"`
18 | ErrorCode int `json:"error_code"`
19 | Message string `json:"message"`
20 | }
21 |
22 | func ChromaErrorFromHTTPResponse(resp *http.Response, err error) *ChromaError {
23 | chromaAPIError := &ChromaError{
24 | ErrorID: "unknown",
25 | Message: "unknown",
26 | }
27 | if err != nil {
28 | chromaAPIError.Message = err.Error()
29 | }
30 | if resp == nil {
31 | return chromaAPIError
32 | }
33 | chromaAPIError.ErrorCode = resp.StatusCode
34 | if err := json.NewDecoder(resp.Body).Decode(chromaAPIError); err != nil {
35 | chromaAPIError.Message = ReadRespBody(resp.Body)
36 | }
37 | return chromaAPIError
38 | }
39 |
40 | func (e *ChromaError) Error() string {
41 | return fmt.Sprintf("Error (%d) %s: %s", e.ErrorCode, e.ErrorID, e.Message)
42 | }
43 |
--------------------------------------------------------------------------------
/internal/http/retry.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "net/http"
7 | "time"
8 | )
9 |
10 | type Option func(*SimpleRetryStrategy) error
11 |
12 | func WithMaxRetries(retries int) Option {
13 | return func(r *SimpleRetryStrategy) error {
14 | if retries <= 0 {
15 | return fmt.Errorf("retries must be a positive integer")
16 | }
17 | r.MaxRetries = retries
18 | return nil
19 | }
20 | }
21 |
22 | func WithFixedDelay(delay time.Duration) Option {
23 | return func(r *SimpleRetryStrategy) error {
24 | if delay <= 0 {
25 | return fmt.Errorf("delay must be a positive integer")
26 | }
27 | r.FixedDelay = delay
28 | return nil
29 | }
30 | }
31 |
32 | func WithRetryableStatusCodes(statusCodes ...int) Option {
33 | return func(r *SimpleRetryStrategy) error {
34 | r.RetryableStatusCodes = statusCodes
35 | return nil
36 | }
37 | }
38 |
39 | func WithExponentialBackOff() Option {
40 | return func(r *SimpleRetryStrategy) error {
41 | r.ExponentialBackOff = true
42 | return nil
43 | }
44 | }
45 |
46 | type SimpleRetryStrategy struct {
47 | MaxRetries int
48 | FixedDelay time.Duration
49 | ExponentialBackOff bool
50 | RetryableStatusCodes []int
51 | }
52 |
53 | func NewSimpleRetryStrategy(opts ...Option) (*SimpleRetryStrategy, error) {
54 | var strategy = &SimpleRetryStrategy{
55 | MaxRetries: 3,
56 | FixedDelay: time.Duration(1000) * time.Millisecond,
57 | RetryableStatusCodes: []int{},
58 | }
59 | for _, opt := range opts {
60 | if err := opt(strategy); err != nil {
61 | return nil, err
62 | }
63 | }
64 | return strategy, nil
65 | }
66 |
67 | func (r *SimpleRetryStrategy) DoWithRetry(client *http.Client, req *http.Request) (*http.Response, error) {
68 | var resp *http.Response
69 | var err error
70 | for i := 0; i < r.MaxRetries; i++ {
71 | resp, err = client.Do(req)
72 | if err != nil {
73 | break
74 | }
75 | if resp.StatusCode >= 200 && resp.StatusCode < 400 {
76 | break
77 | }
78 | if r.isRetryable(resp.StatusCode) {
79 | if r.ExponentialBackOff {
80 | time.Sleep(r.FixedDelay * time.Duration(math.Pow(2, float64(i))))
81 | } else {
82 | time.Sleep(r.FixedDelay)
83 | }
84 | }
85 | }
86 | return resp, err
87 | }
88 |
89 | func (r *SimpleRetryStrategy) isRetryable(code int) bool {
90 | for _, retryableCode := range r.RetryableStatusCodes {
91 | if code == retryableCode {
92 | return true
93 | }
94 | }
95 | return false
96 | }
97 |
--------------------------------------------------------------------------------
/internal/http/retry_test.go:
--------------------------------------------------------------------------------
1 | //go:build basic
2 |
3 | package http
4 |
5 | import (
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 | "time"
10 |
11 | "github.com/stretchr/testify/require"
12 | )
13 |
14 | func TestRetryStrategyWithExponentialBackOff(t *testing.T) {
15 | client := &http.Client{}
16 |
17 | retryableStatusCodes := []int{http.StatusInternalServerError}
18 |
19 | // Create a new SimpleRetryStrategy with exponential backoff enabled
20 | retryStrategy, err := NewSimpleRetryStrategy(
21 | WithMaxRetries(3),
22 | WithFixedDelay(100*time.Millisecond),
23 | WithRetryableStatusCodes(retryableStatusCodes...),
24 | WithExponentialBackOff(),
25 | )
26 | require.NoError(t, err, "error setting up strategy: %v", err)
27 | var serverRetries = 0
28 | // Create a test server that always returns a 500 status code
29 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
30 | serverRetries++
31 | w.WriteHeader(http.StatusInternalServerError)
32 | }))
33 | defer server.Close()
34 |
35 | req, err := http.NewRequest("GET", server.URL, nil)
36 | require.NoError(t, err, "unexpected error: %v", err)
37 |
38 | startTime := time.Now()
39 |
40 | _, err = retryStrategy.DoWithRetry(client, req)
41 | if err != nil {
42 | t.Fatalf("unexpected error: %v", err)
43 | }
44 | // Calculate the total elapsed time
45 | elapsedTime := time.Since(startTime)
46 | // Since we have exponential backoff with delays 100ms, 200ms, 400ms, the total delay should be at least 700ms
47 | expectedMinDelay := 100*time.Millisecond + 200*time.Millisecond + 400*time.Millisecond
48 | require.Less(t, expectedMinDelay, elapsedTime, "expected total delay to be at least %v, but got %v", expectedMinDelay, elapsedTime)
49 | require.Equal(t, 3, serverRetries)
50 | }
51 |
--------------------------------------------------------------------------------
/internal/http/strategy.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import "net/http"
4 |
5 | type RetryStrategy interface {
6 | DoWithRetry(client *http.Client, req *http.Request) (*http.Response, error)
7 | }
8 |
--------------------------------------------------------------------------------
/internal/http/utils.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import "io"
4 |
5 | func ReadRespBody(resp io.Reader) string {
6 | if resp == nil {
7 | return ""
8 | }
9 | body, err := io.ReadAll(resp)
10 | if err != nil {
11 | return ""
12 | }
13 | return string(body)
14 | }
15 |
--------------------------------------------------------------------------------
/metadata/metadata.go:
--------------------------------------------------------------------------------
1 | package metadata
2 |
3 | import (
4 | "github.com/amikos-tech/chroma-go/types"
5 | )
6 |
7 | type MetadataBuilder struct {
8 | Metadata map[string]interface{}
9 | }
10 |
11 | func NewMetadataBuilder(metadata *map[string]interface{}) *MetadataBuilder {
12 | if metadata != nil {
13 | return &MetadataBuilder{Metadata: *metadata}
14 | }
15 | return &MetadataBuilder{Metadata: make(map[string]interface{})}
16 | }
17 |
18 | type Option func(*MetadataBuilder) error
19 |
20 | func WithMetadata(key string, value interface{}) Option {
21 | return func(b *MetadataBuilder) error {
22 | switch value.(type) {
23 | case string, int, float32, bool, int32, uint32, int64, uint64:
24 | b.Metadata[key] = value
25 | case types.DistanceFunction:
26 | b.Metadata[key] = value
27 | default:
28 | return &types.InvalidMetadataValueError{Key: key, Value: value}
29 | }
30 | return nil
31 | }
32 | }
33 |
34 | func WithMetadatas(metadata map[string]interface{}) Option {
35 | return func(b *MetadataBuilder) error {
36 | for k, v := range metadata {
37 | switch v.(type) {
38 | case string, int, float32, bool:
39 | b.Metadata[k] = v
40 | default:
41 | return &types.InvalidMetadataValueError{Key: k, Value: v}
42 | }
43 | }
44 | return nil
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/metadata/metadata_test.go:
--------------------------------------------------------------------------------
1 | //go:build basic
2 |
3 | package metadata
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/stretchr/testify/require"
9 |
10 | "github.com/amikos-tech/chroma-go/test"
11 | )
12 |
13 | func TestWithMetadata(t *testing.T) {
14 | t.Run("Test invalid Metadata", func(t *testing.T) {
15 | builder := NewMetadataBuilder(nil)
16 | err := WithMetadata("testKey", map[string]interface{}{"invalid": "value"})(builder)
17 | require.Error(t, err)
18 | })
19 |
20 | t.Run("Test int", func(t *testing.T) {
21 | actual := make(map[string]interface{})
22 | builder := NewMetadataBuilder(&actual)
23 | err := WithMetadata("testKey", 1)(builder)
24 | require.NoError(t, err)
25 | expected := map[string]interface{}{
26 | "testKey": 1,
27 | }
28 | test.Compare(t, actual, expected)
29 | })
30 | t.Run("Test float32", func(t *testing.T) {
31 | actual := make(map[string]interface{})
32 | builder := NewMetadataBuilder(&actual)
33 | err := WithMetadata("testKey", float32(1.1))(builder)
34 | require.NoError(t, err)
35 | expected := map[string]interface{}{
36 | "testKey": float32(1.1),
37 | }
38 | test.Compare(t, actual, expected)
39 | })
40 |
41 | t.Run("Test bool", func(t *testing.T) {
42 | actual := make(map[string]interface{})
43 | builder := NewMetadataBuilder(&actual)
44 | err := WithMetadata("testKey", true)(builder)
45 | require.NoError(t, err)
46 | expected := map[string]interface{}{
47 | "testKey": true,
48 | }
49 | test.Compare(t, builder.Metadata, expected)
50 | })
51 |
52 | t.Run("Test string", func(t *testing.T) {
53 | actual := make(map[string]interface{})
54 | builder := NewMetadataBuilder(&actual)
55 | err := WithMetadata("testKey", "value")(builder)
56 | require.NoError(t, err)
57 | expected := map[string]interface{}{
58 | "testKey": "value",
59 | }
60 | test.Compare(t, actual, expected)
61 | })
62 |
63 | t.Run("Test all types", func(t *testing.T) {
64 | actual := make(map[string]interface{})
65 | builder := NewMetadataBuilder(&actual)
66 | err := WithMetadata("testKey", "value")(builder)
67 | require.NoError(t, err)
68 | err = WithMetadata("testKey2", 1)(builder)
69 | require.NoError(t, err)
70 | err = WithMetadata("testKey3", true)(builder)
71 | require.NoError(t, err)
72 | err = WithMetadata("testKey4", float32(1.1))(builder)
73 | require.NoError(t, err)
74 | expected := map[string]interface{}{
75 | "testKey": "value",
76 | "testKey2": 1,
77 | "testKey3": true,
78 | "testKey4": float32(1.1),
79 | }
80 | test.Compare(t, actual, expected)
81 | })
82 | }
83 |
--------------------------------------------------------------------------------
/patches/model_anyof.mustache:
--------------------------------------------------------------------------------
1 | // {{classname}} {{{description}}}{{^description}}struct for {{{classname}}}{{/description}}
2 | type {{classname}} struct {
3 | {{#anyOf}}
4 | {{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}} *{{{.}}}
5 | {{/anyOf}}
6 | }
7 |
8 | // Unmarshal JSON data into any of the pointers in the struct
9 | func (dst *{{classname}}) UnmarshalJSON(data []byte) error {
10 | var err error
11 | {{#isNullable}}
12 | // this object is nullable so check if the payload is null or empty string
13 | if string(data) == "" || string(data) == "{}" {
14 | return nil
15 | }
16 |
17 | {{/isNullable}}
18 | {{#discriminator}}
19 | {{#mappedModels}}
20 | {{#-first}}
21 | // use discriminator value to speed up the lookup
22 | var jsonDict map[string]interface{}
23 | err = json.Unmarshal(data, &jsonDict)
24 | if err != nil {
25 | return fmt.Errorf("failed to unmarshal JSON into map for the discriminator lookup")
26 | }
27 |
28 | {{/-first}}
29 | // check if the discriminator value is '{{{mappingName}}}'
30 | if jsonDict["{{{propertyBaseName}}}"] == "{{{mappingName}}}" {
31 | // try to unmarshal JSON data into {{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}}
32 | err = json.Unmarshal(data, &dst.{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}});
33 | if err == nil {
34 | json{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}}, _ := json.Marshal(dst.{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}})
35 | if string(json{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}}) == "{}" { // empty struct
36 | dst.{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}} = nil
37 | } else {
38 | return nil // data stored in dst.{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}}, return on the first match
39 | }
40 | } else {
41 | dst.{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}} = nil
42 | }
43 | }
44 |
45 | {{/mappedModels}}
46 | {{/discriminator}}
47 | {{#anyOf}}
48 | // try to unmarshal JSON data into {{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}}
49 | err = json.Unmarshal(data, &dst.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}});
50 | if err == nil {
51 | json{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}}, _ := json.Marshal(dst.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}})
52 | if string(json{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}}) == "{}" { // empty struct
53 | dst.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}} = nil
54 | } else {
55 | return nil // data stored in dst.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}}, return on the first match
56 | }
57 | } else {
58 | dst.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}} = nil
59 | }
60 |
61 | {{/anyOf}}
62 | return fmt.Errorf("data failed to match schemas in anyOf({{classname}})")
63 | }
64 |
65 | // Marshal data from the first non-nil pointers in the struct to JSON
66 | func (src *{{classname}}) MarshalJSON() ([]byte, error) {
67 | {{#anyOf}}
68 | if src.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}} != nil {
69 | return json.Marshal(&src.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}})
70 | }
71 |
72 | {{/anyOf}}
73 | return nil, nil // no data in anyOf schemas
74 | }
75 |
76 | {{>nullable_model}}
77 |
--------------------------------------------------------------------------------
/pkg/api/v2/auth.go:
--------------------------------------------------------------------------------
1 | package v2
2 |
3 | import (
4 | "encoding/base64"
5 |
6 | "github.com/pkg/errors"
7 | )
8 |
9 | type CredentialsProvider interface {
10 | Authenticate(apiClient *BaseAPIClient) error
11 | }
12 |
13 | type BasicAuthCredentialsProvider struct {
14 | Username string
15 | Password string
16 | }
17 |
18 | func NewBasicAuthCredentialsProvider(username, password string) *BasicAuthCredentialsProvider {
19 | return &BasicAuthCredentialsProvider{
20 | Username: username,
21 | Password: password,
22 | }
23 | }
24 |
25 | func (b *BasicAuthCredentialsProvider) Authenticate(client *BaseAPIClient) error {
26 | auth := b.Username + ":" + b.Password
27 | encodedAuth := base64.StdEncoding.EncodeToString([]byte(auth))
28 | client.defaultHeaders["Authorization"] = "Basic " + encodedAuth
29 | return nil
30 | }
31 |
32 | type TokenTransportHeader string
33 |
34 | const (
35 | AuthorizationTokenHeader TokenTransportHeader = "Authorization"
36 | XChromaTokenHeader TokenTransportHeader = "X-Chroma-Token"
37 | )
38 |
39 | type TokenAuthCredentialsProvider struct {
40 | Token string
41 | Header TokenTransportHeader
42 | }
43 |
44 | func NewTokenAuthCredentialsProvider(token string, header TokenTransportHeader) *TokenAuthCredentialsProvider {
45 | return &TokenAuthCredentialsProvider{
46 | Token: token,
47 | Header: header,
48 | }
49 | }
50 |
51 | func (t *TokenAuthCredentialsProvider) Authenticate(client *BaseAPIClient) error {
52 | switch t.Header {
53 | case AuthorizationTokenHeader:
54 | client.defaultHeaders[string(t.Header)] = "Bearer " + t.Token
55 | return nil
56 | case XChromaTokenHeader:
57 | client.defaultHeaders[string(t.Header)] = t.Token
58 | return nil
59 | default:
60 | return errors.Errorf("unsupported token header: %v", t.Header)
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/pkg/api/v2/base.go:
--------------------------------------------------------------------------------
1 | package v2
2 |
3 | import (
4 | "encoding/json"
5 |
6 | "github.com/go-viper/mapstructure/v2"
7 | "github.com/pkg/errors"
8 | )
9 |
10 | type Tenant interface {
11 | Name() string
12 | String() string
13 | Database(dbName string) Database
14 | Validate() error
15 | }
16 |
17 | type Database interface {
18 | ID() string
19 | Name() string
20 | Tenant() Tenant
21 | String() string
22 | Validate() error
23 | }
24 |
25 | type Include string
26 |
27 | const (
28 | IncludeMetadatas Include = "metadatas"
29 | IncludeDocuments Include = "documents"
30 | IncludeEmbeddings Include = "embeddings"
31 | IncludeURIs Include = "uris"
32 | )
33 |
34 | type Identity struct {
35 | UserID string `json:"user_id"`
36 | Tenant string `json:"tenant"`
37 | Databases []string `json:"databases"`
38 | }
39 |
40 | type TenantBase struct {
41 | TenantName string `json:"name"`
42 | }
43 |
44 | func (t *TenantBase) Name() string {
45 | return t.TenantName
46 | }
47 |
48 | func NewTenant(name string) Tenant {
49 | return &TenantBase{TenantName: name}
50 | }
51 |
52 | func NewTenantFromJSON(jsonString string) (Tenant, error) {
53 | tenant := &TenantBase{}
54 | err := json.Unmarshal([]byte(jsonString), tenant)
55 | if err != nil {
56 | return nil, err
57 | }
58 | return tenant, nil
59 | }
60 | func (t *TenantBase) String() string {
61 | return t.Name()
62 | }
63 |
64 | func (t *TenantBase) Validate() error {
65 | if t.TenantName == "" {
66 | return errors.New("tenant name cannot be empty")
67 | }
68 | return nil
69 | }
70 |
71 | // Database returns a new Database object that can be used for creating collections
72 | func (t *TenantBase) Database(dbName string) Database {
73 | return NewDatabase(dbName, t)
74 | }
75 |
76 | // TODO this may fail for v1 API
77 | // func (t *TenantBase) MarshalJSON() ([]byte, error) {
78 | // return []byte(`"` + t.Name() + `"`), nil
79 | //}
80 |
81 | func NewDefaultTenant() Tenant {
82 | return NewTenant(DefaultTenant)
83 | }
84 |
85 | type DatabaseBase struct {
86 | DBName string `json:"name" mapstructure:"name"`
87 | DBID string `json:"id,omitempty" mapstructure:"id"`
88 | TenantName string `json:"tenant,omitempty" mapstructure:"tenant"`
89 | tenant Tenant
90 | }
91 |
92 | func (d DatabaseBase) Name() string {
93 | return d.DBName
94 | }
95 |
96 | func (d DatabaseBase) Tenant() Tenant {
97 | if d.tenant == nil && d.TenantName != "" {
98 | d.tenant = NewTenant(d.TenantName)
99 | }
100 | return d.tenant
101 | }
102 |
103 | func (d DatabaseBase) String() string {
104 | return d.Name()
105 | }
106 |
107 | func (d DatabaseBase) ID() string {
108 | return d.DBID
109 | }
110 | func (d DatabaseBase) Validate() error {
111 | if d.DBName == "" {
112 | return errors.New("database name cannot be empty")
113 | }
114 | if d.tenant == nil {
115 | return errors.New("tenant cannot be empty")
116 | }
117 | return nil
118 | }
119 |
120 | // TODO this may fail for v1 API
121 | // func (d *DatabaseBase) MarshalJSON() ([]byte, error) {
122 | // return []byte(`"` + d.Name() + `"`), nil
123 | //}
124 |
125 | func NewDatabase(name string, tenant Tenant) Database {
126 | return &DatabaseBase{DBName: name, tenant: tenant}
127 | }
128 |
129 | func NewDatabaseFromJSON(jsonString string) (Database, error) {
130 | database := &DatabaseBase{}
131 | err := json.Unmarshal([]byte(jsonString), database)
132 | if err != nil {
133 | return nil, err
134 | }
135 | if database.TenantName != "" {
136 | database.tenant = NewTenant(database.TenantName)
137 | } else {
138 | database.tenant = NewDefaultTenant()
139 | }
140 | return database, nil
141 | }
142 |
143 | func NewDatabaseFromMap(data map[string]interface{}) (Database, error) {
144 | database := &DatabaseBase{}
145 | err := mapstructure.Decode(data, database)
146 | if err != nil {
147 | return nil, errors.Wrap(err, "error decoding database")
148 | }
149 | if database.TenantName != "" {
150 | database.tenant = NewTenant(database.TenantName)
151 | } else {
152 | database.tenant = NewDefaultTenant()
153 | }
154 | return database, nil
155 | }
156 |
157 | func NewDefaultDatabase() Database {
158 | return NewDatabase(DefaultDatabase, NewDefaultTenant())
159 | }
160 |
--------------------------------------------------------------------------------
/pkg/api/v2/constants.go:
--------------------------------------------------------------------------------
1 | package v2
2 |
3 | const (
4 | DefaultTenant = "default_tenant"
5 | DefaultDatabase = "default_database"
6 | HNSWSpace = "hnsw:space"
7 | HNSWConstructionEF = "hnsw:construction_ef"
8 | HNSWBatchSize = "hnsw:batch_size"
9 | HNSWSyncThreshold = "hnsw:sync_threshold"
10 | HNSWM = "hnsw:M"
11 | HNSWSearchEF = "hnsw:search_ef"
12 | HNSWNumThreads = "hnsw:num_threads"
13 | HNSWResizeFactor = "hnsw:resize_factor"
14 | )
15 |
--------------------------------------------------------------------------------
/pkg/api/v2/document_test.go:
--------------------------------------------------------------------------------
1 | //go:build basicv2
2 |
3 | package v2
4 |
5 | import (
6 | "encoding/json"
7 | "github.com/stretchr/testify/require"
8 | "testing"
9 | )
10 |
11 | func TestTextDocument(t *testing.T) {
12 |
13 | doc := "Hello, world!\n"
14 |
15 | tdoc := NewTextDocument(doc)
16 |
17 | marshal, err := json.Marshal(tdoc)
18 | require.NoError(t, err)
19 | require.Equal(t, `"Hello, world!\n"`, string(marshal))
20 | }
21 |
--------------------------------------------------------------------------------
/pkg/api/v2/ids.go:
--------------------------------------------------------------------------------
1 | package v2
2 |
3 | import (
4 | "crypto/sha256"
5 | "encoding/hex"
6 | "math/rand"
7 | "time"
8 |
9 | "github.com/google/uuid"
10 | "github.com/oklog/ulid"
11 | )
12 |
13 | type GenerateOptions struct {
14 | Document string
15 | }
16 |
17 | type IDGeneratorOption func(opts *GenerateOptions)
18 |
19 | func WithDocument(document string) IDGeneratorOption {
20 | return func(opts *GenerateOptions) {
21 | opts.Document = document
22 | }
23 | }
24 |
25 | type IDGenerator interface {
26 | Generate(opts ...IDGeneratorOption) string
27 | }
28 |
29 | type UUIDGenerator struct{}
30 |
31 | func (u *UUIDGenerator) Generate(opts ...IDGeneratorOption) string {
32 | uuidV4 := uuid.New()
33 | return uuidV4.String()
34 | }
35 |
36 | func NewUUIDGenerator() *UUIDGenerator {
37 | return &UUIDGenerator{}
38 | }
39 |
40 | type SHA256Generator struct{}
41 |
42 | func (s *SHA256Generator) Generate(opts ...IDGeneratorOption) string {
43 | op := GenerateOptions{}
44 | for _, opt := range opts {
45 | opt(&op)
46 | }
47 | if op.Document == "" {
48 | op.Document = uuid.New().String()
49 | }
50 | hasher := sha256.New()
51 | hasher.Write([]byte(op.Document))
52 | sha256Hash := hex.EncodeToString(hasher.Sum(nil))
53 | return sha256Hash
54 | }
55 |
56 | func NewSHA256Generator() *SHA256Generator {
57 | return &SHA256Generator{}
58 | }
59 |
60 | type ULIDGenerator struct{}
61 |
62 | func (u *ULIDGenerator) Generate(opts ...IDGeneratorOption) string {
63 | t := time.Now()
64 | entropy := rand.New(rand.NewSource(t.UnixNano()))
65 | docULID := ulid.MustNew(ulid.Timestamp(t), entropy)
66 | return docULID.String()
67 | }
68 |
69 | func NewULIDGenerator() *ULIDGenerator {
70 | return &ULIDGenerator{}
71 | }
72 |
--------------------------------------------------------------------------------
/pkg/api/v2/record.go:
--------------------------------------------------------------------------------
1 | package v2
2 |
3 | import (
4 | "github.com/pkg/errors"
5 |
6 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
7 | )
8 |
9 | type Record interface {
10 | ID() DocumentID
11 | Document() Document // should work for both text and URI based documents
12 | Embedding() embeddings.Embedding
13 | Metadata() DocumentMetadata
14 | Validate() error
15 | Unwrap() (DocumentID, Document, embeddings.Embedding, DocumentMetadata)
16 | }
17 |
18 | type Records []Record
19 |
20 | type SimpleRecord struct {
21 | id string
22 | embedding embeddings.Embedding
23 | metadata DocumentMetadata
24 | document string
25 | uri string
26 | err error // indicating whether the record is valid or nto
27 | }
28 | type RecordOption func(record *SimpleRecord) error
29 |
30 | func WithRecordID(id string) RecordOption {
31 | return func(r *SimpleRecord) error {
32 | r.id = id
33 | return nil
34 | }
35 | }
36 |
37 | func WithRecordEmbedding(embedding embeddings.Embedding) RecordOption {
38 | return func(r *SimpleRecord) error {
39 | r.embedding = embedding
40 | return nil
41 | }
42 | }
43 |
44 | func WithRecordMetadatas(metadata DocumentMetadata) RecordOption {
45 | return func(r *SimpleRecord) error {
46 | r.metadata = metadata
47 | return nil
48 | }
49 | }
50 | func (r *SimpleRecord) constructValidate() error {
51 | if r.id == "" {
52 | return errors.New("record id is empty")
53 | }
54 | return nil
55 | }
56 | func NewSimpleRecord(opts ...RecordOption) (*SimpleRecord, error) {
57 | r := &SimpleRecord{}
58 | for _, opt := range opts {
59 | err := opt(r)
60 | if err != nil {
61 | return nil, errors.Wrap(err, "error applying record option")
62 | }
63 | }
64 |
65 | err := r.constructValidate()
66 | if err != nil {
67 | return nil, errors.Wrap(err, "error validating record")
68 | }
69 | return r, nil
70 | }
71 |
72 | func (r *SimpleRecord) ID() DocumentID {
73 | return DocumentID(r.id)
74 | }
75 |
76 | func (r *SimpleRecord) Document() Document {
77 | return NewTextDocument(r.document)
78 | }
79 |
80 | func (r *SimpleRecord) URI() string {
81 | return r.uri
82 | }
83 |
84 | func (r *SimpleRecord) Embedding() embeddings.Embedding {
85 | return r.embedding
86 | }
87 |
88 | func (r *SimpleRecord) Metadata() DocumentMetadata {
89 | return r.metadata
90 | }
91 |
92 | func (r *SimpleRecord) Validate() error {
93 | return r.err
94 | }
95 |
96 | func (r *SimpleRecord) Unwrap() (DocumentID, Document, embeddings.Embedding, DocumentMetadata) {
97 | return r.ID(), r.Document(), r.Embedding(), r.Metadata()
98 | }
99 |
--------------------------------------------------------------------------------
/pkg/api/v2/record_test.go:
--------------------------------------------------------------------------------
1 | //go:build basicv2
2 |
3 | package v2
4 |
5 | import (
6 | "testing"
7 |
8 | "github.com/stretchr/testify/require"
9 |
10 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
11 | )
12 |
13 | func TestSimpleRecord(t *testing.T) {
14 | record, err := NewSimpleRecord(WithRecordID("1"),
15 | WithRecordEmbedding(embeddings.NewEmbeddingFromFloat32([]float32{1, 2, 3})),
16 | WithRecordMetadatas(NewDocumentMetadata(NewStringAttribute("key", "value"))))
17 | require.NoError(t, err)
18 | require.NotNil(t, record)
19 | }
20 |
--------------------------------------------------------------------------------
/pkg/api/v2/reranking.go:
--------------------------------------------------------------------------------
1 | package v2
2 |
--------------------------------------------------------------------------------
/pkg/api/v2/results_test.go:
--------------------------------------------------------------------------------
1 | //go:build basicv2
2 |
3 | package v2
4 |
5 | import (
6 | "encoding/json"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/require"
10 |
11 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
12 | )
13 |
14 | func TestGetResultDeserialization(t *testing.T) {
15 | var apiResponse = `{
16 | "documents": [
17 | "document1",
18 | "document2"
19 | ],
20 | "embeddings": [
21 | [0.1,0.2],
22 | [0.3,0.4]
23 | ],
24 | "ids": [
25 | "id1",
26 | "id2"
27 | ],
28 | "include": [
29 | "distances"
30 | ],
31 | "metadatas": [
32 | {
33 | "additionalProp1": true,
34 | "additionalProp2": 1,
35 | "additionalProp3": "test"
36 | },
37 | {"additionalProp1": false}
38 | ]
39 | }`
40 |
41 | var result GetResultImpl
42 | err := json.Unmarshal([]byte(apiResponse), &result)
43 | require.NoError(t, err)
44 | require.Len(t, result.GetDocuments(), 2)
45 | require.Len(t, result.GetIDs(), 2)
46 | require.Equal(t, result.GetIDs()[0], DocumentID("id1"))
47 | require.Equal(t, result.GetDocuments()[0], NewTextDocument("document1"))
48 | require.Equal(t, []float32{0.1, 0.2}, result.GetEmbeddings()[0].ContentAsFloat32())
49 | require.Len(t, result.GetEmbeddings(), 2)
50 | require.Len(t, result.GetMetadatas(), 2)
51 | }
52 |
53 | func TestQueryResultDeserialization(t *testing.T) {
54 | var apiResponse = `{
55 | "distances": [
56 | [
57 | 0.1
58 | ]
59 | ],
60 | "documents": [
61 | [
62 | "string"
63 | ]
64 | ],
65 | "embeddings": [
66 | [
67 | [
68 | 0.1
69 | ]
70 | ]
71 | ],
72 | "ids": [
73 | [
74 | "id1"
75 | ]
76 | ],
77 | "include": [
78 | "distances"
79 | ],
80 | "metadatas": [
81 | [
82 | {
83 | "additionalProp1": true,
84 | "additionalProp2": true,
85 | "additionalProp3": true
86 | }
87 | ]
88 | ]
89 | }`
90 |
91 | var result QueryResultImpl
92 | err := json.Unmarshal([]byte(apiResponse), &result)
93 | require.NoError(t, err)
94 | require.Len(t, result.GetIDGroups(), 1)
95 | require.Len(t, result.GetIDGroups()[0], 1)
96 | require.Equal(t, DocumentID("id1"), result.GetIDGroups()[0][0])
97 |
98 | require.Len(t, result.GetDocumentsGroups(), 1)
99 | require.Len(t, result.GetDocumentsGroups()[0], 1)
100 | require.Equal(t, NewTextDocument("string"), result.GetDocumentsGroups()[0][0])
101 |
102 | require.Len(t, result.GetEmbeddingsGroups(), 1)
103 | require.Len(t, result.GetEmbeddingsGroups()[0], 1)
104 | require.Equal(t, []float32{0.1}, result.GetEmbeddingsGroups()[0][0].ContentAsFloat32())
105 |
106 | require.Len(t, result.GetMetadatasGroups(), 1)
107 | require.Len(t, result.GetMetadatasGroups()[0], 1)
108 | metadata := NewDocumentMetadata(
109 | NewBoolAttribute("additionalProp1", true),
110 | NewBoolAttribute("additionalProp3", true),
111 | NewBoolAttribute("additionalProp2", true),
112 | )
113 | require.Equal(t, metadata, result.GetMetadatasGroups()[0][0])
114 |
115 | require.Len(t, result.GetDistancesGroups(), 1)
116 | require.Len(t, result.GetDistancesGroups()[0], 1)
117 | require.Equal(t, embeddings.Distance(0.1), result.GetDistancesGroups()[0][0])
118 | }
119 |
--------------------------------------------------------------------------------
/pkg/api/v2/server.htpasswd:
--------------------------------------------------------------------------------
1 | admin:$2y$05$sHZUfxjz/M70r02rdtrYgurCrSIZjWGDyzm5i2CeN0dLsHd9qc8Zy
2 |
3 |
--------------------------------------------------------------------------------
/pkg/api/v2/utils.go:
--------------------------------------------------------------------------------
1 | package v2
2 |
3 | import (
4 | "crypto/ecdsa"
5 | "crypto/elliptic"
6 | "crypto/rand"
7 | "crypto/x509"
8 | "crypto/x509/pkix"
9 | "encoding/pem"
10 | "log"
11 | "math/big"
12 | "net"
13 | "os"
14 | "time"
15 | )
16 |
17 | func CreateSelfSignedCert(certPath, keyPath string) {
18 | // Generate a private key
19 | privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
20 | if err != nil {
21 | log.Fatalf("Failed to generate private key: %v", err)
22 | }
23 |
24 | // Prepare certificate
25 | notBefore := time.Now()
26 | notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year
27 |
28 | serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
29 | serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
30 | if err != nil {
31 | log.Fatalf("Failed to generate serial number: %v", err)
32 | }
33 |
34 | template := x509.Certificate{
35 | SerialNumber: serialNumber,
36 | Subject: pkix.Name{
37 | Organization: []string{"Chroma, Inc."},
38 | CommonName: "localhost",
39 | },
40 | NotBefore: notBefore,
41 | NotAfter: notAfter,
42 | KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
43 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
44 | BasicConstraintsValid: true,
45 | DNSNames: []string{"localhost"},
46 | IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
47 | }
48 |
49 | // Create the certificate
50 | derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
51 | if err != nil {
52 | log.Fatalf("Failed to create certificate: %v", err)
53 | }
54 |
55 | // Write the certificate to file
56 | certOut, err := os.Create(certPath)
57 | if err != nil {
58 | log.Fatalf("Failed to open cert.pem for writing: %v", err)
59 | }
60 | if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
61 | log.Fatalf("Failed to write data to cert.pem: %v", err)
62 | }
63 | if err := certOut.Close(); err != nil {
64 | log.Fatalf("Error closing cert.pem: %v", err)
65 | }
66 | log.Printf("Written %s", certPath)
67 |
68 | // Write the private key to file
69 | keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
70 | if err != nil {
71 | log.Fatalf("Failed to open key.pem for writing: %v", err)
72 | }
73 | privBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
74 | if err != nil {
75 | log.Fatalf("Unable to marshal private key: %v", err)
76 | }
77 | if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
78 | log.Fatalf("Failed to write data to key.pem: %v", err)
79 | }
80 | if err := keyOut.Close(); err != nil {
81 | log.Fatalf("Error closing key.pem: %v", err)
82 | }
83 | log.Printf("Written %s", keyPath)
84 | }
85 |
--------------------------------------------------------------------------------
/pkg/api/v2/v1-config.yaml:
--------------------------------------------------------------------------------
1 | ########################
2 | # HTTP server settings #
3 | ########################
4 | port: 8000
5 | listen_address: "0.0.0.0"
6 |
7 | ####################
8 | # General settings #
9 | ####################
10 | persist_path: "/data"
11 | allow_reset: true
12 |
--------------------------------------------------------------------------------
/pkg/api/v2/where_document_test.go:
--------------------------------------------------------------------------------
1 | //go:build basicv2
2 |
3 | package v2
4 |
5 | import (
6 | "testing"
7 | )
8 |
9 | func TestWhereDocument(t *testing.T) {
10 | tests := []struct {
11 | name string
12 | filter WhereDocumentFilter
13 | expected string
14 | }{
15 | {
16 | name: "contain",
17 | filter: Contains("test"),
18 | expected: `{"$contains":"test"}`,
19 | },
20 | {
21 | name: "not contain",
22 | filter: NotContains("test"),
23 | expected: `{"$not_contains":"test"}`,
24 | },
25 | {
26 | name: "or",
27 | filter: OrDocument(Contains("test"), NotContains("test")),
28 | expected: `{"$or":[{"$contains":"test"},{"$not_contains":"test"}]}`,
29 | },
30 | {
31 | name: "and",
32 | filter: AndDocument(Contains("test"), NotContains("test")),
33 | expected: `{"$and":[{"$contains":"test"},{"$not_contains":"test"}]}`,
34 | },
35 | {
36 | name: "or and",
37 | filter: OrDocument(AndDocument(Contains("test"), NotContains("test")), Contains("test")),
38 | expected: `{"$or":[{"$and":[{"$contains":"test"},{"$not_contains":"test"}]},{"$contains":"test"}]}`,
39 | },
40 | }
41 | for _, test := range tests {
42 | t.Run(test.name, func(t *testing.T) {
43 | actual, err := test.filter.MarshalJSON()
44 | if err != nil {
45 | t.Errorf("error marshalling filter: %v", err)
46 | }
47 | if string(actual) != test.expected {
48 | t.Errorf("expected %s, got %s", test.expected, string(actual))
49 | }
50 | })
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/pkg/commons/cohere/cohere_commons_test.go:
--------------------------------------------------------------------------------
1 | //go:build ef || rf
2 |
3 | package cohere
4 |
5 | import (
6 | "fmt"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/require"
10 | )
11 |
12 | func TestValidations(t *testing.T) {
13 | tests := []struct {
14 | name string
15 | options []Option
16 | expectedError string
17 | }{
18 | {
19 | name: "Test empty API key",
20 | options: []Option{
21 | WithDefaultModel("model"),
22 | },
23 | expectedError: "'apiKey' failed on the 'required'",
24 | },
25 | {
26 | name: "Test without default model",
27 | options: []Option{
28 | WithAPIKey("dummy"),
29 | },
30 | expectedError: "'DefaultModel' failed on the 'required'",
31 | },
32 | }
33 | for _, tt := range tests {
34 | t.Run(tt.name, func(t *testing.T) {
35 | _, err := NewCohereClient(tt.options...)
36 | fmt.Printf("Error: %v\n", err)
37 | require.Error(t, err)
38 | require.Contains(t, err.Error(), tt.expectedError)
39 | })
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/pkg/commons/http/constants.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | const ChromaGoClientUserAgent = "chroma-go-client/0.1.x"
4 |
--------------------------------------------------------------------------------
/pkg/commons/http/errors.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "net/http"
7 | )
8 |
9 | // ChromaError represents an error returned by the Chroma API. It contains the ID of the error, the error message and the status code from the HTTP call.
10 | // Example:
11 | //
12 | // {
13 | // "error": "NotFoundError",
14 | // "message": "Tenant default_tenant2 not found"
15 | // }
16 | type ChromaError struct {
17 | ErrorID string `json:"error"`
18 | ErrorCode int `json:"error_code"`
19 | Message string `json:"message"`
20 | }
21 |
22 | func ChromaErrorFromHTTPResponse(resp *http.Response, err error) *ChromaError {
23 | chromaAPIError := &ChromaError{
24 | ErrorID: "unknown",
25 | Message: "unknown",
26 | }
27 | if err != nil {
28 | chromaAPIError.Message = err.Error()
29 | }
30 | if resp == nil {
31 | return chromaAPIError
32 | }
33 | chromaAPIError.ErrorCode = resp.StatusCode
34 | if err := json.NewDecoder(resp.Body).Decode(chromaAPIError); err != nil {
35 | chromaAPIError.Message = ReadRespBody(resp.Body)
36 | }
37 | return chromaAPIError
38 | }
39 |
40 | func (e *ChromaError) Error() string {
41 | return fmt.Sprintf("Error (%d) %s: %s", e.ErrorCode, e.ErrorID, e.Message)
42 | }
43 |
--------------------------------------------------------------------------------
/pkg/commons/http/retry.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "net/http"
7 | "time"
8 | )
9 |
10 | type Option func(*SimpleRetryStrategy) error
11 |
12 | func WithMaxRetries(retries int) Option {
13 | return func(r *SimpleRetryStrategy) error {
14 | if retries <= 0 {
15 | return fmt.Errorf("retries must be a positive integer")
16 | }
17 | r.MaxRetries = retries
18 | return nil
19 | }
20 | }
21 |
22 | func WithFixedDelay(delay time.Duration) Option {
23 | return func(r *SimpleRetryStrategy) error {
24 | if delay <= 0 {
25 | return fmt.Errorf("delay must be a positive integer")
26 | }
27 | r.FixedDelay = delay
28 | return nil
29 | }
30 | }
31 |
32 | func WithRetryableStatusCodes(statusCodes ...int) Option {
33 | return func(r *SimpleRetryStrategy) error {
34 | r.RetryableStatusCodes = statusCodes
35 | return nil
36 | }
37 | }
38 |
39 | func WithExponentialBackOff() Option {
40 | return func(r *SimpleRetryStrategy) error {
41 | r.ExponentialBackOff = true
42 | return nil
43 | }
44 | }
45 |
46 | type SimpleRetryStrategy struct {
47 | MaxRetries int
48 | FixedDelay time.Duration
49 | ExponentialBackOff bool
50 | RetryableStatusCodes []int
51 | }
52 |
53 | func NewSimpleRetryStrategy(opts ...Option) (*SimpleRetryStrategy, error) {
54 | var strategy = &SimpleRetryStrategy{
55 | MaxRetries: 3,
56 | FixedDelay: time.Duration(1000) * time.Millisecond,
57 | RetryableStatusCodes: []int{},
58 | }
59 | for _, opt := range opts {
60 | if err := opt(strategy); err != nil {
61 | return nil, err
62 | }
63 | }
64 | return strategy, nil
65 | }
66 |
67 | func (r *SimpleRetryStrategy) DoWithRetry(client *http.Client, req *http.Request) (*http.Response, error) {
68 | var resp *http.Response
69 | var err error
70 | for i := 0; i < r.MaxRetries; i++ {
71 | resp, err = client.Do(req)
72 | if err != nil {
73 | break
74 | }
75 | if resp.StatusCode >= 200 && resp.StatusCode < 400 {
76 | break
77 | }
78 | if r.isRetryable(resp.StatusCode) {
79 | if r.ExponentialBackOff {
80 | time.Sleep(r.FixedDelay * time.Duration(math.Pow(2, float64(i))))
81 | } else {
82 | time.Sleep(r.FixedDelay)
83 | }
84 | }
85 | }
86 | return resp, err
87 | }
88 |
89 | func (r *SimpleRetryStrategy) isRetryable(code int) bool {
90 | for _, retryableCode := range r.RetryableStatusCodes {
91 | if code == retryableCode {
92 | return true
93 | }
94 | }
95 | return false
96 | }
97 |
--------------------------------------------------------------------------------
/pkg/commons/http/retry_test.go:
--------------------------------------------------------------------------------
1 | //go:build basic
2 |
3 | package http
4 |
5 | import (
6 | "net/http"
7 | "net/http/httptest"
8 | "testing"
9 | "time"
10 |
11 | "github.com/stretchr/testify/require"
12 | )
13 |
14 | func TestRetryStrategyWithExponentialBackOff(t *testing.T) {
15 | client := &http.Client{}
16 |
17 | retryableStatusCodes := []int{http.StatusInternalServerError}
18 |
19 | // Create a new SimpleRetryStrategy with exponential backoff enabled
20 | retryStrategy, err := NewSimpleRetryStrategy(
21 | WithMaxRetries(3),
22 | WithFixedDelay(100*time.Millisecond),
23 | WithRetryableStatusCodes(retryableStatusCodes...),
24 | WithExponentialBackOff(),
25 | )
26 | require.NoError(t, err, "error setting up strategy: %v", err)
27 | var serverRetries = 0
28 | // Create a test server that always returns a 500 status code
29 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
30 | serverRetries++
31 | w.WriteHeader(http.StatusInternalServerError)
32 | }))
33 | defer server.Close()
34 |
35 | req, err := http.NewRequest("GET", server.URL, nil)
36 | require.NoError(t, err, "unexpected error: %v", err)
37 |
38 | startTime := time.Now()
39 |
40 | _, err = retryStrategy.DoWithRetry(client, req)
41 | if err != nil {
42 | t.Fatalf("unexpected error: %v", err)
43 | }
44 | // Calculate the total elapsed time
45 | elapsedTime := time.Since(startTime)
46 | // Since we have exponential backoff with delays 100ms, 200ms, 400ms, the total delay should be at least 700ms
47 | expectedMinDelay := 100*time.Millisecond + 200*time.Millisecond + 400*time.Millisecond
48 | require.Less(t, expectedMinDelay, elapsedTime, "expected total delay to be at least %v, but got %v", expectedMinDelay, elapsedTime)
49 | require.Equal(t, 3, serverRetries)
50 | }
51 |
--------------------------------------------------------------------------------
/pkg/commons/http/strategy.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import "net/http"
4 |
5 | type RetryStrategy interface {
6 | DoWithRetry(client *http.Client, req *http.Request) (*http.Response, error)
7 | }
8 |
--------------------------------------------------------------------------------
/pkg/commons/http/utils.go:
--------------------------------------------------------------------------------
1 | package http
2 |
3 | import "io"
4 |
5 | func ReadRespBody(resp io.Reader) string {
6 | if resp == nil {
7 | return ""
8 | }
9 | body, err := io.ReadAll(resp)
10 | if err != nil {
11 | return ""
12 | }
13 | return string(body)
14 | }
15 |
--------------------------------------------------------------------------------
/pkg/embeddings/cloudflare/option.go:
--------------------------------------------------------------------------------
1 | package cloudflare
2 |
3 | import (
4 | "net/http"
5 | "os"
6 |
7 | "github.com/pkg/errors"
8 |
9 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
10 | )
11 |
12 | type Option func(p *CloudflareClient) error
13 |
14 | func WithGatewayEndpoint(endpoint string) Option {
15 | return func(p *CloudflareClient) error {
16 | if endpoint == "" {
17 | return errors.New("endpoint cannot be empty")
18 | }
19 | p.BaseAPI = endpoint
20 | p.IsGateway = true
21 | return nil
22 | }
23 | }
24 |
25 | func WithDefaultModel(model embeddings.EmbeddingModel) Option {
26 | return func(p *CloudflareClient) error {
27 | p.DefaultModel = model
28 | return nil
29 | }
30 | }
31 |
32 | func WithMaxBatchSize(size int) Option {
33 | return func(p *CloudflareClient) error {
34 | if size <= 0 {
35 | return errors.New("max batch size must be greater than 0")
36 | }
37 | p.MaxBatchSize = size
38 | return nil
39 | }
40 | }
41 |
42 | func WithDefaultHeaders(headers map[string]string) Option {
43 | return func(p *CloudflareClient) error {
44 | p.DefaultHeaders = headers
45 | return nil
46 | }
47 | }
48 |
49 | func WithAPIToken(apiToken string) Option {
50 | return func(p *CloudflareClient) error {
51 | p.APIToken = apiToken
52 | return nil
53 | }
54 | }
55 |
56 | func WithAccountID(accountID string) Option {
57 | return func(p *CloudflareClient) error {
58 | if accountID == "" {
59 | return errors.New("account ID cannot be empty")
60 | }
61 | p.AccountID = accountID
62 | return nil
63 | }
64 | }
65 |
66 | func WithEnvAPIToken() Option {
67 | return func(p *CloudflareClient) error {
68 | if apiToken := os.Getenv("CF_API_TOKEN"); apiToken != "" {
69 | p.APIToken = apiToken
70 | return nil
71 | }
72 | return errors.Errorf("CF_API_TOKEN not set")
73 | }
74 | }
75 |
76 | func WithEnvAccountID() Option {
77 | return func(p *CloudflareClient) error {
78 | if accountID := os.Getenv("CF_ACCOUNT_ID"); accountID != "" {
79 | p.AccountID = accountID
80 | return nil
81 | }
82 | return errors.Errorf("CF_ACCOUNT_ID not set")
83 | }
84 | }
85 |
86 | func WithHTTPClient(client *http.Client) Option {
87 | return func(p *CloudflareClient) error {
88 | if client == nil {
89 | return errors.New("http client cannot be nil")
90 | }
91 | p.Client = client
92 | return nil
93 | }
94 | }
95 |
96 | func WithEnvGatewayEndpoint() Option {
97 | return func(p *CloudflareClient) error {
98 | if endpoint := os.Getenv("CF_GATEWAY_ENDPOINT"); endpoint != "" {
99 | p.BaseAPI = endpoint
100 | p.IsGateway = true
101 | return nil
102 | }
103 | return errors.Errorf("CF_GATEWAY_ENDPOINT not set")
104 | }
105 | }
106 |
--------------------------------------------------------------------------------
/pkg/embeddings/cohere/cohere_test.go:
--------------------------------------------------------------------------------
1 | //go:build ef
2 |
3 | package cohere
4 |
5 | import (
6 | "context"
7 | "fmt"
8 | "os"
9 | "testing"
10 |
11 | ccommons "github.com/amikos-tech/chroma-go/pkg/commons/cohere"
12 |
13 | "github.com/joho/godotenv"
14 | "github.com/stretchr/testify/assert"
15 | "github.com/stretchr/testify/require"
16 | )
17 |
18 | func Test_ef(t *testing.T) {
19 | apiKey := os.Getenv("COHERE_API_KEY")
20 | if apiKey == "" {
21 | err := godotenv.Load("../../../.env")
22 | if err != nil {
23 | assert.Failf(t, "Error loading .env file", "%s", err)
24 | }
25 | apiKey = os.Getenv("COHERE_API_KEY")
26 | }
27 |
28 | t.Run("Test Create Embed", func(t *testing.T) {
29 | ef, err := NewCohereEmbeddingFunction(WithAPIKey(apiKey))
30 | require.NoError(t, err)
31 | documents := []string{
32 | "Document 1 content here",
33 | "Document 2 content here",
34 | // Add more documents as needed
35 | }
36 | resp, err := ef.EmbedDocuments(context.Background(), documents)
37 | require.Nil(t, err)
38 | require.NotNil(t, resp)
39 | })
40 |
41 | t.Run("Test Create Embed with model option", func(t *testing.T) {
42 | ef, err := NewCohereEmbeddingFunction(WithAPIKey(apiKey), WithModel("embed-multilingual-v3.0"))
43 | require.NoError(t, err)
44 | documents := []string{
45 | "Document 1 content here",
46 | "Document 2 content here",
47 | // Add more documents as needed
48 | }
49 | resp, rerr := ef.EmbedDocuments(context.Background(), documents)
50 | require.Nil(t, rerr)
51 | require.NotNil(t, resp)
52 | })
53 |
54 | t.Run("Test Create Embed with model option embeddings type uint8", func(t *testing.T) {
55 | ef, err := NewCohereEmbeddingFunction(WithAPIKey(apiKey), WithModel("embed-multilingual-v3.0"), WithEmbeddingTypes(EmbeddingTypeUInt8))
56 | require.NoError(t, err)
57 | documents := []string{
58 | "Document 1 content here",
59 | "Document 2 content here",
60 | // Add more documents as needed
61 | }
62 | resp, err := ef.EmbedDocuments(context.Background(), documents)
63 | require.Nil(t, err)
64 | require.NotNil(t, resp)
65 | require.Len(t, resp, 2)
66 | fmt.Printf("resp %T\n", resp[0])
67 | require.Empty(t, resp[0].ContentAsFloat32())
68 | require.NotNil(t, resp[0].ContentAsInt32())
69 | })
70 |
71 | t.Run("Test Create Embed with model option embeddings type int8", func(t *testing.T) {
72 | ef, err := NewCohereEmbeddingFunction(WithEnvAPIKey(), WithModel("embed-multilingual-v3.0"), WithEmbeddingTypes(EmbeddingTypeInt8))
73 | require.NoError(t, err)
74 | documents := []string{
75 | "Document 1 content here",
76 | "Document 2 content here",
77 | // Add more documents as needed
78 | }
79 | resp, err := ef.EmbedDocuments(context.Background(), documents)
80 | require.Nil(t, err)
81 | require.NotNil(t, resp)
82 | require.Len(t, resp, 2)
83 | require.Empty(t, resp[0].ContentAsFloat32())
84 | require.NotNil(t, resp[0].ContentAsInt32())
85 | })
86 |
87 | t.Run("Test Create Embed for query", func(t *testing.T) {
88 | ef, err := NewCohereEmbeddingFunction(
89 | WithEnvAPIKey(),
90 | WithModel("embed-multilingual-v3.0"),
91 | )
92 | require.NoError(t, err)
93 | resp, err := ef.EmbedQuery(context.Background(), "This is a query")
94 | require.Nil(t, err)
95 | require.NotNil(t, resp)
96 | require.NotNil(t, resp.ContentAsFloat32())
97 | require.Empty(t, resp.ContentAsInt32())
98 | })
99 |
100 | t.Run("Test With API options", func(t *testing.T) {
101 | ef, err := NewCohereEmbeddingFunction(
102 | WithEnvAPIKey(),
103 | WithBaseURL(ccommons.DefaultBaseURL),
104 | WithAPIVersion(ccommons.DefaultAPIVersion),
105 | WithModel("embed-multilingual-v3.0"),
106 | )
107 | require.NoError(t, err)
108 | resp, err := ef.EmbedQuery(context.Background(), "This is a query")
109 | require.Nil(t, err)
110 | require.NotNil(t, resp)
111 | require.NotNil(t, resp.ContentAsFloat32())
112 | require.Empty(t, resp.ContentAsInt32())
113 | })
114 | }
115 |
--------------------------------------------------------------------------------
/pkg/embeddings/cohere/option.go:
--------------------------------------------------------------------------------
1 | package cohere
2 |
3 | import (
4 | "github.com/pkg/errors"
5 |
6 | ccommons "github.com/amikos-tech/chroma-go/pkg/commons/cohere"
7 | httpc "github.com/amikos-tech/chroma-go/pkg/commons/http"
8 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
9 | )
10 |
11 | type Option func(p *CohereEmbeddingFunction) ccommons.Option
12 |
13 | // WithBaseURL sets the base URL for the Cohere API - the default is https://api.cohere.ai
14 | func WithBaseURL(baseURL string) Option {
15 | return func(p *CohereEmbeddingFunction) ccommons.Option {
16 | return ccommons.WithBaseURL(baseURL)
17 | }
18 | }
19 |
20 | func WithAPIKey(apiKey string) Option {
21 | return func(p *CohereEmbeddingFunction) ccommons.Option {
22 | return ccommons.WithAPIKey(apiKey)
23 | }
24 | }
25 |
26 | // WithEnvAPIKey configures the client to use the COHERE_API_KEY environment variable as the API key
27 | func WithEnvAPIKey() Option {
28 | return func(p *CohereEmbeddingFunction) ccommons.Option {
29 | return ccommons.WithEnvAPIKey()
30 | }
31 | }
32 |
33 | func WithAPIVersion(apiVersion ccommons.APIVersion) Option {
34 | return func(p *CohereEmbeddingFunction) ccommons.Option {
35 | return ccommons.WithAPIVersion(apiVersion)
36 | }
37 | }
38 |
39 | // WithModel sets the default model for the Cohere API - Available models:
40 | // embed-english-v3.0 1024
41 | // embed-multilingual-v3.0 1024
42 | // embed-english-light-v3.0 384
43 | // embed-multilingual-light-v3.0 384
44 | // embed-english-v2.0 4096 (default)
45 | // embed-english-light-v2.0 1024
46 | // embed-multilingual-v2.0 768
47 | func WithModel(model embeddings.EmbeddingModel) Option {
48 | return func(p *CohereEmbeddingFunction) ccommons.Option {
49 | return ccommons.WithDefaultModel(model)
50 | }
51 | }
52 |
53 | // WithDefaultModel sets the default model for the Cohere. This can be overridden in the context of EF embed call. Available models:
54 | // embed-english-v3.0 1024
55 | // embed-multilingual-v3.0 1024
56 | // embed-english-light-v3.0 384
57 | // embed-multilingual-light-v3.0 384
58 | // embed-english-v2.0 4096 (default)
59 | // embed-english-light-v2.0 1024
60 | // embed-multilingual-v2.0 768
61 | func WithDefaultModel(model embeddings.EmbeddingModel) Option {
62 | return func(p *CohereEmbeddingFunction) ccommons.Option {
63 | return ccommons.WithDefaultModel(model)
64 | }
65 | }
66 |
67 | // WithTruncateMode sets the default truncate mode for the Cohere API - Available modes:
68 | // NONE
69 | // START
70 | // END (default)
71 | func WithTruncateMode(truncate TruncateMode) Option {
72 | return func(p *CohereEmbeddingFunction) ccommons.Option {
73 | if truncate != NONE && truncate != START && truncate != END {
74 | return func(c *ccommons.CohereClient) error {
75 | return errors.Errorf("invalid truncate mode %s", truncate)
76 | }
77 | }
78 | p.DefaultTruncateMode = truncate
79 | return ccommons.NoOp()
80 | }
81 | }
82 |
83 | // WithEmbeddingTypes sets the default embedding types for the Cohere API - Available types:
84 | // float (default)
85 | // int8
86 | // uint8
87 | // binary
88 | // ubinary
89 | // TODO we do not have support for returning multiple embedding types from the EmbeddingFunction, so for float->int8, unit8 are supported and returned in the that order
90 | func WithEmbeddingTypes(embeddingTypes ...EmbeddingType) Option {
91 | return func(p *CohereEmbeddingFunction) ccommons.Option {
92 | // if embeddingstypes contains binary or ubinary error
93 | for _, et := range embeddingTypes {
94 | if et == EmbeddingTypeBinary || et == EmbeddingTypeUBinary {
95 | return func(c *ccommons.CohereClient) error {
96 | return errors.Errorf("embedding type %s is not supported", et)
97 | }
98 | }
99 | }
100 | // if embeddingstypes is empty, set to default
101 | if len(embeddingTypes) == 0 {
102 | embeddingTypes = []EmbeddingType{EmbeddingTypeFloat32}
103 | }
104 | p.DefaultEmbeddingTypes = embeddingTypes
105 | return ccommons.NoOp()
106 | }
107 | }
108 |
109 | // WithRetryStrategy configures the client to use the specified retry strategy
110 | func WithRetryStrategy(retryStrategy httpc.RetryStrategy) Option {
111 | return func(p *CohereEmbeddingFunction) ccommons.Option {
112 | return ccommons.WithRetryStrategy(retryStrategy)
113 | }
114 | }
115 |
--------------------------------------------------------------------------------
/pkg/embeddings/default_ef/constants.go:
--------------------------------------------------------------------------------
1 | package defaultef
2 |
3 | const (
4 | LibTokenizersVersion = "0.9.0"
5 | LibOnnxRuntimeVersion = "1.21.0"
6 | onnxModelDownloadEndpoint = "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz"
7 | ChromaCacheDir = ".cache/chroma/"
8 | )
9 |
--------------------------------------------------------------------------------
/pkg/embeddings/default_ef/default_ef_test.go:
--------------------------------------------------------------------------------
1 | package defaultef
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/require"
8 | )
9 |
10 | func Test_Default_EF(t *testing.T) {
11 | ef, closeEf, err := NewDefaultEmbeddingFunction()
12 | require.NoError(t, err)
13 | t.Cleanup(func() {
14 | err := closeEf()
15 | if err != nil {
16 | t.Logf("error while closing embedding function: %v", err)
17 | }
18 | })
19 | require.NotNil(t, ef)
20 | embeddings, err := ef.EmbedDocuments(context.TODO(), []string{"Hello Chroma!", "Hello world!"})
21 | require.NoError(t, err)
22 | require.NotNil(t, embeddings)
23 | require.Len(t, embeddings, 2)
24 | for _, embedding := range embeddings {
25 | require.Equal(t, embedding.Len(), 384)
26 | }
27 | }
28 |
29 | func TestClose(t *testing.T) {
30 | ef, closeEf, err := NewDefaultEmbeddingFunction()
31 | require.NoError(t, err)
32 | require.NotNil(t, ef)
33 | err = closeEf()
34 | require.NoError(t, err)
35 | _, err = ef.EmbedQuery(context.TODO(), "Hello Chroma!")
36 | require.Error(t, err)
37 | require.Contains(t, err.Error(), "embedding function is closed")
38 | }
39 | func TestCloseClosed(t *testing.T) {
40 | ef := &DefaultEmbeddingFunction{}
41 | err := ef.Close()
42 | require.NoError(t, err)
43 | }
44 |
--------------------------------------------------------------------------------
/pkg/embeddings/default_ef/download_utils_test.go:
--------------------------------------------------------------------------------
1 | package defaultef
2 |
3 | import (
4 | "fmt"
5 | "os"
6 | "path/filepath"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/require"
10 | )
11 |
12 | func TestDownload(t *testing.T) {
13 | t.Run("Download", func(t *testing.T) {
14 | onnxLibPath = filepath.Join(t.TempDir(), "libonnxruntime."+LibOnnxRuntimeVersion+"."+getExtensionForOs())
15 | fmt.Println(onnxLibPath)
16 | err := os.RemoveAll(onnxLibPath)
17 | require.NoError(t, err)
18 | err = EnsureOnnxRuntimeSharedLibrary()
19 | require.NoError(t, err)
20 | })
21 | t.Run("Download Tokenizers", func(t *testing.T) {
22 | libTokenizersLibPath = filepath.Join(t.TempDir(), "libtokenizers."+LibTokenizersVersion+"."+getExtensionForOs())
23 | err := os.RemoveAll(libTokenizersLibPath)
24 | require.NoError(t, err)
25 | err = EnsureLibTokenizersSharedLibrary()
26 | require.NoError(t, err)
27 | })
28 | t.Run("Download Model", func(t *testing.T) {
29 | onnxModelCachePath = filepath.Join(t.TempDir(), "onnx_model")
30 | err := os.RemoveAll(onnxModelCachePath)
31 | require.NoError(t, err)
32 | err = EnsureDefaultEmbeddingFunctionModel()
33 | require.NoError(t, err)
34 | })
35 | }
36 |
--------------------------------------------------------------------------------
/pkg/embeddings/distance_metric.go:
--------------------------------------------------------------------------------
1 | package embeddings
2 |
3 | type DistanceMetric string
4 |
5 | const (
6 | L2 DistanceMetric = "l2"
7 | COSINE DistanceMetric = "cosine"
8 | IP DistanceMetric = "ip"
9 | )
10 |
11 | type DistanceMetricOperator interface {
12 | Compare(a, b []float32) float64
13 | }
14 |
15 | type Distance float32
16 | type Distances []Distance
17 |
--------------------------------------------------------------------------------
/pkg/embeddings/embedding_test.go:
--------------------------------------------------------------------------------
1 | package embeddings
2 |
3 | import (
4 | "encoding/json"
5 | "testing"
6 |
7 | "github.com/stretchr/testify/require"
8 | )
9 |
10 | func TestMarshalEmbeddings(t *testing.T) {
11 | embed := NewEmbeddingFromFloat32([]float32{1.1234567891, 2.4, 3.5})
12 |
13 | bytes, err := json.Marshal(embed)
14 | require.NoError(t, err)
15 | require.JSONEq(t, `[1.1234568,2.4,3.5]`, string(bytes))
16 | }
17 |
18 | func TestUnmarshalEmbeddings(t *testing.T) {
19 | var embed Float32Embedding
20 | jsonStr := `[1.1234568,2.4,3.5]`
21 |
22 | err := json.Unmarshal([]byte(jsonStr), &embed)
23 | require.NoError(t, err)
24 | require.Equal(t, 3, embed.Len())
25 | require.Equal(t, float32(1.1234568), embed.ContentAsFloat32()[0])
26 | require.Equal(t, float32(2.4), embed.ContentAsFloat32()[1])
27 | require.Equal(t, float32(3.5), embed.ContentAsFloat32()[2])
28 | }
29 |
--------------------------------------------------------------------------------
/pkg/embeddings/gemini/gemini.go:
--------------------------------------------------------------------------------
1 | package gemini
2 |
3 | import (
4 | "context"
5 |
6 | "github.com/google/generative-ai-go/genai"
7 | "github.com/pkg/errors"
8 | "google.golang.org/api/option"
9 |
10 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
11 | )
12 |
13 | const (
14 | DefaultEmbeddingModel = "text-embedding-004"
15 | ModelContextVar = "model"
16 | APIKeyEnvVar = "GEMINI_API_KEY"
17 | )
18 |
19 | type Client struct {
20 | apiKey string
21 | DefaultModel embeddings.EmbeddingModel
22 | Client *genai.Client
23 | DefaultContext *context.Context
24 | MaxBatchSize int
25 | }
26 |
27 | func applyDefaults(c *Client) (err error) {
28 | if c.DefaultModel == "" {
29 | c.DefaultModel = DefaultEmbeddingModel
30 | }
31 |
32 | if c.DefaultContext == nil {
33 | ctx := context.Background()
34 | c.DefaultContext = &ctx
35 | }
36 |
37 | if c.Client == nil {
38 | c.Client, err = genai.NewClient(*c.DefaultContext, option.WithAPIKey(c.apiKey))
39 | if err != nil {
40 | return errors.WithStack(err)
41 | }
42 | }
43 | return nil
44 | }
45 |
46 | func validate(c *Client) error {
47 | if c.apiKey == "" {
48 | return errors.New("API key is required")
49 | }
50 | return nil
51 | }
52 |
53 | func NewGeminiClient(opts ...Option) (*Client, error) {
54 | client := &Client{}
55 |
56 | for _, opt := range opts {
57 | err := opt(client)
58 | if err != nil {
59 | return nil, errors.Wrap(err, "failed to apply Gemini option")
60 | }
61 | }
62 | err := applyDefaults(client)
63 | if err != nil {
64 | return nil, err
65 | }
66 | if err := validate(client); err != nil {
67 | return nil, errors.Wrap(err, "failed to validate Gemini client options")
68 | }
69 | return client, nil
70 | }
71 |
72 | func (c *Client) CreateEmbedding(ctx context.Context, req []string) ([]embeddings.Embedding, error) {
73 | var em *genai.EmbeddingModel
74 | if ctx.Value(ModelContextVar) != nil {
75 | em = c.Client.EmbeddingModel(ctx.Value(ModelContextVar).(string))
76 | } else {
77 | em = c.Client.EmbeddingModel(string(c.DefaultModel))
78 | }
79 | b := em.NewBatch()
80 | for _, t := range req {
81 | b.AddContent(genai.Text(t))
82 | }
83 | res, err := em.BatchEmbedContents(ctx, b)
84 | if err != nil {
85 | return nil, errors.Wrap(err, "failed to embed contents")
86 | }
87 | var embs = make([][]float32, 0)
88 | for _, e := range res.Embeddings {
89 | embs = append(embs, e.Values)
90 | }
91 |
92 | return embeddings.NewEmbeddingsFromFloat32(embs)
93 | }
94 |
95 | // close closes the underlying client
96 | //
97 | //nolint:unused
98 | func (c *Client) close() error {
99 | return c.Client.Close()
100 | }
101 |
102 | var _ embeddings.EmbeddingFunction = (*GeminiEmbeddingFunction)(nil)
103 |
104 | type GeminiEmbeddingFunction struct {
105 | apiClient *Client
106 | }
107 |
108 | func NewGeminiEmbeddingFunction(opts ...Option) (*GeminiEmbeddingFunction, error) {
109 | client, err := NewGeminiClient(opts...)
110 | if err != nil {
111 | return nil, err
112 | }
113 |
114 | return &GeminiEmbeddingFunction{apiClient: client}, nil
115 | }
116 |
117 | // close closes the underlying client
118 | //
119 | //nolint:unused
120 | func (e *GeminiEmbeddingFunction) close() error {
121 | return e.apiClient.close()
122 | }
123 |
124 | func (e *GeminiEmbeddingFunction) EmbedDocuments(ctx context.Context, documents []string) ([]embeddings.Embedding, error) {
125 | if e.apiClient.MaxBatchSize > 0 && len(documents) > e.apiClient.MaxBatchSize {
126 | return nil, errors.Errorf("number of documents exceeds the maximum batch size %v", e.apiClient.MaxBatchSize)
127 | }
128 | if len(documents) == 0 {
129 | return embeddings.NewEmptyEmbeddings(), nil
130 | }
131 |
132 | response, err := e.apiClient.CreateEmbedding(ctx, documents)
133 | if err != nil {
134 | return nil, errors.Wrap(err, "failed to embed documents")
135 | }
136 | return response, nil
137 | }
138 |
139 | func (e *GeminiEmbeddingFunction) EmbedQuery(ctx context.Context, document string) (embeddings.Embedding, error) {
140 | response, err := e.apiClient.CreateEmbedding(ctx, []string{document})
141 | if err != nil {
142 | return nil, errors.Wrap(err, "failed to embed query")
143 | }
144 | return response[0], nil
145 | }
146 |
--------------------------------------------------------------------------------
/pkg/embeddings/gemini/gemini_test.go:
--------------------------------------------------------------------------------
1 | //go:build ef
2 |
3 | package gemini
4 |
5 | import (
6 | "context"
7 | "os"
8 | "testing"
9 |
10 | "github.com/joho/godotenv"
11 | "github.com/stretchr/testify/assert"
12 | "github.com/stretchr/testify/require"
13 | )
14 |
15 | func Test_gemini_client(t *testing.T) {
16 | apiKey := os.Getenv(APIKeyEnvVar)
17 | if apiKey == "" {
18 | err := godotenv.Load("../../../.env")
19 | if err != nil {
20 | assert.Failf(t, "Error loading .env file", "%s", err)
21 | }
22 | apiKey = os.Getenv(APIKeyEnvVar)
23 | }
24 | client, err := NewGeminiClient(WithEnvAPIKey())
25 | require.NoError(t, err)
26 | defer func(client *Client) {
27 | err := client.close()
28 | if err != nil {
29 |
30 | }
31 | }(client)
32 |
33 | t.Run("Test CreateEmbedding", func(t *testing.T) {
34 | resp, rerr := client.CreateEmbedding(context.Background(), []string{"Test document"})
35 | require.Nil(t, rerr)
36 | require.NotNil(t, resp)
37 | require.Len(t, resp, 1)
38 | })
39 | }
40 |
41 | func Test_gemini_embedding_function(t *testing.T) {
42 | apiKey := os.Getenv(APIKeyEnvVar)
43 | if apiKey == "" {
44 | err := godotenv.Load("../../../.env")
45 | if err != nil {
46 | assert.Failf(t, "Error loading .env file", "%s", err)
47 | }
48 | apiKey = os.Getenv(APIKeyEnvVar)
49 | }
50 |
51 | t.Run("Test EmbedDocuments with env-based api key", func(t *testing.T) {
52 | embeddingFunction, err := NewGeminiEmbeddingFunction(WithEnvAPIKey())
53 | defer func(embeddingFunction *GeminiEmbeddingFunction) {
54 | err := embeddingFunction.close()
55 | if err != nil {
56 |
57 | }
58 | }(embeddingFunction)
59 | require.NoError(t, err)
60 | resp, rerr := embeddingFunction.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"})
61 | require.Nil(t, rerr)
62 | require.NotNil(t, resp)
63 | require.Len(t, resp, 2)
64 | require.Len(t, resp[0].ContentAsFloat32(), 768)
65 |
66 | })
67 |
68 | t.Run("Test EmbedDocuments with provided API key", func(t *testing.T) {
69 | embeddingFunction, err := NewGeminiEmbeddingFunction(WithAPIKey(apiKey))
70 | defer func(embeddingFunction *GeminiEmbeddingFunction) {
71 | err := embeddingFunction.close()
72 | if err != nil {
73 |
74 | }
75 | }(embeddingFunction)
76 | require.NoError(t, err)
77 | resp, rerr := embeddingFunction.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"})
78 |
79 | require.Nil(t, rerr)
80 | require.NotNil(t, resp)
81 | require.Len(t, resp, 2)
82 | require.Len(t, resp[0].ContentAsFloat32(), 768)
83 |
84 | })
85 |
86 | t.Run("Test EmbedDocuments with provided model", func(t *testing.T) {
87 | embeddingFunction, err := NewGeminiEmbeddingFunction(WithEnvAPIKey(), WithDefaultModel(DefaultEmbeddingModel))
88 | defer func(embeddingFunction *GeminiEmbeddingFunction) {
89 | err := embeddingFunction.close()
90 | if err != nil {
91 |
92 | }
93 | }(embeddingFunction)
94 | require.NoError(t, err)
95 | resp, rerr := embeddingFunction.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"})
96 |
97 | require.Nil(t, rerr)
98 | require.NotNil(t, resp)
99 | require.Len(t, resp, 2)
100 | require.Len(t, resp[0].ContentAsFloat32(), 768)
101 |
102 | })
103 |
104 | t.Run("Test EmbedQuery", func(t *testing.T) {
105 | embeddingFunction, err := NewGeminiEmbeddingFunction(WithEnvAPIKey(), WithDefaultModel(DefaultEmbeddingModel))
106 | defer func(embeddingFunction *GeminiEmbeddingFunction) {
107 | err := embeddingFunction.close()
108 | if err != nil {
109 |
110 | }
111 | }(embeddingFunction)
112 | require.NoError(t, err)
113 | resp, rerr := embeddingFunction.EmbedQuery(context.Background(), "this is my query")
114 | require.Nil(t, rerr)
115 | require.NotNil(t, resp)
116 | require.Len(t, resp.ContentAsFloat32(), 768)
117 | })
118 |
119 | t.Run("Test wrong model", func(t *testing.T) {
120 | embeddingFunction, err := NewGeminiEmbeddingFunction(WithEnvAPIKey(), WithDefaultModel("model-does-not-exist"))
121 | defer func(embeddingFunction *GeminiEmbeddingFunction) {
122 | err := embeddingFunction.close()
123 | if err != nil {
124 |
125 | }
126 | }(embeddingFunction)
127 | require.NoError(t, err)
128 | _, rerr := embeddingFunction.EmbedQuery(context.Background(), "this is my query")
129 | require.Contains(t, rerr.Error(), "Error 404")
130 | require.Error(t, rerr)
131 | })
132 |
133 | t.Run("Test wrong API key", func(t *testing.T) {
134 | embeddingFunction, err := NewGeminiEmbeddingFunction(WithAPIKey("wrong-api-key"))
135 | defer func(embeddingFunction *GeminiEmbeddingFunction) {
136 | err := embeddingFunction.close()
137 | if err != nil {
138 |
139 | }
140 | }(embeddingFunction)
141 | require.NoError(t, err)
142 | _, rerr := embeddingFunction.EmbedQuery(context.Background(), "this is my query")
143 | require.Contains(t, rerr.Error(), "API key not valid")
144 | require.Error(t, rerr)
145 | })
146 | }
147 |
--------------------------------------------------------------------------------
/pkg/embeddings/gemini/option.go:
--------------------------------------------------------------------------------
1 | package gemini
2 |
3 | import (
4 | "os"
5 |
6 | "github.com/google/generative-ai-go/genai"
7 | "github.com/pkg/errors"
8 |
9 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
10 | )
11 |
12 | type Option func(p *Client) error
13 |
14 | // WithDefaultModel sets the default model for the client
15 | func WithDefaultModel(model embeddings.EmbeddingModel) Option {
16 | return func(p *Client) error {
17 | if model == "" {
18 | return errors.New("model cannot be empty")
19 | }
20 | p.DefaultModel = model
21 | return nil
22 | }
23 | }
24 |
25 | // WithAPIKey sets the API key for the client
26 | func WithAPIKey(apiKey string) Option {
27 | return func(p *Client) error {
28 | if apiKey == "" {
29 | return errors.New("API key cannot be empty")
30 | }
31 | p.apiKey = apiKey
32 | return nil
33 | }
34 | }
35 |
36 | // WithEnvAPIKey sets the API key for the client from the environment variable GOOGLE_API_KEY
37 | func WithEnvAPIKey() Option {
38 | return func(p *Client) error {
39 | if apiKey := os.Getenv(APIKeyEnvVar); apiKey != "" {
40 | p.apiKey = apiKey
41 | return nil
42 | }
43 | return errors.Errorf("%s not set", APIKeyEnvVar)
44 | }
45 | }
46 |
47 | // WithClient sets the generative AI client for the client
48 | func WithClient(client *genai.Client) Option {
49 | return func(p *Client) error {
50 | if client == nil {
51 | return errors.New("google generative AI client is nil")
52 | }
53 | p.Client = client
54 | return nil
55 | }
56 | }
57 |
58 | // WithMaxBatchSize sets the max batch size for the client - this acts as a limit for the number of embeddings that can be sent in a single request
59 | func WithMaxBatchSize(maxBatchSize int) Option {
60 | return func(p *Client) error {
61 | if maxBatchSize < 1 {
62 | return errors.New("max batch size must be greater than 0")
63 | }
64 | p.MaxBatchSize = maxBatchSize
65 | return nil
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/pkg/embeddings/hf/option.go:
--------------------------------------------------------------------------------
1 | package hf
2 |
3 | import (
4 | "os"
5 |
6 | "github.com/pkg/errors"
7 | )
8 |
9 | type Option func(p *HuggingFaceClient) error
10 |
11 | func WithBaseURL(baseURL string) Option {
12 | return func(p *HuggingFaceClient) error {
13 | if baseURL == "" {
14 | return errors.New("base URL cannot be empty")
15 | }
16 | p.BaseURL = baseURL
17 | return nil
18 | }
19 | }
20 |
21 | func WithAPIKey(apiKey string) Option {
22 | return func(p *HuggingFaceClient) error {
23 | if apiKey == "" {
24 | return errors.New("API key cannot be empty")
25 | }
26 | p.APIKey = apiKey
27 | return nil
28 | }
29 | }
30 |
31 | func WithEnvAPIKey() Option {
32 | return func(p *HuggingFaceClient) error {
33 | if os.Getenv("HF_API_KEY") == "" {
34 | return errors.New("HF_API_KEY not set")
35 | }
36 | p.APIKey = os.Getenv("HF_API_KEY")
37 | return nil
38 | }
39 | }
40 |
41 | func WithModel(model string) Option {
42 | return func(p *HuggingFaceClient) error {
43 | if model == "" {
44 | return errors.New("model cannot be empty")
45 | }
46 | p.Model = model
47 | return nil
48 | }
49 | }
50 |
51 | func WithDefaultHeaders(headers map[string]string) Option {
52 | return func(p *HuggingFaceClient) error {
53 | p.DefaultHeaders = headers
54 | return nil
55 | }
56 | }
57 |
58 | func WithIsHFEIEndpoint() Option {
59 | return func(p *HuggingFaceClient) error {
60 | p.IsHFEIEndpoint = true
61 | return nil
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/pkg/embeddings/jina/jina.go:
--------------------------------------------------------------------------------
1 | package jina
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "io"
8 | "net/http"
9 |
10 | "github.com/pkg/errors"
11 |
12 | chttp "github.com/amikos-tech/chroma-go/pkg/commons/http"
13 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
14 | )
15 |
16 | type EmbeddingType string
17 |
18 | const (
19 | EmbeddingTypeFloat EmbeddingType = "float"
20 | DefaultBaseAPIEndpoint = "https://api.jina.ai/v1/embeddings"
21 | DefaultEmbeddingModel embeddings.EmbeddingModel = "jina-embeddings-v2-base-en"
22 | )
23 |
24 | type EmbeddingRequest struct {
25 | Model string `json:"model"`
26 | Normalized bool `json:"normalized,omitempty"`
27 | EmbeddingType EmbeddingType `json:"embedding_type,omitempty"`
28 | Input []map[string]string `json:"input"`
29 | }
30 |
31 | type EmbeddingResponse struct {
32 | Model string `json:"model"`
33 | Object string `json:"object"`
34 | Usage struct {
35 | TotalTokens int `json:"total_tokens"`
36 | PromptTokens int `json:"prompt_tokens"`
37 | }
38 | Data []struct {
39 | Object string `json:"object"`
40 | Index int `json:"index"`
41 | Embedding []float32 `json:"embedding"` // TODO what about other embedding types - see cohere for example
42 | }
43 | }
44 |
45 | var _ embeddings.EmbeddingFunction = (*JinaEmbeddingFunction)(nil)
46 |
47 | func getDefaults() *JinaEmbeddingFunction {
48 | return &JinaEmbeddingFunction{
49 | httpClient: http.DefaultClient,
50 | defaultModel: DefaultEmbeddingModel,
51 | embeddingEndpoint: DefaultBaseAPIEndpoint,
52 | normalized: true,
53 | embeddingType: EmbeddingTypeFloat,
54 | }
55 | }
56 |
57 | type JinaEmbeddingFunction struct {
58 | httpClient *http.Client
59 | apiKey string
60 | defaultModel embeddings.EmbeddingModel
61 | embeddingEndpoint string
62 | normalized bool
63 | embeddingType EmbeddingType
64 | }
65 |
66 | func NewJinaEmbeddingFunction(opts ...Option) (*JinaEmbeddingFunction, error) {
67 | ef := getDefaults()
68 | for _, opt := range opts {
69 | err := opt(ef)
70 | if err != nil {
71 | return nil, err
72 | }
73 | }
74 | return ef, nil
75 | }
76 |
77 | func (e *JinaEmbeddingFunction) sendRequest(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
78 | payload, err := json.Marshal(req)
79 | if err != nil {
80 | return nil, errors.Wrapf(err, "failed to marshal embedding request body")
81 | }
82 |
83 | httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, e.embeddingEndpoint, bytes.NewBuffer(payload))
84 | if err != nil {
85 | return nil, errors.Wrapf(err, "failed to create embedding request")
86 | }
87 |
88 | httpReq.Header.Set("Content-Type", "application/json")
89 | httpReq.Header.Set("Accept", "application/json")
90 | httpReq.Header.Set("User-Agent", chttp.ChromaGoClientUserAgent)
91 | httpReq.Header.Set("Authorization", "Bearer "+e.apiKey)
92 |
93 | resp, err := e.httpClient.Do(httpReq)
94 | if err != nil {
95 | return nil, errors.Wrapf(err, "failed to send embedding request")
96 | }
97 | defer resp.Body.Close()
98 |
99 | respData, err := io.ReadAll(resp.Body)
100 | if err != nil {
101 | return nil, errors.Wrapf(err, "failed to read response body")
102 | }
103 |
104 | if resp.StatusCode != http.StatusOK {
105 | return nil, errors.Errorf("unexpected response %v: %s", resp.Status, string(respData))
106 | }
107 | var response *EmbeddingResponse
108 | if err := json.Unmarshal(respData, &response); err != nil {
109 | return nil, errors.Wrapf(err, "failed to unmarshal embedding response")
110 | }
111 |
112 | return response, nil
113 | }
114 |
115 | func (e *JinaEmbeddingFunction) EmbedDocuments(ctx context.Context, documents []string) ([]embeddings.Embedding, error) {
116 | var Input = make([]map[string]string, len(documents))
117 |
118 | for i, doc := range documents {
119 | Input[i] = map[string]string{
120 | "text": doc,
121 | }
122 | }
123 | req := &EmbeddingRequest{
124 | Model: string(e.defaultModel),
125 | Input: Input,
126 | }
127 | response, err := e.sendRequest(ctx, req)
128 | if err != nil {
129 | return nil, errors.Wrapf(err, "failed to embed documents")
130 | }
131 | var embs []embeddings.Embedding
132 | for _, data := range response.Data {
133 | embs = append(embs, embeddings.NewEmbeddingFromFloat32(data.Embedding))
134 | }
135 |
136 | return embs, nil
137 | }
138 |
139 | func (e *JinaEmbeddingFunction) EmbedQuery(ctx context.Context, document string) (embeddings.Embedding, error) {
140 | var Input = make([]map[string]string, 1)
141 |
142 | Input[0] = map[string]string{
143 | "text": document,
144 | }
145 | req := &EmbeddingRequest{
146 | Model: string(e.defaultModel),
147 | Input: Input,
148 | }
149 | response, err := e.sendRequest(ctx, req)
150 | if err != nil {
151 | return nil, errors.Wrapf(err, "failed to embed query")
152 | }
153 |
154 | return embeddings.NewEmbeddingFromFloat32(response.Data[0].Embedding), nil
155 | }
156 |
--------------------------------------------------------------------------------
/pkg/embeddings/jina/jina_test.go:
--------------------------------------------------------------------------------
1 | //go:build ef
2 |
3 | package jina
4 |
5 | import (
6 | "context"
7 | "os"
8 | "testing"
9 |
10 | "github.com/joho/godotenv"
11 | "github.com/stretchr/testify/assert"
12 | "github.com/stretchr/testify/require"
13 | )
14 |
15 | func TestJinaEmbeddingFunction(t *testing.T) {
16 | apiKey := os.Getenv("JINA_API_KEY")
17 | if apiKey == "" {
18 | err := godotenv.Load("../../../.env")
19 | if err != nil {
20 | assert.Failf(t, "Error loading .env file", "%s", err)
21 | }
22 | apiKey = os.Getenv("JINA_API_KEY")
23 | }
24 |
25 | t.Run("Test with defaults", func(t *testing.T) {
26 | ef, err := NewJinaEmbeddingFunction(WithAPIKey(apiKey))
27 | require.NoError(t, err)
28 | documents := []string{
29 | "Document 1 content here",
30 | "Document 2 content here",
31 | }
32 | resp, err := ef.EmbedDocuments(context.Background(), documents)
33 | require.NoError(t, err)
34 | require.NotNil(t, resp)
35 | require.Len(t, resp, 2)
36 | require.Equal(t, 768, resp[0].Len())
37 | })
38 |
39 | t.Run("Test with env API key", func(t *testing.T) {
40 | ef, err := NewJinaEmbeddingFunction(WithEnvAPIKey())
41 | require.NoError(t, err)
42 | documents := []string{
43 | "Document 1 content here",
44 | }
45 | resp, err := ef.EmbedDocuments(context.Background(), documents)
46 | require.NoError(t, err)
47 | require.NotNil(t, resp)
48 | require.Len(t, resp, 1)
49 | require.Equal(t, 768, resp[0].Len())
50 | })
51 |
52 | t.Run("Test with normalized off", func(t *testing.T) {
53 | ef, err := NewJinaEmbeddingFunction(WithEnvAPIKey(), WithNormalized(false))
54 | require.NoError(t, err)
55 | documents := []string{
56 | "Document 1 content here",
57 | }
58 | resp, err := ef.EmbedDocuments(context.Background(), documents)
59 | require.NoError(t, err)
60 | require.NotNil(t, resp)
61 | require.Len(t, resp, 1)
62 | require.Equal(t, 768, resp[0].Len())
63 | })
64 |
65 | t.Run("Test with model", func(t *testing.T) {
66 | ef, err := NewJinaEmbeddingFunction(WithEnvAPIKey(), WithModel("jina-embeddings-v3"))
67 | require.NoError(t, err)
68 | documents := []string{
69 | "import chromadb;client=chromadb.Client();collection=client.get_or_create_collection('col_name')",
70 | }
71 | resp, err := ef.EmbedDocuments(context.Background(), documents)
72 | require.NoError(t, err)
73 | require.NotNil(t, resp)
74 | require.Len(t, resp, 1)
75 | require.Equal(t, 1024, resp[0].Len())
76 | })
77 |
78 | t.Run("Test with EmbeddingType float", func(t *testing.T) {
79 | ef, err := NewJinaEmbeddingFunction(WithEnvAPIKey(), WithEmbeddingType(EmbeddingTypeFloat))
80 | require.NoError(t, err)
81 | documents := []string{
82 | "Document 1 content here",
83 | }
84 | resp, err := ef.EmbedDocuments(context.Background(), documents)
85 | require.NoError(t, err)
86 | require.NotNil(t, resp)
87 | require.Len(t, resp, 1)
88 | require.Equal(t, 768, resp[0].Len())
89 | })
90 |
91 | t.Run("Test with embedding endpoint", func(t *testing.T) {
92 | ef, err := NewJinaEmbeddingFunction(WithEnvAPIKey(), WithEmbeddingEndpoint(DefaultBaseAPIEndpoint))
93 | require.NoError(t, err)
94 | documents := []string{
95 | "Document 1 content here",
96 | }
97 | resp, err := ef.EmbedDocuments(context.Background(), documents)
98 | require.NoError(t, err)
99 | require.NotNil(t, resp)
100 | require.Len(t, resp, 1)
101 | require.Equal(t, 768, resp[0].Len())
102 | })
103 | }
104 |
--------------------------------------------------------------------------------
/pkg/embeddings/jina/option.go:
--------------------------------------------------------------------------------
1 | package jina
2 |
3 | import (
4 | "os"
5 |
6 | "github.com/pkg/errors"
7 |
8 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
9 | )
10 |
11 | type Option func(c *JinaEmbeddingFunction) error
12 |
13 | func WithAPIKey(apiKey string) Option {
14 | return func(c *JinaEmbeddingFunction) error {
15 | c.apiKey = apiKey
16 | return nil
17 | }
18 | }
19 |
20 | func WithEnvAPIKey() Option {
21 | return func(c *JinaEmbeddingFunction) error {
22 | if os.Getenv("JINA_API_KEY") == "" {
23 | return errors.Errorf("JINA_API_KEY not set")
24 | }
25 | c.apiKey = os.Getenv("JINA_API_KEY")
26 | return nil
27 | }
28 | }
29 |
30 | func WithModel(model embeddings.EmbeddingModel) Option {
31 | return func(c *JinaEmbeddingFunction) error {
32 | if model == "" {
33 | return errors.New("model cannot be empty")
34 | }
35 | c.defaultModel = model
36 | return nil
37 | }
38 | }
39 |
40 | func WithEmbeddingEndpoint(endpoint string) Option {
41 | return func(c *JinaEmbeddingFunction) error {
42 | if endpoint == "" {
43 | return errors.New("embedding endpoint cannot be empty")
44 | }
45 | c.embeddingEndpoint = endpoint
46 | return nil
47 | }
48 | }
49 |
50 | // WithNormalized sets the flag to indicate to Jina whether to normalize (L2 norm) the output embeddings or not. Defaults to true
51 | func WithNormalized(normalized bool) Option {
52 | return func(c *JinaEmbeddingFunction) error {
53 | c.normalized = normalized
54 | return nil
55 | }
56 | }
57 |
58 | // WithEmbeddingType sets the type of the embedding to be returned by Jina. The default is float. Right now no other options are supported
59 | func WithEmbeddingType(embeddingType EmbeddingType) Option {
60 | return func(c *JinaEmbeddingFunction) error {
61 | if embeddingType == "" {
62 | return errors.New("embedding type cannot be empty")
63 | }
64 | c.embeddingType = embeddingType
65 | return nil
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/pkg/embeddings/mistral/option.go:
--------------------------------------------------------------------------------
1 | package mistral
2 |
3 | import (
4 | "net/http"
5 | "net/url"
6 | "os"
7 |
8 | "github.com/pkg/errors"
9 | )
10 |
11 | type Option func(p *Client) error
12 |
13 | // WithDefaultModel sets the default model for the client
14 | func WithDefaultModel(model string) Option {
15 | return func(p *Client) error {
16 | if model == "" {
17 | return errors.Errorf("default model cannot be empty")
18 | }
19 | p.DefaultModel = model
20 | return nil
21 | }
22 | }
23 |
24 | // WithAPIKey sets the API key for the client
25 | func WithAPIKey(apiKey string) Option {
26 | return func(p *Client) error {
27 | if apiKey == "" {
28 | return errors.Errorf("api key cannot be empty")
29 | }
30 | p.apiKey = apiKey
31 | return nil
32 | }
33 | }
34 |
35 | // WithEnvAPIKey sets the API key for the client from the environment variable GOOGLE_API_KEY
36 | func WithEnvAPIKey() Option {
37 | return func(p *Client) error {
38 | if apiKey := os.Getenv(APIKeyEnvVar); apiKey != "" {
39 | p.apiKey = apiKey
40 | return nil
41 | }
42 | return errors.Errorf("%s not set", APIKeyEnvVar)
43 | }
44 | }
45 |
46 | // WithHTTPClient sets the generative AI client for the client
47 | func WithHTTPClient(client *http.Client) Option {
48 | return func(p *Client) error {
49 | if client == nil {
50 | return errors.Errorf("http client cannot be nil")
51 | }
52 | p.Client = client
53 | return nil
54 | }
55 | }
56 |
57 | // WithMaxBatchSize sets the max batch size for the client - this acts as a limit for the number of embeddings that can be sent in a single request
58 | func WithMaxBatchSize(maxBatchSize int) Option {
59 | return func(p *Client) error {
60 | if maxBatchSize <= 0 {
61 | return errors.Errorf("max batch size must be greater than 0")
62 | }
63 | p.MaxBatchSize = maxBatchSize
64 | return nil
65 | }
66 | }
67 |
68 | // WithBaseURL sets the base URL for the client
69 | func WithBaseURL(baseURL string) Option {
70 | return func(p *Client) error {
71 | if baseURL == "" {
72 | return errors.Errorf("base URL cannot be empty")
73 | }
74 | var err error
75 | p.EmbeddingEndpoint, err = url.JoinPath(baseURL, EmbeddingsEndpoint)
76 | if err != nil {
77 | return errors.Wrap(err, "failed to parse embedding endpoint")
78 | }
79 | return nil
80 | }
81 | }
82 |
--------------------------------------------------------------------------------
/pkg/embeddings/nomic/option.go:
--------------------------------------------------------------------------------
1 | package nomic
2 |
3 | import (
4 | "net/http"
5 | "net/url"
6 | "os"
7 |
8 | "github.com/pkg/errors"
9 |
10 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
11 | )
12 |
13 | type Option func(p *Client) error
14 |
15 | // WithDefaultModel sets the default model for the client
16 | func WithDefaultModel(model embeddings.EmbeddingModel) Option {
17 | return func(p *Client) error {
18 | if model == "" {
19 | return errors.New("default model cannot be empty")
20 | }
21 | p.DefaultModel = model
22 | return nil
23 | }
24 | }
25 |
26 | // WithAPIKey sets the API key for the client
27 | func WithAPIKey(apiKey string) Option {
28 | return func(p *Client) error {
29 | if apiKey == "" {
30 | return errors.New("api key cannot be empty")
31 | }
32 | p.apiKey = apiKey
33 | return nil
34 | }
35 | }
36 |
37 | // WithEnvAPIKey sets the API key for the client from the environment variable GOOGLE_API_KEY
38 | func WithEnvAPIKey() Option {
39 | return func(p *Client) error {
40 | if apiKey := os.Getenv(APIKeyEnvVar); apiKey != "" {
41 | p.apiKey = apiKey
42 | return nil
43 | }
44 | return errors.Errorf("%s not set", APIKeyEnvVar)
45 | }
46 | }
47 |
48 | // WithHTTPClient sets the generative AI client for the client
49 | func WithHTTPClient(client *http.Client) Option {
50 | return func(p *Client) error {
51 | if client == nil {
52 | return errors.New("http client cannot be nil")
53 | }
54 | p.Client = client
55 | return nil
56 | }
57 | }
58 |
59 | // WithMaxBatchSize sets the max batch size for the client - this acts as a limit for the number of embeddings that can be sent in a single request
60 | func WithMaxBatchSize(maxBatchSize int) Option {
61 | return func(p *Client) error {
62 | if maxBatchSize <= 0 {
63 | return errors.New("max batch size must be greater than 0")
64 | }
65 | p.MaxBatchSize = maxBatchSize
66 | return nil
67 | }
68 | }
69 |
70 | // WithBaseURL sets the base URL for the client
71 | func WithBaseURL(baseURL string) Option {
72 | return func(p *Client) error {
73 | if baseURL == "" {
74 | return errors.New("base URL cannot be empty")
75 | }
76 | if _, err := url.ParseRequestURI(baseURL); err != nil {
77 | return errors.Wrap(err, "invalid basePath URL")
78 | }
79 | p.BaseURL = baseURL
80 | return nil
81 | }
82 | }
83 |
84 | // WithTextEmbeddings sets the endpoint to text embeddings
85 | func WithTextEmbeddings() Option {
86 | return func(p *Client) error {
87 | p.EmbeddingsEndpointSuffix = TextEmbeddingsEndpoint
88 | return nil
89 | }
90 | }
91 |
--------------------------------------------------------------------------------
/pkg/embeddings/ollama/ollama.go:
--------------------------------------------------------------------------------
1 | package ollama
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "io"
8 | "net/http"
9 | "net/url"
10 |
11 | "github.com/pkg/errors"
12 |
13 | chttp "github.com/amikos-tech/chroma-go/pkg/commons/http"
14 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
15 | )
16 |
17 | type OllamaClient struct {
18 | BaseURL string
19 | Model embeddings.EmbeddingModel
20 | Client *http.Client
21 | DefaultHeaders map[string]string
22 | }
23 |
24 | type EmbeddingInput struct {
25 | Input string
26 | Inputs []string
27 | }
28 |
29 | func (e EmbeddingInput) MarshalJSON() ([]byte, error) {
30 | if e.Input != "" {
31 | b, err := json.Marshal(e.Input)
32 | if err != nil {
33 | return nil, errors.Wrap(err, "failed to marshal embedding input")
34 | }
35 | return b, nil
36 | } else if len(e.Inputs) > 0 {
37 | b, err := json.Marshal(e.Inputs)
38 | if err != nil {
39 | return nil, errors.Wrap(err, "failed to marshal embedding input")
40 | }
41 | return b, nil
42 | }
43 | return json.Marshal(nil)
44 | }
45 |
46 | type CreateEmbeddingRequest struct {
47 | Model string `json:"model"`
48 | Input *EmbeddingInput `json:"input"`
49 | }
50 |
51 | type CreateEmbeddingResponse struct {
52 | Embeddings [][]float32 `json:"embeddings"`
53 | }
54 |
55 | func (c *CreateEmbeddingRequest) JSON() (string, error) {
56 | data, err := json.Marshal(c)
57 | if err != nil {
58 | return "", errors.Wrap(err, "failed to marshal embedding request JSON")
59 | }
60 | return string(data), nil
61 | }
62 |
63 | func NewOllamaClient(opts ...Option) (*OllamaClient, error) {
64 | client := &OllamaClient{
65 | Client: &http.Client{},
66 | }
67 | for _, opt := range opts {
68 | err := opt(client)
69 | if err != nil {
70 | return nil, errors.Wrap(err, "failed to apply Ollama option")
71 | }
72 | }
73 | return client, nil
74 | }
75 |
76 | func (c *OllamaClient) createEmbedding(ctx context.Context, req *CreateEmbeddingRequest) (*CreateEmbeddingResponse, error) {
77 | reqJSON, err := req.JSON()
78 | if err != nil {
79 | return nil, errors.Wrap(err, "failed to marshal embedding request JSON")
80 | }
81 | endpoint, err := url.JoinPath(c.BaseURL, "/api/embed")
82 | if err != nil {
83 | return nil, errors.Wrap(err, "failed to parse Ollama embedding endpoint")
84 | }
85 |
86 | httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBufferString(reqJSON))
87 | if err != nil {
88 | return nil, errors.Wrap(err, "failed to create HTTP request")
89 | }
90 | for k, v := range c.DefaultHeaders {
91 | httpReq.Header.Set(k, v)
92 | }
93 | httpReq.Header.Set("Accept", "application/json")
94 | httpReq.Header.Set("Content-Type", "application/json")
95 | httpReq.Header.Set("User-Agent", chttp.ChromaGoClientUserAgent)
96 |
97 | resp, err := c.Client.Do(httpReq)
98 | if err != nil {
99 | return nil, errors.Wrap(err, "failed to make HTTP request to Ollama embedding endpoint")
100 | }
101 | defer resp.Body.Close()
102 |
103 | respData, err := io.ReadAll(resp.Body)
104 | if err != nil {
105 | return nil, errors.Wrap(err, "failed to read response body")
106 | }
107 |
108 | if resp.StatusCode != http.StatusOK {
109 | return nil, errors.Errorf("unexpected code [%v] while making a request to %v: %v", resp.Status, endpoint, string(respData))
110 | }
111 |
112 | var embeddingResponse CreateEmbeddingResponse
113 | if err := json.Unmarshal(respData, &embeddingResponse); err != nil {
114 | return nil, errors.Wrap(err, "failed to unmarshal embedding response")
115 | }
116 | return &embeddingResponse, nil
117 | }
118 |
119 | type OllamaEmbeddingFunction struct {
120 | apiClient *OllamaClient
121 | }
122 |
123 | var _ embeddings.EmbeddingFunction = (*OllamaEmbeddingFunction)(nil)
124 |
125 | func NewOllamaEmbeddingFunction(option ...Option) (*OllamaEmbeddingFunction, error) {
126 | client, err := NewOllamaClient(option...)
127 | if err != nil {
128 | return nil, errors.Wrap(err, "failed to initialize OllamaClient")
129 | }
130 | return &OllamaEmbeddingFunction{
131 | apiClient: client,
132 | }, nil
133 | }
134 |
135 | func (e *OllamaEmbeddingFunction) EmbedDocuments(ctx context.Context, documents []string) ([]embeddings.Embedding, error) {
136 | response, err := e.apiClient.createEmbedding(ctx, &CreateEmbeddingRequest{
137 | Model: string(e.apiClient.Model),
138 | Input: &EmbeddingInput{Inputs: documents},
139 | })
140 | if err != nil {
141 | return nil, errors.Wrap(err, "failed to embed documents")
142 | }
143 | return embeddings.NewEmbeddingsFromFloat32(response.Embeddings)
144 | }
145 |
146 | func (e *OllamaEmbeddingFunction) EmbedQuery(ctx context.Context, document string) (embeddings.Embedding, error) {
147 | response, err := e.apiClient.createEmbedding(ctx, &CreateEmbeddingRequest{
148 | Model: string(e.apiClient.Model),
149 | Input: &EmbeddingInput{Input: document},
150 | })
151 | if err != nil {
152 | return nil, errors.Wrap(err, "failed to embed query")
153 | }
154 | return embeddings.NewEmbeddingFromFloat32(response.Embeddings[0]), nil
155 | }
156 |
--------------------------------------------------------------------------------
/pkg/embeddings/ollama/ollama_test.go:
--------------------------------------------------------------------------------
1 | //go:build ef
2 |
3 | package ollama
4 |
5 | import (
6 | "context"
7 | "fmt"
8 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
9 | "github.com/stretchr/testify/require"
10 | tcollama "github.com/testcontainers/testcontainers-go/modules/ollama"
11 | "io"
12 | "net/http"
13 | "strings"
14 | "testing"
15 | )
16 |
17 | func Test_ollama(t *testing.T) {
18 | ctx := context.Background()
19 | ollamaContainer, err := tcollama.Run(ctx, "ollama/ollama:latest")
20 | require.NoError(t, err)
21 | // Clean up the container
22 | defer func() {
23 | if err := ollamaContainer.Terminate(ctx); err != nil {
24 | t.Logf("failed to terminate container: %s\n", err)
25 | }
26 | }()
27 |
28 | model := "nomic-embed-text"
29 | connectionStr, err := ollamaContainer.ConnectionString(ctx)
30 | require.NoError(t, err)
31 | pullURL := fmt.Sprintf("%s/api/pull", connectionStr)
32 | pullPayload := fmt.Sprintf(`{"name": "%s"}`, model)
33 |
34 | resp, err := http.Post(
35 | pullURL,
36 | "application/json",
37 | strings.NewReader(pullPayload),
38 | )
39 | require.NoError(t, err)
40 | respStr, err := io.ReadAll(resp.Body)
41 | require.NoError(t, err)
42 | defer resp.Body.Close()
43 | require.Contains(t, string(respStr), "success")
44 |
45 | // Ensure successful response
46 | require.Equal(t, http.StatusOK, resp.StatusCode)
47 | client, err := NewOllamaClient(WithBaseURL(connectionStr), WithModel(embeddings.EmbeddingModel(model)))
48 | require.NoError(t, err)
49 | t.Run("Test Create Embed Single document", func(t *testing.T) {
50 | resp, rerr := client.createEmbedding(context.Background(), &CreateEmbeddingRequest{Model: "nomic-embed-text", Input: &EmbeddingInput{Input: "Document 1 content here"}})
51 | require.Nil(t, rerr)
52 | require.NotNil(t, resp)
53 | })
54 | t.Run("Test Create Embed multi-document", func(t *testing.T) {
55 | documents := []string{
56 | "Document 1 content here",
57 | "Document 2 content here",
58 | }
59 | ef, err := NewOllamaEmbeddingFunction(WithBaseURL(connectionStr), WithModel(embeddings.EmbeddingModel(model)))
60 | require.NoError(t, err)
61 | resp, rerr := ef.EmbedDocuments(context.Background(), documents)
62 | require.Nil(t, rerr)
63 | require.NotNil(t, resp)
64 | require.Equal(t, 2, len(resp))
65 | })
66 | }
67 |
--------------------------------------------------------------------------------
/pkg/embeddings/ollama/option.go:
--------------------------------------------------------------------------------
1 | package ollama
2 |
3 | import (
4 | "net/url"
5 |
6 | "github.com/pkg/errors"
7 |
8 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
9 | )
10 |
11 | type Option func(p *OllamaClient) error
12 |
13 | func WithBaseURL(baseURL string) Option {
14 | return func(p *OllamaClient) error {
15 | if baseURL == "" {
16 | return errors.New("base URL cannot be empty")
17 | }
18 | if _, err := url.ParseRequestURI(baseURL); err != nil {
19 | return errors.Wrap(err, "invalid base URL")
20 | }
21 | p.BaseURL = baseURL
22 | return nil
23 | }
24 | }
25 | func WithModel(model embeddings.EmbeddingModel) Option {
26 | return func(p *OllamaClient) error {
27 | if model == "" {
28 | return errors.New("model cannot be empty")
29 | }
30 | p.Model = model
31 | return nil
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/pkg/embeddings/openai/options.go:
--------------------------------------------------------------------------------
1 | package openai
2 |
3 | import (
4 | "net/url"
5 |
6 | "github.com/pkg/errors"
7 | )
8 |
9 | // Option is a function type that can be used to modify the client.
10 | type Option func(c *OpenAIClient) error
11 |
12 | func WithBaseURL(baseURL string) Option {
13 | return func(p *OpenAIClient) error {
14 | if baseURL == "" {
15 | return errors.New("Base URL cannot be empty")
16 | }
17 | if _, err := url.ParseRequestURI(baseURL); err != nil {
18 | return errors.Wrap(err, "invalid base URL")
19 | }
20 | p.BaseURL = baseURL
21 | return nil
22 | }
23 | }
24 |
25 | // WithOpenAIOrganizationID is an option for setting the OpenAI org id.
26 | func WithOpenAIOrganizationID(orgID string) Option {
27 | return func(c *OpenAIClient) error {
28 | if orgID == "" {
29 | return errors.New("OrgID cannot be empty")
30 | }
31 | c.OrgID = orgID
32 | return nil
33 | }
34 | }
35 |
36 | // WithOpenAIUser is an option for setting the OpenAI user. The user is passed with every request to OpenAI. It serves for auditing purposes. If not set the user defaults to ChromaGo client.
37 | func WithOpenAIUser(user string) Option {
38 | return func(c *OpenAIClient) error {
39 | if user == "" {
40 | return errors.New("User cannot be empty")
41 | }
42 | c.User = user
43 | return nil
44 | }
45 | }
46 |
47 | // WithModel is an option for setting the model to use. Must be one of: text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large
48 | func WithModel(model EmbeddingModel) Option {
49 | return func(c *OpenAIClient) error {
50 | if string(model) == "" {
51 | return errors.New("Model cannot be empty")
52 | }
53 | if model != TextEmbeddingAda002 && model != TextEmbedding3Small && model != TextEmbedding3Large {
54 | return errors.Errorf("invalid model name %s. Must be one of: %v", model, []string{string(TextEmbeddingAda002), string(TextEmbedding3Small), string(TextEmbedding3Large)})
55 | }
56 | c.Model = string(model)
57 | return nil
58 | }
59 | }
60 | func WithDimensions(dimensions int) Option {
61 | return func(c *OpenAIClient) error {
62 | if dimensions <= 0 {
63 | return errors.Errorf("dimensions must be greater than 0, got %d", dimensions)
64 | }
65 | c.Dimensions = &dimensions
66 | return nil
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/pkg/embeddings/together/option.go:
--------------------------------------------------------------------------------
1 | package together
2 |
3 | import (
4 | "net/http"
5 | "os"
6 |
7 | "github.com/pkg/errors"
8 |
9 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
10 | )
11 |
12 | type Option func(p *TogetherAIClient) error
13 |
14 | func WithDefaultModel(model embeddings.EmbeddingModel) Option {
15 | return func(p *TogetherAIClient) error {
16 | if model == "" {
17 | return errors.New("default model cannot be empty")
18 | }
19 | p.DefaultModel = model
20 | return nil
21 | }
22 | }
23 |
24 | func WithMaxBatchSize(size int) Option {
25 | return func(p *TogetherAIClient) error {
26 | if size <= 0 {
27 | return errors.New("max batch size must be greater than 0")
28 | }
29 | p.MaxBatchSize = size
30 | return nil
31 | }
32 | }
33 |
34 | func WithDefaultHeaders(headers map[string]string) Option {
35 | return func(p *TogetherAIClient) error {
36 | p.DefaultHeaders = headers
37 | return nil
38 | }
39 | }
40 |
41 | func WithAPIToken(apiToken string) Option {
42 | return func(p *TogetherAIClient) error {
43 | if apiToken == "" {
44 | return errors.New("API token cannot be empty")
45 | }
46 | p.APIToken = apiToken
47 | return nil
48 | }
49 | }
50 |
51 | func WithEnvAPIKey() Option {
52 | return func(p *TogetherAIClient) error {
53 | if apiToken := os.Getenv("TOGETHER_API_KEY"); apiToken != "" {
54 | p.APIToken = apiToken
55 | return nil
56 | }
57 | return errors.New("TOGETHER_API_KEY not set")
58 | }
59 | }
60 |
61 | func WithHTTPClient(client *http.Client) Option {
62 | return func(p *TogetherAIClient) error {
63 | if client == nil {
64 | return errors.New("HTTP client cannot be nil")
65 | }
66 | p.Client = client
67 | return nil
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/pkg/embeddings/together/together_test.go:
--------------------------------------------------------------------------------
1 | //go:build ef
2 |
3 | package together
4 |
5 | import (
6 | "context"
7 | "net/http"
8 | "os"
9 | "testing"
10 |
11 | "github.com/joho/godotenv"
12 | "github.com/stretchr/testify/assert"
13 | "github.com/stretchr/testify/require"
14 | )
15 |
16 | func Test_client(t *testing.T) {
17 | apiKey := os.Getenv("TOGETHER_API_KEY")
18 | if apiKey == "" {
19 | err := godotenv.Load("../../../.env")
20 | if err != nil {
21 | assert.Failf(t, "Error loading .env file", "%s", err)
22 | }
23 | apiKey = os.Getenv("TOGETHER_API_KEY")
24 | }
25 | client, err := NewTogetherClient(WithEnvAPIKey())
26 | require.NoError(t, err)
27 |
28 | t.Run("Test CreateEmbedding", func(t *testing.T) {
29 | req := CreateEmbeddingRequest{
30 | Model: "togethercomputer/m2-bert-80M-8k-retrieval",
31 | Input: &EmbeddingInputs{Input: "Test document"},
32 | }
33 | resp, rerr := client.CreateEmbedding(context.Background(), &req)
34 |
35 | require.Nil(t, rerr)
36 | require.NotNil(t, resp)
37 | require.NotNil(t, resp.Data)
38 | require.Len(t, resp.Data, 1)
39 | })
40 | }
41 |
42 | func Test_together_embedding_function(t *testing.T) {
43 | apiKey := os.Getenv("TOGETHER_API_KEY")
44 | if apiKey == "" {
45 | err := godotenv.Load("../../../.env")
46 | if err != nil {
47 | assert.Failf(t, "Error loading .env file", "%s", err)
48 | }
49 | }
50 |
51 | t.Run("Test EmbedDocuments with env-based API Key", func(t *testing.T) {
52 | client, err := NewTogetherEmbeddingFunction(WithEnvAPIKey())
53 | require.NoError(t, err)
54 | resp, rerr := client.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"})
55 |
56 | require.Nil(t, rerr)
57 | require.NotNil(t, resp)
58 | require.Len(t, resp, 2)
59 | require.Equal(t, 768, resp[0].Len())
60 |
61 | })
62 |
63 | t.Run("Test EmbedDocuments for model with env-based API Key", func(t *testing.T) {
64 | client, err := NewTogetherEmbeddingFunction(WithEnvAPIKey(), WithDefaultModel("togethercomputer/m2-bert-80M-2k-retrieval"))
65 | require.NoError(t, err)
66 | resp, rerr := client.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"})
67 |
68 | require.Nil(t, rerr)
69 | require.NotNil(t, resp)
70 | require.Len(t, resp, 2)
71 | require.Equal(t, 768, resp[0].Len())
72 | })
73 |
74 | t.Run("Test EmbedDocuments with too large init batch", func(t *testing.T) {
75 | _, err := NewTogetherEmbeddingFunction(WithEnvAPIKey(), WithMaxBatchSize(200))
76 | require.Error(t, err)
77 | require.Contains(t, err.Error(), "max batch size must be less than")
78 | })
79 |
80 | t.Run("Test EmbedDocuments with too large batch at inference", func(t *testing.T) {
81 | client, err := NewTogetherEmbeddingFunction(WithEnvAPIKey())
82 | require.NoError(t, err)
83 | docs200 := make([]string, 200)
84 | for i := 0; i < 200; i++ {
85 | docs200[i] = "Test document"
86 | }
87 | _, err = client.EmbedDocuments(context.Background(), docs200)
88 | require.Error(t, err)
89 | require.Contains(t, err.Error(), "number of documents exceeds the maximum batch")
90 | })
91 |
92 | t.Run("Test EmbedQuery", func(t *testing.T) {
93 | client, err := NewTogetherEmbeddingFunction(WithEnvAPIKey())
94 | require.NoError(t, err)
95 | resp, err := client.EmbedQuery(context.Background(), "Test query")
96 | require.Nil(t, err)
97 | require.NotNil(t, resp)
98 | require.Equal(t, 768, resp.Len())
99 | })
100 |
101 | t.Run("Test EmbedDocuments with env-based API Key and WithDefaultHeaders", func(t *testing.T) {
102 | client, err := NewTogetherEmbeddingFunction(WithEnvAPIKey(), WithDefaultModel("togethercomputer/m2-bert-80M-2k-retrieval"), WithDefaultHeaders(map[string]string{"X-Test-Header": "test"}))
103 | require.NoError(t, err)
104 | resp, rerr := client.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"})
105 |
106 | require.Nil(t, rerr)
107 | require.NotNil(t, resp)
108 | require.Len(t, resp, 2)
109 | require.Equal(t, 768, resp[0].Len())
110 | })
111 |
112 | t.Run("Test EmbedDocuments with var API Key", func(t *testing.T) {
113 | client, err := NewTogetherEmbeddingFunction(WithAPIToken(os.Getenv("TOGETHER_API_KEY")))
114 | require.NoError(t, err)
115 | resp, rerr := client.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"})
116 |
117 | require.Nil(t, rerr)
118 | require.NotNil(t, resp)
119 | require.Len(t, resp, 2)
120 | require.Equal(t, 768, resp[0].Len())
121 | })
122 |
123 | t.Run("Test EmbedDocuments with var token and account id and http client", func(t *testing.T) {
124 | client, err := NewTogetherEmbeddingFunction(WithAPIToken(os.Getenv("TOGETHER_API_KEY")), WithHTTPClient(http.DefaultClient))
125 | require.NoError(t, err)
126 | resp, rerr := client.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"})
127 |
128 | require.Nil(t, rerr)
129 | require.NotNil(t, resp)
130 | require.Equal(t, 2, len(resp))
131 | require.Equal(t, 768, resp[0].Len())
132 | })
133 | }
134 |
--------------------------------------------------------------------------------
/pkg/embeddings/voyage/option.go:
--------------------------------------------------------------------------------
1 | package voyage
2 |
3 | import (
4 | "net/http"
5 | "os"
6 |
7 | "github.com/pkg/errors"
8 |
9 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
10 | )
11 |
12 | type Option func(p *VoyageAIClient) error
13 |
14 | func WithDefaultModel(model embeddings.EmbeddingModel) Option {
15 | return func(p *VoyageAIClient) error {
16 | if model == "" {
17 | return errors.New("model cannot be empty")
18 | }
19 | p.DefaultModel = model
20 | return nil
21 | }
22 | }
23 |
24 | func WithMaxBatchSize(size int) Option {
25 | return func(p *VoyageAIClient) error {
26 | if size <= 0 {
27 | return errors.New("max batch size must be greater than 0")
28 | }
29 | p.MaxBatchSize = size
30 | return nil
31 | }
32 | }
33 |
34 | func WithDefaultHeaders(headers map[string]string) Option {
35 | return func(p *VoyageAIClient) error {
36 | p.DefaultHeaders = headers
37 | return nil
38 | }
39 | }
40 |
41 | func WithAPIKey(apiToken string) Option {
42 | return func(p *VoyageAIClient) error {
43 | if apiToken == "" {
44 | return errors.New("API key cannot be empty")
45 | }
46 | p.APIKey = apiToken
47 | return nil
48 | }
49 | }
50 |
51 | func WithEnvAPIKey() Option {
52 | return func(p *VoyageAIClient) error {
53 | if apiToken := os.Getenv(APIKeyEnvVar); apiToken != "" {
54 | p.APIKey = apiToken
55 | return nil
56 | }
57 | return errors.Errorf("%s not set", APIKeyEnvVar)
58 | }
59 | }
60 |
61 | func WithHTTPClient(client *http.Client) Option {
62 | return func(p *VoyageAIClient) error {
63 | if client == nil {
64 | return errors.New("HTTP client cannot be nil")
65 | }
66 | p.Client = client
67 | return nil
68 | }
69 | }
70 |
71 | func WithTruncation(truncation bool) Option {
72 | return func(p *VoyageAIClient) error {
73 | p.DefaultTruncation = &truncation
74 | return nil
75 | }
76 | }
77 |
78 | func WithEncodingFormat(format EncodingFormat) Option {
79 | return func(p *VoyageAIClient) error {
80 | if format == "" {
81 | return errors.New("encoding format cannot be empty")
82 | }
83 | var defaultEncodingFormat = format
84 | p.DefaultEncodingFormat = &defaultEncodingFormat
85 | return nil
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/pkg/rerankings/cohere/option.go:
--------------------------------------------------------------------------------
1 | package cohere
2 |
3 | import (
4 | ccommons "github.com/amikos-tech/chroma-go/pkg/commons/cohere"
5 | httpc "github.com/amikos-tech/chroma-go/pkg/commons/http"
6 | "github.com/amikos-tech/chroma-go/pkg/embeddings"
7 | "github.com/amikos-tech/chroma-go/pkg/rerankings"
8 | )
9 |
10 | type Option func(p *CohereRerankingFunction) ccommons.Option
11 |
12 | func WithBaseURL(baseURL string) Option {
13 | return func(p *CohereRerankingFunction) ccommons.Option {
14 | return ccommons.WithBaseURL(baseURL)
15 | }
16 | }
17 |
18 | func WithDefaultModel(model rerankings.RerankingModel) Option {
19 | return func(p *CohereRerankingFunction) ccommons.Option {
20 | return ccommons.WithDefaultModel(embeddings.EmbeddingModel(model))
21 | }
22 | }
23 |
24 | func WithAPIKey(apiKey string) Option {
25 | return func(p *CohereRerankingFunction) ccommons.Option {
26 | return ccommons.WithAPIKey(apiKey)
27 | }
28 | }
29 |
30 | // WithEnvAPIKey configures the client to use the COHERE_API_KEY environment variable as the API key
31 | func WithEnvAPIKey() Option {
32 | return func(p *CohereRerankingFunction) ccommons.Option {
33 | return ccommons.WithEnvAPIKey()
34 | }
35 | }
36 |
37 | func WithTopN(topN int) Option {
38 | return func(p *CohereRerankingFunction) ccommons.Option {
39 | p.TopN = topN
40 | return ccommons.NoOp()
41 | }
42 | }
43 |
44 | // WithRerankFields configures the client to use the specified fields for reranking if the documents are in JSON format
45 | func WithRerankFields(fields []string) Option {
46 | return func(p *CohereRerankingFunction) ccommons.Option {
47 | p.RerankFields = fields
48 | return ccommons.NoOp()
49 | }
50 | }
51 |
52 | // WithReturnDocuments configures the client to return the original documents in the response
53 | func WithReturnDocuments() Option {
54 | return func(p *CohereRerankingFunction) ccommons.Option {
55 | p.ReturnDocuments = true
56 | return ccommons.NoOp()
57 | }
58 | }
59 |
60 | func WithMaxChunksPerDoc(maxChunks int) Option {
61 | return func(p *CohereRerankingFunction) ccommons.Option {
62 | p.MaxChunksPerDoc = maxChunks
63 | return ccommons.NoOp()
64 | }
65 | }
66 |
67 | // WithRetryStrategy configures the client to use the specified retry strategy
68 | func WithRetryStrategy(retryStrategy httpc.RetryStrategy) Option {
69 | return func(p *CohereRerankingFunction) ccommons.Option {
70 | return ccommons.WithRetryStrategy(retryStrategy)
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/pkg/rerankings/hf/huggingface.go:
--------------------------------------------------------------------------------
1 | package huggingface
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "fmt"
8 | "io"
9 | "net/http"
10 |
11 | chromago "github.com/amikos-tech/chroma-go"
12 | chttp "github.com/amikos-tech/chroma-go/pkg/commons/http"
13 | "github.com/amikos-tech/chroma-go/pkg/rerankings"
14 | )
15 |
16 | const (
17 | DefaultBaseAPIEndpoint = "http://127.0.0.1:8080/rerank"
18 | )
19 |
20 | type RerankingRequest struct {
21 | Model *string `json:"model,omitempty"`
22 | Query string `json:"query"`
23 | Texts []string `json:"texts"`
24 | }
25 |
26 | type RerankingResponse []struct {
27 | Index int `json:"index"`
28 | Score float32 `json:"score"`
29 | }
30 |
31 | var _ rerankings.RerankingFunction = (*HFRerankingFunction)(nil)
32 |
33 | func getDefaults() *HFRerankingFunction {
34 | return &HFRerankingFunction{
35 | httpClient: http.DefaultClient,
36 | rerankingEndpoint: DefaultBaseAPIEndpoint,
37 | }
38 | }
39 |
40 | type HFRerankingFunction struct {
41 | httpClient *http.Client
42 | apiKey string
43 | defaultModel *rerankings.RerankingModel
44 | rerankingEndpoint string
45 | }
46 |
47 | func NewHFRerankingFunction(opts ...Option) (*HFRerankingFunction, error) {
48 | ef := getDefaults()
49 | for _, opt := range opts {
50 | err := opt(ef)
51 | if err != nil {
52 | return nil, err
53 | }
54 | }
55 | return ef, nil
56 | }
57 |
58 | func (r *HFRerankingFunction) sendRequest(ctx context.Context, req *RerankingRequest) (*RerankingResponse, error) {
59 | payload, err := json.Marshal(req)
60 | if err != nil {
61 | return nil, err
62 | }
63 |
64 | httpReq, err := http.NewRequest("POST", r.rerankingEndpoint, bytes.NewBuffer(payload))
65 | if err != nil {
66 | return nil, err
67 | }
68 |
69 | httpReq.Header.Set("Content-Type", "application/json")
70 | httpReq.Header.Set("Accept", "application/json")
71 | httpReq.Header.Set("User-Agent", chttp.ChromaGoClientUserAgent)
72 | if r.apiKey != "" {
73 | httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", r.apiKey))
74 | }
75 |
76 | resp, err := r.httpClient.Do(httpReq)
77 | if err != nil {
78 | return nil, err
79 | }
80 | defer resp.Body.Close()
81 | respData, err := io.ReadAll(resp.Body)
82 | if err != nil {
83 | return nil, err
84 | }
85 |
86 | if resp.StatusCode != http.StatusOK {
87 | // TODO serialize body in error
88 | return nil, fmt.Errorf("unexpected response %v: %s", resp.Status, respData)
89 | }
90 | var response *RerankingResponse
91 | if err := json.Unmarshal(respData, &response); err != nil {
92 | return nil, err
93 | }
94 |
95 | return response, nil
96 | }
97 |
98 | func (r *HFRerankingFunction) Rerank(ctx context.Context, query string, results []rerankings.Result) (map[string][]rerankings.RankedResult, error) {
99 | docs := make([]string, 0)
100 | for _, result := range results {
101 | d, err := result.ToText()
102 | if err != nil {
103 | return nil, err
104 | }
105 | docs = append(docs, d)
106 | }
107 | req := &RerankingRequest{
108 | Model: (*string)(r.defaultModel),
109 | Texts: docs,
110 | Query: query,
111 | }
112 |
113 | rerankResp, err := r.sendRequest(ctx, req)
114 | if err != nil {
115 | return nil, err
116 | }
117 |
118 | rankedResults := map[string][]rerankings.RankedResult{r.ID(): make([]rerankings.RankedResult, len(*rerankResp))}
119 | for i, rr := range *rerankResp {
120 | originalDoc, err := results[rr.Index].ToText()
121 | if err != nil {
122 | return nil, err
123 | }
124 |
125 | rankedResults[r.ID()][i] = rerankings.RankedResult{
126 | String: originalDoc,
127 | Index: rr.Index,
128 | Rank: rr.Score,
129 | }
130 | }
131 | return rankedResults, nil
132 | }
133 |
134 | // ID returns the of the reranking function. We use `cohere-` prefix with the default model
135 | func (r *HFRerankingFunction) ID() string {
136 | if r.defaultModel != nil {
137 | return fmt.Sprintf("hf-%s", *(*string)(r.defaultModel))
138 | }
139 | return "hfei"
140 | }
141 |
142 | func (r *HFRerankingFunction) RerankResults(ctx context.Context, queryResults *chromago.QueryResults) (*rerankings.RerankedChromaResults, error) {
143 | rerankedResults := &rerankings.RerankedChromaResults{
144 | QueryResults: *queryResults,
145 | Ranks: map[string][][]float32{r.ID(): make([][]float32, len(queryResults.Ids))},
146 | }
147 | for i, rs := range queryResults.Ids {
148 | if len(rs) == 0 {
149 | return nil, fmt.Errorf("no results to rerank")
150 | }
151 | docs := make([]string, 0)
152 | docs = append(docs, queryResults.Documents[i]...)
153 | req := &RerankingRequest{
154 | Model: (*string)(r.defaultModel),
155 | Texts: docs,
156 | Query: queryResults.QueryTexts[i],
157 | }
158 | rerankResp, err := r.sendRequest(ctx, req)
159 | if err != nil {
160 | return nil, err
161 | }
162 | rerankedResults.Ranks[r.ID()][i] = make([]float32, len(*rerankResp))
163 | for _, rr := range *rerankResp {
164 | rerankedResults.Ranks[r.ID()][i][rr.Index] = rr.Score
165 | }
166 | }
167 | return rerankedResults, nil
168 | }
169 |
--------------------------------------------------------------------------------
/pkg/rerankings/hf/option.go:
--------------------------------------------------------------------------------
1 | package huggingface
2 |
3 | import (
4 | "fmt"
5 | "os"
6 |
7 | "github.com/amikos-tech/chroma-go/pkg/rerankings"
8 | )
9 |
10 | type Option func(c *HFRerankingFunction) error
11 |
12 | func WithAPIKey(apiKey string) Option {
13 | return func(c *HFRerankingFunction) error {
14 | c.apiKey = apiKey
15 | return nil
16 | }
17 | }
18 |
19 | func WithEnvAPIKey() Option {
20 | return func(c *HFRerankingFunction) error {
21 | if os.Getenv("HF_API_KEY") == "" {
22 | return fmt.Errorf("HF_API_KEY not set")
23 | }
24 | c.apiKey = os.Getenv("HF_API_KEY")
25 | return nil
26 | }
27 | }
28 |
29 | func WithModel(model rerankings.RerankingModel) Option {
30 | return func(c *HFRerankingFunction) error {
31 | c.defaultModel = &model
32 | return nil
33 | }
34 | }
35 |
36 | func WithRerankingEndpoint(endpoint string) Option {
37 | return func(c *HFRerankingFunction) error {
38 | c.rerankingEndpoint = endpoint
39 | return nil
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/pkg/rerankings/jina/option.go:
--------------------------------------------------------------------------------
1 | package jina
2 |
3 | import (
4 | "fmt"
5 | "os"
6 |
7 | "github.com/amikos-tech/chroma-go/types"
8 | )
9 |
10 | type Option func(c *JinaRerankingFunction) error
11 |
12 | func WithAPIKey(apiKey string) Option {
13 | return func(c *JinaRerankingFunction) error {
14 | c.apiKey = apiKey
15 | return nil
16 | }
17 | }
18 |
19 | func WithEnvAPIKey() Option {
20 | return func(c *JinaRerankingFunction) error {
21 | if os.Getenv("JINA_API_KEY") == "" {
22 | return fmt.Errorf("JINA_API_KEY not set")
23 | }
24 | c.apiKey = os.Getenv("JINA_API_KEY")
25 | return nil
26 | }
27 | }
28 |
29 | func WithModel(model types.RerankingModel) Option {
30 | return func(c *JinaRerankingFunction) error {
31 | c.defaultModel = model
32 | return nil
33 | }
34 | }
35 |
36 | func WithRerankingEndpoint(endpoint string) Option {
37 | return func(c *JinaRerankingFunction) error {
38 | c.rerankingEndpoint = endpoint
39 | return nil
40 | }
41 | }
42 |
43 | func WithTopN(topN int) Option {
44 | return func(c *JinaRerankingFunction) error {
45 | if topN <= 0 {
46 | return fmt.Errorf("topN must be a positive integer")
47 | }
48 | c.topN = &topN
49 | return nil
50 | }
51 | }
52 |
53 | func WithReturnDocuments(returnDocuments bool) Option {
54 | return func(c *JinaRerankingFunction) error {
55 | c.returnDocuments = &returnDocuments
56 | return nil
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/pkg/rerankings/reranking.go:
--------------------------------------------------------------------------------
1 | package rerankings
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 |
8 | chromago "github.com/amikos-tech/chroma-go"
9 | )
10 |
11 | type RerankingModel string
12 |
13 | type RankedResult struct {
14 | Index int // Index in the original input []string
15 | String string
16 | Rank float32
17 | }
18 |
19 | type RerankedChromaResults struct {
20 | chromago.QueryResults
21 | Ranks map[string][][]float32 // each reranker adds a rank for each result
22 | }
23 |
24 | type Result struct {
25 | Text *string
26 | Object *any
27 | }
28 |
29 | func FromText(text string) Result {
30 | return Result{
31 | Text: &text,
32 | }
33 | }
34 |
35 | func FromTexts(texts []string) []Result {
36 | results := make([]Result, len(texts))
37 | for i, text := range texts {
38 | results[i] = FromText(text)
39 | }
40 | return results
41 | }
42 |
43 | func FromObject(object any) Result {
44 | return Result{
45 | Object: &object,
46 | }
47 | }
48 |
49 | func FromObjects(objects []any) []Result {
50 | results := make([]Result, len(objects))
51 | for i, object := range objects {
52 | results[i] = FromObject(object)
53 | }
54 | return results
55 | }
56 |
57 | func (r *Result) ToText() (string, error) {
58 | if r.IsText() {
59 | return *r.Text, nil
60 | } else if r.IsObject() {
61 | marshal, err := json.Marshal(r.Object)
62 | if err != nil {
63 | return "", err
64 | }
65 | return string(marshal), nil
66 | }
67 | return "", fmt.Errorf("result is neither text nor object")
68 | }
69 |
70 | func (r *Result) IsText() bool {
71 | return r.Text != nil
72 | }
73 |
74 | func (r *Result) IsObject() bool {
75 | return r.Object != nil
76 | }
77 |
78 | type RerankingFunction interface {
79 | ID() string
80 | Rerank(ctx context.Context, query string, results []Result) (map[string][]RankedResult, error)
81 | RerankResults(ctx context.Context, queryResults *chromago.QueryResults) (*RerankedChromaResults, error)
82 | }
83 |
--------------------------------------------------------------------------------
/pkg/rerankings/reranking_test.go:
--------------------------------------------------------------------------------
1 | //go:build rf
2 |
3 | package rerankings
4 |
5 | import (
6 | "context"
7 | "fmt"
8 | "math/rand"
9 | "testing"
10 |
11 | "github.com/stretchr/testify/require"
12 |
13 | chromago "github.com/amikos-tech/chroma-go"
14 | )
15 |
16 | type DummyRerankingFunction struct {
17 | }
18 |
19 | func NewDummyRerankingFunction() *DummyRerankingFunction {
20 | return &DummyRerankingFunction{}
21 | }
22 |
23 | func (d *DummyRerankingFunction) ID() string {
24 | return "dummy"
25 | }
26 | func (d *DummyRerankingFunction) Rerank(_ context.Context, _ string, results []Result) (map[string][]RankedResult, error) {
27 | if len(results) == 0 {
28 | return nil, fmt.Errorf("no results to rerank")
29 | }
30 | rerankedResults := make([]RankedResult, len(results))
31 | for i, result := range results {
32 | doc, err := result.ToText()
33 | if err != nil {
34 | return nil, err
35 | }
36 | rerankedResults[i] = RankedResult{
37 | String: doc,
38 | Index: i,
39 | Rank: rand.Float32(),
40 | }
41 | }
42 | return map[string][]RankedResult{d.ID(): rerankedResults}, nil
43 | }
44 |
45 | func (d *DummyRerankingFunction) RerankResults(_ context.Context, queryResults *chromago.QueryResults) (*RerankedChromaResults, error) {
46 | if len(queryResults.Ids) == 0 {
47 | return nil, fmt.Errorf("no results to rerank")
48 | }
49 | results := &RerankedChromaResults{
50 | QueryResults: *queryResults,
51 | Ranks: map[string][][]float32{d.ID(): make([][]float32, len(queryResults.Ids))},
52 | }
53 | for i, qr := range queryResults.Ids {
54 | results.Ranks[d.ID()][i] = make([]float32, len(qr))
55 | for j := range qr {
56 | results.Ranks[d.ID()][i][j] = rand.Float32()
57 | }
58 | }
59 | return results, nil
60 | }
61 |
62 | func Test_reranking_function(t *testing.T) {
63 | rerankingFunction := NewDummyRerankingFunction()
64 | t.Run("Rerank string results", func(t *testing.T) {
65 | query := "hello world"
66 | results := []string{"hello", "world"}
67 | rerankedResults, err := rerankingFunction.Rerank(context.Background(), query, FromTexts(results))
68 | require.NoError(t, err)
69 | require.NotNil(t, rerankedResults)
70 | require.Contains(t, rerankedResults, rerankingFunction.ID())
71 | require.Equal(t, len(results), len(rerankedResults[rerankingFunction.ID()]))
72 | for _, result := range rerankedResults[rerankingFunction.ID()] {
73 | require.Equal(t, results[result.Index], result.String)
74 | }
75 | })
76 |
77 | t.Run("Rerank chroma results", func(t *testing.T) {
78 | query := "hello world"
79 | results := &chromago.QueryResults{
80 | Ids: [][]string{{"1"}, {"2"}},
81 | Documents: [][]string{{"hello"}, {"world"}},
82 | Distances: [][]float32{{0.1}, {0.2}},
83 | QueryTexts: []string{query},
84 | }
85 | rerankedResults, err := rerankingFunction.RerankResults(context.Background(), results)
86 | require.NoError(t, err)
87 | require.NotNil(t, rerankedResults)
88 | require.Contains(t, rerankedResults.Ranks, rerankingFunction.ID())
89 | require.Equal(t, len(results.Ids), len(rerankedResults.Ids))
90 | require.Equal(t, results.Ids, rerankedResults.Ids)
91 | require.Equal(t, results.Documents, rerankedResults.Documents)
92 | require.Equal(t, results.QueryTexts, rerankedResults.QueryTexts)
93 | })
94 | }
95 |
--------------------------------------------------------------------------------
/pkg/tokenizers/libtokenizers/tokenizers.h:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | struct EncodeOptions {
5 | bool add_special_token;
6 | bool return_type_ids;
7 | bool return_tokens;
8 | bool return_special_tokens_mask;
9 | bool return_attention_mask;
10 | bool return_offsets;
11 | };
12 |
13 | struct TokenizerOptions {
14 | bool encode_special_tokens;
15 | };
16 |
17 | struct Buffer {
18 | uint32_t *ids;
19 | uint32_t *type_ids;
20 | uint32_t *special_tokens_mask;
21 | uint32_t *attention_mask;
22 | char *tokens;
23 | size_t *offsets;
24 | uint32_t len;
25 | };
26 |
27 | void *from_bytes(const uint8_t *config, uint32_t len, const struct TokenizerOptions *options);
28 |
29 | void *from_bytes_with_truncation(const uint8_t *config, uint32_t len, uint32_t max_len, uint8_t direction);
30 |
31 | void *from_file(const char *config);
32 |
33 | struct Buffer encode(void *ptr, const char *message, const struct EncodeOptions *options);
34 |
35 | char *decode(void *ptr, const uint32_t *ids, uint32_t len, bool skip_special_tokens);
36 |
37 | uint32_t vocab_size(void *ptr);
38 |
39 | void free_tokenizer(void *ptr);
40 |
41 | void free_buffer(struct Buffer buffer);
42 |
43 | void free_string(char *string);
44 |
--------------------------------------------------------------------------------
/scripts/chroma_server.sh:
--------------------------------------------------------------------------------
1 | #!/usr/scripts/env bash
2 |
3 | docker run --rm -it -p 8000:8000 -e ALLOW_RESET=TRUE chromadb/chroma:latest
4 |
--------------------------------------------------------------------------------
/swagger/model_collection.go:
--------------------------------------------------------------------------------
1 | /*
2 | ChromaDB API
3 |
4 | This is OpenAPI schema for ChromaDB API.
5 |
6 | API version: 1.0.0
7 | */
8 |
9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
10 |
11 | package openapi
12 |
13 | import (
14 | "encoding/json"
15 | )
16 |
17 | // checks if the Collection type satisfies the MappedNullable interface at compile time
18 | var _ MappedNullable = &Collection{}
19 |
20 | // Collection struct for Collection
21 | type Collection struct {
22 | Name string `json:"name"`
23 | Id string `json:"id"`
24 | Metadata *map[string]Metadata `json:"metadata,omitempty"`
25 | }
26 |
27 | // NewCollection instantiates a new Collection object
28 | // This constructor will assign default values to properties that have it defined,
29 | // and makes sure properties required by API are set, but the set of arguments
30 | // will change when the set of required properties is changed
31 | func NewCollection(name string, id string) *Collection {
32 | this := Collection{}
33 | this.Name = name
34 | this.Id = id
35 | return &this
36 | }
37 |
38 | // NewCollectionWithDefaults instantiates a new Collection object
39 | // This constructor will only assign default values to properties that have it defined,
40 | // but it doesn't guarantee that properties required by API are set
41 | func NewCollectionWithDefaults() *Collection {
42 | this := Collection{}
43 | return &this
44 | }
45 |
46 | // GetName returns the Name field value
47 | func (o *Collection) GetName() string {
48 | if o == nil {
49 | var ret string
50 | return ret
51 | }
52 |
53 | return o.Name
54 | }
55 |
56 | // GetNameOk returns a tuple with the Name field value
57 | // and a boolean to check if the value has been set.
58 | func (o *Collection) GetNameOk() (*string, bool) {
59 | if o == nil {
60 | return nil, false
61 | }
62 | return &o.Name, true
63 | }
64 |
65 | // SetName sets field value
66 | func (o *Collection) SetName(v string) {
67 | o.Name = v
68 | }
69 |
70 | // GetId returns the Id field value
71 | func (o *Collection) GetId() string {
72 | if o == nil {
73 | var ret string
74 | return ret
75 | }
76 |
77 | return o.Id
78 | }
79 |
80 | // GetIdOk returns a tuple with the Id field value
81 | // and a boolean to check if the value has been set.
82 | func (o *Collection) GetIdOk() (*string, bool) {
83 | if o == nil {
84 | return nil, false
85 | }
86 | return &o.Id, true
87 | }
88 |
89 | // SetId sets field value
90 | func (o *Collection) SetId(v string) {
91 | o.Id = v
92 | }
93 |
94 | // GetMetadata returns the Metadata field value if set, zero value otherwise.
95 | func (o *Collection) GetMetadata() map[string]Metadata {
96 | if o == nil || IsNil(o.Metadata) {
97 | var ret map[string]Metadata
98 | return ret
99 | }
100 | return *o.Metadata
101 | }
102 |
103 | // GetMetadataOk returns a tuple with the Metadata field value if set, nil otherwise
104 | // and a boolean to check if the value has been set.
105 | func (o *Collection) GetMetadataOk() (*map[string]Metadata, bool) {
106 | if o == nil || IsNil(o.Metadata) {
107 | return nil, false
108 | }
109 | return o.Metadata, true
110 | }
111 |
112 | // HasMetadata returns a boolean if a field has been set.
113 | func (o *Collection) HasMetadata() bool {
114 | if o != nil && !IsNil(o.Metadata) {
115 | return true
116 | }
117 |
118 | return false
119 | }
120 |
121 | // SetMetadata gets a reference to the given map[string]Metadata and assigns it to the Metadata field.
122 | func (o *Collection) SetMetadata(v map[string]Metadata) {
123 | o.Metadata = &v
124 | }
125 |
126 | func (o Collection) MarshalJSON() ([]byte, error) {
127 | toSerialize, err := o.ToMap()
128 | if err != nil {
129 | return []byte{}, err
130 | }
131 | return json.Marshal(toSerialize)
132 | }
133 |
134 | func (o Collection) ToMap() (map[string]interface{}, error) {
135 | toSerialize := map[string]interface{}{}
136 | toSerialize["name"] = o.Name
137 | toSerialize["id"] = o.Id
138 | if !IsNil(o.Metadata) {
139 | toSerialize["metadata"] = o.Metadata
140 | }
141 | return toSerialize, nil
142 | }
143 |
144 | type NullableCollection struct {
145 | value *Collection
146 | isSet bool
147 | }
148 |
149 | func (v NullableCollection) Get() *Collection {
150 | return v.value
151 | }
152 |
153 | func (v *NullableCollection) Set(val *Collection) {
154 | v.value = val
155 | v.isSet = true
156 | }
157 |
158 | func (v NullableCollection) IsSet() bool {
159 | return v.isSet
160 | }
161 |
162 | func (v *NullableCollection) Unset() {
163 | v.value = nil
164 | v.isSet = false
165 | }
166 |
167 | func NewNullableCollection(val *Collection) *NullableCollection {
168 | return &NullableCollection{value: val, isSet: true}
169 | }
170 |
171 | func (v NullableCollection) MarshalJSON() ([]byte, error) {
172 | return json.Marshal(v.value)
173 | }
174 |
175 | func (v *NullableCollection) UnmarshalJSON(src []byte) error {
176 | v.isSet = true
177 | return json.Unmarshal(src, &v.value)
178 | }
179 |
--------------------------------------------------------------------------------
/swagger/model_create_database.go:
--------------------------------------------------------------------------------
1 | /*
2 | ChromaDB API
3 |
4 | This is OpenAPI schema for ChromaDB API.
5 |
6 | API version: 1.0.0
7 | */
8 |
9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
10 |
11 | package openapi
12 |
13 | import (
14 | "encoding/json"
15 | )
16 |
17 | // checks if the CreateDatabase type satisfies the MappedNullable interface at compile time
18 | var _ MappedNullable = &CreateDatabase{}
19 |
20 | // CreateDatabase struct for CreateDatabase
21 | type CreateDatabase struct {
22 | Name string `json:"name"`
23 | }
24 |
25 | // NewCreateDatabase instantiates a new CreateDatabase object
26 | // This constructor will assign default values to properties that have it defined,
27 | // and makes sure properties required by API are set, but the set of arguments
28 | // will change when the set of required properties is changed
29 | func NewCreateDatabase(name string) *CreateDatabase {
30 | this := CreateDatabase{}
31 | this.Name = name
32 | return &this
33 | }
34 |
35 | // NewCreateDatabaseWithDefaults instantiates a new CreateDatabase object
36 | // This constructor will only assign default values to properties that have it defined,
37 | // but it doesn't guarantee that properties required by API are set
38 | func NewCreateDatabaseWithDefaults() *CreateDatabase {
39 | this := CreateDatabase{}
40 | return &this
41 | }
42 |
43 | // GetName returns the Name field value
44 | func (o *CreateDatabase) GetName() string {
45 | if o == nil {
46 | var ret string
47 | return ret
48 | }
49 |
50 | return o.Name
51 | }
52 |
53 | // GetNameOk returns a tuple with the Name field value
54 | // and a boolean to check if the value has been set.
55 | func (o *CreateDatabase) GetNameOk() (*string, bool) {
56 | if o == nil {
57 | return nil, false
58 | }
59 | return &o.Name, true
60 | }
61 |
62 | // SetName sets field value
63 | func (o *CreateDatabase) SetName(v string) {
64 | o.Name = v
65 | }
66 |
67 | func (o CreateDatabase) MarshalJSON() ([]byte, error) {
68 | toSerialize, err := o.ToMap()
69 | if err != nil {
70 | return []byte{}, err
71 | }
72 | return json.Marshal(toSerialize)
73 | }
74 |
75 | func (o CreateDatabase) ToMap() (map[string]interface{}, error) {
76 | toSerialize := map[string]interface{}{}
77 | toSerialize["name"] = o.Name
78 | return toSerialize, nil
79 | }
80 |
81 | type NullableCreateDatabase struct {
82 | value *CreateDatabase
83 | isSet bool
84 | }
85 |
86 | func (v NullableCreateDatabase) Get() *CreateDatabase {
87 | return v.value
88 | }
89 |
90 | func (v *NullableCreateDatabase) Set(val *CreateDatabase) {
91 | v.value = val
92 | v.isSet = true
93 | }
94 |
95 | func (v NullableCreateDatabase) IsSet() bool {
96 | return v.isSet
97 | }
98 |
99 | func (v *NullableCreateDatabase) Unset() {
100 | v.value = nil
101 | v.isSet = false
102 | }
103 |
104 | func NewNullableCreateDatabase(val *CreateDatabase) *NullableCreateDatabase {
105 | return &NullableCreateDatabase{value: val, isSet: true}
106 | }
107 |
108 | func (v NullableCreateDatabase) MarshalJSON() ([]byte, error) {
109 | return json.Marshal(v.value)
110 | }
111 |
112 | func (v *NullableCreateDatabase) UnmarshalJSON(src []byte) error {
113 | v.isSet = true
114 | return json.Unmarshal(src, &v.value)
115 | }
116 |
--------------------------------------------------------------------------------
/swagger/model_create_tenant.go:
--------------------------------------------------------------------------------
1 | /*
2 | ChromaDB API
3 |
4 | This is OpenAPI schema for ChromaDB API.
5 |
6 | API version: 1.0.0
7 | */
8 |
9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
10 |
11 | package openapi
12 |
13 | import (
14 | "encoding/json"
15 | )
16 |
17 | // checks if the CreateTenant type satisfies the MappedNullable interface at compile time
18 | var _ MappedNullable = &CreateTenant{}
19 |
20 | // CreateTenant struct for CreateTenant
21 | type CreateTenant struct {
22 | Name string `json:"name"`
23 | }
24 |
25 | // NewCreateTenant instantiates a new CreateTenant object
26 | // This constructor will assign default values to properties that have it defined,
27 | // and makes sure properties required by API are set, but the set of arguments
28 | // will change when the set of required properties is changed
29 | func NewCreateTenant(name string) *CreateTenant {
30 | this := CreateTenant{}
31 | this.Name = name
32 | return &this
33 | }
34 |
35 | // NewCreateTenantWithDefaults instantiates a new CreateTenant object
36 | // This constructor will only assign default values to properties that have it defined,
37 | // but it doesn't guarantee that properties required by API are set
38 | func NewCreateTenantWithDefaults() *CreateTenant {
39 | this := CreateTenant{}
40 | return &this
41 | }
42 |
43 | // GetName returns the Name field value
44 | func (o *CreateTenant) GetName() string {
45 | if o == nil {
46 | var ret string
47 | return ret
48 | }
49 |
50 | return o.Name
51 | }
52 |
53 | // GetNameOk returns a tuple with the Name field value
54 | // and a boolean to check if the value has been set.
55 | func (o *CreateTenant) GetNameOk() (*string, bool) {
56 | if o == nil {
57 | return nil, false
58 | }
59 | return &o.Name, true
60 | }
61 |
62 | // SetName sets field value
63 | func (o *CreateTenant) SetName(v string) {
64 | o.Name = v
65 | }
66 |
67 | func (o CreateTenant) MarshalJSON() ([]byte, error) {
68 | toSerialize, err := o.ToMap()
69 | if err != nil {
70 | return []byte{}, err
71 | }
72 | return json.Marshal(toSerialize)
73 | }
74 |
75 | func (o CreateTenant) ToMap() (map[string]interface{}, error) {
76 | toSerialize := map[string]interface{}{}
77 | toSerialize["name"] = o.Name
78 | return toSerialize, nil
79 | }
80 |
81 | type NullableCreateTenant struct {
82 | value *CreateTenant
83 | isSet bool
84 | }
85 |
86 | func (v NullableCreateTenant) Get() *CreateTenant {
87 | return v.value
88 | }
89 |
90 | func (v *NullableCreateTenant) Set(val *CreateTenant) {
91 | v.value = val
92 | v.isSet = true
93 | }
94 |
95 | func (v NullableCreateTenant) IsSet() bool {
96 | return v.isSet
97 | }
98 |
99 | func (v *NullableCreateTenant) Unset() {
100 | v.value = nil
101 | v.isSet = false
102 | }
103 |
104 | func NewNullableCreateTenant(val *CreateTenant) *NullableCreateTenant {
105 | return &NullableCreateTenant{value: val, isSet: true}
106 | }
107 |
108 | func (v NullableCreateTenant) MarshalJSON() ([]byte, error) {
109 | return json.Marshal(v.value)
110 | }
111 |
112 | func (v *NullableCreateTenant) UnmarshalJSON(src []byte) error {
113 | v.isSet = true
114 | return json.Unmarshal(src, &v.value)
115 | }
116 |
--------------------------------------------------------------------------------
/swagger/model_embeddings_inner.go:
--------------------------------------------------------------------------------
1 | /*
2 | ChromaDB API
3 |
4 | This is OpenAPI schema for ChromaDB API.
5 |
6 | API version: 1.0.0
7 | */
8 |
9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
10 |
11 | package openapi
12 |
13 | import (
14 | "encoding/json"
15 | "fmt"
16 | )
17 |
18 | // EmbeddingsInner struct for EmbeddingsInner
19 | type EmbeddingsInner struct {
20 | ArrayOfFloat32 *[]float32
21 | ArrayOfInt32 *[]int32
22 | }
23 |
24 | // Unmarshal JSON data into any of the pointers in the struct
25 | func (dst *EmbeddingsInner) UnmarshalJSON(data []byte) error {
26 | var err error
27 | // try to unmarshal JSON data into ArrayOfFloat32
28 | err = json.Unmarshal(data, &dst.ArrayOfFloat32)
29 | if err == nil {
30 | jsonArrayOfFloat32, _ := json.Marshal(dst.ArrayOfFloat32)
31 | if string(jsonArrayOfFloat32) == "{}" { // empty struct
32 | dst.ArrayOfFloat32 = nil
33 | } else {
34 | return nil // data stored in dst.ArrayOfFloat32, return on the first match
35 | }
36 | } else {
37 | dst.ArrayOfFloat32 = nil
38 | }
39 |
40 | // try to unmarshal JSON data into ArrayOfInt32
41 | err = json.Unmarshal(data, &dst.ArrayOfInt32)
42 | if err == nil {
43 | jsonArrayOfInt32, _ := json.Marshal(dst.ArrayOfInt32)
44 | if string(jsonArrayOfInt32) == "{}" { // empty struct
45 | dst.ArrayOfInt32 = nil
46 | } else {
47 | return nil // data stored in dst.ArrayOfInt32, return on the first match
48 | }
49 | } else {
50 | dst.ArrayOfInt32 = nil
51 | }
52 |
53 | return fmt.Errorf("data failed to match schemas in anyOf(EmbeddingsInner)")
54 | }
55 |
56 | // Marshal data from the first non-nil pointers in the struct to JSON
57 | func (src *EmbeddingsInner) MarshalJSON() ([]byte, error) {
58 | if src.ArrayOfFloat32 != nil {
59 | return json.Marshal(&src.ArrayOfFloat32)
60 | }
61 |
62 | if src.ArrayOfInt32 != nil {
63 | return json.Marshal(&src.ArrayOfInt32)
64 | }
65 |
66 | return nil, nil // no data in anyOf schemas
67 | }
68 |
69 | type NullableEmbeddingsInner struct {
70 | value *EmbeddingsInner
71 | isSet bool
72 | }
73 |
74 | func (v NullableEmbeddingsInner) Get() *EmbeddingsInner {
75 | return v.value
76 | }
77 |
78 | func (v *NullableEmbeddingsInner) Set(val *EmbeddingsInner) {
79 | v.value = val
80 | v.isSet = true
81 | }
82 |
83 | func (v NullableEmbeddingsInner) IsSet() bool {
84 | return v.isSet
85 | }
86 |
87 | func (v *NullableEmbeddingsInner) Unset() {
88 | v.value = nil
89 | v.isSet = false
90 | }
91 |
92 | func NewNullableEmbeddingsInner(val *EmbeddingsInner) *NullableEmbeddingsInner {
93 | return &NullableEmbeddingsInner{value: val, isSet: true}
94 | }
95 |
96 | func (v NullableEmbeddingsInner) MarshalJSON() ([]byte, error) {
97 | return json.Marshal(v.value)
98 | }
99 |
100 | func (v *NullableEmbeddingsInner) UnmarshalJSON(src []byte) error {
101 | v.isSet = true
102 | return json.Unmarshal(src, &v.value)
103 | }
104 |
--------------------------------------------------------------------------------
/swagger/model_http_validation_error.go:
--------------------------------------------------------------------------------
1 | /*
2 | ChromaDB API
3 |
4 | This is OpenAPI schema for ChromaDB API.
5 |
6 | API version: 1.0.0
7 | */
8 |
9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
10 |
11 | package openapi
12 |
13 | import (
14 | "encoding/json"
15 | )
16 |
17 | // checks if the HTTPValidationError type satisfies the MappedNullable interface at compile time
18 | var _ MappedNullable = &HTTPValidationError{}
19 |
20 | // HTTPValidationError struct for HTTPValidationError
21 | type HTTPValidationError struct {
22 | Detail []ValidationError `json:"detail,omitempty"`
23 | }
24 |
25 | // NewHTTPValidationError instantiates a new HTTPValidationError object
26 | // This constructor will assign default values to properties that have it defined,
27 | // and makes sure properties required by API are set, but the set of arguments
28 | // will change when the set of required properties is changed
29 | func NewHTTPValidationError() *HTTPValidationError {
30 | this := HTTPValidationError{}
31 | return &this
32 | }
33 |
34 | // NewHTTPValidationErrorWithDefaults instantiates a new HTTPValidationError object
35 | // This constructor will only assign default values to properties that have it defined,
36 | // but it doesn't guarantee that properties required by API are set
37 | func NewHTTPValidationErrorWithDefaults() *HTTPValidationError {
38 | this := HTTPValidationError{}
39 | return &this
40 | }
41 |
42 | // GetDetail returns the Detail field value if set, zero value otherwise.
43 | func (o *HTTPValidationError) GetDetail() []ValidationError {
44 | if o == nil || IsNil(o.Detail) {
45 | var ret []ValidationError
46 | return ret
47 | }
48 | return o.Detail
49 | }
50 |
51 | // GetDetailOk returns a tuple with the Detail field value if set, nil otherwise
52 | // and a boolean to check if the value has been set.
53 | func (o *HTTPValidationError) GetDetailOk() ([]ValidationError, bool) {
54 | if o == nil || IsNil(o.Detail) {
55 | return nil, false
56 | }
57 | return o.Detail, true
58 | }
59 |
60 | // HasDetail returns a boolean if a field has been set.
61 | func (o *HTTPValidationError) HasDetail() bool {
62 | if o != nil && !IsNil(o.Detail) {
63 | return true
64 | }
65 |
66 | return false
67 | }
68 |
69 | // SetDetail gets a reference to the given []ValidationError and assigns it to the Detail field.
70 | func (o *HTTPValidationError) SetDetail(v []ValidationError) {
71 | o.Detail = v
72 | }
73 |
74 | func (o HTTPValidationError) MarshalJSON() ([]byte, error) {
75 | toSerialize, err := o.ToMap()
76 | if err != nil {
77 | return []byte{}, err
78 | }
79 | return json.Marshal(toSerialize)
80 | }
81 |
82 | func (o HTTPValidationError) ToMap() (map[string]interface{}, error) {
83 | toSerialize := map[string]interface{}{}
84 | if !IsNil(o.Detail) {
85 | toSerialize["detail"] = o.Detail
86 | }
87 | return toSerialize, nil
88 | }
89 |
90 | type NullableHTTPValidationError struct {
91 | value *HTTPValidationError
92 | isSet bool
93 | }
94 |
95 | func (v NullableHTTPValidationError) Get() *HTTPValidationError {
96 | return v.value
97 | }
98 |
99 | func (v *NullableHTTPValidationError) Set(val *HTTPValidationError) {
100 | v.value = val
101 | v.isSet = true
102 | }
103 |
104 | func (v NullableHTTPValidationError) IsSet() bool {
105 | return v.isSet
106 | }
107 |
108 | func (v *NullableHTTPValidationError) Unset() {
109 | v.value = nil
110 | v.isSet = false
111 | }
112 |
113 | func NewNullableHTTPValidationError(val *HTTPValidationError) *NullableHTTPValidationError {
114 | return &NullableHTTPValidationError{value: val, isSet: true}
115 | }
116 |
117 | func (v NullableHTTPValidationError) MarshalJSON() ([]byte, error) {
118 | return json.Marshal(v.value)
119 | }
120 |
121 | func (v *NullableHTTPValidationError) UnmarshalJSON(src []byte) error {
122 | v.isSet = true
123 | return json.Unmarshal(src, &v.value)
124 | }
125 |
--------------------------------------------------------------------------------
/swagger/model_include_inner.go:
--------------------------------------------------------------------------------
1 | /*
2 | ChromaDB API
3 |
4 | This is OpenAPI schema for ChromaDB API.
5 |
6 | API version: 1.0.0
7 | */
8 |
9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
10 |
11 | package openapi
12 |
13 | import (
14 | "encoding/json"
15 | "fmt"
16 | )
17 |
18 | // IncludeInner struct for IncludeInner
19 | type IncludeInner struct {
20 | String *string
21 | }
22 |
23 | // Unmarshal JSON data into any of the pointers in the struct
24 | func (dst *IncludeInner) UnmarshalJSON(data []byte) error {
25 | var err error
26 | // try to unmarshal JSON data into String
27 | err = json.Unmarshal(data, &dst.String)
28 | if err == nil {
29 | jsonString, _ := json.Marshal(dst.String)
30 | if string(jsonString) == "{}" { // empty struct
31 | dst.String = nil
32 | } else {
33 | return nil // data stored in dst.String, return on the first match
34 | }
35 | } else {
36 | dst.String = nil
37 | }
38 |
39 | return fmt.Errorf("data failed to match schemas in anyOf(IncludeInner)")
40 | }
41 |
42 | // Marshal data from the first non-nil pointers in the struct to JSON
43 | func (src *IncludeInner) MarshalJSON() ([]byte, error) {
44 | if src.String != nil {
45 | return json.Marshal(&src.String)
46 | }
47 |
48 | return nil, nil // no data in anyOf schemas
49 | }
50 |
51 | type NullableIncludeInner struct {
52 | value *IncludeInner
53 | isSet bool
54 | }
55 |
56 | func (v NullableIncludeInner) Get() *IncludeInner {
57 | return v.value
58 | }
59 |
60 | func (v *NullableIncludeInner) Set(val *IncludeInner) {
61 | v.value = val
62 | v.isSet = true
63 | }
64 |
65 | func (v NullableIncludeInner) IsSet() bool {
66 | return v.isSet
67 | }
68 |
69 | func (v *NullableIncludeInner) Unset() {
70 | v.value = nil
71 | v.isSet = false
72 | }
73 |
74 | func NewNullableIncludeInner(val *IncludeInner) *NullableIncludeInner {
75 | return &NullableIncludeInner{value: val, isSet: true}
76 | }
77 |
78 | func (v NullableIncludeInner) MarshalJSON() ([]byte, error) {
79 | return json.Marshal(v.value)
80 | }
81 |
82 | func (v *NullableIncludeInner) UnmarshalJSON(src []byte) error {
83 | v.isSet = true
84 | return json.Unmarshal(src, &v.value)
85 | }
86 |
--------------------------------------------------------------------------------
/swagger/model_location_inner.go:
--------------------------------------------------------------------------------
1 | /*
2 | ChromaDB API
3 |
4 | This is OpenAPI schema for ChromaDB API.
5 |
6 | API version: 1.0.0
7 | */
8 |
9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
10 |
11 | package openapi
12 |
13 | import (
14 | "encoding/json"
15 | "fmt"
16 | )
17 |
18 | // LocationInner struct for LocationInner
19 | type LocationInner struct {
20 | Int32 *int32
21 | String *string
22 | }
23 |
24 | // Unmarshal JSON data into any of the pointers in the struct
25 | func (dst *LocationInner) UnmarshalJSON(data []byte) error {
26 | var err error
27 | // try to unmarshal JSON data into Int32
28 | err = json.Unmarshal(data, &dst.Int32)
29 | if err == nil {
30 | jsonInt32, _ := json.Marshal(dst.Int32)
31 | if string(jsonInt32) == "{}" { // empty struct
32 | dst.Int32 = nil
33 | } else {
34 | return nil // data stored in dst.Int32, return on the first match
35 | }
36 | } else {
37 | dst.Int32 = nil
38 | }
39 |
40 | // try to unmarshal JSON data into String
41 | err = json.Unmarshal(data, &dst.String)
42 | if err == nil {
43 | jsonString, _ := json.Marshal(dst.String)
44 | if string(jsonString) == "{}" { // empty struct
45 | dst.String = nil
46 | } else {
47 | return nil // data stored in dst.String, return on the first match
48 | }
49 | } else {
50 | dst.String = nil
51 | }
52 |
53 | return fmt.Errorf("data failed to match schemas in anyOf(LocationInner)")
54 | }
55 |
56 | // Marshal data from the first non-nil pointers in the struct to JSON
57 | func (src *LocationInner) MarshalJSON() ([]byte, error) {
58 | if src.Int32 != nil {
59 | return json.Marshal(&src.Int32)
60 | }
61 |
62 | if src.String != nil {
63 | return json.Marshal(&src.String)
64 | }
65 |
66 | return nil, nil // no data in anyOf schemas
67 | }
68 |
69 | type NullableLocationInner struct {
70 | value *LocationInner
71 | isSet bool
72 | }
73 |
74 | func (v NullableLocationInner) Get() *LocationInner {
75 | return v.value
76 | }
77 |
78 | func (v *NullableLocationInner) Set(val *LocationInner) {
79 | v.value = val
80 | v.isSet = true
81 | }
82 |
83 | func (v NullableLocationInner) IsSet() bool {
84 | return v.isSet
85 | }
86 |
87 | func (v *NullableLocationInner) Unset() {
88 | v.value = nil
89 | v.isSet = false
90 | }
91 |
92 | func NewNullableLocationInner(val *LocationInner) *NullableLocationInner {
93 | return &NullableLocationInner{value: val, isSet: true}
94 | }
95 |
96 | func (v NullableLocationInner) MarshalJSON() ([]byte, error) {
97 | return json.Marshal(v.value)
98 | }
99 |
100 | func (v *NullableLocationInner) UnmarshalJSON(src []byte) error {
101 | v.isSet = true
102 | return json.Unmarshal(src, &v.value)
103 | }
104 |
--------------------------------------------------------------------------------
/swagger/model_metadata.go:
--------------------------------------------------------------------------------
1 | /*
2 | ChromaDB API
3 |
4 | This is OpenAPI schema for ChromaDB API.
5 |
6 | API version: 1.0.0
7 | */
8 |
9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
10 |
11 | package openapi
12 |
13 | import (
14 | "encoding/json"
15 | "fmt"
16 | )
17 |
18 | // Metadata struct for Metadata
19 | type Metadata struct {
20 | Bool *bool
21 | Float32 *float32
22 | Int32 *int32
23 | String *string
24 | }
25 |
26 | // Unmarshal JSON data into any of the pointers in the struct
27 | func (dst *Metadata) UnmarshalJSON(data []byte) error {
28 | var err error
29 | // try to unmarshal JSON data into Bool
30 | err = json.Unmarshal(data, &dst.Bool)
31 | if err == nil {
32 | jsonBool, _ := json.Marshal(dst.Bool)
33 | if string(jsonBool) == "{}" { // empty struct
34 | dst.Bool = nil
35 | } else {
36 | return nil // data stored in dst.Bool, return on the first match
37 | }
38 | } else {
39 | dst.Bool = nil
40 | }
41 |
42 | // try to unmarshal JSON data into Int32
43 | err = json.Unmarshal(data, &dst.Int32)
44 | if err == nil {
45 | jsonInt32, _ := json.Marshal(dst.Int32)
46 | if string(jsonInt32) == "{}" { // empty struct
47 | dst.Int32 = nil
48 | } else {
49 | return nil // data stored in dst.Int32, return on the first match
50 | }
51 | } else {
52 | dst.Int32 = nil
53 | }
54 | // try to unmarshal JSON data into Float32
55 | err = json.Unmarshal(data, &dst.Float32)
56 | if err == nil {
57 | jsonFloat32, _ := json.Marshal(dst.Float32)
58 | if string(jsonFloat32) == "{}" { // empty struct
59 | dst.Float32 = nil
60 | } else {
61 | return nil // data stored in dst.Float32, return on the first match
62 | }
63 | } else {
64 | dst.Float32 = nil
65 | }
66 |
67 | // try to unmarshal JSON data into String
68 | err = json.Unmarshal(data, &dst.String)
69 | if err == nil {
70 | jsonString, _ := json.Marshal(dst.String)
71 | if string(jsonString) == "{}" { // empty struct
72 | dst.String = nil
73 | } else {
74 | return nil // data stored in dst.String, return on the first match
75 | }
76 | } else {
77 | dst.String = nil
78 | }
79 |
80 | return fmt.Errorf("data failed to match schemas in anyOf(Metadata)")
81 | }
82 |
83 | // Marshal data from the first non-nil pointers in the struct to JSON
84 | func (src *Metadata) MarshalJSON() ([]byte, error) {
85 | if src.Bool != nil {
86 | return json.Marshal(&src.Bool)
87 | }
88 |
89 | if src.Float32 != nil {
90 | return json.Marshal(&src.Float32)
91 | }
92 |
93 | if src.Int32 != nil {
94 | return json.Marshal(&src.Int32)
95 | }
96 |
97 | if src.String != nil {
98 | return json.Marshal(&src.String)
99 | }
100 |
101 | return nil, nil // no data in anyOf schemas
102 | }
103 |
104 | type NullableMetadata struct {
105 | value *Metadata
106 | isSet bool
107 | }
108 |
109 | func (v NullableMetadata) Get() *Metadata {
110 | return v.value
111 | }
112 |
113 | func (v *NullableMetadata) Set(val *Metadata) {
114 | v.value = val
115 | v.isSet = true
116 | }
117 |
118 | func (v NullableMetadata) IsSet() bool {
119 | return v.isSet
120 | }
121 |
122 | func (v *NullableMetadata) Unset() {
123 | v.value = nil
124 | v.isSet = false
125 | }
126 |
127 | func NewNullableMetadata(val *Metadata) *NullableMetadata {
128 | return &NullableMetadata{value: val, isSet: true}
129 | }
130 |
131 | func (v NullableMetadata) MarshalJSON() ([]byte, error) {
132 | return json.Marshal(v.value)
133 | }
134 |
135 | func (v *NullableMetadata) UnmarshalJSON(src []byte) error {
136 | v.isSet = true
137 | return json.Unmarshal(src, &v.value)
138 | }
139 |
--------------------------------------------------------------------------------
/swagger/model_tenant.go:
--------------------------------------------------------------------------------
1 | /*
2 | ChromaDB API
3 |
4 | This is OpenAPI schema for ChromaDB API.
5 |
6 | API version: 1.0.0
7 | */
8 |
9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
10 |
11 | package openapi
12 |
13 | import (
14 | "encoding/json"
15 | )
16 |
17 | // checks if the Tenant type satisfies the MappedNullable interface at compile time
18 | var _ MappedNullable = &Tenant{}
19 |
20 | // Tenant struct for Tenant
21 | type Tenant struct {
22 | Name *string `json:"name,omitempty"`
23 | }
24 |
25 | // NewTenant instantiates a new Tenant object
26 | // This constructor will assign default values to properties that have it defined,
27 | // and makes sure properties required by API are set, but the set of arguments
28 | // will change when the set of required properties is changed
29 | func NewTenant() *Tenant {
30 | this := Tenant{}
31 | return &this
32 | }
33 |
34 | // NewTenantWithDefaults instantiates a new Tenant object
35 | // This constructor will only assign default values to properties that have it defined,
36 | // but it doesn't guarantee that properties required by API are set
37 | func NewTenantWithDefaults() *Tenant {
38 | this := Tenant{}
39 | return &this
40 | }
41 |
42 | // GetName returns the Name field value if set, zero value otherwise.
43 | func (o *Tenant) GetName() string {
44 | if o == nil || IsNil(o.Name) {
45 | var ret string
46 | return ret
47 | }
48 | return *o.Name
49 | }
50 |
51 | // GetNameOk returns a tuple with the Name field value if set, nil otherwise
52 | // and a boolean to check if the value has been set.
53 | func (o *Tenant) GetNameOk() (*string, bool) {
54 | if o == nil || IsNil(o.Name) {
55 | return nil, false
56 | }
57 | return o.Name, true
58 | }
59 |
60 | // HasName returns a boolean if a field has been set.
61 | func (o *Tenant) HasName() bool {
62 | if o != nil && !IsNil(o.Name) {
63 | return true
64 | }
65 |
66 | return false
67 | }
68 |
69 | // SetName gets a reference to the given string and assigns it to the Name field.
70 | func (o *Tenant) SetName(v string) {
71 | o.Name = &v
72 | }
73 |
74 | func (o Tenant) MarshalJSON() ([]byte, error) {
75 | toSerialize, err := o.ToMap()
76 | if err != nil {
77 | return []byte{}, err
78 | }
79 | return json.Marshal(toSerialize)
80 | }
81 |
82 | func (o Tenant) ToMap() (map[string]interface{}, error) {
83 | toSerialize := map[string]interface{}{}
84 | if !IsNil(o.Name) {
85 | toSerialize["name"] = o.Name
86 | }
87 | return toSerialize, nil
88 | }
89 |
90 | type NullableTenant struct {
91 | value *Tenant
92 | isSet bool
93 | }
94 |
95 | func (v NullableTenant) Get() *Tenant {
96 | return v.value
97 | }
98 |
99 | func (v *NullableTenant) Set(val *Tenant) {
100 | v.value = val
101 | v.isSet = true
102 | }
103 |
104 | func (v NullableTenant) IsSet() bool {
105 | return v.isSet
106 | }
107 |
108 | func (v *NullableTenant) Unset() {
109 | v.value = nil
110 | v.isSet = false
111 | }
112 |
113 | func NewNullableTenant(val *Tenant) *NullableTenant {
114 | return &NullableTenant{value: val, isSet: true}
115 | }
116 |
117 | func (v NullableTenant) MarshalJSON() ([]byte, error) {
118 | return json.Marshal(v.value)
119 | }
120 |
121 | func (v *NullableTenant) UnmarshalJSON(src []byte) error {
122 | v.isSet = true
123 | return json.Unmarshal(src, &v.value)
124 | }
125 |
--------------------------------------------------------------------------------
/swagger/model_update_collection.go:
--------------------------------------------------------------------------------
1 | /*
2 | ChromaDB API
3 |
4 | This is OpenAPI schema for ChromaDB API.
5 |
6 | API version: 1.0.0
7 | */
8 |
9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
10 |
11 | package openapi
12 |
13 | import (
14 | "encoding/json"
15 | )
16 |
17 | // checks if the UpdateCollection type satisfies the MappedNullable interface at compile time
18 | var _ MappedNullable = &UpdateCollection{}
19 |
20 | // UpdateCollection struct for UpdateCollection
21 | type UpdateCollection struct {
22 | NewName *string `json:"new_name,omitempty"`
23 | NewMetadata map[string]interface{} `json:"new_metadata,omitempty"`
24 | }
25 |
26 | // NewUpdateCollection instantiates a new UpdateCollection object
27 | // This constructor will assign default values to properties that have it defined,
28 | // and makes sure properties required by API are set, but the set of arguments
29 | // will change when the set of required properties is changed
30 | func NewUpdateCollection() *UpdateCollection {
31 | this := UpdateCollection{}
32 | return &this
33 | }
34 |
35 | // NewUpdateCollectionWithDefaults instantiates a new UpdateCollection object
36 | // This constructor will only assign default values to properties that have it defined,
37 | // but it doesn't guarantee that properties required by API are set
38 | func NewUpdateCollectionWithDefaults() *UpdateCollection {
39 | this := UpdateCollection{}
40 | return &this
41 | }
42 |
43 | // GetNewName returns the NewName field value if set, zero value otherwise.
44 | func (o *UpdateCollection) GetNewName() string {
45 | if o == nil || IsNil(o.NewName) {
46 | var ret string
47 | return ret
48 | }
49 | return *o.NewName
50 | }
51 |
52 | // GetNewNameOk returns a tuple with the NewName field value if set, nil otherwise
53 | // and a boolean to check if the value has been set.
54 | func (o *UpdateCollection) GetNewNameOk() (*string, bool) {
55 | if o == nil || IsNil(o.NewName) {
56 | return nil, false
57 | }
58 | return o.NewName, true
59 | }
60 |
61 | // HasNewName returns a boolean if a field has been set.
62 | func (o *UpdateCollection) HasNewName() bool {
63 | if o != nil && !IsNil(o.NewName) {
64 | return true
65 | }
66 |
67 | return false
68 | }
69 |
70 | // SetNewName gets a reference to the given string and assigns it to the NewName field.
71 | func (o *UpdateCollection) SetNewName(v string) {
72 | o.NewName = &v
73 | }
74 |
75 | // GetNewMetadata returns the NewMetadata field value if set, zero value otherwise.
76 | func (o *UpdateCollection) GetNewMetadata() map[string]interface{} {
77 | if o == nil || IsNil(o.NewMetadata) {
78 | var ret map[string]interface{}
79 | return ret
80 | }
81 | return o.NewMetadata
82 | }
83 |
84 | // GetNewMetadataOk returns a tuple with the NewMetadata field value if set, nil otherwise
85 | // and a boolean to check if the value has been set.
86 | func (o *UpdateCollection) GetNewMetadataOk() (map[string]interface{}, bool) {
87 | if o == nil || IsNil(o.NewMetadata) {
88 | return map[string]interface{}{}, false
89 | }
90 | return o.NewMetadata, true
91 | }
92 |
93 | // HasNewMetadata returns a boolean if a field has been set.
94 | func (o *UpdateCollection) HasNewMetadata() bool {
95 | if o != nil && !IsNil(o.NewMetadata) {
96 | return true
97 | }
98 |
99 | return false
100 | }
101 |
102 | // SetNewMetadata gets a reference to the given map[string]interface{} and assigns it to the NewMetadata field.
103 | func (o *UpdateCollection) SetNewMetadata(v map[string]interface{}) {
104 | o.NewMetadata = v
105 | }
106 |
107 | func (o UpdateCollection) MarshalJSON() ([]byte, error) {
108 | toSerialize, err := o.ToMap()
109 | if err != nil {
110 | return []byte{}, err
111 | }
112 | return json.Marshal(toSerialize)
113 | }
114 |
115 | func (o UpdateCollection) ToMap() (map[string]interface{}, error) {
116 | toSerialize := map[string]interface{}{}
117 | if !IsNil(o.NewName) {
118 | toSerialize["new_name"] = o.NewName
119 | }
120 | if !IsNil(o.NewMetadata) {
121 | toSerialize["new_metadata"] = o.NewMetadata
122 | }
123 | return toSerialize, nil
124 | }
125 |
126 | type NullableUpdateCollection struct {
127 | value *UpdateCollection
128 | isSet bool
129 | }
130 |
131 | func (v NullableUpdateCollection) Get() *UpdateCollection {
132 | return v.value
133 | }
134 |
135 | func (v *NullableUpdateCollection) Set(val *UpdateCollection) {
136 | v.value = val
137 | v.isSet = true
138 | }
139 |
140 | func (v NullableUpdateCollection) IsSet() bool {
141 | return v.isSet
142 | }
143 |
144 | func (v *NullableUpdateCollection) Unset() {
145 | v.value = nil
146 | v.isSet = false
147 | }
148 |
149 | func NewNullableUpdateCollection(val *UpdateCollection) *NullableUpdateCollection {
150 | return &NullableUpdateCollection{value: val, isSet: true}
151 | }
152 |
153 | func (v NullableUpdateCollection) MarshalJSON() ([]byte, error) {
154 | return json.Marshal(v.value)
155 | }
156 |
157 | func (v *NullableUpdateCollection) UnmarshalJSON(src []byte) error {
158 | v.isSet = true
159 | return json.Unmarshal(src, &v.value)
160 | }
161 |
--------------------------------------------------------------------------------
/swagger/model_validation_error.go:
--------------------------------------------------------------------------------
1 | /*
2 | ChromaDB API
3 |
4 | This is OpenAPI schema for ChromaDB API.
5 |
6 | API version: 1.0.0
7 | */
8 |
9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
10 |
11 | package openapi
12 |
13 | import (
14 | "encoding/json"
15 | )
16 |
17 | // checks if the ValidationError type satisfies the MappedNullable interface at compile time
18 | var _ MappedNullable = &ValidationError{}
19 |
20 | // ValidationError struct for ValidationError
21 | type ValidationError struct {
22 | Loc []LocationInner `json:"loc"`
23 | Msg string `json:"msg"`
24 | Type string `json:"type"`
25 | }
26 |
27 | // NewValidationError instantiates a new ValidationError object
28 | // This constructor will assign default values to properties that have it defined,
29 | // and makes sure properties required by API are set, but the set of arguments
30 | // will change when the set of required properties is changed
31 | func NewValidationError(loc []LocationInner, msg string, type_ string) *ValidationError {
32 | this := ValidationError{}
33 | this.Loc = loc
34 | this.Msg = msg
35 | this.Type = type_
36 | return &this
37 | }
38 |
39 | // NewValidationErrorWithDefaults instantiates a new ValidationError object
40 | // This constructor will only assign default values to properties that have it defined,
41 | // but it doesn't guarantee that properties required by API are set
42 | func NewValidationErrorWithDefaults() *ValidationError {
43 | this := ValidationError{}
44 | return &this
45 | }
46 |
47 | // GetLoc returns the Loc field value
48 | func (o *ValidationError) GetLoc() []LocationInner {
49 | if o == nil {
50 | var ret []LocationInner
51 | return ret
52 | }
53 |
54 | return o.Loc
55 | }
56 |
57 | // GetLocOk returns a tuple with the Loc field value
58 | // and a boolean to check if the value has been set.
59 | func (o *ValidationError) GetLocOk() ([]LocationInner, bool) {
60 | if o == nil {
61 | return nil, false
62 | }
63 | return o.Loc, true
64 | }
65 |
66 | // SetLoc sets field value
67 | func (o *ValidationError) SetLoc(v []LocationInner) {
68 | o.Loc = v
69 | }
70 |
71 | // GetMsg returns the Msg field value
72 | func (o *ValidationError) GetMsg() string {
73 | if o == nil {
74 | var ret string
75 | return ret
76 | }
77 |
78 | return o.Msg
79 | }
80 |
81 | // GetMsgOk returns a tuple with the Msg field value
82 | // and a boolean to check if the value has been set.
83 | func (o *ValidationError) GetMsgOk() (*string, bool) {
84 | if o == nil {
85 | return nil, false
86 | }
87 | return &o.Msg, true
88 | }
89 |
90 | // SetMsg sets field value
91 | func (o *ValidationError) SetMsg(v string) {
92 | o.Msg = v
93 | }
94 |
95 | // GetType returns the Type field value
96 | func (o *ValidationError) GetType() string {
97 | if o == nil {
98 | var ret string
99 | return ret
100 | }
101 |
102 | return o.Type
103 | }
104 |
105 | // GetTypeOk returns a tuple with the Type field value
106 | // and a boolean to check if the value has been set.
107 | func (o *ValidationError) GetTypeOk() (*string, bool) {
108 | if o == nil {
109 | return nil, false
110 | }
111 | return &o.Type, true
112 | }
113 |
114 | // SetType sets field value
115 | func (o *ValidationError) SetType(v string) {
116 | o.Type = v
117 | }
118 |
119 | func (o ValidationError) MarshalJSON() ([]byte, error) {
120 | toSerialize, err := o.ToMap()
121 | if err != nil {
122 | return []byte{}, err
123 | }
124 | return json.Marshal(toSerialize)
125 | }
126 |
127 | func (o ValidationError) ToMap() (map[string]interface{}, error) {
128 | toSerialize := map[string]interface{}{}
129 | toSerialize["loc"] = o.Loc
130 | toSerialize["msg"] = o.Msg
131 | toSerialize["type"] = o.Type
132 | return toSerialize, nil
133 | }
134 |
135 | type NullableValidationError struct {
136 | value *ValidationError
137 | isSet bool
138 | }
139 |
140 | func (v NullableValidationError) Get() *ValidationError {
141 | return v.value
142 | }
143 |
144 | func (v *NullableValidationError) Set(val *ValidationError) {
145 | v.value = val
146 | v.isSet = true
147 | }
148 |
149 | func (v NullableValidationError) IsSet() bool {
150 | return v.isSet
151 | }
152 |
153 | func (v *NullableValidationError) Unset() {
154 | v.value = nil
155 | v.isSet = false
156 | }
157 |
158 | func NewNullableValidationError(val *ValidationError) *NullableValidationError {
159 | return &NullableValidationError{value: val, isSet: true}
160 | }
161 |
162 | func (v NullableValidationError) MarshalJSON() ([]byte, error) {
163 | return json.Marshal(v.value)
164 | }
165 |
166 | func (v *NullableValidationError) UnmarshalJSON(src []byte) error {
167 | v.isSet = true
168 | return json.Unmarshal(src, &v.value)
169 | }
170 |
--------------------------------------------------------------------------------
/swagger/response.go:
--------------------------------------------------------------------------------
1 | /*
2 | ChromaDB API
3 |
4 | This is OpenAPI schema for ChromaDB API.
5 |
6 | API version: 1.0.0
7 | */
8 |
9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT.
10 |
11 | package openapi
12 |
13 | import (
14 | "net/http"
15 | )
16 |
17 | // APIResponse stores the API response returned by the server.
18 | type APIResponse struct {
19 | *http.Response `json:"-"`
20 | Message string `json:"message,omitempty"`
21 | // Operation is the name of the OpenAPI operation.
22 | Operation string `json:"operation,omitempty"`
23 | // RequestURL is the request URL. This value is always available, even if the
24 | // embedded *http.Response is nil.
25 | RequestURL string `json:"url,omitempty"`
26 | // Method is the HTTP method used for the request. This value is always
27 | // available, even if the embedded *http.Response is nil.
28 | Method string `json:"method,omitempty"`
29 | // Payload holds the contents of the response body (which may be nil or empty).
30 | // This is provided here as the raw response.Body() reader will have already
31 | // been drained.
32 | Payload []byte `json:"-"`
33 | }
34 |
35 | // NewAPIResponse returns a new APIResponse object.
36 | func NewAPIResponse(r *http.Response) *APIResponse {
37 |
38 | response := &APIResponse{Response: r}
39 | return response
40 | }
41 |
42 | // NewAPIResponseWithError returns a new APIResponse object with the provided error message.
43 | func NewAPIResponseWithError(errorMessage string) *APIResponse {
44 |
45 | response := &APIResponse{Message: errorMessage}
46 | return response
47 | }
48 |
--------------------------------------------------------------------------------
/types/record_test.go:
--------------------------------------------------------------------------------
1 | //go:build basic
2 |
3 | package types
4 |
5 | import (
6 | "context"
7 | "testing"
8 |
9 | "github.com/stretchr/testify/require"
10 | )
11 |
12 | func TestNewRecordSet(t *testing.T) {
13 | t.Run("Test NewRecordSet", func(t *testing.T) {
14 | recordSet, err := NewRecordSet()
15 | require.NoError(t, err)
16 | require.NotNil(t, recordSet)
17 | require.NotNil(t, recordSet.Records)
18 | require.Equal(t, len(recordSet.Records), 0)
19 | })
20 | t.Run("Test NewRecordSet with options", func(t *testing.T) {
21 | recordSet, err := NewRecordSet(WithIDGenerator(NewULIDGenerator()), WithEmbeddingFunction(NewConsistentHashEmbeddingFunction()))
22 | require.NoError(t, err)
23 | require.NotNil(t, recordSet)
24 | require.NotNil(t, recordSet.IDGenerator)
25 | require.NotNil(t, recordSet.EmbeddingFunction)
26 | require.NotNil(t, recordSet.Records)
27 | require.Equal(t, len(recordSet.Records), 0)
28 | })
29 |
30 | t.Run("Test NewRecordSet with IDGenerator", func(t *testing.T) {
31 | recordSet, err := NewRecordSet(WithIDGenerator(NewULIDGenerator()))
32 | recordSet.WithRecord(WithDocument("test document"))
33 | require.NoError(t, err)
34 | require.NotNil(t, recordSet)
35 | require.NotNil(t, recordSet.IDGenerator)
36 | require.Nil(t, recordSet.EmbeddingFunction)
37 | require.NotNil(t, recordSet.Records)
38 | require.Equal(t, len(recordSet.Records), 1)
39 | require.NotNil(t, recordSet.Records[0].ID)
40 | })
41 |
42 | t.Run("Test NewRecordSet with EmbeddingFunction", func(t *testing.T) {
43 | recordSet, err := NewRecordSet(WithEmbeddingFunction(NewConsistentHashEmbeddingFunction()))
44 | require.NoError(t, err)
45 | recordSet.WithRecord(WithDocument("test document"), WithID("1"))
46 | _, err = recordSet.BuildAndValidate(context.TODO())
47 | require.NoError(t, err)
48 | require.NotNil(t, recordSet)
49 | require.Nil(t, recordSet.IDGenerator)
50 | require.NotNil(t, recordSet.EmbeddingFunction)
51 | require.NotNil(t, recordSet.Records)
52 | require.Equal(t, len(recordSet.Records), 1)
53 | require.NotNil(t, recordSet.Records[0].ID)
54 | require.Equal(t, recordSet.Records[0].ID, "1")
55 | require.NotNil(t, recordSet.Records[0].Document)
56 | require.Equal(t, recordSet.Records[0].Document, "test document")
57 | require.NotNil(t, recordSet.Records[0].Embedding.GetFloat32())
58 | })
59 |
60 | t.Run("Test NewRecordSet with complete Record", func(t *testing.T) {
61 | recordSet, err := NewRecordSet(WithEmbeddingFunction(NewConsistentHashEmbeddingFunction()))
62 | require.NoError(t, err)
63 | var embeddings = []float32{0.1, 0.2, 0.3}
64 | recordSet.WithRecord(
65 | WithID("1"),
66 | WithDocument("test document"),
67 | WithEmbedding(*NewEmbeddingFromFloat32(embeddings)),
68 | WithMetadata("testKey", 1),
69 | )
70 | _, err = recordSet.BuildAndValidate(context.TODO())
71 | require.NoError(t, err)
72 | require.NotNil(t, recordSet)
73 | require.Nil(t, recordSet.IDGenerator)
74 | require.NotNil(t, recordSet.EmbeddingFunction)
75 | require.NotNil(t, recordSet.Records)
76 | require.Equal(t, len(recordSet.Records), 1)
77 | require.NotNil(t, recordSet.Records[0].ID)
78 | require.Equal(t, recordSet.Records[0].ID, "1")
79 | require.NotNil(t, recordSet.Records[0].Embedding.GetFloat32())
80 | require.Equal(t, recordSet.Records[0].Embedding.GetFloat32(), &embeddings)
81 | require.Equal(t, recordSet.Records[0].Document, "test document")
82 | require.Equal(t, recordSet.Records[0].Metadata["testKey"], 1)
83 | })
84 | }
85 |
--------------------------------------------------------------------------------
/where_document/wheredoc.go:
--------------------------------------------------------------------------------
1 | package wheredoc
2 |
3 | import (
4 | "fmt"
5 | )
6 |
7 | type InvalidWhereDocumentValueError struct {
8 | Value interface{}
9 | }
10 |
11 | func (e *InvalidWhereDocumentValueError) Error() string {
12 | return fmt.Sprintf("Invalid value for where document clause for value %v. Allowed values are string", e.Value)
13 | }
14 |
15 | type Builder struct {
16 | WhereClause map[string]interface{}
17 | err error
18 | }
19 |
20 | func NewWhereDocumentBuilder() *Builder {
21 | return &Builder{WhereClause: make(map[string]interface{})}
22 | }
23 |
24 | func (w *Builder) operation(operation string, value interface{}) *Builder {
25 | if w.err != nil {
26 | return w
27 | }
28 | inner := make(map[string]interface{})
29 |
30 | switch value.(type) {
31 | case string:
32 | default:
33 | w.err = &InvalidWhereDocumentValueError{Value: value}
34 | return w
35 | }
36 | inner[operation] = value
37 | w.WhereClause[operation] = value
38 | return w
39 | }
40 |
41 | func (w *Builder) Contains(value interface{}) *Builder {
42 | return w.operation("$contains", value)
43 | }
44 |
45 | func (w *Builder) NotContains(value interface{}) *Builder {
46 | return w.operation("$not_contains", value)
47 | }
48 |
49 | func (w *Builder) And(builders ...*Builder) *Builder {
50 | if w.err != nil {
51 | return w
52 | }
53 | var andClause []map[string]interface{}
54 | for _, b := range builders {
55 | buildExpr, err := b.Build()
56 | if err != nil {
57 | w.err = err
58 | return w
59 | }
60 | andClause = append(andClause, buildExpr)
61 | }
62 | w.WhereClause["$and"] = andClause
63 | return w
64 | }
65 |
66 | func (w *Builder) Or(builders ...*Builder) *Builder {
67 | if w.err != nil {
68 | return w
69 | }
70 | var orClause []map[string]interface{}
71 | for _, b := range builders {
72 | buildExpr, err := b.Build()
73 | if err != nil {
74 | w.err = err
75 | return w
76 | }
77 | orClause = append(orClause, buildExpr)
78 | }
79 | w.WhereClause["$or"] = orClause
80 | return w
81 | }
82 |
83 | func (w *Builder) Build() (map[string]interface{}, error) {
84 | if w.err != nil {
85 | return nil, w.err
86 | }
87 | return w.WhereClause, nil
88 | }
89 |
90 | type WhereDocumentOperation func(builder *Builder) error
91 |
92 | func Contains(value interface{}) WhereDocumentOperation {
93 | return func(w *Builder) error {
94 | w.Contains(value)
95 | return nil
96 | }
97 | }
98 |
99 | func NotContains(value interface{}) WhereDocumentOperation {
100 | return func(w *Builder) error {
101 | w.NotContains(value)
102 | return nil
103 | }
104 | }
105 |
106 | func And(ops ...WhereDocumentOperation) WhereDocumentOperation {
107 | return func(w *Builder) error {
108 | subBuilders := make([]*Builder, 0, len(ops))
109 | for _, op := range ops {
110 | wdx := NewWhereDocumentBuilder()
111 | if err := op(wdx); err != nil {
112 | return err
113 | }
114 | subBuilders = append(subBuilders, wdx)
115 | }
116 | w.And(subBuilders...)
117 | return nil
118 | }
119 | }
120 |
121 | func Or(ops ...WhereDocumentOperation) WhereDocumentOperation {
122 | return func(w *Builder) error {
123 | subBuilders := make([]*Builder, 0, len(ops))
124 | for _, op := range ops {
125 | wdx := NewWhereDocumentBuilder()
126 | if err := op(wdx); err != nil {
127 | return err
128 | }
129 | subBuilders = append(subBuilders, wdx)
130 | }
131 | w.Or(subBuilders...)
132 | return nil
133 | }
134 | }
135 |
136 | func WhereDocument(operation WhereDocumentOperation) (map[string]interface{}, error) {
137 | w := NewWhereDocumentBuilder()
138 | if err := operation(w); err != nil {
139 | return nil, err
140 | }
141 | return w.Build()
142 | }
143 |
--------------------------------------------------------------------------------