├── .github └── workflows │ ├── go.yml │ ├── lint.yaml │ ├── mkdocs.yml │ └── nightly.yml ├── .gitignore ├── .golangci.yml ├── CNAME ├── CODEOWNERS ├── LICENSE.md ├── Makefile ├── README.md ├── chroma.go ├── collection ├── collection.go └── collection_test.go ├── docs ├── CNAME ├── docs │ ├── assets │ │ └── images │ │ │ ├── favicon.png │ │ │ └── logo.png │ ├── auth.md │ ├── client.md │ ├── embeddings.md │ ├── filtering.md │ ├── index.md │ ├── javascripts │ │ └── gtag.js │ ├── records.md │ └── rerankers.md └── mkdocs.yml ├── examples └── v2 │ ├── basic │ ├── README.md │ ├── go.mod │ ├── go.sum │ └── main.go │ ├── custom_embedding_function │ ├── README.md │ ├── go.mod │ └── go.sum │ ├── embedding_function_basic │ ├── README.md │ ├── go.mod │ ├── go.sum │ └── main.go │ ├── reranking_function_basic │ ├── README.md │ ├── go.mod │ └── go.sum │ └── tenant_and_db │ ├── README.md │ ├── go.mod │ ├── go.sum │ └── main.go ├── gen_api.sh ├── gen_api_v3.sh ├── go.mod ├── go.sum ├── internal └── http │ ├── constants.go │ ├── errors.go │ ├── retry.go │ ├── retry_test.go │ ├── strategy.go │ └── utils.go ├── metadata ├── metadata.go └── metadata_test.go ├── openapi.yaml ├── patches └── model_anyof.mustache ├── pkg ├── api │ └── v2 │ │ ├── auth.go │ │ ├── base.go │ │ ├── client.go │ │ ├── client_http.go │ │ ├── client_http_integration_test.go │ │ ├── client_http_test.go │ │ ├── collection.go │ │ ├── collection_http.go │ │ ├── collection_http_integration_test.go │ │ ├── collection_http_test.go │ │ ├── constants.go │ │ ├── document.go │ │ ├── document_test.go │ │ ├── ids.go │ │ ├── metadata.go │ │ ├── metadata_test.go │ │ ├── record.go │ │ ├── record_test.go │ │ ├── reranking.go │ │ ├── results.go │ │ ├── results_test.go │ │ ├── server.htpasswd │ │ ├── utils.go │ │ ├── v1-config.yaml │ │ ├── where.go │ │ ├── where_document.go │ │ ├── where_document_test.go │ │ └── where_test.go ├── commons │ ├── cohere │ │ ├── cohere_commons.go │ │ └── cohere_commons_test.go │ └── http │ │ ├── constants.go │ │ ├── errors.go │ │ ├── retry.go │ │ ├── retry_test.go │ │ ├── strategy.go │ │ └── utils.go ├── embeddings │ ├── cloudflare │ │ ├── cloudflare.go │ │ ├── cloudflare_test.go │ │ └── option.go │ ├── cohere │ │ ├── cohere.go │ │ ├── cohere_test.go │ │ └── option.go │ ├── default_ef │ │ ├── constants.go │ │ ├── default_ef.go │ │ ├── default_ef_test.go │ │ ├── download_utils.go │ │ ├── download_utils_test.go │ │ └── tensors_utils.go │ ├── distance_metric.go │ ├── embedding.go │ ├── embedding_test.go │ ├── gemini │ │ ├── gemini.go │ │ ├── gemini_test.go │ │ └── option.go │ ├── hf │ │ ├── hf.go │ │ ├── hf_test.go │ │ └── option.go │ ├── jina │ │ ├── jina.go │ │ ├── jina_test.go │ │ └── option.go │ ├── mistral │ │ ├── mistral.go │ │ ├── mistral_test.go │ │ └── option.go │ ├── nomic │ │ ├── nomic.go │ │ ├── nomic_test.go │ │ └── option.go │ ├── ollama │ │ ├── ollama.go │ │ ├── ollama_test.go │ │ └── option.go │ ├── openai │ │ ├── openai.go │ │ ├── openai_test.go │ │ └── options.go │ ├── together │ │ ├── option.go │ │ ├── together.go │ │ └── together_test.go │ └── voyage │ │ ├── option.go │ │ ├── voyage.go │ │ └── voyage_test.go ├── rerankings │ ├── cohere │ │ ├── cohere.go │ │ ├── cohere_test.go │ │ └── option.go │ ├── hf │ │ ├── huggingface.go │ │ ├── huggingface_test.go │ │ └── option.go │ ├── jina │ │ ├── jina.go │ │ ├── jina_test.go │ │ └── option.go │ ├── reranking.go │ └── reranking_test.go └── tokenizers │ └── libtokenizers │ ├── tokenizer.go │ └── tokenizers.h ├── scripts └── chroma_server.sh ├── swagger ├── api_default.go ├── client.go ├── configuration.go ├── model_add_embedding.go ├── model_collection.go ├── model_create_collection.go ├── model_create_database.go ├── model_create_tenant.go ├── model_database.go ├── model_delete_embedding.go ├── model_embeddings_inner.go ├── model_get_embedding.go ├── model_get_result.go ├── model_http_validation_error.go ├── model_include_inner.go ├── model_location_inner.go ├── model_metadata.go ├── model_query_embedding.go ├── model_query_result.go ├── model_tenant.go ├── model_update_collection.go ├── model_update_embedding.go ├── model_validation_error.go ├── response.go └── utils.go ├── test ├── chroma_api_test.go ├── chroma_client_test.go └── test_utils.go ├── types ├── record.go ├── record_test.go ├── types.go └── types_test.go ├── where ├── where.go └── where_test.go └── where_document ├── wheredoc.go └── wheredoc_test.go /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Go Lint 2 | 3 | on: 4 | pull_request: {} 5 | 6 | jobs: 7 | lint: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | # Required: allow read access to the content for analysis. 11 | contents: read 12 | # Optional: allow read access to pull request. Use with `only-new-issues` option. 13 | pull-requests: read 14 | # Optional: Allow write access to checks to allow the action to annotate code in the PR. 15 | checks: write 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v4 19 | with: 20 | fetch-depth: 0 21 | - name: Set up Go 22 | uses: actions/setup-go@v4 23 | with: 24 | go-version-file: 'go.mod' 25 | - name: Run golangci-lint 26 | uses: golangci/golangci-lint-action@v8 27 | with: 28 | version: v2.1 -------------------------------------------------------------------------------- /.github/workflows/mkdocs.yml: -------------------------------------------------------------------------------- 1 | name: Deploy MkDocs Site 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.x 19 | 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install mkdocs-material 24 | 25 | - name: Build the MkDocs site 26 | run: | 27 | cd docs 28 | mkdocs build -d ../site 29 | cp ../CNAME ../site 30 | 31 | - name: Deploy to GitHub Pages 32 | uses: peaceiris/actions-gh-pages@v3 33 | with: 34 | github_token: ${{ secrets.GITHUB_TOKEN }} 35 | publish_dir: ./site -------------------------------------------------------------------------------- /.github/workflows/nightly.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Nightly Test 5 | 6 | on: 7 | schedule: 8 | - cron: '0 0 * * *' # Run nightly at 00:00 UTC 9 | workflow_dispatch: 10 | 11 | 12 | jobs: 13 | build: 14 | name: Test API V2 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Log in to GitHub Container Registry 19 | uses: docker/login-action@v3 20 | with: 21 | registry: ghcr.io 22 | username: ${{ github.actor }} 23 | password: ${{ secrets.GITHUB_TOKEN }} 24 | - name: Set up Go 25 | uses: actions/setup-go@v4 26 | with: 27 | go-version-file: 'go.mod' 28 | - name: Lint 29 | uses: golangci/golangci-lint-action@v8 30 | with: 31 | version: v2.1 32 | - name: Build 33 | run: make build 34 | - name: Test 35 | run: make test-v2 36 | env: 37 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 38 | COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} 39 | HF_API_KEY: ${{ secrets.HF_API_KEY }} 40 | CF_API_TOKEN: ${{ secrets.CF_API_TOKEN }} 41 | CF_ACCOUNT_ID: ${{ secrets.CF_ACCOUNT_ID }} 42 | CF_GATEWAY_ENDPOINT: ${{ secrets.CF_GATEWAY_ENDPOINT }} 43 | TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }} 44 | VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }} 45 | GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} 46 | MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} 47 | NOMIC_API_KEY: ${{ secrets.NOMIC_API_KEY }} 48 | CHROMA_VERSION: "latest" 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | .env 26 | .idea/* 27 | *.iml 28 | /data 29 | /pkg/rerankings/hf/data 30 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | run: 3 | modules-download-mode: readonly 4 | linters: 5 | enable: 6 | - dupword 7 | - ginkgolinter 8 | - gocritic 9 | - mirror 10 | settings: 11 | gocritic: 12 | disable-all: false 13 | staticcheck: 14 | checks: 15 | - all 16 | - ST1000 17 | - ST1001 18 | - ST1003 19 | dot-import-whitelist: 20 | - fmt 21 | exclusions: 22 | generated: lax 23 | presets: 24 | - comments 25 | - common-false-positives 26 | - legacy 27 | - std-error-handling 28 | rules: 29 | - linters: 30 | - ineffassign 31 | path: conversion\.go 32 | - linters: 33 | - staticcheck 34 | text: 'ST1003: should not use underscores in Go names; func (Convert_.*_To_.*|SetDefaults_)' 35 | - linters: 36 | - ginkgolinter 37 | text: use a function call in (Eventually|Consistently) 38 | paths: 39 | - third_party$ 40 | - builtin$ 41 | - examples$ 42 | - ./swagger 43 | issues: 44 | max-same-issues: 0 45 | formatters: 46 | enable: 47 | - gci 48 | settings: 49 | gci: 50 | sections: 51 | - standard 52 | - default 53 | - prefix(github.com/amikos-tech/chroma-go) 54 | - blank 55 | - dot 56 | custom-order: true 57 | exclusions: 58 | generated: lax 59 | paths: 60 | - third_party$ 61 | - builtin$ 62 | - examples$ 63 | -------------------------------------------------------------------------------- /CNAME: -------------------------------------------------------------------------------- 1 | go-client.chromadb.dev -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @tazarov -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2023 Amikos Tech Ltd. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | generate: 2 | echo "This is deprecated. 0.2.x or later does not use generated client." 3 | sh ./gen_api_v3.sh 4 | 5 | build: 6 | go build -v ./... 7 | 8 | .PHONY: gotestsum-bin 9 | gotestsum-bin: 10 | go install gotest.tools/gotestsum@latest 11 | 12 | .PHONY: test 13 | test: gotestsum-bin 14 | gotestsum \ 15 | --format short-verbose \ 16 | --rerun-fails=5 \ 17 | --packages="./..." \ 18 | --junitfile unit-v1.xml \ 19 | -- \ 20 | -v \ 21 | -tags=basic \ 22 | -coverprofile=coverage-v1.out \ 23 | -timeout=30m 24 | 25 | .PHONY: test-v2 26 | test-v2: gotestsum-bin 27 | gotestsum \ 28 | --format short-verbose \ 29 | --rerun-fails=5 \ 30 | --packages="./..." \ 31 | --junitfile unit-v2.xml \ 32 | -- \ 33 | -v \ 34 | -tags=basicv2 \ 35 | -coverprofile=coverage-v2.out \ 36 | -timeout=30m 37 | 38 | .PHONY: test-rf 39 | test-rf: gotestsum-bin 40 | gotestsum \ 41 | --format short-verbose \ 42 | --rerun-fails=5 \ 43 | --packages="./..." \ 44 | --junitfile unit-rf.xml \ 45 | -- \ 46 | -v \ 47 | -tags=rf \ 48 | -coverprofile=coverage-rf.out \ 49 | -timeout=30m 50 | 51 | .PHONY: test-ef 52 | test-ef: gotestsum-bin 53 | gotestsum \ 54 | --format short-verbose \ 55 | --rerun-fails=5 \ 56 | --packages="./..." \ 57 | --junitfile unit-ef.xml \ 58 | -- \ 59 | -v \ 60 | -tags=ef \ 61 | -coverprofile=coverage-ef.out \ 62 | -timeout=30m 63 | 64 | .PHONY: lint 65 | lint: 66 | golangci-lint run 67 | 68 | .PHONY: lint-fix 69 | lint-fix: 70 | golangci-lint run --fix ./... 71 | 72 | .PHONY: clean-lint-cache 73 | clean-lint-cache: 74 | golangci-lint cache clean 75 | 76 | 77 | .PHONY: server 78 | server: 79 | sh ./scripts/chroma_server.sh 80 | -------------------------------------------------------------------------------- /collection/collection.go: -------------------------------------------------------------------------------- 1 | package collection 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/amikos-tech/chroma-go/metadata" 7 | "github.com/amikos-tech/chroma-go/types" 8 | ) 9 | 10 | type Builder struct { 11 | Tenant string 12 | Database string 13 | Name string 14 | Metadata map[string]interface{} 15 | CreateIfNotExist bool 16 | EmbeddingFunction types.EmbeddingFunction 17 | IDGenerator types.IDGenerator 18 | } 19 | 20 | type Option func(*Builder) error 21 | 22 | func WithEmbeddingFunction(embeddingFunction types.EmbeddingFunction) Option { 23 | return func(c *Builder) error { 24 | c.EmbeddingFunction = embeddingFunction 25 | return nil 26 | } 27 | } 28 | 29 | func WithIDGenerator(idGenerator types.IDGenerator) Option { 30 | return func(c *Builder) error { 31 | c.IDGenerator = idGenerator 32 | return nil 33 | } 34 | } 35 | 36 | func WithCreateIfNotExist(create bool) Option { 37 | return func(c *Builder) error { 38 | c.CreateIfNotExist = create 39 | return nil 40 | } 41 | } 42 | 43 | func WithHNSWDistanceFunction(distanceFunction types.DistanceFunction) Option { 44 | return func(b *Builder) error { 45 | if distanceFunction != types.L2 && distanceFunction != types.IP && distanceFunction != types.COSINE { 46 | return fmt.Errorf("invalid distance function, must be one of l2, ip, or cosine") 47 | } 48 | return WithMetadata(types.HNSWSpace, distanceFunction)(b) 49 | } 50 | } 51 | 52 | func WithHNSWBatchSize(batchSize int32) Option { 53 | return func(b *Builder) error { 54 | if batchSize < 1 { 55 | return fmt.Errorf("batch size must be greater than 0") 56 | } 57 | return WithMetadata(types.HNSWBatchSize, batchSize)(b) 58 | } 59 | } 60 | 61 | func WithHNSWSyncThreshold(syncThreshold int32) Option { 62 | return func(b *Builder) error { 63 | if syncThreshold < 1 { 64 | return fmt.Errorf("sync threshold must be greater than 0") 65 | } 66 | return WithMetadata(types.HNSWSyncThreshold, syncThreshold)(b) 67 | } 68 | } 69 | 70 | func WithHNSWM(m int32) Option { 71 | return func(b *Builder) error { 72 | if m < 1 { 73 | return fmt.Errorf("m must be greater than 0") 74 | } 75 | return WithMetadata(types.HNSWM, m)(b) 76 | } 77 | } 78 | 79 | func WithHNSWConstructionEf(efConstruction int32) Option { 80 | return func(b *Builder) error { 81 | if efConstruction < 1 { 82 | return fmt.Errorf("efConstruction must be greater than 0") 83 | } 84 | return WithMetadata(types.HNSWConstructionEF, efConstruction)(b) 85 | } 86 | } 87 | 88 | // WithMetadatas adds metadata to the collection. If the metadata key already exists, the value is overwritten. 89 | func WithMetadatas(metadata map[string]interface{}) Option { 90 | return func(b *Builder) error { 91 | if b.Metadata == nil { 92 | b.Metadata = make(map[string]interface{}) 93 | } 94 | for k, v := range metadata { 95 | err := WithMetadata(k, v)(b) 96 | if err != nil { 97 | return err 98 | } 99 | } 100 | return nil 101 | } 102 | } 103 | 104 | func WithHNSWSearchEf(efSearch int32) Option { 105 | return func(b *Builder) error { 106 | if efSearch < 1 { 107 | return fmt.Errorf("efSearch must be greater than 0") 108 | } 109 | return WithMetadata(types.HNSWSearchEF, efSearch)(b) 110 | } 111 | } 112 | 113 | func WithHNSWNumThreads(numThreads int32) Option { 114 | return func(b *Builder) error { 115 | if numThreads < 1 { 116 | return fmt.Errorf("numThreads must be greater than 0") 117 | } 118 | return WithMetadata(types.HNSWNumThreads, numThreads)(b) 119 | } 120 | } 121 | 122 | func WithHNSWResizeFactor(resizeFactor float32) Option { 123 | return func(b *Builder) error { 124 | if resizeFactor < 0 { 125 | return fmt.Errorf("resizeFactor must be greater than or equal to 0") 126 | } 127 | return WithMetadata(types.HNSWResizeFactor, resizeFactor)(b) 128 | } 129 | } 130 | 131 | func WithMetadata(key string, value interface{}) Option { 132 | return func(b *Builder) error { 133 | if b.Metadata == nil { 134 | b.Metadata = make(map[string]interface{}) 135 | } 136 | err := metadata.WithMetadata(key, value)(metadata.NewMetadataBuilder(&b.Metadata)) 137 | if err != nil { 138 | return err 139 | } 140 | return nil 141 | } 142 | } 143 | 144 | func WithTenant(tenant string) Option { 145 | return func(c *Builder) error { 146 | if tenant == "" { 147 | return fmt.Errorf("tenant cannot be empty") 148 | } 149 | c.Tenant = tenant 150 | return nil 151 | } 152 | } 153 | 154 | func WithDatabase(database string) Option { 155 | return func(c *Builder) error { 156 | if database == "" { 157 | return fmt.Errorf("database cannot be empty") 158 | } 159 | c.Database = database 160 | return nil 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /collection/collection_test.go: -------------------------------------------------------------------------------- 1 | //go:build basic 2 | 3 | package collection 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/amikos-tech/chroma-go/types" 11 | ) 12 | 13 | func TestCollectionBuilder(t *testing.T) { 14 | 15 | t.Run("With Embedding Function", func(t *testing.T) { 16 | b := &Builder{} 17 | err := WithEmbeddingFunction(nil)(b) 18 | require.NoError(t, err, "Unexpected error: %v", err) 19 | require.Nil(t, b.EmbeddingFunction) 20 | }) 21 | t.Run("With ID Generator", func(t *testing.T) { 22 | var generator = types.NewULIDGenerator() 23 | b := &Builder{} 24 | err := WithIDGenerator(generator)(b) 25 | require.NoError(t, err, "Unexpected error: %v", err) 26 | require.Equal(t, generator, b.IDGenerator) 27 | }) 28 | t.Run("With Create If Not Exist", func(t *testing.T) { 29 | b := &Builder{} 30 | err := WithCreateIfNotExist(true)(b) 31 | require.NoError(t, err, "Unexpected error: %v", err) 32 | require.True(t, b.CreateIfNotExist) 33 | }) 34 | 35 | t.Run("With HNSW Distance Function", func(t *testing.T) { 36 | b := &Builder{} 37 | err := WithHNSWDistanceFunction(types.L2)(b) 38 | require.NoError(t, err, "Unexpected error: %v", err) 39 | require.NoError(t, err, "Unexpected error: %v", err) 40 | require.Equal(t, types.L2, b.Metadata[types.HNSWSpace]) 41 | }) 42 | 43 | t.Run("With Metadata", func(t *testing.T) { 44 | b := &Builder{} 45 | err := WithMetadata("testKey", "testValue")(b) 46 | require.NoError(t, err, "Unexpected error: %v", err) 47 | require.Equal(t, "testValue", b.Metadata["testKey"]) 48 | }) 49 | 50 | t.Run("With Metadatas", func(t *testing.T) { 51 | b := &Builder{} 52 | err := WithMetadatas(map[string]interface{}{"testKey": "testValue"})(b) 53 | require.NoError(t, err, "Unexpected error: %v", err) 54 | require.Equal(t, "testValue", b.Metadata["testKey"]) 55 | }) 56 | 57 | t.Run("With Metadatas for existing no override", func(t *testing.T) { 58 | b := &Builder{} 59 | b.Metadata = map[string]interface{}{"existingKey": "existingValue"} 60 | err := WithMetadatas(map[string]interface{}{"testKey": "testValue"})(b) 61 | require.NoError(t, err, "Unexpected error: %v", err) 62 | require.Contains(t, b.Metadata, "testKey") 63 | require.Equal(t, "testValue", b.Metadata["testKey"]) 64 | require.Contains(t, b.Metadata, "existingKey") 65 | require.Equal(t, "existingValue", b.Metadata["existingKey"]) 66 | }) 67 | 68 | t.Run("With Metadatas for existing with override", func(t *testing.T) { 69 | b := &Builder{} 70 | b.Metadata = map[string]interface{}{"existingKey": "existingValue"} 71 | err := WithMetadatas(map[string]interface{}{"existingKey": "newValue"})(b) 72 | require.NoError(t, err, "Unexpected error: %v", err) 73 | require.Contains(t, b.Metadata, "existingKey") 74 | require.Equal(t, "newValue", b.Metadata["existingKey"]) 75 | }) 76 | 77 | t.Run("With Metadatas for invalid type", func(t *testing.T) { 78 | b := &Builder{} 79 | err := WithMetadatas(map[string]interface{}{"testKey": map[string]interface{}{"invalid": "value"}})(b) 80 | require.Error(t, err) 81 | }) 82 | 83 | t.Run("With HNSW Batch Size", func(t *testing.T) { 84 | b := &Builder{} 85 | err := WithHNSWBatchSize(10)(b) 86 | require.NoError(t, err, "Unexpected error: %v", err) 87 | require.Equal(t, int32(10), b.Metadata[types.HNSWBatchSize]) 88 | }) 89 | 90 | t.Run("With HNSW Sync Threshold", func(t *testing.T) { 91 | b := &Builder{} 92 | err := WithHNSWSyncThreshold(10)(b) 93 | require.NoError(t, err, "Unexpected error: %v", err) 94 | require.Equal(t, int32(10), b.Metadata[types.HNSWSyncThreshold]) 95 | }) 96 | 97 | t.Run("With HNSWM", func(t *testing.T) { 98 | b := &Builder{} 99 | err := WithHNSWM(10)(b) 100 | require.NoError(t, err, "Unexpected error: %v", err) 101 | require.Equal(t, int32(10), b.Metadata[types.HNSWM]) 102 | }) 103 | 104 | t.Run("With HNSW Construction Ef", func(t *testing.T) { 105 | b := &Builder{} 106 | err := WithHNSWConstructionEf(10)(b) 107 | require.NoError(t, err, "Unexpected error: %v", err) 108 | require.Equal(t, int32(10), b.Metadata[types.HNSWConstructionEF]) 109 | }) 110 | 111 | t.Run("With HNSW Search Ef", func(t *testing.T) { 112 | b := &Builder{} 113 | err := WithHNSWSearchEf(10)(b) 114 | require.NoError(t, err, "Unexpected error: %v", err) 115 | require.Equal(t, int32(10), b.Metadata[types.HNSWSearchEF]) 116 | }) 117 | } 118 | -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | go-client.chromadb.dev -------------------------------------------------------------------------------- /docs/docs/assets/images/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amikos-tech/chroma-go/6bff4f299fe5cf068434eaf8cf2c4c6738639309/docs/docs/assets/images/favicon.png -------------------------------------------------------------------------------- /docs/docs/assets/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amikos-tech/chroma-go/6bff4f299fe5cf068434eaf8cf2c4c6738639309/docs/docs/assets/images/logo.png -------------------------------------------------------------------------------- /docs/docs/auth.md: -------------------------------------------------------------------------------- 1 | # Authentication 2 | 3 | There are four ways to authenticate with Chroma: 4 | 5 | - Manual Header authentication - this approach requires you to be familiar with the server-side auth and generate and insert the necessary headers manually. 6 | - Chroma Basic Auth mechanism 7 | - Chroma Token Auth mechanism with Bearer Authorization header 8 | - Chroma Token Auth mechanism with X-Chroma-Token header 9 | 10 | ### Manual Header Authentication 11 | 12 | ```go 13 | package main 14 | 15 | import ( 16 | "context" 17 | "log" 18 | chroma "github.com/amikos-tech/chroma-go" 19 | ) 20 | 21 | func main() { 22 | var defaultHeaders = map[string]string{"Authorization": "Bearer my-custom-token"} 23 | clientWithTenant, err := chroma.NewClient(chroma.WithBasePath("http://api.trychroma.com/v1/"), chroma.WithDefaultHeaders(defaultHeaders)) 24 | if err != nil { 25 | log.Fatalf("Error creating client: %s \n", err) 26 | } 27 | _, err = clientWithTenant.Heartbeat(context.TODO()) 28 | if err != nil { 29 | log.Fatalf("Error calling heartbeat: %s \n", err) 30 | } 31 | } 32 | ``` 33 | 34 | ### Chroma Basic Auth mechanism 35 | 36 | ```go 37 | package main 38 | 39 | import ( 40 | "context" 41 | "log" 42 | chroma "github.com/amikos-tech/chroma-go" 43 | "github.com/amikos-tech/chroma-go/types" 44 | ) 45 | 46 | func main() { 47 | client, err := chroma.NewClient( 48 | chroma.WithBasePath("http://api.trychroma.com/v1/"), 49 | chroma.WithAuth(types.NewBasicAuthCredentialsProvider("myUser", "myPassword")), 50 | ) 51 | if err != nil { 52 | log.Fatalf("Error creating client: %s \n", err) 53 | } 54 | _, err = client.Heartbeat(context.TODO()) 55 | if err != nil { 56 | log.Fatalf("Error calling heartbeat: %s \n", err) 57 | } 58 | } 59 | ``` 60 | 61 | ### Chroma Token Auth mechanism with Bearer Authorization header 62 | 63 | ```go 64 | package main 65 | 66 | import ( 67 | "context" 68 | "log" 69 | chroma "github.com/amikos-tech/chroma-go" 70 | "github.com/amikos-tech/chroma-go/types" 71 | ) 72 | 73 | func main() { 74 | client, err := chroma.NewClient( 75 | chroma.WithBasePath("http://api.trychroma.com/v1/"), 76 | chroma.WithAuth(types.NewTokenAuthCredentialsProvider("my-auth-token", types.AuthorizationTokenHeader)), 77 | ) 78 | if err != nil { 79 | log.Fatalf("Error creating client: %s \n", err) 80 | } 81 | _, err = client.Heartbeat(context.TODO()) 82 | if err != nil { 83 | log.Fatalf("Error calling heartbeat: %s \n", err) 84 | } 85 | } 86 | ``` 87 | 88 | ### Chroma Token Auth mechanism with X-Chroma-Token header 89 | 90 | ```go 91 | package main 92 | 93 | import ( 94 | "context" 95 | "log" 96 | chroma "github.com/amikos-tech/chroma-go" 97 | "github.com/amikos-tech/chroma-go/types" 98 | ) 99 | 100 | func main() { 101 | client, err := chroma.NewClient( 102 | chroma.WithBasePath("http://api.trychroma.com/v1/"), 103 | chroma.WithAuth(types.NewTokenAuthCredentialsProvider("my-auth-token", types.XChromaTokenHeader)), 104 | ) 105 | if err != nil { 106 | log.Fatalf("Error creating client: %s \n", err) 107 | } 108 | _, err = client.Heartbeat(context.TODO()) 109 | if err != nil { 110 | log.Fatalf("Error calling heartbeat: %s \n", err) 111 | } 112 | } 113 | ``` 114 | 115 | -------------------------------------------------------------------------------- /docs/docs/client.md: -------------------------------------------------------------------------------- 1 | # Chroma Client 2 | 3 | Options: 4 | 5 | | Options | Usage | Description | Value | Required | 6 | |-------------------|-----------------------------------------|-----------------------------------------------------------------------------------------|----------------------------|---------------------------------------| 7 | | basePath | `WithBasePath("http://localhost:8000")` | The Chroma server base API. | Non-empty valid URL string | No (default: `http://localhost:8000`) | 8 | | Tenant | `WithTenant("tenant")` | The default tenant to use. | `string` | No (default: `default_tenant`) | 9 | | Database | `WithDatabase("database")` | The default database to use. | `string` | No (default: `default_database`) | 10 | | Debug | `WithDebug(true/false)` | Enable debug mode. | `bool` | No (default: `false`) | 11 | | Default Headers | `WithDefaultHeaders(map[string]string)` | Set default headers for the client. | `map[string]string` | No (default: `nil`) | 12 | | SSL Cert | `WithSSLCert("path/to/cert.pem")` | Set the path to the SSL certificate. | valid path to SSL cert. | No (default: Not Set) | 13 | | Insecure | `WithInsecure()` | Disable SSL certificate verification | | No (default: Not Set) | 14 | | Custom HttpClient | `WithHTTPClient(http.Client)` | Set a custom http client. If this is set then SSL Cert and Insecure options are ignore. | `*http.Client` | No (default: Default HTTPClient) | 15 | 16 | !!! note "Tenant and Database" 17 | 18 | The tenant and database are only supported for Chroma API version `0.4.15+`. 19 | 20 | Creating a new client: 21 | 22 | ```go 23 | package main 24 | 25 | import ( 26 | "context" 27 | "fmt" 28 | "log" 29 | "os" 30 | 31 | chroma "github.com/amikos-tech/chroma-go" 32 | ) 33 | 34 | func main() { 35 | client, err := chroma.NewClient( 36 | chroma.WithBasePath("http://localhost:8000"), 37 | chroma.WithTenant("my_tenant"), 38 | chroma.WithDatabase("my_db"), 39 | chroma.WithDebug(true), 40 | chroma.WithDefaultHeaders(map[string]string{"Authorization": "Bearer my token"}), 41 | chroma.WithSSLCert("path/to/cert.pem"), 42 | ) 43 | if err != nil { 44 | fmt.Printf("Failed to create client: %v", err) 45 | } 46 | // do something with client 47 | 48 | // Close the client to release any resources such as local embedding functions 49 | err = client.Close() 50 | if err != nil { 51 | fmt.Printf("Failed to close client: %v",err) 52 | } 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /docs/docs/filtering.md: -------------------------------------------------------------------------------- 1 | # Filtering 2 | 3 | Chroma offers two types of filters: 4 | 5 | - Metadata - filtering based on metadata attribute values 6 | - Documents - filtering based on document content (contains or not contains) 7 | 8 | ## Metadata 9 | 10 | * TODO - Add builder example 11 | * TODO - Describe all available operations 12 | 13 | ```go 14 | package main 15 | 16 | import ( 17 | "context" 18 | "fmt" 19 | chroma "github.com/amikos-tech/chroma-go" 20 | "github.com/amikos-tech/chroma-go/pkg/embeddings/openai" 21 | "github.com/amikos-tech/chroma-go/types" 22 | "github.com/amikos-tech/chroma-go/where" 23 | ) 24 | 25 | func main() { 26 | embeddingF, err := openai.NewOpenAIEmbeddingFunction("sk-xxxx") 27 | if err != nil { 28 | fmt.Println(err) 29 | return 30 | } 31 | client, err := chroma.NewClient() // connects to localhost:8000 32 | if err != nil { 33 | fmt.Println(err) 34 | return 35 | } 36 | collection, err := client.GetCollection(context.TODO(), "my-collection", embeddingF) 37 | if err != nil { 38 | fmt.Println(err) 39 | return 40 | } 41 | // Filter by metadata 42 | 43 | result, err := collection.GetWithOptions( 44 | context.Background(), 45 | types.WithWhere( 46 | where.Or( 47 | where.Eq("category", "Chroma"), 48 | where.Eq("type", "vector database"), 49 | ), 50 | ), 51 | ) 52 | if err != nil { 53 | fmt.Println(err) 54 | return 55 | } 56 | // do something with result 57 | fmt.Println(result) 58 | } 59 | 60 | ``` 61 | 62 | ## Document 63 | 64 | * TODO - Add builder example 65 | * TODO - Describe all available operations 66 | 67 | ```go 68 | package main 69 | 70 | import ( 71 | "context" 72 | "fmt" 73 | chroma "github.com/amikos-tech/chroma-go" 74 | "github.com/amikos-tech/chroma-go/pkg/embeddings/openai" 75 | "github.com/amikos-tech/chroma-go/types" 76 | "github.com/amikos-tech/chroma-go/where_document" 77 | ) 78 | 79 | func main() { 80 | embeddingF, err := openai.NewOpenAIEmbeddingFunction("sk-xxxx") 81 | if err != nil { 82 | fmt.Println(err) 83 | return 84 | } 85 | client, err := chroma.NewClient(chroma.WithBasePath("http://localhost:8000")) 86 | if err != nil { 87 | fmt.Println(err) 88 | return 89 | } 90 | collection, err := client.GetCollection(context.TODO(), "my-collection", embeddingF) 91 | if err != nil { 92 | fmt.Println(err) 93 | return 94 | } 95 | // Filter by metadata 96 | 97 | result, err := collection.GetWithOptions( 98 | context.Background(), 99 | types.WithWhereDocument( 100 | wheredoc.Or( 101 | wheredoc.Contains("Vector database"), 102 | wheredoc.Contains("Chroma"), 103 | ), 104 | ), 105 | ) 106 | 107 | if err != nil { 108 | fmt.Println(err) 109 | return 110 | } 111 | // do something with result 112 | fmt.Println(result) 113 | } 114 | ``` -------------------------------------------------------------------------------- /docs/docs/javascripts/gtag.js: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/docs/records.md: -------------------------------------------------------------------------------- 1 | # Records 2 | 3 | Records are a mechanism that allows you to manage Chroma documents as a cohesive unit. This has several advantages over 4 | the traditional approach of managing documents, ids, embeddings, and metadata separately. 5 | 6 | Two concepts are important to keep in mind here: 7 | 8 | - Record - corresponds to a single document in Chroma which includes id, embedding, metadata, the document or URI 9 | - RecordSet - a single unit of work to insert, upsert, update or delete records. 10 | 11 | 12 | ## Record 13 | 14 | A Record contains the following fields: 15 | 16 | - ID (string) 17 | - Document (string) - optional 18 | - Metadata (map[string]interface{}) - optional 19 | - Embedding ([]float32 or []int32, wrapped in Embedding struct) 20 | - URI (string) - optional 21 | 22 | Here's the `Record` type: 23 | 24 | ```go 25 | package types 26 | 27 | type Record struct { 28 | ID string 29 | Embedding Embedding 30 | Metadata map[string]interface{} 31 | Document string 32 | URI string 33 | err error // indicating whether the record is valid 34 | } 35 | ``` 36 | 37 | ## RecordSet 38 | 39 | A record set is a cohesive unit of work, allowing the user to add, upsert, update, or delete records. 40 | 41 | 42 | !!! note "Operation support" 43 | 44 | Currently the record set only supports add operation 45 | 46 | ```go 47 | rs, rerr := types.NewRecordSet( 48 | types.WithEmbeddingFunction(types.NewConsistentHashEmbeddingFunction()), 49 | types.WithIDGenerator(types.NewULIDGenerator()), 50 | ) 51 | if err != nil { 52 | log.Fatalf("Error creating record set: %s", err) 53 | } 54 | // you can loop here to add multiple records 55 | rs.WithRecord(types.WithDocument("Document 1 content"), types.WithMetadata("key1", "value1")) 56 | rs.WithRecord(types.WithDocument("Document 2 content"), types.WithMetadata("key2", "value2")) 57 | records, err = rs.BuildAndValidate(context.Background()) 58 | 59 | ``` -------------------------------------------------------------------------------- /docs/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: ChromaDB Go Client 2 | site_url: https://go-client.chromadb.dev 3 | repo_url: https://github.com/amikos-tech/chroma-go 4 | copyright: "Amikos Tech LTD, 2024 (core ChromaDB contributors)" 5 | theme: 6 | name: material 7 | palette: 8 | primary: black 9 | logo: assets/images/logo.png 10 | favicon: assets/images/favicon.png 11 | font: 12 | text: Roboto 13 | code: Roboto Mono 14 | features: 15 | - content.code.annotate 16 | - content.code.copy 17 | - navigation.instant 18 | - navigation.instant.progress 19 | - navigation.tracking 20 | - navigation.indexes 21 | extra: 22 | homepage: https://www.trychroma.com 23 | social: 24 | - icon: fontawesome/brands/github 25 | link: https://github.com/chroma-core/chroma 26 | name: Chroma on GitHub 27 | - icon: fontawesome/brands/twitter 28 | link: https://twitter.com/trychroma 29 | name: Chroma on Twitter 30 | - icon: fontawesome/brands/github 31 | link: https://github.com/amikos-tech 32 | name: Amikos on GitHub 33 | - icon: fontawesome/brands/medium 34 | link: https://medium.com/@amikostech 35 | name: Amikos on Medium 36 | analytics: 37 | provider: google 38 | property: G-NNN722BJKE 39 | consent: 40 | title: Cookie consent 41 | description: >- 42 | We use cookies for analytics purposes. By continuing to use this website, you agree to their use. 43 | extra_javascript: 44 | - javascripts/gtag.js 45 | markdown_extensions: 46 | - abbr 47 | - admonition 48 | - attr_list 49 | - md_in_html 50 | - markdown.extensions.extra 51 | - toc: 52 | permalink: true 53 | title: On this page 54 | toc_depth: 3 55 | - tables 56 | - pymdownx.highlight: 57 | anchor_linenums: true 58 | line_spans: __span 59 | pygments_lang_class: true 60 | - pymdownx.inlinehilite 61 | - pymdownx.snippets: 62 | base_path: assets/snippets/ 63 | - pymdownx.superfences 64 | - pymdownx.tabbed: 65 | alternate_style: true 66 | - pymdownx.tasklist: 67 | custom_checkbox: true 68 | plugins: 69 | - tags 70 | - search -------------------------------------------------------------------------------- /examples/v2/basic/README.md: -------------------------------------------------------------------------------- 1 | # Basic Usage 2 | 3 | This example represents a getting started guide on how to use the go client with new Chroma API v2 and Chroma v1.0.x -------------------------------------------------------------------------------- /examples/v2/basic/go.mod: -------------------------------------------------------------------------------- 1 | module main 2 | 3 | go 1.24 4 | 5 | replace github.com/amikos-tech/chroma-go => ../../../ 6 | 7 | require github.com/amikos-tech/chroma-go v0.2.0 8 | 9 | require ( 10 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect 11 | github.com/google/uuid v1.6.0 // indirect 12 | github.com/oklog/ulid v1.3.1 // indirect 13 | github.com/pkg/errors v0.9.1 // indirect 14 | github.com/yalue/onnxruntime_go v1.19.0 // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /examples/v2/basic/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | 8 | chroma "github.com/amikos-tech/chroma-go/pkg/api/v2" 9 | ) 10 | 11 | func main() { 12 | // Create a new Chroma client 13 | client, err := chroma.NewHTTPClient(chroma.WithDebug()) 14 | if err != nil { 15 | log.Fatalf("Error creating client: %s \n", err) 16 | return 17 | } 18 | // Close the client to release any resources such as local embedding functions 19 | defer func() { 20 | err = client.Close() 21 | if err != nil { 22 | log.Fatalf("Error closing client: %s \n", err) 23 | } 24 | }() 25 | 26 | // Create a new collection with options. We don't provide an embedding function here, so the default embedding function will be used 27 | col, err := client.GetOrCreateCollection(context.Background(), "col1", 28 | chroma.WithCollectionMetadataCreate( 29 | chroma.NewMetadata( 30 | chroma.NewStringAttribute("str", "hello2"), 31 | chroma.NewIntAttribute("int", 1), 32 | chroma.NewFloatAttribute("float", 1.1), 33 | ), 34 | ), 35 | ) 36 | if err != nil { 37 | log.Fatalf("Error creating collection: %s \n", err) 38 | return 39 | } 40 | 41 | err = col.Add(context.Background(), 42 | //chroma.WithIDGenerator(chroma.NewULIDGenerator()), 43 | chroma.WithIDs("1", "2"), 44 | chroma.WithTexts("hello world", "goodbye world"), 45 | chroma.WithMetadatas( 46 | chroma.NewDocumentMetadata(chroma.NewIntAttribute("int", 1)), 47 | chroma.NewDocumentMetadata(chroma.NewStringAttribute("str1", "hello2")), 48 | )) 49 | if err != nil { 50 | log.Fatalf("Error adding collection: %s \n", err) 51 | } 52 | 53 | count, err := col.Count(context.Background()) 54 | if err != nil { 55 | log.Fatalf("Error counting collection: %s \n", err) 56 | return 57 | } 58 | fmt.Printf("Count collection: %d\n", count) 59 | 60 | qr, err := col.Query(context.Background(), 61 | chroma.WithQueryTexts("say hello"), 62 | chroma.WithIncludeQuery(chroma.IncludeDocuments, chroma.IncludeMetadatas), 63 | ) 64 | if err != nil { 65 | log.Fatalf("Error querying collection: %s \n", err) 66 | return 67 | } 68 | fmt.Printf("Query result: %v\n", qr.GetDocumentsGroups()[0][0]) 69 | 70 | err = col.Delete(context.Background(), chroma.WithIDsDelete("1", "2")) 71 | if err != nil { 72 | log.Fatalf("Error deleting collection: %s \n", err) 73 | return 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /examples/v2/custom_embedding_function/README.md: -------------------------------------------------------------------------------- 1 | # Custom Embedding Function Example 2 | 3 | > [!NOTE] 4 | > Coming soon... -------------------------------------------------------------------------------- /examples/v2/custom_embedding_function/go.mod: -------------------------------------------------------------------------------- 1 | module main 2 | 3 | go 1.24.1 4 | 5 | replace github.com/amikos-tech/chroma-go => ../../../ 6 | 7 | require github.com/amikos-tech/chroma-go v0.2.0 8 | 9 | require ( 10 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect 11 | github.com/google/uuid v1.6.0 // indirect 12 | github.com/oklog/ulid v1.3.1 // indirect 13 | github.com/pkg/errors v0.9.1 // indirect 14 | github.com/yalue/onnxruntime_go v1.19.0 // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /examples/v2/custom_embedding_function/go.sum: -------------------------------------------------------------------------------- 1 | github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= 2 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 3 | github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= 4 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 5 | github.com/yalue/onnxruntime_go v1.19.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= 6 | -------------------------------------------------------------------------------- /examples/v2/embedding_function_basic/README.md: -------------------------------------------------------------------------------- 1 | # Embedding Function Usage Example 2 | 3 | This example demonstrates who to use embedding functions. 4 | 5 | ## Running 6 | 7 | ```bash 8 | go mod download 9 | OPENAI_API_KEY=sk-xxxxx go run main.go 10 | ``` 11 | -------------------------------------------------------------------------------- /examples/v2/embedding_function_basic/go.mod: -------------------------------------------------------------------------------- 1 | module main 2 | 3 | go 1.24.1 4 | 5 | replace github.com/amikos-tech/chroma-go => ../../../ 6 | 7 | require github.com/amikos-tech/chroma-go v0.2.0 8 | 9 | require ( 10 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect 11 | github.com/google/uuid v1.6.0 // indirect 12 | github.com/oklog/ulid v1.3.1 // indirect 13 | github.com/pkg/errors v0.9.1 // indirect 14 | github.com/yalue/onnxruntime_go v1.19.0 // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /examples/v2/embedding_function_basic/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "os" 8 | 9 | chroma "github.com/amikos-tech/chroma-go/pkg/api/v2" 10 | openai "github.com/amikos-tech/chroma-go/pkg/embeddings/openai" 11 | ) 12 | 13 | func main() { 14 | // Create a new Chroma client 15 | client, err := chroma.NewHTTPClient(chroma.WithDebug()) 16 | if err != nil { 17 | log.Fatalf("Error creating client: %s \n", err) 18 | return 19 | } 20 | // Close the client to release any resources such as local embedding functions 21 | defer func() { 22 | err = client.Close() 23 | if err != nil { 24 | log.Fatalf("Error closing client: %s \n", err) 25 | } 26 | }() 27 | 28 | ef, err := openai.NewOpenAIEmbeddingFunction(os.Getenv("OPENAI_API_KEY"), openai.WithModel(openai.TextEmbedding3Small)) 29 | if err != nil { 30 | log.Fatalf("Error creating embedding function: %s \n", err) 31 | return 32 | } 33 | 34 | // Create a new collection with options. We don't provide an embedding function here, so the default embedding function will be used 35 | col, err := client.GetOrCreateCollection(context.Background(), "openai-embedding-function-test", 36 | chroma.WithCollectionMetadataCreate( 37 | chroma.NewMetadata( 38 | chroma.NewStringAttribute("str", "hello2"), 39 | chroma.NewIntAttribute("int", 1), 40 | chroma.NewFloatAttribute("float", 1.1), 41 | ), 42 | ), 43 | chroma.WithEmbeddingFunctionCreate(ef), 44 | ) 45 | if err != nil { 46 | log.Fatalf("Error creating collection: %s \n", err) 47 | return 48 | } 49 | 50 | err = col.Add(context.Background(), 51 | //chroma.WithIDGenerator(chroma.NewULIDGenerator()), 52 | chroma.WithIDs("1", "2"), 53 | chroma.WithTexts("hello world", "goodbye world"), 54 | chroma.WithMetadatas( 55 | chroma.NewDocumentMetadata(chroma.NewIntAttribute("int", 1)), 56 | chroma.NewDocumentMetadata(chroma.NewStringAttribute("str1", "hello2")), 57 | )) 58 | if err != nil { 59 | log.Fatalf("Error adding collection: %s \n", err) 60 | } 61 | 62 | count, err := col.Count(context.Background()) 63 | if err != nil { 64 | log.Fatalf("Error counting collection: %s \n", err) 65 | return 66 | } 67 | fmt.Printf("Count collection: %d\n", count) 68 | 69 | qr, err := col.Query(context.Background(), chroma.WithQueryTexts("say hello")) 70 | if err != nil { 71 | log.Fatalf("Error querying collection: %s \n", err) 72 | return 73 | } 74 | fmt.Printf("Query result: %v\n", qr.GetDocumentsGroups()[0][0]) 75 | 76 | err = col.Delete(context.Background(), chroma.WithIDsDelete("1", "2")) 77 | if err != nil { 78 | log.Fatalf("Error deleting collection: %s \n", err) 79 | return 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /examples/v2/reranking_function_basic/README.md: -------------------------------------------------------------------------------- 1 | # Reranking Function Usage Examples 2 | 3 | > [!NOTE] 4 | > Coming soon... -------------------------------------------------------------------------------- /examples/v2/reranking_function_basic/go.mod: -------------------------------------------------------------------------------- 1 | module main 2 | 3 | go 1.24.1 4 | 5 | replace github.com/amikos-tech/chroma-go => ../../../ 6 | 7 | require github.com/amikos-tech/chroma-go v0.2.0 8 | 9 | require ( 10 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect 11 | github.com/google/uuid v1.6.0 // indirect 12 | github.com/oklog/ulid v1.3.1 // indirect 13 | github.com/pkg/errors v0.9.1 // indirect 14 | github.com/yalue/onnxruntime_go v1.19.0 // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /examples/v2/reranking_function_basic/go.sum: -------------------------------------------------------------------------------- 1 | github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= 2 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 3 | github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= 4 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 5 | github.com/yalue/onnxruntime_go v1.19.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= 6 | -------------------------------------------------------------------------------- /examples/v2/tenant_and_db/README.md: -------------------------------------------------------------------------------- 1 | # Tenant and DB management 2 | 3 | This example shows how to manage tenants and databases. -------------------------------------------------------------------------------- /examples/v2/tenant_and_db/go.mod: -------------------------------------------------------------------------------- 1 | module main 2 | 3 | go 1.24.1 4 | 5 | replace github.com/amikos-tech/chroma-go => ../../../ 6 | 7 | require github.com/amikos-tech/chroma-go v0.2.0 8 | 9 | require ( 10 | github.com/go-viper/mapstructure/v2 v2.2.1 // indirect 11 | github.com/google/uuid v1.6.0 // indirect 12 | github.com/oklog/ulid v1.3.1 // indirect 13 | github.com/pkg/errors v0.9.1 // indirect 14 | github.com/yalue/onnxruntime_go v1.19.0 // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /examples/v2/tenant_and_db/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | chroma "github.com/amikos-tech/chroma-go/pkg/api/v2" 7 | "log" 8 | "math/rand" 9 | ) 10 | 11 | func main() { 12 | client, err := chroma.NewHTTPClient(chroma.WithDebug()) 13 | if err != nil { 14 | log.Fatalf("Error creating client: %s \n", err) 15 | return 16 | } 17 | // Close the client to release any resources such as local embedding functions 18 | defer func() { 19 | err = client.Close() 20 | if err != nil { 21 | log.Fatalf("Error closing client: %s \n", err) 22 | } 23 | }() 24 | r := rand.Int() 25 | tenant, err := client.CreateTenant(context.Background(), chroma.NewTenant(fmt.Sprintf("tenant-%d", r))) 26 | if err != nil { 27 | log.Fatalf("Error creating tenant: %s \n", err) 28 | return 29 | } 30 | fmt.Printf("Created tenant %v\n", tenant) 31 | db1, err := client.CreateDatabase(context.Background(), tenant.Database("db2")) 32 | if err != nil { 33 | log.Fatalf("Error creating database: %s \n", err) 34 | return 35 | } 36 | col, err := client.GetOrCreateCollection(context.Background(), "col1", 37 | chroma.WithDatabaseCreate(db1), chroma.WithCollectionMetadataCreate( 38 | chroma.NewMetadata( 39 | chroma.NewStringAttribute("str", "hello"), 40 | chroma.NewIntAttribute("int", 1), 41 | chroma.NewFloatAttribute("float", 1.1), 42 | ), 43 | ), 44 | ) 45 | if err != nil { 46 | log.Fatalf("Error creating collection: %s \n", err) 47 | return 48 | } 49 | fmt.Printf("Created collection %v+\n", col) 50 | 51 | err = col.Add(context.Background(), 52 | //chroma.WithIDGenerator(chroma.NewULIDGenerator()), 53 | chroma.WithIDs("1", "2"), 54 | chroma.WithTexts("hello world", "goodbye world"), 55 | chroma.WithMetadatas( 56 | chroma.NewDocumentMetadata(chroma.NewIntAttribute("int", 1)), 57 | chroma.NewDocumentMetadata(chroma.NewStringAttribute("str", "hello")), 58 | )) 59 | if err != nil { 60 | log.Fatalf("Error adding collection: %s \n", err) 61 | } 62 | 63 | colCount, err := client.CountCollections(context.Background(), chroma.WithDatabaseCount(db1)) 64 | if err != nil { 65 | log.Fatalf("Error counting collections: %s \n", err) 66 | return 67 | } 68 | fmt.Printf("Count collections in %s : %d\n", db1.String(), colCount) 69 | cols, err := client.ListCollections(context.Background(), chroma.WithDatabaseList(db1)) 70 | if err != nil { 71 | log.Fatalf("Error listing collections: %s \n", err) 72 | return 73 | } 74 | fmt.Printf("List collections in %s : %d\n", db1.String(), len(cols)) 75 | 76 | qr, err := col.Query(context.Background(), chroma.WithQueryTexts("say hello")) 77 | if err != nil { 78 | log.Fatalf("Error querying collection: %s \n", err) 79 | return 80 | } 81 | fmt.Printf("Query result: %v\n", qr.GetDocumentsGroups()[0][0]) 82 | err = col.Delete(context.Background(), chroma.WithIDsDelete("1", "2")) 83 | if err != nil { 84 | log.Fatalf("Error deleting collection: %s \n", err) 85 | return 86 | } 87 | fmt.Printf("Deleted items from collection %s\n", col.Name()) 88 | 89 | err = client.DeleteCollection(context.Background(), "col1", chroma.WithDatabaseDelete(db1)) 90 | if err != nil { 91 | log.Fatalf("Error deleting collection: %s \n", err) 92 | return 93 | } 94 | err = client.DeleteDatabase(context.Background(), db1) 95 | if err != nil { 96 | log.Fatalf("Error deleting database: %s \n", err) 97 | return 98 | } 99 | fmt.Printf("Deleted database %s\n", db1) 100 | } 101 | -------------------------------------------------------------------------------- /gen_api.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | #docker run --rm -v ${PWD}:/local swaggerapi/swagger-codegen-cli -v 5 | 6 | docker run --rm -v ${PWD}:/local swaggerapi/swagger-codegen-cli-v3 generate \ 7 | -DapiTests=false \ 8 | -i /local/openapi.yaml \ 9 | -l go \ 10 | -o /local/swagger -------------------------------------------------------------------------------- /gen_api_v3.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | rm -rf swagger 4 | # if openapi-generator-cli.jar exists in the current directory, use it; otherwise, download it 5 | if [ ! -f openapi-generator-cli.jar ]; then 6 | wget https://repo1.maven.org/maven2/org/openapitools/openapi-generator-cli/6.6.0/openapi-generator-cli-6.6.0.jar -O openapi-generator-cli.jar 7 | fi 8 | mkdir generator 9 | cd generator 10 | jar -xf ../openapi-generator-cli.jar 11 | cp ../patches/model_anyof.mustache ./go/ 12 | jar -cf ../openapi-generator-cli-patched.jar * 13 | jar -cmf META-INF/MANIFEST.MF ../openapi-generator-cli-patched-fixed.jar * 14 | cd .. 15 | rm -rf generator 16 | # on windows: Invoke-WebRequest -OutFile openapi-generator-cli.jar https://repo1.maven.org/maven2/org/openapitools/openapi-generator-cli/6.6.0/openapi-generator-cli-6.6.0.jar 17 | java -jar openapi-generator-cli-patched-fixed.jar generate -i openapi.yaml -g go -o swagger/ 18 | 19 | rm swagger/go.mod 20 | rm swagger/go.sum 21 | rm swagger/.gitignore 22 | rm swagger/.openapi-generator-ignore 23 | rm swagger/.travis.yml 24 | rm swagger/git_push.sh 25 | rm swagger/README.md 26 | rm -rf swagger/test/ 27 | rm -rf swagger/docs/ 28 | rm -rf swagger/api/ 29 | rm -rf swagger/.openapi-generator/ 30 | 31 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/amikos-tech/chroma-go 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.23.6 6 | 7 | require ( 8 | github.com/Masterminds/semver v1.5.0 9 | github.com/docker/docker v28.0.1+incompatible 10 | github.com/go-playground/validator/v10 v10.22.0 11 | github.com/go-viper/mapstructure/v2 v2.2.1 12 | github.com/google/generative-ai-go v0.19.0 13 | github.com/google/uuid v1.6.0 14 | github.com/joho/godotenv v1.5.1 15 | github.com/leanovate/gopter v0.2.11 16 | github.com/oklog/ulid v1.3.1 17 | github.com/pkg/errors v0.9.1 18 | github.com/stretchr/testify v1.10.0 19 | github.com/testcontainers/testcontainers-go v0.36.0 20 | github.com/testcontainers/testcontainers-go/modules/chroma v0.36.0 21 | github.com/testcontainers/testcontainers-go/modules/ollama v0.36.0 22 | github.com/yalue/onnxruntime_go v1.19.0 23 | google.golang.org/api v0.186.0 24 | ) 25 | 26 | require ( 27 | cloud.google.com/go v0.115.0 // indirect 28 | cloud.google.com/go/ai v0.8.0 // indirect 29 | cloud.google.com/go/auth v0.6.0 // indirect 30 | cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect 31 | cloud.google.com/go/compute/metadata v0.5.0 // indirect 32 | cloud.google.com/go/longrunning v0.5.7 // indirect 33 | dario.cat/mergo v1.0.1 // indirect 34 | github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect 35 | github.com/Microsoft/go-winio v0.6.2 // indirect 36 | github.com/cenkalti/backoff/v4 v4.3.0 // indirect 37 | github.com/containerd/log v0.1.0 // indirect 38 | github.com/containerd/platforms v1.0.0-rc.1 // indirect 39 | github.com/cpuguy83/dockercfg v0.3.2 // indirect 40 | github.com/davecgh/go-spew v1.1.1 // indirect 41 | github.com/distribution/reference v0.6.0 // indirect 42 | github.com/docker/go-connections v0.5.0 // indirect 43 | github.com/docker/go-units v0.5.0 // indirect 44 | github.com/ebitengine/purego v0.8.2 // indirect 45 | github.com/felixge/httpsnoop v1.0.4 // indirect 46 | github.com/gabriel-vasile/mimetype v1.4.3 // indirect 47 | github.com/go-logr/logr v1.4.2 // indirect 48 | github.com/go-logr/stdr v1.2.2 // indirect 49 | github.com/go-ole/go-ole v1.2.6 // indirect 50 | github.com/go-playground/locales v0.14.1 // indirect 51 | github.com/go-playground/universal-translator v0.18.1 // indirect 52 | github.com/gogo/protobuf v1.3.2 // indirect 53 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 54 | github.com/golang/protobuf v1.5.4 // indirect 55 | github.com/google/s2a-go v0.1.7 // indirect 56 | github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect 57 | github.com/googleapis/gax-go/v2 v2.12.5 // indirect 58 | github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect 59 | github.com/klauspost/compress v1.17.11 // indirect 60 | github.com/leodido/go-urn v1.4.0 // indirect 61 | github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect 62 | github.com/magiconair/properties v1.8.9 // indirect 63 | github.com/moby/docker-image-spec v1.3.1 // indirect 64 | github.com/moby/patternmatcher v0.6.0 // indirect 65 | github.com/moby/sys/sequential v0.6.0 // indirect 66 | github.com/moby/sys/user v0.3.0 // indirect 67 | github.com/moby/sys/userns v0.1.0 // indirect 68 | github.com/moby/term v0.5.0 // indirect 69 | github.com/morikuni/aec v1.0.0 // indirect 70 | github.com/opencontainers/go-digest v1.0.0 // indirect 71 | github.com/opencontainers/image-spec v1.1.1 // indirect 72 | github.com/pmezard/go-difflib v1.0.0 // indirect 73 | github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect 74 | github.com/shirou/gopsutil/v4 v4.25.1 // indirect 75 | github.com/sirupsen/logrus v1.9.3 // indirect 76 | github.com/tklauser/go-sysconf v0.3.12 // indirect 77 | github.com/tklauser/numcpus v0.6.1 // indirect 78 | github.com/yusufpapurcu/wmi v1.2.4 // indirect 79 | go.opencensus.io v0.24.0 // indirect 80 | go.opentelemetry.io/auto/sdk v1.1.0 // indirect 81 | go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect 82 | go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect 83 | go.opentelemetry.io/otel v1.35.0 // indirect 84 | go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0 // indirect 85 | go.opentelemetry.io/otel/metric v1.35.0 // indirect 86 | go.opentelemetry.io/otel/trace v1.35.0 // indirect 87 | go.opentelemetry.io/proto/otlp v1.3.1 // indirect 88 | golang.org/x/crypto v0.37.0 // indirect 89 | golang.org/x/net v0.39.0 // indirect 90 | golang.org/x/oauth2 v0.23.0 // indirect 91 | golang.org/x/sync v0.13.0 // indirect 92 | golang.org/x/sys v0.32.0 // indirect 93 | golang.org/x/text v0.24.0 // indirect 94 | golang.org/x/time v0.5.0 // indirect 95 | google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect 96 | google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 // indirect 97 | google.golang.org/grpc v1.68.1 // indirect 98 | google.golang.org/protobuf v1.35.2 // indirect 99 | gopkg.in/yaml.v3 v3.0.1 // indirect 100 | ) 101 | -------------------------------------------------------------------------------- /internal/http/constants.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | const ChromaGoClientUserAgent = "chroma-go-client/0.1.x" 4 | -------------------------------------------------------------------------------- /internal/http/errors.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/http" 7 | ) 8 | 9 | // ChromaError represents an error returned by the Chroma API. It contains the ID of the error, the error message and the status code from the HTTP call. 10 | // Example: 11 | // 12 | // { 13 | // "error": "NotFoundError", 14 | // "message": "Tenant default_tenant2 not found" 15 | // } 16 | type ChromaError struct { 17 | ErrorID string `json:"error"` 18 | ErrorCode int `json:"error_code"` 19 | Message string `json:"message"` 20 | } 21 | 22 | func ChromaErrorFromHTTPResponse(resp *http.Response, err error) *ChromaError { 23 | chromaAPIError := &ChromaError{ 24 | ErrorID: "unknown", 25 | Message: "unknown", 26 | } 27 | if err != nil { 28 | chromaAPIError.Message = err.Error() 29 | } 30 | if resp == nil { 31 | return chromaAPIError 32 | } 33 | chromaAPIError.ErrorCode = resp.StatusCode 34 | if err := json.NewDecoder(resp.Body).Decode(chromaAPIError); err != nil { 35 | chromaAPIError.Message = ReadRespBody(resp.Body) 36 | } 37 | return chromaAPIError 38 | } 39 | 40 | func (e *ChromaError) Error() string { 41 | return fmt.Sprintf("Error (%d) %s: %s", e.ErrorCode, e.ErrorID, e.Message) 42 | } 43 | -------------------------------------------------------------------------------- /internal/http/retry.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "net/http" 7 | "time" 8 | ) 9 | 10 | type Option func(*SimpleRetryStrategy) error 11 | 12 | func WithMaxRetries(retries int) Option { 13 | return func(r *SimpleRetryStrategy) error { 14 | if retries <= 0 { 15 | return fmt.Errorf("retries must be a positive integer") 16 | } 17 | r.MaxRetries = retries 18 | return nil 19 | } 20 | } 21 | 22 | func WithFixedDelay(delay time.Duration) Option { 23 | return func(r *SimpleRetryStrategy) error { 24 | if delay <= 0 { 25 | return fmt.Errorf("delay must be a positive integer") 26 | } 27 | r.FixedDelay = delay 28 | return nil 29 | } 30 | } 31 | 32 | func WithRetryableStatusCodes(statusCodes ...int) Option { 33 | return func(r *SimpleRetryStrategy) error { 34 | r.RetryableStatusCodes = statusCodes 35 | return nil 36 | } 37 | } 38 | 39 | func WithExponentialBackOff() Option { 40 | return func(r *SimpleRetryStrategy) error { 41 | r.ExponentialBackOff = true 42 | return nil 43 | } 44 | } 45 | 46 | type SimpleRetryStrategy struct { 47 | MaxRetries int 48 | FixedDelay time.Duration 49 | ExponentialBackOff bool 50 | RetryableStatusCodes []int 51 | } 52 | 53 | func NewSimpleRetryStrategy(opts ...Option) (*SimpleRetryStrategy, error) { 54 | var strategy = &SimpleRetryStrategy{ 55 | MaxRetries: 3, 56 | FixedDelay: time.Duration(1000) * time.Millisecond, 57 | RetryableStatusCodes: []int{}, 58 | } 59 | for _, opt := range opts { 60 | if err := opt(strategy); err != nil { 61 | return nil, err 62 | } 63 | } 64 | return strategy, nil 65 | } 66 | 67 | func (r *SimpleRetryStrategy) DoWithRetry(client *http.Client, req *http.Request) (*http.Response, error) { 68 | var resp *http.Response 69 | var err error 70 | for i := 0; i < r.MaxRetries; i++ { 71 | resp, err = client.Do(req) 72 | if err != nil { 73 | break 74 | } 75 | if resp.StatusCode >= 200 && resp.StatusCode < 400 { 76 | break 77 | } 78 | if r.isRetryable(resp.StatusCode) { 79 | if r.ExponentialBackOff { 80 | time.Sleep(r.FixedDelay * time.Duration(math.Pow(2, float64(i)))) 81 | } else { 82 | time.Sleep(r.FixedDelay) 83 | } 84 | } 85 | } 86 | return resp, err 87 | } 88 | 89 | func (r *SimpleRetryStrategy) isRetryable(code int) bool { 90 | for _, retryableCode := range r.RetryableStatusCodes { 91 | if code == retryableCode { 92 | return true 93 | } 94 | } 95 | return false 96 | } 97 | -------------------------------------------------------------------------------- /internal/http/retry_test.go: -------------------------------------------------------------------------------- 1 | //go:build basic 2 | 3 | package http 4 | 5 | import ( 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestRetryStrategyWithExponentialBackOff(t *testing.T) { 15 | client := &http.Client{} 16 | 17 | retryableStatusCodes := []int{http.StatusInternalServerError} 18 | 19 | // Create a new SimpleRetryStrategy with exponential backoff enabled 20 | retryStrategy, err := NewSimpleRetryStrategy( 21 | WithMaxRetries(3), 22 | WithFixedDelay(100*time.Millisecond), 23 | WithRetryableStatusCodes(retryableStatusCodes...), 24 | WithExponentialBackOff(), 25 | ) 26 | require.NoError(t, err, "error setting up strategy: %v", err) 27 | var serverRetries = 0 28 | // Create a test server that always returns a 500 status code 29 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 30 | serverRetries++ 31 | w.WriteHeader(http.StatusInternalServerError) 32 | })) 33 | defer server.Close() 34 | 35 | req, err := http.NewRequest("GET", server.URL, nil) 36 | require.NoError(t, err, "unexpected error: %v", err) 37 | 38 | startTime := time.Now() 39 | 40 | _, err = retryStrategy.DoWithRetry(client, req) 41 | if err != nil { 42 | t.Fatalf("unexpected error: %v", err) 43 | } 44 | // Calculate the total elapsed time 45 | elapsedTime := time.Since(startTime) 46 | // Since we have exponential backoff with delays 100ms, 200ms, 400ms, the total delay should be at least 700ms 47 | expectedMinDelay := 100*time.Millisecond + 200*time.Millisecond + 400*time.Millisecond 48 | require.Less(t, expectedMinDelay, elapsedTime, "expected total delay to be at least %v, but got %v", expectedMinDelay, elapsedTime) 49 | require.Equal(t, 3, serverRetries) 50 | } 51 | -------------------------------------------------------------------------------- /internal/http/strategy.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import "net/http" 4 | 5 | type RetryStrategy interface { 6 | DoWithRetry(client *http.Client, req *http.Request) (*http.Response, error) 7 | } 8 | -------------------------------------------------------------------------------- /internal/http/utils.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import "io" 4 | 5 | func ReadRespBody(resp io.Reader) string { 6 | if resp == nil { 7 | return "" 8 | } 9 | body, err := io.ReadAll(resp) 10 | if err != nil { 11 | return "" 12 | } 13 | return string(body) 14 | } 15 | -------------------------------------------------------------------------------- /metadata/metadata.go: -------------------------------------------------------------------------------- 1 | package metadata 2 | 3 | import ( 4 | "github.com/amikos-tech/chroma-go/types" 5 | ) 6 | 7 | type MetadataBuilder struct { 8 | Metadata map[string]interface{} 9 | } 10 | 11 | func NewMetadataBuilder(metadata *map[string]interface{}) *MetadataBuilder { 12 | if metadata != nil { 13 | return &MetadataBuilder{Metadata: *metadata} 14 | } 15 | return &MetadataBuilder{Metadata: make(map[string]interface{})} 16 | } 17 | 18 | type Option func(*MetadataBuilder) error 19 | 20 | func WithMetadata(key string, value interface{}) Option { 21 | return func(b *MetadataBuilder) error { 22 | switch value.(type) { 23 | case string, int, float32, bool, int32, uint32, int64, uint64: 24 | b.Metadata[key] = value 25 | case types.DistanceFunction: 26 | b.Metadata[key] = value 27 | default: 28 | return &types.InvalidMetadataValueError{Key: key, Value: value} 29 | } 30 | return nil 31 | } 32 | } 33 | 34 | func WithMetadatas(metadata map[string]interface{}) Option { 35 | return func(b *MetadataBuilder) error { 36 | for k, v := range metadata { 37 | switch v.(type) { 38 | case string, int, float32, bool: 39 | b.Metadata[k] = v 40 | default: 41 | return &types.InvalidMetadataValueError{Key: k, Value: v} 42 | } 43 | } 44 | return nil 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /metadata/metadata_test.go: -------------------------------------------------------------------------------- 1 | //go:build basic 2 | 3 | package metadata 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/amikos-tech/chroma-go/test" 11 | ) 12 | 13 | func TestWithMetadata(t *testing.T) { 14 | t.Run("Test invalid Metadata", func(t *testing.T) { 15 | builder := NewMetadataBuilder(nil) 16 | err := WithMetadata("testKey", map[string]interface{}{"invalid": "value"})(builder) 17 | require.Error(t, err) 18 | }) 19 | 20 | t.Run("Test int", func(t *testing.T) { 21 | actual := make(map[string]interface{}) 22 | builder := NewMetadataBuilder(&actual) 23 | err := WithMetadata("testKey", 1)(builder) 24 | require.NoError(t, err) 25 | expected := map[string]interface{}{ 26 | "testKey": 1, 27 | } 28 | test.Compare(t, actual, expected) 29 | }) 30 | t.Run("Test float32", func(t *testing.T) { 31 | actual := make(map[string]interface{}) 32 | builder := NewMetadataBuilder(&actual) 33 | err := WithMetadata("testKey", float32(1.1))(builder) 34 | require.NoError(t, err) 35 | expected := map[string]interface{}{ 36 | "testKey": float32(1.1), 37 | } 38 | test.Compare(t, actual, expected) 39 | }) 40 | 41 | t.Run("Test bool", func(t *testing.T) { 42 | actual := make(map[string]interface{}) 43 | builder := NewMetadataBuilder(&actual) 44 | err := WithMetadata("testKey", true)(builder) 45 | require.NoError(t, err) 46 | expected := map[string]interface{}{ 47 | "testKey": true, 48 | } 49 | test.Compare(t, builder.Metadata, expected) 50 | }) 51 | 52 | t.Run("Test string", func(t *testing.T) { 53 | actual := make(map[string]interface{}) 54 | builder := NewMetadataBuilder(&actual) 55 | err := WithMetadata("testKey", "value")(builder) 56 | require.NoError(t, err) 57 | expected := map[string]interface{}{ 58 | "testKey": "value", 59 | } 60 | test.Compare(t, actual, expected) 61 | }) 62 | 63 | t.Run("Test all types", func(t *testing.T) { 64 | actual := make(map[string]interface{}) 65 | builder := NewMetadataBuilder(&actual) 66 | err := WithMetadata("testKey", "value")(builder) 67 | require.NoError(t, err) 68 | err = WithMetadata("testKey2", 1)(builder) 69 | require.NoError(t, err) 70 | err = WithMetadata("testKey3", true)(builder) 71 | require.NoError(t, err) 72 | err = WithMetadata("testKey4", float32(1.1))(builder) 73 | require.NoError(t, err) 74 | expected := map[string]interface{}{ 75 | "testKey": "value", 76 | "testKey2": 1, 77 | "testKey3": true, 78 | "testKey4": float32(1.1), 79 | } 80 | test.Compare(t, actual, expected) 81 | }) 82 | } 83 | -------------------------------------------------------------------------------- /patches/model_anyof.mustache: -------------------------------------------------------------------------------- 1 | // {{classname}} {{{description}}}{{^description}}struct for {{{classname}}}{{/description}} 2 | type {{classname}} struct { 3 | {{#anyOf}} 4 | {{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}} *{{{.}}} 5 | {{/anyOf}} 6 | } 7 | 8 | // Unmarshal JSON data into any of the pointers in the struct 9 | func (dst *{{classname}}) UnmarshalJSON(data []byte) error { 10 | var err error 11 | {{#isNullable}} 12 | // this object is nullable so check if the payload is null or empty string 13 | if string(data) == "" || string(data) == "{}" { 14 | return nil 15 | } 16 | 17 | {{/isNullable}} 18 | {{#discriminator}} 19 | {{#mappedModels}} 20 | {{#-first}} 21 | // use discriminator value to speed up the lookup 22 | var jsonDict map[string]interface{} 23 | err = json.Unmarshal(data, &jsonDict) 24 | if err != nil { 25 | return fmt.Errorf("failed to unmarshal JSON into map for the discriminator lookup") 26 | } 27 | 28 | {{/-first}} 29 | // check if the discriminator value is '{{{mappingName}}}' 30 | if jsonDict["{{{propertyBaseName}}}"] == "{{{mappingName}}}" { 31 | // try to unmarshal JSON data into {{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}} 32 | err = json.Unmarshal(data, &dst.{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}}); 33 | if err == nil { 34 | json{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}}, _ := json.Marshal(dst.{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}}) 35 | if string(json{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}}) == "{}" { // empty struct 36 | dst.{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}} = nil 37 | } else { 38 | return nil // data stored in dst.{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}}, return on the first match 39 | } 40 | } else { 41 | dst.{{#lambda.type-to-name}}{modelName}{{/lambda.type-to-name}} = nil 42 | } 43 | } 44 | 45 | {{/mappedModels}} 46 | {{/discriminator}} 47 | {{#anyOf}} 48 | // try to unmarshal JSON data into {{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}} 49 | err = json.Unmarshal(data, &dst.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}}); 50 | if err == nil { 51 | json{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}}, _ := json.Marshal(dst.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}}) 52 | if string(json{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}}) == "{}" { // empty struct 53 | dst.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}} = nil 54 | } else { 55 | return nil // data stored in dst.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}}, return on the first match 56 | } 57 | } else { 58 | dst.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}} = nil 59 | } 60 | 61 | {{/anyOf}} 62 | return fmt.Errorf("data failed to match schemas in anyOf({{classname}})") 63 | } 64 | 65 | // Marshal data from the first non-nil pointers in the struct to JSON 66 | func (src *{{classname}}) MarshalJSON() ([]byte, error) { 67 | {{#anyOf}} 68 | if src.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}} != nil { 69 | return json.Marshal(&src.{{#lambda.type-to-name}}{{{.}}}{{/lambda.type-to-name}}) 70 | } 71 | 72 | {{/anyOf}} 73 | return nil, nil // no data in anyOf schemas 74 | } 75 | 76 | {{>nullable_model}} 77 | -------------------------------------------------------------------------------- /pkg/api/v2/auth.go: -------------------------------------------------------------------------------- 1 | package v2 2 | 3 | import ( 4 | "encoding/base64" 5 | 6 | "github.com/pkg/errors" 7 | ) 8 | 9 | type CredentialsProvider interface { 10 | Authenticate(apiClient *BaseAPIClient) error 11 | } 12 | 13 | type BasicAuthCredentialsProvider struct { 14 | Username string 15 | Password string 16 | } 17 | 18 | func NewBasicAuthCredentialsProvider(username, password string) *BasicAuthCredentialsProvider { 19 | return &BasicAuthCredentialsProvider{ 20 | Username: username, 21 | Password: password, 22 | } 23 | } 24 | 25 | func (b *BasicAuthCredentialsProvider) Authenticate(client *BaseAPIClient) error { 26 | auth := b.Username + ":" + b.Password 27 | encodedAuth := base64.StdEncoding.EncodeToString([]byte(auth)) 28 | client.defaultHeaders["Authorization"] = "Basic " + encodedAuth 29 | return nil 30 | } 31 | 32 | type TokenTransportHeader string 33 | 34 | const ( 35 | AuthorizationTokenHeader TokenTransportHeader = "Authorization" 36 | XChromaTokenHeader TokenTransportHeader = "X-Chroma-Token" 37 | ) 38 | 39 | type TokenAuthCredentialsProvider struct { 40 | Token string 41 | Header TokenTransportHeader 42 | } 43 | 44 | func NewTokenAuthCredentialsProvider(token string, header TokenTransportHeader) *TokenAuthCredentialsProvider { 45 | return &TokenAuthCredentialsProvider{ 46 | Token: token, 47 | Header: header, 48 | } 49 | } 50 | 51 | func (t *TokenAuthCredentialsProvider) Authenticate(client *BaseAPIClient) error { 52 | switch t.Header { 53 | case AuthorizationTokenHeader: 54 | client.defaultHeaders[string(t.Header)] = "Bearer " + t.Token 55 | return nil 56 | case XChromaTokenHeader: 57 | client.defaultHeaders[string(t.Header)] = t.Token 58 | return nil 59 | default: 60 | return errors.Errorf("unsupported token header: %v", t.Header) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /pkg/api/v2/base.go: -------------------------------------------------------------------------------- 1 | package v2 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/go-viper/mapstructure/v2" 7 | "github.com/pkg/errors" 8 | ) 9 | 10 | type Tenant interface { 11 | Name() string 12 | String() string 13 | Database(dbName string) Database 14 | Validate() error 15 | } 16 | 17 | type Database interface { 18 | ID() string 19 | Name() string 20 | Tenant() Tenant 21 | String() string 22 | Validate() error 23 | } 24 | 25 | type Include string 26 | 27 | const ( 28 | IncludeMetadatas Include = "metadatas" 29 | IncludeDocuments Include = "documents" 30 | IncludeEmbeddings Include = "embeddings" 31 | IncludeURIs Include = "uris" 32 | ) 33 | 34 | type Identity struct { 35 | UserID string `json:"user_id"` 36 | Tenant string `json:"tenant"` 37 | Databases []string `json:"databases"` 38 | } 39 | 40 | type TenantBase struct { 41 | TenantName string `json:"name"` 42 | } 43 | 44 | func (t *TenantBase) Name() string { 45 | return t.TenantName 46 | } 47 | 48 | func NewTenant(name string) Tenant { 49 | return &TenantBase{TenantName: name} 50 | } 51 | 52 | func NewTenantFromJSON(jsonString string) (Tenant, error) { 53 | tenant := &TenantBase{} 54 | err := json.Unmarshal([]byte(jsonString), tenant) 55 | if err != nil { 56 | return nil, err 57 | } 58 | return tenant, nil 59 | } 60 | func (t *TenantBase) String() string { 61 | return t.Name() 62 | } 63 | 64 | func (t *TenantBase) Validate() error { 65 | if t.TenantName == "" { 66 | return errors.New("tenant name cannot be empty") 67 | } 68 | return nil 69 | } 70 | 71 | // Database returns a new Database object that can be used for creating collections 72 | func (t *TenantBase) Database(dbName string) Database { 73 | return NewDatabase(dbName, t) 74 | } 75 | 76 | // TODO this may fail for v1 API 77 | // func (t *TenantBase) MarshalJSON() ([]byte, error) { 78 | // return []byte(`"` + t.Name() + `"`), nil 79 | //} 80 | 81 | func NewDefaultTenant() Tenant { 82 | return NewTenant(DefaultTenant) 83 | } 84 | 85 | type DatabaseBase struct { 86 | DBName string `json:"name" mapstructure:"name"` 87 | DBID string `json:"id,omitempty" mapstructure:"id"` 88 | TenantName string `json:"tenant,omitempty" mapstructure:"tenant"` 89 | tenant Tenant 90 | } 91 | 92 | func (d DatabaseBase) Name() string { 93 | return d.DBName 94 | } 95 | 96 | func (d DatabaseBase) Tenant() Tenant { 97 | if d.tenant == nil && d.TenantName != "" { 98 | d.tenant = NewTenant(d.TenantName) 99 | } 100 | return d.tenant 101 | } 102 | 103 | func (d DatabaseBase) String() string { 104 | return d.Name() 105 | } 106 | 107 | func (d DatabaseBase) ID() string { 108 | return d.DBID 109 | } 110 | func (d DatabaseBase) Validate() error { 111 | if d.DBName == "" { 112 | return errors.New("database name cannot be empty") 113 | } 114 | if d.tenant == nil { 115 | return errors.New("tenant cannot be empty") 116 | } 117 | return nil 118 | } 119 | 120 | // TODO this may fail for v1 API 121 | // func (d *DatabaseBase) MarshalJSON() ([]byte, error) { 122 | // return []byte(`"` + d.Name() + `"`), nil 123 | //} 124 | 125 | func NewDatabase(name string, tenant Tenant) Database { 126 | return &DatabaseBase{DBName: name, tenant: tenant} 127 | } 128 | 129 | func NewDatabaseFromJSON(jsonString string) (Database, error) { 130 | database := &DatabaseBase{} 131 | err := json.Unmarshal([]byte(jsonString), database) 132 | if err != nil { 133 | return nil, err 134 | } 135 | if database.TenantName != "" { 136 | database.tenant = NewTenant(database.TenantName) 137 | } else { 138 | database.tenant = NewDefaultTenant() 139 | } 140 | return database, nil 141 | } 142 | 143 | func NewDatabaseFromMap(data map[string]interface{}) (Database, error) { 144 | database := &DatabaseBase{} 145 | err := mapstructure.Decode(data, database) 146 | if err != nil { 147 | return nil, errors.Wrap(err, "error decoding database") 148 | } 149 | if database.TenantName != "" { 150 | database.tenant = NewTenant(database.TenantName) 151 | } else { 152 | database.tenant = NewDefaultTenant() 153 | } 154 | return database, nil 155 | } 156 | 157 | func NewDefaultDatabase() Database { 158 | return NewDatabase(DefaultDatabase, NewDefaultTenant()) 159 | } 160 | -------------------------------------------------------------------------------- /pkg/api/v2/constants.go: -------------------------------------------------------------------------------- 1 | package v2 2 | 3 | const ( 4 | DefaultTenant = "default_tenant" 5 | DefaultDatabase = "default_database" 6 | HNSWSpace = "hnsw:space" 7 | HNSWConstructionEF = "hnsw:construction_ef" 8 | HNSWBatchSize = "hnsw:batch_size" 9 | HNSWSyncThreshold = "hnsw:sync_threshold" 10 | HNSWM = "hnsw:M" 11 | HNSWSearchEF = "hnsw:search_ef" 12 | HNSWNumThreads = "hnsw:num_threads" 13 | HNSWResizeFactor = "hnsw:resize_factor" 14 | ) 15 | -------------------------------------------------------------------------------- /pkg/api/v2/document_test.go: -------------------------------------------------------------------------------- 1 | //go:build basicv2 2 | 3 | package v2 4 | 5 | import ( 6 | "encoding/json" 7 | "github.com/stretchr/testify/require" 8 | "testing" 9 | ) 10 | 11 | func TestTextDocument(t *testing.T) { 12 | 13 | doc := "Hello, world!\n" 14 | 15 | tdoc := NewTextDocument(doc) 16 | 17 | marshal, err := json.Marshal(tdoc) 18 | require.NoError(t, err) 19 | require.Equal(t, `"Hello, world!\n"`, string(marshal)) 20 | } 21 | -------------------------------------------------------------------------------- /pkg/api/v2/ids.go: -------------------------------------------------------------------------------- 1 | package v2 2 | 3 | import ( 4 | "crypto/sha256" 5 | "encoding/hex" 6 | "math/rand" 7 | "time" 8 | 9 | "github.com/google/uuid" 10 | "github.com/oklog/ulid" 11 | ) 12 | 13 | type GenerateOptions struct { 14 | Document string 15 | } 16 | 17 | type IDGeneratorOption func(opts *GenerateOptions) 18 | 19 | func WithDocument(document string) IDGeneratorOption { 20 | return func(opts *GenerateOptions) { 21 | opts.Document = document 22 | } 23 | } 24 | 25 | type IDGenerator interface { 26 | Generate(opts ...IDGeneratorOption) string 27 | } 28 | 29 | type UUIDGenerator struct{} 30 | 31 | func (u *UUIDGenerator) Generate(opts ...IDGeneratorOption) string { 32 | uuidV4 := uuid.New() 33 | return uuidV4.String() 34 | } 35 | 36 | func NewUUIDGenerator() *UUIDGenerator { 37 | return &UUIDGenerator{} 38 | } 39 | 40 | type SHA256Generator struct{} 41 | 42 | func (s *SHA256Generator) Generate(opts ...IDGeneratorOption) string { 43 | op := GenerateOptions{} 44 | for _, opt := range opts { 45 | opt(&op) 46 | } 47 | if op.Document == "" { 48 | op.Document = uuid.New().String() 49 | } 50 | hasher := sha256.New() 51 | hasher.Write([]byte(op.Document)) 52 | sha256Hash := hex.EncodeToString(hasher.Sum(nil)) 53 | return sha256Hash 54 | } 55 | 56 | func NewSHA256Generator() *SHA256Generator { 57 | return &SHA256Generator{} 58 | } 59 | 60 | type ULIDGenerator struct{} 61 | 62 | func (u *ULIDGenerator) Generate(opts ...IDGeneratorOption) string { 63 | t := time.Now() 64 | entropy := rand.New(rand.NewSource(t.UnixNano())) 65 | docULID := ulid.MustNew(ulid.Timestamp(t), entropy) 66 | return docULID.String() 67 | } 68 | 69 | func NewULIDGenerator() *ULIDGenerator { 70 | return &ULIDGenerator{} 71 | } 72 | -------------------------------------------------------------------------------- /pkg/api/v2/record.go: -------------------------------------------------------------------------------- 1 | package v2 2 | 3 | import ( 4 | "github.com/pkg/errors" 5 | 6 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 7 | ) 8 | 9 | type Record interface { 10 | ID() DocumentID 11 | Document() Document // should work for both text and URI based documents 12 | Embedding() embeddings.Embedding 13 | Metadata() DocumentMetadata 14 | Validate() error 15 | Unwrap() (DocumentID, Document, embeddings.Embedding, DocumentMetadata) 16 | } 17 | 18 | type Records []Record 19 | 20 | type SimpleRecord struct { 21 | id string 22 | embedding embeddings.Embedding 23 | metadata DocumentMetadata 24 | document string 25 | uri string 26 | err error // indicating whether the record is valid or nto 27 | } 28 | type RecordOption func(record *SimpleRecord) error 29 | 30 | func WithRecordID(id string) RecordOption { 31 | return func(r *SimpleRecord) error { 32 | r.id = id 33 | return nil 34 | } 35 | } 36 | 37 | func WithRecordEmbedding(embedding embeddings.Embedding) RecordOption { 38 | return func(r *SimpleRecord) error { 39 | r.embedding = embedding 40 | return nil 41 | } 42 | } 43 | 44 | func WithRecordMetadatas(metadata DocumentMetadata) RecordOption { 45 | return func(r *SimpleRecord) error { 46 | r.metadata = metadata 47 | return nil 48 | } 49 | } 50 | func (r *SimpleRecord) constructValidate() error { 51 | if r.id == "" { 52 | return errors.New("record id is empty") 53 | } 54 | return nil 55 | } 56 | func NewSimpleRecord(opts ...RecordOption) (*SimpleRecord, error) { 57 | r := &SimpleRecord{} 58 | for _, opt := range opts { 59 | err := opt(r) 60 | if err != nil { 61 | return nil, errors.Wrap(err, "error applying record option") 62 | } 63 | } 64 | 65 | err := r.constructValidate() 66 | if err != nil { 67 | return nil, errors.Wrap(err, "error validating record") 68 | } 69 | return r, nil 70 | } 71 | 72 | func (r *SimpleRecord) ID() DocumentID { 73 | return DocumentID(r.id) 74 | } 75 | 76 | func (r *SimpleRecord) Document() Document { 77 | return NewTextDocument(r.document) 78 | } 79 | 80 | func (r *SimpleRecord) URI() string { 81 | return r.uri 82 | } 83 | 84 | func (r *SimpleRecord) Embedding() embeddings.Embedding { 85 | return r.embedding 86 | } 87 | 88 | func (r *SimpleRecord) Metadata() DocumentMetadata { 89 | return r.metadata 90 | } 91 | 92 | func (r *SimpleRecord) Validate() error { 93 | return r.err 94 | } 95 | 96 | func (r *SimpleRecord) Unwrap() (DocumentID, Document, embeddings.Embedding, DocumentMetadata) { 97 | return r.ID(), r.Document(), r.Embedding(), r.Metadata() 98 | } 99 | -------------------------------------------------------------------------------- /pkg/api/v2/record_test.go: -------------------------------------------------------------------------------- 1 | //go:build basicv2 2 | 3 | package v2 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/require" 9 | 10 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 11 | ) 12 | 13 | func TestSimpleRecord(t *testing.T) { 14 | record, err := NewSimpleRecord(WithRecordID("1"), 15 | WithRecordEmbedding(embeddings.NewEmbeddingFromFloat32([]float32{1, 2, 3})), 16 | WithRecordMetadatas(NewDocumentMetadata(NewStringAttribute("key", "value")))) 17 | require.NoError(t, err) 18 | require.NotNil(t, record) 19 | } 20 | -------------------------------------------------------------------------------- /pkg/api/v2/reranking.go: -------------------------------------------------------------------------------- 1 | package v2 2 | -------------------------------------------------------------------------------- /pkg/api/v2/results_test.go: -------------------------------------------------------------------------------- 1 | //go:build basicv2 2 | 3 | package v2 4 | 5 | import ( 6 | "encoding/json" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/require" 10 | 11 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 12 | ) 13 | 14 | func TestGetResultDeserialization(t *testing.T) { 15 | var apiResponse = `{ 16 | "documents": [ 17 | "document1", 18 | "document2" 19 | ], 20 | "embeddings": [ 21 | [0.1,0.2], 22 | [0.3,0.4] 23 | ], 24 | "ids": [ 25 | "id1", 26 | "id2" 27 | ], 28 | "include": [ 29 | "distances" 30 | ], 31 | "metadatas": [ 32 | { 33 | "additionalProp1": true, 34 | "additionalProp2": 1, 35 | "additionalProp3": "test" 36 | }, 37 | {"additionalProp1": false} 38 | ] 39 | }` 40 | 41 | var result GetResultImpl 42 | err := json.Unmarshal([]byte(apiResponse), &result) 43 | require.NoError(t, err) 44 | require.Len(t, result.GetDocuments(), 2) 45 | require.Len(t, result.GetIDs(), 2) 46 | require.Equal(t, result.GetIDs()[0], DocumentID("id1")) 47 | require.Equal(t, result.GetDocuments()[0], NewTextDocument("document1")) 48 | require.Equal(t, []float32{0.1, 0.2}, result.GetEmbeddings()[0].ContentAsFloat32()) 49 | require.Len(t, result.GetEmbeddings(), 2) 50 | require.Len(t, result.GetMetadatas(), 2) 51 | } 52 | 53 | func TestQueryResultDeserialization(t *testing.T) { 54 | var apiResponse = `{ 55 | "distances": [ 56 | [ 57 | 0.1 58 | ] 59 | ], 60 | "documents": [ 61 | [ 62 | "string" 63 | ] 64 | ], 65 | "embeddings": [ 66 | [ 67 | [ 68 | 0.1 69 | ] 70 | ] 71 | ], 72 | "ids": [ 73 | [ 74 | "id1" 75 | ] 76 | ], 77 | "include": [ 78 | "distances" 79 | ], 80 | "metadatas": [ 81 | [ 82 | { 83 | "additionalProp1": true, 84 | "additionalProp2": true, 85 | "additionalProp3": true 86 | } 87 | ] 88 | ] 89 | }` 90 | 91 | var result QueryResultImpl 92 | err := json.Unmarshal([]byte(apiResponse), &result) 93 | require.NoError(t, err) 94 | require.Len(t, result.GetIDGroups(), 1) 95 | require.Len(t, result.GetIDGroups()[0], 1) 96 | require.Equal(t, DocumentID("id1"), result.GetIDGroups()[0][0]) 97 | 98 | require.Len(t, result.GetDocumentsGroups(), 1) 99 | require.Len(t, result.GetDocumentsGroups()[0], 1) 100 | require.Equal(t, NewTextDocument("string"), result.GetDocumentsGroups()[0][0]) 101 | 102 | require.Len(t, result.GetEmbeddingsGroups(), 1) 103 | require.Len(t, result.GetEmbeddingsGroups()[0], 1) 104 | require.Equal(t, []float32{0.1}, result.GetEmbeddingsGroups()[0][0].ContentAsFloat32()) 105 | 106 | require.Len(t, result.GetMetadatasGroups(), 1) 107 | require.Len(t, result.GetMetadatasGroups()[0], 1) 108 | metadata := NewDocumentMetadata( 109 | NewBoolAttribute("additionalProp1", true), 110 | NewBoolAttribute("additionalProp3", true), 111 | NewBoolAttribute("additionalProp2", true), 112 | ) 113 | require.Equal(t, metadata, result.GetMetadatasGroups()[0][0]) 114 | 115 | require.Len(t, result.GetDistancesGroups(), 1) 116 | require.Len(t, result.GetDistancesGroups()[0], 1) 117 | require.Equal(t, embeddings.Distance(0.1), result.GetDistancesGroups()[0][0]) 118 | } 119 | -------------------------------------------------------------------------------- /pkg/api/v2/server.htpasswd: -------------------------------------------------------------------------------- 1 | admin:$2y$05$sHZUfxjz/M70r02rdtrYgurCrSIZjWGDyzm5i2CeN0dLsHd9qc8Zy 2 | 3 | -------------------------------------------------------------------------------- /pkg/api/v2/utils.go: -------------------------------------------------------------------------------- 1 | package v2 2 | 3 | import ( 4 | "crypto/ecdsa" 5 | "crypto/elliptic" 6 | "crypto/rand" 7 | "crypto/x509" 8 | "crypto/x509/pkix" 9 | "encoding/pem" 10 | "log" 11 | "math/big" 12 | "net" 13 | "os" 14 | "time" 15 | ) 16 | 17 | func CreateSelfSignedCert(certPath, keyPath string) { 18 | // Generate a private key 19 | privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 20 | if err != nil { 21 | log.Fatalf("Failed to generate private key: %v", err) 22 | } 23 | 24 | // Prepare certificate 25 | notBefore := time.Now() 26 | notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year 27 | 28 | serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) 29 | serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) 30 | if err != nil { 31 | log.Fatalf("Failed to generate serial number: %v", err) 32 | } 33 | 34 | template := x509.Certificate{ 35 | SerialNumber: serialNumber, 36 | Subject: pkix.Name{ 37 | Organization: []string{"Chroma, Inc."}, 38 | CommonName: "localhost", 39 | }, 40 | NotBefore: notBefore, 41 | NotAfter: notAfter, 42 | KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, 43 | ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 44 | BasicConstraintsValid: true, 45 | DNSNames: []string{"localhost"}, 46 | IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, 47 | } 48 | 49 | // Create the certificate 50 | derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) 51 | if err != nil { 52 | log.Fatalf("Failed to create certificate: %v", err) 53 | } 54 | 55 | // Write the certificate to file 56 | certOut, err := os.Create(certPath) 57 | if err != nil { 58 | log.Fatalf("Failed to open cert.pem for writing: %v", err) 59 | } 60 | if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { 61 | log.Fatalf("Failed to write data to cert.pem: %v", err) 62 | } 63 | if err := certOut.Close(); err != nil { 64 | log.Fatalf("Error closing cert.pem: %v", err) 65 | } 66 | log.Printf("Written %s", certPath) 67 | 68 | // Write the private key to file 69 | keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) 70 | if err != nil { 71 | log.Fatalf("Failed to open key.pem for writing: %v", err) 72 | } 73 | privBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) 74 | if err != nil { 75 | log.Fatalf("Unable to marshal private key: %v", err) 76 | } 77 | if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { 78 | log.Fatalf("Failed to write data to key.pem: %v", err) 79 | } 80 | if err := keyOut.Close(); err != nil { 81 | log.Fatalf("Error closing key.pem: %v", err) 82 | } 83 | log.Printf("Written %s", keyPath) 84 | } 85 | -------------------------------------------------------------------------------- /pkg/api/v2/v1-config.yaml: -------------------------------------------------------------------------------- 1 | ######################## 2 | # HTTP server settings # 3 | ######################## 4 | port: 8000 5 | listen_address: "0.0.0.0" 6 | 7 | #################### 8 | # General settings # 9 | #################### 10 | persist_path: "/data" 11 | allow_reset: true 12 | -------------------------------------------------------------------------------- /pkg/api/v2/where_document_test.go: -------------------------------------------------------------------------------- 1 | //go:build basicv2 2 | 3 | package v2 4 | 5 | import ( 6 | "testing" 7 | ) 8 | 9 | func TestWhereDocument(t *testing.T) { 10 | tests := []struct { 11 | name string 12 | filter WhereDocumentFilter 13 | expected string 14 | }{ 15 | { 16 | name: "contain", 17 | filter: Contains("test"), 18 | expected: `{"$contains":"test"}`, 19 | }, 20 | { 21 | name: "not contain", 22 | filter: NotContains("test"), 23 | expected: `{"$not_contains":"test"}`, 24 | }, 25 | { 26 | name: "or", 27 | filter: OrDocument(Contains("test"), NotContains("test")), 28 | expected: `{"$or":[{"$contains":"test"},{"$not_contains":"test"}]}`, 29 | }, 30 | { 31 | name: "and", 32 | filter: AndDocument(Contains("test"), NotContains("test")), 33 | expected: `{"$and":[{"$contains":"test"},{"$not_contains":"test"}]}`, 34 | }, 35 | { 36 | name: "or and", 37 | filter: OrDocument(AndDocument(Contains("test"), NotContains("test")), Contains("test")), 38 | expected: `{"$or":[{"$and":[{"$contains":"test"},{"$not_contains":"test"}]},{"$contains":"test"}]}`, 39 | }, 40 | } 41 | for _, test := range tests { 42 | t.Run(test.name, func(t *testing.T) { 43 | actual, err := test.filter.MarshalJSON() 44 | if err != nil { 45 | t.Errorf("error marshalling filter: %v", err) 46 | } 47 | if string(actual) != test.expected { 48 | t.Errorf("expected %s, got %s", test.expected, string(actual)) 49 | } 50 | }) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /pkg/commons/cohere/cohere_commons_test.go: -------------------------------------------------------------------------------- 1 | //go:build ef || rf 2 | 3 | package cohere 4 | 5 | import ( 6 | "fmt" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestValidations(t *testing.T) { 13 | tests := []struct { 14 | name string 15 | options []Option 16 | expectedError string 17 | }{ 18 | { 19 | name: "Test empty API key", 20 | options: []Option{ 21 | WithDefaultModel("model"), 22 | }, 23 | expectedError: "'apiKey' failed on the 'required'", 24 | }, 25 | { 26 | name: "Test without default model", 27 | options: []Option{ 28 | WithAPIKey("dummy"), 29 | }, 30 | expectedError: "'DefaultModel' failed on the 'required'", 31 | }, 32 | } 33 | for _, tt := range tests { 34 | t.Run(tt.name, func(t *testing.T) { 35 | _, err := NewCohereClient(tt.options...) 36 | fmt.Printf("Error: %v\n", err) 37 | require.Error(t, err) 38 | require.Contains(t, err.Error(), tt.expectedError) 39 | }) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /pkg/commons/http/constants.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | const ChromaGoClientUserAgent = "chroma-go-client/0.1.x" 4 | -------------------------------------------------------------------------------- /pkg/commons/http/errors.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/http" 7 | ) 8 | 9 | // ChromaError represents an error returned by the Chroma API. It contains the ID of the error, the error message and the status code from the HTTP call. 10 | // Example: 11 | // 12 | // { 13 | // "error": "NotFoundError", 14 | // "message": "Tenant default_tenant2 not found" 15 | // } 16 | type ChromaError struct { 17 | ErrorID string `json:"error"` 18 | ErrorCode int `json:"error_code"` 19 | Message string `json:"message"` 20 | } 21 | 22 | func ChromaErrorFromHTTPResponse(resp *http.Response, err error) *ChromaError { 23 | chromaAPIError := &ChromaError{ 24 | ErrorID: "unknown", 25 | Message: "unknown", 26 | } 27 | if err != nil { 28 | chromaAPIError.Message = err.Error() 29 | } 30 | if resp == nil { 31 | return chromaAPIError 32 | } 33 | chromaAPIError.ErrorCode = resp.StatusCode 34 | if err := json.NewDecoder(resp.Body).Decode(chromaAPIError); err != nil { 35 | chromaAPIError.Message = ReadRespBody(resp.Body) 36 | } 37 | return chromaAPIError 38 | } 39 | 40 | func (e *ChromaError) Error() string { 41 | return fmt.Sprintf("Error (%d) %s: %s", e.ErrorCode, e.ErrorID, e.Message) 42 | } 43 | -------------------------------------------------------------------------------- /pkg/commons/http/retry.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "net/http" 7 | "time" 8 | ) 9 | 10 | type Option func(*SimpleRetryStrategy) error 11 | 12 | func WithMaxRetries(retries int) Option { 13 | return func(r *SimpleRetryStrategy) error { 14 | if retries <= 0 { 15 | return fmt.Errorf("retries must be a positive integer") 16 | } 17 | r.MaxRetries = retries 18 | return nil 19 | } 20 | } 21 | 22 | func WithFixedDelay(delay time.Duration) Option { 23 | return func(r *SimpleRetryStrategy) error { 24 | if delay <= 0 { 25 | return fmt.Errorf("delay must be a positive integer") 26 | } 27 | r.FixedDelay = delay 28 | return nil 29 | } 30 | } 31 | 32 | func WithRetryableStatusCodes(statusCodes ...int) Option { 33 | return func(r *SimpleRetryStrategy) error { 34 | r.RetryableStatusCodes = statusCodes 35 | return nil 36 | } 37 | } 38 | 39 | func WithExponentialBackOff() Option { 40 | return func(r *SimpleRetryStrategy) error { 41 | r.ExponentialBackOff = true 42 | return nil 43 | } 44 | } 45 | 46 | type SimpleRetryStrategy struct { 47 | MaxRetries int 48 | FixedDelay time.Duration 49 | ExponentialBackOff bool 50 | RetryableStatusCodes []int 51 | } 52 | 53 | func NewSimpleRetryStrategy(opts ...Option) (*SimpleRetryStrategy, error) { 54 | var strategy = &SimpleRetryStrategy{ 55 | MaxRetries: 3, 56 | FixedDelay: time.Duration(1000) * time.Millisecond, 57 | RetryableStatusCodes: []int{}, 58 | } 59 | for _, opt := range opts { 60 | if err := opt(strategy); err != nil { 61 | return nil, err 62 | } 63 | } 64 | return strategy, nil 65 | } 66 | 67 | func (r *SimpleRetryStrategy) DoWithRetry(client *http.Client, req *http.Request) (*http.Response, error) { 68 | var resp *http.Response 69 | var err error 70 | for i := 0; i < r.MaxRetries; i++ { 71 | resp, err = client.Do(req) 72 | if err != nil { 73 | break 74 | } 75 | if resp.StatusCode >= 200 && resp.StatusCode < 400 { 76 | break 77 | } 78 | if r.isRetryable(resp.StatusCode) { 79 | if r.ExponentialBackOff { 80 | time.Sleep(r.FixedDelay * time.Duration(math.Pow(2, float64(i)))) 81 | } else { 82 | time.Sleep(r.FixedDelay) 83 | } 84 | } 85 | } 86 | return resp, err 87 | } 88 | 89 | func (r *SimpleRetryStrategy) isRetryable(code int) bool { 90 | for _, retryableCode := range r.RetryableStatusCodes { 91 | if code == retryableCode { 92 | return true 93 | } 94 | } 95 | return false 96 | } 97 | -------------------------------------------------------------------------------- /pkg/commons/http/retry_test.go: -------------------------------------------------------------------------------- 1 | //go:build basic 2 | 3 | package http 4 | 5 | import ( 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestRetryStrategyWithExponentialBackOff(t *testing.T) { 15 | client := &http.Client{} 16 | 17 | retryableStatusCodes := []int{http.StatusInternalServerError} 18 | 19 | // Create a new SimpleRetryStrategy with exponential backoff enabled 20 | retryStrategy, err := NewSimpleRetryStrategy( 21 | WithMaxRetries(3), 22 | WithFixedDelay(100*time.Millisecond), 23 | WithRetryableStatusCodes(retryableStatusCodes...), 24 | WithExponentialBackOff(), 25 | ) 26 | require.NoError(t, err, "error setting up strategy: %v", err) 27 | var serverRetries = 0 28 | // Create a test server that always returns a 500 status code 29 | server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 30 | serverRetries++ 31 | w.WriteHeader(http.StatusInternalServerError) 32 | })) 33 | defer server.Close() 34 | 35 | req, err := http.NewRequest("GET", server.URL, nil) 36 | require.NoError(t, err, "unexpected error: %v", err) 37 | 38 | startTime := time.Now() 39 | 40 | _, err = retryStrategy.DoWithRetry(client, req) 41 | if err != nil { 42 | t.Fatalf("unexpected error: %v", err) 43 | } 44 | // Calculate the total elapsed time 45 | elapsedTime := time.Since(startTime) 46 | // Since we have exponential backoff with delays 100ms, 200ms, 400ms, the total delay should be at least 700ms 47 | expectedMinDelay := 100*time.Millisecond + 200*time.Millisecond + 400*time.Millisecond 48 | require.Less(t, expectedMinDelay, elapsedTime, "expected total delay to be at least %v, but got %v", expectedMinDelay, elapsedTime) 49 | require.Equal(t, 3, serverRetries) 50 | } 51 | -------------------------------------------------------------------------------- /pkg/commons/http/strategy.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import "net/http" 4 | 5 | type RetryStrategy interface { 6 | DoWithRetry(client *http.Client, req *http.Request) (*http.Response, error) 7 | } 8 | -------------------------------------------------------------------------------- /pkg/commons/http/utils.go: -------------------------------------------------------------------------------- 1 | package http 2 | 3 | import "io" 4 | 5 | func ReadRespBody(resp io.Reader) string { 6 | if resp == nil { 7 | return "" 8 | } 9 | body, err := io.ReadAll(resp) 10 | if err != nil { 11 | return "" 12 | } 13 | return string(body) 14 | } 15 | -------------------------------------------------------------------------------- /pkg/embeddings/cloudflare/option.go: -------------------------------------------------------------------------------- 1 | package cloudflare 2 | 3 | import ( 4 | "net/http" 5 | "os" 6 | 7 | "github.com/pkg/errors" 8 | 9 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 10 | ) 11 | 12 | type Option func(p *CloudflareClient) error 13 | 14 | func WithGatewayEndpoint(endpoint string) Option { 15 | return func(p *CloudflareClient) error { 16 | if endpoint == "" { 17 | return errors.New("endpoint cannot be empty") 18 | } 19 | p.BaseAPI = endpoint 20 | p.IsGateway = true 21 | return nil 22 | } 23 | } 24 | 25 | func WithDefaultModel(model embeddings.EmbeddingModel) Option { 26 | return func(p *CloudflareClient) error { 27 | p.DefaultModel = model 28 | return nil 29 | } 30 | } 31 | 32 | func WithMaxBatchSize(size int) Option { 33 | return func(p *CloudflareClient) error { 34 | if size <= 0 { 35 | return errors.New("max batch size must be greater than 0") 36 | } 37 | p.MaxBatchSize = size 38 | return nil 39 | } 40 | } 41 | 42 | func WithDefaultHeaders(headers map[string]string) Option { 43 | return func(p *CloudflareClient) error { 44 | p.DefaultHeaders = headers 45 | return nil 46 | } 47 | } 48 | 49 | func WithAPIToken(apiToken string) Option { 50 | return func(p *CloudflareClient) error { 51 | p.APIToken = apiToken 52 | return nil 53 | } 54 | } 55 | 56 | func WithAccountID(accountID string) Option { 57 | return func(p *CloudflareClient) error { 58 | if accountID == "" { 59 | return errors.New("account ID cannot be empty") 60 | } 61 | p.AccountID = accountID 62 | return nil 63 | } 64 | } 65 | 66 | func WithEnvAPIToken() Option { 67 | return func(p *CloudflareClient) error { 68 | if apiToken := os.Getenv("CF_API_TOKEN"); apiToken != "" { 69 | p.APIToken = apiToken 70 | return nil 71 | } 72 | return errors.Errorf("CF_API_TOKEN not set") 73 | } 74 | } 75 | 76 | func WithEnvAccountID() Option { 77 | return func(p *CloudflareClient) error { 78 | if accountID := os.Getenv("CF_ACCOUNT_ID"); accountID != "" { 79 | p.AccountID = accountID 80 | return nil 81 | } 82 | return errors.Errorf("CF_ACCOUNT_ID not set") 83 | } 84 | } 85 | 86 | func WithHTTPClient(client *http.Client) Option { 87 | return func(p *CloudflareClient) error { 88 | if client == nil { 89 | return errors.New("http client cannot be nil") 90 | } 91 | p.Client = client 92 | return nil 93 | } 94 | } 95 | 96 | func WithEnvGatewayEndpoint() Option { 97 | return func(p *CloudflareClient) error { 98 | if endpoint := os.Getenv("CF_GATEWAY_ENDPOINT"); endpoint != "" { 99 | p.BaseAPI = endpoint 100 | p.IsGateway = true 101 | return nil 102 | } 103 | return errors.Errorf("CF_GATEWAY_ENDPOINT not set") 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /pkg/embeddings/cohere/cohere_test.go: -------------------------------------------------------------------------------- 1 | //go:build ef 2 | 3 | package cohere 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "os" 9 | "testing" 10 | 11 | ccommons "github.com/amikos-tech/chroma-go/pkg/commons/cohere" 12 | 13 | "github.com/joho/godotenv" 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/require" 16 | ) 17 | 18 | func Test_ef(t *testing.T) { 19 | apiKey := os.Getenv("COHERE_API_KEY") 20 | if apiKey == "" { 21 | err := godotenv.Load("../../../.env") 22 | if err != nil { 23 | assert.Failf(t, "Error loading .env file", "%s", err) 24 | } 25 | apiKey = os.Getenv("COHERE_API_KEY") 26 | } 27 | 28 | t.Run("Test Create Embed", func(t *testing.T) { 29 | ef, err := NewCohereEmbeddingFunction(WithAPIKey(apiKey)) 30 | require.NoError(t, err) 31 | documents := []string{ 32 | "Document 1 content here", 33 | "Document 2 content here", 34 | // Add more documents as needed 35 | } 36 | resp, err := ef.EmbedDocuments(context.Background(), documents) 37 | require.Nil(t, err) 38 | require.NotNil(t, resp) 39 | }) 40 | 41 | t.Run("Test Create Embed with model option", func(t *testing.T) { 42 | ef, err := NewCohereEmbeddingFunction(WithAPIKey(apiKey), WithModel("embed-multilingual-v3.0")) 43 | require.NoError(t, err) 44 | documents := []string{ 45 | "Document 1 content here", 46 | "Document 2 content here", 47 | // Add more documents as needed 48 | } 49 | resp, rerr := ef.EmbedDocuments(context.Background(), documents) 50 | require.Nil(t, rerr) 51 | require.NotNil(t, resp) 52 | }) 53 | 54 | t.Run("Test Create Embed with model option embeddings type uint8", func(t *testing.T) { 55 | ef, err := NewCohereEmbeddingFunction(WithAPIKey(apiKey), WithModel("embed-multilingual-v3.0"), WithEmbeddingTypes(EmbeddingTypeUInt8)) 56 | require.NoError(t, err) 57 | documents := []string{ 58 | "Document 1 content here", 59 | "Document 2 content here", 60 | // Add more documents as needed 61 | } 62 | resp, err := ef.EmbedDocuments(context.Background(), documents) 63 | require.Nil(t, err) 64 | require.NotNil(t, resp) 65 | require.Len(t, resp, 2) 66 | fmt.Printf("resp %T\n", resp[0]) 67 | require.Empty(t, resp[0].ContentAsFloat32()) 68 | require.NotNil(t, resp[0].ContentAsInt32()) 69 | }) 70 | 71 | t.Run("Test Create Embed with model option embeddings type int8", func(t *testing.T) { 72 | ef, err := NewCohereEmbeddingFunction(WithEnvAPIKey(), WithModel("embed-multilingual-v3.0"), WithEmbeddingTypes(EmbeddingTypeInt8)) 73 | require.NoError(t, err) 74 | documents := []string{ 75 | "Document 1 content here", 76 | "Document 2 content here", 77 | // Add more documents as needed 78 | } 79 | resp, err := ef.EmbedDocuments(context.Background(), documents) 80 | require.Nil(t, err) 81 | require.NotNil(t, resp) 82 | require.Len(t, resp, 2) 83 | require.Empty(t, resp[0].ContentAsFloat32()) 84 | require.NotNil(t, resp[0].ContentAsInt32()) 85 | }) 86 | 87 | t.Run("Test Create Embed for query", func(t *testing.T) { 88 | ef, err := NewCohereEmbeddingFunction( 89 | WithEnvAPIKey(), 90 | WithModel("embed-multilingual-v3.0"), 91 | ) 92 | require.NoError(t, err) 93 | resp, err := ef.EmbedQuery(context.Background(), "This is a query") 94 | require.Nil(t, err) 95 | require.NotNil(t, resp) 96 | require.NotNil(t, resp.ContentAsFloat32()) 97 | require.Empty(t, resp.ContentAsInt32()) 98 | }) 99 | 100 | t.Run("Test With API options", func(t *testing.T) { 101 | ef, err := NewCohereEmbeddingFunction( 102 | WithEnvAPIKey(), 103 | WithBaseURL(ccommons.DefaultBaseURL), 104 | WithAPIVersion(ccommons.DefaultAPIVersion), 105 | WithModel("embed-multilingual-v3.0"), 106 | ) 107 | require.NoError(t, err) 108 | resp, err := ef.EmbedQuery(context.Background(), "This is a query") 109 | require.Nil(t, err) 110 | require.NotNil(t, resp) 111 | require.NotNil(t, resp.ContentAsFloat32()) 112 | require.Empty(t, resp.ContentAsInt32()) 113 | }) 114 | } 115 | -------------------------------------------------------------------------------- /pkg/embeddings/cohere/option.go: -------------------------------------------------------------------------------- 1 | package cohere 2 | 3 | import ( 4 | "github.com/pkg/errors" 5 | 6 | ccommons "github.com/amikos-tech/chroma-go/pkg/commons/cohere" 7 | httpc "github.com/amikos-tech/chroma-go/pkg/commons/http" 8 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 9 | ) 10 | 11 | type Option func(p *CohereEmbeddingFunction) ccommons.Option 12 | 13 | // WithBaseURL sets the base URL for the Cohere API - the default is https://api.cohere.ai 14 | func WithBaseURL(baseURL string) Option { 15 | return func(p *CohereEmbeddingFunction) ccommons.Option { 16 | return ccommons.WithBaseURL(baseURL) 17 | } 18 | } 19 | 20 | func WithAPIKey(apiKey string) Option { 21 | return func(p *CohereEmbeddingFunction) ccommons.Option { 22 | return ccommons.WithAPIKey(apiKey) 23 | } 24 | } 25 | 26 | // WithEnvAPIKey configures the client to use the COHERE_API_KEY environment variable as the API key 27 | func WithEnvAPIKey() Option { 28 | return func(p *CohereEmbeddingFunction) ccommons.Option { 29 | return ccommons.WithEnvAPIKey() 30 | } 31 | } 32 | 33 | func WithAPIVersion(apiVersion ccommons.APIVersion) Option { 34 | return func(p *CohereEmbeddingFunction) ccommons.Option { 35 | return ccommons.WithAPIVersion(apiVersion) 36 | } 37 | } 38 | 39 | // WithModel sets the default model for the Cohere API - Available models: 40 | // embed-english-v3.0 1024 41 | // embed-multilingual-v3.0 1024 42 | // embed-english-light-v3.0 384 43 | // embed-multilingual-light-v3.0 384 44 | // embed-english-v2.0 4096 (default) 45 | // embed-english-light-v2.0 1024 46 | // embed-multilingual-v2.0 768 47 | func WithModel(model embeddings.EmbeddingModel) Option { 48 | return func(p *CohereEmbeddingFunction) ccommons.Option { 49 | return ccommons.WithDefaultModel(model) 50 | } 51 | } 52 | 53 | // WithDefaultModel sets the default model for the Cohere. This can be overridden in the context of EF embed call. Available models: 54 | // embed-english-v3.0 1024 55 | // embed-multilingual-v3.0 1024 56 | // embed-english-light-v3.0 384 57 | // embed-multilingual-light-v3.0 384 58 | // embed-english-v2.0 4096 (default) 59 | // embed-english-light-v2.0 1024 60 | // embed-multilingual-v2.0 768 61 | func WithDefaultModel(model embeddings.EmbeddingModel) Option { 62 | return func(p *CohereEmbeddingFunction) ccommons.Option { 63 | return ccommons.WithDefaultModel(model) 64 | } 65 | } 66 | 67 | // WithTruncateMode sets the default truncate mode for the Cohere API - Available modes: 68 | // NONE 69 | // START 70 | // END (default) 71 | func WithTruncateMode(truncate TruncateMode) Option { 72 | return func(p *CohereEmbeddingFunction) ccommons.Option { 73 | if truncate != NONE && truncate != START && truncate != END { 74 | return func(c *ccommons.CohereClient) error { 75 | return errors.Errorf("invalid truncate mode %s", truncate) 76 | } 77 | } 78 | p.DefaultTruncateMode = truncate 79 | return ccommons.NoOp() 80 | } 81 | } 82 | 83 | // WithEmbeddingTypes sets the default embedding types for the Cohere API - Available types: 84 | // float (default) 85 | // int8 86 | // uint8 87 | // binary 88 | // ubinary 89 | // TODO we do not have support for returning multiple embedding types from the EmbeddingFunction, so for float->int8, unit8 are supported and returned in the that order 90 | func WithEmbeddingTypes(embeddingTypes ...EmbeddingType) Option { 91 | return func(p *CohereEmbeddingFunction) ccommons.Option { 92 | // if embeddingstypes contains binary or ubinary error 93 | for _, et := range embeddingTypes { 94 | if et == EmbeddingTypeBinary || et == EmbeddingTypeUBinary { 95 | return func(c *ccommons.CohereClient) error { 96 | return errors.Errorf("embedding type %s is not supported", et) 97 | } 98 | } 99 | } 100 | // if embeddingstypes is empty, set to default 101 | if len(embeddingTypes) == 0 { 102 | embeddingTypes = []EmbeddingType{EmbeddingTypeFloat32} 103 | } 104 | p.DefaultEmbeddingTypes = embeddingTypes 105 | return ccommons.NoOp() 106 | } 107 | } 108 | 109 | // WithRetryStrategy configures the client to use the specified retry strategy 110 | func WithRetryStrategy(retryStrategy httpc.RetryStrategy) Option { 111 | return func(p *CohereEmbeddingFunction) ccommons.Option { 112 | return ccommons.WithRetryStrategy(retryStrategy) 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /pkg/embeddings/default_ef/constants.go: -------------------------------------------------------------------------------- 1 | package defaultef 2 | 3 | const ( 4 | LibTokenizersVersion = "0.9.0" 5 | LibOnnxRuntimeVersion = "1.21.0" 6 | onnxModelDownloadEndpoint = "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz" 7 | ChromaCacheDir = ".cache/chroma/" 8 | ) 9 | -------------------------------------------------------------------------------- /pkg/embeddings/default_ef/default_ef_test.go: -------------------------------------------------------------------------------- 1 | package defaultef 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func Test_Default_EF(t *testing.T) { 11 | ef, closeEf, err := NewDefaultEmbeddingFunction() 12 | require.NoError(t, err) 13 | t.Cleanup(func() { 14 | err := closeEf() 15 | if err != nil { 16 | t.Logf("error while closing embedding function: %v", err) 17 | } 18 | }) 19 | require.NotNil(t, ef) 20 | embeddings, err := ef.EmbedDocuments(context.TODO(), []string{"Hello Chroma!", "Hello world!"}) 21 | require.NoError(t, err) 22 | require.NotNil(t, embeddings) 23 | require.Len(t, embeddings, 2) 24 | for _, embedding := range embeddings { 25 | require.Equal(t, embedding.Len(), 384) 26 | } 27 | } 28 | 29 | func TestClose(t *testing.T) { 30 | ef, closeEf, err := NewDefaultEmbeddingFunction() 31 | require.NoError(t, err) 32 | require.NotNil(t, ef) 33 | err = closeEf() 34 | require.NoError(t, err) 35 | _, err = ef.EmbedQuery(context.TODO(), "Hello Chroma!") 36 | require.Error(t, err) 37 | require.Contains(t, err.Error(), "embedding function is closed") 38 | } 39 | func TestCloseClosed(t *testing.T) { 40 | ef := &DefaultEmbeddingFunction{} 41 | err := ef.Close() 42 | require.NoError(t, err) 43 | } 44 | -------------------------------------------------------------------------------- /pkg/embeddings/default_ef/download_utils_test.go: -------------------------------------------------------------------------------- 1 | package defaultef 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestDownload(t *testing.T) { 13 | t.Run("Download", func(t *testing.T) { 14 | onnxLibPath = filepath.Join(t.TempDir(), "libonnxruntime."+LibOnnxRuntimeVersion+"."+getExtensionForOs()) 15 | fmt.Println(onnxLibPath) 16 | err := os.RemoveAll(onnxLibPath) 17 | require.NoError(t, err) 18 | err = EnsureOnnxRuntimeSharedLibrary() 19 | require.NoError(t, err) 20 | }) 21 | t.Run("Download Tokenizers", func(t *testing.T) { 22 | libTokenizersLibPath = filepath.Join(t.TempDir(), "libtokenizers."+LibTokenizersVersion+"."+getExtensionForOs()) 23 | err := os.RemoveAll(libTokenizersLibPath) 24 | require.NoError(t, err) 25 | err = EnsureLibTokenizersSharedLibrary() 26 | require.NoError(t, err) 27 | }) 28 | t.Run("Download Model", func(t *testing.T) { 29 | onnxModelCachePath = filepath.Join(t.TempDir(), "onnx_model") 30 | err := os.RemoveAll(onnxModelCachePath) 31 | require.NoError(t, err) 32 | err = EnsureDefaultEmbeddingFunctionModel() 33 | require.NoError(t, err) 34 | }) 35 | } 36 | -------------------------------------------------------------------------------- /pkg/embeddings/distance_metric.go: -------------------------------------------------------------------------------- 1 | package embeddings 2 | 3 | type DistanceMetric string 4 | 5 | const ( 6 | L2 DistanceMetric = "l2" 7 | COSINE DistanceMetric = "cosine" 8 | IP DistanceMetric = "ip" 9 | ) 10 | 11 | type DistanceMetricOperator interface { 12 | Compare(a, b []float32) float64 13 | } 14 | 15 | type Distance float32 16 | type Distances []Distance 17 | -------------------------------------------------------------------------------- /pkg/embeddings/embedding_test.go: -------------------------------------------------------------------------------- 1 | package embeddings 2 | 3 | import ( 4 | "encoding/json" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestMarshalEmbeddings(t *testing.T) { 11 | embed := NewEmbeddingFromFloat32([]float32{1.1234567891, 2.4, 3.5}) 12 | 13 | bytes, err := json.Marshal(embed) 14 | require.NoError(t, err) 15 | require.JSONEq(t, `[1.1234568,2.4,3.5]`, string(bytes)) 16 | } 17 | 18 | func TestUnmarshalEmbeddings(t *testing.T) { 19 | var embed Float32Embedding 20 | jsonStr := `[1.1234568,2.4,3.5]` 21 | 22 | err := json.Unmarshal([]byte(jsonStr), &embed) 23 | require.NoError(t, err) 24 | require.Equal(t, 3, embed.Len()) 25 | require.Equal(t, float32(1.1234568), embed.ContentAsFloat32()[0]) 26 | require.Equal(t, float32(2.4), embed.ContentAsFloat32()[1]) 27 | require.Equal(t, float32(3.5), embed.ContentAsFloat32()[2]) 28 | } 29 | -------------------------------------------------------------------------------- /pkg/embeddings/gemini/gemini.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/google/generative-ai-go/genai" 7 | "github.com/pkg/errors" 8 | "google.golang.org/api/option" 9 | 10 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 11 | ) 12 | 13 | const ( 14 | DefaultEmbeddingModel = "text-embedding-004" 15 | ModelContextVar = "model" 16 | APIKeyEnvVar = "GEMINI_API_KEY" 17 | ) 18 | 19 | type Client struct { 20 | apiKey string 21 | DefaultModel embeddings.EmbeddingModel 22 | Client *genai.Client 23 | DefaultContext *context.Context 24 | MaxBatchSize int 25 | } 26 | 27 | func applyDefaults(c *Client) (err error) { 28 | if c.DefaultModel == "" { 29 | c.DefaultModel = DefaultEmbeddingModel 30 | } 31 | 32 | if c.DefaultContext == nil { 33 | ctx := context.Background() 34 | c.DefaultContext = &ctx 35 | } 36 | 37 | if c.Client == nil { 38 | c.Client, err = genai.NewClient(*c.DefaultContext, option.WithAPIKey(c.apiKey)) 39 | if err != nil { 40 | return errors.WithStack(err) 41 | } 42 | } 43 | return nil 44 | } 45 | 46 | func validate(c *Client) error { 47 | if c.apiKey == "" { 48 | return errors.New("API key is required") 49 | } 50 | return nil 51 | } 52 | 53 | func NewGeminiClient(opts ...Option) (*Client, error) { 54 | client := &Client{} 55 | 56 | for _, opt := range opts { 57 | err := opt(client) 58 | if err != nil { 59 | return nil, errors.Wrap(err, "failed to apply Gemini option") 60 | } 61 | } 62 | err := applyDefaults(client) 63 | if err != nil { 64 | return nil, err 65 | } 66 | if err := validate(client); err != nil { 67 | return nil, errors.Wrap(err, "failed to validate Gemini client options") 68 | } 69 | return client, nil 70 | } 71 | 72 | func (c *Client) CreateEmbedding(ctx context.Context, req []string) ([]embeddings.Embedding, error) { 73 | var em *genai.EmbeddingModel 74 | if ctx.Value(ModelContextVar) != nil { 75 | em = c.Client.EmbeddingModel(ctx.Value(ModelContextVar).(string)) 76 | } else { 77 | em = c.Client.EmbeddingModel(string(c.DefaultModel)) 78 | } 79 | b := em.NewBatch() 80 | for _, t := range req { 81 | b.AddContent(genai.Text(t)) 82 | } 83 | res, err := em.BatchEmbedContents(ctx, b) 84 | if err != nil { 85 | return nil, errors.Wrap(err, "failed to embed contents") 86 | } 87 | var embs = make([][]float32, 0) 88 | for _, e := range res.Embeddings { 89 | embs = append(embs, e.Values) 90 | } 91 | 92 | return embeddings.NewEmbeddingsFromFloat32(embs) 93 | } 94 | 95 | // close closes the underlying client 96 | // 97 | //nolint:unused 98 | func (c *Client) close() error { 99 | return c.Client.Close() 100 | } 101 | 102 | var _ embeddings.EmbeddingFunction = (*GeminiEmbeddingFunction)(nil) 103 | 104 | type GeminiEmbeddingFunction struct { 105 | apiClient *Client 106 | } 107 | 108 | func NewGeminiEmbeddingFunction(opts ...Option) (*GeminiEmbeddingFunction, error) { 109 | client, err := NewGeminiClient(opts...) 110 | if err != nil { 111 | return nil, err 112 | } 113 | 114 | return &GeminiEmbeddingFunction{apiClient: client}, nil 115 | } 116 | 117 | // close closes the underlying client 118 | // 119 | //nolint:unused 120 | func (e *GeminiEmbeddingFunction) close() error { 121 | return e.apiClient.close() 122 | } 123 | 124 | func (e *GeminiEmbeddingFunction) EmbedDocuments(ctx context.Context, documents []string) ([]embeddings.Embedding, error) { 125 | if e.apiClient.MaxBatchSize > 0 && len(documents) > e.apiClient.MaxBatchSize { 126 | return nil, errors.Errorf("number of documents exceeds the maximum batch size %v", e.apiClient.MaxBatchSize) 127 | } 128 | if len(documents) == 0 { 129 | return embeddings.NewEmptyEmbeddings(), nil 130 | } 131 | 132 | response, err := e.apiClient.CreateEmbedding(ctx, documents) 133 | if err != nil { 134 | return nil, errors.Wrap(err, "failed to embed documents") 135 | } 136 | return response, nil 137 | } 138 | 139 | func (e *GeminiEmbeddingFunction) EmbedQuery(ctx context.Context, document string) (embeddings.Embedding, error) { 140 | response, err := e.apiClient.CreateEmbedding(ctx, []string{document}) 141 | if err != nil { 142 | return nil, errors.Wrap(err, "failed to embed query") 143 | } 144 | return response[0], nil 145 | } 146 | -------------------------------------------------------------------------------- /pkg/embeddings/gemini/gemini_test.go: -------------------------------------------------------------------------------- 1 | //go:build ef 2 | 3 | package gemini 4 | 5 | import ( 6 | "context" 7 | "os" 8 | "testing" 9 | 10 | "github.com/joho/godotenv" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func Test_gemini_client(t *testing.T) { 16 | apiKey := os.Getenv(APIKeyEnvVar) 17 | if apiKey == "" { 18 | err := godotenv.Load("../../../.env") 19 | if err != nil { 20 | assert.Failf(t, "Error loading .env file", "%s", err) 21 | } 22 | apiKey = os.Getenv(APIKeyEnvVar) 23 | } 24 | client, err := NewGeminiClient(WithEnvAPIKey()) 25 | require.NoError(t, err) 26 | defer func(client *Client) { 27 | err := client.close() 28 | if err != nil { 29 | 30 | } 31 | }(client) 32 | 33 | t.Run("Test CreateEmbedding", func(t *testing.T) { 34 | resp, rerr := client.CreateEmbedding(context.Background(), []string{"Test document"}) 35 | require.Nil(t, rerr) 36 | require.NotNil(t, resp) 37 | require.Len(t, resp, 1) 38 | }) 39 | } 40 | 41 | func Test_gemini_embedding_function(t *testing.T) { 42 | apiKey := os.Getenv(APIKeyEnvVar) 43 | if apiKey == "" { 44 | err := godotenv.Load("../../../.env") 45 | if err != nil { 46 | assert.Failf(t, "Error loading .env file", "%s", err) 47 | } 48 | apiKey = os.Getenv(APIKeyEnvVar) 49 | } 50 | 51 | t.Run("Test EmbedDocuments with env-based api key", func(t *testing.T) { 52 | embeddingFunction, err := NewGeminiEmbeddingFunction(WithEnvAPIKey()) 53 | defer func(embeddingFunction *GeminiEmbeddingFunction) { 54 | err := embeddingFunction.close() 55 | if err != nil { 56 | 57 | } 58 | }(embeddingFunction) 59 | require.NoError(t, err) 60 | resp, rerr := embeddingFunction.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"}) 61 | require.Nil(t, rerr) 62 | require.NotNil(t, resp) 63 | require.Len(t, resp, 2) 64 | require.Len(t, resp[0].ContentAsFloat32(), 768) 65 | 66 | }) 67 | 68 | t.Run("Test EmbedDocuments with provided API key", func(t *testing.T) { 69 | embeddingFunction, err := NewGeminiEmbeddingFunction(WithAPIKey(apiKey)) 70 | defer func(embeddingFunction *GeminiEmbeddingFunction) { 71 | err := embeddingFunction.close() 72 | if err != nil { 73 | 74 | } 75 | }(embeddingFunction) 76 | require.NoError(t, err) 77 | resp, rerr := embeddingFunction.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"}) 78 | 79 | require.Nil(t, rerr) 80 | require.NotNil(t, resp) 81 | require.Len(t, resp, 2) 82 | require.Len(t, resp[0].ContentAsFloat32(), 768) 83 | 84 | }) 85 | 86 | t.Run("Test EmbedDocuments with provided model", func(t *testing.T) { 87 | embeddingFunction, err := NewGeminiEmbeddingFunction(WithEnvAPIKey(), WithDefaultModel(DefaultEmbeddingModel)) 88 | defer func(embeddingFunction *GeminiEmbeddingFunction) { 89 | err := embeddingFunction.close() 90 | if err != nil { 91 | 92 | } 93 | }(embeddingFunction) 94 | require.NoError(t, err) 95 | resp, rerr := embeddingFunction.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"}) 96 | 97 | require.Nil(t, rerr) 98 | require.NotNil(t, resp) 99 | require.Len(t, resp, 2) 100 | require.Len(t, resp[0].ContentAsFloat32(), 768) 101 | 102 | }) 103 | 104 | t.Run("Test EmbedQuery", func(t *testing.T) { 105 | embeddingFunction, err := NewGeminiEmbeddingFunction(WithEnvAPIKey(), WithDefaultModel(DefaultEmbeddingModel)) 106 | defer func(embeddingFunction *GeminiEmbeddingFunction) { 107 | err := embeddingFunction.close() 108 | if err != nil { 109 | 110 | } 111 | }(embeddingFunction) 112 | require.NoError(t, err) 113 | resp, rerr := embeddingFunction.EmbedQuery(context.Background(), "this is my query") 114 | require.Nil(t, rerr) 115 | require.NotNil(t, resp) 116 | require.Len(t, resp.ContentAsFloat32(), 768) 117 | }) 118 | 119 | t.Run("Test wrong model", func(t *testing.T) { 120 | embeddingFunction, err := NewGeminiEmbeddingFunction(WithEnvAPIKey(), WithDefaultModel("model-does-not-exist")) 121 | defer func(embeddingFunction *GeminiEmbeddingFunction) { 122 | err := embeddingFunction.close() 123 | if err != nil { 124 | 125 | } 126 | }(embeddingFunction) 127 | require.NoError(t, err) 128 | _, rerr := embeddingFunction.EmbedQuery(context.Background(), "this is my query") 129 | require.Contains(t, rerr.Error(), "Error 404") 130 | require.Error(t, rerr) 131 | }) 132 | 133 | t.Run("Test wrong API key", func(t *testing.T) { 134 | embeddingFunction, err := NewGeminiEmbeddingFunction(WithAPIKey("wrong-api-key")) 135 | defer func(embeddingFunction *GeminiEmbeddingFunction) { 136 | err := embeddingFunction.close() 137 | if err != nil { 138 | 139 | } 140 | }(embeddingFunction) 141 | require.NoError(t, err) 142 | _, rerr := embeddingFunction.EmbedQuery(context.Background(), "this is my query") 143 | require.Contains(t, rerr.Error(), "API key not valid") 144 | require.Error(t, rerr) 145 | }) 146 | } 147 | -------------------------------------------------------------------------------- /pkg/embeddings/gemini/option.go: -------------------------------------------------------------------------------- 1 | package gemini 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/google/generative-ai-go/genai" 7 | "github.com/pkg/errors" 8 | 9 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 10 | ) 11 | 12 | type Option func(p *Client) error 13 | 14 | // WithDefaultModel sets the default model for the client 15 | func WithDefaultModel(model embeddings.EmbeddingModel) Option { 16 | return func(p *Client) error { 17 | if model == "" { 18 | return errors.New("model cannot be empty") 19 | } 20 | p.DefaultModel = model 21 | return nil 22 | } 23 | } 24 | 25 | // WithAPIKey sets the API key for the client 26 | func WithAPIKey(apiKey string) Option { 27 | return func(p *Client) error { 28 | if apiKey == "" { 29 | return errors.New("API key cannot be empty") 30 | } 31 | p.apiKey = apiKey 32 | return nil 33 | } 34 | } 35 | 36 | // WithEnvAPIKey sets the API key for the client from the environment variable GOOGLE_API_KEY 37 | func WithEnvAPIKey() Option { 38 | return func(p *Client) error { 39 | if apiKey := os.Getenv(APIKeyEnvVar); apiKey != "" { 40 | p.apiKey = apiKey 41 | return nil 42 | } 43 | return errors.Errorf("%s not set", APIKeyEnvVar) 44 | } 45 | } 46 | 47 | // WithClient sets the generative AI client for the client 48 | func WithClient(client *genai.Client) Option { 49 | return func(p *Client) error { 50 | if client == nil { 51 | return errors.New("google generative AI client is nil") 52 | } 53 | p.Client = client 54 | return nil 55 | } 56 | } 57 | 58 | // WithMaxBatchSize sets the max batch size for the client - this acts as a limit for the number of embeddings that can be sent in a single request 59 | func WithMaxBatchSize(maxBatchSize int) Option { 60 | return func(p *Client) error { 61 | if maxBatchSize < 1 { 62 | return errors.New("max batch size must be greater than 0") 63 | } 64 | p.MaxBatchSize = maxBatchSize 65 | return nil 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /pkg/embeddings/hf/option.go: -------------------------------------------------------------------------------- 1 | package hf 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/pkg/errors" 7 | ) 8 | 9 | type Option func(p *HuggingFaceClient) error 10 | 11 | func WithBaseURL(baseURL string) Option { 12 | return func(p *HuggingFaceClient) error { 13 | if baseURL == "" { 14 | return errors.New("base URL cannot be empty") 15 | } 16 | p.BaseURL = baseURL 17 | return nil 18 | } 19 | } 20 | 21 | func WithAPIKey(apiKey string) Option { 22 | return func(p *HuggingFaceClient) error { 23 | if apiKey == "" { 24 | return errors.New("API key cannot be empty") 25 | } 26 | p.APIKey = apiKey 27 | return nil 28 | } 29 | } 30 | 31 | func WithEnvAPIKey() Option { 32 | return func(p *HuggingFaceClient) error { 33 | if os.Getenv("HF_API_KEY") == "" { 34 | return errors.New("HF_API_KEY not set") 35 | } 36 | p.APIKey = os.Getenv("HF_API_KEY") 37 | return nil 38 | } 39 | } 40 | 41 | func WithModel(model string) Option { 42 | return func(p *HuggingFaceClient) error { 43 | if model == "" { 44 | return errors.New("model cannot be empty") 45 | } 46 | p.Model = model 47 | return nil 48 | } 49 | } 50 | 51 | func WithDefaultHeaders(headers map[string]string) Option { 52 | return func(p *HuggingFaceClient) error { 53 | p.DefaultHeaders = headers 54 | return nil 55 | } 56 | } 57 | 58 | func WithIsHFEIEndpoint() Option { 59 | return func(p *HuggingFaceClient) error { 60 | p.IsHFEIEndpoint = true 61 | return nil 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /pkg/embeddings/jina/jina.go: -------------------------------------------------------------------------------- 1 | package jina 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "io" 8 | "net/http" 9 | 10 | "github.com/pkg/errors" 11 | 12 | chttp "github.com/amikos-tech/chroma-go/pkg/commons/http" 13 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 14 | ) 15 | 16 | type EmbeddingType string 17 | 18 | const ( 19 | EmbeddingTypeFloat EmbeddingType = "float" 20 | DefaultBaseAPIEndpoint = "https://api.jina.ai/v1/embeddings" 21 | DefaultEmbeddingModel embeddings.EmbeddingModel = "jina-embeddings-v2-base-en" 22 | ) 23 | 24 | type EmbeddingRequest struct { 25 | Model string `json:"model"` 26 | Normalized bool `json:"normalized,omitempty"` 27 | EmbeddingType EmbeddingType `json:"embedding_type,omitempty"` 28 | Input []map[string]string `json:"input"` 29 | } 30 | 31 | type EmbeddingResponse struct { 32 | Model string `json:"model"` 33 | Object string `json:"object"` 34 | Usage struct { 35 | TotalTokens int `json:"total_tokens"` 36 | PromptTokens int `json:"prompt_tokens"` 37 | } 38 | Data []struct { 39 | Object string `json:"object"` 40 | Index int `json:"index"` 41 | Embedding []float32 `json:"embedding"` // TODO what about other embedding types - see cohere for example 42 | } 43 | } 44 | 45 | var _ embeddings.EmbeddingFunction = (*JinaEmbeddingFunction)(nil) 46 | 47 | func getDefaults() *JinaEmbeddingFunction { 48 | return &JinaEmbeddingFunction{ 49 | httpClient: http.DefaultClient, 50 | defaultModel: DefaultEmbeddingModel, 51 | embeddingEndpoint: DefaultBaseAPIEndpoint, 52 | normalized: true, 53 | embeddingType: EmbeddingTypeFloat, 54 | } 55 | } 56 | 57 | type JinaEmbeddingFunction struct { 58 | httpClient *http.Client 59 | apiKey string 60 | defaultModel embeddings.EmbeddingModel 61 | embeddingEndpoint string 62 | normalized bool 63 | embeddingType EmbeddingType 64 | } 65 | 66 | func NewJinaEmbeddingFunction(opts ...Option) (*JinaEmbeddingFunction, error) { 67 | ef := getDefaults() 68 | for _, opt := range opts { 69 | err := opt(ef) 70 | if err != nil { 71 | return nil, err 72 | } 73 | } 74 | return ef, nil 75 | } 76 | 77 | func (e *JinaEmbeddingFunction) sendRequest(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) { 78 | payload, err := json.Marshal(req) 79 | if err != nil { 80 | return nil, errors.Wrapf(err, "failed to marshal embedding request body") 81 | } 82 | 83 | httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, e.embeddingEndpoint, bytes.NewBuffer(payload)) 84 | if err != nil { 85 | return nil, errors.Wrapf(err, "failed to create embedding request") 86 | } 87 | 88 | httpReq.Header.Set("Content-Type", "application/json") 89 | httpReq.Header.Set("Accept", "application/json") 90 | httpReq.Header.Set("User-Agent", chttp.ChromaGoClientUserAgent) 91 | httpReq.Header.Set("Authorization", "Bearer "+e.apiKey) 92 | 93 | resp, err := e.httpClient.Do(httpReq) 94 | if err != nil { 95 | return nil, errors.Wrapf(err, "failed to send embedding request") 96 | } 97 | defer resp.Body.Close() 98 | 99 | respData, err := io.ReadAll(resp.Body) 100 | if err != nil { 101 | return nil, errors.Wrapf(err, "failed to read response body") 102 | } 103 | 104 | if resp.StatusCode != http.StatusOK { 105 | return nil, errors.Errorf("unexpected response %v: %s", resp.Status, string(respData)) 106 | } 107 | var response *EmbeddingResponse 108 | if err := json.Unmarshal(respData, &response); err != nil { 109 | return nil, errors.Wrapf(err, "failed to unmarshal embedding response") 110 | } 111 | 112 | return response, nil 113 | } 114 | 115 | func (e *JinaEmbeddingFunction) EmbedDocuments(ctx context.Context, documents []string) ([]embeddings.Embedding, error) { 116 | var Input = make([]map[string]string, len(documents)) 117 | 118 | for i, doc := range documents { 119 | Input[i] = map[string]string{ 120 | "text": doc, 121 | } 122 | } 123 | req := &EmbeddingRequest{ 124 | Model: string(e.defaultModel), 125 | Input: Input, 126 | } 127 | response, err := e.sendRequest(ctx, req) 128 | if err != nil { 129 | return nil, errors.Wrapf(err, "failed to embed documents") 130 | } 131 | var embs []embeddings.Embedding 132 | for _, data := range response.Data { 133 | embs = append(embs, embeddings.NewEmbeddingFromFloat32(data.Embedding)) 134 | } 135 | 136 | return embs, nil 137 | } 138 | 139 | func (e *JinaEmbeddingFunction) EmbedQuery(ctx context.Context, document string) (embeddings.Embedding, error) { 140 | var Input = make([]map[string]string, 1) 141 | 142 | Input[0] = map[string]string{ 143 | "text": document, 144 | } 145 | req := &EmbeddingRequest{ 146 | Model: string(e.defaultModel), 147 | Input: Input, 148 | } 149 | response, err := e.sendRequest(ctx, req) 150 | if err != nil { 151 | return nil, errors.Wrapf(err, "failed to embed query") 152 | } 153 | 154 | return embeddings.NewEmbeddingFromFloat32(response.Data[0].Embedding), nil 155 | } 156 | -------------------------------------------------------------------------------- /pkg/embeddings/jina/jina_test.go: -------------------------------------------------------------------------------- 1 | //go:build ef 2 | 3 | package jina 4 | 5 | import ( 6 | "context" 7 | "os" 8 | "testing" 9 | 10 | "github.com/joho/godotenv" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestJinaEmbeddingFunction(t *testing.T) { 16 | apiKey := os.Getenv("JINA_API_KEY") 17 | if apiKey == "" { 18 | err := godotenv.Load("../../../.env") 19 | if err != nil { 20 | assert.Failf(t, "Error loading .env file", "%s", err) 21 | } 22 | apiKey = os.Getenv("JINA_API_KEY") 23 | } 24 | 25 | t.Run("Test with defaults", func(t *testing.T) { 26 | ef, err := NewJinaEmbeddingFunction(WithAPIKey(apiKey)) 27 | require.NoError(t, err) 28 | documents := []string{ 29 | "Document 1 content here", 30 | "Document 2 content here", 31 | } 32 | resp, err := ef.EmbedDocuments(context.Background(), documents) 33 | require.NoError(t, err) 34 | require.NotNil(t, resp) 35 | require.Len(t, resp, 2) 36 | require.Equal(t, 768, resp[0].Len()) 37 | }) 38 | 39 | t.Run("Test with env API key", func(t *testing.T) { 40 | ef, err := NewJinaEmbeddingFunction(WithEnvAPIKey()) 41 | require.NoError(t, err) 42 | documents := []string{ 43 | "Document 1 content here", 44 | } 45 | resp, err := ef.EmbedDocuments(context.Background(), documents) 46 | require.NoError(t, err) 47 | require.NotNil(t, resp) 48 | require.Len(t, resp, 1) 49 | require.Equal(t, 768, resp[0].Len()) 50 | }) 51 | 52 | t.Run("Test with normalized off", func(t *testing.T) { 53 | ef, err := NewJinaEmbeddingFunction(WithEnvAPIKey(), WithNormalized(false)) 54 | require.NoError(t, err) 55 | documents := []string{ 56 | "Document 1 content here", 57 | } 58 | resp, err := ef.EmbedDocuments(context.Background(), documents) 59 | require.NoError(t, err) 60 | require.NotNil(t, resp) 61 | require.Len(t, resp, 1) 62 | require.Equal(t, 768, resp[0].Len()) 63 | }) 64 | 65 | t.Run("Test with model", func(t *testing.T) { 66 | ef, err := NewJinaEmbeddingFunction(WithEnvAPIKey(), WithModel("jina-embeddings-v3")) 67 | require.NoError(t, err) 68 | documents := []string{ 69 | "import chromadb;client=chromadb.Client();collection=client.get_or_create_collection('col_name')", 70 | } 71 | resp, err := ef.EmbedDocuments(context.Background(), documents) 72 | require.NoError(t, err) 73 | require.NotNil(t, resp) 74 | require.Len(t, resp, 1) 75 | require.Equal(t, 1024, resp[0].Len()) 76 | }) 77 | 78 | t.Run("Test with EmbeddingType float", func(t *testing.T) { 79 | ef, err := NewJinaEmbeddingFunction(WithEnvAPIKey(), WithEmbeddingType(EmbeddingTypeFloat)) 80 | require.NoError(t, err) 81 | documents := []string{ 82 | "Document 1 content here", 83 | } 84 | resp, err := ef.EmbedDocuments(context.Background(), documents) 85 | require.NoError(t, err) 86 | require.NotNil(t, resp) 87 | require.Len(t, resp, 1) 88 | require.Equal(t, 768, resp[0].Len()) 89 | }) 90 | 91 | t.Run("Test with embedding endpoint", func(t *testing.T) { 92 | ef, err := NewJinaEmbeddingFunction(WithEnvAPIKey(), WithEmbeddingEndpoint(DefaultBaseAPIEndpoint)) 93 | require.NoError(t, err) 94 | documents := []string{ 95 | "Document 1 content here", 96 | } 97 | resp, err := ef.EmbedDocuments(context.Background(), documents) 98 | require.NoError(t, err) 99 | require.NotNil(t, resp) 100 | require.Len(t, resp, 1) 101 | require.Equal(t, 768, resp[0].Len()) 102 | }) 103 | } 104 | -------------------------------------------------------------------------------- /pkg/embeddings/jina/option.go: -------------------------------------------------------------------------------- 1 | package jina 2 | 3 | import ( 4 | "os" 5 | 6 | "github.com/pkg/errors" 7 | 8 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 9 | ) 10 | 11 | type Option func(c *JinaEmbeddingFunction) error 12 | 13 | func WithAPIKey(apiKey string) Option { 14 | return func(c *JinaEmbeddingFunction) error { 15 | c.apiKey = apiKey 16 | return nil 17 | } 18 | } 19 | 20 | func WithEnvAPIKey() Option { 21 | return func(c *JinaEmbeddingFunction) error { 22 | if os.Getenv("JINA_API_KEY") == "" { 23 | return errors.Errorf("JINA_API_KEY not set") 24 | } 25 | c.apiKey = os.Getenv("JINA_API_KEY") 26 | return nil 27 | } 28 | } 29 | 30 | func WithModel(model embeddings.EmbeddingModel) Option { 31 | return func(c *JinaEmbeddingFunction) error { 32 | if model == "" { 33 | return errors.New("model cannot be empty") 34 | } 35 | c.defaultModel = model 36 | return nil 37 | } 38 | } 39 | 40 | func WithEmbeddingEndpoint(endpoint string) Option { 41 | return func(c *JinaEmbeddingFunction) error { 42 | if endpoint == "" { 43 | return errors.New("embedding endpoint cannot be empty") 44 | } 45 | c.embeddingEndpoint = endpoint 46 | return nil 47 | } 48 | } 49 | 50 | // WithNormalized sets the flag to indicate to Jina whether to normalize (L2 norm) the output embeddings or not. Defaults to true 51 | func WithNormalized(normalized bool) Option { 52 | return func(c *JinaEmbeddingFunction) error { 53 | c.normalized = normalized 54 | return nil 55 | } 56 | } 57 | 58 | // WithEmbeddingType sets the type of the embedding to be returned by Jina. The default is float. Right now no other options are supported 59 | func WithEmbeddingType(embeddingType EmbeddingType) Option { 60 | return func(c *JinaEmbeddingFunction) error { 61 | if embeddingType == "" { 62 | return errors.New("embedding type cannot be empty") 63 | } 64 | c.embeddingType = embeddingType 65 | return nil 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /pkg/embeddings/mistral/option.go: -------------------------------------------------------------------------------- 1 | package mistral 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "os" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | type Option func(p *Client) error 12 | 13 | // WithDefaultModel sets the default model for the client 14 | func WithDefaultModel(model string) Option { 15 | return func(p *Client) error { 16 | if model == "" { 17 | return errors.Errorf("default model cannot be empty") 18 | } 19 | p.DefaultModel = model 20 | return nil 21 | } 22 | } 23 | 24 | // WithAPIKey sets the API key for the client 25 | func WithAPIKey(apiKey string) Option { 26 | return func(p *Client) error { 27 | if apiKey == "" { 28 | return errors.Errorf("api key cannot be empty") 29 | } 30 | p.apiKey = apiKey 31 | return nil 32 | } 33 | } 34 | 35 | // WithEnvAPIKey sets the API key for the client from the environment variable GOOGLE_API_KEY 36 | func WithEnvAPIKey() Option { 37 | return func(p *Client) error { 38 | if apiKey := os.Getenv(APIKeyEnvVar); apiKey != "" { 39 | p.apiKey = apiKey 40 | return nil 41 | } 42 | return errors.Errorf("%s not set", APIKeyEnvVar) 43 | } 44 | } 45 | 46 | // WithHTTPClient sets the generative AI client for the client 47 | func WithHTTPClient(client *http.Client) Option { 48 | return func(p *Client) error { 49 | if client == nil { 50 | return errors.Errorf("http client cannot be nil") 51 | } 52 | p.Client = client 53 | return nil 54 | } 55 | } 56 | 57 | // WithMaxBatchSize sets the max batch size for the client - this acts as a limit for the number of embeddings that can be sent in a single request 58 | func WithMaxBatchSize(maxBatchSize int) Option { 59 | return func(p *Client) error { 60 | if maxBatchSize <= 0 { 61 | return errors.Errorf("max batch size must be greater than 0") 62 | } 63 | p.MaxBatchSize = maxBatchSize 64 | return nil 65 | } 66 | } 67 | 68 | // WithBaseURL sets the base URL for the client 69 | func WithBaseURL(baseURL string) Option { 70 | return func(p *Client) error { 71 | if baseURL == "" { 72 | return errors.Errorf("base URL cannot be empty") 73 | } 74 | var err error 75 | p.EmbeddingEndpoint, err = url.JoinPath(baseURL, EmbeddingsEndpoint) 76 | if err != nil { 77 | return errors.Wrap(err, "failed to parse embedding endpoint") 78 | } 79 | return nil 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /pkg/embeddings/nomic/option.go: -------------------------------------------------------------------------------- 1 | package nomic 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "os" 7 | 8 | "github.com/pkg/errors" 9 | 10 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 11 | ) 12 | 13 | type Option func(p *Client) error 14 | 15 | // WithDefaultModel sets the default model for the client 16 | func WithDefaultModel(model embeddings.EmbeddingModel) Option { 17 | return func(p *Client) error { 18 | if model == "" { 19 | return errors.New("default model cannot be empty") 20 | } 21 | p.DefaultModel = model 22 | return nil 23 | } 24 | } 25 | 26 | // WithAPIKey sets the API key for the client 27 | func WithAPIKey(apiKey string) Option { 28 | return func(p *Client) error { 29 | if apiKey == "" { 30 | return errors.New("api key cannot be empty") 31 | } 32 | p.apiKey = apiKey 33 | return nil 34 | } 35 | } 36 | 37 | // WithEnvAPIKey sets the API key for the client from the environment variable GOOGLE_API_KEY 38 | func WithEnvAPIKey() Option { 39 | return func(p *Client) error { 40 | if apiKey := os.Getenv(APIKeyEnvVar); apiKey != "" { 41 | p.apiKey = apiKey 42 | return nil 43 | } 44 | return errors.Errorf("%s not set", APIKeyEnvVar) 45 | } 46 | } 47 | 48 | // WithHTTPClient sets the generative AI client for the client 49 | func WithHTTPClient(client *http.Client) Option { 50 | return func(p *Client) error { 51 | if client == nil { 52 | return errors.New("http client cannot be nil") 53 | } 54 | p.Client = client 55 | return nil 56 | } 57 | } 58 | 59 | // WithMaxBatchSize sets the max batch size for the client - this acts as a limit for the number of embeddings that can be sent in a single request 60 | func WithMaxBatchSize(maxBatchSize int) Option { 61 | return func(p *Client) error { 62 | if maxBatchSize <= 0 { 63 | return errors.New("max batch size must be greater than 0") 64 | } 65 | p.MaxBatchSize = maxBatchSize 66 | return nil 67 | } 68 | } 69 | 70 | // WithBaseURL sets the base URL for the client 71 | func WithBaseURL(baseURL string) Option { 72 | return func(p *Client) error { 73 | if baseURL == "" { 74 | return errors.New("base URL cannot be empty") 75 | } 76 | if _, err := url.ParseRequestURI(baseURL); err != nil { 77 | return errors.Wrap(err, "invalid basePath URL") 78 | } 79 | p.BaseURL = baseURL 80 | return nil 81 | } 82 | } 83 | 84 | // WithTextEmbeddings sets the endpoint to text embeddings 85 | func WithTextEmbeddings() Option { 86 | return func(p *Client) error { 87 | p.EmbeddingsEndpointSuffix = TextEmbeddingsEndpoint 88 | return nil 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /pkg/embeddings/ollama/ollama.go: -------------------------------------------------------------------------------- 1 | package ollama 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "io" 8 | "net/http" 9 | "net/url" 10 | 11 | "github.com/pkg/errors" 12 | 13 | chttp "github.com/amikos-tech/chroma-go/pkg/commons/http" 14 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 15 | ) 16 | 17 | type OllamaClient struct { 18 | BaseURL string 19 | Model embeddings.EmbeddingModel 20 | Client *http.Client 21 | DefaultHeaders map[string]string 22 | } 23 | 24 | type EmbeddingInput struct { 25 | Input string 26 | Inputs []string 27 | } 28 | 29 | func (e EmbeddingInput) MarshalJSON() ([]byte, error) { 30 | if e.Input != "" { 31 | b, err := json.Marshal(e.Input) 32 | if err != nil { 33 | return nil, errors.Wrap(err, "failed to marshal embedding input") 34 | } 35 | return b, nil 36 | } else if len(e.Inputs) > 0 { 37 | b, err := json.Marshal(e.Inputs) 38 | if err != nil { 39 | return nil, errors.Wrap(err, "failed to marshal embedding input") 40 | } 41 | return b, nil 42 | } 43 | return json.Marshal(nil) 44 | } 45 | 46 | type CreateEmbeddingRequest struct { 47 | Model string `json:"model"` 48 | Input *EmbeddingInput `json:"input"` 49 | } 50 | 51 | type CreateEmbeddingResponse struct { 52 | Embeddings [][]float32 `json:"embeddings"` 53 | } 54 | 55 | func (c *CreateEmbeddingRequest) JSON() (string, error) { 56 | data, err := json.Marshal(c) 57 | if err != nil { 58 | return "", errors.Wrap(err, "failed to marshal embedding request JSON") 59 | } 60 | return string(data), nil 61 | } 62 | 63 | func NewOllamaClient(opts ...Option) (*OllamaClient, error) { 64 | client := &OllamaClient{ 65 | Client: &http.Client{}, 66 | } 67 | for _, opt := range opts { 68 | err := opt(client) 69 | if err != nil { 70 | return nil, errors.Wrap(err, "failed to apply Ollama option") 71 | } 72 | } 73 | return client, nil 74 | } 75 | 76 | func (c *OllamaClient) createEmbedding(ctx context.Context, req *CreateEmbeddingRequest) (*CreateEmbeddingResponse, error) { 77 | reqJSON, err := req.JSON() 78 | if err != nil { 79 | return nil, errors.Wrap(err, "failed to marshal embedding request JSON") 80 | } 81 | endpoint, err := url.JoinPath(c.BaseURL, "/api/embed") 82 | if err != nil { 83 | return nil, errors.Wrap(err, "failed to parse Ollama embedding endpoint") 84 | } 85 | 86 | httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBufferString(reqJSON)) 87 | if err != nil { 88 | return nil, errors.Wrap(err, "failed to create HTTP request") 89 | } 90 | for k, v := range c.DefaultHeaders { 91 | httpReq.Header.Set(k, v) 92 | } 93 | httpReq.Header.Set("Accept", "application/json") 94 | httpReq.Header.Set("Content-Type", "application/json") 95 | httpReq.Header.Set("User-Agent", chttp.ChromaGoClientUserAgent) 96 | 97 | resp, err := c.Client.Do(httpReq) 98 | if err != nil { 99 | return nil, errors.Wrap(err, "failed to make HTTP request to Ollama embedding endpoint") 100 | } 101 | defer resp.Body.Close() 102 | 103 | respData, err := io.ReadAll(resp.Body) 104 | if err != nil { 105 | return nil, errors.Wrap(err, "failed to read response body") 106 | } 107 | 108 | if resp.StatusCode != http.StatusOK { 109 | return nil, errors.Errorf("unexpected code [%v] while making a request to %v: %v", resp.Status, endpoint, string(respData)) 110 | } 111 | 112 | var embeddingResponse CreateEmbeddingResponse 113 | if err := json.Unmarshal(respData, &embeddingResponse); err != nil { 114 | return nil, errors.Wrap(err, "failed to unmarshal embedding response") 115 | } 116 | return &embeddingResponse, nil 117 | } 118 | 119 | type OllamaEmbeddingFunction struct { 120 | apiClient *OllamaClient 121 | } 122 | 123 | var _ embeddings.EmbeddingFunction = (*OllamaEmbeddingFunction)(nil) 124 | 125 | func NewOllamaEmbeddingFunction(option ...Option) (*OllamaEmbeddingFunction, error) { 126 | client, err := NewOllamaClient(option...) 127 | if err != nil { 128 | return nil, errors.Wrap(err, "failed to initialize OllamaClient") 129 | } 130 | return &OllamaEmbeddingFunction{ 131 | apiClient: client, 132 | }, nil 133 | } 134 | 135 | func (e *OllamaEmbeddingFunction) EmbedDocuments(ctx context.Context, documents []string) ([]embeddings.Embedding, error) { 136 | response, err := e.apiClient.createEmbedding(ctx, &CreateEmbeddingRequest{ 137 | Model: string(e.apiClient.Model), 138 | Input: &EmbeddingInput{Inputs: documents}, 139 | }) 140 | if err != nil { 141 | return nil, errors.Wrap(err, "failed to embed documents") 142 | } 143 | return embeddings.NewEmbeddingsFromFloat32(response.Embeddings) 144 | } 145 | 146 | func (e *OllamaEmbeddingFunction) EmbedQuery(ctx context.Context, document string) (embeddings.Embedding, error) { 147 | response, err := e.apiClient.createEmbedding(ctx, &CreateEmbeddingRequest{ 148 | Model: string(e.apiClient.Model), 149 | Input: &EmbeddingInput{Input: document}, 150 | }) 151 | if err != nil { 152 | return nil, errors.Wrap(err, "failed to embed query") 153 | } 154 | return embeddings.NewEmbeddingFromFloat32(response.Embeddings[0]), nil 155 | } 156 | -------------------------------------------------------------------------------- /pkg/embeddings/ollama/ollama_test.go: -------------------------------------------------------------------------------- 1 | //go:build ef 2 | 3 | package ollama 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 9 | "github.com/stretchr/testify/require" 10 | tcollama "github.com/testcontainers/testcontainers-go/modules/ollama" 11 | "io" 12 | "net/http" 13 | "strings" 14 | "testing" 15 | ) 16 | 17 | func Test_ollama(t *testing.T) { 18 | ctx := context.Background() 19 | ollamaContainer, err := tcollama.Run(ctx, "ollama/ollama:latest") 20 | require.NoError(t, err) 21 | // Clean up the container 22 | defer func() { 23 | if err := ollamaContainer.Terminate(ctx); err != nil { 24 | t.Logf("failed to terminate container: %s\n", err) 25 | } 26 | }() 27 | 28 | model := "nomic-embed-text" 29 | connectionStr, err := ollamaContainer.ConnectionString(ctx) 30 | require.NoError(t, err) 31 | pullURL := fmt.Sprintf("%s/api/pull", connectionStr) 32 | pullPayload := fmt.Sprintf(`{"name": "%s"}`, model) 33 | 34 | resp, err := http.Post( 35 | pullURL, 36 | "application/json", 37 | strings.NewReader(pullPayload), 38 | ) 39 | require.NoError(t, err) 40 | respStr, err := io.ReadAll(resp.Body) 41 | require.NoError(t, err) 42 | defer resp.Body.Close() 43 | require.Contains(t, string(respStr), "success") 44 | 45 | // Ensure successful response 46 | require.Equal(t, http.StatusOK, resp.StatusCode) 47 | client, err := NewOllamaClient(WithBaseURL(connectionStr), WithModel(embeddings.EmbeddingModel(model))) 48 | require.NoError(t, err) 49 | t.Run("Test Create Embed Single document", func(t *testing.T) { 50 | resp, rerr := client.createEmbedding(context.Background(), &CreateEmbeddingRequest{Model: "nomic-embed-text", Input: &EmbeddingInput{Input: "Document 1 content here"}}) 51 | require.Nil(t, rerr) 52 | require.NotNil(t, resp) 53 | }) 54 | t.Run("Test Create Embed multi-document", func(t *testing.T) { 55 | documents := []string{ 56 | "Document 1 content here", 57 | "Document 2 content here", 58 | } 59 | ef, err := NewOllamaEmbeddingFunction(WithBaseURL(connectionStr), WithModel(embeddings.EmbeddingModel(model))) 60 | require.NoError(t, err) 61 | resp, rerr := ef.EmbedDocuments(context.Background(), documents) 62 | require.Nil(t, rerr) 63 | require.NotNil(t, resp) 64 | require.Equal(t, 2, len(resp)) 65 | }) 66 | } 67 | -------------------------------------------------------------------------------- /pkg/embeddings/ollama/option.go: -------------------------------------------------------------------------------- 1 | package ollama 2 | 3 | import ( 4 | "net/url" 5 | 6 | "github.com/pkg/errors" 7 | 8 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 9 | ) 10 | 11 | type Option func(p *OllamaClient) error 12 | 13 | func WithBaseURL(baseURL string) Option { 14 | return func(p *OllamaClient) error { 15 | if baseURL == "" { 16 | return errors.New("base URL cannot be empty") 17 | } 18 | if _, err := url.ParseRequestURI(baseURL); err != nil { 19 | return errors.Wrap(err, "invalid base URL") 20 | } 21 | p.BaseURL = baseURL 22 | return nil 23 | } 24 | } 25 | func WithModel(model embeddings.EmbeddingModel) Option { 26 | return func(p *OllamaClient) error { 27 | if model == "" { 28 | return errors.New("model cannot be empty") 29 | } 30 | p.Model = model 31 | return nil 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /pkg/embeddings/openai/options.go: -------------------------------------------------------------------------------- 1 | package openai 2 | 3 | import ( 4 | "net/url" 5 | 6 | "github.com/pkg/errors" 7 | ) 8 | 9 | // Option is a function type that can be used to modify the client. 10 | type Option func(c *OpenAIClient) error 11 | 12 | func WithBaseURL(baseURL string) Option { 13 | return func(p *OpenAIClient) error { 14 | if baseURL == "" { 15 | return errors.New("Base URL cannot be empty") 16 | } 17 | if _, err := url.ParseRequestURI(baseURL); err != nil { 18 | return errors.Wrap(err, "invalid base URL") 19 | } 20 | p.BaseURL = baseURL 21 | return nil 22 | } 23 | } 24 | 25 | // WithOpenAIOrganizationID is an option for setting the OpenAI org id. 26 | func WithOpenAIOrganizationID(orgID string) Option { 27 | return func(c *OpenAIClient) error { 28 | if orgID == "" { 29 | return errors.New("OrgID cannot be empty") 30 | } 31 | c.OrgID = orgID 32 | return nil 33 | } 34 | } 35 | 36 | // WithOpenAIUser is an option for setting the OpenAI user. The user is passed with every request to OpenAI. It serves for auditing purposes. If not set the user defaults to ChromaGo client. 37 | func WithOpenAIUser(user string) Option { 38 | return func(c *OpenAIClient) error { 39 | if user == "" { 40 | return errors.New("User cannot be empty") 41 | } 42 | c.User = user 43 | return nil 44 | } 45 | } 46 | 47 | // WithModel is an option for setting the model to use. Must be one of: text-embedding-ada-002, text-embedding-3-small, text-embedding-3-large 48 | func WithModel(model EmbeddingModel) Option { 49 | return func(c *OpenAIClient) error { 50 | if string(model) == "" { 51 | return errors.New("Model cannot be empty") 52 | } 53 | if model != TextEmbeddingAda002 && model != TextEmbedding3Small && model != TextEmbedding3Large { 54 | return errors.Errorf("invalid model name %s. Must be one of: %v", model, []string{string(TextEmbeddingAda002), string(TextEmbedding3Small), string(TextEmbedding3Large)}) 55 | } 56 | c.Model = string(model) 57 | return nil 58 | } 59 | } 60 | func WithDimensions(dimensions int) Option { 61 | return func(c *OpenAIClient) error { 62 | if dimensions <= 0 { 63 | return errors.Errorf("dimensions must be greater than 0, got %d", dimensions) 64 | } 65 | c.Dimensions = &dimensions 66 | return nil 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /pkg/embeddings/together/option.go: -------------------------------------------------------------------------------- 1 | package together 2 | 3 | import ( 4 | "net/http" 5 | "os" 6 | 7 | "github.com/pkg/errors" 8 | 9 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 10 | ) 11 | 12 | type Option func(p *TogetherAIClient) error 13 | 14 | func WithDefaultModel(model embeddings.EmbeddingModel) Option { 15 | return func(p *TogetherAIClient) error { 16 | if model == "" { 17 | return errors.New("default model cannot be empty") 18 | } 19 | p.DefaultModel = model 20 | return nil 21 | } 22 | } 23 | 24 | func WithMaxBatchSize(size int) Option { 25 | return func(p *TogetherAIClient) error { 26 | if size <= 0 { 27 | return errors.New("max batch size must be greater than 0") 28 | } 29 | p.MaxBatchSize = size 30 | return nil 31 | } 32 | } 33 | 34 | func WithDefaultHeaders(headers map[string]string) Option { 35 | return func(p *TogetherAIClient) error { 36 | p.DefaultHeaders = headers 37 | return nil 38 | } 39 | } 40 | 41 | func WithAPIToken(apiToken string) Option { 42 | return func(p *TogetherAIClient) error { 43 | if apiToken == "" { 44 | return errors.New("API token cannot be empty") 45 | } 46 | p.APIToken = apiToken 47 | return nil 48 | } 49 | } 50 | 51 | func WithEnvAPIKey() Option { 52 | return func(p *TogetherAIClient) error { 53 | if apiToken := os.Getenv("TOGETHER_API_KEY"); apiToken != "" { 54 | p.APIToken = apiToken 55 | return nil 56 | } 57 | return errors.New("TOGETHER_API_KEY not set") 58 | } 59 | } 60 | 61 | func WithHTTPClient(client *http.Client) Option { 62 | return func(p *TogetherAIClient) error { 63 | if client == nil { 64 | return errors.New("HTTP client cannot be nil") 65 | } 66 | p.Client = client 67 | return nil 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /pkg/embeddings/together/together_test.go: -------------------------------------------------------------------------------- 1 | //go:build ef 2 | 3 | package together 4 | 5 | import ( 6 | "context" 7 | "net/http" 8 | "os" 9 | "testing" 10 | 11 | "github.com/joho/godotenv" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func Test_client(t *testing.T) { 17 | apiKey := os.Getenv("TOGETHER_API_KEY") 18 | if apiKey == "" { 19 | err := godotenv.Load("../../../.env") 20 | if err != nil { 21 | assert.Failf(t, "Error loading .env file", "%s", err) 22 | } 23 | apiKey = os.Getenv("TOGETHER_API_KEY") 24 | } 25 | client, err := NewTogetherClient(WithEnvAPIKey()) 26 | require.NoError(t, err) 27 | 28 | t.Run("Test CreateEmbedding", func(t *testing.T) { 29 | req := CreateEmbeddingRequest{ 30 | Model: "togethercomputer/m2-bert-80M-8k-retrieval", 31 | Input: &EmbeddingInputs{Input: "Test document"}, 32 | } 33 | resp, rerr := client.CreateEmbedding(context.Background(), &req) 34 | 35 | require.Nil(t, rerr) 36 | require.NotNil(t, resp) 37 | require.NotNil(t, resp.Data) 38 | require.Len(t, resp.Data, 1) 39 | }) 40 | } 41 | 42 | func Test_together_embedding_function(t *testing.T) { 43 | apiKey := os.Getenv("TOGETHER_API_KEY") 44 | if apiKey == "" { 45 | err := godotenv.Load("../../../.env") 46 | if err != nil { 47 | assert.Failf(t, "Error loading .env file", "%s", err) 48 | } 49 | } 50 | 51 | t.Run("Test EmbedDocuments with env-based API Key", func(t *testing.T) { 52 | client, err := NewTogetherEmbeddingFunction(WithEnvAPIKey()) 53 | require.NoError(t, err) 54 | resp, rerr := client.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"}) 55 | 56 | require.Nil(t, rerr) 57 | require.NotNil(t, resp) 58 | require.Len(t, resp, 2) 59 | require.Equal(t, 768, resp[0].Len()) 60 | 61 | }) 62 | 63 | t.Run("Test EmbedDocuments for model with env-based API Key", func(t *testing.T) { 64 | client, err := NewTogetherEmbeddingFunction(WithEnvAPIKey(), WithDefaultModel("togethercomputer/m2-bert-80M-2k-retrieval")) 65 | require.NoError(t, err) 66 | resp, rerr := client.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"}) 67 | 68 | require.Nil(t, rerr) 69 | require.NotNil(t, resp) 70 | require.Len(t, resp, 2) 71 | require.Equal(t, 768, resp[0].Len()) 72 | }) 73 | 74 | t.Run("Test EmbedDocuments with too large init batch", func(t *testing.T) { 75 | _, err := NewTogetherEmbeddingFunction(WithEnvAPIKey(), WithMaxBatchSize(200)) 76 | require.Error(t, err) 77 | require.Contains(t, err.Error(), "max batch size must be less than") 78 | }) 79 | 80 | t.Run("Test EmbedDocuments with too large batch at inference", func(t *testing.T) { 81 | client, err := NewTogetherEmbeddingFunction(WithEnvAPIKey()) 82 | require.NoError(t, err) 83 | docs200 := make([]string, 200) 84 | for i := 0; i < 200; i++ { 85 | docs200[i] = "Test document" 86 | } 87 | _, err = client.EmbedDocuments(context.Background(), docs200) 88 | require.Error(t, err) 89 | require.Contains(t, err.Error(), "number of documents exceeds the maximum batch") 90 | }) 91 | 92 | t.Run("Test EmbedQuery", func(t *testing.T) { 93 | client, err := NewTogetherEmbeddingFunction(WithEnvAPIKey()) 94 | require.NoError(t, err) 95 | resp, err := client.EmbedQuery(context.Background(), "Test query") 96 | require.Nil(t, err) 97 | require.NotNil(t, resp) 98 | require.Equal(t, 768, resp.Len()) 99 | }) 100 | 101 | t.Run("Test EmbedDocuments with env-based API Key and WithDefaultHeaders", func(t *testing.T) { 102 | client, err := NewTogetherEmbeddingFunction(WithEnvAPIKey(), WithDefaultModel("togethercomputer/m2-bert-80M-2k-retrieval"), WithDefaultHeaders(map[string]string{"X-Test-Header": "test"})) 103 | require.NoError(t, err) 104 | resp, rerr := client.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"}) 105 | 106 | require.Nil(t, rerr) 107 | require.NotNil(t, resp) 108 | require.Len(t, resp, 2) 109 | require.Equal(t, 768, resp[0].Len()) 110 | }) 111 | 112 | t.Run("Test EmbedDocuments with var API Key", func(t *testing.T) { 113 | client, err := NewTogetherEmbeddingFunction(WithAPIToken(os.Getenv("TOGETHER_API_KEY"))) 114 | require.NoError(t, err) 115 | resp, rerr := client.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"}) 116 | 117 | require.Nil(t, rerr) 118 | require.NotNil(t, resp) 119 | require.Len(t, resp, 2) 120 | require.Equal(t, 768, resp[0].Len()) 121 | }) 122 | 123 | t.Run("Test EmbedDocuments with var token and account id and http client", func(t *testing.T) { 124 | client, err := NewTogetherEmbeddingFunction(WithAPIToken(os.Getenv("TOGETHER_API_KEY")), WithHTTPClient(http.DefaultClient)) 125 | require.NoError(t, err) 126 | resp, rerr := client.EmbedDocuments(context.Background(), []string{"Test document", "Another test document"}) 127 | 128 | require.Nil(t, rerr) 129 | require.NotNil(t, resp) 130 | require.Equal(t, 2, len(resp)) 131 | require.Equal(t, 768, resp[0].Len()) 132 | }) 133 | } 134 | -------------------------------------------------------------------------------- /pkg/embeddings/voyage/option.go: -------------------------------------------------------------------------------- 1 | package voyage 2 | 3 | import ( 4 | "net/http" 5 | "os" 6 | 7 | "github.com/pkg/errors" 8 | 9 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 10 | ) 11 | 12 | type Option func(p *VoyageAIClient) error 13 | 14 | func WithDefaultModel(model embeddings.EmbeddingModel) Option { 15 | return func(p *VoyageAIClient) error { 16 | if model == "" { 17 | return errors.New("model cannot be empty") 18 | } 19 | p.DefaultModel = model 20 | return nil 21 | } 22 | } 23 | 24 | func WithMaxBatchSize(size int) Option { 25 | return func(p *VoyageAIClient) error { 26 | if size <= 0 { 27 | return errors.New("max batch size must be greater than 0") 28 | } 29 | p.MaxBatchSize = size 30 | return nil 31 | } 32 | } 33 | 34 | func WithDefaultHeaders(headers map[string]string) Option { 35 | return func(p *VoyageAIClient) error { 36 | p.DefaultHeaders = headers 37 | return nil 38 | } 39 | } 40 | 41 | func WithAPIKey(apiToken string) Option { 42 | return func(p *VoyageAIClient) error { 43 | if apiToken == "" { 44 | return errors.New("API key cannot be empty") 45 | } 46 | p.APIKey = apiToken 47 | return nil 48 | } 49 | } 50 | 51 | func WithEnvAPIKey() Option { 52 | return func(p *VoyageAIClient) error { 53 | if apiToken := os.Getenv(APIKeyEnvVar); apiToken != "" { 54 | p.APIKey = apiToken 55 | return nil 56 | } 57 | return errors.Errorf("%s not set", APIKeyEnvVar) 58 | } 59 | } 60 | 61 | func WithHTTPClient(client *http.Client) Option { 62 | return func(p *VoyageAIClient) error { 63 | if client == nil { 64 | return errors.New("HTTP client cannot be nil") 65 | } 66 | p.Client = client 67 | return nil 68 | } 69 | } 70 | 71 | func WithTruncation(truncation bool) Option { 72 | return func(p *VoyageAIClient) error { 73 | p.DefaultTruncation = &truncation 74 | return nil 75 | } 76 | } 77 | 78 | func WithEncodingFormat(format EncodingFormat) Option { 79 | return func(p *VoyageAIClient) error { 80 | if format == "" { 81 | return errors.New("encoding format cannot be empty") 82 | } 83 | var defaultEncodingFormat = format 84 | p.DefaultEncodingFormat = &defaultEncodingFormat 85 | return nil 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /pkg/rerankings/cohere/option.go: -------------------------------------------------------------------------------- 1 | package cohere 2 | 3 | import ( 4 | ccommons "github.com/amikos-tech/chroma-go/pkg/commons/cohere" 5 | httpc "github.com/amikos-tech/chroma-go/pkg/commons/http" 6 | "github.com/amikos-tech/chroma-go/pkg/embeddings" 7 | "github.com/amikos-tech/chroma-go/pkg/rerankings" 8 | ) 9 | 10 | type Option func(p *CohereRerankingFunction) ccommons.Option 11 | 12 | func WithBaseURL(baseURL string) Option { 13 | return func(p *CohereRerankingFunction) ccommons.Option { 14 | return ccommons.WithBaseURL(baseURL) 15 | } 16 | } 17 | 18 | func WithDefaultModel(model rerankings.RerankingModel) Option { 19 | return func(p *CohereRerankingFunction) ccommons.Option { 20 | return ccommons.WithDefaultModel(embeddings.EmbeddingModel(model)) 21 | } 22 | } 23 | 24 | func WithAPIKey(apiKey string) Option { 25 | return func(p *CohereRerankingFunction) ccommons.Option { 26 | return ccommons.WithAPIKey(apiKey) 27 | } 28 | } 29 | 30 | // WithEnvAPIKey configures the client to use the COHERE_API_KEY environment variable as the API key 31 | func WithEnvAPIKey() Option { 32 | return func(p *CohereRerankingFunction) ccommons.Option { 33 | return ccommons.WithEnvAPIKey() 34 | } 35 | } 36 | 37 | func WithTopN(topN int) Option { 38 | return func(p *CohereRerankingFunction) ccommons.Option { 39 | p.TopN = topN 40 | return ccommons.NoOp() 41 | } 42 | } 43 | 44 | // WithRerankFields configures the client to use the specified fields for reranking if the documents are in JSON format 45 | func WithRerankFields(fields []string) Option { 46 | return func(p *CohereRerankingFunction) ccommons.Option { 47 | p.RerankFields = fields 48 | return ccommons.NoOp() 49 | } 50 | } 51 | 52 | // WithReturnDocuments configures the client to return the original documents in the response 53 | func WithReturnDocuments() Option { 54 | return func(p *CohereRerankingFunction) ccommons.Option { 55 | p.ReturnDocuments = true 56 | return ccommons.NoOp() 57 | } 58 | } 59 | 60 | func WithMaxChunksPerDoc(maxChunks int) Option { 61 | return func(p *CohereRerankingFunction) ccommons.Option { 62 | p.MaxChunksPerDoc = maxChunks 63 | return ccommons.NoOp() 64 | } 65 | } 66 | 67 | // WithRetryStrategy configures the client to use the specified retry strategy 68 | func WithRetryStrategy(retryStrategy httpc.RetryStrategy) Option { 69 | return func(p *CohereRerankingFunction) ccommons.Option { 70 | return ccommons.WithRetryStrategy(retryStrategy) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /pkg/rerankings/hf/huggingface.go: -------------------------------------------------------------------------------- 1 | package huggingface 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | 11 | chromago "github.com/amikos-tech/chroma-go" 12 | chttp "github.com/amikos-tech/chroma-go/pkg/commons/http" 13 | "github.com/amikos-tech/chroma-go/pkg/rerankings" 14 | ) 15 | 16 | const ( 17 | DefaultBaseAPIEndpoint = "http://127.0.0.1:8080/rerank" 18 | ) 19 | 20 | type RerankingRequest struct { 21 | Model *string `json:"model,omitempty"` 22 | Query string `json:"query"` 23 | Texts []string `json:"texts"` 24 | } 25 | 26 | type RerankingResponse []struct { 27 | Index int `json:"index"` 28 | Score float32 `json:"score"` 29 | } 30 | 31 | var _ rerankings.RerankingFunction = (*HFRerankingFunction)(nil) 32 | 33 | func getDefaults() *HFRerankingFunction { 34 | return &HFRerankingFunction{ 35 | httpClient: http.DefaultClient, 36 | rerankingEndpoint: DefaultBaseAPIEndpoint, 37 | } 38 | } 39 | 40 | type HFRerankingFunction struct { 41 | httpClient *http.Client 42 | apiKey string 43 | defaultModel *rerankings.RerankingModel 44 | rerankingEndpoint string 45 | } 46 | 47 | func NewHFRerankingFunction(opts ...Option) (*HFRerankingFunction, error) { 48 | ef := getDefaults() 49 | for _, opt := range opts { 50 | err := opt(ef) 51 | if err != nil { 52 | return nil, err 53 | } 54 | } 55 | return ef, nil 56 | } 57 | 58 | func (r *HFRerankingFunction) sendRequest(ctx context.Context, req *RerankingRequest) (*RerankingResponse, error) { 59 | payload, err := json.Marshal(req) 60 | if err != nil { 61 | return nil, err 62 | } 63 | 64 | httpReq, err := http.NewRequest("POST", r.rerankingEndpoint, bytes.NewBuffer(payload)) 65 | if err != nil { 66 | return nil, err 67 | } 68 | 69 | httpReq.Header.Set("Content-Type", "application/json") 70 | httpReq.Header.Set("Accept", "application/json") 71 | httpReq.Header.Set("User-Agent", chttp.ChromaGoClientUserAgent) 72 | if r.apiKey != "" { 73 | httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", r.apiKey)) 74 | } 75 | 76 | resp, err := r.httpClient.Do(httpReq) 77 | if err != nil { 78 | return nil, err 79 | } 80 | defer resp.Body.Close() 81 | respData, err := io.ReadAll(resp.Body) 82 | if err != nil { 83 | return nil, err 84 | } 85 | 86 | if resp.StatusCode != http.StatusOK { 87 | // TODO serialize body in error 88 | return nil, fmt.Errorf("unexpected response %v: %s", resp.Status, respData) 89 | } 90 | var response *RerankingResponse 91 | if err := json.Unmarshal(respData, &response); err != nil { 92 | return nil, err 93 | } 94 | 95 | return response, nil 96 | } 97 | 98 | func (r *HFRerankingFunction) Rerank(ctx context.Context, query string, results []rerankings.Result) (map[string][]rerankings.RankedResult, error) { 99 | docs := make([]string, 0) 100 | for _, result := range results { 101 | d, err := result.ToText() 102 | if err != nil { 103 | return nil, err 104 | } 105 | docs = append(docs, d) 106 | } 107 | req := &RerankingRequest{ 108 | Model: (*string)(r.defaultModel), 109 | Texts: docs, 110 | Query: query, 111 | } 112 | 113 | rerankResp, err := r.sendRequest(ctx, req) 114 | if err != nil { 115 | return nil, err 116 | } 117 | 118 | rankedResults := map[string][]rerankings.RankedResult{r.ID(): make([]rerankings.RankedResult, len(*rerankResp))} 119 | for i, rr := range *rerankResp { 120 | originalDoc, err := results[rr.Index].ToText() 121 | if err != nil { 122 | return nil, err 123 | } 124 | 125 | rankedResults[r.ID()][i] = rerankings.RankedResult{ 126 | String: originalDoc, 127 | Index: rr.Index, 128 | Rank: rr.Score, 129 | } 130 | } 131 | return rankedResults, nil 132 | } 133 | 134 | // ID returns the of the reranking function. We use `cohere-` prefix with the default model 135 | func (r *HFRerankingFunction) ID() string { 136 | if r.defaultModel != nil { 137 | return fmt.Sprintf("hf-%s", *(*string)(r.defaultModel)) 138 | } 139 | return "hfei" 140 | } 141 | 142 | func (r *HFRerankingFunction) RerankResults(ctx context.Context, queryResults *chromago.QueryResults) (*rerankings.RerankedChromaResults, error) { 143 | rerankedResults := &rerankings.RerankedChromaResults{ 144 | QueryResults: *queryResults, 145 | Ranks: map[string][][]float32{r.ID(): make([][]float32, len(queryResults.Ids))}, 146 | } 147 | for i, rs := range queryResults.Ids { 148 | if len(rs) == 0 { 149 | return nil, fmt.Errorf("no results to rerank") 150 | } 151 | docs := make([]string, 0) 152 | docs = append(docs, queryResults.Documents[i]...) 153 | req := &RerankingRequest{ 154 | Model: (*string)(r.defaultModel), 155 | Texts: docs, 156 | Query: queryResults.QueryTexts[i], 157 | } 158 | rerankResp, err := r.sendRequest(ctx, req) 159 | if err != nil { 160 | return nil, err 161 | } 162 | rerankedResults.Ranks[r.ID()][i] = make([]float32, len(*rerankResp)) 163 | for _, rr := range *rerankResp { 164 | rerankedResults.Ranks[r.ID()][i][rr.Index] = rr.Score 165 | } 166 | } 167 | return rerankedResults, nil 168 | } 169 | -------------------------------------------------------------------------------- /pkg/rerankings/hf/option.go: -------------------------------------------------------------------------------- 1 | package huggingface 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/amikos-tech/chroma-go/pkg/rerankings" 8 | ) 9 | 10 | type Option func(c *HFRerankingFunction) error 11 | 12 | func WithAPIKey(apiKey string) Option { 13 | return func(c *HFRerankingFunction) error { 14 | c.apiKey = apiKey 15 | return nil 16 | } 17 | } 18 | 19 | func WithEnvAPIKey() Option { 20 | return func(c *HFRerankingFunction) error { 21 | if os.Getenv("HF_API_KEY") == "" { 22 | return fmt.Errorf("HF_API_KEY not set") 23 | } 24 | c.apiKey = os.Getenv("HF_API_KEY") 25 | return nil 26 | } 27 | } 28 | 29 | func WithModel(model rerankings.RerankingModel) Option { 30 | return func(c *HFRerankingFunction) error { 31 | c.defaultModel = &model 32 | return nil 33 | } 34 | } 35 | 36 | func WithRerankingEndpoint(endpoint string) Option { 37 | return func(c *HFRerankingFunction) error { 38 | c.rerankingEndpoint = endpoint 39 | return nil 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /pkg/rerankings/jina/option.go: -------------------------------------------------------------------------------- 1 | package jina 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/amikos-tech/chroma-go/types" 8 | ) 9 | 10 | type Option func(c *JinaRerankingFunction) error 11 | 12 | func WithAPIKey(apiKey string) Option { 13 | return func(c *JinaRerankingFunction) error { 14 | c.apiKey = apiKey 15 | return nil 16 | } 17 | } 18 | 19 | func WithEnvAPIKey() Option { 20 | return func(c *JinaRerankingFunction) error { 21 | if os.Getenv("JINA_API_KEY") == "" { 22 | return fmt.Errorf("JINA_API_KEY not set") 23 | } 24 | c.apiKey = os.Getenv("JINA_API_KEY") 25 | return nil 26 | } 27 | } 28 | 29 | func WithModel(model types.RerankingModel) Option { 30 | return func(c *JinaRerankingFunction) error { 31 | c.defaultModel = model 32 | return nil 33 | } 34 | } 35 | 36 | func WithRerankingEndpoint(endpoint string) Option { 37 | return func(c *JinaRerankingFunction) error { 38 | c.rerankingEndpoint = endpoint 39 | return nil 40 | } 41 | } 42 | 43 | func WithTopN(topN int) Option { 44 | return func(c *JinaRerankingFunction) error { 45 | if topN <= 0 { 46 | return fmt.Errorf("topN must be a positive integer") 47 | } 48 | c.topN = &topN 49 | return nil 50 | } 51 | } 52 | 53 | func WithReturnDocuments(returnDocuments bool) Option { 54 | return func(c *JinaRerankingFunction) error { 55 | c.returnDocuments = &returnDocuments 56 | return nil 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /pkg/rerankings/reranking.go: -------------------------------------------------------------------------------- 1 | package rerankings 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | 8 | chromago "github.com/amikos-tech/chroma-go" 9 | ) 10 | 11 | type RerankingModel string 12 | 13 | type RankedResult struct { 14 | Index int // Index in the original input []string 15 | String string 16 | Rank float32 17 | } 18 | 19 | type RerankedChromaResults struct { 20 | chromago.QueryResults 21 | Ranks map[string][][]float32 // each reranker adds a rank for each result 22 | } 23 | 24 | type Result struct { 25 | Text *string 26 | Object *any 27 | } 28 | 29 | func FromText(text string) Result { 30 | return Result{ 31 | Text: &text, 32 | } 33 | } 34 | 35 | func FromTexts(texts []string) []Result { 36 | results := make([]Result, len(texts)) 37 | for i, text := range texts { 38 | results[i] = FromText(text) 39 | } 40 | return results 41 | } 42 | 43 | func FromObject(object any) Result { 44 | return Result{ 45 | Object: &object, 46 | } 47 | } 48 | 49 | func FromObjects(objects []any) []Result { 50 | results := make([]Result, len(objects)) 51 | for i, object := range objects { 52 | results[i] = FromObject(object) 53 | } 54 | return results 55 | } 56 | 57 | func (r *Result) ToText() (string, error) { 58 | if r.IsText() { 59 | return *r.Text, nil 60 | } else if r.IsObject() { 61 | marshal, err := json.Marshal(r.Object) 62 | if err != nil { 63 | return "", err 64 | } 65 | return string(marshal), nil 66 | } 67 | return "", fmt.Errorf("result is neither text nor object") 68 | } 69 | 70 | func (r *Result) IsText() bool { 71 | return r.Text != nil 72 | } 73 | 74 | func (r *Result) IsObject() bool { 75 | return r.Object != nil 76 | } 77 | 78 | type RerankingFunction interface { 79 | ID() string 80 | Rerank(ctx context.Context, query string, results []Result) (map[string][]RankedResult, error) 81 | RerankResults(ctx context.Context, queryResults *chromago.QueryResults) (*RerankedChromaResults, error) 82 | } 83 | -------------------------------------------------------------------------------- /pkg/rerankings/reranking_test.go: -------------------------------------------------------------------------------- 1 | //go:build rf 2 | 3 | package rerankings 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "math/rand" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/require" 12 | 13 | chromago "github.com/amikos-tech/chroma-go" 14 | ) 15 | 16 | type DummyRerankingFunction struct { 17 | } 18 | 19 | func NewDummyRerankingFunction() *DummyRerankingFunction { 20 | return &DummyRerankingFunction{} 21 | } 22 | 23 | func (d *DummyRerankingFunction) ID() string { 24 | return "dummy" 25 | } 26 | func (d *DummyRerankingFunction) Rerank(_ context.Context, _ string, results []Result) (map[string][]RankedResult, error) { 27 | if len(results) == 0 { 28 | return nil, fmt.Errorf("no results to rerank") 29 | } 30 | rerankedResults := make([]RankedResult, len(results)) 31 | for i, result := range results { 32 | doc, err := result.ToText() 33 | if err != nil { 34 | return nil, err 35 | } 36 | rerankedResults[i] = RankedResult{ 37 | String: doc, 38 | Index: i, 39 | Rank: rand.Float32(), 40 | } 41 | } 42 | return map[string][]RankedResult{d.ID(): rerankedResults}, nil 43 | } 44 | 45 | func (d *DummyRerankingFunction) RerankResults(_ context.Context, queryResults *chromago.QueryResults) (*RerankedChromaResults, error) { 46 | if len(queryResults.Ids) == 0 { 47 | return nil, fmt.Errorf("no results to rerank") 48 | } 49 | results := &RerankedChromaResults{ 50 | QueryResults: *queryResults, 51 | Ranks: map[string][][]float32{d.ID(): make([][]float32, len(queryResults.Ids))}, 52 | } 53 | for i, qr := range queryResults.Ids { 54 | results.Ranks[d.ID()][i] = make([]float32, len(qr)) 55 | for j := range qr { 56 | results.Ranks[d.ID()][i][j] = rand.Float32() 57 | } 58 | } 59 | return results, nil 60 | } 61 | 62 | func Test_reranking_function(t *testing.T) { 63 | rerankingFunction := NewDummyRerankingFunction() 64 | t.Run("Rerank string results", func(t *testing.T) { 65 | query := "hello world" 66 | results := []string{"hello", "world"} 67 | rerankedResults, err := rerankingFunction.Rerank(context.Background(), query, FromTexts(results)) 68 | require.NoError(t, err) 69 | require.NotNil(t, rerankedResults) 70 | require.Contains(t, rerankedResults, rerankingFunction.ID()) 71 | require.Equal(t, len(results), len(rerankedResults[rerankingFunction.ID()])) 72 | for _, result := range rerankedResults[rerankingFunction.ID()] { 73 | require.Equal(t, results[result.Index], result.String) 74 | } 75 | }) 76 | 77 | t.Run("Rerank chroma results", func(t *testing.T) { 78 | query := "hello world" 79 | results := &chromago.QueryResults{ 80 | Ids: [][]string{{"1"}, {"2"}}, 81 | Documents: [][]string{{"hello"}, {"world"}}, 82 | Distances: [][]float32{{0.1}, {0.2}}, 83 | QueryTexts: []string{query}, 84 | } 85 | rerankedResults, err := rerankingFunction.RerankResults(context.Background(), results) 86 | require.NoError(t, err) 87 | require.NotNil(t, rerankedResults) 88 | require.Contains(t, rerankedResults.Ranks, rerankingFunction.ID()) 89 | require.Equal(t, len(results.Ids), len(rerankedResults.Ids)) 90 | require.Equal(t, results.Ids, rerankedResults.Ids) 91 | require.Equal(t, results.Documents, rerankedResults.Documents) 92 | require.Equal(t, results.QueryTexts, rerankedResults.QueryTexts) 93 | }) 94 | } 95 | -------------------------------------------------------------------------------- /pkg/tokenizers/libtokenizers/tokenizers.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | struct EncodeOptions { 5 | bool add_special_token; 6 | bool return_type_ids; 7 | bool return_tokens; 8 | bool return_special_tokens_mask; 9 | bool return_attention_mask; 10 | bool return_offsets; 11 | }; 12 | 13 | struct TokenizerOptions { 14 | bool encode_special_tokens; 15 | }; 16 | 17 | struct Buffer { 18 | uint32_t *ids; 19 | uint32_t *type_ids; 20 | uint32_t *special_tokens_mask; 21 | uint32_t *attention_mask; 22 | char *tokens; 23 | size_t *offsets; 24 | uint32_t len; 25 | }; 26 | 27 | void *from_bytes(const uint8_t *config, uint32_t len, const struct TokenizerOptions *options); 28 | 29 | void *from_bytes_with_truncation(const uint8_t *config, uint32_t len, uint32_t max_len, uint8_t direction); 30 | 31 | void *from_file(const char *config); 32 | 33 | struct Buffer encode(void *ptr, const char *message, const struct EncodeOptions *options); 34 | 35 | char *decode(void *ptr, const uint32_t *ids, uint32_t len, bool skip_special_tokens); 36 | 37 | uint32_t vocab_size(void *ptr); 38 | 39 | void free_tokenizer(void *ptr); 40 | 41 | void free_buffer(struct Buffer buffer); 42 | 43 | void free_string(char *string); 44 | -------------------------------------------------------------------------------- /scripts/chroma_server.sh: -------------------------------------------------------------------------------- 1 | #!/usr/scripts/env bash 2 | 3 | docker run --rm -it -p 8000:8000 -e ALLOW_RESET=TRUE chromadb/chroma:latest 4 | -------------------------------------------------------------------------------- /swagger/model_collection.go: -------------------------------------------------------------------------------- 1 | /* 2 | ChromaDB API 3 | 4 | This is OpenAPI schema for ChromaDB API. 5 | 6 | API version: 1.0.0 7 | */ 8 | 9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. 10 | 11 | package openapi 12 | 13 | import ( 14 | "encoding/json" 15 | ) 16 | 17 | // checks if the Collection type satisfies the MappedNullable interface at compile time 18 | var _ MappedNullable = &Collection{} 19 | 20 | // Collection struct for Collection 21 | type Collection struct { 22 | Name string `json:"name"` 23 | Id string `json:"id"` 24 | Metadata *map[string]Metadata `json:"metadata,omitempty"` 25 | } 26 | 27 | // NewCollection instantiates a new Collection object 28 | // This constructor will assign default values to properties that have it defined, 29 | // and makes sure properties required by API are set, but the set of arguments 30 | // will change when the set of required properties is changed 31 | func NewCollection(name string, id string) *Collection { 32 | this := Collection{} 33 | this.Name = name 34 | this.Id = id 35 | return &this 36 | } 37 | 38 | // NewCollectionWithDefaults instantiates a new Collection object 39 | // This constructor will only assign default values to properties that have it defined, 40 | // but it doesn't guarantee that properties required by API are set 41 | func NewCollectionWithDefaults() *Collection { 42 | this := Collection{} 43 | return &this 44 | } 45 | 46 | // GetName returns the Name field value 47 | func (o *Collection) GetName() string { 48 | if o == nil { 49 | var ret string 50 | return ret 51 | } 52 | 53 | return o.Name 54 | } 55 | 56 | // GetNameOk returns a tuple with the Name field value 57 | // and a boolean to check if the value has been set. 58 | func (o *Collection) GetNameOk() (*string, bool) { 59 | if o == nil { 60 | return nil, false 61 | } 62 | return &o.Name, true 63 | } 64 | 65 | // SetName sets field value 66 | func (o *Collection) SetName(v string) { 67 | o.Name = v 68 | } 69 | 70 | // GetId returns the Id field value 71 | func (o *Collection) GetId() string { 72 | if o == nil { 73 | var ret string 74 | return ret 75 | } 76 | 77 | return o.Id 78 | } 79 | 80 | // GetIdOk returns a tuple with the Id field value 81 | // and a boolean to check if the value has been set. 82 | func (o *Collection) GetIdOk() (*string, bool) { 83 | if o == nil { 84 | return nil, false 85 | } 86 | return &o.Id, true 87 | } 88 | 89 | // SetId sets field value 90 | func (o *Collection) SetId(v string) { 91 | o.Id = v 92 | } 93 | 94 | // GetMetadata returns the Metadata field value if set, zero value otherwise. 95 | func (o *Collection) GetMetadata() map[string]Metadata { 96 | if o == nil || IsNil(o.Metadata) { 97 | var ret map[string]Metadata 98 | return ret 99 | } 100 | return *o.Metadata 101 | } 102 | 103 | // GetMetadataOk returns a tuple with the Metadata field value if set, nil otherwise 104 | // and a boolean to check if the value has been set. 105 | func (o *Collection) GetMetadataOk() (*map[string]Metadata, bool) { 106 | if o == nil || IsNil(o.Metadata) { 107 | return nil, false 108 | } 109 | return o.Metadata, true 110 | } 111 | 112 | // HasMetadata returns a boolean if a field has been set. 113 | func (o *Collection) HasMetadata() bool { 114 | if o != nil && !IsNil(o.Metadata) { 115 | return true 116 | } 117 | 118 | return false 119 | } 120 | 121 | // SetMetadata gets a reference to the given map[string]Metadata and assigns it to the Metadata field. 122 | func (o *Collection) SetMetadata(v map[string]Metadata) { 123 | o.Metadata = &v 124 | } 125 | 126 | func (o Collection) MarshalJSON() ([]byte, error) { 127 | toSerialize, err := o.ToMap() 128 | if err != nil { 129 | return []byte{}, err 130 | } 131 | return json.Marshal(toSerialize) 132 | } 133 | 134 | func (o Collection) ToMap() (map[string]interface{}, error) { 135 | toSerialize := map[string]interface{}{} 136 | toSerialize["name"] = o.Name 137 | toSerialize["id"] = o.Id 138 | if !IsNil(o.Metadata) { 139 | toSerialize["metadata"] = o.Metadata 140 | } 141 | return toSerialize, nil 142 | } 143 | 144 | type NullableCollection struct { 145 | value *Collection 146 | isSet bool 147 | } 148 | 149 | func (v NullableCollection) Get() *Collection { 150 | return v.value 151 | } 152 | 153 | func (v *NullableCollection) Set(val *Collection) { 154 | v.value = val 155 | v.isSet = true 156 | } 157 | 158 | func (v NullableCollection) IsSet() bool { 159 | return v.isSet 160 | } 161 | 162 | func (v *NullableCollection) Unset() { 163 | v.value = nil 164 | v.isSet = false 165 | } 166 | 167 | func NewNullableCollection(val *Collection) *NullableCollection { 168 | return &NullableCollection{value: val, isSet: true} 169 | } 170 | 171 | func (v NullableCollection) MarshalJSON() ([]byte, error) { 172 | return json.Marshal(v.value) 173 | } 174 | 175 | func (v *NullableCollection) UnmarshalJSON(src []byte) error { 176 | v.isSet = true 177 | return json.Unmarshal(src, &v.value) 178 | } 179 | -------------------------------------------------------------------------------- /swagger/model_create_database.go: -------------------------------------------------------------------------------- 1 | /* 2 | ChromaDB API 3 | 4 | This is OpenAPI schema for ChromaDB API. 5 | 6 | API version: 1.0.0 7 | */ 8 | 9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. 10 | 11 | package openapi 12 | 13 | import ( 14 | "encoding/json" 15 | ) 16 | 17 | // checks if the CreateDatabase type satisfies the MappedNullable interface at compile time 18 | var _ MappedNullable = &CreateDatabase{} 19 | 20 | // CreateDatabase struct for CreateDatabase 21 | type CreateDatabase struct { 22 | Name string `json:"name"` 23 | } 24 | 25 | // NewCreateDatabase instantiates a new CreateDatabase object 26 | // This constructor will assign default values to properties that have it defined, 27 | // and makes sure properties required by API are set, but the set of arguments 28 | // will change when the set of required properties is changed 29 | func NewCreateDatabase(name string) *CreateDatabase { 30 | this := CreateDatabase{} 31 | this.Name = name 32 | return &this 33 | } 34 | 35 | // NewCreateDatabaseWithDefaults instantiates a new CreateDatabase object 36 | // This constructor will only assign default values to properties that have it defined, 37 | // but it doesn't guarantee that properties required by API are set 38 | func NewCreateDatabaseWithDefaults() *CreateDatabase { 39 | this := CreateDatabase{} 40 | return &this 41 | } 42 | 43 | // GetName returns the Name field value 44 | func (o *CreateDatabase) GetName() string { 45 | if o == nil { 46 | var ret string 47 | return ret 48 | } 49 | 50 | return o.Name 51 | } 52 | 53 | // GetNameOk returns a tuple with the Name field value 54 | // and a boolean to check if the value has been set. 55 | func (o *CreateDatabase) GetNameOk() (*string, bool) { 56 | if o == nil { 57 | return nil, false 58 | } 59 | return &o.Name, true 60 | } 61 | 62 | // SetName sets field value 63 | func (o *CreateDatabase) SetName(v string) { 64 | o.Name = v 65 | } 66 | 67 | func (o CreateDatabase) MarshalJSON() ([]byte, error) { 68 | toSerialize, err := o.ToMap() 69 | if err != nil { 70 | return []byte{}, err 71 | } 72 | return json.Marshal(toSerialize) 73 | } 74 | 75 | func (o CreateDatabase) ToMap() (map[string]interface{}, error) { 76 | toSerialize := map[string]interface{}{} 77 | toSerialize["name"] = o.Name 78 | return toSerialize, nil 79 | } 80 | 81 | type NullableCreateDatabase struct { 82 | value *CreateDatabase 83 | isSet bool 84 | } 85 | 86 | func (v NullableCreateDatabase) Get() *CreateDatabase { 87 | return v.value 88 | } 89 | 90 | func (v *NullableCreateDatabase) Set(val *CreateDatabase) { 91 | v.value = val 92 | v.isSet = true 93 | } 94 | 95 | func (v NullableCreateDatabase) IsSet() bool { 96 | return v.isSet 97 | } 98 | 99 | func (v *NullableCreateDatabase) Unset() { 100 | v.value = nil 101 | v.isSet = false 102 | } 103 | 104 | func NewNullableCreateDatabase(val *CreateDatabase) *NullableCreateDatabase { 105 | return &NullableCreateDatabase{value: val, isSet: true} 106 | } 107 | 108 | func (v NullableCreateDatabase) MarshalJSON() ([]byte, error) { 109 | return json.Marshal(v.value) 110 | } 111 | 112 | func (v *NullableCreateDatabase) UnmarshalJSON(src []byte) error { 113 | v.isSet = true 114 | return json.Unmarshal(src, &v.value) 115 | } 116 | -------------------------------------------------------------------------------- /swagger/model_create_tenant.go: -------------------------------------------------------------------------------- 1 | /* 2 | ChromaDB API 3 | 4 | This is OpenAPI schema for ChromaDB API. 5 | 6 | API version: 1.0.0 7 | */ 8 | 9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. 10 | 11 | package openapi 12 | 13 | import ( 14 | "encoding/json" 15 | ) 16 | 17 | // checks if the CreateTenant type satisfies the MappedNullable interface at compile time 18 | var _ MappedNullable = &CreateTenant{} 19 | 20 | // CreateTenant struct for CreateTenant 21 | type CreateTenant struct { 22 | Name string `json:"name"` 23 | } 24 | 25 | // NewCreateTenant instantiates a new CreateTenant object 26 | // This constructor will assign default values to properties that have it defined, 27 | // and makes sure properties required by API are set, but the set of arguments 28 | // will change when the set of required properties is changed 29 | func NewCreateTenant(name string) *CreateTenant { 30 | this := CreateTenant{} 31 | this.Name = name 32 | return &this 33 | } 34 | 35 | // NewCreateTenantWithDefaults instantiates a new CreateTenant object 36 | // This constructor will only assign default values to properties that have it defined, 37 | // but it doesn't guarantee that properties required by API are set 38 | func NewCreateTenantWithDefaults() *CreateTenant { 39 | this := CreateTenant{} 40 | return &this 41 | } 42 | 43 | // GetName returns the Name field value 44 | func (o *CreateTenant) GetName() string { 45 | if o == nil { 46 | var ret string 47 | return ret 48 | } 49 | 50 | return o.Name 51 | } 52 | 53 | // GetNameOk returns a tuple with the Name field value 54 | // and a boolean to check if the value has been set. 55 | func (o *CreateTenant) GetNameOk() (*string, bool) { 56 | if o == nil { 57 | return nil, false 58 | } 59 | return &o.Name, true 60 | } 61 | 62 | // SetName sets field value 63 | func (o *CreateTenant) SetName(v string) { 64 | o.Name = v 65 | } 66 | 67 | func (o CreateTenant) MarshalJSON() ([]byte, error) { 68 | toSerialize, err := o.ToMap() 69 | if err != nil { 70 | return []byte{}, err 71 | } 72 | return json.Marshal(toSerialize) 73 | } 74 | 75 | func (o CreateTenant) ToMap() (map[string]interface{}, error) { 76 | toSerialize := map[string]interface{}{} 77 | toSerialize["name"] = o.Name 78 | return toSerialize, nil 79 | } 80 | 81 | type NullableCreateTenant struct { 82 | value *CreateTenant 83 | isSet bool 84 | } 85 | 86 | func (v NullableCreateTenant) Get() *CreateTenant { 87 | return v.value 88 | } 89 | 90 | func (v *NullableCreateTenant) Set(val *CreateTenant) { 91 | v.value = val 92 | v.isSet = true 93 | } 94 | 95 | func (v NullableCreateTenant) IsSet() bool { 96 | return v.isSet 97 | } 98 | 99 | func (v *NullableCreateTenant) Unset() { 100 | v.value = nil 101 | v.isSet = false 102 | } 103 | 104 | func NewNullableCreateTenant(val *CreateTenant) *NullableCreateTenant { 105 | return &NullableCreateTenant{value: val, isSet: true} 106 | } 107 | 108 | func (v NullableCreateTenant) MarshalJSON() ([]byte, error) { 109 | return json.Marshal(v.value) 110 | } 111 | 112 | func (v *NullableCreateTenant) UnmarshalJSON(src []byte) error { 113 | v.isSet = true 114 | return json.Unmarshal(src, &v.value) 115 | } 116 | -------------------------------------------------------------------------------- /swagger/model_embeddings_inner.go: -------------------------------------------------------------------------------- 1 | /* 2 | ChromaDB API 3 | 4 | This is OpenAPI schema for ChromaDB API. 5 | 6 | API version: 1.0.0 7 | */ 8 | 9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. 10 | 11 | package openapi 12 | 13 | import ( 14 | "encoding/json" 15 | "fmt" 16 | ) 17 | 18 | // EmbeddingsInner struct for EmbeddingsInner 19 | type EmbeddingsInner struct { 20 | ArrayOfFloat32 *[]float32 21 | ArrayOfInt32 *[]int32 22 | } 23 | 24 | // Unmarshal JSON data into any of the pointers in the struct 25 | func (dst *EmbeddingsInner) UnmarshalJSON(data []byte) error { 26 | var err error 27 | // try to unmarshal JSON data into ArrayOfFloat32 28 | err = json.Unmarshal(data, &dst.ArrayOfFloat32) 29 | if err == nil { 30 | jsonArrayOfFloat32, _ := json.Marshal(dst.ArrayOfFloat32) 31 | if string(jsonArrayOfFloat32) == "{}" { // empty struct 32 | dst.ArrayOfFloat32 = nil 33 | } else { 34 | return nil // data stored in dst.ArrayOfFloat32, return on the first match 35 | } 36 | } else { 37 | dst.ArrayOfFloat32 = nil 38 | } 39 | 40 | // try to unmarshal JSON data into ArrayOfInt32 41 | err = json.Unmarshal(data, &dst.ArrayOfInt32) 42 | if err == nil { 43 | jsonArrayOfInt32, _ := json.Marshal(dst.ArrayOfInt32) 44 | if string(jsonArrayOfInt32) == "{}" { // empty struct 45 | dst.ArrayOfInt32 = nil 46 | } else { 47 | return nil // data stored in dst.ArrayOfInt32, return on the first match 48 | } 49 | } else { 50 | dst.ArrayOfInt32 = nil 51 | } 52 | 53 | return fmt.Errorf("data failed to match schemas in anyOf(EmbeddingsInner)") 54 | } 55 | 56 | // Marshal data from the first non-nil pointers in the struct to JSON 57 | func (src *EmbeddingsInner) MarshalJSON() ([]byte, error) { 58 | if src.ArrayOfFloat32 != nil { 59 | return json.Marshal(&src.ArrayOfFloat32) 60 | } 61 | 62 | if src.ArrayOfInt32 != nil { 63 | return json.Marshal(&src.ArrayOfInt32) 64 | } 65 | 66 | return nil, nil // no data in anyOf schemas 67 | } 68 | 69 | type NullableEmbeddingsInner struct { 70 | value *EmbeddingsInner 71 | isSet bool 72 | } 73 | 74 | func (v NullableEmbeddingsInner) Get() *EmbeddingsInner { 75 | return v.value 76 | } 77 | 78 | func (v *NullableEmbeddingsInner) Set(val *EmbeddingsInner) { 79 | v.value = val 80 | v.isSet = true 81 | } 82 | 83 | func (v NullableEmbeddingsInner) IsSet() bool { 84 | return v.isSet 85 | } 86 | 87 | func (v *NullableEmbeddingsInner) Unset() { 88 | v.value = nil 89 | v.isSet = false 90 | } 91 | 92 | func NewNullableEmbeddingsInner(val *EmbeddingsInner) *NullableEmbeddingsInner { 93 | return &NullableEmbeddingsInner{value: val, isSet: true} 94 | } 95 | 96 | func (v NullableEmbeddingsInner) MarshalJSON() ([]byte, error) { 97 | return json.Marshal(v.value) 98 | } 99 | 100 | func (v *NullableEmbeddingsInner) UnmarshalJSON(src []byte) error { 101 | v.isSet = true 102 | return json.Unmarshal(src, &v.value) 103 | } 104 | -------------------------------------------------------------------------------- /swagger/model_http_validation_error.go: -------------------------------------------------------------------------------- 1 | /* 2 | ChromaDB API 3 | 4 | This is OpenAPI schema for ChromaDB API. 5 | 6 | API version: 1.0.0 7 | */ 8 | 9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. 10 | 11 | package openapi 12 | 13 | import ( 14 | "encoding/json" 15 | ) 16 | 17 | // checks if the HTTPValidationError type satisfies the MappedNullable interface at compile time 18 | var _ MappedNullable = &HTTPValidationError{} 19 | 20 | // HTTPValidationError struct for HTTPValidationError 21 | type HTTPValidationError struct { 22 | Detail []ValidationError `json:"detail,omitempty"` 23 | } 24 | 25 | // NewHTTPValidationError instantiates a new HTTPValidationError object 26 | // This constructor will assign default values to properties that have it defined, 27 | // and makes sure properties required by API are set, but the set of arguments 28 | // will change when the set of required properties is changed 29 | func NewHTTPValidationError() *HTTPValidationError { 30 | this := HTTPValidationError{} 31 | return &this 32 | } 33 | 34 | // NewHTTPValidationErrorWithDefaults instantiates a new HTTPValidationError object 35 | // This constructor will only assign default values to properties that have it defined, 36 | // but it doesn't guarantee that properties required by API are set 37 | func NewHTTPValidationErrorWithDefaults() *HTTPValidationError { 38 | this := HTTPValidationError{} 39 | return &this 40 | } 41 | 42 | // GetDetail returns the Detail field value if set, zero value otherwise. 43 | func (o *HTTPValidationError) GetDetail() []ValidationError { 44 | if o == nil || IsNil(o.Detail) { 45 | var ret []ValidationError 46 | return ret 47 | } 48 | return o.Detail 49 | } 50 | 51 | // GetDetailOk returns a tuple with the Detail field value if set, nil otherwise 52 | // and a boolean to check if the value has been set. 53 | func (o *HTTPValidationError) GetDetailOk() ([]ValidationError, bool) { 54 | if o == nil || IsNil(o.Detail) { 55 | return nil, false 56 | } 57 | return o.Detail, true 58 | } 59 | 60 | // HasDetail returns a boolean if a field has been set. 61 | func (o *HTTPValidationError) HasDetail() bool { 62 | if o != nil && !IsNil(o.Detail) { 63 | return true 64 | } 65 | 66 | return false 67 | } 68 | 69 | // SetDetail gets a reference to the given []ValidationError and assigns it to the Detail field. 70 | func (o *HTTPValidationError) SetDetail(v []ValidationError) { 71 | o.Detail = v 72 | } 73 | 74 | func (o HTTPValidationError) MarshalJSON() ([]byte, error) { 75 | toSerialize, err := o.ToMap() 76 | if err != nil { 77 | return []byte{}, err 78 | } 79 | return json.Marshal(toSerialize) 80 | } 81 | 82 | func (o HTTPValidationError) ToMap() (map[string]interface{}, error) { 83 | toSerialize := map[string]interface{}{} 84 | if !IsNil(o.Detail) { 85 | toSerialize["detail"] = o.Detail 86 | } 87 | return toSerialize, nil 88 | } 89 | 90 | type NullableHTTPValidationError struct { 91 | value *HTTPValidationError 92 | isSet bool 93 | } 94 | 95 | func (v NullableHTTPValidationError) Get() *HTTPValidationError { 96 | return v.value 97 | } 98 | 99 | func (v *NullableHTTPValidationError) Set(val *HTTPValidationError) { 100 | v.value = val 101 | v.isSet = true 102 | } 103 | 104 | func (v NullableHTTPValidationError) IsSet() bool { 105 | return v.isSet 106 | } 107 | 108 | func (v *NullableHTTPValidationError) Unset() { 109 | v.value = nil 110 | v.isSet = false 111 | } 112 | 113 | func NewNullableHTTPValidationError(val *HTTPValidationError) *NullableHTTPValidationError { 114 | return &NullableHTTPValidationError{value: val, isSet: true} 115 | } 116 | 117 | func (v NullableHTTPValidationError) MarshalJSON() ([]byte, error) { 118 | return json.Marshal(v.value) 119 | } 120 | 121 | func (v *NullableHTTPValidationError) UnmarshalJSON(src []byte) error { 122 | v.isSet = true 123 | return json.Unmarshal(src, &v.value) 124 | } 125 | -------------------------------------------------------------------------------- /swagger/model_include_inner.go: -------------------------------------------------------------------------------- 1 | /* 2 | ChromaDB API 3 | 4 | This is OpenAPI schema for ChromaDB API. 5 | 6 | API version: 1.0.0 7 | */ 8 | 9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. 10 | 11 | package openapi 12 | 13 | import ( 14 | "encoding/json" 15 | "fmt" 16 | ) 17 | 18 | // IncludeInner struct for IncludeInner 19 | type IncludeInner struct { 20 | String *string 21 | } 22 | 23 | // Unmarshal JSON data into any of the pointers in the struct 24 | func (dst *IncludeInner) UnmarshalJSON(data []byte) error { 25 | var err error 26 | // try to unmarshal JSON data into String 27 | err = json.Unmarshal(data, &dst.String) 28 | if err == nil { 29 | jsonString, _ := json.Marshal(dst.String) 30 | if string(jsonString) == "{}" { // empty struct 31 | dst.String = nil 32 | } else { 33 | return nil // data stored in dst.String, return on the first match 34 | } 35 | } else { 36 | dst.String = nil 37 | } 38 | 39 | return fmt.Errorf("data failed to match schemas in anyOf(IncludeInner)") 40 | } 41 | 42 | // Marshal data from the first non-nil pointers in the struct to JSON 43 | func (src *IncludeInner) MarshalJSON() ([]byte, error) { 44 | if src.String != nil { 45 | return json.Marshal(&src.String) 46 | } 47 | 48 | return nil, nil // no data in anyOf schemas 49 | } 50 | 51 | type NullableIncludeInner struct { 52 | value *IncludeInner 53 | isSet bool 54 | } 55 | 56 | func (v NullableIncludeInner) Get() *IncludeInner { 57 | return v.value 58 | } 59 | 60 | func (v *NullableIncludeInner) Set(val *IncludeInner) { 61 | v.value = val 62 | v.isSet = true 63 | } 64 | 65 | func (v NullableIncludeInner) IsSet() bool { 66 | return v.isSet 67 | } 68 | 69 | func (v *NullableIncludeInner) Unset() { 70 | v.value = nil 71 | v.isSet = false 72 | } 73 | 74 | func NewNullableIncludeInner(val *IncludeInner) *NullableIncludeInner { 75 | return &NullableIncludeInner{value: val, isSet: true} 76 | } 77 | 78 | func (v NullableIncludeInner) MarshalJSON() ([]byte, error) { 79 | return json.Marshal(v.value) 80 | } 81 | 82 | func (v *NullableIncludeInner) UnmarshalJSON(src []byte) error { 83 | v.isSet = true 84 | return json.Unmarshal(src, &v.value) 85 | } 86 | -------------------------------------------------------------------------------- /swagger/model_location_inner.go: -------------------------------------------------------------------------------- 1 | /* 2 | ChromaDB API 3 | 4 | This is OpenAPI schema for ChromaDB API. 5 | 6 | API version: 1.0.0 7 | */ 8 | 9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. 10 | 11 | package openapi 12 | 13 | import ( 14 | "encoding/json" 15 | "fmt" 16 | ) 17 | 18 | // LocationInner struct for LocationInner 19 | type LocationInner struct { 20 | Int32 *int32 21 | String *string 22 | } 23 | 24 | // Unmarshal JSON data into any of the pointers in the struct 25 | func (dst *LocationInner) UnmarshalJSON(data []byte) error { 26 | var err error 27 | // try to unmarshal JSON data into Int32 28 | err = json.Unmarshal(data, &dst.Int32) 29 | if err == nil { 30 | jsonInt32, _ := json.Marshal(dst.Int32) 31 | if string(jsonInt32) == "{}" { // empty struct 32 | dst.Int32 = nil 33 | } else { 34 | return nil // data stored in dst.Int32, return on the first match 35 | } 36 | } else { 37 | dst.Int32 = nil 38 | } 39 | 40 | // try to unmarshal JSON data into String 41 | err = json.Unmarshal(data, &dst.String) 42 | if err == nil { 43 | jsonString, _ := json.Marshal(dst.String) 44 | if string(jsonString) == "{}" { // empty struct 45 | dst.String = nil 46 | } else { 47 | return nil // data stored in dst.String, return on the first match 48 | } 49 | } else { 50 | dst.String = nil 51 | } 52 | 53 | return fmt.Errorf("data failed to match schemas in anyOf(LocationInner)") 54 | } 55 | 56 | // Marshal data from the first non-nil pointers in the struct to JSON 57 | func (src *LocationInner) MarshalJSON() ([]byte, error) { 58 | if src.Int32 != nil { 59 | return json.Marshal(&src.Int32) 60 | } 61 | 62 | if src.String != nil { 63 | return json.Marshal(&src.String) 64 | } 65 | 66 | return nil, nil // no data in anyOf schemas 67 | } 68 | 69 | type NullableLocationInner struct { 70 | value *LocationInner 71 | isSet bool 72 | } 73 | 74 | func (v NullableLocationInner) Get() *LocationInner { 75 | return v.value 76 | } 77 | 78 | func (v *NullableLocationInner) Set(val *LocationInner) { 79 | v.value = val 80 | v.isSet = true 81 | } 82 | 83 | func (v NullableLocationInner) IsSet() bool { 84 | return v.isSet 85 | } 86 | 87 | func (v *NullableLocationInner) Unset() { 88 | v.value = nil 89 | v.isSet = false 90 | } 91 | 92 | func NewNullableLocationInner(val *LocationInner) *NullableLocationInner { 93 | return &NullableLocationInner{value: val, isSet: true} 94 | } 95 | 96 | func (v NullableLocationInner) MarshalJSON() ([]byte, error) { 97 | return json.Marshal(v.value) 98 | } 99 | 100 | func (v *NullableLocationInner) UnmarshalJSON(src []byte) error { 101 | v.isSet = true 102 | return json.Unmarshal(src, &v.value) 103 | } 104 | -------------------------------------------------------------------------------- /swagger/model_metadata.go: -------------------------------------------------------------------------------- 1 | /* 2 | ChromaDB API 3 | 4 | This is OpenAPI schema for ChromaDB API. 5 | 6 | API version: 1.0.0 7 | */ 8 | 9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. 10 | 11 | package openapi 12 | 13 | import ( 14 | "encoding/json" 15 | "fmt" 16 | ) 17 | 18 | // Metadata struct for Metadata 19 | type Metadata struct { 20 | Bool *bool 21 | Float32 *float32 22 | Int32 *int32 23 | String *string 24 | } 25 | 26 | // Unmarshal JSON data into any of the pointers in the struct 27 | func (dst *Metadata) UnmarshalJSON(data []byte) error { 28 | var err error 29 | // try to unmarshal JSON data into Bool 30 | err = json.Unmarshal(data, &dst.Bool) 31 | if err == nil { 32 | jsonBool, _ := json.Marshal(dst.Bool) 33 | if string(jsonBool) == "{}" { // empty struct 34 | dst.Bool = nil 35 | } else { 36 | return nil // data stored in dst.Bool, return on the first match 37 | } 38 | } else { 39 | dst.Bool = nil 40 | } 41 | 42 | // try to unmarshal JSON data into Int32 43 | err = json.Unmarshal(data, &dst.Int32) 44 | if err == nil { 45 | jsonInt32, _ := json.Marshal(dst.Int32) 46 | if string(jsonInt32) == "{}" { // empty struct 47 | dst.Int32 = nil 48 | } else { 49 | return nil // data stored in dst.Int32, return on the first match 50 | } 51 | } else { 52 | dst.Int32 = nil 53 | } 54 | // try to unmarshal JSON data into Float32 55 | err = json.Unmarshal(data, &dst.Float32) 56 | if err == nil { 57 | jsonFloat32, _ := json.Marshal(dst.Float32) 58 | if string(jsonFloat32) == "{}" { // empty struct 59 | dst.Float32 = nil 60 | } else { 61 | return nil // data stored in dst.Float32, return on the first match 62 | } 63 | } else { 64 | dst.Float32 = nil 65 | } 66 | 67 | // try to unmarshal JSON data into String 68 | err = json.Unmarshal(data, &dst.String) 69 | if err == nil { 70 | jsonString, _ := json.Marshal(dst.String) 71 | if string(jsonString) == "{}" { // empty struct 72 | dst.String = nil 73 | } else { 74 | return nil // data stored in dst.String, return on the first match 75 | } 76 | } else { 77 | dst.String = nil 78 | } 79 | 80 | return fmt.Errorf("data failed to match schemas in anyOf(Metadata)") 81 | } 82 | 83 | // Marshal data from the first non-nil pointers in the struct to JSON 84 | func (src *Metadata) MarshalJSON() ([]byte, error) { 85 | if src.Bool != nil { 86 | return json.Marshal(&src.Bool) 87 | } 88 | 89 | if src.Float32 != nil { 90 | return json.Marshal(&src.Float32) 91 | } 92 | 93 | if src.Int32 != nil { 94 | return json.Marshal(&src.Int32) 95 | } 96 | 97 | if src.String != nil { 98 | return json.Marshal(&src.String) 99 | } 100 | 101 | return nil, nil // no data in anyOf schemas 102 | } 103 | 104 | type NullableMetadata struct { 105 | value *Metadata 106 | isSet bool 107 | } 108 | 109 | func (v NullableMetadata) Get() *Metadata { 110 | return v.value 111 | } 112 | 113 | func (v *NullableMetadata) Set(val *Metadata) { 114 | v.value = val 115 | v.isSet = true 116 | } 117 | 118 | func (v NullableMetadata) IsSet() bool { 119 | return v.isSet 120 | } 121 | 122 | func (v *NullableMetadata) Unset() { 123 | v.value = nil 124 | v.isSet = false 125 | } 126 | 127 | func NewNullableMetadata(val *Metadata) *NullableMetadata { 128 | return &NullableMetadata{value: val, isSet: true} 129 | } 130 | 131 | func (v NullableMetadata) MarshalJSON() ([]byte, error) { 132 | return json.Marshal(v.value) 133 | } 134 | 135 | func (v *NullableMetadata) UnmarshalJSON(src []byte) error { 136 | v.isSet = true 137 | return json.Unmarshal(src, &v.value) 138 | } 139 | -------------------------------------------------------------------------------- /swagger/model_tenant.go: -------------------------------------------------------------------------------- 1 | /* 2 | ChromaDB API 3 | 4 | This is OpenAPI schema for ChromaDB API. 5 | 6 | API version: 1.0.0 7 | */ 8 | 9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. 10 | 11 | package openapi 12 | 13 | import ( 14 | "encoding/json" 15 | ) 16 | 17 | // checks if the Tenant type satisfies the MappedNullable interface at compile time 18 | var _ MappedNullable = &Tenant{} 19 | 20 | // Tenant struct for Tenant 21 | type Tenant struct { 22 | Name *string `json:"name,omitempty"` 23 | } 24 | 25 | // NewTenant instantiates a new Tenant object 26 | // This constructor will assign default values to properties that have it defined, 27 | // and makes sure properties required by API are set, but the set of arguments 28 | // will change when the set of required properties is changed 29 | func NewTenant() *Tenant { 30 | this := Tenant{} 31 | return &this 32 | } 33 | 34 | // NewTenantWithDefaults instantiates a new Tenant object 35 | // This constructor will only assign default values to properties that have it defined, 36 | // but it doesn't guarantee that properties required by API are set 37 | func NewTenantWithDefaults() *Tenant { 38 | this := Tenant{} 39 | return &this 40 | } 41 | 42 | // GetName returns the Name field value if set, zero value otherwise. 43 | func (o *Tenant) GetName() string { 44 | if o == nil || IsNil(o.Name) { 45 | var ret string 46 | return ret 47 | } 48 | return *o.Name 49 | } 50 | 51 | // GetNameOk returns a tuple with the Name field value if set, nil otherwise 52 | // and a boolean to check if the value has been set. 53 | func (o *Tenant) GetNameOk() (*string, bool) { 54 | if o == nil || IsNil(o.Name) { 55 | return nil, false 56 | } 57 | return o.Name, true 58 | } 59 | 60 | // HasName returns a boolean if a field has been set. 61 | func (o *Tenant) HasName() bool { 62 | if o != nil && !IsNil(o.Name) { 63 | return true 64 | } 65 | 66 | return false 67 | } 68 | 69 | // SetName gets a reference to the given string and assigns it to the Name field. 70 | func (o *Tenant) SetName(v string) { 71 | o.Name = &v 72 | } 73 | 74 | func (o Tenant) MarshalJSON() ([]byte, error) { 75 | toSerialize, err := o.ToMap() 76 | if err != nil { 77 | return []byte{}, err 78 | } 79 | return json.Marshal(toSerialize) 80 | } 81 | 82 | func (o Tenant) ToMap() (map[string]interface{}, error) { 83 | toSerialize := map[string]interface{}{} 84 | if !IsNil(o.Name) { 85 | toSerialize["name"] = o.Name 86 | } 87 | return toSerialize, nil 88 | } 89 | 90 | type NullableTenant struct { 91 | value *Tenant 92 | isSet bool 93 | } 94 | 95 | func (v NullableTenant) Get() *Tenant { 96 | return v.value 97 | } 98 | 99 | func (v *NullableTenant) Set(val *Tenant) { 100 | v.value = val 101 | v.isSet = true 102 | } 103 | 104 | func (v NullableTenant) IsSet() bool { 105 | return v.isSet 106 | } 107 | 108 | func (v *NullableTenant) Unset() { 109 | v.value = nil 110 | v.isSet = false 111 | } 112 | 113 | func NewNullableTenant(val *Tenant) *NullableTenant { 114 | return &NullableTenant{value: val, isSet: true} 115 | } 116 | 117 | func (v NullableTenant) MarshalJSON() ([]byte, error) { 118 | return json.Marshal(v.value) 119 | } 120 | 121 | func (v *NullableTenant) UnmarshalJSON(src []byte) error { 122 | v.isSet = true 123 | return json.Unmarshal(src, &v.value) 124 | } 125 | -------------------------------------------------------------------------------- /swagger/model_update_collection.go: -------------------------------------------------------------------------------- 1 | /* 2 | ChromaDB API 3 | 4 | This is OpenAPI schema for ChromaDB API. 5 | 6 | API version: 1.0.0 7 | */ 8 | 9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. 10 | 11 | package openapi 12 | 13 | import ( 14 | "encoding/json" 15 | ) 16 | 17 | // checks if the UpdateCollection type satisfies the MappedNullable interface at compile time 18 | var _ MappedNullable = &UpdateCollection{} 19 | 20 | // UpdateCollection struct for UpdateCollection 21 | type UpdateCollection struct { 22 | NewName *string `json:"new_name,omitempty"` 23 | NewMetadata map[string]interface{} `json:"new_metadata,omitempty"` 24 | } 25 | 26 | // NewUpdateCollection instantiates a new UpdateCollection object 27 | // This constructor will assign default values to properties that have it defined, 28 | // and makes sure properties required by API are set, but the set of arguments 29 | // will change when the set of required properties is changed 30 | func NewUpdateCollection() *UpdateCollection { 31 | this := UpdateCollection{} 32 | return &this 33 | } 34 | 35 | // NewUpdateCollectionWithDefaults instantiates a new UpdateCollection object 36 | // This constructor will only assign default values to properties that have it defined, 37 | // but it doesn't guarantee that properties required by API are set 38 | func NewUpdateCollectionWithDefaults() *UpdateCollection { 39 | this := UpdateCollection{} 40 | return &this 41 | } 42 | 43 | // GetNewName returns the NewName field value if set, zero value otherwise. 44 | func (o *UpdateCollection) GetNewName() string { 45 | if o == nil || IsNil(o.NewName) { 46 | var ret string 47 | return ret 48 | } 49 | return *o.NewName 50 | } 51 | 52 | // GetNewNameOk returns a tuple with the NewName field value if set, nil otherwise 53 | // and a boolean to check if the value has been set. 54 | func (o *UpdateCollection) GetNewNameOk() (*string, bool) { 55 | if o == nil || IsNil(o.NewName) { 56 | return nil, false 57 | } 58 | return o.NewName, true 59 | } 60 | 61 | // HasNewName returns a boolean if a field has been set. 62 | func (o *UpdateCollection) HasNewName() bool { 63 | if o != nil && !IsNil(o.NewName) { 64 | return true 65 | } 66 | 67 | return false 68 | } 69 | 70 | // SetNewName gets a reference to the given string and assigns it to the NewName field. 71 | func (o *UpdateCollection) SetNewName(v string) { 72 | o.NewName = &v 73 | } 74 | 75 | // GetNewMetadata returns the NewMetadata field value if set, zero value otherwise. 76 | func (o *UpdateCollection) GetNewMetadata() map[string]interface{} { 77 | if o == nil || IsNil(o.NewMetadata) { 78 | var ret map[string]interface{} 79 | return ret 80 | } 81 | return o.NewMetadata 82 | } 83 | 84 | // GetNewMetadataOk returns a tuple with the NewMetadata field value if set, nil otherwise 85 | // and a boolean to check if the value has been set. 86 | func (o *UpdateCollection) GetNewMetadataOk() (map[string]interface{}, bool) { 87 | if o == nil || IsNil(o.NewMetadata) { 88 | return map[string]interface{}{}, false 89 | } 90 | return o.NewMetadata, true 91 | } 92 | 93 | // HasNewMetadata returns a boolean if a field has been set. 94 | func (o *UpdateCollection) HasNewMetadata() bool { 95 | if o != nil && !IsNil(o.NewMetadata) { 96 | return true 97 | } 98 | 99 | return false 100 | } 101 | 102 | // SetNewMetadata gets a reference to the given map[string]interface{} and assigns it to the NewMetadata field. 103 | func (o *UpdateCollection) SetNewMetadata(v map[string]interface{}) { 104 | o.NewMetadata = v 105 | } 106 | 107 | func (o UpdateCollection) MarshalJSON() ([]byte, error) { 108 | toSerialize, err := o.ToMap() 109 | if err != nil { 110 | return []byte{}, err 111 | } 112 | return json.Marshal(toSerialize) 113 | } 114 | 115 | func (o UpdateCollection) ToMap() (map[string]interface{}, error) { 116 | toSerialize := map[string]interface{}{} 117 | if !IsNil(o.NewName) { 118 | toSerialize["new_name"] = o.NewName 119 | } 120 | if !IsNil(o.NewMetadata) { 121 | toSerialize["new_metadata"] = o.NewMetadata 122 | } 123 | return toSerialize, nil 124 | } 125 | 126 | type NullableUpdateCollection struct { 127 | value *UpdateCollection 128 | isSet bool 129 | } 130 | 131 | func (v NullableUpdateCollection) Get() *UpdateCollection { 132 | return v.value 133 | } 134 | 135 | func (v *NullableUpdateCollection) Set(val *UpdateCollection) { 136 | v.value = val 137 | v.isSet = true 138 | } 139 | 140 | func (v NullableUpdateCollection) IsSet() bool { 141 | return v.isSet 142 | } 143 | 144 | func (v *NullableUpdateCollection) Unset() { 145 | v.value = nil 146 | v.isSet = false 147 | } 148 | 149 | func NewNullableUpdateCollection(val *UpdateCollection) *NullableUpdateCollection { 150 | return &NullableUpdateCollection{value: val, isSet: true} 151 | } 152 | 153 | func (v NullableUpdateCollection) MarshalJSON() ([]byte, error) { 154 | return json.Marshal(v.value) 155 | } 156 | 157 | func (v *NullableUpdateCollection) UnmarshalJSON(src []byte) error { 158 | v.isSet = true 159 | return json.Unmarshal(src, &v.value) 160 | } 161 | -------------------------------------------------------------------------------- /swagger/model_validation_error.go: -------------------------------------------------------------------------------- 1 | /* 2 | ChromaDB API 3 | 4 | This is OpenAPI schema for ChromaDB API. 5 | 6 | API version: 1.0.0 7 | */ 8 | 9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. 10 | 11 | package openapi 12 | 13 | import ( 14 | "encoding/json" 15 | ) 16 | 17 | // checks if the ValidationError type satisfies the MappedNullable interface at compile time 18 | var _ MappedNullable = &ValidationError{} 19 | 20 | // ValidationError struct for ValidationError 21 | type ValidationError struct { 22 | Loc []LocationInner `json:"loc"` 23 | Msg string `json:"msg"` 24 | Type string `json:"type"` 25 | } 26 | 27 | // NewValidationError instantiates a new ValidationError object 28 | // This constructor will assign default values to properties that have it defined, 29 | // and makes sure properties required by API are set, but the set of arguments 30 | // will change when the set of required properties is changed 31 | func NewValidationError(loc []LocationInner, msg string, type_ string) *ValidationError { 32 | this := ValidationError{} 33 | this.Loc = loc 34 | this.Msg = msg 35 | this.Type = type_ 36 | return &this 37 | } 38 | 39 | // NewValidationErrorWithDefaults instantiates a new ValidationError object 40 | // This constructor will only assign default values to properties that have it defined, 41 | // but it doesn't guarantee that properties required by API are set 42 | func NewValidationErrorWithDefaults() *ValidationError { 43 | this := ValidationError{} 44 | return &this 45 | } 46 | 47 | // GetLoc returns the Loc field value 48 | func (o *ValidationError) GetLoc() []LocationInner { 49 | if o == nil { 50 | var ret []LocationInner 51 | return ret 52 | } 53 | 54 | return o.Loc 55 | } 56 | 57 | // GetLocOk returns a tuple with the Loc field value 58 | // and a boolean to check if the value has been set. 59 | func (o *ValidationError) GetLocOk() ([]LocationInner, bool) { 60 | if o == nil { 61 | return nil, false 62 | } 63 | return o.Loc, true 64 | } 65 | 66 | // SetLoc sets field value 67 | func (o *ValidationError) SetLoc(v []LocationInner) { 68 | o.Loc = v 69 | } 70 | 71 | // GetMsg returns the Msg field value 72 | func (o *ValidationError) GetMsg() string { 73 | if o == nil { 74 | var ret string 75 | return ret 76 | } 77 | 78 | return o.Msg 79 | } 80 | 81 | // GetMsgOk returns a tuple with the Msg field value 82 | // and a boolean to check if the value has been set. 83 | func (o *ValidationError) GetMsgOk() (*string, bool) { 84 | if o == nil { 85 | return nil, false 86 | } 87 | return &o.Msg, true 88 | } 89 | 90 | // SetMsg sets field value 91 | func (o *ValidationError) SetMsg(v string) { 92 | o.Msg = v 93 | } 94 | 95 | // GetType returns the Type field value 96 | func (o *ValidationError) GetType() string { 97 | if o == nil { 98 | var ret string 99 | return ret 100 | } 101 | 102 | return o.Type 103 | } 104 | 105 | // GetTypeOk returns a tuple with the Type field value 106 | // and a boolean to check if the value has been set. 107 | func (o *ValidationError) GetTypeOk() (*string, bool) { 108 | if o == nil { 109 | return nil, false 110 | } 111 | return &o.Type, true 112 | } 113 | 114 | // SetType sets field value 115 | func (o *ValidationError) SetType(v string) { 116 | o.Type = v 117 | } 118 | 119 | func (o ValidationError) MarshalJSON() ([]byte, error) { 120 | toSerialize, err := o.ToMap() 121 | if err != nil { 122 | return []byte{}, err 123 | } 124 | return json.Marshal(toSerialize) 125 | } 126 | 127 | func (o ValidationError) ToMap() (map[string]interface{}, error) { 128 | toSerialize := map[string]interface{}{} 129 | toSerialize["loc"] = o.Loc 130 | toSerialize["msg"] = o.Msg 131 | toSerialize["type"] = o.Type 132 | return toSerialize, nil 133 | } 134 | 135 | type NullableValidationError struct { 136 | value *ValidationError 137 | isSet bool 138 | } 139 | 140 | func (v NullableValidationError) Get() *ValidationError { 141 | return v.value 142 | } 143 | 144 | func (v *NullableValidationError) Set(val *ValidationError) { 145 | v.value = val 146 | v.isSet = true 147 | } 148 | 149 | func (v NullableValidationError) IsSet() bool { 150 | return v.isSet 151 | } 152 | 153 | func (v *NullableValidationError) Unset() { 154 | v.value = nil 155 | v.isSet = false 156 | } 157 | 158 | func NewNullableValidationError(val *ValidationError) *NullableValidationError { 159 | return &NullableValidationError{value: val, isSet: true} 160 | } 161 | 162 | func (v NullableValidationError) MarshalJSON() ([]byte, error) { 163 | return json.Marshal(v.value) 164 | } 165 | 166 | func (v *NullableValidationError) UnmarshalJSON(src []byte) error { 167 | v.isSet = true 168 | return json.Unmarshal(src, &v.value) 169 | } 170 | -------------------------------------------------------------------------------- /swagger/response.go: -------------------------------------------------------------------------------- 1 | /* 2 | ChromaDB API 3 | 4 | This is OpenAPI schema for ChromaDB API. 5 | 6 | API version: 1.0.0 7 | */ 8 | 9 | // Code generated by OpenAPI Generator (https://openapi-generator.tech); DO NOT EDIT. 10 | 11 | package openapi 12 | 13 | import ( 14 | "net/http" 15 | ) 16 | 17 | // APIResponse stores the API response returned by the server. 18 | type APIResponse struct { 19 | *http.Response `json:"-"` 20 | Message string `json:"message,omitempty"` 21 | // Operation is the name of the OpenAPI operation. 22 | Operation string `json:"operation,omitempty"` 23 | // RequestURL is the request URL. This value is always available, even if the 24 | // embedded *http.Response is nil. 25 | RequestURL string `json:"url,omitempty"` 26 | // Method is the HTTP method used for the request. This value is always 27 | // available, even if the embedded *http.Response is nil. 28 | Method string `json:"method,omitempty"` 29 | // Payload holds the contents of the response body (which may be nil or empty). 30 | // This is provided here as the raw response.Body() reader will have already 31 | // been drained. 32 | Payload []byte `json:"-"` 33 | } 34 | 35 | // NewAPIResponse returns a new APIResponse object. 36 | func NewAPIResponse(r *http.Response) *APIResponse { 37 | 38 | response := &APIResponse{Response: r} 39 | return response 40 | } 41 | 42 | // NewAPIResponseWithError returns a new APIResponse object with the provided error message. 43 | func NewAPIResponseWithError(errorMessage string) *APIResponse { 44 | 45 | response := &APIResponse{Message: errorMessage} 46 | return response 47 | } 48 | -------------------------------------------------------------------------------- /types/record_test.go: -------------------------------------------------------------------------------- 1 | //go:build basic 2 | 3 | package types 4 | 5 | import ( 6 | "context" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestNewRecordSet(t *testing.T) { 13 | t.Run("Test NewRecordSet", func(t *testing.T) { 14 | recordSet, err := NewRecordSet() 15 | require.NoError(t, err) 16 | require.NotNil(t, recordSet) 17 | require.NotNil(t, recordSet.Records) 18 | require.Equal(t, len(recordSet.Records), 0) 19 | }) 20 | t.Run("Test NewRecordSet with options", func(t *testing.T) { 21 | recordSet, err := NewRecordSet(WithIDGenerator(NewULIDGenerator()), WithEmbeddingFunction(NewConsistentHashEmbeddingFunction())) 22 | require.NoError(t, err) 23 | require.NotNil(t, recordSet) 24 | require.NotNil(t, recordSet.IDGenerator) 25 | require.NotNil(t, recordSet.EmbeddingFunction) 26 | require.NotNil(t, recordSet.Records) 27 | require.Equal(t, len(recordSet.Records), 0) 28 | }) 29 | 30 | t.Run("Test NewRecordSet with IDGenerator", func(t *testing.T) { 31 | recordSet, err := NewRecordSet(WithIDGenerator(NewULIDGenerator())) 32 | recordSet.WithRecord(WithDocument("test document")) 33 | require.NoError(t, err) 34 | require.NotNil(t, recordSet) 35 | require.NotNil(t, recordSet.IDGenerator) 36 | require.Nil(t, recordSet.EmbeddingFunction) 37 | require.NotNil(t, recordSet.Records) 38 | require.Equal(t, len(recordSet.Records), 1) 39 | require.NotNil(t, recordSet.Records[0].ID) 40 | }) 41 | 42 | t.Run("Test NewRecordSet with EmbeddingFunction", func(t *testing.T) { 43 | recordSet, err := NewRecordSet(WithEmbeddingFunction(NewConsistentHashEmbeddingFunction())) 44 | require.NoError(t, err) 45 | recordSet.WithRecord(WithDocument("test document"), WithID("1")) 46 | _, err = recordSet.BuildAndValidate(context.TODO()) 47 | require.NoError(t, err) 48 | require.NotNil(t, recordSet) 49 | require.Nil(t, recordSet.IDGenerator) 50 | require.NotNil(t, recordSet.EmbeddingFunction) 51 | require.NotNil(t, recordSet.Records) 52 | require.Equal(t, len(recordSet.Records), 1) 53 | require.NotNil(t, recordSet.Records[0].ID) 54 | require.Equal(t, recordSet.Records[0].ID, "1") 55 | require.NotNil(t, recordSet.Records[0].Document) 56 | require.Equal(t, recordSet.Records[0].Document, "test document") 57 | require.NotNil(t, recordSet.Records[0].Embedding.GetFloat32()) 58 | }) 59 | 60 | t.Run("Test NewRecordSet with complete Record", func(t *testing.T) { 61 | recordSet, err := NewRecordSet(WithEmbeddingFunction(NewConsistentHashEmbeddingFunction())) 62 | require.NoError(t, err) 63 | var embeddings = []float32{0.1, 0.2, 0.3} 64 | recordSet.WithRecord( 65 | WithID("1"), 66 | WithDocument("test document"), 67 | WithEmbedding(*NewEmbeddingFromFloat32(embeddings)), 68 | WithMetadata("testKey", 1), 69 | ) 70 | _, err = recordSet.BuildAndValidate(context.TODO()) 71 | require.NoError(t, err) 72 | require.NotNil(t, recordSet) 73 | require.Nil(t, recordSet.IDGenerator) 74 | require.NotNil(t, recordSet.EmbeddingFunction) 75 | require.NotNil(t, recordSet.Records) 76 | require.Equal(t, len(recordSet.Records), 1) 77 | require.NotNil(t, recordSet.Records[0].ID) 78 | require.Equal(t, recordSet.Records[0].ID, "1") 79 | require.NotNil(t, recordSet.Records[0].Embedding.GetFloat32()) 80 | require.Equal(t, recordSet.Records[0].Embedding.GetFloat32(), &embeddings) 81 | require.Equal(t, recordSet.Records[0].Document, "test document") 82 | require.Equal(t, recordSet.Records[0].Metadata["testKey"], 1) 83 | }) 84 | } 85 | -------------------------------------------------------------------------------- /where_document/wheredoc.go: -------------------------------------------------------------------------------- 1 | package wheredoc 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | type InvalidWhereDocumentValueError struct { 8 | Value interface{} 9 | } 10 | 11 | func (e *InvalidWhereDocumentValueError) Error() string { 12 | return fmt.Sprintf("Invalid value for where document clause for value %v. Allowed values are string", e.Value) 13 | } 14 | 15 | type Builder struct { 16 | WhereClause map[string]interface{} 17 | err error 18 | } 19 | 20 | func NewWhereDocumentBuilder() *Builder { 21 | return &Builder{WhereClause: make(map[string]interface{})} 22 | } 23 | 24 | func (w *Builder) operation(operation string, value interface{}) *Builder { 25 | if w.err != nil { 26 | return w 27 | } 28 | inner := make(map[string]interface{}) 29 | 30 | switch value.(type) { 31 | case string: 32 | default: 33 | w.err = &InvalidWhereDocumentValueError{Value: value} 34 | return w 35 | } 36 | inner[operation] = value 37 | w.WhereClause[operation] = value 38 | return w 39 | } 40 | 41 | func (w *Builder) Contains(value interface{}) *Builder { 42 | return w.operation("$contains", value) 43 | } 44 | 45 | func (w *Builder) NotContains(value interface{}) *Builder { 46 | return w.operation("$not_contains", value) 47 | } 48 | 49 | func (w *Builder) And(builders ...*Builder) *Builder { 50 | if w.err != nil { 51 | return w 52 | } 53 | var andClause []map[string]interface{} 54 | for _, b := range builders { 55 | buildExpr, err := b.Build() 56 | if err != nil { 57 | w.err = err 58 | return w 59 | } 60 | andClause = append(andClause, buildExpr) 61 | } 62 | w.WhereClause["$and"] = andClause 63 | return w 64 | } 65 | 66 | func (w *Builder) Or(builders ...*Builder) *Builder { 67 | if w.err != nil { 68 | return w 69 | } 70 | var orClause []map[string]interface{} 71 | for _, b := range builders { 72 | buildExpr, err := b.Build() 73 | if err != nil { 74 | w.err = err 75 | return w 76 | } 77 | orClause = append(orClause, buildExpr) 78 | } 79 | w.WhereClause["$or"] = orClause 80 | return w 81 | } 82 | 83 | func (w *Builder) Build() (map[string]interface{}, error) { 84 | if w.err != nil { 85 | return nil, w.err 86 | } 87 | return w.WhereClause, nil 88 | } 89 | 90 | type WhereDocumentOperation func(builder *Builder) error 91 | 92 | func Contains(value interface{}) WhereDocumentOperation { 93 | return func(w *Builder) error { 94 | w.Contains(value) 95 | return nil 96 | } 97 | } 98 | 99 | func NotContains(value interface{}) WhereDocumentOperation { 100 | return func(w *Builder) error { 101 | w.NotContains(value) 102 | return nil 103 | } 104 | } 105 | 106 | func And(ops ...WhereDocumentOperation) WhereDocumentOperation { 107 | return func(w *Builder) error { 108 | subBuilders := make([]*Builder, 0, len(ops)) 109 | for _, op := range ops { 110 | wdx := NewWhereDocumentBuilder() 111 | if err := op(wdx); err != nil { 112 | return err 113 | } 114 | subBuilders = append(subBuilders, wdx) 115 | } 116 | w.And(subBuilders...) 117 | return nil 118 | } 119 | } 120 | 121 | func Or(ops ...WhereDocumentOperation) WhereDocumentOperation { 122 | return func(w *Builder) error { 123 | subBuilders := make([]*Builder, 0, len(ops)) 124 | for _, op := range ops { 125 | wdx := NewWhereDocumentBuilder() 126 | if err := op(wdx); err != nil { 127 | return err 128 | } 129 | subBuilders = append(subBuilders, wdx) 130 | } 131 | w.Or(subBuilders...) 132 | return nil 133 | } 134 | } 135 | 136 | func WhereDocument(operation WhereDocumentOperation) (map[string]interface{}, error) { 137 | w := NewWhereDocumentBuilder() 138 | if err := operation(w); err != nil { 139 | return nil, err 140 | } 141 | return w.Build() 142 | } 143 | --------------------------------------------------------------------------------