├── .github └── workflows │ ├── CI.yml │ ├── PyPI.yml │ └── TestPyPI.yml ├── .gitignore ├── .python-version ├── LICENSE ├── README.md ├── demo_ss.png ├── pyproject.toml ├── pyrightconfig.json ├── requirements.txt ├── src ├── streamlit_chromadb_connection.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── requires.txt │ └── top_level.txt └── streamlit_chromadb_connection │ ├── __init__.py │ └── chromadb_connection.py └── tests └── unit_tests ├── collection_test.py └── connect_test.py /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ "dev/*", "fix/*" ] 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - name: Checkout branch 13 | uses: actions/checkout@v4 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v4 17 | 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install -r requirements.txt 22 | 23 | - name: Unit test 24 | run: python -m unittest discover -s tests/unit_tests -p '*_test.py' 25 | -------------------------------------------------------------------------------- /.github/workflows/PyPI.yml: -------------------------------------------------------------------------------- 1 | name: PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | publish: 9 | name: Publish to PyPI 10 | runs-on: ubuntu-latest 11 | environment: 12 | name: pypi 13 | url: https://pypi.org/project/streamlit-chromadb-connection/ 14 | permissions: 15 | id-token: write 16 | 17 | steps: 18 | - name: Checkout branch 19 | uses: actions/checkout@v4 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v4 23 | 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install -r requirements.txt 28 | 29 | - name: Build package 30 | run: python -m build 31 | 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@release/v1 34 | with: 35 | password: ${{ secrets.PYPI_API_TOKEN }} 36 | -------------------------------------------------------------------------------- /.github/workflows/TestPyPI.yml: -------------------------------------------------------------------------------- 1 | name: TestPyPI 2 | 3 | on: 4 | push: 5 | branches: [ "main"] 6 | 7 | jobs: 8 | publish: 9 | name: Publish to TestPyPI 10 | runs-on: ubuntu-latest 11 | environment: 12 | name: testpypi 13 | url: https://test.pypi.org/project/streamlit-chromadb-connection/ 14 | permissions: 15 | id-token: write 16 | 17 | steps: 18 | - name: Checkout branch 19 | uses: actions/checkout@v4 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v4 23 | 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install -r requirements.txt 28 | 29 | - name: Build package 30 | run: python -m build 31 | 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@release/v1 34 | with: 35 | repository-url: https://test.pypi.org/legacy/ 36 | password: ${{ secrets.TEST_PYPI_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *venv 2 | **__pycache__** 3 | **dist** 4 | **.zed** 5 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 The Python Packaging Authority 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 📂 ChromaDBConnection 2 | 3 | ![Demo Screen Shot](https://github.com/Dev317/streamlit_chromadb_connection/blob/236d4c4cecbd56c19695f55b20b58492518e8300/demo_ss.png?raw=True) 4 | 5 | Connection for Chroma vector database, `ChromaDBConnection`, has been released which makes it easy to connect any Streamlit LLM-powered app to. 6 | 7 | With `st.connection()`, connecting to a Chroma vector database becomes just a few lines of code: 8 | 9 | 10 | ```python 11 | import streamlit as st 12 | from streamlit_chromadb_connection.chromadb_connection import ChromadbConnection 13 | 14 | configuration = { 15 | "client": "PersistentClient", 16 | "path": "/tmp/.chroma" 17 | } 18 | 19 | collection_name = "documents_collection" 20 | 21 | conn = st.connection("chromadb", 22 | type=ChromaDBConnection, 23 | **configuration) 24 | documents_collection_df = conn.get_collection_data(collection_name) 25 | st.dataframe(documents_collection_df) 26 | ``` 27 | 28 | ## 📑 ChromaDBConnection API 29 | 30 | ### _connect() 31 | There are 2 ways to connect to a Chroma client: 32 | 1. **PersistentClient**: Data will be persisted to a local machine 33 | ```python 34 | import streamlit as st 35 | from streamlit_chromadb_connection.chromadb_connection import ChromadbConnection 36 | 37 | configuration = { 38 | "client": "PersistentClient", 39 | "path": "/tmp/.chroma" 40 | } 41 | 42 | conn = st.connection(name="persistent_chromadb", 43 | type=ChromadbConnection, 44 | **configuration) 45 | ``` 46 | 47 | 2. **HttpClient**: Data will be persisted to a cloud server where Chroma resides 48 | ```python 49 | import streamlit as st 50 | from streamlit_chromadb_connection.chromadb_connection import ChromadbConnection 51 | 52 | configuration = { 53 | "client": "HttpClient", 54 | "host": "localhost", 55 | "port": 8000, 56 | } 57 | 58 | conn = st.connection(name="http_connection", 59 | type=ChromadbConnection, 60 | **configuration) 61 | ``` 62 | 63 | 64 | ### create_collection() 65 | In order to create a Chroma collection, one needs to supply a `collection_name` and `embedding_function_name`, `embedding_config` and (optional) `metadata`. 66 | 67 | There are current possible options for `embedding_function_name`: 68 | - DefaultEmbeddingFunction 69 | - SentenceTransformerEmbeddingFunction 70 | - OpenAIEmbeddingFunction 71 | - CohereEmbeddingFunction 72 | - GooglePalmEmbeddingFunction 73 | - GoogleVertexEmbeddingFunction 74 | - HuggingFaceEmbeddingFunction 75 | - InstructorEmbeddingFunction 76 | - Text2VecEmbeddingFunction 77 | - ONNXMiniLM_L6_V2 78 | 79 | For `DefaultEmbeddingFunction`, the `embedding_config` argument can be left as an empty string. However, for other embedding functions such as `OpenAIEmbeddingFunction`, one needs to provide configuration such as: 80 | 81 | ```python 82 | embedding_config = { 83 | api_key: "{OPENAI_API_KEY}", 84 | model_name: "{OPENAI_MODEL}", 85 | } 86 | ``` 87 | 88 | One can also change the distance function by changing the `metadata` argument, such as: 89 | 90 | ```python 91 | metadata = {"hnsw:space": "l2"} # Squared L2 norm 92 | metadata = {"hnsw:space": "cosine"} # Cosine similarity 93 | metadata = {"hnsw:space": "ip"} # Inner product 94 | ``` 95 | 96 | Sample code to create connection: 97 | 98 | ```python 99 | collection_name = "documents_collection" 100 | embedding_function_name = "DefaultEmbeddingFunction" 101 | conn.create_collection(collection_name=collection_name, 102 | embedding_function_name=embedding_function_name, 103 | embedding_config={}, 104 | metadata = {"hnsw:space": "cosine"}) 105 | ``` 106 | 107 | ### get_collection_data() 108 | This method returns a dataframe that consists of the embeddings and documents of a collection. 109 | The `attributes` argument is a list of attributes to be included in the DataFrame. 110 | The following code snippet will return all data in a collection in the form of a DataFrame, with 2 columns: `documents` and `embeddings`. 111 | 112 | ```python 113 | collection_name = "documents_collection" 114 | conn.get_collection_data(collection_name=collection_name, 115 | attributes= ["documents", "embeddings"]) 116 | ``` 117 | 118 | ### delete_collection() 119 | This method deletes the stated collection name. 120 | 121 | ```python 122 | collection_name = "documents_collection" 123 | conn.delete_collection(collection_name=collection_name) 124 | ``` 125 | 126 | ### upload_documents() 127 | This method uploads documents to a collection. 128 | If embeddings are not provided, the method will embed the documents using the embedding function specified in the collection. 129 | 130 | 131 | ```python 132 | collection_name = "documents_collection" 133 | embedding_function_name = "DefaultEmbeddingFunction" 134 | embedding_config = {} 135 | conn.upload_documents(collection_name=collection_name, 136 | documents=["lorem ipsum", "doc2", "doc3"], 137 | metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}], 138 | embeeding_function_name=embedding_function_name, 139 | embedding_config=embedding_config, 140 | ids=["id1", "id2", "id3"]) 141 | ``` 142 | 143 | ### update_collection_data() 144 | This method updates documents in a collection based on their ids. 145 | 146 | ```python 147 | embedding_function_name = "DefaultEmbeddingFunction" 148 | embedding_config = {} 149 | conn.upload_documents(collection_name=collection_name, 150 | documents=["this is a", "this is b", "this is c"], 151 | metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}], 152 | embeeding_function_name=embedding_function_name, 153 | embedding_config=embedding_config, 154 | ids=["id1", "id2", "id3"]) 155 | 156 | conn.update_collection_data(collection_name=collection_name, 157 | documents=["this is b", "this is c", "this is d"], 158 | embeeding_function_name=embedding_function_name, 159 | embedding_config=embedding_config, 160 | ids=["id1", "id2", "id3"]) 161 | ``` 162 | 163 | ### query() 164 | This method retrieves top k relevant document based on a list of queries supplied. 165 | The result will be in a dataframe where each row will shows the top k relevant documents of each query. 166 | 167 | ```python 168 | collection_name = "documents_collection" 169 | embedding_function_name = "DefaultEmbeddingFunction" 170 | embedding_config = {} 171 | conn.upload_documents(collection_name=collection_name, 172 | documents=["lorem ipsum", "doc2", "doc3"], 173 | metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}], 174 | ids=["id1", "id2", "id3"], 175 | embeeding_function_name=embedding_function_name, 176 | embedding_config=embedding_config, 177 | embeddings=None) 178 | 179 | queried_data = conn.query(collection_name=collection_name, 180 | query=["random_query1", "random_query2"], 181 | num_results_limit=10, 182 | attributes=["documents", "embeddings", "metadatas", "data"]) 183 | ``` 184 | 185 | Metadata and document filters are also provided in `where_metadata_filter` and `where_document_filter` arguments respectively for more relevant search. For better understanding on the usage of where filters, please refer to: https://docs.trychroma.com/usage-guide#using-where-filters 186 | 187 | ```python 188 | queried_data = conn.query(collection_name=collection_name, 189 | query=["this is"], 190 | num_results_limit=10, 191 | attributes=["documents", "embeddings", "metadatas", "data"], 192 | where_metadata_filter={"chapter": "3"}) 193 | ``` 194 | 195 | 196 | *** 197 | 🎉 That's it! `ChromaDBConnection` is ready to be used with `st.connection()`. 🎉 198 | *** 199 | 200 | ## Contribution 🔥 201 | ``` 202 | author={Vu Quang Minh}, 203 | github={Dev317}, 204 | year={2023} 205 | ``` 206 | -------------------------------------------------------------------------------- /demo_ss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dev317/streamlit_chromadb_connection/5b89190df285c2e46f02bb10c10d05b8d6d98962/demo_ss.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "streamlit_chromadb_connection" 3 | version = "1.0.5" 4 | authors = [ 5 | { name="Dev317", email="mineskiroxro@gmail.com" }, 6 | ] 7 | description = "A simple adapter connection for any Streamlit LLM-powered app to use ChromaDB vector database." 8 | readme = "README.md" 9 | requires-python = ">=3.8" 10 | classifiers = [ 11 | "Programming Language :: Python :: 3", 12 | "License :: OSI Approved :: MIT License", 13 | "Operating System :: OS Independent", 14 | ] 15 | dynamic = ["dependencies"] 16 | 17 | [tool.setuptools.dynamic] 18 | dependencies = {file = ["requirements.txt"]} 19 | 20 | [project.urls] 21 | Homepage = "https://github.com/Dev317/streamlit_chromadb_connection" 22 | Issues = "https://github.com/Dev317/streamlit_chromadb_connection/issues" 23 | 24 | [build-system] 25 | requires = ["setuptools>=61.0"] 26 | build-backend = "setuptools.build_meta" 27 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "venvPath": ".", 3 | "venv": "venv" 4 | } 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | altair==5.3.0 2 | annotated-types==0.7.0 3 | anyio==4.4.0 4 | asgiref==3.8.1 5 | attrs==23.2.0 6 | backoff==2.2.1 7 | bcrypt==4.1.3 8 | blinker==1.8.2 9 | build==1.2.1 10 | cachetools==5.3.3 11 | certifi==2024.7.4 12 | charset-normalizer==3.3.2 13 | chroma-hnswlib==0.7.3 14 | chromadb==0.5.3 15 | click==8.1.7 16 | coloredlogs==15.0.1 17 | Deprecated==1.2.14 18 | dnspython==2.6.1 19 | email_validator==2.2.0 20 | fastapi==0.111.0 21 | fastapi-cli==0.0.4 22 | filelock==3.15.4 23 | flatbuffers==24.3.25 24 | fsspec==2024.6.1 25 | gitdb==4.0.11 26 | GitPython==3.1.43 27 | google-auth==2.32.0 28 | googleapis-common-protos==1.63.2 29 | grpcio==1.64.1 30 | h11==0.14.0 31 | httpcore==1.0.5 32 | httptools==0.6.1 33 | httpx==0.27.0 34 | huggingface-hub==0.23.4 35 | humanfriendly==10.0 36 | idna==3.7 37 | importlib_metadata==7.1.0 38 | importlib_resources==6.4.0 39 | Jinja2==3.1.4 40 | jsonschema==4.23.0 41 | jsonschema-specifications==2023.12.1 42 | kubernetes==30.1.0 43 | markdown-it-py==3.0.0 44 | MarkupSafe==2.1.5 45 | mdurl==0.1.2 46 | mmh3==4.1.0 47 | monotonic==1.6 48 | mpmath==1.3.0 49 | numpy==1.26.4 50 | oauthlib==3.2.2 51 | onnxruntime==1.18.1 52 | opentelemetry-api==1.25.0 53 | opentelemetry-exporter-otlp-proto-common==1.25.0 54 | opentelemetry-exporter-otlp-proto-grpc==1.25.0 55 | opentelemetry-instrumentation==0.46b0 56 | opentelemetry-instrumentation-asgi==0.46b0 57 | opentelemetry-instrumentation-fastapi==0.46b0 58 | opentelemetry-proto==1.25.0 59 | opentelemetry-sdk==1.25.0 60 | opentelemetry-semantic-conventions==0.46b0 61 | opentelemetry-util-http==0.46b0 62 | orjson==3.10.6 63 | overrides==7.7.0 64 | packaging==24.1 65 | pandas==2.2.2 66 | pillow==10.4.0 67 | posthog==3.5.0 68 | protobuf==4.25.3 69 | pyarrow==16.1.0 70 | pyasn1==0.6.0 71 | pyasn1_modules==0.4.0 72 | pydantic==2.8.2 73 | pydantic_core==2.20.1 74 | pydeck==0.9.1 75 | Pygments==2.18.0 76 | PyPika==0.48.9 77 | pyproject_hooks==1.1.0 78 | python-dateutil==2.9.0.post0 79 | python-dotenv==1.0.1 80 | python-multipart==0.0.9 81 | pytz==2024.1 82 | PyYAML==6.0.1 83 | referencing==0.35.1 84 | requests==2.32.3 85 | requests-oauthlib==2.0.0 86 | rich==13.7.1 87 | rpds-py==0.19.0 88 | rsa==4.9 89 | setuptools==70.3.0 90 | shellingham==1.5.4 91 | six==1.16.0 92 | smmap==5.0.1 93 | sniffio==1.3.1 94 | starlette==0.37.2 95 | streamlit==1.36.0 96 | sympy==1.13.0 97 | tenacity==8.5.0 98 | tokenizers==0.19.1 99 | toml==0.10.2 100 | toolz==0.12.1 101 | tornado==6.4.1 102 | tqdm==4.66.4 103 | typer==0.12.3 104 | typing_extensions==4.12.2 105 | tzdata==2024.1 106 | ujson==5.10.0 107 | urllib3==2.2.2 108 | uvicorn==0.30.1 109 | uvloop==0.19.0 110 | watchfiles==0.22.0 111 | websocket-client==1.8.0 112 | websockets==12.0 113 | wrapt==1.16.0 114 | zipp==3.19.2 115 | -------------------------------------------------------------------------------- /src/streamlit_chromadb_connection.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: streamlit-chromadb-connection 3 | Version: 0.0.5 4 | Summary: A simple adapter connection for any Streamlit LLM-powered app to use ChromaDB vector database. 5 | Author-email: Dev317 6 | Project-URL: Homepage, https://github.com/Dev317/streamlit_chromadb_connection 7 | Project-URL: Issues, https://github.com/Dev317/streamlit_chromadb_connection/issues 8 | Classifier: Programming Language :: Python :: 3 9 | Classifier: License :: OSI Approved :: MIT License 10 | Classifier: Operating System :: OS Independent 11 | Requires-Python: >=3.8 12 | Description-Content-Type: text/markdown 13 | License-File: LICENSE 14 | Requires-Dist: altair==5.2.0 15 | Requires-Dist: annotated-types==0.6.0 16 | Requires-Dist: anyio==3.7.1 17 | Requires-Dist: asgiref==3.7.2 18 | Requires-Dist: attrs==23.1.0 19 | Requires-Dist: backoff==2.2.1 20 | Requires-Dist: bcrypt==4.1.1 21 | Requires-Dist: blinker==1.7.0 22 | Requires-Dist: build==1.0.3 23 | Requires-Dist: cachetools==5.3.2 24 | Requires-Dist: certifi==2023.11.17 25 | Requires-Dist: charset-normalizer==3.3.2 26 | Requires-Dist: chroma-hnswlib==0.7.3 27 | Requires-Dist: chromadb==0.4.18 28 | Requires-Dist: click==8.1.7 29 | Requires-Dist: coloredlogs==15.0.1 30 | Requires-Dist: Deprecated==1.2.14 31 | Requires-Dist: docutils==0.20.1 32 | Requires-Dist: exceptiongroup==1.2.0 33 | Requires-Dist: fastapi==0.104.1 34 | Requires-Dist: filelock==3.13.1 35 | Requires-Dist: flatbuffers==23.5.26 36 | Requires-Dist: fsspec==2023.10.0 37 | Requires-Dist: gitdb==4.0.11 38 | Requires-Dist: GitPython==3.1.40 39 | Requires-Dist: google-auth==2.24.0 40 | Requires-Dist: googleapis-common-protos==1.61.0 41 | Requires-Dist: grpcio==1.59.3 42 | Requires-Dist: h11==0.14.0 43 | Requires-Dist: httptools==0.6.1 44 | Requires-Dist: huggingface-hub==0.19.4 45 | Requires-Dist: humanfriendly==10.0 46 | Requires-Dist: idna==3.6 47 | Requires-Dist: importlib-metadata==6.9.0 48 | Requires-Dist: importlib-resources==6.1.1 49 | Requires-Dist: jaraco.classes==3.3.0 50 | Requires-Dist: Jinja2==3.1.2 51 | Requires-Dist: jsonschema==4.20.0 52 | Requires-Dist: jsonschema-specifications==2023.11.2 53 | Requires-Dist: keyring==24.3.0 54 | Requires-Dist: kubernetes==28.1.0 55 | Requires-Dist: markdown-it-py==3.0.0 56 | Requires-Dist: MarkupSafe==2.1.3 57 | Requires-Dist: mdurl==0.1.2 58 | Requires-Dist: mmh3==4.0.1 59 | Requires-Dist: monotonic==1.6 60 | Requires-Dist: more-itertools==10.1.0 61 | Requires-Dist: mpmath==1.3.0 62 | Requires-Dist: nh3==0.2.14 63 | Requires-Dist: numpy==1.26.2 64 | Requires-Dist: oauthlib==3.2.2 65 | Requires-Dist: onnxruntime==1.16.3 66 | Requires-Dist: opentelemetry-api==1.21.0 67 | Requires-Dist: opentelemetry-exporter-otlp-proto-common==1.21.0 68 | Requires-Dist: opentelemetry-exporter-otlp-proto-grpc==1.21.0 69 | Requires-Dist: opentelemetry-instrumentation==0.42b0 70 | Requires-Dist: opentelemetry-instrumentation-asgi==0.42b0 71 | Requires-Dist: opentelemetry-instrumentation-fastapi==0.42b0 72 | Requires-Dist: opentelemetry-proto==1.21.0 73 | Requires-Dist: opentelemetry-sdk==1.21.0 74 | Requires-Dist: opentelemetry-semantic-conventions==0.42b0 75 | Requires-Dist: opentelemetry-util-http==0.42b0 76 | Requires-Dist: overrides==7.4.0 77 | Requires-Dist: packaging==23.2 78 | Requires-Dist: pandas==2.1.3 79 | Requires-Dist: Pillow==10.1.0 80 | Requires-Dist: pkginfo==1.9.6 81 | Requires-Dist: posthog==3.0.2 82 | Requires-Dist: protobuf==4.25.1 83 | Requires-Dist: pulsar-client==3.3.0 84 | Requires-Dist: pyarrow==14.0.1 85 | Requires-Dist: pyasn1==0.5.1 86 | Requires-Dist: pyasn1-modules==0.3.0 87 | Requires-Dist: pydantic==2.5.2 88 | Requires-Dist: pydantic_core==2.14.5 89 | Requires-Dist: pydeck==0.8.1b0 90 | Requires-Dist: Pygments==2.17.2 91 | Requires-Dist: PyPika==0.48.9 92 | Requires-Dist: pyproject_hooks==1.0.0 93 | Requires-Dist: python-dateutil==2.8.2 94 | Requires-Dist: python-dotenv==1.0.0 95 | Requires-Dist: pytz==2023.3.post1 96 | Requires-Dist: PyYAML==6.0.1 97 | Requires-Dist: readme-renderer==42.0 98 | Requires-Dist: referencing==0.31.1 99 | Requires-Dist: requests==2.31.0 100 | Requires-Dist: requests-oauthlib==1.3.1 101 | Requires-Dist: requests-toolbelt==1.0.0 102 | Requires-Dist: rfc3986==2.0.0 103 | Requires-Dist: rich==13.7.0 104 | Requires-Dist: rpds-py==0.13.2 105 | Requires-Dist: rsa==4.9 106 | Requires-Dist: six==1.16.0 107 | Requires-Dist: smmap==5.0.1 108 | Requires-Dist: sniffio==1.3.0 109 | Requires-Dist: starlette==0.27.0 110 | Requires-Dist: streamlit==1.29.0 111 | Requires-Dist: sympy==1.12 112 | Requires-Dist: tenacity==8.2.3 113 | Requires-Dist: tokenizers==0.15.0 114 | Requires-Dist: toml==0.10.2 115 | Requires-Dist: tomli==2.0.1 116 | Requires-Dist: toolz==0.12.0 117 | Requires-Dist: tornado==6.4 118 | Requires-Dist: tqdm==4.66.1 119 | Requires-Dist: twine==4.0.2 120 | Requires-Dist: typer==0.9.0 121 | Requires-Dist: typing_extensions==4.8.0 122 | Requires-Dist: tzdata==2023.3 123 | Requires-Dist: tzlocal==5.2 124 | Requires-Dist: urllib3==1.26.18 125 | Requires-Dist: uvicorn==0.24.0.post1 126 | Requires-Dist: uvloop==0.19.0 127 | Requires-Dist: validators==0.22.0 128 | Requires-Dist: watchfiles==0.21.0 129 | Requires-Dist: websocket-client==1.6.4 130 | Requires-Dist: websockets==12.0 131 | Requires-Dist: wrapt==1.16.0 132 | Requires-Dist: zipp==3.17.0 133 | 134 | # 📂 ChromaDBConnection 135 | 136 | Connection for Chroma vector database, `ChromaDBConnection`, has been released which makes it easy to connect any Streamlit LLM-powered app to. 137 | 138 | With `st.connection()`, connecting to a Chroma vector database becomes just a few lines of code: 139 | 140 | 141 | ```python 142 | import streamlit as st 143 | from streamlit_chromadb_connection.chromadb_connection import ChromadbConnection 144 | 145 | configuration = { 146 | "client": "PersistentClient", 147 | "path": "/tmp/.chroma" 148 | } 149 | 150 | collection_name = "documents_collection" 151 | 152 | conn = st.connection("chromadb", 153 | type=ChromaDBConnection, 154 | **configuration) 155 | documents_collection_df = conn.get_collection_data(collection_name) 156 | st.dataframe(documents_collection_df) 157 | ``` 158 | 159 | ## 📑 ChromaDBConnection API 160 | 161 | ### _connect() 162 | There are 2 ways to connect to a Chroma client: 163 | 1. **PersistentClient**: Data will be persisted to a local machine 164 | ```python 165 | import streamlit as st 166 | from streamlit_chromadb_connection.chromadb_connection import ChromadbConnection 167 | 168 | configuration = { 169 | "client": "PersistentClient", 170 | "path": "/tmp/.chroma" 171 | } 172 | 173 | conn = st.connection(name="persistent_chromadb", 174 | type=ChromadbConnection, 175 | **configuration) 176 | ``` 177 | 178 | 2. **HttpClient**: Data will be persisted to a cloud server where Chroma resides 179 | ```python 180 | import streamlit as st 181 | from streamlit_chromadb_connection.chromadb_connection import ChromadbConnection 182 | 183 | configuration = { 184 | "client": "HttpClient", 185 | "host": "localhost", 186 | "port": 8000, 187 | } 188 | 189 | conn = st.connection(name="http_connection", 190 | type=ChromadbConnection, 191 | **configuration) 192 | ``` 193 | 194 | 195 | ### create_collection() 196 | In order to create a Chroma collection, one needs to supply a `collection_name` and `embedding_function_name`, `embedding_config` and (optional) `metadata`. 197 | 198 | There are current possible options for `embedding_function_name`: 199 | - DefaultEmbeddingFunction 200 | - SentenceTransformerEmbeddingFunction 201 | - OpenAIEmbeddingFunction 202 | - CohereEmbeddingFunction 203 | - GooglePalmEmbeddingFunction 204 | - GoogleVertexEmbeddingFunction 205 | - HuggingFaceEmbeddingFunction 206 | - InstructorEmbeddingFunction 207 | - Text2VecEmbeddingFunction 208 | - ONNXMiniLM_L6_V2 209 | 210 | For `DefaultEmbeddingFunction`, the `embedding_config` argument can be left as an empty string. However, for other embedding functions such as `OpenAIEmbeddingFunction`, one needs to provide configuration such as: 211 | 212 | ```python 213 | embedding_config = { 214 | api_key: "{OPENAI_API_KEY}", 215 | model_name: "{OPENAI_MODEL}", 216 | } 217 | ``` 218 | 219 | One can also change the distance function by changing the `metadata` argument, such as: 220 | 221 | ```python 222 | metadata = {"hnsw:space": "l2"} # Squared L2 norm 223 | metadata = {"hnsw:space": "cosine"} # Cosine similarity 224 | metadata = {"hnsw:space": "ip"} # Inner product 225 | ``` 226 | 227 | Sample code to create connection: 228 | 229 | ```python 230 | collection_name = "documents_collection" 231 | embedding_function_name = "DefaultEmbeddingFunction" 232 | conn.create_collection(collection_name=collection_name, 233 | embedding_function_name=embedding_function_name, 234 | embedding_config={}, 235 | metadata = {"hnsw:space": "cosine"}) 236 | ``` 237 | 238 | ### get_collection_data() 239 | This method returns a dataframe that consists of the embeddings and documents of a collection. 240 | The `attributes` argument is a list of attributes to be included in the DataFrame. 241 | The following code snippet will return all data in a collection in the form of a DataFrame, with 2 columns: `documents` and `embeddings`. 242 | 243 | ```python 244 | collection_name = "documents_collection" 245 | conn.get_collection_data(collection_name=collection_name, 246 | attributes= ["documents", "embeddings"]) 247 | ``` 248 | 249 | ### delete_collection() 250 | This method deletes the stated collection name. 251 | 252 | ```python 253 | collection_name = "documents_collection" 254 | conn.delete_collection(collection_name=collection_name) 255 | ``` 256 | 257 | ### upload_document() 258 | This method uploads documents to a collection. 259 | If embeddings are not provided, the method will embed the documents using the embedding function specified in the collection. 260 | 261 | 262 | ```python 263 | collection_name = "documents_collection" 264 | conn.upload_document(collection_name=collection_name, 265 | documents=["lorem ipsum", "doc2", "doc3"], 266 | metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}], 267 | ids=["id1", "id2", "id3"], 268 | embeddings=None) 269 | ``` 270 | 271 | ### query() 272 | This method retrieves top k relevant document based on a list of queries supplied. 273 | The result will be in a dataframe where each row will shows the top k relevant documents of each query. 274 | 275 | ```python 276 | collection_name = "documents_collection" 277 | conn.upload_document(collection_name=collection_name, 278 | documents=["lorem ipsum", "doc2", "doc3"], 279 | metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}], 280 | ids=["id1", "id2", "id3"], 281 | embeddings=None) 282 | 283 | queried_data = conn.query(collection_name=collection_name, 284 | query=["random_query1", "random_query2"], 285 | num_results_limit=10, 286 | attributes=["documents", "embeddings", "metadatas", "data"]) 287 | ``` 288 | 289 | Metadata and document filters are also provided in `where_metadata_filter` and `where_document_filter` arguments respectively for more relevant search. For better understanding on the usage of where filters, please refer to: https://docs.trychroma.com/usage-guide#using-where-filters 290 | 291 | ```python 292 | queried_data = conn.query(collection_name=collection_name, 293 | query=["this is"], 294 | num_results_limit=10, 295 | attributes=["documents", "embeddings", "metadatas", "data"], 296 | where_metadata_filter={"chapter": "3"}) 297 | ``` 298 | 299 | 300 | *** 301 | 🎉 That's it! `ChromaDBConnection` is ready to be used with `st.connection()`. 🎉 302 | *** 303 | 304 | ## Contribution 🔥 305 | ``` 306 | author={Vu Quang Minh}, 307 | github={Dev317}, 308 | year={2023} 309 | ``` 310 | -------------------------------------------------------------------------------- /src/streamlit_chromadb_connection.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | pyproject.toml 4 | requirements.txt 5 | src/streamlit_chromadb_connection/__init__.py 6 | src/streamlit_chromadb_connection/chromadb_connection.py 7 | src/streamlit_chromadb_connection.egg-info/PKG-INFO 8 | src/streamlit_chromadb_connection.egg-info/SOURCES.txt 9 | src/streamlit_chromadb_connection.egg-info/dependency_links.txt 10 | src/streamlit_chromadb_connection.egg-info/requires.txt 11 | src/streamlit_chromadb_connection.egg-info/top_level.txt -------------------------------------------------------------------------------- /src/streamlit_chromadb_connection.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/streamlit_chromadb_connection.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | altair==5.2.0 2 | annotated-types==0.6.0 3 | anyio==3.7.1 4 | asgiref==3.7.2 5 | attrs==23.1.0 6 | backoff==2.2.1 7 | bcrypt==4.1.1 8 | blinker==1.7.0 9 | build==1.0.3 10 | cachetools==5.3.2 11 | certifi==2023.11.17 12 | charset-normalizer==3.3.2 13 | chroma-hnswlib==0.7.3 14 | chromadb==0.4.18 15 | click==8.1.7 16 | coloredlogs==15.0.1 17 | Deprecated==1.2.14 18 | docutils==0.20.1 19 | exceptiongroup==1.2.0 20 | fastapi==0.104.1 21 | filelock==3.13.1 22 | flatbuffers==23.5.26 23 | fsspec==2023.10.0 24 | gitdb==4.0.11 25 | GitPython==3.1.40 26 | google-auth==2.24.0 27 | googleapis-common-protos==1.61.0 28 | grpcio==1.59.3 29 | h11==0.14.0 30 | httptools==0.6.1 31 | huggingface-hub==0.19.4 32 | humanfriendly==10.0 33 | idna==3.6 34 | importlib-metadata==6.9.0 35 | importlib-resources==6.1.1 36 | jaraco.classes==3.3.0 37 | Jinja2==3.1.2 38 | jsonschema==4.20.0 39 | jsonschema-specifications==2023.11.2 40 | keyring==24.3.0 41 | kubernetes==28.1.0 42 | markdown-it-py==3.0.0 43 | MarkupSafe==2.1.3 44 | mdurl==0.1.2 45 | mmh3==4.0.1 46 | monotonic==1.6 47 | more-itertools==10.1.0 48 | mpmath==1.3.0 49 | nh3==0.2.14 50 | numpy==1.26.2 51 | oauthlib==3.2.2 52 | onnxruntime==1.16.3 53 | opentelemetry-api==1.21.0 54 | opentelemetry-exporter-otlp-proto-common==1.21.0 55 | opentelemetry-exporter-otlp-proto-grpc==1.21.0 56 | opentelemetry-instrumentation==0.42b0 57 | opentelemetry-instrumentation-asgi==0.42b0 58 | opentelemetry-instrumentation-fastapi==0.42b0 59 | opentelemetry-proto==1.21.0 60 | opentelemetry-sdk==1.21.0 61 | opentelemetry-semantic-conventions==0.42b0 62 | opentelemetry-util-http==0.42b0 63 | overrides==7.4.0 64 | packaging==23.2 65 | pandas==2.1.3 66 | Pillow==10.1.0 67 | pkginfo==1.9.6 68 | posthog==3.0.2 69 | protobuf==4.25.1 70 | pulsar-client==3.3.0 71 | pyarrow==14.0.1 72 | pyasn1==0.5.1 73 | pyasn1-modules==0.3.0 74 | pydantic==2.5.2 75 | pydantic_core==2.14.5 76 | pydeck==0.8.1b0 77 | Pygments==2.17.2 78 | PyPika==0.48.9 79 | pyproject_hooks==1.0.0 80 | python-dateutil==2.8.2 81 | python-dotenv==1.0.0 82 | pytz==2023.3.post1 83 | PyYAML==6.0.1 84 | readme-renderer==42.0 85 | referencing==0.31.1 86 | requests==2.31.0 87 | requests-oauthlib==1.3.1 88 | requests-toolbelt==1.0.0 89 | rfc3986==2.0.0 90 | rich==13.7.0 91 | rpds-py==0.13.2 92 | rsa==4.9 93 | six==1.16.0 94 | smmap==5.0.1 95 | sniffio==1.3.0 96 | starlette==0.27.0 97 | streamlit==1.29.0 98 | sympy==1.12 99 | tenacity==8.2.3 100 | tokenizers==0.15.0 101 | toml==0.10.2 102 | tomli==2.0.1 103 | toolz==0.12.0 104 | tornado==6.4 105 | tqdm==4.66.1 106 | twine==4.0.2 107 | typer==0.9.0 108 | typing_extensions==4.8.0 109 | tzdata==2023.3 110 | tzlocal==5.2 111 | urllib3==1.26.18 112 | uvicorn==0.24.0.post1 113 | uvloop==0.19.0 114 | validators==0.22.0 115 | watchfiles==0.21.0 116 | websocket-client==1.6.4 117 | websockets==12.0 118 | wrapt==1.16.0 119 | zipp==3.17.0 120 | -------------------------------------------------------------------------------- /src/streamlit_chromadb_connection.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | streamlit_chromadb_connection 2 | -------------------------------------------------------------------------------- /src/streamlit_chromadb_connection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dev317/streamlit_chromadb_connection/5b89190df285c2e46f02bb10c10d05b8d6d98962/src/streamlit_chromadb_connection/__init__.py -------------------------------------------------------------------------------- /src/streamlit_chromadb_connection/chromadb_connection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import streamlit 3 | from streamlit.connections import BaseConnection 4 | import chromadb 5 | from chromadb.utils.embedding_functions import ( 6 | OpenAIEmbeddingFunction, 7 | CohereEmbeddingFunction, 8 | GooglePalmEmbeddingFunction, 9 | HuggingFaceEmbeddingFunction, 10 | InstructorEmbeddingFunction, 11 | DefaultEmbeddingFunction, 12 | SentenceTransformerEmbeddingFunction, 13 | Text2VecEmbeddingFunction, 14 | ONNXMiniLM_L6_V2, 15 | GoogleVertexEmbeddingFunction 16 | ) 17 | from chromadb.api import ClientAPI 18 | from chromadb import PersistentClient, HttpClient 19 | from chromadb.api.types import Documents, EmbeddingFunction 20 | from chromadb.api.models.Collection import Collection 21 | from typing import Dict, List, Union, cast 22 | from typing_extensions import override 23 | import pandas as pd 24 | 25 | 26 | class ChromadbConnection(BaseConnection): 27 | 28 | EMBEDDING_FUNCTION_MAP = { 29 | "DefaultEmbeddingFunction": DefaultEmbeddingFunction, 30 | "SentenceTransformerEmbeddingFunction": SentenceTransformerEmbeddingFunction, 31 | "OpenAIEmbeddingFunction": OpenAIEmbeddingFunction, 32 | "CohereEmbeddingFunction": CohereEmbeddingFunction, 33 | "GooglePalmEmbeddingFunction": GooglePalmEmbeddingFunction, 34 | "GoogleVertexEmbeddingFunction": GoogleVertexEmbeddingFunction, 35 | "HuggingFaceEmbeddingFunction": HuggingFaceEmbeddingFunction, 36 | "InstructorEmbeddingFunction": InstructorEmbeddingFunction, 37 | "Text2VecEmbeddingFunction": Text2VecEmbeddingFunction, 38 | "ONNXMiniLM_L6_V2": ONNXMiniLM_L6_V2 39 | } 40 | 41 | """ 42 | This class acts as an adapter to connect to ChromaDB vector database. 43 | It extends the BaseConnection class by overidding _connect(). 44 | It also provides other helpful methods to interact with the ChromaDB client. 45 | """ 46 | 47 | @override 48 | def _connect(self, 49 | client: str = "PersistentClient", 50 | **kwargs) -> ClientAPI: 51 | 52 | if client == "PersistentClient": 53 | if "path" not in self._kwargs: 54 | raise Exception("`path` argument is not provided!") 55 | 56 | path = self._kwargs["path"] 57 | if not os.path.exists(path): 58 | raise Exception(f"Path `{path}` does not exist!") 59 | 60 | return chromadb.PersistentClient( 61 | path=path 62 | ) 63 | 64 | if client == "HttpClient": 65 | if "host" not in self._kwargs: 66 | raise Exception("`host` argument is not provided!") 67 | if "port" not in self._kwargs: 68 | raise Exception("`port` argument is not provided!") 69 | 70 | return chromadb.HttpClient( 71 | host=self._kwargs["host"], 72 | port=self._kwargs["port"], 73 | ) 74 | 75 | else: 76 | raise Exception("Invalid client type provided in `client` argument!") 77 | 78 | def create_collection(self, 79 | collection_name: str, 80 | embedding_function_name: str, 81 | embedding_config: Dict, 82 | metadata: Dict = {"hnsw:space": "l2"}) -> None: 83 | """ 84 | This method creates a collection in ChromaDB that requires an embedding function and distance method. 85 | The embedding function is specified by the `embedding_function_name` argument. 86 | The `embedding_config` argument is a dictionary that contains the configuration for the embedding function. 87 | The `metadata` argument is a dictionary that contains the configuration for the distance method. 88 | """ 89 | 90 | if embedding_function_name not in self.EMBEDDING_FUNCTION_MAP: 91 | raise Exception("Invalid embedding function provided in `embedding_function` argument!") 92 | 93 | try: 94 | embedding_function = self.EMBEDDING_FUNCTION_MAP[embedding_function_name](**embedding_config) 95 | self._instance.create_collection( 96 | name=collection_name, 97 | embedding_function=embedding_function, 98 | metadata=metadata 99 | ) 100 | except Exception as exception: 101 | raise Exception(f"Error while creating collection `{collection_name}`: {str(exception)}") 102 | 103 | 104 | def delete_collection(self, collection_name: str) -> None: 105 | """ 106 | This method deletes a collection in ChromaDB. 107 | If the collection does not exist, it will raise an exception. 108 | """ 109 | 110 | try: 111 | self._instance.delete_collection(name=collection_name) 112 | except Exception as exception: 113 | raise Exception(f"Error while deleting collection `{collection_name}`: {str(exception)}") 114 | 115 | def get_collection(self, collection_name: str) -> Collection: 116 | """ 117 | This method gets a collection in ChromaDB. 118 | If the collection does not exist, it will raise an exception. 119 | """ 120 | 121 | try: 122 | return self._instance.get_collection(name=collection_name) 123 | except Exception as exception: 124 | raise Exception(f"Error while getting collection `{collection_name}`: {str(exception)}") 125 | 126 | def get_all_collection_names(self) -> List: 127 | """ 128 | This method gets all collection names in ChromaDB. 129 | """ 130 | 131 | collection_names = [] 132 | collections = self._instance.list_collections() 133 | for col in collections: 134 | collection_names.append(col.name) 135 | return collection_names 136 | 137 | def upload_documents(self, 138 | collection_name: str, 139 | documents: List, 140 | metadatas: List, 141 | ids: List, 142 | embedding_function_name: str = "", 143 | embedding_config: Dict = {}, 144 | embeddings: List = None) -> None: 145 | """ 146 | This method uploads documents to a collection in ChromaDB. 147 | The `documents` argument is a list of documents, which contains list of texts to be embedded. 148 | The `metadatas` argument is a list of metadatas, which contains list of dictionaries that provide details about each document. 149 | The `embeddings` argument is a list of embeddings, which contains list of embeddings for each document. 150 | The `ids` argument is a list of ids, which contains list of ids for each document. 151 | 152 | If embeddings are not provided, the method will embed the documents using the embedding function specified in the collection. 153 | """ 154 | if embedding_function_name not in self.EMBEDDING_FUNCTION_MAP: 155 | raise Exception("Invalid embedding function provided in `embedding_function` argument!") 156 | 157 | embedding_function = cast( 158 | EmbeddingFunction[Documents], 159 | self.EMBEDDING_FUNCTION_MAP[embedding_function_name](**embedding_config), 160 | ) 161 | 162 | try: 163 | collection = self._instance.get_collection( 164 | name=collection_name, 165 | embedding_function=embedding_function 166 | ) 167 | for idx, doc in enumerate(documents): 168 | if not embeddings: 169 | embedding = collection._embedding_function([doc]) 170 | else: 171 | embedding = embeddings[idx] 172 | 173 | collection.add(ids=ids[idx], 174 | metadatas=metadatas[idx], 175 | documents=doc, 176 | embeddings=embedding) 177 | 178 | except Exception as exception: 179 | raise Exception(f"Error while adding document to collection `{collection_name}`: {str(exception)}") 180 | 181 | def update_collection_data(self, 182 | collection_name: str, 183 | ids: List, 184 | documents: List, 185 | metadatas: List, 186 | embedding_function_name: str = "", 187 | embedding_config: Dict = {}, 188 | embeddings: List = None) -> None: 189 | """ 190 | This method updates documents in a collection in ChromaDB based on their existing ids. 191 | """ 192 | if embedding_function_name not in self.EMBEDDING_FUNCTION_MAP: 193 | raise Exception("Invalid embedding function provided in `embedding_function` argument!") 194 | 195 | embedding_function = cast( 196 | EmbeddingFunction[Documents], 197 | self.EMBEDDING_FUNCTION_MAP[embedding_function_name](**embedding_config), 198 | ) 199 | 200 | try: 201 | collection = self._instance.get_collection(collection_name) 202 | for idx, doc in enumerate(documents): 203 | if not embeddings: 204 | embedding = collection._embedding_function([doc]) 205 | else: 206 | embedding = embeddings[idx] 207 | 208 | collection.update(ids=ids[idx], 209 | metadatas=metadatas[idx], 210 | documents=doc, 211 | embeddings=embedding) 212 | except Exception as exception: 213 | raise Exception(f"Error while updating document in collection `{collection_name}`: {str(exception)}") 214 | 215 | 216 | def get_collection_data(self, 217 | collection_name: str, 218 | attributes: List = ["documents", "embeddings", "metadatas"]): 219 | """ 220 | This method gets all data from a collection in ChromaDB in form of a Pandas DataFrame. 221 | The `attributes` argument is a list of attributes to be included in the DataFrame. 222 | """ 223 | 224 | @streamlit.cache_data(ttl=10) 225 | def get_data() -> pd.DataFrame: 226 | try: 227 | collection = self._instance.get_collection(collection_name) 228 | collection_data = collection.get( 229 | include=attributes 230 | ) 231 | return pd.DataFrame(data=collection_data) 232 | except Exception as exception: 233 | raise Exception(f"Error while getting data from collection `{collection_name}`: {str(exception)}") 234 | return get_data() 235 | 236 | def query(self, 237 | collection_name: str, 238 | query: List, 239 | where_metadata_filter: Dict = None, 240 | where_document_filter: Dict = None, 241 | num_results_limit: int = 10, 242 | attributes: List = ["distances", "documents", "embeddings", "metadatas", "uris", "data"], 243 | ) -> pd.DataFrame: 244 | """ 245 | This method queries a collection in ChromaDB based on a list of query texts. 246 | The `where_metadata_filter` argument is a dictionary that contains the metadata filter. 247 | The `where_document_filter` argument is a dictionary that contains the document filter. 248 | The `num_results_limit` argument is the number of results to be returned. 249 | The `attributes` argument is a list of attributes to be included in the DataFrame. 250 | 251 | The return dataframe will only contain one row of result for each query text. 252 | """ 253 | 254 | try: 255 | collection = self._instance.get_collection(collection_name) 256 | results = collection.query( 257 | query_texts=query, 258 | n_results=num_results_limit, 259 | include=attributes, 260 | where=where_metadata_filter, 261 | where_document=where_document_filter 262 | ) 263 | df = pd.DataFrame() 264 | df["ids"] = results["ids"] 265 | 266 | for k in attributes: 267 | if results[k] is None: 268 | df[k] = [None] * len(df) 269 | else: 270 | df[k] = results[k] 271 | 272 | return df 273 | 274 | except Exception as exception: 275 | raise Exception(f"Error while querying collection `{collection_name}`: {str(exception)}") 276 | -------------------------------------------------------------------------------- /tests/unit_tests/collection_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from unittest import TestCase 4 | from unittest.mock import patch, MagicMock 5 | from src.streamlit_chromadb_connection.chromadb_connection import ChromadbConnection 6 | import streamlit 7 | 8 | 9 | class TestCollection(TestCase): 10 | 11 | def __init__(self, *args, **kwargs): 12 | super(TestCollection, self).__init__(*args, **kwargs) 13 | 14 | def test_create_collection(self): 15 | mock_persistent_path = f"{os.getcwd()}/tests/unit_tests/create_collection_persistent" 16 | os.mkdir(mock_persistent_path) 17 | mock_connection = streamlit.connection( 18 | name="test_create_collection", 19 | type=ChromadbConnection, 20 | client="PersistentClient", 21 | path=mock_persistent_path 22 | ) 23 | 24 | try: 25 | mock_connection.create_collection( 26 | collection_name="test_create_collection", 27 | embedding_function_name="DefaultEmbeddingFunction", 28 | embedding_config={}, 29 | ) 30 | except Exception as ex: 31 | self.fail(f"create_collection() raised Exception: {str(ex)}!") 32 | finally: 33 | shutil.rmtree(mock_persistent_path) 34 | 35 | # TODO: resolve sqlite operation 36 | # def test_create_collection_invalid_embedding_function(self): 37 | # mock_persistent_path = f"{os.getcwd()}/tests/unit_tests/create_collection_persistent" 38 | # os.mkdir(mock_persistent_path) 39 | # mock_connection = streamlit.connection( 40 | # name="test_create_collection_invalid_embedding_function", 41 | # type=ChromadbConnection, 42 | # client="PersistentClient", 43 | # path=mock_persistent_path 44 | # ) 45 | 46 | # with self.assertRaises(Exception) as context: 47 | # mock_connection.create_collection( 48 | # collection_name="test_create_invalid_embedding_collection", 49 | # embedding_function_name="InvalidEmbeddingFunction", 50 | # embedding_config={}, 51 | # ) 52 | # self.assertTrue("Invalid embedding function provided in `embedding_function` argument!" in str(context.exception)) 53 | # shutil.rmtree(mock_persistent_path) 54 | 55 | # def test_create_collection_existing_collection(self): 56 | # mock_persistent_path = f"{os.getcwd()}/tests/unit_tests/create_collection_persistent" 57 | # os.mkdir(mock_persistent_path) 58 | # mock_connection = streamlit.connection( 59 | # name="test_create_collection_existing_collection", 60 | # type=ChromadbConnection, 61 | # client="PersistentClient", 62 | # path=mock_persistent_path 63 | # ) 64 | # try: 65 | # with self.assertRaises(Exception) as context: 66 | # mock_connection.create_collection( 67 | # collection_name="test_create_existing_collection", 68 | # embedding_function_name="DefaultEmbeddingFunction", 69 | # embedding_config={}, 70 | # ) 71 | 72 | # mock_connection.create_collection( 73 | # collection_name="test_create_existing_collection", 74 | # embedding_function_name="DefaultEmbeddingFunction", 75 | # embedding_config={}, 76 | # ) 77 | 78 | # self.assertTrue(f"Error while creating collection `{self._connection_name}`: Collection already exists!" in str(context.exception)) 79 | # finally: 80 | # shutil.rmtree(mock_persistent_path) 81 | 82 | def test_delete_collection(self): 83 | mock_persistent_dir = f"{os.getcwd()}/tests/unit_tests/delete_persistent" 84 | os.mkdir(mock_persistent_dir) 85 | mock_connection = streamlit.connection( 86 | name="test_delete_collection", 87 | type=ChromadbConnection, 88 | client="PersistentClient", 89 | path=mock_persistent_dir 90 | ) 91 | 92 | try: 93 | mock_connection.create_collection( 94 | collection_name="test_delete_collection", 95 | embedding_function_name="DefaultEmbeddingFunction", 96 | embedding_config={}, 97 | ) 98 | mock_connection.delete_collection( 99 | collection_name="test_delete_collection" 100 | ) 101 | except Exception as ex: 102 | self.fail(f"delete_collection() raised Exception: {str(ex)}!") 103 | finally: 104 | shutil.rmtree(mock_persistent_dir) 105 | 106 | def test_delete_non_existing_collection(self): 107 | mock_persistent_dir = f"{os.getcwd()}/tests/unit_tests/delete_non_existing_persistent" 108 | os.mkdir(mock_persistent_dir) 109 | mock_connection = streamlit.connection( 110 | name="test_delete_non_existing_collection", 111 | type=ChromadbConnection, 112 | client="PersistentClient", 113 | path=mock_persistent_dir 114 | ) 115 | 116 | with self.assertRaises(Exception) as context: 117 | mock_connection.delete_collection( 118 | collection_name="test_delete_non_existing_collection" 119 | ) 120 | self.assertTrue(f"Error while deleting collection `{self.collection_name}`: Collection does not exist!" in str(context.exception)) 121 | shutil.rmtree(mock_persistent_dir) 122 | 123 | def test_get_collection(self): 124 | mock_persistent_dir = f"{os.getcwd()}/tests/unit_tests/get_collection_persistent" 125 | os.mkdir(mock_persistent_dir) 126 | mock_connection = streamlit.connection( 127 | name="test_get_collection", 128 | type=ChromadbConnection, 129 | client="PersistentClient", 130 | path=mock_persistent_dir 131 | ) 132 | 133 | try: 134 | mock_connection.create_collection( 135 | collection_name="test_get_existing_collection", 136 | embedding_function_name="DefaultEmbeddingFunction", 137 | embedding_config={}, 138 | ) 139 | 140 | mock_connection.get_collection( 141 | collection_name="test_get_existing_collection" 142 | ) 143 | except Exception as ex: 144 | self.fail(f"get_collection() raised Exception: {str(ex)}!") 145 | finally: 146 | shutil.rmtree(mock_persistent_dir) 147 | 148 | def test_get_non_existing_collection(self): 149 | mock_persistent_dir = f"{os.getcwd()}/tests/unit_tests/get_collection_non_persistent" 150 | os.mkdir(mock_persistent_dir) 151 | mock_connection = streamlit.connection( 152 | name="test_get_non_existing_collection", 153 | type=ChromadbConnection, 154 | client="PersistentClient", 155 | path=mock_persistent_dir 156 | ) 157 | 158 | with self.assertRaises(Exception) as context: 159 | mock_connection.get_collection( 160 | collection_name="test_get_non_existing_collection" 161 | ) 162 | self.assertTrue(f"Error while getting collection `{self.collection_name}`: Collection does not exist!" in str(context.exception)) 163 | shutil.rmtree(mock_persistent_dir) 164 | 165 | def test_get_all_collections(self): 166 | mock_persistent_dir = f"{os.getcwd()}/tests/unit_tests/get_all_collection_persistent" 167 | os.mkdir(mock_persistent_dir) 168 | mock_connection = streamlit.connection( 169 | name="test_get_all_collections", 170 | type=ChromadbConnection, 171 | client="PersistentClient", 172 | path=mock_persistent_dir 173 | ) 174 | mock_connection.create_collection( 175 | collection_name="test_get_existing_collection_1", 176 | embedding_function_name="DefaultEmbeddingFunction", 177 | embedding_config={}, 178 | ) 179 | mock_connection.create_collection( 180 | collection_name="test_get_existing_collection_2", 181 | embedding_function_name="DefaultEmbeddingFunction", 182 | embedding_config={}, 183 | ) 184 | 185 | collection_names = mock_connection.get_all_collection_names() 186 | self.assertGreaterEqual(len(collection_names), 2) 187 | shutil.rmtree(mock_persistent_dir) 188 | 189 | def test_add_document(self): 190 | mock_persistent_dir = f"{os.getcwd()}/tests/unit_tests/add_docs_persistent" 191 | os.mkdir(mock_persistent_dir) 192 | mock_connection = streamlit.connection( 193 | name="test_add_document", 194 | type=ChromadbConnection, 195 | client="PersistentClient", 196 | path=mock_persistent_dir 197 | ) 198 | 199 | try: 200 | mock_connection.create_collection( 201 | collection_name="test_add_collection", 202 | embedding_function_name="DefaultEmbeddingFunction", 203 | embedding_config={}, 204 | ) 205 | 206 | mock_connection.upload_documents( 207 | collection_name="test_add_collection", 208 | documents=["lorem ipsum", "doc2", "doc3"], 209 | metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}], 210 | ids=["id1", "id2", "id3"], 211 | embedding_function_name="DefaultEmbeddingFunction", 212 | embedding_config={}, 213 | ) 214 | except Exception as ex: 215 | self.fail(f"add_document() raised Exception: {str(ex)}!") 216 | finally: 217 | shutil.rmtree(mock_persistent_dir) 218 | 219 | def test_get_collection_data(self): 220 | mock_persistent_dir = f"{os.getcwd()}/tests/unit_tests/get_data_persistent" 221 | os.mkdir(mock_persistent_dir) 222 | mock_connection = streamlit.connection( 223 | name="test_get_data", 224 | type=ChromadbConnection, 225 | client="PersistentClient", 226 | path=mock_persistent_dir 227 | ) 228 | 229 | try: 230 | mock_connection.create_collection( 231 | collection_name="test_get_data", 232 | embedding_function_name="DefaultEmbeddingFunction", 233 | embedding_config={}, 234 | ) 235 | 236 | mock_connection.upload_documents( 237 | collection_name="test_get_data", 238 | documents=["lorem ipsum", "doc2", "doc3"], 239 | metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}], 240 | ids=["id1", "id2", "id3"], 241 | embedding_function_name="DefaultEmbeddingFunction", 242 | embedding_config={}, 243 | embeddings=None, 244 | ) 245 | existing_data = mock_connection.get_collection_data( 246 | collection_name="test_get_data", 247 | attributes=["documents", "embeddings", "metadatas"] 248 | ) 249 | self.assertEqual(len(existing_data), 3) 250 | except Exception as ex: 251 | self.fail(f"get_collection_data() raised Exception: {str(ex)}!") 252 | finally: 253 | shutil.rmtree(mock_persistent_dir) 254 | 255 | def test_query_collection(self): 256 | mock_persistent_dir = f"{os.getcwd()}/tests/unit_tests/query_data_persistent" 257 | os.mkdir(mock_persistent_dir) 258 | mock_connection = streamlit.connection( 259 | name="test_query_collection", 260 | type=ChromadbConnection, 261 | client="PersistentClient", 262 | path=mock_persistent_dir 263 | ) 264 | 265 | try: 266 | mock_connection.create_collection( 267 | collection_name="test_query_collection", 268 | embedding_function_name="DefaultEmbeddingFunction", 269 | embedding_config={}, 270 | ) 271 | 272 | mock_connection.upload_documents( 273 | collection_name="test_query_collection", 274 | documents=["this is a", "this is b", "this is c"], 275 | metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}], 276 | ids=["id1", "id2", "id3"], 277 | embedding_function_name="DefaultEmbeddingFunction", 278 | embedding_config={}, 279 | embeddings=None, 280 | ) 281 | existing_data = mock_connection.get_collection_data( 282 | collection_name="test_query_collection", 283 | attributes=["documents", "embeddings", "metadatas"] 284 | ) 285 | self.assertEqual(len(existing_data), 3) 286 | queried_data = mock_connection.query( 287 | collection_name="test_query_collection", 288 | query=["this is"], 289 | num_results_limit=10, 290 | attributes=["documents", "embeddings", "metadatas", "data"] 291 | ) 292 | self.assertLessEqual(len(queried_data["ids"]), 3) 293 | 294 | except Exception as ex: 295 | self.fail(f"query() raised Exception: {str(ex)}!") 296 | finally: 297 | shutil.rmtree(mock_persistent_dir) 298 | 299 | def test_query_collection_with_filter(self): 300 | mock_persistent_dir = f"{os.getcwd()}/tests/unit_tests/query_filter_data_persistent" 301 | os.mkdir(mock_persistent_dir) 302 | mock_connection = streamlit.connection( 303 | name="test_query_collection_with_filter", 304 | type=ChromadbConnection, 305 | client="PersistentClient", 306 | path=mock_persistent_dir 307 | ) 308 | 309 | try: 310 | mock_connection.create_collection( 311 | collection_name="test_query_filter_collection", 312 | embedding_function_name="DefaultEmbeddingFunction", 313 | embedding_config={}, 314 | ) 315 | 316 | mock_connection.upload_documents( 317 | collection_name="test_query_filter_collection", 318 | documents=["this is a", "this is b", "this is c"], 319 | metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}], 320 | ids=["id1", "id2", "id3"], 321 | embedding_function_name="DefaultEmbeddingFunction", 322 | embedding_config={}, 323 | embeddings=None, 324 | ) 325 | existing_data = mock_connection.get_collection_data( 326 | collection_name="test_query_filter_collection", 327 | attributes=["documents", "embeddings", "metadatas"] 328 | ) 329 | self.assertEqual(len(existing_data), 3) 330 | queried_data = mock_connection.query( 331 | collection_name="test_query_filter_collection", 332 | query=["this is"], 333 | num_results_limit=10, 334 | attributes=["documents", "embeddings", "metadatas", "data"], 335 | where_metadata_filter={"chapter": "3"}, 336 | ) 337 | self.assertLessEqual(len(queried_data["ids"]), 1) 338 | 339 | except Exception as ex: 340 | self.fail(f"query() raised Exception: {str(ex)}!") 341 | finally: 342 | shutil.rmtree(mock_persistent_dir) 343 | 344 | def test_update_collection_data(self): 345 | mock_persistent_dir = f"{os.getcwd()}/tests/unit_tests/update_data_persistent" 346 | os.mkdir(mock_persistent_dir) 347 | mock_connection = streamlit.connection( 348 | name="test_update_collection_data", 349 | type=ChromadbConnection, 350 | client="PersistentClient", 351 | path=mock_persistent_dir 352 | ) 353 | 354 | try: 355 | mock_connection.create_collection( 356 | collection_name="test_update_collection_data", 357 | embedding_function_name="DefaultEmbeddingFunction", 358 | embedding_config={}, 359 | ) 360 | 361 | mock_connection.upload_documents( 362 | collection_name="test_update_collection_data", 363 | documents=["this is a", "this is b", "this is c"], 364 | metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}], 365 | embedding_function_name="DefaultEmbeddingFunction", 366 | embedding_config={}, 367 | ids=["id1", "id2", "id3"], 368 | ) 369 | mock_connection.update_collection_data( 370 | collection_name="test_update_collection_data", 371 | ids=["id1", "id2", "id3"], 372 | documents=["this is b", "this is c", "this is d"], 373 | metadatas=[{"chapter": "3", "verse": "16"}, {"chapter": "3", "verse": "5"}, {"chapter": "29", "verse": "11"}], 374 | embedding_function_name="DefaultEmbeddingFunction", 375 | embedding_config={}, 376 | ) 377 | updated_data = mock_connection.get_collection_data( 378 | collection_name="test_update_collection_data", 379 | attributes=["documents", "embeddings", "metadatas"] 380 | ) 381 | self.assertEqual(updated_data["documents"].tolist(), ["this is b", "this is c", "this is d"]) 382 | except Exception as ex: 383 | self.fail(f"update_collection_data() raised Exception: {str(ex)}!") 384 | finally: 385 | shutil.rmtree(mock_persistent_dir) 386 | -------------------------------------------------------------------------------- /tests/unit_tests/connect_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest import TestCase 3 | from unittest.mock import patch, MagicMock 4 | from src.streamlit_chromadb_connection.chromadb_connection import ChromadbConnection 5 | import streamlit 6 | import chromadb 7 | import shutil 8 | 9 | 10 | class TestConnect(TestCase): 11 | 12 | def __init__(self, *args, **kwargs): 13 | super(TestConnect, self).__init__(*args, **kwargs) 14 | self.persistent_path = f"{os.getcwd()}/tests/unit_tests/persistent" 15 | 16 | def setUp(self) -> None: 17 | if not os.path.exists(self.persistent_path): 18 | os.mkdir(self.persistent_path) 19 | 20 | def tearDown(self) -> None: 21 | if os.path.exists(self.persistent_path): 22 | shutil.rmtree(self.persistent_path) 23 | 24 | def test_persistent_connect(self): 25 | connection = streamlit.connection( 26 | name="persistent_connection", 27 | type=ChromadbConnection, 28 | client="PersistentClient", 29 | path=self.persistent_path 30 | ) 31 | self.assertEqual(connection._raw_instance._server._system.settings.is_persistent, True) 32 | self.assertEqual(connection._raw_instance._server._system.settings.persist_directory, self.persistent_path) 33 | 34 | def test_persistent_connect_missing_path(self): 35 | with self.assertRaises(Exception) as context: 36 | streamlit.connection( 37 | name="persistent_connection", 38 | type=ChromadbConnection, 39 | client="PersistentClient", 40 | ) 41 | self.assertTrue("`path` argument is not provided!" in str(context.exception)) 42 | 43 | def test_persistent_connect_invalid_path(self): 44 | invalid_path = f"{os.getcwd()}/tests/unit_tests/invalid_persistent" 45 | with self.assertRaises(Exception) as context: 46 | streamlit.connection( 47 | name="persistent_connection", 48 | type=ChromadbConnection, 49 | client="PersistentClient", 50 | path=invalid_path 51 | ) 52 | self.assertTrue(f"Path `{invalid_path}` does not exist!" in str(context.exception)) 53 | 54 | @patch("src.streamlit_chromadb_connection.chromadb_connection.ChromadbConnection._connect") 55 | def test_http_connect(self, mock_client_object): 56 | mock_client_object.return_value = MagicMock(spec=chromadb.HttpClient) 57 | streamlit.connection( 58 | name="http_connection", 59 | type=ChromadbConnection, 60 | client="HttpClient", 61 | host="localhost", 62 | port=8080 63 | ) 64 | mock_client_object.assert_called_once_with(client="HttpClient", host="localhost", port=8080) 65 | 66 | def test_http_connect_missing_host(self): 67 | with self.assertRaises(Exception) as context: 68 | streamlit.connection( 69 | name="http_connection", 70 | type=ChromadbConnection, 71 | client="HttpClient", 72 | port=8080 73 | ) 74 | self.assertTrue("`host` argument is not provided!" in str(context.exception)) 75 | 76 | def test_http_connect_missing_port(self): 77 | with self.assertRaises(Exception) as context: 78 | streamlit.connection( 79 | name="http_connection", 80 | type=ChromadbConnection, 81 | client="HttpClient", 82 | host="localhost" 83 | ) 84 | self.assertTrue("`port` argument is not provided!" in str(context.exception)) 85 | 86 | --------------------------------------------------------------------------------