├── .github └── workflows │ ├── integration-test.yml │ └── release.yml ├── .gitignore ├── CODEOWNERS ├── LICENSE.md ├── README.md ├── pom.xml └── src ├── main ├── java │ └── tech │ │ └── amikos │ │ └── chromadb │ │ ├── ChromaException.java │ │ ├── Client.java │ │ ├── Collection.java │ │ ├── Constants.java │ │ ├── EFException.java │ │ ├── Embedding.java │ │ └── embeddings │ │ ├── DefaultEmbeddingFunction.java │ │ ├── EmbeddingFunction.java │ │ ├── WithParam.java │ │ ├── cohere │ │ ├── CohereEmbeddingFunction.java │ │ ├── CreateEmbeddingRequest.java │ │ └── CreateEmbeddingResponse.java │ │ ├── hf │ │ ├── CreateEmbeddingRequest.java │ │ ├── CreateEmbeddingResponse.java │ │ └── HuggingFaceEmbeddingFunction.java │ │ ├── ollama │ │ ├── CreateEmbeddingRequest.java │ │ ├── CreateEmbeddingResponse.java │ │ └── OllamaEmbeddingFunction.java │ │ └── openai │ │ ├── CreateEmbeddingRequest.java │ │ ├── CreateEmbeddingResponse.java │ │ └── OpenAIEmbeddingFunction.java └── resources │ └── openapi │ ├── api.yaml │ └── openai.yaml └── test └── java └── tech └── amikos └── chromadb ├── TestAPI.java ├── Utils.java └── embeddings ├── TestDefaultEmbeddings.java ├── cohere └── TestCohereEmbeddings.java ├── hf └── TestHuggingFaceEmbeddings.java ├── ollama └── TestOllamaEmbeddings.java └── openai └── TestOpenAIEmbeddings.java /.github/workflows/integration-test.yml: -------------------------------------------------------------------------------- 1 | name: Integration test 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | - "**" 8 | 9 | jobs: 10 | integration-test: 11 | strategy: 12 | matrix: 13 | chroma-version: [0.4.24, 0.5.0, 0.5.5, 0.5.15 ] 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up JDK 8 18 | uses: actions/setup-java@v3 19 | with: 20 | java-version: '8' 21 | distribution: 'adopt' 22 | cache: maven 23 | - name: Test with Maven 24 | run: mvn --batch-mode clean test 25 | env: 26 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 27 | COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} 28 | HF_API_KEY: ${{ secrets.HF_API_KEY }} 29 | CHROMA_VERSION: ${{ matrix.chroma-version }} 30 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | release: 5 | types: [created] 6 | branches: [ "main" ] 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up JDK 8 14 | uses: actions/setup-java@v3 15 | with: 16 | java-version: '8' 17 | distribution: 'adopt' 18 | cache: maven 19 | server-id: ossrh 20 | server-username: MAVEN_USERNAME 21 | server-password: MAVEN_PASSWORD 22 | 23 | - id: install-secret-key 24 | name: Install gpg secret key 25 | run: | 26 | # Install gpg secret key 27 | cat <(echo -e "${GPG_KEY}") | gpg --batch --import 28 | # Verify gpg secret key 29 | gpg --list-secret-keys --keyid-format LONG 30 | env: 31 | GPG_KEY: ${{ secrets.AMIKOS_OSS_GPG_SECRET_KEY }} 32 | - name: Version bump 33 | run: | 34 | mvn \ 35 | --no-transfer-progress \ 36 | --batch-mode \ 37 | -Dgpg.skip=true \ 38 | -DskipTests \ 39 | versions:set \ 40 | -DnewVersion=${{ github.ref_name }} 41 | - name: Publish package 42 | run: | 43 | mvn \ 44 | --no-transfer-progress \ 45 | --batch-mode \ 46 | -Dgpg.passphrase=${GPG_PASSPHRASE} \ 47 | -DskipTests \ 48 | clean package deploy 49 | env: 50 | # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 51 | MAVEN_USERNAME: ${{ secrets.OSSRH_USERNAME }} 52 | MAVEN_PASSWORD: ${{ secrets.OSSRH_TOKEN }} 53 | GPG_PASSPHRASE: ${{ secrets.AMIKOS_OSS_GPG_SECRET_KEY_PASSWORD }} 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | !.mvn/wrapper/maven-wrapper.jar 3 | !**/src/main/**/target/ 4 | !**/src/test/**/target/ 5 | 6 | ### IntelliJ IDEA ### 7 | .idea/ 8 | **.iml 9 | 10 | ### Eclipse ### 11 | .apt_generated 12 | .classpath 13 | .factorypath 14 | .project 15 | .settings 16 | .springBeans 17 | .sts4-cache 18 | 19 | ### NetBeans ### 20 | /nbproject/private/ 21 | /nbbuild/ 22 | /dist/ 23 | /nbdist/ 24 | /.nb-gradle/ 25 | build/ 26 | !**/src/main/**/build/ 27 | !**/src/test/**/build/ 28 | 29 | ### VS Code ### 30 | .vscode/ 31 | 32 | ### Mac OS ### 33 | .DS_Store*.class 34 | *.log 35 | *.ctxt 36 | .mtj.tmp/ 37 | *.jar 38 | *.war 39 | *.nar 40 | *.ear 41 | *.zip 42 | *.tar.gz 43 | *.rar 44 | hs_err_pid* 45 | replay_pid* 46 | cmake-build-*/ 47 | .idea/**/mongoSettings.xml 48 | out/ 49 | .idea_modules/ 50 | atlassian-ide-plugin.xml 51 | .idea/replstate.xml 52 | .idea/sonarlint/ 53 | com_crashlytics_export_strings.xml 54 | crashlytics.properties 55 | crashlytics-build.properties 56 | fabric.properties 57 | .idea/httpRequests 58 | .idea/caches/build_file_checksums.ser 59 | **.env -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @tazarov -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2023 Amikos Tech Ltd. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chroma Vector Database Java Client 2 | 3 | This is a very basic/naive implementation in Java of the Chroma Vector Database API. 4 | 5 | This client works with Chroma Versions `0.4.3+` 6 | 7 | ## Features 8 | 9 | ### Embeddings Support 10 | 11 | - ✅ Default Embedding Function (all-mini-lm model) 12 | - ✅ OpenAI Embedding Function 13 | - ✅ Cohere Embedding Function 14 | - ✅ HuggingFace Embedding Function (Inference API) 15 | - ✅ Ollama Embedding Function 16 | - ✅ Hugging Face Text Embedding Inference (HFEI) API 17 | - [ ] Sentence Transformers 18 | - [ ] PaLM API 19 | - [ ] Custom Embedding Function 20 | - [ ] Cloudflare Workers AI 21 | 22 | ### Feature Parity with ChromaDB API 23 | 24 | - [x] Reset 25 | - [x] Heartbeat 26 | - [x] List Collections 27 | - [x] Get Version 28 | - [x] Create Collection 29 | - [x] Delete Collection 30 | - [x] Collection Add 31 | - [x] Collection Get (partial without additional parameters) 32 | - [x] Collection Count 33 | - [x] Collection Query 34 | - [x] Collection Modify 35 | - [x] Collection Update 36 | - [x] Collection Upsert 37 | - [x] Collection Create Index 38 | - [x] Collection Delete - delete documents in collection 39 | 40 | ## TODO 41 | 42 | - [x] Push the package to Maven 43 | Central - https://docs.github.com/en/actions/publishing-packages/publishing-java-packages-with-maven 44 | - ⚒️ Fluent API - make it easier for users to make use of the library 45 | - [ ] Support for PaLM API 46 | - [x] Support for Sentence Transformers with Hugging Face API 47 | - ⚒️ Authentication ⚒️ 48 | 49 | ## Usage 50 | 51 | Add Maven dependency: 52 | 53 | ```xml 54 | 55 | 56 | io.github.amikos-tech 57 | chromadb-java-client 58 | 0.1.7 59 | 60 | ``` 61 | 62 | Ensure you have a running instance of Chroma running. We recommend one of the two following options: 63 | 64 | - Official documentation - https://docs.trychroma.com/usage-guide#running-chroma-in-clientserver-mode 65 | - If you are a fan of Kubernetes, you can use the Helm chart - https://github.com/amikos-tech/chromadb-chart (Note: You 66 | will need `Docker`, `minikube` and `kubectl` installed) 67 | 68 | ### Default Embedding Function 69 | 70 | Since version `0.1.6` the library also offers a built-in default embedding function which does not rely on any external 71 | API to generate embeddings and works in the same way it works in core Chroma Python package. 72 | 73 | ```java 74 | package tech.amikos; 75 | 76 | import tech.amikos.chromadb.*; 77 | import tech.amikos.chromadb.Collection; 78 | import tech.amikos.chromadb.embeddings.DefaultEmbeddingFunction; 79 | 80 | import java.util.*; 81 | 82 | public class Main { 83 | public static void main(String[] args) { 84 | try { 85 | Client client = new Client(System.getenv("CHROMA_URL")); 86 | client.reset(); 87 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 88 | Collection collection = client.createCollection("test-collection", null, true, ef); 89 | List> metadata = new ArrayList<>(); 90 | metadata.add(new HashMap() {{ 91 | put("type", "scientist"); 92 | }}); 93 | metadata.add(new HashMap() {{ 94 | put("type", "spy"); 95 | }}); 96 | collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2")); 97 | Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null); 98 | System.out.println(qr); 99 | } catch (Exception e) { 100 | System.out.println(e); 101 | } 102 | } 103 | } 104 | ``` 105 | 106 | ### Example OpenAI Embedding Function 107 | 108 | In this example we rely on `tech.amikos.chromadb.embeddings.openai.OpenAIEmbeddingFunction` to generate embeddings for 109 | our documents. 110 | 111 | | **Important**: Ensure you have `OPENAI_API_KEY` environment variable set 112 | 113 | ```java 114 | package tech.amikos; 115 | 116 | import tech.amikos.chromadb.Client; 117 | import tech.amikos.chromadb.Collection; 118 | import tech.amikos.chromadb.EmbeddingFunction; 119 | import tech.amikos.chromadb.embeddings.openai.OpenAIEmbeddingFunction; 120 | 121 | import java.util.*; 122 | 123 | public class Main { 124 | public static void main(String[] args) { 125 | try { 126 | Client client = new Client(System.getenv("CHROMA_URL")); 127 | String apiKey = System.getenv("OPENAI_API_KEY"); 128 | EmbeddingFunction ef = new OpenAIEmbeddingFunction(apiKey, "text-embedding-3-small"); 129 | Collection collection = client.createCollection("test-collection", null, true, ef); 130 | List> metadata = new ArrayList<>(); 131 | metadata.add(new HashMap() {{ 132 | put("type", "scientist"); 133 | }}); 134 | metadata.add(new HashMap() {{ 135 | put("type", "spy"); 136 | }}); 137 | collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2")); 138 | Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null); 139 | System.out.println(qr); 140 | } catch (Exception e) { 141 | e.printStackTrace(); 142 | System.out.println(e); 143 | } 144 | } 145 | } 146 | ``` 147 | 148 | The above should output: 149 | 150 | ```bash 151 | {"documents":[["Hello, my name is Bond. I am a Spy.","Hello, my name is John. I am a Data Scientist."]],"ids":[["2","1"]],"metadatas":[[{"type":"spy"},{"type":"scientist"}]],"distances":[[0.28461432,0.50961685]]} 152 | ``` 153 | 154 | #### Custom OpenAI Endpoint 155 | 156 | For endpoints compatible with OpenAI Embeddings API (e.g. [ollama](https://github.com/ollama/ollama)), you can use the 157 | following: 158 | 159 | > Note: We have added a builder to help with the configuration of the OpenAIEmbeddingFunction 160 | 161 | ```java 162 | EmbeddingFunction ef = OpenAIEmbeddingFunction.Instance() 163 | .withOpenAIAPIKey(apiKey) 164 | .withModelName("llama2") 165 | .withApiEndpoint("http://localhost:11434/api/embedding") // not really custom, but just to test the method 166 | .build(); 167 | ``` 168 | 169 | Quick Start Guide with Ollama: 170 | 171 | ```bash 172 | docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama 173 | docker exec -it ollama ollama run llama2 # press Ctrl+D to exit after model downloads successfully 174 | # test it 175 | curl http://localhost:11434/api/embeddings -d '{\n "model": "llama2",\n "prompt": "Here is an article about llamas..."\n}' 176 | ``` 177 | 178 | ### Example Cohere Embedding Function 179 | 180 | In this example we rely on `tech.amikos.chromadb.embeddings.cohere.CohereEmbeddingFunction` to generate embeddings for 181 | our documents. 182 | 183 | | **Important**: Ensure you have `COHERE_API_KEY` environment variable set 184 | 185 | ```java 186 | package tech.amikos; 187 | 188 | import tech.amikos.chromadb.*; 189 | import tech.amikos.chromadb.Collection; 190 | import tech.amikos.chromadb.embeddings.cohere.CohereEmbeddingFunction; 191 | 192 | import java.util.*; 193 | 194 | public class Main { 195 | public static void main(String[] args) { 196 | try { 197 | Client client = new Client(System.getenv("CHROMA_URL")); 198 | client.reset(); 199 | String apiKey = System.getenv("COHERE_API_KEY"); 200 | EmbeddingFunction ef = new CohereEmbeddingFunction(apiKey); 201 | Collection collection = client.createCollection("test-collection", null, true, ef); 202 | List> metadata = new ArrayList<>(); 203 | metadata.add(new HashMap() {{ 204 | put("type", "scientist"); 205 | }}); 206 | metadata.add(new HashMap() {{ 207 | put("type", "spy"); 208 | }}); 209 | collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2")); 210 | Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null); 211 | System.out.println(qr); 212 | } catch (Exception e) { 213 | e.printStackTrace(); 214 | System.out.println(e); 215 | } 216 | } 217 | } 218 | ``` 219 | 220 | The above should output: 221 | 222 | ```bash 223 | {"documents":[["Hello, my name is Bond. I am a Spy.","Hello, my name is John. I am a Data Scientist."]],"ids":[["2","1"]],"metadatas":[[{"type":"spy"},{"type":"scientist"}]],"distances":[[5112.614,10974.804]]} 224 | ``` 225 | 226 | ### Example Hugging Face Sentence Transformers Embedding Function 227 | 228 | #### Hugging Face Inference API 229 | 230 | In this example we rely on `tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction` to generate embeddings for 231 | our documents using HuggingFace cloud-based inference API. 232 | 233 | | **Important**: Ensure you have `HF_API_KEY` environment variable set 234 | 235 | ```java 236 | package tech.amikos; 237 | 238 | import tech.amikos.chromadb.*; 239 | import tech.amikos.chromadb.Collection; 240 | import tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction; 241 | 242 | import java.util.*; 243 | 244 | public class Main { 245 | public static void main(String[] args) { 246 | try { 247 | Client client = new Client("http://localhost:8000"); 248 | String apiKey = System.getenv("HF_API_KEY"); 249 | EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(apiKey); 250 | Collection collection = client.createCollection("test-collection", null, true, ef); 251 | List> metadata = new ArrayList<>(); 252 | metadata.add(new HashMap() {{ 253 | put("type", "scientist"); 254 | }}); 255 | metadata.add(new HashMap() {{ 256 | put("type", "spy"); 257 | }}); 258 | collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2")); 259 | Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null); 260 | System.out.println(qr); 261 | } catch (Exception e) { 262 | System.out.println(e); 263 | } 264 | } 265 | } 266 | ``` 267 | 268 | The above should output: 269 | 270 | ```bash 271 | {"documents":[["Hello, my name is Bond. I am a Spy.","Hello, my name is John. I am a Data Scientist."]],"ids":[["2","1"]],"metadatas":[[{"type":"spy"},{"type":"scientist"}]],"distances":[[0.9073759,1.6440368]]} 272 | ``` 273 | 274 | #### Hugging Face Text Embedding Inference (HFEI) API 275 | 276 | In this example we'll use a local Docker based server to generate the embeddings with 277 | `Snowflake/snowflake-arctic-embed-s` mode. 278 | 279 | First let's start the HFEI server: 280 | 281 | ```bash 282 | docker run -d -p 8008:80 --platform linux/amd64 --name hfei ghcr.io/huggingface/text-embeddings-inference:cpu-1.5.0 --model-id Snowflake/snowflake-arctic-embed-s --revision main 283 | ``` 284 | 285 | > Note: Check the official documentation for more details - https://github.com/huggingface/text-embeddings-inference 286 | 287 | Then we can use the following code to generate embeddings. Note the use of 288 | `new HuggingFaceEmbeddingFunction.WithAPIType(HuggingFaceEmbeddingFunction.APIType.HFEI_API));` to define the API type, 289 | this will ensure the client uses the correct endpoint. 290 | 291 | ```java 292 | package tech.amikos; 293 | 294 | import tech.amikos.chromadb.*; 295 | import tech.amikos.chromadb.Collection; 296 | import tech.amikos.chromadb.embeddings.hf.HuggingFaceEmbeddingFunction; 297 | 298 | import java.util.*; 299 | 300 | public class Main { 301 | public static void main(String[] args) { 302 | try { 303 | Client client = new Client("http://localhost:8000"); 304 | EmbeddingFunction ef = new HuggingFaceEmbeddingFunction( 305 | WithParam.baseAPI("http://localhost:8008"), 306 | new HuggingFaceEmbeddingFunction.WithAPIType(HuggingFaceEmbeddingFunction.APIType.HFEI_API)); 307 | Collection collection = client.createCollection("test-collection", null, true, ef); 308 | List> metadata = new ArrayList<>(); 309 | metadata.add(new HashMap() {{ 310 | put("type", "scientist"); 311 | }}); 312 | metadata.add(new HashMap() {{ 313 | put("type", "spy"); 314 | }}); 315 | collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2")); 316 | Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null); 317 | System.out.println(qr); 318 | } catch (Exception e) { 319 | System.out.println(e); 320 | } 321 | } 322 | } 323 | ``` 324 | 325 | The above should similar to the following output: 326 | 327 | ```bash 328 | {"documents":[["Hello, my name is Bond. I am a Spy.","Hello, my name is John. I am a Data Scientist."]],"ids":[["2","1"]],"metadatas":[[{"type":"spy"},{"type":"scientist"}]],"distances":[[0.19665092,0.42433012]]} 329 | ``` 330 | 331 | ### Ollama Embedding Function 332 | 333 | In this example we rely on `tech.amikos.chromadb.embeddings.ollama.OllamaEmbeddingFunction` to generate embeddings for 334 | our documents. 335 | 336 | ```java 337 | package tech.amikos; 338 | 339 | import tech.amikos.chromadb.*; 340 | import tech.amikos.chromadb.embeddings.ollama.OllamaEmbeddingFunction; 341 | import tech.amikos.chromadb.Collection; 342 | 343 | import java.util.*; 344 | 345 | public class Main { 346 | public static void main(String[] args) { 347 | try { 348 | Client client = new Client(System.getenv("CHROMA_URL")); 349 | client.reset(); 350 | EmbeddingFunction ef = new OllamaEmbeddingFunction(); 351 | Collection collection = client.createCollection("test-collection", null, true, ef); 352 | List> metadata = new ArrayList<>(); 353 | metadata.add(new HashMap() {{ 354 | put("type", "scientist"); 355 | }}); 356 | metadata.add(new HashMap() {{ 357 | put("type", "spy"); 358 | }}); 359 | collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2")); 360 | Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null); 361 | System.out.println(qr); 362 | } catch (Exception e) { 363 | System.out.println(e); 364 | } 365 | } 366 | } 367 | ``` 368 | 369 | ### Example Auth 370 | 371 | > Note: This is a workaround until the client overhaul is completed 372 | 373 | **Basic Auth**: 374 | 375 | ```java 376 | package tech.amikos; 377 | 378 | import tech.amikos.chromadb.*; 379 | import tech.amikos.chromadb.Collection; 380 | 381 | import java.util.*; 382 | 383 | public class Main { 384 | public static void main(String[] args) { 385 | try { 386 | Client client = new Client(System.getenv("CHROMA_URL")); 387 | String encodedString = Base64.getEncoder().encodeToString("admin:admin".getBytes()); 388 | client.setDefaultHeaders(new HashMap<>() {{ 389 | put("Authorization", "Basic " + encodedString); 390 | }}); 391 | // your code here 392 | } catch (Exception e) { 393 | System.out.println(e); 394 | } 395 | } 396 | } 397 | ``` 398 | 399 | **Static Auth - Authorization**: 400 | 401 | ```java 402 | package tech.amikos; 403 | 404 | import tech.amikos.chromadb.*; 405 | import tech.amikos.chromadb.Collection; 406 | 407 | import java.util.*; 408 | 409 | public class Main { 410 | public static void main(String[] args) { 411 | try { 412 | Client client = new Client(System.getenv("CHROMA_URL")); 413 | String encodedString = Base64.getEncoder().encodeToString("admin:admin".getBytes()); 414 | client.setDefaultHeaders(new HashMap<>() {{ 415 | put("Authorization", "Bearer test-token"); 416 | }}); 417 | // your code here 418 | } catch (Exception e) { 419 | System.out.println(e); 420 | } 421 | } 422 | } 423 | ``` 424 | 425 | **Static Auth - X-Chroma-Token**: 426 | 427 | ```java 428 | package tech.amikos; 429 | 430 | import tech.amikos.chromadb.*; 431 | import tech.amikos.chromadb.Collection; 432 | 433 | import java.util.*; 434 | 435 | public class Main { 436 | public static void main(String[] args) { 437 | try { 438 | Client client = new Client(System.getenv("CHROMA_URL")); 439 | String encodedString = Base64.getEncoder().encodeToString("admin:admin".getBytes()); 440 | client.setDefaultHeaders(new HashMap<>() {{ 441 | put("X-Chroma-Token", "test-token"); 442 | }}); 443 | // your code here 444 | } catch (Exception e) { 445 | System.out.println(e); 446 | } 447 | } 448 | } 449 | ``` 450 | 451 | ## Development Notes 452 | 453 | We have made some minor changes on top of the ChromaDB API (`src/main/resources/openapi/api.yaml`) so that the API can 454 | work with Java and Swagger Codegen. The reason is that statically type languages like Java don't like the `anyOf` 455 | and `oneOf` keywords (This also is the reason why we don't use the generated java client for OpenAI API). 456 | 457 | ## Contributing 458 | 459 | Pull requests are welcome. 460 | 461 | ## References 462 | 463 | - https://docs.trychroma.com/ - Official Chroma documentation 464 | - https://github.com/amikos-tech/chromadb-chart - Chroma Helm chart for cloud-native deployments 465 | - https://github.com/openai/openai-openapi - OpenAI OpenAPI specification (While we don't use it to generate a client 466 | for Java, it helps us understand the API better) 467 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | Chroma Vector DB Java Client Library 7 | io.github.amikos-tech 8 | chromadb-java-client 9 | 0.1.7 10 | 11 | Chroma Vector DB Java Client 12 | 13 | https://github.com/amikos-tech/chromadb-java-client 14 | 15 | GitHub 16 | https://github.com/amikos-tech/chromadb-java-client/issues 17 | 18 | 19 | 20 | tazarov 21 | Trayan Azarov 22 | opensource@amikos.tech 23 | Amikos Tech OOD 24 | 25 | 26 | 27 | Amikos Tech OOD 28 | https://amikos.tech 29 | 30 | 31 | scm:git:git://github.com/amikos-tech/chromadb-java-client.git 32 | scm:git:ssh://git@github.com:amikos-tech/chromadb-java-client.git 33 | https://github.com/amikos-tech/chromadb-java-client/tree/main 34 | 35 | 36 | MIT 37 | https://opensource.org/licenses/MIT 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | ossrh 52 | https://s01.oss.sonatype.org/content/repositories/snapshots 53 | 54 | 55 | ossrh 56 | https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ 57 | 58 | 59 | 60 | 1.8 61 | ${java.version} 62 | ${java.version} 63 | 1.8.5 64 | 1.6.9 65 | 4.12.0 66 | 2.10.1 67 | 1.7.0 68 | 1.0.0 69 | 4.13.2 70 | 3.11.1 71 | UTF-8 72 | 73 | 74 | 75 | 76 | javax.annotation 77 | javax.annotation-api 78 | 1.3.2 79 | 80 | 81 | joda-time 82 | joda-time 83 | 2.10.10 84 | 85 | 86 | io.swagger.core.v3 87 | swagger-annotations 88 | 2.2.10 89 | 90 | 91 | com.squareup.okhttp3 92 | okhttp 93 | ${okhttp-version} 94 | 95 | 96 | com.squareup.okhttp3 97 | logging-interceptor 98 | ${okhttp-version} 99 | 100 | 101 | com.google.code.gson 102 | gson 103 | ${gson-version} 104 | 105 | 106 | io.gsonfire 107 | gson-fire 108 | ${gson-fire-version} 109 | 110 | 111 | org.threeten 112 | threetenbp 113 | ${threetenbp-version} 114 | 115 | 116 | commons-net 117 | commons-net 118 | ${commons-net-version} 119 | 120 | 121 | ai.djl.huggingface 122 | tokenizers 123 | 0.29.0 124 | 125 | 126 | com.microsoft.onnxruntime 127 | onnxruntime 128 | 1.18.0 129 | 130 | 131 | commons-io 132 | commons-io 133 | 2.16.1 134 | 135 | 136 | org.nd4j 137 | nd4j-native-platform 138 | 1.0.0-M2 139 | 140 | 141 | org.apache.commons 142 | commons-compress 143 | 1.27.0 144 | 145 | 146 | 147 | junit 148 | junit 149 | ${junit-version} 150 | test 151 | 152 | 153 | com.github.tomakehurst 154 | wiremock-jre8 155 | 2.35.2 156 | test 157 | 158 | 159 | org.testcontainers 160 | testcontainers-bom 161 | 1.20.1 162 | pom 163 | import 164 | 165 | 166 | org.testcontainers 167 | chromadb 168 | 1.20.1 169 | test 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | io.swagger.codegen.v3 206 | swagger-codegen-maven-plugin 207 | 3.0.46 208 | 209 | 210 | com.github.jknack 211 | handlebars 212 | 4.3.0 213 | 214 | 215 | 216 | 217 | chroma-api 218 | 219 | generate 220 | 221 | 222 | ${project.basedir}/src/main/resources/openapi/api.yaml 223 | java 224 | 225 | okhttp4-gson 226 | 227 | joda 228 | 229 | 230 | tech.amikos.chromadb.handler 231 | tech.amikos.chromadb.model 232 | tech.amikos.chromadb.handler 233 | 234 | 235 | 236 | 237 | CreateEmbeddingRequest=tech.amikos.chromadb.embeddings.openai.CreateEmbeddingRequest 238 | 239 | 240 | false 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | org.codehaus.mojo 291 | build-helper-maven-plugin 292 | 3.2.0 293 | 294 | 295 | add_sources 296 | generate-sources 297 | 298 | add-source 299 | 300 | 301 | 302 | src/main/java 303 | target/generated-sources/swagger/src/main/java 304 | 305 | 306 | 307 | 308 | add_test_sources 309 | generate-test-sources 310 | 311 | add-test-source 312 | 313 | 314 | 315 | src/test/java 316 | 317 | 318 | 319 | 320 | 321 | 322 | org.apache.maven.plugins 323 | maven-jar-plugin 324 | 2.2 325 | 326 | 327 | 328 | jar 329 | test-jar 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | org.apache.maven.plugins 338 | maven-dependency-plugin 339 | 3.6.0 340 | 341 | 342 | package 343 | 344 | copy-dependencies 345 | 346 | 347 | ${project.build.directory}/lib 348 | 349 | 350 | 351 | 352 | 353 | org.apache.maven.plugins 354 | maven-gpg-plugin 355 | 1.5 356 | 357 | 358 | sign-artifacts 359 | verify 360 | 361 | sign 362 | 363 | 364 | 365 | --pinentry-mode 366 | loopback 367 | 368 | 369 | 370 | 371 | 372 | 373 | org.apache.maven.plugins 374 | maven-source-plugin 375 | 2.2.1 376 | 377 | 378 | attach-sources 379 | 380 | jar-no-fork 381 | 382 | 383 | 384 | 385 | 386 | org.apache.maven.plugins 387 | maven-javadoc-plugin 388 | 2.9.1 389 | 390 | 391 | attach-javadocs 392 | 393 | jar 394 | 395 | 396 | 397 | 398 | 399 | net.nicoulaj.maven.plugins 400 | checksum-maven-plugin 401 | 1.7 402 | 403 | 404 | 405 | files 406 | 407 | package 408 | 409 | 410 | MD5 411 | SHA-512 412 | 413 | 414 | 415 | ${project.build.directory} 416 | 417 | **/*.jar 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | org.codehaus.mojo 427 | versions-maven-plugin 428 | 2.8.1 429 | 430 | 431 | org.sonatype.plugins 432 | nexus-staging-maven-plugin 433 | 1.6.7 434 | true 435 | 436 | ossrh 437 | https://s01.oss.sonatype.org/ 438 | true 439 | 440 | 441 | 442 | org.apache.maven.plugins 443 | maven-enforcer-plugin 444 | 3.5.0 445 | 446 | 447 | enforce-java 448 | 449 | enforce 450 | 451 | 452 | 453 | 454 | [1.8,) 455 | 456 | 457 | true 458 | 459 | 460 | 461 | 462 | 463 | org.apache.maven.plugins 464 | maven-compiler-plugin 465 | 3.8.1 466 | 467 | ${maven.compiler.source} 468 | ${maven.compiler.target} 469 | 470 | 471 | 472 | 473 | 474 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/ChromaException.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb; 2 | 3 | public class ChromaException extends Exception { 4 | public ChromaException(String message) { 5 | super(message); 6 | } 7 | 8 | public ChromaException(String message, Throwable cause) { 9 | super(message, cause); 10 | } 11 | 12 | public ChromaException(Throwable cause) { 13 | super(cause); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/Client.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb; 2 | 3 | import com.google.gson.internal.LinkedTreeMap; 4 | import tech.amikos.chromadb.embeddings.EmbeddingFunction; 5 | import tech.amikos.chromadb.handler.ApiClient; 6 | import tech.amikos.chromadb.handler.ApiException; 7 | import tech.amikos.chromadb.handler.DefaultApi; 8 | import tech.amikos.chromadb.model.CreateCollection; 9 | 10 | import java.math.BigDecimal; 11 | import java.util.List; 12 | import java.util.Map; 13 | import java.util.stream.Collectors; 14 | 15 | /** 16 | * ChromaDB Client 17 | */ 18 | public class Client { 19 | final ApiClient apiClient = new ApiClient(); 20 | private int timeout = 60; 21 | DefaultApi api; 22 | 23 | public Client(String basePath) { 24 | apiClient.setBasePath(basePath); 25 | api = new DefaultApi(apiClient); 26 | apiClient.setHttpClient(apiClient.getHttpClient().newBuilder() 27 | .readTimeout(this.timeout, java.util.concurrent.TimeUnit.SECONDS) 28 | .writeTimeout(this.timeout, java.util.concurrent.TimeUnit.SECONDS) 29 | .build()); 30 | api.getApiClient().setUserAgent("Chroma-JavaClient/0.1.x"); 31 | } 32 | 33 | /** 34 | * Set the timeout for the client 35 | * @param timeout timeout in seconds 36 | */ 37 | public void setTimeout(int timeout) { 38 | this.timeout = timeout; 39 | apiClient.setHttpClient(apiClient.getHttpClient().newBuilder() 40 | .readTimeout(this.timeout, java.util.concurrent.TimeUnit.SECONDS) 41 | .writeTimeout(this.timeout, java.util.concurrent.TimeUnit.SECONDS) 42 | .build()); 43 | } 44 | 45 | /** 46 | * Set the default headers for the client to be sent with every request 47 | * @param headers 48 | */ 49 | public void setDefaultHeaders(Map headers) { 50 | for (Map.Entry entry : headers.entrySet()) { 51 | apiClient.addDefaultHeader(entry.getKey(), entry.getValue()); 52 | } 53 | } 54 | 55 | public Collection getCollection(String collectionName, EmbeddingFunction embeddingFunction) throws ApiException { 56 | return new Collection(api, collectionName, embeddingFunction).fetch(); 57 | } 58 | 59 | public Map heartbeat() throws ApiException { 60 | return api.heartbeat(); 61 | } 62 | 63 | public Collection createCollection(String collectionName, Map metadata, Boolean createOrGet, EmbeddingFunction embeddingFunction) throws ApiException { 64 | CreateCollection req = new CreateCollection(); 65 | req.setName(collectionName); 66 | Map _metadata = metadata; 67 | if (metadata == null || metadata.isEmpty() || !metadata.containsKey("embedding_function")) { 68 | _metadata = new LinkedTreeMap<>(); 69 | _metadata.put("embedding_function", embeddingFunction.getClass().getName()); 70 | } 71 | req.setMetadata(_metadata); 72 | req.setGetOrCreate(createOrGet); 73 | LinkedTreeMap resp = (LinkedTreeMap) api.createCollection(req); 74 | return new Collection(api, (String) resp.get("name"), embeddingFunction).fetch(); 75 | } 76 | 77 | public Collection deleteCollection(String collectionName) throws ApiException { 78 | Collection collection = Collection.getInstance(api, collectionName); 79 | api.deleteCollection(collectionName); 80 | return collection; 81 | } 82 | 83 | public Collection upsert(String collectionName, EmbeddingFunction ef) throws ApiException { 84 | Collection collection = getCollection(collectionName, ef); 85 | // collection.upsert(); 86 | return collection; 87 | } 88 | 89 | public Boolean reset() throws ApiException { 90 | return api.reset(); 91 | } 92 | 93 | public List listCollections() throws ApiException { 94 | List apiResponse = (List) api.listCollections(); 95 | return apiResponse.stream().map((LinkedTreeMap m) -> { 96 | try { 97 | return getCollection((String) m.get("name"), null); 98 | } catch (ApiException e) { 99 | e.printStackTrace(); //this is not great as we're swallowing the exception 100 | } 101 | return null; 102 | }).collect(Collectors.toList()); 103 | } 104 | 105 | public String version() throws ApiException { 106 | return api.version(); 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/Collection.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb; 2 | 3 | import com.google.gson.Gson; 4 | import com.google.gson.annotations.SerializedName; 5 | import com.google.gson.internal.LinkedTreeMap; 6 | import tech.amikos.chromadb.embeddings.EmbeddingFunction; 7 | import tech.amikos.chromadb.handler.ApiException; 8 | import tech.amikos.chromadb.handler.DefaultApi; 9 | import tech.amikos.chromadb.model.*; 10 | 11 | import java.util.List; 12 | import java.util.Map; 13 | import java.util.stream.Collectors; 14 | 15 | import static java.lang.Thread.sleep; 16 | 17 | public class Collection { 18 | static Gson gson = new Gson(); 19 | DefaultApi api; 20 | String collectionName; 21 | 22 | String collectionId; 23 | 24 | LinkedTreeMap metadata = new LinkedTreeMap<>(); 25 | 26 | private EmbeddingFunction embeddingFunction; 27 | 28 | public Collection(DefaultApi api, String collectionName, EmbeddingFunction embeddingFunction) { 29 | this.api = api; 30 | this.collectionName = collectionName; 31 | this.embeddingFunction = embeddingFunction; 32 | 33 | } 34 | 35 | public String getName() { 36 | return collectionName; 37 | } 38 | 39 | public String getId() { 40 | return collectionId; 41 | } 42 | 43 | public Map getMetadata() { 44 | return metadata; 45 | } 46 | 47 | public Collection fetch() throws ApiException { 48 | try { 49 | LinkedTreeMap resp = (LinkedTreeMap) api.getCollection(collectionName); 50 | this.collectionName = resp.get("name").toString(); 51 | this.collectionId = resp.get("id").toString(); 52 | this.metadata = (LinkedTreeMap) resp.get("metadata"); 53 | return this; 54 | } catch (ApiException e) { 55 | throw e; 56 | } 57 | } 58 | 59 | public static Collection getInstance(DefaultApi api, String collectionName) throws ApiException { 60 | return new Collection(api, collectionName, null); 61 | } 62 | 63 | @Override 64 | public String toString() { 65 | return "Collection{" + 66 | "collectionName='" + collectionName + '\'' + 67 | ", collectionId='" + collectionId + '\'' + 68 | ", metadata=" + metadata + 69 | '}'; 70 | } 71 | 72 | public GetResult get(List ids, Map where, Map whereDocument) throws ApiException { 73 | GetEmbedding req = new GetEmbedding(); 74 | req.ids(ids).where(where).whereDocument(whereDocument); 75 | Gson gson = new Gson(); 76 | String json = gson.toJson(api.get(req, this.collectionId)); 77 | return new Gson().fromJson(json, GetResult.class); 78 | } 79 | 80 | public GetResult get() throws ApiException { 81 | return this.get(null, null, null); 82 | } 83 | 84 | public Object delete() throws ApiException { 85 | return this.delete(null, null, null); 86 | } 87 | 88 | public Object upsert(List embeddings, List> metadatas, List documents, List ids) throws ChromaException { 89 | AddEmbedding req = new AddEmbedding(); 90 | List _embeddings = embeddings; 91 | if (_embeddings == null) { 92 | _embeddings = this.embeddingFunction.embedDocuments(documents); 93 | } 94 | req.setEmbeddings(_embeddings.stream().map(Embedding::asArray).collect(Collectors.toList())); 95 | req.setMetadatas((List>) (Object) metadatas); 96 | req.setDocuments(documents); 97 | req.incrementIndex(true); 98 | req.setIds(ids); 99 | try { 100 | return api.upsert(req, this.collectionId); 101 | } catch (ApiException e) { 102 | throw new ChromaException(e); 103 | } 104 | } 105 | 106 | 107 | public Object add(List embeddings, List> metadatas, List documents, List ids) throws ChromaException { 108 | AddEmbedding req = new AddEmbedding(); 109 | List _embeddings = embeddings; 110 | if (_embeddings == null) { 111 | _embeddings = this.embeddingFunction.embedDocuments(documents); 112 | } 113 | req.setEmbeddings(_embeddings.stream().map(Embedding::asArray).collect(Collectors.toList())); 114 | req.setMetadatas((List>) (Object) metadatas); 115 | req.setDocuments(documents); 116 | req.incrementIndex(true); 117 | req.setIds(ids); 118 | try { 119 | return api.add(req, this.collectionId); 120 | } catch (ApiException e) { 121 | throw new ChromaException(e); 122 | } 123 | } 124 | 125 | public Integer count() throws ApiException { 126 | return api.count(this.collectionId); 127 | } 128 | 129 | public Object delete(List ids, Map where, Map whereDocument) throws ApiException { 130 | DeleteEmbedding req = new DeleteEmbedding(); 131 | req.setIds(ids); 132 | if (where != null) { 133 | req.where(where.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); 134 | } 135 | if (whereDocument != null) { 136 | req.whereDocument(whereDocument.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); 137 | } 138 | return api.delete(req, this.collectionId); 139 | } 140 | 141 | public Object deleteWithIds(List ids) throws ApiException { 142 | return delete(ids, null, null); 143 | } 144 | 145 | public Object deleteWhere(Map where) throws ApiException { 146 | return delete(null, where, null); 147 | } 148 | 149 | public Object deleteWhereWhereDocuments(Map where, Map whereDocument) throws ApiException { 150 | return delete(null, where, whereDocument); 151 | } 152 | 153 | public Object deleteWhereDocuments(Map whereDocument) throws ApiException { 154 | return delete(null, null, whereDocument); 155 | } 156 | 157 | 158 | public Object update(String newName, Map newMetadata) throws ApiException { 159 | UpdateCollection req = new UpdateCollection(); 160 | if (newName != null) { 161 | req.setNewName(newName); 162 | } 163 | if (newMetadata != null && embeddingFunction != null) { 164 | if (!newMetadata.containsKey("embedding_function")) { 165 | newMetadata.put("embedding_function", embeddingFunction.getClass().getName()); 166 | } 167 | req.setNewMetadata(newMetadata); 168 | } 169 | Object resp = api.updateCollection(req, this.collectionId); 170 | this.collectionName = newName; 171 | this.fetch(); //do we really need to fetch? 172 | return resp; 173 | } 174 | 175 | public Object updateEmbeddings(List embeddings, List> metadatas, List documents, List ids) throws ChromaException { 176 | UpdateEmbedding req = new UpdateEmbedding(); 177 | List _embeddings = embeddings; 178 | if (_embeddings == null) { 179 | _embeddings = this.embeddingFunction.embedDocuments(documents); 180 | } 181 | req.setEmbeddings(_embeddings.stream().map(Embedding::asArray).collect(Collectors.toList())); 182 | req.setDocuments(documents); 183 | req.setMetadatas((List) (Object) metadatas); 184 | req.setIds(ids); 185 | try { 186 | return api.update(req, this.collectionId); 187 | } catch (ApiException e) { 188 | throw new ChromaException(e); 189 | } 190 | } 191 | 192 | 193 | public QueryResponse query(List queryTexts, Integer nResults, Map where, Map whereDocument, List include) throws ChromaException { 194 | QueryEmbedding body = new QueryEmbedding(); 195 | body.queryEmbeddings(this.embeddingFunction.embedDocuments(queryTexts).stream().map(Embedding::asArray).collect(Collectors.toList())); 196 | body.nResults(nResults); 197 | body.include(include); 198 | if (where != null) { 199 | body.where(where.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); 200 | } 201 | if (whereDocument != null) { 202 | body.whereDocument(whereDocument.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); 203 | } 204 | try { 205 | Gson gson = new Gson(); 206 | String json = gson.toJson(api.getNearestNeighbors(body, this.collectionId)); 207 | return new Gson().fromJson(json, QueryResponse.class); 208 | } catch (ApiException e) { 209 | throw new ChromaException(e); 210 | } 211 | } 212 | 213 | public static class QueryResponse { 214 | @SerializedName("documents") 215 | private List> documents; 216 | @SerializedName("embeddings") 217 | private List> embeddings; 218 | @SerializedName("ids") 219 | private List> ids; 220 | @SerializedName("metadatas") 221 | private List>> metadatas; 222 | @SerializedName("distances") 223 | private List> distances; 224 | 225 | public List> getDocuments() { 226 | return documents; 227 | } 228 | 229 | public List> getEmbeddings() { 230 | return embeddings; 231 | } 232 | 233 | public List> getIds() { 234 | return ids; 235 | } 236 | 237 | public List>> getMetadatas() { 238 | return metadatas; 239 | } 240 | 241 | public List> getDistances() { 242 | return distances; 243 | } 244 | 245 | @Override 246 | public String toString() { 247 | return new Gson().toJson(this); 248 | } 249 | 250 | 251 | } 252 | 253 | public static class GetResult { 254 | @SerializedName("documents") 255 | private List documents; 256 | @SerializedName("embeddings") 257 | private List embeddings; 258 | @SerializedName("ids") 259 | private List ids; 260 | @SerializedName("metadatas") 261 | private List> metadatas; 262 | 263 | public List getDocuments() { 264 | return documents; 265 | } 266 | 267 | public List getEmbeddings() { 268 | return embeddings; 269 | } 270 | 271 | public List getIds() { 272 | return ids; 273 | } 274 | 275 | public List> getMetadatas() { 276 | return metadatas; 277 | } 278 | 279 | @Override 280 | public String toString() { 281 | return new Gson().toJson(this); 282 | } 283 | } 284 | } 285 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/Constants.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb; 2 | 3 | import okhttp3.MediaType; 4 | 5 | public class Constants { 6 | 7 | public static final String EF_PARAMS_BASE_API = "baseAPI"; 8 | public static final String EF_PARAMS_MODEL = "modelName"; 9 | public static final String EF_PARAMS_API_KEY = "apiKey"; 10 | public static final String EF_PARAMS_API_KEY_FROM_ENV = "envAPIKey"; 11 | public static final String MODEL_NAME = "MODEL_NAME"; 12 | public static final MediaType JSON = MediaType.parse("application/json; charset=utf-8"); 13 | public static final String HTTP_AGENT = "chroma-java-client"; 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/EFException.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb; 2 | 3 | /** 4 | * This exception encapsulates all exceptions thrown by the EmbeddingFunction class. 5 | */ 6 | public class EFException extends ChromaException { 7 | public EFException(String message) { 8 | super(message); 9 | } 10 | 11 | public EFException(String message, Throwable cause) { 12 | super(message, cause); 13 | } 14 | 15 | public EFException(Throwable cause) { 16 | super(cause); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/Embedding.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb; 2 | 3 | import java.util.List; 4 | import java.util.stream.Collectors; 5 | import java.util.stream.IntStream; 6 | 7 | public class Embedding { 8 | private final float[] embedding; 9 | 10 | public Embedding(float[] embeddings) { 11 | this.embedding = embeddings; 12 | } 13 | 14 | public Embedding(List embedding) { 15 | this.embedding = new float[embedding.size()]; 16 | for (int i = 0; i < embedding.size(); i++) { 17 | //TODO what if embeddings are integers? 18 | this.embedding[i] = embedding.get(i).floatValue(); 19 | } 20 | } 21 | 22 | 23 | public List asList() { 24 | return IntStream.range(0, embedding.length) 25 | .mapToObj(i -> embedding[i]) 26 | .collect(Collectors.toList()); 27 | 28 | } 29 | 30 | public int getDimensions() { 31 | return embedding.length; 32 | } 33 | 34 | public float[] asArray() { 35 | return embedding; 36 | } 37 | 38 | public static Embedding fromList(List embedding) { 39 | return new Embedding(embedding); 40 | } 41 | 42 | public static Embedding fromArray(float[] embedding) { 43 | return new Embedding(embedding); 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/DefaultEmbeddingFunction.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings; 2 | 3 | import ai.djl.huggingface.tokenizers.Encoding; 4 | import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; 5 | import ai.onnxruntime.*; 6 | 7 | import java.util.stream.Collectors; 8 | import java.util.zip.GZIPInputStream; 9 | 10 | import org.apache.commons.compress.archivers.tar.*; 11 | import org.nd4j.linalg.api.ndarray.INDArray; 12 | import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue; 13 | import org.nd4j.linalg.factory.Nd4j; 14 | import org.nd4j.shade.guava.primitives.Floats; 15 | import tech.amikos.chromadb.EFException; 16 | import tech.amikos.chromadb.Embedding; 17 | 18 | import java.io.*; 19 | import java.net.URL; 20 | import java.nio.LongBuffer; 21 | import java.nio.file.Files; 22 | import java.nio.file.Path; 23 | import java.nio.file.Paths; 24 | import java.nio.file.StandardCopyOption; 25 | import java.security.MessageDigest; 26 | import java.security.NoSuchAlgorithmException; 27 | import java.util.*; 28 | 29 | public class DefaultEmbeddingFunction implements EmbeddingFunction { 30 | public static final String MODEL_NAME = "all-MiniLM-L6-v2"; 31 | private static final String ARCHIVE_FILENAME = "onnx.tar.gz"; 32 | private static final String MODEL_DOWNLOAD_URL = "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz"; 33 | private static final String MODEL_SHA256_CHECKSUM = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3"; 34 | public static final Path MODEL_CACHE_DIR = Paths.get(System.getProperty("user.home"), ".cache", "chroma", "onnx_models", MODEL_NAME); 35 | private static final Path modelPath = Paths.get(MODEL_CACHE_DIR.toString(), "onnx"); 36 | private static final Path modelFile = Paths.get(modelPath.toString(), "model.onnx"); 37 | private final HuggingFaceTokenizer tokenizer; 38 | private final OrtEnvironment env; 39 | final OrtSession session; 40 | 41 | public static float[][] normalize(float[][] v) { 42 | int rows = v.length; 43 | int cols = v[0].length; 44 | float[] norm = new float[rows]; 45 | 46 | // Step 1: Compute the L2 norm of each row 47 | for (int i = 0; i < rows; i++) { 48 | float sum = 0; 49 | for (int j = 0; j < cols; j++) { 50 | sum += v[i][j] * v[i][j]; 51 | } 52 | norm[i] = (float) Math.sqrt(sum); 53 | } 54 | 55 | // Step 2: Handle zero norms 56 | for (int i = 0; i < rows; i++) { 57 | if (norm[i] == 0) { 58 | norm[i] = 1e-12f; 59 | } 60 | } 61 | 62 | // Step 3: Normalize each row 63 | float[][] normalized = new float[rows][cols]; 64 | for (int i = 0; i < rows; i++) { 65 | for (int j = 0; j < cols; j++) { 66 | normalized[i][j] = v[i][j] / norm[i]; 67 | } 68 | } 69 | return normalized; 70 | } 71 | 72 | public DefaultEmbeddingFunction() throws EFException { 73 | if (!validateModel()) { 74 | downloadAndSetupModel(); 75 | } 76 | 77 | Map tokenizerConfig = Collections.unmodifiableMap(new HashMap() {{ 78 | put("padding", "MAX_LENGTH"); 79 | put("maxLength", "256"); 80 | }}); 81 | 82 | try { 83 | tokenizer = HuggingFaceTokenizer.newInstance(modelPath, tokenizerConfig); 84 | 85 | this.env = OrtEnvironment.getEnvironment(); 86 | OrtSession.SessionOptions options = new OrtSession.SessionOptions(); 87 | this.session = env.createSession(modelFile.toString(), options); 88 | } catch (OrtException | IOException e) { 89 | throw new EFException(e); 90 | } 91 | } 92 | 93 | public List> forward(List documents) throws OrtException { 94 | Encoding[] e = tokenizer.batchEncode(documents, true, false); 95 | ArrayList inputIds = new ArrayList<>(); 96 | ArrayList attentionMask = new ArrayList<>(); 97 | ArrayList tokenIdtypes = new ArrayList<>(); 98 | int maxIds = 0; 99 | for (Encoding encoding : e) { 100 | maxIds = Math.max(maxIds, encoding.getIds().length); 101 | inputIds.addAll(Arrays.asList(Arrays.stream(encoding.getIds()).boxed().toArray(Long[]::new))); 102 | attentionMask.addAll(Arrays.asList(Arrays.stream(encoding.getAttentionMask()).boxed().toArray(Long[]::new))); 103 | tokenIdtypes.addAll(Arrays.asList(Arrays.stream(encoding.getTypeIds()).boxed().toArray(Long[]::new))); 104 | } 105 | long[] inputShape = new long[]{e.length, maxIds}; 106 | OnnxTensor inputTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(inputIds.stream().mapToLong(i -> i).toArray()), inputShape); 107 | OnnxTensor attentionTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(attentionMask.stream().mapToLong(i -> i).toArray()), inputShape); 108 | OnnxTensor _tokenIdtypes = OnnxTensor.createTensor(env, LongBuffer.wrap(tokenIdtypes.stream().mapToLong(i -> i).toArray()), inputShape); 109 | // Inputs for all-MiniLM-L6-v2 model 110 | Map inputs = Collections.unmodifiableMap(new HashMap() {{ 111 | put("input_ids", inputTensor); 112 | put("attention_mask", attentionTensor); 113 | put("token_type_ids", _tokenIdtypes); 114 | }}); 115 | INDArray lastHiddenState = null; 116 | try (OrtSession.Result results = session.run(inputs)) { 117 | lastHiddenState = Nd4j.create((float[][][]) results.get(0).getValue()); 118 | 119 | } 120 | INDArray attMask = Nd4j.create(attentionMask.stream().mapToDouble(i -> i).toArray(), inputShape, 'c'); 121 | INDArray expandedMask = Nd4j.expandDims(attMask, 2).broadcast(lastHiddenState.shape()); 122 | INDArray summed = lastHiddenState.mul(expandedMask).sum(1); 123 | INDArray[] clippedSumMask = Nd4j.getExecutioner().exec( 124 | new ClipByValue(expandedMask.sum(1), 1e-9, Double.MAX_VALUE) 125 | ); 126 | INDArray embeddings = summed.div(clippedSumMask[0]); 127 | float[][] embeddingsArray = normalize(embeddings.toFloatMatrix()); 128 | List> embeddingsList = new ArrayList<>(); 129 | for (float[] embedding : embeddingsArray) { 130 | embeddingsList.add(Floats.asList(embedding)); 131 | } 132 | return embeddingsList; 133 | } 134 | 135 | private static String getSHA256Checksum(String filePath) throws IOException, NoSuchAlgorithmException { 136 | MessageDigest digest = MessageDigest.getInstance("SHA-256"); 137 | try (FileInputStream fis = new FileInputStream(filePath)) { 138 | byte[] byteArray = new byte[1024]; 139 | int bytesCount; 140 | while ((bytesCount = fis.read(byteArray)) != -1) { 141 | digest.update(byteArray, 0, bytesCount); 142 | } 143 | } 144 | byte[] bytes = digest.digest(); 145 | StringBuilder sb = new StringBuilder(); 146 | for (byte b : bytes) { 147 | sb.append(String.format("%02x", b)); 148 | } 149 | return sb.toString(); 150 | } 151 | 152 | private static void extractTarGz(Path tarGzPath, Path extractDir) throws IOException { 153 | try (InputStream fileIn = Files.newInputStream(tarGzPath); 154 | GZIPInputStream gzipIn = new GZIPInputStream(fileIn); 155 | TarArchiveInputStream tarIn = new TarArchiveInputStream(gzipIn)) { 156 | 157 | TarArchiveEntry entry; 158 | while ((entry = tarIn.getNextTarEntry()) != null) { 159 | Path entryPath = extractDir.resolve(entry.getName()); 160 | if (entry.isDirectory()) { 161 | Files.createDirectories(entryPath); 162 | } else { 163 | Files.createDirectories(entryPath.getParent()); 164 | try (OutputStream out = Files.newOutputStream(entryPath)) { 165 | byte[] buffer = new byte[1024]; 166 | int len; 167 | while ((len = tarIn.read(buffer)) != -1) { 168 | out.write(buffer, 0, len); 169 | } 170 | } 171 | } 172 | } 173 | } 174 | } 175 | 176 | private void downloadAndSetupModel() throws EFException { 177 | try (InputStream in = new URL(MODEL_DOWNLOAD_URL).openStream()) { 178 | if (!Files.exists(MODEL_CACHE_DIR)) { 179 | Files.createDirectories(MODEL_CACHE_DIR); 180 | } 181 | Path archivePath = Paths.get(MODEL_CACHE_DIR.toString(), ARCHIVE_FILENAME); 182 | if (!archivePath.toFile().exists()) { 183 | System.err.println("Model not found under " + archivePath + ". Downloading..."); 184 | Files.copy(in, archivePath, StandardCopyOption.REPLACE_EXISTING); 185 | } 186 | if (!MODEL_SHA256_CHECKSUM.equals(getSHA256Checksum(archivePath.toString()))) { 187 | throw new RuntimeException("Checksum does not match. Delete the whole directory " + MODEL_CACHE_DIR + " and try again."); 188 | } 189 | extractTarGz(archivePath, MODEL_CACHE_DIR); 190 | archivePath.toFile().delete(); 191 | } catch (IOException | NoSuchAlgorithmException e) { 192 | throw new EFException(e); 193 | } 194 | } 195 | 196 | 197 | /** 198 | * Check if the model is present at the expected location 199 | * 200 | * @return 201 | */ 202 | private boolean validateModel() { 203 | return modelFile.toFile().exists() && modelFile.toFile().isFile(); 204 | } 205 | 206 | @Override 207 | public Embedding embedQuery(String query) throws EFException { 208 | try { 209 | return Embedding.fromList(forward(Collections.singletonList(query)).get(0)); 210 | } catch (OrtException e) { 211 | throw new EFException(e); 212 | } 213 | } 214 | 215 | @Override 216 | public List embedDocuments(List documents) throws EFException { 217 | try { 218 | return forward(documents).stream().map(Embedding::new).collect(Collectors.toList()); 219 | } catch (OrtException e) { 220 | throw new EFException(e); 221 | } 222 | } 223 | 224 | @Override 225 | public List embedDocuments(String[] documents) throws EFException { 226 | return embedDocuments(Arrays.asList(documents)); 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/EmbeddingFunction.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings; 2 | 3 | import tech.amikos.chromadb.EFException; 4 | import tech.amikos.chromadb.Embedding; 5 | 6 | import java.util.List; 7 | 8 | public interface EmbeddingFunction { 9 | 10 | Embedding embedQuery(String query) throws EFException; 11 | 12 | List embedDocuments(List documents) throws EFException; 13 | 14 | List embedDocuments(String[] documents) throws EFException; 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/WithParam.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings; 2 | 3 | 4 | import tech.amikos.chromadb.Constants; 5 | import tech.amikos.chromadb.EFException; 6 | 7 | import java.util.Map; 8 | 9 | public abstract class WithParam { 10 | public abstract void apply(Map params) throws EFException; 11 | 12 | public static WithParam apiKey(String apiKey) { 13 | return new WithAPIKey(apiKey); 14 | } 15 | 16 | public static WithParam apiKeyFromEnv(String apiKeyEnvVarName) { 17 | return new WithEnvAPIKey(apiKeyEnvVarName); 18 | } 19 | 20 | public static WithParam model(String model) { 21 | return new WithModel(model); 22 | } 23 | 24 | public static WithParam modelFromEnv(String modelEnvVarName) { 25 | return new WithModelFromEnv(modelEnvVarName); 26 | } 27 | 28 | public static WithParam baseAPI(String baseAPI) { 29 | return new WithBaseAPI(baseAPI); 30 | } 31 | 32 | public static WithParam defaultModel(String model) { 33 | return new WithDefaultModel(model); 34 | } 35 | 36 | 37 | } 38 | 39 | class WithBaseAPI extends WithParam { 40 | private final String baseAPI; 41 | 42 | public WithBaseAPI(String baseAPI) { 43 | this.baseAPI = baseAPI; 44 | } 45 | 46 | @Override 47 | public void apply(Map params) { 48 | params.put(Constants.EF_PARAMS_BASE_API, baseAPI); 49 | } 50 | } 51 | 52 | class WithModel extends WithParam { 53 | private final String model; 54 | 55 | public WithModel(String model) { 56 | this.model = model; 57 | } 58 | 59 | @Override 60 | public void apply(Map params) { 61 | params.put(Constants.EF_PARAMS_MODEL, model); 62 | } 63 | } 64 | 65 | class WithModelFromEnv extends WithParam { 66 | 67 | private String modelEnvVarName = Constants.MODEL_NAME; 68 | 69 | public WithModelFromEnv(String modelEnvVarName) { 70 | this.modelEnvVarName = modelEnvVarName; 71 | } 72 | 73 | /** 74 | * Reads MODEL_NAME from the environment 75 | */ 76 | public WithModelFromEnv() { 77 | } 78 | 79 | @Override 80 | public void apply(Map params) throws EFException { 81 | if (System.getenv(modelEnvVarName) == null) { 82 | throw new EFException("Model not found in environment variable: " + modelEnvVarName); 83 | } 84 | params.put(Constants.EF_PARAMS_MODEL, System.getenv(modelEnvVarName)); 85 | } 86 | } 87 | 88 | class WithDefaultModel extends WithParam { 89 | 90 | private final String model; 91 | 92 | public WithDefaultModel(String model) { 93 | this.model = model; 94 | } 95 | 96 | @Override 97 | public void apply(Map params) { 98 | params.put(Constants.EF_PARAMS_MODEL, model); 99 | } 100 | } 101 | 102 | class WithAPIKey extends WithParam { 103 | private final String apiKey; 104 | 105 | public WithAPIKey(String apiKey) { 106 | this.apiKey = apiKey; 107 | } 108 | 109 | @Override 110 | public void apply(Map params) { 111 | params.put(Constants.EF_PARAMS_API_KEY, apiKey); 112 | } 113 | } 114 | 115 | class WithEnvAPIKey extends WithParam { 116 | private final String apiKeyEnvVarName; 117 | 118 | public WithEnvAPIKey(String apiKeyEnvVarName) { 119 | this.apiKeyEnvVarName = apiKeyEnvVarName; 120 | } 121 | 122 | @Override 123 | public void apply(Map params) throws EFException { 124 | if (System.getenv(apiKeyEnvVarName) == null) { 125 | throw new EFException("API Key not found in environment variable: " + apiKeyEnvVarName); 126 | } 127 | params.put(Constants.EF_PARAMS_API_KEY_FROM_ENV, System.getenv(apiKeyEnvVarName)); 128 | } 129 | } -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/cohere/CohereEmbeddingFunction.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.cohere; 2 | 3 | import com.google.gson.Gson; 4 | import okhttp3.OkHttpClient; 5 | import okhttp3.Request; 6 | import okhttp3.RequestBody; 7 | import okhttp3.Response; 8 | import org.jetbrains.annotations.NotNull; 9 | import tech.amikos.chromadb.*; 10 | import tech.amikos.chromadb.embeddings.EmbeddingFunction; 11 | import tech.amikos.chromadb.embeddings.WithParam; 12 | 13 | import java.io.IOException; 14 | import java.util.Arrays; 15 | import java.util.HashMap; 16 | import java.util.List; 17 | import java.util.Map; 18 | import java.util.stream.Collectors; 19 | 20 | import static tech.amikos.chromadb.Constants.JSON; 21 | 22 | public class CohereEmbeddingFunction implements EmbeddingFunction { 23 | public static final String DEFAULT_MODEL_NAME = "embed-english-v2.0"; 24 | public static final String DEFAULT_BASE_API = "https://api.cohere.ai/v1/"; 25 | public static final String COHERE_API_KEY_ENV = "COHERE_API_KEY"; 26 | 27 | private final OkHttpClient client = new OkHttpClient(); 28 | private final Map configParams = new HashMap<>(); 29 | private static final Gson gson = new Gson(); 30 | 31 | 32 | private static final List defaults = Arrays.asList( 33 | WithParam.baseAPI(DEFAULT_BASE_API), 34 | 35 | WithParam.defaultModel(DEFAULT_MODEL_NAME) 36 | ); 37 | 38 | public CohereEmbeddingFunction() throws EFException { 39 | for (WithParam param : defaults) { 40 | param.apply(this.configParams); 41 | } 42 | WithParam.apiKeyFromEnv(COHERE_API_KEY_ENV).apply(this.configParams); 43 | } 44 | 45 | public CohereEmbeddingFunction(WithParam... params) throws EFException { 46 | // apply defaults 47 | for (WithParam param : defaults) { 48 | param.apply(this.configParams); 49 | } 50 | for (WithParam param : params) { 51 | param.apply(this.configParams); 52 | } 53 | } 54 | 55 | 56 | public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) throws EFException { 57 | Request request = new Request.Builder() 58 | .url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString() + "embed") 59 | .post(RequestBody.create(req.json(), JSON)) 60 | .addHeader("Accept", "application/json") 61 | .addHeader("Content-Type", "application/json") 62 | .addHeader("X-Client-Name", Constants.HTTP_AGENT) 63 | .addHeader("User-Agent", Constants.HTTP_AGENT) 64 | .addHeader("Authorization", "Bearer " + configParams.get(Constants.EF_PARAMS_API_KEY).toString()) 65 | .build(); 66 | try (Response response = client.newCall(request).execute()) { 67 | if (!response.isSuccessful()) { 68 | throw new IOException("Unexpected code " + response); 69 | } 70 | 71 | String responseData = response.body().string(); 72 | 73 | return gson.fromJson(responseData, CreateEmbeddingResponse.class); 74 | } catch (IOException e) { 75 | System.out.println(e.getClass()); 76 | throw new EFException(e); 77 | } 78 | } 79 | 80 | @Override 81 | public Embedding embedQuery(String query) throws EFException { 82 | CreateEmbeddingResponse response = createEmbedding( 83 | new CreateEmbeddingRequest() 84 | .model(this.configParams.get(Constants.EF_PARAMS_MODEL).toString()) 85 | .inputType("search_query") 86 | .texts(new String[]{query}) 87 | ); 88 | return new Embedding(response.getEmbeddings().get(0)); 89 | } 90 | 91 | @Override 92 | public List embedDocuments(@NotNull List documents) throws EFException { 93 | CreateEmbeddingResponse response = createEmbedding( 94 | new CreateEmbeddingRequest() 95 | .model(this.configParams.get(Constants.EF_PARAMS_MODEL).toString()) 96 | .inputType("search_document") 97 | .texts(documents.toArray(new String[0])) 98 | ); 99 | return response.getEmbeddings().stream().map(Embedding::new).collect(Collectors.toList()); 100 | } 101 | 102 | @Override 103 | public List embedDocuments(String[] documents) throws EFException { 104 | return embedDocuments(Arrays.asList(documents)); 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/cohere/CreateEmbeddingRequest.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.cohere; 2 | 3 | import com.google.gson.*; 4 | import com.google.gson.annotations.SerializedName; 5 | 6 | import java.lang.reflect.Type; 7 | 8 | public class CreateEmbeddingRequest { 9 | @SerializedName("model") 10 | private String model = "embed-english-v2.0"; 11 | 12 | @SerializedName("texts") 13 | private String[] texts; 14 | 15 | @SerializedName("truncate") 16 | private TruncateMode truncateMode = TruncateMode.END; 17 | @SerializedName("compress") 18 | private Boolean compress; 19 | 20 | @SerializedName("compression_codebook") 21 | private String compressionCodebook; 22 | 23 | @SerializedName("input_type") 24 | private String inputType = "search_document"; 25 | 26 | public enum TruncateMode { 27 | NONE, 28 | START, 29 | END, 30 | } 31 | 32 | public static class TruncateModeSerializer implements JsonSerializer { 33 | @Override 34 | public JsonElement serialize(TruncateMode src, Type typeOfSrc, JsonSerializationContext context) { 35 | return new JsonPrimitive(src.toString().toUpperCase()); 36 | } 37 | } 38 | 39 | //create fluent api for all fields 40 | public CreateEmbeddingRequest model(String model) { 41 | this.model = model; 42 | return this; 43 | } 44 | 45 | public CreateEmbeddingRequest texts(String[] texts) { 46 | this.texts = texts; 47 | return this; 48 | } 49 | 50 | public CreateEmbeddingRequest truncateMode(TruncateMode truncateMode) { 51 | this.truncateMode = truncateMode; 52 | return this; 53 | } 54 | 55 | public CreateEmbeddingRequest compress(Boolean compress) { 56 | this.compress = compress; 57 | return this; 58 | } 59 | 60 | public CreateEmbeddingRequest compressionCodebook(String compressionCodebook) { 61 | this.compressionCodebook = compressionCodebook; 62 | return this; 63 | } 64 | 65 | public CreateEmbeddingRequest inputType(String inputType) { 66 | this.inputType = inputType; 67 | return this; 68 | } 69 | 70 | public String json() { 71 | GsonBuilder gsonBuilder = new GsonBuilder(); 72 | gsonBuilder.registerTypeAdapter(TruncateMode.class, new TruncateModeSerializer()); 73 | Gson customGson = gsonBuilder.create(); 74 | return customGson.toJson(this); 75 | } 76 | 77 | public String toString() { 78 | return json(); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/cohere/CreateEmbeddingResponse.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.cohere; 2 | 3 | import com.google.gson.Gson; 4 | import com.google.gson.annotations.SerializedName; 5 | import com.google.gson.internal.LinkedTreeMap; 6 | 7 | import java.util.List; 8 | 9 | public class CreateEmbeddingResponse { 10 | 11 | @SerializedName("id") 12 | public String id; 13 | 14 | @SerializedName("texts") 15 | public String[] texts; 16 | 17 | @SerializedName("embeddings") 18 | public List> embeddings; 19 | 20 | @SerializedName("meta") 21 | public LinkedTreeMap meta; 22 | 23 | // create getters for all fields 24 | 25 | public String getId() { 26 | return id; 27 | } 28 | 29 | public String[] getTexts() { 30 | return texts; 31 | } 32 | 33 | public List> getEmbeddings() { 34 | return embeddings; 35 | } 36 | 37 | public LinkedTreeMap getMeta() { 38 | return meta; 39 | } 40 | 41 | 42 | @Override 43 | public String toString() { 44 | return new Gson().toJson(this); 45 | } 46 | 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/hf/CreateEmbeddingRequest.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.hf; 2 | 3 | import com.google.gson.Gson; 4 | import com.google.gson.annotations.SerializedName; 5 | 6 | import java.util.HashMap; 7 | 8 | public class CreateEmbeddingRequest { 9 | 10 | @SerializedName("inputs") 11 | private String[] inputs; 12 | @SerializedName("options") 13 | private HashMap options; 14 | 15 | public CreateEmbeddingRequest inputs(String[] inputs) { 16 | this.inputs = inputs; 17 | return this; 18 | } 19 | 20 | public CreateEmbeddingRequest options(HashMap options) { 21 | this.options = options; 22 | return this; 23 | } 24 | 25 | public String[] getInputs() { 26 | return inputs; 27 | } 28 | 29 | public HashMap getOptions() { 30 | return options; 31 | } 32 | 33 | public String toString() { 34 | return this.json(); 35 | } 36 | 37 | public String json() { 38 | return new Gson().toJson(this); 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/hf/CreateEmbeddingResponse.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.hf; 2 | 3 | import com.google.gson.Gson; 4 | 5 | import java.util.List; 6 | 7 | public class CreateEmbeddingResponse { 8 | public List> embeddings; 9 | 10 | public List> getEmbeddings() { 11 | return embeddings; 12 | } 13 | 14 | public CreateEmbeddingResponse(List> embeddings) { 15 | this.embeddings = embeddings; 16 | } 17 | 18 | @Override 19 | public String toString() { 20 | return new Gson().toJson(this); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/hf/HuggingFaceEmbeddingFunction.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.hf; 2 | 3 | 4 | import com.google.gson.Gson; 5 | import okhttp3.*; 6 | import org.jetbrains.annotations.NotNull; 7 | import tech.amikos.chromadb.*; 8 | import tech.amikos.chromadb.embeddings.EmbeddingFunction; 9 | import tech.amikos.chromadb.embeddings.WithParam; 10 | 11 | import java.io.IOException; 12 | import java.util.*; 13 | import java.util.stream.Collectors; 14 | 15 | import static tech.amikos.chromadb.Constants.JSON; 16 | 17 | public class HuggingFaceEmbeddingFunction implements EmbeddingFunction { 18 | public static final String DEFAULT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"; 19 | public static final String DEFAULT_BASE_API = "https://api-inference.huggingface.co/pipeline/feature-extraction/"; 20 | public static final String HFEI_API_PATH = "/embed"; 21 | public static final String HF_API_KEY_ENV = "HF_API_KEY"; 22 | public static final String API_TYPE_CONFIG_KEY = "apiType"; 23 | private final OkHttpClient client = new OkHttpClient(); 24 | private final Map configParams = new HashMap<>(); 25 | private static final Gson gson = new Gson(); 26 | 27 | private static final List defaults = Arrays.asList( 28 | new WithAPIType(APIType.HF_API), 29 | WithParam.baseAPI(DEFAULT_BASE_API), 30 | WithParam.defaultModel(DEFAULT_MODEL_NAME) 31 | ); 32 | 33 | public HuggingFaceEmbeddingFunction() throws EFException { 34 | for (WithParam param : defaults) { 35 | param.apply(this.configParams); 36 | } 37 | WithParam.apiKeyFromEnv(HF_API_KEY_ENV).apply(this.configParams); 38 | } 39 | 40 | public HuggingFaceEmbeddingFunction(WithParam... params) throws EFException { 41 | // apply defaults 42 | 43 | for (WithParam param : defaults) { 44 | param.apply(this.configParams); 45 | } 46 | for (WithParam param : params) { 47 | param.apply(this.configParams); 48 | } 49 | } 50 | 51 | public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) throws EFException { 52 | Request.Builder rb = new Request.Builder() 53 | 54 | .post(RequestBody.create(req.json(), JSON)) 55 | .addHeader("Accept", "application/json") 56 | .addHeader("Content-Type", "application/json") 57 | .addHeader("User-Agent", Constants.HTTP_AGENT); 58 | if (configParams.containsKey(API_TYPE_CONFIG_KEY) && configParams.get(API_TYPE_CONFIG_KEY).equals(APIType.HFEI_API)) { 59 | rb.url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString() + HFEI_API_PATH); 60 | } else { 61 | rb.url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString() + this.configParams.get(Constants.EF_PARAMS_MODEL).toString()); 62 | } 63 | if (configParams.containsKey(Constants.EF_PARAMS_API_KEY)) { 64 | rb.addHeader("Authorization", "Bearer " + configParams.get(Constants.EF_PARAMS_API_KEY).toString()); 65 | } 66 | Request request = rb.build(); 67 | try (Response response = client.newCall(request).execute()) { 68 | if (!response.isSuccessful()) { 69 | throw new IOException("Unexpected code " + response); 70 | } 71 | 72 | String responseData = response.body().string(); 73 | 74 | List parsedResponse = gson.fromJson(responseData, List.class); 75 | 76 | return new CreateEmbeddingResponse(parsedResponse); 77 | } catch (IOException e) { 78 | throw new EFException(e); 79 | } 80 | } 81 | 82 | @Override 83 | public Embedding embedQuery(String query) throws EFException { 84 | CreateEmbeddingResponse response = this.createEmbedding(new CreateEmbeddingRequest().inputs(new String[]{query})); 85 | return new Embedding(response.getEmbeddings().get(0)); 86 | } 87 | 88 | @Override 89 | public List embedDocuments(@NotNull List documents) throws EFException { 90 | CreateEmbeddingResponse response = this.createEmbedding(new CreateEmbeddingRequest().inputs(documents.toArray(new String[0]))); 91 | return response.getEmbeddings().stream().map(Embedding::fromList).collect(Collectors.toList()); 92 | } 93 | 94 | @Override 95 | public List embedDocuments(String[] documents) throws EFException { 96 | CreateEmbeddingResponse response = this.createEmbedding(new CreateEmbeddingRequest().inputs(documents)); 97 | return response.getEmbeddings().stream().map(Embedding::fromList).collect(Collectors.toList()); 98 | } 99 | 100 | public static class WithAPIType extends WithParam { 101 | private final APIType apiType; 102 | 103 | public WithAPIType(APIType apitype) { 104 | this.apiType = apitype; 105 | } 106 | 107 | @Override 108 | public void apply(Map params) { 109 | params.put(API_TYPE_CONFIG_KEY, apiType); 110 | } 111 | } 112 | 113 | public enum APIType{ 114 | HF_API, 115 | HFEI_API 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/ollama/CreateEmbeddingRequest.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.ollama; 2 | 3 | import com.google.gson.*; 4 | import com.google.gson.annotations.SerializedName; 5 | 6 | import java.lang.reflect.Type; 7 | 8 | public class CreateEmbeddingRequest { 9 | @SerializedName("model") 10 | private String model = "embed-english-v2.0"; 11 | 12 | @SerializedName("input") 13 | private String[] input; 14 | 15 | 16 | 17 | //create fluent api for all fields 18 | public CreateEmbeddingRequest model(String model) { 19 | this.model = model; 20 | return this; 21 | } 22 | 23 | public CreateEmbeddingRequest input(String[] input) { 24 | this.input = input; 25 | return this; 26 | } 27 | 28 | 29 | public String json() { 30 | GsonBuilder gsonBuilder = new GsonBuilder(); 31 | Gson customGson = gsonBuilder.create(); 32 | return customGson.toJson(this); 33 | } 34 | 35 | public String toString() { 36 | return json(); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/ollama/CreateEmbeddingResponse.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.ollama; 2 | 3 | import com.google.gson.Gson; 4 | import com.google.gson.annotations.SerializedName; 5 | import com.google.gson.internal.LinkedTreeMap; 6 | 7 | import java.util.List; 8 | 9 | public class CreateEmbeddingResponse { 10 | 11 | @SerializedName("model") 12 | public String model; 13 | 14 | @SerializedName("embeddings") 15 | public List> embeddings; 16 | 17 | // create getters for all fields 18 | 19 | public String getModel() { 20 | return model; 21 | } 22 | 23 | 24 | public List> getEmbeddings() { 25 | return embeddings; 26 | } 27 | 28 | 29 | @Override 30 | public String toString() { 31 | return new Gson().toJson(this); 32 | } 33 | 34 | 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/ollama/OllamaEmbeddingFunction.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.ollama; 2 | 3 | import com.google.gson.Gson; 4 | import okhttp3.*; 5 | import tech.amikos.chromadb.*; 6 | import tech.amikos.chromadb.embeddings.EmbeddingFunction; 7 | import tech.amikos.chromadb.embeddings.WithParam; 8 | 9 | import java.io.IOException; 10 | import java.util.*; 11 | import java.util.stream.Collectors; 12 | 13 | import static tech.amikos.chromadb.Constants.JSON; 14 | 15 | public class OllamaEmbeddingFunction implements EmbeddingFunction { 16 | public final static String DEFAULT_BASE_API = "http://localhost:11434/api/embed"; 17 | public final static String DEFAULT_MODEL_NAME = "nomic-embed-text"; 18 | private final OkHttpClient client = new OkHttpClient(); 19 | private final Gson gson = new Gson(); 20 | private final Map configParams = new HashMap<>(); 21 | 22 | private static final List defaults = Arrays.asList( 23 | WithParam.baseAPI(DEFAULT_BASE_API), 24 | WithParam.defaultModel(DEFAULT_MODEL_NAME) 25 | ); 26 | 27 | public OllamaEmbeddingFunction() throws EFException { 28 | for (WithParam param : defaults) { 29 | param.apply(this.configParams); 30 | } 31 | } 32 | 33 | public OllamaEmbeddingFunction(WithParam... params) throws EFException { 34 | // apply defaults 35 | 36 | for (WithParam param : defaults) { 37 | param.apply(this.configParams); 38 | } 39 | for (WithParam param : params) { 40 | param.apply(this.configParams); 41 | } 42 | } 43 | 44 | private CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) throws EFException { 45 | Request request = new Request.Builder() 46 | .url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString()) 47 | .post(RequestBody.create(req.json(), JSON)) 48 | .addHeader("Accept", "application/json") 49 | .addHeader("Content-Type", "application/json") 50 | .addHeader("User-Agent", Constants.HTTP_AGENT) 51 | .build(); 52 | try (Response response = client.newCall(request).execute()) { 53 | if (!response.isSuccessful()) { 54 | throw new IOException("Unexpected code " + response); 55 | } 56 | String responseData = response.body().string(); 57 | 58 | return gson.fromJson(responseData, CreateEmbeddingResponse.class); 59 | } catch (IOException e) { 60 | throw new EFException(e); 61 | } 62 | } 63 | 64 | @Override 65 | public Embedding embedQuery(String query) throws EFException { 66 | CreateEmbeddingResponse response = createEmbedding( 67 | new CreateEmbeddingRequest() 68 | .model(this.configParams.get(Constants.EF_PARAMS_MODEL).toString()) 69 | .input(new String[]{query}) 70 | ); 71 | return new Embedding(response.getEmbeddings().get(0)); 72 | } 73 | 74 | @Override 75 | public List embedDocuments(List documents) throws EFException { 76 | CreateEmbeddingResponse response = createEmbedding( 77 | new CreateEmbeddingRequest() 78 | .model(this.configParams.get(Constants.EF_PARAMS_MODEL).toString()) 79 | .input(documents.toArray(new String[0])) 80 | ); 81 | return response.getEmbeddings().stream().map(Embedding::new).collect(Collectors.toList()); 82 | } 83 | 84 | @Override 85 | public List embedDocuments(String[] documents) throws EFException { 86 | return embedDocuments(Arrays.asList(documents)); 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/openai/CreateEmbeddingRequest.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.openai; 2 | 3 | import com.google.gson.*; 4 | import com.google.gson.annotations.SerializedName; 5 | import io.swagger.v3.oas.annotations.media.Schema; 6 | 7 | import java.lang.reflect.Type; 8 | import java.util.List; 9 | 10 | public class CreateEmbeddingRequest { 11 | @SerializedName("model") 12 | private String model = "text-embedding-ada-002"; 13 | @SerializedName("user") 14 | private String user = "java-client=0.0.1"; 15 | 16 | public static class Input { 17 | private String text; 18 | private String[] texts; 19 | private Integer[] integers; 20 | 21 | private List listOfListOfIntegers; 22 | 23 | public Input(String text) { 24 | this.text = text; 25 | } 26 | 27 | public Input(String[] texts) { 28 | this.texts = texts; 29 | } 30 | 31 | public Input(Integer[] integers) { 32 | this.integers = integers; 33 | } 34 | 35 | public Input(List listOfListOfIntegers) { 36 | this.listOfListOfIntegers = listOfListOfIntegers; 37 | } 38 | 39 | public Object serialize() { 40 | if (text != null) { 41 | return text; 42 | } else if (texts != null) { 43 | return texts; 44 | } else if (integers != null) { 45 | return integers; 46 | } else if (listOfListOfIntegers != null) { 47 | return listOfListOfIntegers; 48 | } else { 49 | throw new RuntimeException("Invalid input"); 50 | } 51 | } 52 | } 53 | 54 | @SerializedName("input") 55 | private Input input; 56 | 57 | public CreateEmbeddingRequest() { 58 | } 59 | 60 | public CreateEmbeddingRequest(String model, String user) { 61 | this.model = model; 62 | this.user = user; 63 | } 64 | 65 | public CreateEmbeddingRequest(String model) { 66 | this.model = model; 67 | } 68 | 69 | public CreateEmbeddingRequest model(String model) { 70 | this.model = model; 71 | return this; 72 | } 73 | 74 | @Schema(example = "text-embedding-ada-002", required = true, description = "ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them. ") 75 | public String getModel() { 76 | return model; 77 | } 78 | 79 | public CreateEmbeddingRequest input(Input input) { 80 | this.input = input; 81 | return this; 82 | } 83 | 84 | public void setModel(String model) { 85 | this.model = model; 86 | } 87 | 88 | @Schema(example = "The quick brown fox jumped over the lazy dog", required = true, description = "Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. Each input must not exceed the max input tokens for the model (8191 tokens for `text-embedding-ada-002`). [Example Python code](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) for counting tokens. ") 89 | public Input getInput() { 90 | return input; 91 | } 92 | 93 | @Schema(example = "user-1234", description = "A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids). ") 94 | public String getUser() { 95 | return user; 96 | } 97 | 98 | public CreateEmbeddingRequest user(String user) { 99 | this.user = user; 100 | return this; 101 | } 102 | 103 | public static class CreateEmbeddingRequestSerializer implements JsonSerializer { 104 | @Override 105 | public JsonElement serialize(CreateEmbeddingRequest req, Type type, JsonSerializationContext context) { 106 | JsonObject jsonObject = new JsonObject(); 107 | jsonObject.addProperty("model", req.getModel()); 108 | jsonObject.addProperty("user", req.getUser()); 109 | JsonElement input = context.serialize(req.getInput().serialize()); 110 | jsonObject.add("input", input); 111 | return jsonObject; 112 | } 113 | } 114 | 115 | public String json() { 116 | Gson gson = new GsonBuilder() 117 | .registerTypeAdapter(CreateEmbeddingRequest.class, new CreateEmbeddingRequestSerializer()) 118 | .create(); 119 | return gson.toJson(this); 120 | } 121 | 122 | public String toString() { 123 | return new Gson().toJson(this); 124 | } 125 | 126 | } 127 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/openai/CreateEmbeddingResponse.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.openai; 2 | 3 | import com.google.gson.Gson; 4 | import com.google.gson.annotations.SerializedName; 5 | 6 | import java.util.List; 7 | 8 | public class CreateEmbeddingResponse { 9 | 10 | @SerializedName("object") 11 | private String object; 12 | 13 | @SerializedName("data") 14 | private List data; 15 | 16 | @SerializedName("model") 17 | private String model; 18 | 19 | @SerializedName("usage") 20 | private Usage usage; 21 | 22 | public List getData() { 23 | return data; 24 | } 25 | 26 | public static class EmbeddingData { 27 | @SerializedName("object") 28 | private String object; 29 | 30 | @SerializedName("index") 31 | private int index; 32 | 33 | @SerializedName("embedding") 34 | private List embedding; 35 | 36 | public List getEmbedding() { 37 | return embedding; 38 | } 39 | } 40 | 41 | 42 | @Override 43 | public String toString() { 44 | return new Gson().toJson(this); 45 | } 46 | 47 | 48 | private static class Usage { 49 | @SerializedName("prompt_tokens") 50 | private Integer promptTokens; 51 | @SerializedName("total_tokens") 52 | private Integer totalTokens; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/tech/amikos/chromadb/embeddings/openai/OpenAIEmbeddingFunction.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.openai; 2 | 3 | import com.google.gson.Gson; 4 | import okhttp3.OkHttpClient; 5 | import okhttp3.Request; 6 | import okhttp3.RequestBody; 7 | import okhttp3.Response; 8 | import tech.amikos.chromadb.*; 9 | import tech.amikos.chromadb.embeddings.EmbeddingFunction; 10 | import tech.amikos.chromadb.embeddings.WithParam; 11 | 12 | import java.io.IOException; 13 | import java.util.Arrays; 14 | import java.util.HashMap; 15 | import java.util.List; 16 | import java.util.Map; 17 | import java.util.stream.Collectors; 18 | 19 | import static tech.amikos.chromadb.Constants.JSON; 20 | 21 | public class OpenAIEmbeddingFunction implements EmbeddingFunction { 22 | 23 | public static final String DEFAULT_MODEL_NAME = "text-embedding-ada-002"; 24 | public static final String DEFAULT_BASE_API = "https://api.openai.com/v1/embeddings"; 25 | public static final String OPENAI_API_KEY_ENV = "OPENAI_API_KEY"; 26 | private final OkHttpClient client = new OkHttpClient(); 27 | private final Gson gson = new Gson(); 28 | private final Map configParams = new HashMap<>(); 29 | private static final List defaults = Arrays.asList( 30 | WithParam.baseAPI(DEFAULT_BASE_API), 31 | WithParam.defaultModel(DEFAULT_MODEL_NAME) 32 | ); 33 | 34 | 35 | public OpenAIEmbeddingFunction() throws EFException { 36 | for (WithParam param : defaults) { 37 | param.apply(this.configParams); 38 | } 39 | WithParam.apiKeyFromEnv(OPENAI_API_KEY_ENV).apply(this.configParams); 40 | } 41 | 42 | public OpenAIEmbeddingFunction(WithParam... params) throws EFException { 43 | // apply defaults 44 | 45 | for (WithParam param : defaults) { 46 | param.apply(this.configParams); 47 | } 48 | for (WithParam param : params) { 49 | param.apply(this.configParams); 50 | } 51 | } 52 | 53 | public CreateEmbeddingResponse createEmbedding(CreateEmbeddingRequest req) throws EFException { 54 | Request request = new Request.Builder() 55 | .url(this.configParams.get(Constants.EF_PARAMS_BASE_API).toString()) 56 | .post(RequestBody.create(req.json(), JSON)) 57 | .addHeader("Accept", "application/json") 58 | .addHeader("Content-Type", "application/json") 59 | .addHeader("Authorization", "Bearer " + configParams.get(Constants.EF_PARAMS_API_KEY).toString()) 60 | .build(); 61 | try (Response response = client.newCall(request).execute()) { 62 | if (!response.isSuccessful()) { 63 | throw new IOException("Unexpected code " + response); 64 | } 65 | 66 | String responseData = response.body().string(); 67 | 68 | return gson.fromJson(responseData, CreateEmbeddingResponse.class); 69 | } catch (IOException e) { 70 | throw new EFException(e); 71 | } 72 | } 73 | 74 | @Override 75 | public Embedding embedQuery(String query) throws EFException { 76 | CreateEmbeddingRequest req = new CreateEmbeddingRequest().model(this.configParams.get(Constants.EF_PARAMS_MODEL).toString()); 77 | req.input(new CreateEmbeddingRequest.Input(query)); 78 | CreateEmbeddingResponse response = this.createEmbedding(req); 79 | return new Embedding(response.getData().get(0).getEmbedding()); 80 | } 81 | 82 | @Override 83 | public List embedDocuments(List documents) throws EFException { 84 | CreateEmbeddingRequest req = new CreateEmbeddingRequest().model(this.configParams.get(Constants.EF_PARAMS_MODEL).toString()); 85 | req.input(new CreateEmbeddingRequest.Input(documents.toArray(new String[0]))); 86 | CreateEmbeddingResponse response = this.createEmbedding(req); 87 | return response.getData().stream().map(emb -> new Embedding(emb.getEmbedding())).collect(Collectors.toList()); 88 | } 89 | 90 | @Override 91 | public List embedDocuments(String[] documents) throws EFException { 92 | return embedDocuments(Arrays.asList(documents)); 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /src/main/resources/openapi/api.yaml: -------------------------------------------------------------------------------- 1 | openapi: "3.0.0" 2 | info: 3 | title: FastAPI 4 | version: 0.1.0 5 | paths: 6 | /api/v1: 7 | get: 8 | summary: Root 9 | operationId: root 10 | responses: 11 | '200': 12 | description: Successful Response 13 | content: 14 | application/json: 15 | schema: 16 | additionalProperties: 17 | type: integer 18 | type: object 19 | title: Response Root Api V1 Get 20 | /api/v1/reset: 21 | post: 22 | summary: Reset 23 | operationId: reset 24 | responses: 25 | '200': 26 | description: Successful Response 27 | content: 28 | application/json: 29 | schema: 30 | type: boolean 31 | title: Response Reset Api V1 Reset Post 32 | /api/v1/version: 33 | get: 34 | summary: Version 35 | operationId: version 36 | responses: 37 | '200': 38 | description: Successful Response 39 | content: 40 | application/json: 41 | schema: 42 | type: string 43 | title: Response Version Api V1 Version Get 44 | /api/v1/heartbeat: 45 | get: 46 | summary: Heartbeat 47 | operationId: heartbeat 48 | responses: 49 | '200': 50 | description: Successful Response 51 | content: 52 | application/json: 53 | schema: 54 | additionalProperties: 55 | type: number 56 | type: object 57 | title: Response Heartbeat Api V1 Heartbeat Get 58 | /api/v1/raw_sql: 59 | post: 60 | summary: Raw Sql 61 | operationId: raw_sql 62 | requestBody: 63 | content: 64 | application/json: 65 | schema: 66 | $ref: '#/components/schemas/RawSql' 67 | required: true 68 | responses: 69 | '200': 70 | description: Successful Response 71 | content: 72 | application/json: 73 | schema: {} 74 | '422': 75 | description: Validation Error 76 | content: 77 | application/json: 78 | schema: 79 | $ref: '#/components/schemas/HTTPValidationError' 80 | /api/v1/collections: 81 | get: 82 | summary: List Collections 83 | operationId: list_collections 84 | responses: 85 | '200': 86 | description: Successful Response 87 | content: 88 | application/json: 89 | schema: {} 90 | post: 91 | summary: Create Collection 92 | operationId: create_collection 93 | requestBody: 94 | content: 95 | application/json: 96 | schema: 97 | $ref: '#/components/schemas/CreateCollection' 98 | required: true 99 | responses: 100 | '200': 101 | description: Successful Response 102 | content: 103 | application/json: 104 | schema: {} 105 | '422': 106 | description: Validation Error 107 | content: 108 | application/json: 109 | schema: 110 | $ref: '#/components/schemas/HTTPValidationError' 111 | /api/v1/collections/{collection_id}/add: 112 | post: 113 | summary: Add 114 | operationId: add 115 | parameters: 116 | - required: true 117 | schema: 118 | type: string 119 | title: Collection Id 120 | name: collection_id 121 | in: path 122 | requestBody: 123 | content: 124 | application/json: 125 | schema: 126 | $ref: '#/components/schemas/AddEmbedding' 127 | required: true 128 | responses: 129 | '201': 130 | description: Successful Response 131 | content: 132 | application/json: 133 | schema: {} 134 | '422': 135 | description: Validation Error 136 | content: 137 | application/json: 138 | schema: 139 | $ref: '#/components/schemas/HTTPValidationError' 140 | /api/v1/collections/{collection_id}/update: 141 | post: 142 | summary: Update 143 | operationId: update 144 | parameters: 145 | - required: true 146 | schema: 147 | type: string 148 | title: Collection Id 149 | name: collection_id 150 | in: path 151 | requestBody: 152 | content: 153 | application/json: 154 | schema: 155 | $ref: '#/components/schemas/UpdateEmbedding' 156 | required: true 157 | responses: 158 | '200': 159 | description: Successful Response 160 | content: 161 | application/json: 162 | schema: {} 163 | '422': 164 | description: Validation Error 165 | content: 166 | application/json: 167 | schema: 168 | $ref: '#/components/schemas/HTTPValidationError' 169 | /api/v1/collections/{collection_id}/upsert: 170 | post: 171 | summary: Upsert 172 | operationId: upsert 173 | parameters: 174 | - required: true 175 | schema: 176 | type: string 177 | title: Collection Id 178 | name: collection_id 179 | in: path 180 | requestBody: 181 | content: 182 | application/json: 183 | schema: 184 | $ref: '#/components/schemas/AddEmbedding' 185 | required: true 186 | responses: 187 | '200': 188 | description: Successful Response 189 | content: 190 | application/json: 191 | schema: {} 192 | '422': 193 | description: Validation Error 194 | content: 195 | application/json: 196 | schema: 197 | $ref: '#/components/schemas/HTTPValidationError' 198 | /api/v1/collections/{collection_id}/get: 199 | post: 200 | summary: Get 201 | operationId: get 202 | parameters: 203 | - required: true 204 | schema: 205 | type: string 206 | title: Collection Id 207 | name: collection_id 208 | in: path 209 | requestBody: 210 | content: 211 | application/json: 212 | schema: 213 | $ref: '#/components/schemas/GetEmbedding' 214 | required: true 215 | responses: 216 | '200': 217 | description: Successful Response 218 | content: 219 | application/json: 220 | schema: {} #TODO add actual GetResult Body 221 | '422': 222 | description: Validation Error 223 | content: 224 | application/json: 225 | schema: 226 | $ref: '#/components/schemas/HTTPValidationError' 227 | /api/v1/collections/{collection_id}/delete: 228 | post: 229 | summary: Delete 230 | operationId: delete 231 | parameters: 232 | - required: true 233 | schema: 234 | type: string 235 | title: Collection Id 236 | name: collection_id 237 | in: path 238 | requestBody: 239 | content: 240 | application/json: 241 | schema: 242 | $ref: '#/components/schemas/DeleteEmbedding' 243 | required: true 244 | responses: 245 | '200': 246 | description: Successful Response 247 | content: 248 | application/json: 249 | schema: {} 250 | '422': 251 | description: Validation Error 252 | content: 253 | application/json: 254 | schema: 255 | $ref: '#/components/schemas/HTTPValidationError' 256 | /api/v1/collections/{collection_id}/count: 257 | get: 258 | summary: Count 259 | operationId: count 260 | parameters: 261 | - required: true 262 | schema: 263 | type: string 264 | title: Collection Id 265 | name: collection_id 266 | in: path 267 | responses: 268 | '200': 269 | description: Successful Response 270 | content: 271 | application/json: 272 | schema: 273 | type: integer 274 | '422': 275 | description: Validation Error 276 | content: 277 | application/json: 278 | schema: 279 | $ref: '#/components/schemas/HTTPValidationError' 280 | /api/v1/collections/{collection_id}/query: 281 | post: 282 | summary: Get Nearest Neighbors 283 | operationId: get_nearest_neighbors 284 | parameters: 285 | - required: true 286 | schema: 287 | type: string 288 | title: Collection Id 289 | name: collection_id 290 | in: path 291 | requestBody: 292 | content: 293 | application/json: 294 | schema: 295 | $ref: '#/components/schemas/QueryEmbedding' 296 | required: true 297 | responses: 298 | '200': 299 | description: Successful Response 300 | content: 301 | application/json: 302 | schema: {} 303 | '422': 304 | description: Validation Error 305 | content: 306 | application/json: 307 | schema: 308 | $ref: '#/components/schemas/HTTPValidationError' 309 | /api/v1/collections/{collection_name}/create_index: 310 | post: 311 | summary: Create Index 312 | operationId: create_index 313 | parameters: 314 | - required: true 315 | schema: 316 | type: string 317 | title: Collection Name 318 | name: collection_name 319 | in: path 320 | responses: 321 | '200': 322 | description: Successful Response 323 | content: 324 | application/json: 325 | schema: 326 | type: boolean 327 | '422': 328 | description: Validation Error 329 | content: 330 | application/json: 331 | schema: 332 | $ref: '#/components/schemas/HTTPValidationError' 333 | /api/v1/collections/{collection_name}: 334 | get: 335 | summary: Get Collection 336 | operationId: get_collection 337 | parameters: 338 | - required: true 339 | schema: 340 | type: string 341 | title: Collection Name 342 | name: collection_name 343 | in: path 344 | responses: 345 | '200': 346 | description: Successful Response 347 | content: 348 | application/json: 349 | schema: {} 350 | '422': 351 | description: Validation Error 352 | content: 353 | application/json: 354 | schema: 355 | $ref: '#/components/schemas/HTTPValidationError' 356 | delete: 357 | summary: Delete Collection 358 | operationId: delete_collection 359 | parameters: 360 | - required: true 361 | schema: 362 | type: string 363 | title: Collection Name 364 | name: collection_name 365 | in: path 366 | responses: 367 | '200': 368 | description: Successful Response 369 | content: 370 | application/json: 371 | schema: {} 372 | '422': 373 | description: Validation Error 374 | content: 375 | application/json: 376 | schema: 377 | $ref: '#/components/schemas/HTTPValidationError' 378 | /api/v1/collections/{collection_id}: 379 | put: 380 | summary: Update Collection 381 | operationId: update_collection 382 | parameters: 383 | - required: true 384 | schema: 385 | type: string 386 | title: Collection Id 387 | name: collection_id 388 | in: path 389 | requestBody: 390 | content: 391 | application/json: 392 | schema: 393 | $ref: '#/components/schemas/UpdateCollection' 394 | required: true 395 | responses: 396 | '200': 397 | description: Successful Response 398 | content: 399 | application/json: 400 | schema: {} 401 | '422': 402 | description: Validation Error 403 | content: 404 | application/json: 405 | schema: 406 | $ref: '#/components/schemas/HTTPValidationError' 407 | components: 408 | schemas: 409 | AddEmbedding: 410 | properties: 411 | embeddings: 412 | items: {} 413 | type: array 414 | title: Embeddings 415 | metadatas: 416 | items: 417 | type: object 418 | additionalProperties: true 419 | type: array 420 | title: Metadatas 421 | documents: 422 | items: 423 | type: string 424 | type: array 425 | title: Documents 426 | ids: 427 | items: 428 | type: string 429 | type: array 430 | title: Ids 431 | increment_index: 432 | type: boolean 433 | title: Increment Index 434 | default: true 435 | type: object 436 | required: 437 | - ids 438 | title: AddEmbedding 439 | CreateCollection: 440 | properties: 441 | name: 442 | type: string 443 | title: Name 444 | metadata: 445 | type: object 446 | title: Metadata 447 | get_or_create: 448 | type: boolean 449 | title: Get Or Create 450 | default: false 451 | type: object 452 | required: 453 | - name 454 | title: CreateCollection 455 | DeleteEmbedding: 456 | properties: 457 | ids: 458 | items: 459 | type: string 460 | type: array 461 | title: Ids 462 | where: 463 | type: object 464 | title: Where 465 | where_document: 466 | type: object 467 | title: Where Document 468 | type: object 469 | title: DeleteEmbedding 470 | GetEmbedding: 471 | properties: 472 | ids: 473 | items: 474 | type: string 475 | type: array 476 | title: Ids 477 | where: 478 | type: object 479 | title: Where 480 | where_document: 481 | type: object 482 | title: Where Document 483 | sort: 484 | type: string 485 | title: Sort 486 | limit: 487 | type: integer 488 | title: Limit 489 | offset: 490 | type: integer 491 | title: Offset 492 | include: 493 | items: 494 | anyOf: 495 | - type: string 496 | enum: 497 | - documents 498 | - type: string 499 | enum: 500 | - embeddings 501 | - type: string 502 | enum: 503 | - metadatas 504 | - type: string 505 | enum: 506 | - distances 507 | type: array 508 | title: Include 509 | default: 510 | - metadatas 511 | - documents 512 | type: object 513 | title: GetEmbedding 514 | HTTPValidationError: 515 | properties: 516 | detail: 517 | items: 518 | $ref: '#/components/schemas/ValidationError' 519 | type: array 520 | title: Detail 521 | type: object 522 | title: HTTPValidationError 523 | QueryEmbedding: 524 | properties: 525 | where: 526 | type: object 527 | title: Where 528 | additionalProperties: true 529 | default: {} 530 | where_document: 531 | type: object 532 | title: Where Document 533 | additionalProperties: true 534 | default: {} 535 | query_embeddings: 536 | items: {} 537 | type: array 538 | additionalProperties: true 539 | title: Query Embeddings 540 | n_results: 541 | type: integer 542 | title: N Results 543 | default: 10 544 | include: 545 | items: 546 | type: string 547 | enum: 548 | - documents 549 | - embeddings 550 | - metadatas 551 | - distances 552 | type: array 553 | title: Include 554 | default: 555 | - metadatas 556 | - documents 557 | - distances 558 | type: object 559 | required: 560 | - query_embeddings 561 | title: QueryEmbedding 562 | RawSql: 563 | properties: 564 | raw_sql: 565 | type: string 566 | title: Raw Sql 567 | type: object 568 | required: 569 | - raw_sql 570 | title: RawSql 571 | UpdateCollection: 572 | properties: 573 | new_name: 574 | type: string 575 | title: New Name 576 | new_metadata: 577 | type: object 578 | title: New Metadata 579 | type: object 580 | title: UpdateCollection 581 | UpdateEmbedding: 582 | properties: 583 | embeddings: 584 | items: {} 585 | type: array 586 | title: Embeddings 587 | metadatas: 588 | items: 589 | type: object 590 | type: array 591 | title: Metadatas 592 | documents: 593 | items: 594 | type: string 595 | type: array 596 | title: Documents 597 | ids: 598 | items: 599 | type: string 600 | type: array 601 | title: Ids 602 | increment_index: 603 | type: boolean 604 | title: Increment Index 605 | default: true 606 | type: object 607 | required: 608 | - ids 609 | title: UpdateEmbedding 610 | ValidationError: 611 | properties: 612 | loc: 613 | items: 614 | anyOf: 615 | - type: string 616 | - type: integer 617 | type: array 618 | title: Location 619 | msg: 620 | type: string 621 | title: Message 622 | type: 623 | type: string 624 | title: Error Type 625 | type: object 626 | required: 627 | - loc 628 | - msg 629 | - type 630 | title: ValidationError 631 | -------------------------------------------------------------------------------- /src/test/java/tech/amikos/chromadb/TestAPI.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb; 2 | 3 | import com.github.tomakehurst.wiremock.junit.WireMockRule; 4 | import org.junit.BeforeClass; 5 | import org.junit.Rule; 6 | import org.junit.Test; 7 | import org.testcontainers.chromadb.ChromaDBContainer; 8 | import org.testcontainers.containers.wait.strategy.Wait; 9 | import tech.amikos.chromadb.embeddings.DefaultEmbeddingFunction; 10 | import tech.amikos.chromadb.embeddings.EmbeddingFunction; 11 | import tech.amikos.chromadb.handler.ApiException; 12 | 13 | import java.io.IOException; 14 | import java.math.BigDecimal; 15 | import java.util.*; 16 | 17 | import static com.github.tomakehurst.wiremock.client.WireMock.*; 18 | import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; 19 | import static org.junit.Assert.*; 20 | 21 | public class TestAPI { 22 | 23 | @Rule 24 | public WireMockRule wireMockRule = new WireMockRule(8001); 25 | static ChromaDBContainer chromaContainer; 26 | 27 | @BeforeClass 28 | public static void setupChromaDB() throws Exception { 29 | try { 30 | Utils.loadEnvFile(".env"); 31 | String chromaVersion = System.getenv("CHROMA_VERSION"); 32 | if (chromaVersion == null) { 33 | chromaVersion = "0.5.15"; 34 | } 35 | chromaContainer = new ChromaDBContainer("chromadb/chroma:" + chromaVersion).withEnv("ALLOW_RESET", "TRUE"); 36 | chromaContainer.start(); 37 | chromaContainer.waitingFor(Wait.forHttp("/api/v1/heartbeat").forStatusCode(200)); 38 | System.setProperty("CHROMA_URL", chromaContainer.getEndpoint()); 39 | } catch (Exception e) { 40 | System.err.println("ChromaDBContainer failed to start"); 41 | throw e; 42 | } 43 | } 44 | 45 | public static void tearDownChromaDB() { 46 | System.out.println(chromaContainer.getLogs()); 47 | chromaContainer.stop(); 48 | } 49 | 50 | 51 | @Test 52 | public void testHeartbeat() throws ApiException, IOException, InterruptedException { 53 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 54 | Map hb = client.heartbeat(); 55 | assertTrue(hb.containsKey("nanosecond heartbeat")); 56 | } 57 | 58 | @Test 59 | public void testGetCollectionGet() throws ApiException, IOException, EFException { 60 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 61 | client.reset(); 62 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 63 | client.createCollection("test-collection", null, true, ef); 64 | assertTrue(client.getCollection("test-collection", ef).get() != null); 65 | } 66 | 67 | 68 | @Test 69 | public void testCreateCollection() throws ApiException, EFException { 70 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 71 | client.reset(); 72 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 73 | Collection collection = client.createCollection("test-collection", null, true, ef); 74 | assertEquals(collection.getName(), "test-collection"); 75 | } 76 | 77 | @Test 78 | public void testDeleteCollection() throws ApiException, EFException { 79 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 80 | client.reset(); 81 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 82 | client.createCollection("test-collection", null, true, ef); 83 | client.deleteCollection("test-collection"); 84 | 85 | try { 86 | client.getCollection("test-collection", ef); 87 | } catch (ApiException e) { 88 | assertTrue(Arrays.asList(400, 500).contains(e.getCode())); 89 | } 90 | } 91 | 92 | @Test 93 | public void testCreateUpsert() throws ApiException, ChromaException { 94 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 95 | client.reset(); 96 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 97 | Collection collection = client.createCollection("test-collection", null, true, ef); 98 | List> metadata = new ArrayList<>(); 99 | metadata.add(new HashMap() {{ 100 | put("key", "value"); 101 | }}); 102 | Object resp = collection.upsert(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); 103 | assertEquals(1, (int) collection.count()); 104 | } 105 | 106 | @Test 107 | public void testCreateAdd() throws ApiException, ChromaException { 108 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 109 | client.reset(); 110 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 111 | Collection collection = client.createCollection("test-collection", null, true, ef); 112 | List> metadata = new ArrayList<>(); 113 | metadata.add(new HashMap() {{ 114 | put("key", "value"); 115 | }}); 116 | Object resp = collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); 117 | System.out.println(resp); 118 | System.out.println(collection.get()); 119 | assertEquals(1, (int) collection.count()); 120 | } 121 | 122 | @Test 123 | public void testQuery() throws ApiException, ChromaException { 124 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 125 | client.reset(); 126 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 127 | Collection collection = client.createCollection("test-collection", null, true, ef); 128 | List> metadata = new ArrayList<>(); 129 | metadata.add(new HashMap() {{ 130 | put("key", "value"); 131 | }}); 132 | collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); 133 | collection.add(null, metadata, Arrays.asList("Hello, my name is Bond. I am a Spy."), Arrays.asList("2")); 134 | Collection.QueryResponse qr = collection.query(Arrays.asList("name is John"), 10, null, null, null); 135 | assertEquals(qr.getIds().get(0).get(0), "1"); //we check that Bond doc is first 136 | } 137 | 138 | @Test 139 | public void testQueryExample() throws ApiException, ChromaException { 140 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 141 | client.reset(); 142 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 143 | Collection collection = client.createCollection("test-collection", null, true, ef); 144 | List> metadata = new ArrayList<>(); 145 | metadata.add(new HashMap() {{ 146 | put("type", "scientist"); 147 | }}); 148 | metadata.add(new HashMap() {{ 149 | put("type", "spy"); 150 | }}); 151 | collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, my name is Bond. I am a Spy."), Arrays.asList("1", "2")); 152 | Collection.QueryResponse qr = collection.query(Arrays.asList("Who is the spy"), 10, null, null, null); 153 | assertEquals(qr.getIds().get(0).get(0), "2"); //we check that Bond doc is first 154 | } 155 | 156 | @Test 157 | public void testReset() throws ApiException, EFException { 158 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 159 | client.reset(); 160 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 161 | Collection collection = client.createCollection("test-collection", null, true, ef); 162 | List> metadata = new ArrayList<>(); 163 | client.reset(); 164 | 165 | try { 166 | client.getCollection("test-collection", ef); 167 | } catch (ApiException e) { 168 | assertTrue(Arrays.asList(400, 500).contains(e.getCode())); 169 | } 170 | } 171 | 172 | @Test 173 | public void testListCollections() throws ApiException, EFException { 174 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 175 | client.reset(); 176 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 177 | client.createCollection("test-collection", null, true, ef); 178 | assertEquals(client.listCollections().size(), 1); 179 | } 180 | 181 | @Test 182 | public void testCollectionCount() throws ApiException, ChromaException { 183 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 184 | client.reset(); 185 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 186 | Collection collection = client.createCollection("test-collection", null, true, ef); 187 | List> metadata = new ArrayList<>(); 188 | metadata.add(new HashMap() {{ 189 | put("key", "value"); 190 | }}); 191 | Object resp = collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); 192 | assertEquals(1, (int) collection.count()); 193 | } 194 | 195 | @Test 196 | public void testCollectionDeleteIds() throws ApiException, ChromaException { 197 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 198 | client.reset(); 199 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 200 | Collection collection = client.createCollection("test-collection", null, true, ef); 201 | List> metadata = new ArrayList<>(); 202 | metadata.add(new HashMap() {{ 203 | put("key", "value"); 204 | }}); 205 | Object resp = collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); 206 | collection.deleteWithIds(Arrays.asList("1")); 207 | assertEquals(0, collection.get().getDocuments().size()); 208 | } 209 | 210 | @Test 211 | public void testCollectionDeleteWhere() throws ApiException, ChromaException { 212 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 213 | client.reset(); 214 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 215 | Collection collection = client.createCollection("test-collection", null, true, ef); 216 | List> metadata = new ArrayList<>(); 217 | metadata.add(new HashMap() {{ 218 | put("key", "value"); 219 | }}); 220 | Object resp = collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); 221 | Map where = new HashMap<>(); 222 | where.put("key", "value"); 223 | collection.deleteWhere(where); 224 | assertEquals(0, (int) collection.count()); 225 | } 226 | 227 | @Test 228 | public void testCollectionDeleteLogicalOrWhere() throws ApiException, ChromaException { 229 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 230 | client.reset(); 231 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 232 | Collection collection = client.createCollection("test-collection", null, true, ef); 233 | List> metadata = new ArrayList<>(); 234 | metadata.add(new HashMap() {{ 235 | put("key1", "value1"); 236 | }}); 237 | Object resp = collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); 238 | Map where = new HashMap<>(); 239 | List> conditions = new ArrayList<>(); 240 | Map condition1 = new HashMap<>(); 241 | condition1.put("key1", "value1"); 242 | conditions.add(condition1); 243 | 244 | Map condition2 = new HashMap<>(); 245 | condition2.put("key2", "value2"); 246 | conditions.add(condition2); 247 | where.put("$or", conditions); 248 | collection.deleteWhere(where); 249 | assertEquals(0, (int) collection.count()); 250 | } 251 | 252 | @Test 253 | public void testCollectionDeleteWhereNoMatch() throws ApiException, ChromaException { 254 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 255 | client.reset(); 256 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 257 | Collection collection = client.createCollection("test-collection", null, true, ef); 258 | List> metadata = new ArrayList<>(); 259 | metadata.add(new HashMap() {{ 260 | put("key", "value"); 261 | }}); 262 | Object resp = collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); 263 | Map where = new HashMap<>(); 264 | where.put("key", "value2"); 265 | collection.deleteWhere(where); 266 | assertEquals(1, (int) collection.count()); 267 | } 268 | 269 | @Test 270 | public void testCollectionDeleteWhereDocuments() throws ApiException, ChromaException { 271 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 272 | client.reset(); 273 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 274 | Collection collection = client.createCollection("test-collection", null, true, ef); 275 | List> metadata = new ArrayList<>(); 276 | metadata.add(new HashMap() {{ 277 | put("key", "value"); 278 | }}); 279 | Object resp = collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); 280 | Map whereDocuments = new HashMap<>(); 281 | whereDocuments.put("$contains", "John"); 282 | collection.deleteWhereDocuments(whereDocuments); 283 | assertEquals(0, (int) collection.count()); 284 | 285 | } 286 | 287 | @Test 288 | public void testCollectionDeleteWhereDocumentsNoMatch() throws ApiException, ChromaException { 289 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 290 | client.reset(); 291 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 292 | Collection collection = client.createCollection("test-collection", null, true, ef); 293 | List> metadata = new ArrayList<>(); 294 | metadata.add(new HashMap() {{ 295 | put("key", "value"); 296 | }}); 297 | Object resp = collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); 298 | Map whereDocuments = new HashMap<>(); 299 | whereDocuments.put("$contains", "Mohn"); 300 | collection.deleteWhereDocuments(whereDocuments); 301 | assertEquals(1, (int) collection.count()); 302 | } 303 | 304 | 305 | @Test 306 | public void testVersion() throws ApiException { 307 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 308 | client.reset(); 309 | String version = client.version(); 310 | assertNotNull(version); 311 | } 312 | 313 | 314 | @Test 315 | public void testUpdateCollection() throws ApiException, ChromaException { 316 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 317 | client.reset(); 318 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 319 | Collection collection = client.createCollection("test-collection", null, true, ef); 320 | List> metadata = new ArrayList<>(); 321 | metadata.add(new HashMap() {{ 322 | put("key", "value"); 323 | }}); 324 | Object resp = collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); 325 | collection.update("test-collection2", null); 326 | assertEquals(collection.getName(), "test-collection2"); 327 | } 328 | 329 | @Test 330 | public void testCollectionUpdateEmbeddings() throws ApiException, ChromaException { 331 | Client client = new Client(Utils.getEnvOrProperty("CHROMA_URL")); 332 | client.reset(); 333 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 334 | Collection collection = client.createCollection("test-collection", null, true, ef); 335 | List> metadata = new ArrayList<>(); 336 | metadata.add(new HashMap() {{ 337 | put("key", "value"); 338 | }}); 339 | Object resp = collection.add(null, metadata, Arrays.asList("Hello, my name is John. I am a Data Scientist."), Arrays.asList("1")); 340 | collection.updateEmbeddings(null, null, Arrays.asList("Hello, my name is Bonh. I am a Data Scientist."), Arrays.asList("1")); 341 | 342 | } 343 | 344 | @Test 345 | public void testTimeoutOk() throws ApiException, IOException { 346 | stubFor(get(urlEqualTo("/api/v1/heartbeat")) 347 | .willReturn(aResponse() 348 | .withHeader("Content-Type", "application/json") 349 | .withBody("{\"nanosecond heartbeat\": 123456789}").withFixedDelay(2000))); 350 | 351 | Utils.loadEnvFile(".env"); 352 | Client client = new Client("http://127.0.0.1:8001"); 353 | client.setTimeout(3); 354 | Map hb = client.heartbeat(); 355 | assertTrue(hb.containsKey("nanosecond heartbeat")); 356 | } 357 | 358 | @Test(expected = ApiException.class) 359 | public void testTimeoutExpires() throws ApiException, IOException { 360 | stubFor(get(urlEqualTo("/api/v1/heartbeat")) 361 | .willReturn(aResponse() 362 | .withHeader("Content-Type", "application/json") 363 | .withBody("{\"nanosecond heartbeat\": 123456789}").withFixedDelay(2000))); 364 | 365 | Utils.loadEnvFile(".env"); 366 | Client client = new Client("http://127.0.0.1:8001"); 367 | client.setTimeout(1); 368 | try { 369 | client.heartbeat(); 370 | } catch (ApiException e) { 371 | assertTrue(e.getMessage().contains("Read timed out") || e.getMessage().contains("timeout")); 372 | throw e; 373 | } 374 | 375 | } 376 | 377 | 378 | @Test 379 | public void testClientHeaders() throws ApiException, IOException { 380 | stubFor(get(urlEqualTo("/api/v1/heartbeat")) 381 | .willReturn(aResponse() 382 | .withHeader("Content-Type", "application/json") 383 | .withBody("{\"nanosecond heartbeat\": 123456789}"))); 384 | Utils.loadEnvFile(".env"); 385 | Client client = new Client("http://127.0.0.1:8001"); 386 | client.setDefaultHeaders(new HashMap() {{ 387 | put("Your-Header-Key", "Your-Expected-Header-Value"); 388 | }}); 389 | Map hb = client.heartbeat(); 390 | assertTrue(hb.containsKey("nanosecond heartbeat")); 391 | // Verify that a GET request was made with a specific header 392 | verify(getRequestedFor(urlEqualTo("/api/v1/heartbeat")) 393 | .withHeader("Your-Header-Key", equalTo("Your-Expected-Header-Value"))); 394 | } 395 | 396 | @Test 397 | public void testClientAuthorizationBasicHeader() throws ApiException, IOException { 398 | stubFor(get(urlEqualTo("/api/v1/heartbeat")) 399 | .willReturn(aResponse() 400 | .withHeader("Content-Type", "application/json") 401 | .withBody("{\"nanosecond heartbeat\": 123456789}"))); 402 | Utils.loadEnvFile(".env"); 403 | Client client = new Client("http://127.0.0.1:8001"); 404 | String encodedString = Base64.getEncoder().encodeToString("admin:admin".getBytes()); 405 | client.setDefaultHeaders(new HashMap() {{ 406 | put("Authorization", "Basic " + encodedString); 407 | }}); 408 | Map hb = client.heartbeat(); 409 | assertTrue(hb.containsKey("nanosecond heartbeat")); 410 | // Verify that a GET request was made with a specific header 411 | verify(getRequestedFor(urlEqualTo("/api/v1/heartbeat")) 412 | .withHeader("Authorization", equalTo("Basic " + encodedString))); 413 | } 414 | 415 | @Test 416 | public void testClientAuthorizationBearerHeader() throws ApiException, IOException { 417 | stubFor(get(urlEqualTo("/api/v1/heartbeat")) 418 | .willReturn(aResponse() 419 | .withHeader("Content-Type", "application/json") 420 | .withBody("{\"nanosecond heartbeat\": 123456789}"))); 421 | Utils.loadEnvFile(".env"); 422 | Client client = new Client("http://127.0.0.1:8001"); 423 | client.setDefaultHeaders(new HashMap() {{ 424 | put("Authorization", "Bearer test-token"); 425 | }}); 426 | Map hb = client.heartbeat(); 427 | assertTrue(hb.containsKey("nanosecond heartbeat")); 428 | // Verify that a GET request was made with a specific header 429 | verify(getRequestedFor(urlEqualTo("/api/v1/heartbeat")) 430 | .withHeader("Authorization", equalTo("Bearer test-token"))); 431 | } 432 | 433 | @Test 434 | public void testClientXChromaTokenHeader() throws ApiException, IOException { 435 | stubFor(get(urlEqualTo("/api/v1/heartbeat")) 436 | .willReturn(aResponse() 437 | .withHeader("Content-Type", "application/json") 438 | .withBody("{\"nanosecond heartbeat\": 123456789}"))); 439 | Utils.loadEnvFile(".env"); 440 | Client client = new Client("http://127.0.0.1:8001"); 441 | client.setDefaultHeaders(new HashMap() {{ 442 | put("X-Chroma-Token", "test-token"); 443 | }}); 444 | Map hb = client.heartbeat(); 445 | assertTrue(hb.containsKey("nanosecond heartbeat")); 446 | // Verify that a GET request was made with a specific header 447 | verify(getRequestedFor(urlEqualTo("/api/v1/heartbeat")) 448 | .withHeader("X-Chroma-Token", equalTo("test-token"))); 449 | } 450 | } 451 | -------------------------------------------------------------------------------- /src/test/java/tech/amikos/chromadb/Utils.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.File; 5 | import java.io.FileReader; 6 | import java.io.IOException; 7 | 8 | public class Utils { 9 | public static void loadEnvFile(String path) { 10 | File file = new File(path); 11 | if (!file.exists()) { 12 | System.out.println(".env file does not exist. Skipping loading."); 13 | return; 14 | } 15 | try (BufferedReader br = new BufferedReader(new FileReader(path))) { 16 | String line; 17 | while ((line = br.readLine()) != null) { 18 | int equalIndex = line.indexOf('='); 19 | if (equalIndex > 0) { 20 | String key = line.substring(0, equalIndex).trim(); 21 | String value = line.substring(equalIndex + 1).trim(); 22 | System.setProperty(key, value); 23 | } 24 | } 25 | } catch (IOException e) { 26 | System.err.println("Failed to load .env file: " + e.getMessage()); 27 | } 28 | } 29 | 30 | public static String getEnvOrProperty(String key) { 31 | // Try to get the value from the environment variables 32 | String value = System.getenv(key); 33 | 34 | // If not found, try to get the value from the system properties 35 | if (value == null) { 36 | value = System.getProperty(key); 37 | } 38 | 39 | return value; 40 | } 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/test/java/tech/amikos/chromadb/embeddings/TestDefaultEmbeddings.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings; 2 | 3 | import org.apache.commons.io.FileUtils; 4 | import org.junit.Test; 5 | import tech.amikos.chromadb.Embedding; 6 | 7 | import java.util.Arrays; 8 | import java.util.List; 9 | 10 | import static org.junit.Assert.*; 11 | 12 | public class TestDefaultEmbeddings { 13 | // this represents the output of sentence-transformers/all-MiniLM-L6-v2 for "Hello, my name is John. I am a Data Scientist.", "Hello, I am Jane and I am an ML researcher." 14 | private static float[][] groundThruth = {{-0.09585458785295486f, 15 | 0.00948028638958931f, 16 | 0.06449528783559799f, 17 | 0.06595586985349655f, 18 | -0.04144003242254257f, 19 | -0.10210221260786057f, 20 | 0.041818294674158096f, 21 | -0.004058007150888443f, 22 | -0.030316634103655815f, 23 | 0.02270486019551754f, 24 | -0.035775184631347656f, 25 | -0.019363753497600555f, 26 | 0.05840573087334633f, 27 | -0.019858120009303093f, 28 | -0.03416302800178528f, 29 | 0.023755116388201714f, 30 | -0.04887963458895683f, 31 | -0.023338720202445984f, 32 | -0.03097701258957386f, 33 | -0.06726820021867752f, 34 | -0.047849390655756f, 35 | 0.002262122929096222f, 36 | -0.050954557955265045f, 37 | -0.0704192966222763f, 38 | 0.03851667791604996f, 39 | 0.05452864617109299f, 40 | 0.0472191721200943f, 41 | 0.05565935745835304f, 42 | -0.06013466790318489f, 43 | -0.060048338025808334f, 44 | -0.06324896961450577f, 45 | -0.004660584963858128f, 46 | 0.0833488330245018f, 47 | 0.06518704444169998f, 48 | 0.061783649027347565f, 49 | -0.0057580843567848206f, 50 | -0.06831730902194977f, 51 | -0.0027126308996230364f, 52 | 0.0032052304595708847f, 53 | 0.004442276898771524f, 54 | 0.02566530555486679f, 55 | -0.06587797403335571f, 56 | 0.04000983387231827f, 57 | 0.020301034674048424f, 58 | -0.042582057416439056f, 59 | -0.034585196524858475f, 60 | 0.01486957911401987f, 61 | 0.006351375486701727f, 62 | 0.012819512747228146f, 63 | 0.016794603317975998f, 64 | -0.1442369818687439f, 65 | -0.036635056138038635f, 66 | -0.012132037431001663f, 67 | 0.054017357528209686f, 68 | 0.03453433886170387f, 69 | -0.03326355293393135f, 70 | 0.03831372782588005f, 71 | 0.03907116502523422f, 72 | -0.04063969478011131f, 73 | -0.011150147765874863f, 74 | -0.06000356003642082f, 75 | -0.012438077479600906f, 76 | -0.1148035079240799f, 77 | 0.05022793635725975f, 78 | 0.08595263212919235f, 79 | 0.019273564219474792f, 80 | -0.03813284635543823f, 81 | 0.04905250295996666f, 82 | 0.021730052307248116f, 83 | -0.06437823921442032f, 84 | -0.06207166984677315f, 85 | -0.010931634344160557f, 86 | -0.07005178183317184f, 87 | 0.048188939690589905f, 88 | -0.021501995623111725f, 89 | -0.01519235409796238f, 90 | -0.01088312640786171f, 91 | 0.036744046956300735f, 92 | 0.11533796042203903f, 93 | 0.04394731670618057f, 94 | 0.022351941093802452f, 95 | -0.0694778710603714f, 96 | -0.14287206530570984f, 97 | 0.002787005854770541f, 98 | -0.02520938776433468f, 99 | 0.018195422366261482f, 100 | -0.03290892764925957f, 101 | 0.09594959020614624f, 102 | -0.0790214091539383f, 103 | -0.06380059570074081f, 104 | 0.0020406043622642756f, 105 | -0.0022298344410955906f, 106 | 0.0011398063506931067f, 107 | 0.08340359479188919f, 108 | -0.09954331070184708f, 109 | -0.024701174348592758f, 110 | 0.019631244242191315f, 111 | -0.006075597368180752f, 112 | -0.04857270419597626f, 113 | 0.036346402019262314f, 114 | -0.011533500626683235f, 115 | 0.03277066349983215f, 116 | 0.022185729816555977f, 117 | 0.10524636507034302f, 118 | 0.0022879349999129772f, 119 | 0.01772546023130417f, 120 | -0.04884761571884155f, 121 | 0.05326351150870323f, 122 | 0.03047122433781624f, 123 | -0.08757409453392029f, 124 | -0.030721280723810196f, 125 | -0.00772859575226903f, 126 | -0.08797163516283035f, 127 | -0.028308028355240822f, 128 | -0.001540705212391913f, 129 | -0.06213157996535301f, 130 | -0.04336702451109886f, 131 | 0.05428041145205498f, 132 | 0.026199134066700935f, 133 | -0.0470283217728138f, 134 | 0.028990667313337326f, 135 | 0.05769278481602669f, 136 | -0.08303508162498474f, 137 | -0.05381003022193909f, 138 | 0.036160971969366074f, 139 | 0.058339837938547134f, 140 | -0.09613247960805893f, 141 | -6.687360178521131e-33f, 142 | 0.03682004287838936f, 143 | 0.043789900839328766f, 144 | 0.10437802970409393f, 145 | 0.15669868886470795f, 146 | -0.06030746549367905f, 147 | -0.0011067378800362349f, 148 | -0.10755843669176102f, 149 | -0.0020333367865532637f, 150 | -0.03200452774763107f, 151 | 0.010678966529667377f, 152 | -0.000848894938826561f, 153 | 0.04295719414949417f, 154 | 0.019144756719470024f, 155 | -0.009240921586751938f, 156 | -0.07697387784719467f, 157 | 0.031998272985219955f, 158 | 0.012779496610164642f, 159 | 0.006133145187050104f, 160 | -0.026671690866351128f, 161 | 0.052042942494153976f, 162 | -0.002637297147884965f, 163 | -0.04257839545607567f, 164 | 0.030323265120387077f, 165 | 0.04461345449090004f, 166 | 0.08032087981700897f, 167 | -0.038311731070280075f, 168 | 0.02064959704875946f, 169 | -0.08821066468954086f, 170 | 0.15501201152801514f, 171 | 0.005460179410874844f, 172 | 0.012035267427563667f, 173 | 0.01326702255755663f, 174 | 0.006971028633415699f, 175 | -0.03180614486336708f, 176 | 0.07937445491552353f, 177 | -0.027660902589559555f, 178 | -0.0010033247526735067f, 179 | -0.07399091124534607f, 180 | 0.005404886789619923f, 181 | 0.016363127157092094f, 182 | -0.037255868315696716f, 183 | 0.03943704441189766f, 184 | 0.03164369985461235f, 185 | 0.025043606758117676f, 186 | -0.0387364886701107f, 187 | -0.04596373811364174f, 188 | 0.08971153199672699f, 189 | 0.009558212012052536f, 190 | 0.061766281723976135f, 191 | -0.00535972835496068f, 192 | -0.06706439703702927f, 193 | 0.008159090764820576f, 194 | -0.06945706158876419f, 195 | 0.03745569661259651f, 196 | -0.016799751669168472f, 197 | 0.05164659395813942f, 198 | 0.048563942313194275f, 199 | -0.0007197027443908155f, 200 | 0.00350027228705585f, 201 | 0.049179140478372574f, 202 | -0.10629745572805405f, 203 | 0.06949912756681442f, 204 | 0.004620889667421579f, 205 | 0.024071011692285538f, 206 | -0.01602218858897686f, 207 | 0.005312735214829445f, 208 | -0.030950920656323433f, 209 | -0.003942802082747221f, 210 | 0.09100464731454849f, 211 | 0.08968598395586014f, 212 | 0.026305897161364555f, 213 | -0.021082688122987747f, 214 | 0.0569409541785717f, 215 | 0.051254235208034515f, 216 | 0.09538979083299637f, 217 | 0.06134245917201042f, 218 | -0.04533593729138374f, 219 | 0.015669571235775948f, 220 | -0.052699241787195206f, 221 | 0.0024702614173293114f, 222 | -0.02752934768795967f, 223 | -0.017606819048523903f, 224 | 0.02029530517756939f, 225 | -0.020326904952526093f, 226 | 0.042705975472927094f, 227 | 0.07871954143047333f, 228 | -0.0032613887451589108f, 229 | -0.08032066375017166f, 230 | 0.025115517899394035f, 231 | -0.0027041647117584944f, 232 | -0.04574376717209816f, 233 | 0.032349593937397f, 234 | 0.02724195271730423f, 235 | 0.01871751807630062f, 236 | -0.06770546734333038f, 237 | 3.749572778469732e-33f, 238 | -0.031197205185890198f, 239 | 0.02471674047410488f, 240 | 0.025199811905622482f, 241 | 0.02507738396525383f, 242 | 0.07604343444108963f, 243 | -0.06405463814735413f, 244 | 0.021355606615543365f, 245 | 0.01512217242270708f, 246 | -0.011929183267056942f, 247 | 0.026375196874141693f, 248 | 0.006126232445240021f, 249 | -0.0052116625010967255f, 250 | 0.036424219608306885f, 251 | -0.033887360244989395f, 252 | 0.04992717504501343f, 253 | 0.050488606095314026f, 254 | 0.010197697207331657f, 255 | -0.02888398990035057f, 256 | -0.0836842805147171f, 257 | -0.015586711466312408f, 258 | -0.11193014681339264f, 259 | 0.03552526235580444f, 260 | -0.08755801618099213f, 261 | -0.0011521631386131048f, 262 | 0.05781976133584976f, 263 | -0.05618219077587128f, 264 | 0.030507076531648636f, 265 | 0.00858727004379034f, 266 | 0.0405815914273262f, 267 | 0.0014287643134593964f, 268 | -0.0647103562951088f, 269 | -0.004953945521265268f, 270 | -0.11465773731470108f, 271 | -0.05401500687003136f, 272 | -0.05764947831630707f, 273 | 0.013947632163763046f, 274 | 0.07131663709878922f, 275 | -0.09751860797405243f, 276 | 0.02640238218009472f, 277 | -0.05058669671416283f, 278 | 0.003804068546742201f, 279 | 0.030711550265550613f, 280 | -0.02500280737876892f, 281 | 0.007008112035691738f, 282 | -0.0016476899618282914f, 283 | -0.07471956312656403f, 284 | 3.1293591746361926e-05f, 285 | 0.03754337877035141f, 286 | -0.053971726447343826f, 287 | 0.023011285811662674f, 288 | -0.03540029376745224f, 289 | 0.009353957138955593f, 290 | 0.023803526535630226f, 291 | -0.07089712470769882f, 292 | 0.0332937054336071f, 293 | 0.038956545293331146f, 294 | 0.10356861352920532f, 295 | 0.020697198808193207f, 296 | -0.0017958646640181541f, 297 | 0.06905552744865417f, 298 | -0.02837231568992138f, 299 | -0.011508774012327194f, 300 | 0.02831237018108368f, 301 | 0.11676878482103348f, 302 | 0.008890269324183464f, 303 | -0.014236279763281345f, 304 | 0.008686048910021782f, 305 | -0.0027900345157831907f, 306 | -0.01813756488263607f, 307 | -0.08539675176143646f, 308 | 0.09203267842531204f, 309 | -0.015694696456193924f, 310 | 0.01952199824154377f, 311 | -0.059190526604652405f, 312 | 0.012267987243831158f, 313 | -0.036826860159635544f, 314 | -0.03677833825349808f, 315 | 0.0064687347039580345f, 316 | -0.0672142431139946f, 317 | 0.0897166058421135f, 318 | 0.03025389276444912f, 319 | 0.011798573657870293f, 320 | 0.019593898206949234f, 321 | 0.015443249605596066f, 322 | 0.01595941185951233f, 323 | 0.015705550089478493f, 324 | -0.0004455214075278491f, 325 | 0.02854067087173462f, 326 | -0.04562589153647423f, 327 | -0.047964513301849365f, 328 | -0.12343782931566238f, 329 | -0.014698817394673824f, 330 | -0.03646207973361015f, 331 | -0.006195136345922947f, 332 | -0.04197908192873001f, 333 | -2.204232885105739e-08f, 334 | 0.014136684127151966f, 335 | -0.061789777129888535f, 336 | 0.046023767441511154f, 337 | 0.0072265504859387875f, 338 | 0.004496420733630657f, 339 | 0.04225821793079376f, 340 | -0.05978027358651161f, 341 | 0.06642647832632065f, 342 | -0.0011911533074453473f, 343 | -0.011026511900126934f, 344 | 0.07080002129077911f, 345 | 0.022587932646274567f, 346 | -0.05675440654158592f, 347 | -0.01975814253091812f, 348 | 0.08161557465791702f, 349 | -0.005019841715693474f, 350 | 0.044190991669893265f, 351 | -0.0009667372796684504f, 352 | -0.01896727830171585f, 353 | 0.028884947299957275f, 354 | 0.032576825469732285f, 355 | 0.0770535096526146f, 356 | -0.046847864985466f, 357 | 0.07227548956871033f, 358 | 0.09427182376384735f, 359 | 0.0531928576529026f, 360 | 0.08638159930706024f, 361 | 0.10740622878074646f, 362 | -0.09499743580818176f, 363 | 0.025405744090676308f, 364 | -0.01263652928173542f, 365 | 0.04843660816550255f, 366 | -0.007302490994334221f, 367 | 0.048734281212091446f, 368 | -0.015140815638005733f, 369 | -0.10783609747886658f, 370 | 0.05389820411801338f, 371 | 0.024867836385965347f, 372 | 0.05159829556941986f, 373 | -0.008217083290219307f, 374 | -0.07178572565317154f, 375 | 0.09328121691942215f, 376 | -0.023542558774352074f, 377 | 0.05239303782582283f, 378 | -0.03273957967758179f, 379 | 0.07256103307008743f, 380 | -0.024191774427890778f, 381 | -0.00945044495165348f, 382 | -0.014314060099422932f, 383 | 0.017101649194955826f, 384 | -0.015191149897873402f, 385 | -0.029306570068001747f, 386 | 0.06332913041114807f, 387 | 0.030651697888970375f, 388 | 0.10759934782981873f, 389 | -0.004157477058470249f, 390 | 0.06462010741233826f, 391 | 0.06584816426038742f, 392 | -0.04045629873871803f, 393 | -0.07335453480482101f, 394 | 0.05981225520372391f, 395 | -0.05171000584959984f, 396 | -0.04719773307442665f, 397 | -0.033861324191093445f}, 398 | {0.032470982521772385f, 399 | -0.017813976854085922f, 400 | 0.09103047847747803f, 401 | 0.0662265345454216f, 402 | -0.004986118525266647f, 403 | -0.036341968923807144f, 404 | 0.004300299100577831f, 405 | 0.005200391635298729f, 406 | -0.03274604678153992f, 407 | -0.014939035288989544f, 408 | 0.01147277932614088f, 409 | -0.0029152948409318924f, 410 | -0.02211623452603817f, 411 | -0.042097341269254684f, 412 | -0.049638181924819946f, 413 | 0.09333044290542603f, 414 | -0.020927147939801216f, 415 | -0.026947692036628723f, 416 | -0.12432024627923965f, 417 | 0.013082092627882957f, 418 | -0.09639916568994522f, 419 | -0.03806604444980621f, 420 | -0.012412002310156822f, 421 | 0.0435180589556694f, 422 | 0.017384078353643417f, 423 | -0.018446993082761765f, 424 | 0.05199646204710007f, 425 | 0.02032429352402687f, 426 | 0.0744006410241127f, 427 | 0.03578963130712509f, 428 | -0.03855806216597557f, 429 | 0.022090259939432144f, 430 | 0.06836511194705963f, 431 | 0.06377742439508438f, 432 | 0.06869128346443176f, 433 | 0.05386146157979965f, 434 | 0.047937557101249695f, 435 | -0.014405525289475918f, 436 | 0.042463596910238266f, 437 | 0.035660456866025925f, 438 | 0.01611713506281376f, 439 | -0.09118328243494034f, 440 | 0.03768451139330864f, 441 | -0.029440438374876976f, 442 | -0.02313595451414585f, 443 | -0.029354600235819817f, 444 | 0.009454349987208843f, 445 | -0.004506702534854412f, 446 | -0.02972273714840412f, 447 | 0.018433788791298866f, 448 | -0.11473238468170166f, 449 | -0.010125036351382732f, 450 | -0.0013624498387798667f, 451 | 0.024189213290810585f, 452 | 0.023019906133413315f, 453 | -0.04870212450623512f, 454 | -0.011203262023627758f, 455 | 0.004811882507055998f, 456 | 0.005358981899917126f, 457 | 0.009688693098723888f, 458 | -0.032397400587797165f, 459 | 0.00649231718853116f, 460 | -0.0943230614066124f, 461 | 0.05091932415962219f, 462 | 0.04015204682946205f, 463 | 0.006206254940479994f, 464 | -0.12947048246860504f, 465 | 0.03915683180093765f, 466 | 0.0038421349599957466f, 467 | -0.05188129097223282f, 468 | 0.004851798061281443f, 469 | -0.038142792880535126f, 470 | -0.08346209675073624f, 471 | 0.110108382999897f, 472 | 0.01975100301206112f, 473 | 0.024924926459789276f, 474 | 0.10811963677406311f, 475 | -0.03592298552393913f, 476 | 0.11878448724746704f, 477 | 0.01458478532731533f, 478 | 0.03076137788593769f, 479 | -0.0074920919723808765f, 480 | -0.04760413616895676f, 481 | 0.015651658177375793f, 482 | -0.07754474878311157f, 483 | -0.07776712626218796f, 484 | -0.05155258998274803f, 485 | 0.08887834846973419f, 486 | 0.019695332273840904f, 487 | 0.004609257448464632f, 488 | -0.03667515516281128f, 489 | 0.026315441355109215f, 490 | -0.036759596318006516f, 491 | 0.025319784879684448f, 492 | -0.03272395208477974f, 493 | -0.01455981470644474f, 494 | 0.008798807859420776f, 495 | -0.027347559109330177f, 496 | -0.004474400542676449f, 497 | 0.08458873629570007f, 498 | -0.022725073620676994f, 499 | 0.08203046023845673f, 500 | -0.009424270130693913f, 501 | 0.05794243514537811f, 502 | -0.027780253440141678f, 503 | -0.039574600756168365f, 504 | -0.03985302895307541f, 505 | 0.03268566355109215f, 506 | 0.017434800043702126f, 507 | -0.07331137359142303f, 508 | 0.04008691385388374f, 509 | -0.049156155437231064f, 510 | -0.041357241570949554f, 511 | 0.015092026442289352f, 512 | 0.042833536863327026f, 513 | -0.09655860811471939f, 514 | -0.009191427379846573f, 515 | 0.05041028931736946f, 516 | 0.025020930916070938f, 517 | -0.0010903703514486551f, 518 | -0.01153971441090107f, 519 | 0.055329855531454086f, 520 | -0.07093106955289841f, 521 | -0.10202781111001968f, 522 | -0.05860423669219017f, 523 | -0.03908321261405945f, 524 | -0.06963896006345749f, 525 | -6.819881738212989e-33f, 526 | -0.05034259706735611f, 527 | 0.046635307371616364f, 528 | 0.061453815549612045f, 529 | 0.1632499247789383f, 530 | -0.04075726494193077f, 531 | -0.022670969367027283f, 532 | -0.006809735205024481f, 533 | -0.032284870743751526f, 534 | 0.03480874374508858f, 535 | -0.012414581142365932f, 536 | -0.0071273562498390675f, 537 | 0.03807320073246956f, 538 | -0.006105185952037573f, 539 | -0.022729644551873207f, 540 | -0.061231229454278946f, 541 | 0.0619988888502121f, 542 | 0.05873256176710129f, 543 | 0.032221317291259766f, 544 | -0.08261745423078537f, 545 | 0.022939501330256462f, 546 | 0.06733444333076477f, 547 | -0.06015993282198906f, 548 | 0.021170560270547867f, 549 | -0.054680775851011276f, 550 | 0.04987915977835655f, 551 | -0.039290230721235275f, 552 | 0.04173582047224045f, 553 | -0.06061732769012451f, 554 | 0.07556696981191635f, 555 | 0.01709105633199215f, 556 | -0.01861475221812725f, 557 | 0.07065264135599136f, 558 | -0.026991551741957664f, 559 | -0.05504021421074867f, 560 | 0.002307836664840579f, 561 | 0.06408514827489853f, 562 | 0.03251758590340614f, 563 | -0.08594448864459991f, 564 | 0.01902201771736145f, 565 | 0.03420829772949219f, 566 | -0.04273582994937897f, 567 | 0.03365328162908554f, 568 | 0.15104778110980988f, 569 | -0.024075830355286598f, 570 | -0.0030312633607536554f, 571 | 0.03414997085928917f, 572 | 0.021509287878870964f, 573 | -0.08260119706392288f, 574 | 0.05932952091097832f, 575 | -0.022917667403817177f, 576 | -0.10696452856063843f, 577 | 0.003822541097179055f, 578 | -0.05484633520245552f, 579 | -0.0009232646552845836f, 580 | 0.04648015275597572f, 581 | 0.028328044340014458f, 582 | -0.011910024099051952f, 583 | 0.07007365673780441f, 584 | -0.0032808773685246706f, 585 | 0.025030862540006638f, 586 | -0.0588592030107975f, 587 | 0.032608259469270706f, 588 | -0.03625897690653801f, 589 | 0.0575789213180542f, 590 | 0.0219504963606596f, 591 | -0.00031610886799171567f, 592 | 0.008611606433987617f, 593 | 0.005103116389364004f, 594 | 0.0544932521879673f, 595 | 0.02514653280377388f, 596 | -0.01069401390850544f, 597 | 0.040063243359327316f, 598 | 0.0030146585777401924f, 599 | -0.03559613227844238f, 600 | -0.042914632707834244f, 601 | 0.03892488405108452f, 602 | -0.022482896223664284f, 603 | -0.006886311341077089f, 604 | 0.060940083116292953f, 605 | -0.007247471250593662f, 606 | -0.018828019499778748f, 607 | 0.02945023402571678f, 608 | -0.04034523665904999f, 609 | -0.027738675475120544f, 610 | 0.011778023093938828f, 611 | 0.053314242511987686f, 612 | -0.012108864262700081f, 613 | -0.003313975175842643f, 614 | 0.06962978094816208f, 615 | -0.025122152641415596f, 616 | -0.06273254752159119f, 617 | 0.0005014355410821736f, 618 | 0.006060570478439331f, 619 | 0.013680621981620789f, 620 | -0.12261677533388138f, 621 | 3.233225129927534e-33f, 622 | -0.0589633509516716f, 623 | 0.040580566972494125f, 624 | 0.019625846296548843f, 625 | 0.02602434903383255f, 626 | 0.06306370347738266f, 627 | 0.009427609853446484f, 628 | -0.00481993006542325f, 629 | -0.003048526355996728f, 630 | -0.05872064456343651f, 631 | 0.12210481613874435f, 632 | -0.0001253927475772798f, 633 | -0.07909934222698212f, 634 | -0.02846868336200714f, 635 | 0.05321290343999863f, 636 | -0.06128118187189102f, 637 | -0.01313082966953516f, 638 | 0.061504822224378586f, 639 | -0.050759200006723404f, 640 | -0.06360293179750443f, 641 | -0.009319847449660301f, 642 | -0.08136218786239624f, 643 | 0.06484118849039078f, 644 | -0.06656011939048767f, 645 | 0.052685387432575226f, 646 | 0.0656621903181076f, 647 | -0.01696452498435974f, 648 | 0.02894679643213749f, 649 | 0.010055086575448513f, 650 | -0.03425998613238335f, 651 | -0.023136617615818977f, 652 | 0.013063939288258553f, 653 | -0.0001964945695362985f, 654 | -0.07823090255260468f, 655 | -0.049299705773591995f, 656 | -0.01307445764541626f, 657 | 0.07255972176790237f, 658 | 0.11026933789253235f, 659 | -0.13309241831302643f, 660 | 0.03477446362376213f, 661 | -0.031274206936359406f, 662 | 0.005546981003135443f, 663 | -0.014786618761718273f, 664 | -0.04382247105240822f, 665 | -0.022944122552871704f, 666 | -0.005763736087828875f, 667 | -0.02803990989923477f, 668 | 0.03973599895834923f, 669 | -0.06982606649398804f, 670 | 0.055903855711221695f, 671 | -0.01849217340350151f, 672 | -0.029178695753216743f, 673 | -0.010836802423000336f, 674 | -0.03435875475406647f, 675 | -0.0006889880751259625f, 676 | -0.02007262594997883f, 677 | -0.039470091462135315f, 678 | 0.08727303892374039f, 679 | -0.05273575335741043f, 680 | -0.034710273146629333f, 681 | 0.06429289281368256f, 682 | -0.0441647432744503f, 683 | 0.002606051042675972f, 684 | -0.041703928261995316f, 685 | 0.08004797995090485f, 686 | -0.022487612441182137f, 687 | -0.03471028432250023f, 688 | -0.0008381651132367551f, 689 | 0.0909014567732811f, 690 | -0.06743320822715759f, 691 | -0.07035818696022034f, 692 | 0.08347608149051666f, 693 | 0.006671304348856211f, 694 | 0.009866351261734962f, 695 | -0.03535081073641777f, 696 | -0.1096097007393837f, 697 | -0.020150264725089073f, 698 | 0.030442936345934868f, 699 | -0.08029218018054962f, 700 | 0.00021869537886232138f, 701 | 0.008556777611374855f, 702 | 0.021512867882847786f, 703 | -0.10086994618177414f, 704 | 0.06591685861349106f, 705 | 0.02657444402575493f, 706 | 0.03068164363503456f, 707 | 0.00835234671831131f, 708 | -0.01443871483206749f, 709 | 0.026387102901935577f, 710 | -0.057937491685152054f, 711 | -0.061116188764572144f, 712 | -0.07227037847042084f, 713 | -0.06715630739927292f, 714 | -0.029829327017068863f, 715 | -0.04155528545379639f, 716 | -0.01873711682856083f, 717 | -2.313823621591382e-08f, 718 | -0.04755893349647522f, 719 | -0.011731435544788837f, 720 | 0.03431084752082825f, 721 | 0.04400825873017311f, 722 | -0.060959603637456894f, 723 | 0.09364251792430878f, 724 | -0.14093442261219025f, 725 | 0.0018414495280012488f, 726 | 0.006476705428212881f, 727 | 0.04462616890668869f, 728 | -0.02774716541171074f, 729 | -0.040246184915304184f, 730 | -0.0867241695523262f, 731 | -0.0030751472804695368f, 732 | 0.06490536034107208f, 733 | 0.001515702810138464f, 734 | 0.020403634756803513f, 735 | 0.0036018516402691603f, 736 | 0.003295356407761574f, 737 | -0.01664314605295658f, 738 | 0.05650820955634117f, 739 | 0.0651976615190506f, 740 | 0.017309093847870827f, 741 | -0.008219453506171703f, 742 | 0.045807670801877975f, 743 | 0.028303757309913635f, 744 | 0.0075932214967906475f, 745 | 0.10886969417333603f, 746 | -0.09187823534011841f, 747 | -0.01187183428555727f, 748 | -0.08677797019481659f, 749 | 0.06410083174705505f, 750 | -0.048447176814079285f, 751 | 0.03530236706137657f, 752 | -0.006899421568959951f, 753 | -0.057062678039073944f, 754 | 0.0431440994143486f, 755 | 0.00041753044934011996f, 756 | 0.006904853507876396f, 757 | -0.007148951292037964f, 758 | -0.01590801402926445f, 759 | 0.057071223855018616f, 760 | -0.07607273757457733f, 761 | 0.020105505362153053f, 762 | 0.04024725407361984f, 763 | 0.08025668561458588f, 764 | 0.02379310131072998f, 765 | -0.040123388171195984f, 766 | -0.00920546893030405f, 767 | 0.05223090201616287f, 768 | 0.041356317698955536f, 769 | -0.02788914553821087f, 770 | 0.19586071372032166f, 771 | 0.0018509950023144484f, 772 | 0.05987996608018875f, 773 | 0.046917349100112915f, 774 | 0.02290949784219265f, 775 | 0.04779403284192085f, 776 | -0.029831457883119583f, 777 | -0.013816923834383488f, 778 | 0.12134341895580292f, 779 | -0.0315195731818676f, 780 | -0.0016703865258023143f, 781 | -0.05963480845093727f}}; 782 | 783 | public static boolean compareFloatArrays(Float[] array1, float[] array2, float marginOfError) { 784 | if (array1.length != array2.length) { 785 | return false; // Arrays are of different lengths 786 | } 787 | 788 | for (int i = 0; i < array1.length; i++) { 789 | if (Math.abs(array1[i] - array2[i]) > marginOfError) { 790 | System.out.println("Difference at index " + i + ": " + Math.abs(array1[i] - array2[i]) + " > " + marginOfError); 791 | System.out.println("Expected: " + array2[i] + ", Actual: " + array1[i]); 792 | System.out.println("abs(expected - actual) / abs(expected): " + Math.abs(array2[i] - array1[i]) / Math.abs(array2[i])); 793 | return false; // Difference exceeds margin of error 794 | } 795 | } 796 | 797 | return true; // All elements are within the margin of error 798 | } 799 | 800 | @Test 801 | public void testDefaultEmbeddingFunction() throws Exception { 802 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 803 | List result = ef.embedDocuments(Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, I am Jane and I am an ML researcher.")); 804 | assertEquals(2, result.size()); 805 | assertEquals(384, result.get(0).getDimensions()); 806 | for (int i = 0; i < result.size(); i++) { 807 | assertTrue(compareFloatArrays(result.get(i).asList().toArray(new Float[0]), groundThruth[i], 1e-5f)); 808 | } 809 | } 810 | 811 | @Test 812 | public void testDownloadModel() throws Exception { 813 | // delete directory if exists 814 | if (DefaultEmbeddingFunction.MODEL_CACHE_DIR.toFile().exists()) { 815 | FileUtils.deleteDirectory(DefaultEmbeddingFunction.MODEL_CACHE_DIR.toFile()); 816 | } 817 | EmbeddingFunction ef = new DefaultEmbeddingFunction(); 818 | 819 | assertTrue(DefaultEmbeddingFunction.MODEL_CACHE_DIR.toFile().exists()); 820 | 821 | } 822 | 823 | 824 | } 825 | -------------------------------------------------------------------------------- /src/test/java/tech/amikos/chromadb/embeddings/cohere/TestCohereEmbeddings.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.cohere; 2 | 3 | import org.junit.Test; 4 | import tech.amikos.chromadb.EFException; 5 | import tech.amikos.chromadb.Embedding; 6 | import tech.amikos.chromadb.Utils; 7 | import tech.amikos.chromadb.embeddings.WithParam; 8 | 9 | import java.util.List; 10 | 11 | import static org.junit.Assert.*; 12 | 13 | public class TestCohereEmbeddings { 14 | 15 | 16 | @Test 17 | public void testEmbedDocuments() throws EFException { 18 | Utils.loadEnvFile(".env"); 19 | String apiKey = Utils.getEnvOrProperty("COHERE_API_KEY"); 20 | CohereEmbeddingFunction ef = new CohereEmbeddingFunction(WithParam.apiKey(apiKey)); 21 | List results = ef.embedDocuments(new String[]{"Hello world", "How are you?"}); 22 | assertNotNull(results); 23 | assertEquals(2, results.size()); 24 | assertEquals(4096, results.get(0).getDimensions()); 25 | } 26 | 27 | @Test 28 | public void testEmbedQuery() throws EFException { 29 | Utils.loadEnvFile(".env"); 30 | String apiKey = Utils.getEnvOrProperty("COHERE_API_KEY"); 31 | CohereEmbeddingFunction ef = new CohereEmbeddingFunction(WithParam.apiKey(apiKey)); 32 | Embedding results = ef.embedQuery("How are you?"); 33 | assertNotNull(results); 34 | assertEquals(4096, results.getDimensions()); 35 | } 36 | 37 | @Test 38 | public void testWithModel() throws EFException { 39 | Utils.loadEnvFile(".env"); 40 | String apiKey = Utils.getEnvOrProperty("COHERE_API_KEY"); 41 | CohereEmbeddingFunction ef = new CohereEmbeddingFunction(WithParam.apiKey(apiKey), WithParam.model("embed-english-light-v3.0")); 42 | Embedding results = ef.embedQuery("How are you?"); 43 | assertNotNull(results); 44 | assertEquals(384, results.getDimensions()); 45 | } 46 | 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/test/java/tech/amikos/chromadb/embeddings/hf/TestHuggingFaceEmbeddings.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.hf; 2 | 3 | import org.junit.BeforeClass; 4 | import org.junit.Test; 5 | import org.testcontainers.containers.GenericContainer; 6 | import org.testcontainers.containers.wait.strategy.Wait; 7 | import tech.amikos.chromadb.*; 8 | import tech.amikos.chromadb.embeddings.EmbeddingFunction; 9 | import tech.amikos.chromadb.embeddings.WithParam; 10 | import tech.amikos.chromadb.handler.ApiException; 11 | 12 | import java.util.*; 13 | 14 | import static org.junit.Assert.assertEquals; 15 | import static org.junit.Assert.assertNotNull; 16 | 17 | public class TestHuggingFaceEmbeddings { 18 | static GenericContainer hfeiContainer; 19 | 20 | @BeforeClass 21 | public static void setup() throws Exception { 22 | Utils.loadEnvFile(".env"); 23 | 24 | try { 25 | hfeiContainer = new GenericContainer("ghcr.io/huggingface/text-embeddings-inference:cpu-1.5.0") 26 | .withCommand("--model-id Snowflake/snowflake-arctic-embed-s --revision main") 27 | .withExposedPorts(80) 28 | .waitingFor(Wait.forHttp("/").forStatusCode(200)); 29 | hfeiContainer.start(); 30 | System.setProperty("HFEI_URL", "http://" + hfeiContainer.getHost() + ":" + hfeiContainer.getMappedPort(80)); 31 | } catch (Exception e) { 32 | System.err.println("HFEI container failed to start"); 33 | throw e; 34 | } 35 | } 36 | 37 | @Test 38 | public void testEmbedDocuments() throws ApiException, EFException { 39 | String apiKey = Utils.getEnvOrProperty("HF_API_KEY"); 40 | EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(WithParam.apiKey(apiKey)); 41 | List results = ef.embedDocuments(Arrays.asList("Hello world", "How are you?")); 42 | assertEquals(2, results.size()); 43 | assertEquals(384, results.get(0).getDimensions()); 44 | } 45 | 46 | @Test 47 | public void testEmbedQuery() throws ApiException, EFException { 48 | String apiKey = Utils.getEnvOrProperty("HF_API_KEY"); 49 | EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(WithParam.apiKey(apiKey)); 50 | Embedding results = ef.embedQuery("How are you?"); 51 | assertNotNull(results); 52 | assertEquals(384, results.getDimensions()); 53 | } 54 | 55 | @Test 56 | public void testWithModel() throws ApiException, EFException { 57 | String apiKey = Utils.getEnvOrProperty("HF_API_KEY"); 58 | EmbeddingFunction ef = new HuggingFaceEmbeddingFunction(WithParam.apiKey(apiKey), WithParam.model("sentence-transformers/all-mpnet-base-v2")); 59 | Embedding results = ef.embedQuery("How are you?"); 60 | assertNotNull(results); 61 | assertEquals(768, results.getDimensions()); 62 | } 63 | 64 | @Test 65 | public void testWithURL() throws EFException { 66 | EmbeddingFunction ef = new HuggingFaceEmbeddingFunction( 67 | WithParam.baseAPI(System.getProperty("HFEI_URL")), 68 | new HuggingFaceEmbeddingFunction.WithAPIType(HuggingFaceEmbeddingFunction.APIType.HFEI_API)); 69 | Embedding results = ef.embedQuery("How are you?"); 70 | assertNotNull(results); 71 | assertEquals(384, results.getDimensions()); 72 | } 73 | } 74 | 75 | -------------------------------------------------------------------------------- /src/test/java/tech/amikos/chromadb/embeddings/ollama/TestOllamaEmbeddings.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.ollama; 2 | 3 | import org.junit.AfterClass; 4 | import org.junit.BeforeClass; 5 | import org.junit.Test; 6 | import org.testcontainers.containers.GenericContainer; 7 | import org.testcontainers.containers.wait.strategy.Wait; 8 | import tech.amikos.chromadb.EFException; 9 | import tech.amikos.chromadb.Embedding; 10 | import tech.amikos.chromadb.embeddings.WithParam; 11 | 12 | import java.util.Arrays; 13 | import java.util.List; 14 | 15 | import static org.junit.Assert.*; 16 | 17 | public class TestOllamaEmbeddings { 18 | static GenericContainer ollamaContainer; 19 | 20 | @BeforeClass 21 | public static void setupChromaDB() throws Exception { 22 | try { 23 | ollamaContainer = new GenericContainer("ollama/ollama:latest") 24 | .withExposedPorts(11434) 25 | .waitingFor(Wait.forHttp("/api/version").forStatusCode(200)); 26 | ollamaContainer.start(); 27 | ollamaContainer.waitingFor(Wait.forHttp("/api/version").forStatusCode(200)); 28 | System.setProperty("OLLAMA_URL", "http://" + ollamaContainer.getHost() + ":" + ollamaContainer.getMappedPort(11434) + "/api/embed"); 29 | ollamaContainer.execInContainer("ollama", "pull", "nomic-embed-text"); 30 | } catch (Exception e) { 31 | System.err.println("Ollama container failed to start"); 32 | throw e; 33 | } 34 | } 35 | 36 | @AfterClass 37 | public static void teardownChromaDB() { 38 | if (ollamaContainer != null) { 39 | ollamaContainer.stop(); 40 | } 41 | } 42 | 43 | @Test 44 | public void testOllamaEmbedDocuments() throws EFException { 45 | OllamaEmbeddingFunction ef = new OllamaEmbeddingFunction(WithParam.baseAPI(System.getProperty("OLLAMA_URL"))); 46 | List results = ef.embedDocuments(Arrays.asList("Hello, my name is John. I am a Data Scientist.", "Hello, I am Jane and I am an ML researcher.")); 47 | assertEquals(2, results.size()); 48 | assertEquals(768, results.get(0).getDimensions()); 49 | } 50 | 51 | @Test 52 | public void testOllamaEmbedQuery() throws EFException { 53 | OllamaEmbeddingFunction ef = new OllamaEmbeddingFunction(WithParam.baseAPI(System.getProperty("OLLAMA_URL"))); 54 | Embedding results = ef.embedQuery("Hello, I am Jane and I am an ML researcher."); 55 | assertNotNull(results); 56 | assertEquals(768, results.getDimensions()); 57 | } 58 | 59 | 60 | } 61 | -------------------------------------------------------------------------------- /src/test/java/tech/amikos/chromadb/embeddings/openai/TestOpenAIEmbeddings.java: -------------------------------------------------------------------------------- 1 | package tech.amikos.chromadb.embeddings.openai; 2 | 3 | import org.junit.BeforeClass; 4 | import org.junit.Test; 5 | import tech.amikos.chromadb.EFException; 6 | import tech.amikos.chromadb.Embedding; 7 | import tech.amikos.chromadb.Utils; 8 | import tech.amikos.chromadb.embeddings.WithParam; 9 | 10 | import java.util.Arrays; 11 | import java.util.List; 12 | 13 | import static org.junit.Assert.assertEquals; 14 | import static org.junit.Assert.assertNotNull; 15 | 16 | public class TestOpenAIEmbeddings { 17 | @BeforeClass 18 | public static void setup() { 19 | Utils.loadEnvFile(".env"); 20 | } 21 | 22 | @Test 23 | public void testSerialization() { 24 | CreateEmbeddingRequest req = new CreateEmbeddingRequest(); 25 | req.model("text-ada"); 26 | req.user("user-1234567890"); 27 | req.input(new CreateEmbeddingRequest.Input("Hello, my name is John. I am a Data Scientist.")); 28 | assertEquals("{\"model\":\"text-ada\",\"user\":\"user-1234567890\",\"input\":\"Hello, my name is John. I am a Data Scientist.\"}", req.json()); 29 | } 30 | 31 | @Test 32 | public void testSerializationListOfStrings() { 33 | CreateEmbeddingRequest req = new CreateEmbeddingRequest(); 34 | req.model("text-ada"); 35 | req.user("user-1234567890"); 36 | req.input(new CreateEmbeddingRequest.Input(new String[]{"Hello, my name is John. I am a Data Scientist.", "Hello, my name is John. I am a Data Scientist."})); 37 | assertEquals("{\"model\":\"text-ada\",\"user\":\"user-1234567890\",\"input\":[\"Hello, my name is John. I am a Data Scientist.\",\"Hello, my name is John. I am a Data Scientist.\"]}", req.json()); 38 | } 39 | 40 | @Test 41 | public void testSerializationListOfIntegers() { 42 | CreateEmbeddingRequest req = new CreateEmbeddingRequest(); 43 | req.model("text-ada"); 44 | req.user("user-1234567890"); 45 | req.input(new CreateEmbeddingRequest.Input(new Integer[]{1, 2, 3, 4, 5})); 46 | assertEquals("{\"model\":\"text-ada\",\"user\":\"user-1234567890\",\"input\":[1,2,3,4,5]}", req.json()); 47 | } 48 | 49 | @Test 50 | public void testSerializationListOfListOfIntegers() { 51 | CreateEmbeddingRequest req = new CreateEmbeddingRequest(); 52 | req.model("text-ada"); 53 | req.user("user-1234567890"); 54 | List list = Arrays.asList(new Integer[]{1, 2, 3}, new Integer[]{4, 5, 6}); 55 | req.input(new CreateEmbeddingRequest.Input(list)); 56 | assertEquals("{\"model\":\"text-ada\",\"user\":\"user-1234567890\",\"input\":[[1,2,3],[4,5,6]]}", req.json()); 57 | } 58 | 59 | @Test 60 | public void testEmbedDocuments() throws EFException { 61 | String apiKey = Utils.getEnvOrProperty("OPENAI_API_KEY"); 62 | List embeddings = new OpenAIEmbeddingFunction( 63 | WithParam.apiKey(apiKey), 64 | WithParam.model("text-embedding-3-small")) 65 | .embedDocuments(Arrays.asList("Hello, my name is John. I am a Data Scientist.")); 66 | assertNotNull(embeddings); 67 | assertEquals(1536, embeddings.get(0).getDimensions()); 68 | } 69 | 70 | @Test 71 | public void testEFBuilderWithCustomURL() throws EFException { 72 | String apiKey = Utils.getEnvOrProperty("OPENAI_API_KEY"); 73 | List embeddings = new OpenAIEmbeddingFunction( 74 | WithParam.apiKey(apiKey), 75 | WithParam.model("text-embedding-3-small"), 76 | WithParam.baseAPI("https://api.openai.com/v1/embeddings")) 77 | .embedDocuments(Arrays.asList("Hello, my name is John. I am a Data Scientist.")); 78 | assertNotNull(embeddings); 79 | assertEquals(1536, embeddings.get(0).getDimensions()); 80 | } 81 | 82 | 83 | @Test 84 | public void testEmbedQuery() throws EFException { 85 | String apiKey = Utils.getEnvOrProperty("OPENAI_API_KEY"); 86 | CreateEmbeddingRequest req = new CreateEmbeddingRequest().model("text-embedding-3-small"); 87 | OpenAIEmbeddingFunction ef = new OpenAIEmbeddingFunction(WithParam.apiKey(apiKey)); 88 | Embedding embeddings = ef.embedQuery("Hello, my name is John. I am a Data Scientist."); 89 | assertNotNull(embeddings); 90 | assertEquals(1536, embeddings.getDimensions()); 91 | } 92 | } 93 | --------------------------------------------------------------------------------